This commit is contained in:
chenxuanhong
2022-01-21 18:01:36 +08:00
parent e698d99173
commit bebaeef2ce
8 changed files with 678 additions and 101 deletions
+3 -1
View File
@@ -126,4 +126,6 @@ wandb/
train_logs/
test_logs/
arcface_ckpt/
GUI/
GUI/
insightface_func/
parsing_model/
+228
View File
@@ -0,0 +1,228 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator.py
# Created Date: Sunday January 16th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 20th January 2022 10:51:02 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
class InstanceNorm(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(InstanceNorm, self).__init__()
self.epsilon = epsilon
def forward(self, x):
x = x - torch.mean(x, (2, 3), True)
tmp = torch.mul(x, x) # or x ** 2
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
return x * tmp
class ApplyStyle(nn.Module):
"""
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
"""
def __init__(self, latent_size, channels):
super(ApplyStyle, self).__init__()
self.linear = nn.Linear(latent_size, channels * 2)
def forward(self, x, latent):
style = self.linear(latent) # style => [batch_size, n_channels*2]
shape = [-1, 2, x.size(1), 1, 1]
style = style.view(shape) # [batch_size, 2, n_channels, ...]
#x = x * (style[:, 0] + 1.) + style[:, 1]
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
return x
class ResnetBlock_Adain(nn.Module):
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
super(ResnetBlock_Adain, self).__init__()
p = 0
conv1 = []
if padding_type == 'reflect':
conv1 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv1 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
self.conv1 = nn.Sequential(*conv1)
self.style1 = ApplyStyle(latent_size, dim)
self.act1 = activation
p = 0
conv2 = []
if padding_type == 'reflect':
conv2 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv2 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
self.conv2 = nn.Sequential(*conv2)
self.style2 = ApplyStyle(latent_size, dim)
def forward(self, x, dlatents_in_slice):
y = self.conv1(x)
y = self.style1(y, dlatents_in_slice)
y = self.act1(y)
y = self.conv2(y)
y = self.style2(y, dlatents_in_slice)
out = x + y
return out
class newres(nn.Module):
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
super(ResnetBlock_Adain, self).__init__()
p = 0
conv1 = []
if padding_type == 'reflect':
conv1 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv1 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv1 += [nn.Conv2d(dim, dim, kernel_size=1), InstanceNorm()]
self.conv1 = nn.Sequential(*conv1)
self.style1 = ApplyStyle(latent_size, dim)
self.act1 = activation
p = 0
conv2 = []
if padding_type == 'reflect':
conv2 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv2 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
self.conv2 = nn.Sequential(*conv2)
self.style2 = ApplyStyle(latent_size, dim)
def forward(self, x, dlatents_in_slice):
y = self.conv1(x)
y = self.style1(y, dlatents_in_slice)
y = self.act1(y)
y = self.conv2(y)
y = self.style2(y, dlatents_in_slice)
out = x + y
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
chn = kwargs["g_conv_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.ReLU(True)
self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64), activation)
### downsample
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128), activation)
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256), activation)
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512), activation)
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512), activation)
### resnet blocks
BN = []
for i in range(res_num):
BN += [
ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)]
self.BottleNeck = nn.Sequential(*BN)
self.up4 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512), activation
)
self.up3 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256), activation
)
self.up2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128), activation
)
self.up1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64), activation
)
self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, input, id):
x = input # 3*224*224
skip1 = self.first_layer(x)
skip2 = self.down1(skip1)
skip3 = self.down2(skip2)
skip4 = self.down3(skip3)
res = self.down4(skip4)
for i in range(len(self.BottleNeck)):
x = self.BottleNeck[i](res, id)
x = self.up4(x)
x = self.up3(x)
x = self.up2(x)
x = self.up1(x)
x = self.last_layer(x)
return x
+13
View File
@@ -0,0 +1,13 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: PatchNCE.py
# Created Date: Friday January 21st 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 21st January 2022 5:04:43 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
+12 -13
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 7:44:02 pm
# Last Modified: Friday, 21st January 2022 10:55:59 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -30,24 +30,23 @@ def getParameters():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('-v', '--version', type=str, default='fastnst_3',
parser.add_argument('-v', '--version', type=str, default='2layerFM',
help="version name for train, test, finetune")
parser.add_argument('-c', '--cuda', type=int, default=-1) # >0 if it is set as -1, program will use CPU
parser.add_argument('-e', '--checkpoint_epoch', type=int, default=19,
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
parser.add_argument('-s', '--checkpoint_step', type=int, default=310000,
help="checkpoint epoch for test phase or finetune phase")
# test
parser.add_argument('-t', '--test_script_name', type=str, default='FastNST')
parser.add_argument('-t', '--test_script_name', type=str, default='video')
parser.add_argument('-b', '--batch_size', type=int, default=1)
parser.add_argument('-n', '--node_name', type=str, default='localhost',
choices=['localhost', '4card','8card','new4card'])
parser.add_argument('--save_test_result', action='store_false')
parser.add_argument('--test_dataloader', type=str, default='dir')
parser.add_argument('-p', '--test_data_path', type=str, default='G:\\UltraHighStyleTransfer\\benchmark')
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\dlrb2.jpeg')
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\G2010.mp4',
help="file path for attribute images or video")
parser.add_argument('--use_specified_data', action='store_true')
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
@@ -235,10 +234,10 @@ def main():
# TODO get the checkpoint file path
sys_state["ckp_name"] = {}
for data_key in sys_state["checkpoint_names"].keys():
sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"],
"%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"][data_key]))
# for data_key in sys_state["checkpoint_names"].keys():
# sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"],
# "%d_%s.pth"%(sys_state["checkpoint_epoch"],
# sys_state["checkpoint_names"][data_key]))
# Get the test configurations
sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"]
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 8:22:37 pm
# Last Modified: Sunday, 4th July 2021 11:32:14 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -35,13 +35,13 @@ class Tester(object):
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'TestDataset')
dataloader = dataloaderClass(config["test_data_path"],
1,
config["batch_size"],
["png","jpg"])
self.test_loader= dataloader
self.test_iter = len(dataloader)
# if len(dataloader)%config["batch_size"]>0:
# self.test_iter+=1
self.test_iter = len(dataloader)//config["batch_size"]
if len(dataloader)%config["batch_size"]>0:
self.test_iter+=1
def __init_framework__(self):
@@ -52,14 +52,19 @@ class Tester(object):
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
script_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
script_name = "components."+self.config["module_script_name"]
class_name = self.config["class_name"]
package = __import__(script_name, fromlist=True)
network_class = getattr(package, class_name)
n_class = len(self.config["selectedStyleDir"])
# TODO replace below lines to define the model framework
self.network = network_class(**model_config["g_model"]["module_params"])
self.network = network_class(self.config["GConvDim"],
self.config["GKS"],
self.config["resNum"],
n_class
#**self.config["module_params"]
)
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
@@ -68,14 +73,12 @@ class Tester(object):
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_epoch"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path))
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
# print(loader1.key())
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"])
# self.network.load_state_dict(torch.load(pathwocao))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"]))
def test(self):
@@ -84,13 +87,18 @@ class Tester(object):
ckp_epoch = self.config["checkpoint_epoch"]
version = self.config["version"]
batch_size = self.config["batch_size"]
win_size = self.config["model_configs"]["g_model"]["module_params"]["window_size"]
style_names = self.config["selectedStyleDir"]
n_class = len(style_names)
# models
self.__init_framework__()
condition_labels = torch.ones((n_class, batch_size, 1)).long()
for i in range(n_class):
condition_labels[i,:,:] = condition_labels[i,:,:]*i
if self.config["cuda"] >=0:
condition_labels = condition_labels.cuda()
total = len(self.test_loader)
print("total:", total)
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
@@ -98,25 +106,18 @@ class Tester(object):
start_time = time.time()
self.network.eval()
with torch.no_grad():
for _ in tqdm(range(total)):
for _ in tqdm(range(total//batch_size)):
contents, img_names = self.test_loader()
B, C, H, W = contents.shape
crop_h = H - H%32
crop_w = W - W%32
crop_s = min(crop_h, crop_w)
contents = contents[:,:,(H//2 - crop_s//2):(crop_s//2 + H//2),
(W//2 - crop_s//2):(crop_s//2 + W//2)]
if self.config["cuda"] >=0:
contents = contents.cuda()
res = self.network(contents, (crop_s, crop_s))
print("res shape:", res.shape)
res = tensor2img(res.cpu())
temp_img = res[0,:,:,:]
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
print(save_dir)
print(img_names[0])
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}.png'.format(
img_names[0], version, ckp_epoch)),temp_img)
for i in range(n_class):
if self.config["cuda"] >=0:
contents = contents.cuda()
res, _ = self.network(contents, condition_labels[i, 0, :])
res = tensor2img(res.cpu())
for t in range(batch_size):
temp_img = res[t,:,:,:]
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format(
img_names[t], version, ckp_epoch, style_names[i])),temp_img)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
+174 -51
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 4th July 2021 11:32:14 am
# Last Modified: Friday, 21st January 2022 11:06:37 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -15,12 +15,24 @@
import os
import cv2
import time
import shutil
import torch
from utilities.utilities import tensor2img
import torch.nn.functional as F
from torchvision import transforms
# from utilities.Reporter import Reporter
from moviepy.editor import AudioFileClip, VideoFileClip
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
import numpy as np
from tqdm import tqdm
from PIL import Image
import glob
from utilities.ImagenetNorm import ImagenetNorm
from parsing_model.model import BiSeNet
from insightface_func.face_detect_crop_single import Face_detect_crop
from utilities.reverse2original import reverse2wholeimage
class Tester(object):
def __init__(self, config, reporter):
@@ -29,20 +41,126 @@ class Tester(object):
# logger
self.reporter = reporter
#============build evaluation dataloader==============#
print("Prepare the test dataloader...")
dlModulename = config["test_dataloader"]
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'TestDataset')
dataloader = dataloaderClass(config["test_data_path"],
config["batch_size"],
["png","jpg"])
self.test_loader= dataloader
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def cv2totensor(self, cv2_img):
"""
cv2_img: an image read by cv2, H*W*C
return: an 1*C*H*W tensor
"""
cv2_img = cv2.cvtColor(cv2_img,cv2.COLOR_BGR2RGB)
cv2_img = torch.from_numpy(cv2_img)
cv2_img = cv2_img.permute(2,0,1).cuda()
temp = cv2_img / 255.0
temp -= self.imagenet_mean
temp /= self.imagenet_std
return temp.unsqueeze(0)
self.test_iter = len(dataloader)//config["batch_size"]
if len(dataloader)%config["batch_size"]>0:
self.test_iter+=1
def video_swap(
self,
video_path,
id_vetor,
save_path,
temp_results_dir='./temp_results',
crop_size=512,
use_mask =False
):
video_forcheck = VideoFileClip(video_path)
if video_forcheck.audio is None:
no_audio = True
else:
no_audio = False
del video_forcheck
if not no_audio:
video_audio_clip = AudioFileClip(video_path)
video = cv2.VideoCapture(video_path)
ret = True
frame_index = 0
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
# video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = video.get(cv2.CAP_PROP_FPS)
if os.path.exists(temp_results_dir):
shutil.rmtree(temp_results_dir)
spNorm =ImagenetNorm()
if use_mask:
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = os.path.join('./parsing_model', '79999_iter.pth')
net.load_state_dict(torch.load(save_pth))
net.eval()
else:
net =None
# while ret:
for frame_index in tqdm(range(frame_count)):
ret, frame = video.read()
if ret:
detect_results = self.detect.get(frame,crop_size)
if detect_results is not None:
# print(frame_index)
if not os.path.exists(temp_results_dir):
os.mkdir(temp_results_dir)
frame_align_crop_list = detect_results[0]
frame_mat_list = detect_results[1]
swap_result_list = []
frame_align_crop_tenor_list = []
for frame_align_crop in frame_align_crop_list:
frame_align_crop_tenor = self.cv2totensor(frame_align_crop)
swap_result = self.network(frame_align_crop_tenor, id_vetor)[0]
swap_result = swap_result* self.imagenet_std + self.imagenet_mean
swap_result = torch.clip(swap_result,0.0,1.0)
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
swap_result_list.append(swap_result)
frame_align_crop_tenor_list.append(frame_align_crop_tenor)
reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame,\
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),pasring_model =net,use_mask=use_mask, norm = spNorm)
else:
if not os.path.exists(temp_results_dir):
os.mkdir(temp_results_dir)
frame = frame.astype(np.uint8)
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
else:
break
video.release()
# image_filename_list = []
path = os.path.join(temp_results_dir,'*.jpg')
image_filenames = sorted(glob.glob(path))
clips = ImageSequenceClip(image_filenames,fps = fps)
if not no_audio:
clips = clips.set_audio(video_audio_clip)
basename = os.path.basename(video_path)
basename = os.path.splitext(basename)[0]
save_filename = os.path.join(save_path, basename+".mp4")
index = 0
while(True):
if os.path.exists(save_filename):
save_filename = os.path.join(save_path, basename+"_%d.mp4"%index)
index += 1
else:
break
clips.write_videofile(save_filename,audio_codec='aac')
def __init_framework__(self):
'''
@@ -52,53 +170,68 @@ class Tester(object):
#===============build models================#
print("build models...")
# TODO [import models here]
script_name = "components."+self.config["module_script_name"]
class_name = self.config["class_name"]
package = __import__(script_name, fromlist=True)
network_class = getattr(package, class_name)
n_class = len(self.config["selectedStyleDir"])
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = network_class(self.config["GConvDim"],
self.config["GKS"],
self.config["resNum"],
n_class
#**self.config["module_params"]
)
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
# print(loader1.key())
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"])
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path))
# self.network.load_state_dict(torch.load(pathwocao))
print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"]))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
def test(self):
# save_result = self.config["saveTestResult"]
save_dir = self.config["test_samples_path"]
ckp_epoch = self.config["checkpoint_epoch"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
batch_size = self.config["batch_size"]
style_names = self.config["selectedStyleDir"]
n_class = len(style_names)
id_imgs = self.config["id_imgs"]
attr_files = self.config["attr_files"]
self.arcface_ckpt= self.config["arcface_ckpt"]
# models
self.__init_framework__()
condition_labels = torch.ones((n_class, batch_size, 1)).long()
for i in range(n_class):
condition_labels[i,:,:] = condition_labels[i,:,:]*i
if self.config["cuda"] >=0:
condition_labels = condition_labels.cuda()
total = len(self.test_loader)
mode = None
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode)
id_img = cv2.imread(id_imgs)
id_img_align_crop, _ = self.detect.get(id_img,512)
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
@@ -106,18 +239,8 @@ class Tester(object):
start_time = time.time()
self.network.eval()
with torch.no_grad():
for _ in tqdm(range(total//batch_size)):
contents, img_names = self.test_loader()
for i in range(n_class):
if self.config["cuda"] >=0:
contents = contents.cuda()
res, _ = self.network(contents, condition_labels[i, 0, :])
res = tensor2img(res.cpu())
for t in range(batch_size):
temp_img = res[t,:,:,:]
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format(
img_names[t], version, ckp_epoch, style_names[i])),temp_img)
self.video_swap(attr_files, latend_id, save_dir, temp_results_dir="./.temples",\
use_mask=False,crop_size=512)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
+38
View File
@@ -0,0 +1,38 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: ImagenetNorm.py
# Created Date: Friday January 21st 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 21st January 2022 10:41:50 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch.nn as nn
import numpy as np
import torch
class ImagenetNorm(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(ImagenetNorm, self).__init__()
self.mean = np.array([0.485, 0.456, 0.406])
self.mean = torch.from_numpy(self.mean).float().cuda()
self.mean = self.mean.view([1, 3, 1, 1])
self.std = np.array([0.229, 0.224, 0.225])
self.std = torch.from_numpy(self.std).float().cuda()
self.std = self.std.view([1, 3, 1, 1])
def forward(self, x):
mean = self.mean.expand([1, 3, x.shape[2], x.shape[3]])
std = self.std.expand([1, 3, x.shape[2], x.shape[3]])
x = (x - mean) / std
return x
+173
View File
@@ -0,0 +1,173 @@
import cv2
import numpy as np
# import time
import torch
from torch.nn import functional as F
import torch.nn as nn
def encode_segmentation_rgb(segmentation, no_neck=True):
parse = segmentation
face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
mouth_id = 11
# hair_id = 17
face_map = np.zeros([parse.shape[0], parse.shape[1]])
mouth_map = np.zeros([parse.shape[0], parse.shape[1]])
# hair_map = np.zeros([parse.shape[0], parse.shape[1]])
for valid_id in face_part_ids:
valid_index = np.where(parse==valid_id)
face_map[valid_index] = 255
valid_index = np.where(parse==mouth_id)
mouth_map[valid_index] = 255
# valid_index = np.where(parse==hair_id)
# hair_map[valid_index] = 255
#return np.stack([face_map, mouth_map,hair_map], axis=2)
return np.stack([face_map, mouth_map], axis=2)
class SoftErosion(nn.Module):
def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
super(SoftErosion, self).__init__()
r = kernel_size // 2
self.padding = r
self.iterations = iterations
self.threshold = threshold
# Create kernel
y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
kernel = dist.max() - dist
kernel /= kernel.sum()
kernel = kernel.view(1, 1, *kernel.shape)
self.register_buffer('weight', kernel)
def forward(self, x):
x = x.float()
for i in range(self.iterations - 1):
x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
mask = x >= self.threshold
x[mask] = 1.0
x[~mask] /= x[~mask].max()
return x, mask
def postprocess(swapped_face, target, target_mask,smooth_mask):
# target_mask = cv2.resize(target_mask, (self.size, self.size))
mask_tensor = torch.from_numpy(target_mask.copy().transpose((2, 0, 1))).float().mul_(1/255.0).cuda()
face_mask_tensor = mask_tensor[0] + mask_tensor[1]
soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
soft_face_mask_tensor.squeeze_()
soft_face_mask = soft_face_mask_tensor.cpu().numpy()
soft_face_mask = soft_face_mask[:, :, np.newaxis]
result = swapped_face * soft_face_mask + target * (1 - soft_face_mask)
result = result[:,:,::-1]# .astype(np.uint8)
return result
def reverse2wholeimage(b_align_crop_tenor_list,swaped_imgs, mats, crop_size, oriimg, save_path = '', \
pasring_model =None, norm = None, use_mask = False):
target_image_list = []
img_mask_list = []
if use_mask:
smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).cuda()
else:
pass
# print(len(swaped_imgs))
# print(mats)
# print(len(b_align_crop_tenor_list))
for swaped_img, mat ,source_img in zip(swaped_imgs, mats,b_align_crop_tenor_list):
swaped_img = swaped_img.cpu().detach().numpy().transpose((1, 2, 0))
img_white = np.full((crop_size,crop_size), 255, dtype=float)
# inverse the Affine transformation matrix
mat_rev = np.zeros([2,3])
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
mat_rev[0][0] = mat[1][1]/div1
mat_rev[0][1] = -mat[0][1]/div1
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
mat_rev[1][0] = mat[1][0]/div2
mat_rev[1][1] = -mat[0][0]/div2
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
orisize = (oriimg.shape[1], oriimg.shape[0])
if use_mask:
source_img_norm = norm(source_img)
source_img_512 = F.interpolate(source_img_norm,size=(512,512))
out = pasring_model(source_img_512)[0]
parsing = out.squeeze(0).detach().cpu().numpy().argmax(0)
vis_parsing_anno = parsing.copy().astype(np.uint8)
tgt_mask = encode_segmentation_rgb(vis_parsing_anno)
if tgt_mask.sum() >= 5000:
# face_mask_tensor = tgt_mask[...,0] + tgt_mask[...,1]
target_mask = cv2.resize(tgt_mask, (crop_size, crop_size))
# print(source_img)
target_image_parsing = postprocess(swaped_img, source_img[0].cpu().detach().numpy().transpose((1, 2, 0)), target_mask,smooth_mask)
target_image = cv2.warpAffine(target_image_parsing, mat_rev, orisize)
# target_image_parsing = cv2.warpAffine(swaped_img, mat_rev, orisize)
else:
target_image = cv2.warpAffine(swaped_img, mat_rev, orisize)[..., ::-1]
else:
target_image = cv2.warpAffine(swaped_img, mat_rev, orisize)
# source_image = cv2.warpAffine(source_img, mat_rev, orisize)
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
img_white[img_white>20] =255
img_mask = img_white
# if use_mask:
# kernel = np.ones((40,40),np.uint8)
# img_mask = cv2.erode(img_mask,kernel,iterations = 1)
# else:
kernel = np.ones((40,40),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel_size = (20, 20)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
# kernel = np.ones((10,10),np.uint8)
# img_mask = cv2.erode(img_mask,kernel,iterations = 1)
img_mask /= 255
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
# pasing mask
# target_image_parsing = postprocess(target_image, source_image, tgt_mask)
if use_mask:
target_image = np.array(target_image, dtype=np.float) * 255
else:
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
img_mask_list.append(img_mask)
target_image_list.append(target_image)
# target_image /= 255
# target_image = 0
img = np.array(oriimg, dtype=np.float)
for img_mask, target_image in zip(img_mask_list, target_image_list):
img = img_mask * target_image + (1-img_mask) * img
final_img = img.astype(np.uint8)
cv2.imwrite(save_path, final_img)