update
This commit is contained in:
+3
-1
@@ -126,4 +126,6 @@ wandb/
|
||||
train_logs/
|
||||
test_logs/
|
||||
arcface_ckpt/
|
||||
GUI/
|
||||
GUI/
|
||||
insightface_func/
|
||||
parsing_model/
|
||||
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator.py
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 20th January 2022 10:51:02 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
class InstanceNorm(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
"""
|
||||
@notice: avoid in-place ops.
|
||||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
||||
"""
|
||||
super(InstanceNorm, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
x = x - torch.mean(x, (2, 3), True)
|
||||
tmp = torch.mul(x, x) # or x ** 2
|
||||
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
||||
return x * tmp
|
||||
|
||||
class ApplyStyle(nn.Module):
|
||||
"""
|
||||
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
||||
"""
|
||||
def __init__(self, latent_size, channels):
|
||||
super(ApplyStyle, self).__init__()
|
||||
self.linear = nn.Linear(latent_size, channels * 2)
|
||||
|
||||
def forward(self, x, latent):
|
||||
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
||||
shape = [-1, 2, x.size(1), 1, 1]
|
||||
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
||||
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
||||
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
|
||||
return x
|
||||
|
||||
class ResnetBlock_Adain(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Adain, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
if padding_type == 'reflect':
|
||||
conv1 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv1 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = ApplyStyle(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
conv2 = []
|
||||
if padding_type == 'reflect':
|
||||
conv2 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv2 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = ApplyStyle(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
y = self.conv1(x)
|
||||
y = self.style1(y, dlatents_in_slice)
|
||||
y = self.act1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.style2(y, dlatents_in_slice)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
class newres(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Adain, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
if padding_type == 'reflect':
|
||||
conv1 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv1 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=1), InstanceNorm()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = ApplyStyle(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
conv2 = []
|
||||
if padding_type == 'reflect':
|
||||
conv2 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv2 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = ApplyStyle(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
y = self.conv1(x)
|
||||
y = self.style1(y, dlatents_in_slice)
|
||||
y = self.act1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.style2(y, dlatents_in_slice)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(128), activation)
|
||||
|
||||
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(256), activation)
|
||||
|
||||
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(512), activation)
|
||||
|
||||
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(512), activation)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
self.up4 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(512), activation
|
||||
)
|
||||
|
||||
self.up3 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(256), activation
|
||||
)
|
||||
|
||||
self.up2 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(128), activation
|
||||
)
|
||||
|
||||
self.up1 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(64), activation
|
||||
)
|
||||
|
||||
self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input, id):
|
||||
x = input # 3*224*224
|
||||
skip1 = self.first_layer(x)
|
||||
skip2 = self.down1(skip1)
|
||||
skip3 = self.down2(skip2)
|
||||
skip4 = self.down3(skip3)
|
||||
res = self.down4(skip4)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
x = self.BottleNeck[i](res, id)
|
||||
|
||||
x = self.up4(x)
|
||||
x = self.up3(x)
|
||||
x = self.up2(x)
|
||||
x = self.up1(x)
|
||||
x = self.last_layer(x)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: PatchNCE.py
|
||||
# Created Date: Friday January 21st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 21st January 2022 5:04:43 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th October 2021 7:44:02 pm
|
||||
# Last Modified: Friday, 21st January 2022 10:55:59 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -30,24 +30,23 @@ def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='fastnst_3',
|
||||
parser.add_argument('-v', '--version', type=str, default='2layerFM',
|
||||
help="version name for train, test, finetune")
|
||||
|
||||
parser.add_argument('-c', '--cuda', type=int, default=-1) # >0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-e', '--checkpoint_epoch', type=int, default=19,
|
||||
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-s', '--checkpoint_step', type=int, default=310000,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
|
||||
# test
|
||||
parser.add_argument('-t', '--test_script_name', type=str, default='FastNST')
|
||||
parser.add_argument('-t', '--test_script_name', type=str, default='video')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=1)
|
||||
parser.add_argument('-n', '--node_name', type=str, default='localhost',
|
||||
choices=['localhost', '4card','8card','new4card'])
|
||||
|
||||
parser.add_argument('--save_test_result', action='store_false')
|
||||
|
||||
parser.add_argument('--test_dataloader', type=str, default='dir')
|
||||
|
||||
parser.add_argument('-p', '--test_data_path', type=str, default='G:\\UltraHighStyleTransfer\\benchmark')
|
||||
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\dlrb2.jpeg')
|
||||
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\G2010.mp4',
|
||||
help="file path for attribute images or video")
|
||||
|
||||
parser.add_argument('--use_specified_data', action='store_true')
|
||||
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
|
||||
@@ -235,10 +234,10 @@ def main():
|
||||
|
||||
# TODO get the checkpoint file path
|
||||
sys_state["ckp_name"] = {}
|
||||
for data_key in sys_state["checkpoint_names"].keys():
|
||||
sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"],
|
||||
"%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
sys_state["checkpoint_names"][data_key]))
|
||||
# for data_key in sys_state["checkpoint_names"].keys():
|
||||
# sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"],
|
||||
# "%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
# sys_state["checkpoint_names"][data_key]))
|
||||
|
||||
# Get the test configurations
|
||||
sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"]
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th October 2021 8:22:37 pm
|
||||
# Last Modified: Sunday, 4th July 2021 11:32:14 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -35,13 +35,13 @@ class Tester(object):
|
||||
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'TestDataset')
|
||||
dataloader = dataloaderClass(config["test_data_path"],
|
||||
1,
|
||||
config["batch_size"],
|
||||
["png","jpg"])
|
||||
self.test_loader= dataloader
|
||||
|
||||
self.test_iter = len(dataloader)
|
||||
# if len(dataloader)%config["batch_size"]>0:
|
||||
# self.test_iter+=1
|
||||
self.test_iter = len(dataloader)//config["batch_size"]
|
||||
if len(dataloader)%config["batch_size"]>0:
|
||||
self.test_iter+=1
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
@@ -52,14 +52,19 @@ class Tester(object):
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
script_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
script_name = "components."+self.config["module_script_name"]
|
||||
class_name = self.config["class_name"]
|
||||
package = __import__(script_name, fromlist=True)
|
||||
network_class = getattr(package, class_name)
|
||||
n_class = len(self.config["selectedStyleDir"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = network_class(**model_config["g_model"]["module_params"])
|
||||
self.network = network_class(self.config["GConvDim"],
|
||||
self.config["GKS"],
|
||||
self.config["resNum"],
|
||||
n_class
|
||||
#**self.config["module_params"]
|
||||
)
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
@@ -68,14 +73,12 @@ class Tester(object):
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_epoch"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
|
||||
self.network.load_state_dict(torch.load(model_path))
|
||||
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
|
||||
# print(loader1.key())
|
||||
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
|
||||
self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"])
|
||||
# self.network.load_state_dict(torch.load(pathwocao))
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"]))
|
||||
|
||||
def test(self):
|
||||
|
||||
@@ -84,13 +87,18 @@ class Tester(object):
|
||||
ckp_epoch = self.config["checkpoint_epoch"]
|
||||
version = self.config["version"]
|
||||
batch_size = self.config["batch_size"]
|
||||
win_size = self.config["model_configs"]["g_model"]["module_params"]["window_size"]
|
||||
style_names = self.config["selectedStyleDir"]
|
||||
n_class = len(style_names)
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
condition_labels = torch.ones((n_class, batch_size, 1)).long()
|
||||
for i in range(n_class):
|
||||
condition_labels[i,:,:] = condition_labels[i,:,:]*i
|
||||
if self.config["cuda"] >=0:
|
||||
condition_labels = condition_labels.cuda()
|
||||
total = len(self.test_loader)
|
||||
print("total:", total)
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
@@ -98,25 +106,18 @@ class Tester(object):
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
for _ in tqdm(range(total)):
|
||||
for _ in tqdm(range(total//batch_size)):
|
||||
contents, img_names = self.test_loader()
|
||||
B, C, H, W = contents.shape
|
||||
crop_h = H - H%32
|
||||
crop_w = W - W%32
|
||||
crop_s = min(crop_h, crop_w)
|
||||
contents = contents[:,:,(H//2 - crop_s//2):(crop_s//2 + H//2),
|
||||
(W//2 - crop_s//2):(crop_s//2 + W//2)]
|
||||
if self.config["cuda"] >=0:
|
||||
contents = contents.cuda()
|
||||
res = self.network(contents, (crop_s, crop_s))
|
||||
print("res shape:", res.shape)
|
||||
res = tensor2img(res.cpu())
|
||||
temp_img = res[0,:,:,:]
|
||||
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
|
||||
print(save_dir)
|
||||
print(img_names[0])
|
||||
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}.png'.format(
|
||||
img_names[0], version, ckp_epoch)),temp_img)
|
||||
for i in range(n_class):
|
||||
if self.config["cuda"] >=0:
|
||||
contents = contents.cuda()
|
||||
res, _ = self.network(contents, condition_labels[i, 0, :])
|
||||
res = tensor2img(res.cpu())
|
||||
for t in range(batch_size):
|
||||
temp_img = res[t,:,:,:]
|
||||
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format(
|
||||
img_names[t], version, ckp_epoch, style_names[i])),temp_img)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
+174
-51
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 4th July 2021 11:32:14 am
|
||||
# Last Modified: Friday, 21st January 2022 11:06:37 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -15,12 +15,24 @@
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from utilities.utilities import tensor2img
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
# from utilities.Reporter import Reporter
|
||||
from moviepy.editor import AudioFileClip, VideoFileClip
|
||||
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import glob
|
||||
|
||||
from utilities.ImagenetNorm import ImagenetNorm
|
||||
from parsing_model.model import BiSeNet
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
from utilities.reverse2original import reverse2wholeimage
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
@@ -29,20 +41,126 @@ class Tester(object):
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
#============build evaluation dataloader==============#
|
||||
print("Prepare the test dataloader...")
|
||||
dlModulename = config["test_dataloader"]
|
||||
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'TestDataset')
|
||||
dataloader = dataloaderClass(config["test_data_path"],
|
||||
config["batch_size"],
|
||||
["png","jpg"])
|
||||
self.test_loader= dataloader
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
|
||||
def cv2totensor(self, cv2_img):
|
||||
"""
|
||||
cv2_img: an image read by cv2, H*W*C
|
||||
return: an 1*C*H*W tensor
|
||||
"""
|
||||
cv2_img = cv2.cvtColor(cv2_img,cv2.COLOR_BGR2RGB)
|
||||
cv2_img = torch.from_numpy(cv2_img)
|
||||
cv2_img = cv2_img.permute(2,0,1).cuda()
|
||||
temp = cv2_img / 255.0
|
||||
temp -= self.imagenet_mean
|
||||
temp /= self.imagenet_std
|
||||
return temp.unsqueeze(0)
|
||||
|
||||
self.test_iter = len(dataloader)//config["batch_size"]
|
||||
if len(dataloader)%config["batch_size"]>0:
|
||||
self.test_iter+=1
|
||||
def video_swap(
|
||||
self,
|
||||
video_path,
|
||||
id_vetor,
|
||||
save_path,
|
||||
temp_results_dir='./temp_results',
|
||||
crop_size=512,
|
||||
use_mask =False
|
||||
):
|
||||
|
||||
video_forcheck = VideoFileClip(video_path)
|
||||
if video_forcheck.audio is None:
|
||||
no_audio = True
|
||||
else:
|
||||
no_audio = False
|
||||
|
||||
del video_forcheck
|
||||
|
||||
if not no_audio:
|
||||
video_audio_clip = AudioFileClip(video_path)
|
||||
|
||||
video = cv2.VideoCapture(video_path)
|
||||
ret = True
|
||||
frame_index = 0
|
||||
|
||||
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
# video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
|
||||
# video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
fps = video.get(cv2.CAP_PROP_FPS)
|
||||
if os.path.exists(temp_results_dir):
|
||||
shutil.rmtree(temp_results_dir)
|
||||
spNorm =ImagenetNorm()
|
||||
if use_mask:
|
||||
n_classes = 19
|
||||
net = BiSeNet(n_classes=n_classes)
|
||||
net.cuda()
|
||||
save_pth = os.path.join('./parsing_model', '79999_iter.pth')
|
||||
net.load_state_dict(torch.load(save_pth))
|
||||
net.eval()
|
||||
else:
|
||||
net =None
|
||||
|
||||
# while ret:
|
||||
for frame_index in tqdm(range(frame_count)):
|
||||
ret, frame = video.read()
|
||||
if ret:
|
||||
detect_results = self.detect.get(frame,crop_size)
|
||||
|
||||
if detect_results is not None:
|
||||
# print(frame_index)
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame_align_crop_list = detect_results[0]
|
||||
frame_mat_list = detect_results[1]
|
||||
swap_result_list = []
|
||||
frame_align_crop_tenor_list = []
|
||||
for frame_align_crop in frame_align_crop_list:
|
||||
frame_align_crop_tenor = self.cv2totensor(frame_align_crop)
|
||||
swap_result = self.network(frame_align_crop_tenor, id_vetor)[0]
|
||||
swap_result = swap_result* self.imagenet_std + self.imagenet_mean
|
||||
swap_result = torch.clip(swap_result,0.0,1.0)
|
||||
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
|
||||
swap_result_list.append(swap_result)
|
||||
frame_align_crop_tenor_list.append(frame_align_crop_tenor)
|
||||
reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame,\
|
||||
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),pasring_model =net,use_mask=use_mask, norm = spNorm)
|
||||
|
||||
else:
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame = frame.astype(np.uint8)
|
||||
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
|
||||
else:
|
||||
break
|
||||
|
||||
video.release()
|
||||
|
||||
# image_filename_list = []
|
||||
path = os.path.join(temp_results_dir,'*.jpg')
|
||||
image_filenames = sorted(glob.glob(path))
|
||||
|
||||
clips = ImageSequenceClip(image_filenames,fps = fps)
|
||||
|
||||
if not no_audio:
|
||||
clips = clips.set_audio(video_audio_clip)
|
||||
basename = os.path.basename(video_path)
|
||||
basename = os.path.splitext(basename)[0]
|
||||
save_filename = os.path.join(save_path, basename+".mp4")
|
||||
index = 0
|
||||
while(True):
|
||||
if os.path.exists(save_filename):
|
||||
save_filename = os.path.join(save_path, basename+"_%d.mp4"%index)
|
||||
index += 1
|
||||
else:
|
||||
break
|
||||
clips.write_videofile(save_filename,audio_codec='aac')
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
@@ -52,53 +170,68 @@ class Tester(object):
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
script_name = "components."+self.config["module_script_name"]
|
||||
class_name = self.config["class_name"]
|
||||
package = __import__(script_name, fromlist=True)
|
||||
network_class = getattr(package, class_name)
|
||||
n_class = len(self.config["selectedStyleDir"])
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = network_class(self.config["GConvDim"],
|
||||
self.config["GKS"],
|
||||
self.config["resNum"],
|
||||
n_class
|
||||
#**self.config["module_params"]
|
||||
)
|
||||
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
|
||||
# print(loader1.key())
|
||||
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
|
||||
self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"])
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path))
|
||||
# self.network.load_state_dict(torch.load(pathwocao))
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"]))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
def test(self):
|
||||
|
||||
# save_result = self.config["saveTestResult"]
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_epoch = self.config["checkpoint_epoch"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_names = self.config["selectedStyleDir"]
|
||||
n_class = len(style_names)
|
||||
id_imgs = self.config["id_imgs"]
|
||||
attr_files = self.config["attr_files"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
condition_labels = torch.ones((n_class, batch_size, 1)).long()
|
||||
for i in range(n_class):
|
||||
condition_labels[i,:,:] = condition_labels[i,:,:]*i
|
||||
if self.config["cuda"] >=0:
|
||||
condition_labels = condition_labels.cuda()
|
||||
total = len(self.test_loader)
|
||||
|
||||
|
||||
mode = None
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
id_img = cv2.imread(id_imgs)
|
||||
id_img_align_crop, _ = self.detect.get(id_img,512)
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
id_img = self.transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0).cuda()
|
||||
|
||||
#create latent id
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
latend_id = self.arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
@@ -106,18 +239,8 @@ class Tester(object):
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
for _ in tqdm(range(total//batch_size)):
|
||||
contents, img_names = self.test_loader()
|
||||
for i in range(n_class):
|
||||
if self.config["cuda"] >=0:
|
||||
contents = contents.cuda()
|
||||
res, _ = self.network(contents, condition_labels[i, 0, :])
|
||||
res = tensor2img(res.cpu())
|
||||
for t in range(batch_size):
|
||||
temp_img = res[t,:,:,:]
|
||||
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format(
|
||||
img_names[t], version, ckp_epoch, style_names[i])),temp_img)
|
||||
self.video_swap(attr_files, latend_id, save_dir, temp_results_dir="./.temples",\
|
||||
use_mask=False,crop_size=512)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: ImagenetNorm.py
|
||||
# Created Date: Friday January 21st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 21st January 2022 10:41:50 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch
|
||||
class ImagenetNorm(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
"""
|
||||
@notice: avoid in-place ops.
|
||||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
||||
"""
|
||||
super(ImagenetNorm, self).__init__()
|
||||
self.mean = np.array([0.485, 0.456, 0.406])
|
||||
self.mean = torch.from_numpy(self.mean).float().cuda()
|
||||
self.mean = self.mean.view([1, 3, 1, 1])
|
||||
|
||||
self.std = np.array([0.229, 0.224, 0.225])
|
||||
self.std = torch.from_numpy(self.std).float().cuda()
|
||||
self.std = self.std.view([1, 3, 1, 1])
|
||||
|
||||
def forward(self, x):
|
||||
mean = self.mean.expand([1, 3, x.shape[2], x.shape[3]])
|
||||
std = self.std.expand([1, 3, x.shape[2], x.shape[3]])
|
||||
|
||||
x = (x - mean) / std
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,173 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
# import time
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def encode_segmentation_rgb(segmentation, no_neck=True):
|
||||
parse = segmentation
|
||||
|
||||
face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
|
||||
mouth_id = 11
|
||||
# hair_id = 17
|
||||
face_map = np.zeros([parse.shape[0], parse.shape[1]])
|
||||
mouth_map = np.zeros([parse.shape[0], parse.shape[1]])
|
||||
# hair_map = np.zeros([parse.shape[0], parse.shape[1]])
|
||||
|
||||
for valid_id in face_part_ids:
|
||||
valid_index = np.where(parse==valid_id)
|
||||
face_map[valid_index] = 255
|
||||
valid_index = np.where(parse==mouth_id)
|
||||
mouth_map[valid_index] = 255
|
||||
# valid_index = np.where(parse==hair_id)
|
||||
# hair_map[valid_index] = 255
|
||||
#return np.stack([face_map, mouth_map,hair_map], axis=2)
|
||||
return np.stack([face_map, mouth_map], axis=2)
|
||||
|
||||
|
||||
class SoftErosion(nn.Module):
|
||||
def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
|
||||
super(SoftErosion, self).__init__()
|
||||
r = kernel_size // 2
|
||||
self.padding = r
|
||||
self.iterations = iterations
|
||||
self.threshold = threshold
|
||||
|
||||
# Create kernel
|
||||
y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
|
||||
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
|
||||
kernel = dist.max() - dist
|
||||
kernel /= kernel.sum()
|
||||
kernel = kernel.view(1, 1, *kernel.shape)
|
||||
self.register_buffer('weight', kernel)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.float()
|
||||
for i in range(self.iterations - 1):
|
||||
x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
|
||||
x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
|
||||
|
||||
mask = x >= self.threshold
|
||||
x[mask] = 1.0
|
||||
x[~mask] /= x[~mask].max()
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
def postprocess(swapped_face, target, target_mask,smooth_mask):
|
||||
# target_mask = cv2.resize(target_mask, (self.size, self.size))
|
||||
|
||||
mask_tensor = torch.from_numpy(target_mask.copy().transpose((2, 0, 1))).float().mul_(1/255.0).cuda()
|
||||
face_mask_tensor = mask_tensor[0] + mask_tensor[1]
|
||||
|
||||
soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
|
||||
soft_face_mask_tensor.squeeze_()
|
||||
|
||||
soft_face_mask = soft_face_mask_tensor.cpu().numpy()
|
||||
soft_face_mask = soft_face_mask[:, :, np.newaxis]
|
||||
|
||||
result = swapped_face * soft_face_mask + target * (1 - soft_face_mask)
|
||||
result = result[:,:,::-1]# .astype(np.uint8)
|
||||
return result
|
||||
|
||||
def reverse2wholeimage(b_align_crop_tenor_list,swaped_imgs, mats, crop_size, oriimg, save_path = '', \
|
||||
pasring_model =None, norm = None, use_mask = False):
|
||||
|
||||
target_image_list = []
|
||||
img_mask_list = []
|
||||
if use_mask:
|
||||
smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).cuda()
|
||||
else:
|
||||
pass
|
||||
|
||||
# print(len(swaped_imgs))
|
||||
# print(mats)
|
||||
# print(len(b_align_crop_tenor_list))
|
||||
for swaped_img, mat ,source_img in zip(swaped_imgs, mats,b_align_crop_tenor_list):
|
||||
swaped_img = swaped_img.cpu().detach().numpy().transpose((1, 2, 0))
|
||||
img_white = np.full((crop_size,crop_size), 255, dtype=float)
|
||||
|
||||
# inverse the Affine transformation matrix
|
||||
mat_rev = np.zeros([2,3])
|
||||
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
|
||||
mat_rev[0][0] = mat[1][1]/div1
|
||||
mat_rev[0][1] = -mat[0][1]/div1
|
||||
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
|
||||
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
|
||||
mat_rev[1][0] = mat[1][0]/div2
|
||||
mat_rev[1][1] = -mat[0][0]/div2
|
||||
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
|
||||
|
||||
orisize = (oriimg.shape[1], oriimg.shape[0])
|
||||
if use_mask:
|
||||
source_img_norm = norm(source_img)
|
||||
source_img_512 = F.interpolate(source_img_norm,size=(512,512))
|
||||
out = pasring_model(source_img_512)[0]
|
||||
parsing = out.squeeze(0).detach().cpu().numpy().argmax(0)
|
||||
vis_parsing_anno = parsing.copy().astype(np.uint8)
|
||||
tgt_mask = encode_segmentation_rgb(vis_parsing_anno)
|
||||
if tgt_mask.sum() >= 5000:
|
||||
# face_mask_tensor = tgt_mask[...,0] + tgt_mask[...,1]
|
||||
target_mask = cv2.resize(tgt_mask, (crop_size, crop_size))
|
||||
# print(source_img)
|
||||
target_image_parsing = postprocess(swaped_img, source_img[0].cpu().detach().numpy().transpose((1, 2, 0)), target_mask,smooth_mask)
|
||||
|
||||
|
||||
target_image = cv2.warpAffine(target_image_parsing, mat_rev, orisize)
|
||||
# target_image_parsing = cv2.warpAffine(swaped_img, mat_rev, orisize)
|
||||
else:
|
||||
target_image = cv2.warpAffine(swaped_img, mat_rev, orisize)[..., ::-1]
|
||||
else:
|
||||
target_image = cv2.warpAffine(swaped_img, mat_rev, orisize)
|
||||
# source_image = cv2.warpAffine(source_img, mat_rev, orisize)
|
||||
|
||||
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
|
||||
|
||||
|
||||
img_white[img_white>20] =255
|
||||
|
||||
img_mask = img_white
|
||||
|
||||
# if use_mask:
|
||||
# kernel = np.ones((40,40),np.uint8)
|
||||
# img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
# else:
|
||||
kernel = np.ones((40,40),np.uint8)
|
||||
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
kernel_size = (20, 20)
|
||||
blur_size = tuple(2*i+1 for i in kernel_size)
|
||||
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
|
||||
|
||||
# kernel = np.ones((10,10),np.uint8)
|
||||
# img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
|
||||
|
||||
|
||||
img_mask /= 255
|
||||
|
||||
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
||||
|
||||
# pasing mask
|
||||
|
||||
# target_image_parsing = postprocess(target_image, source_image, tgt_mask)
|
||||
|
||||
if use_mask:
|
||||
target_image = np.array(target_image, dtype=np.float) * 255
|
||||
else:
|
||||
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
|
||||
|
||||
|
||||
img_mask_list.append(img_mask)
|
||||
target_image_list.append(target_image)
|
||||
|
||||
|
||||
# target_image /= 255
|
||||
# target_image = 0
|
||||
img = np.array(oriimg, dtype=np.float)
|
||||
for img_mask, target_image in zip(img_mask_list, target_image_list):
|
||||
img = img_mask * target_image + (1-img_mask) * img
|
||||
|
||||
final_img = img.astype(np.uint8)
|
||||
cv2.imwrite(save_path, final_img)
|
||||
Reference in New Issue
Block a user