diff --git a/.gitignore b/.gitignore index 3e65a47..209c7c7 100644 --- a/.gitignore +++ b/.gitignore @@ -126,4 +126,6 @@ wandb/ train_logs/ test_logs/ arcface_ckpt/ -GUI/ \ No newline at end of file +GUI/ +insightface_func/ +parsing_model/ \ No newline at end of file diff --git a/components/Generator_reduce.py b/components/Generator_reduce.py new file mode 100644 index 0000000..67b68e3 --- /dev/null +++ b/components/Generator_reduce.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator.py +# Created Date: Sunday January 16th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th January 2022 10:51:02 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import torch +from torch import nn +from torch.nn import init +from torch.nn import functional as F + +class InstanceNorm(nn.Module): + def __init__(self, epsilon=1e-8): + """ + @notice: avoid in-place ops. + https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 + """ + super(InstanceNorm, self).__init__() + self.epsilon = epsilon + + def forward(self, x): + x = x - torch.mean(x, (2, 3), True) + tmp = torch.mul(x, x) # or x ** 2 + tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) + return x * tmp + +class ApplyStyle(nn.Module): + """ + @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb + """ + def __init__(self, latent_size, channels): + super(ApplyStyle, self).__init__() + self.linear = nn.Linear(latent_size, channels * 2) + + def forward(self, x, latent): + style = self.linear(latent) # style => [batch_size, n_channels*2] + shape = [-1, 2, x.size(1), 1, 1] + style = style.view(shape) # [batch_size, 2, n_channels, ...] + #x = x * (style[:, 0] + 1.) + style[:, 1] + x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1 + return x + +class ResnetBlock_Adain(nn.Module): + def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): + super(ResnetBlock_Adain, self).__init__() + + p = 0 + conv1 = [] + if padding_type == 'reflect': + conv1 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv1 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] + self.conv1 = nn.Sequential(*conv1) + self.style1 = ApplyStyle(latent_size, dim) + self.act1 = activation + + p = 0 + conv2 = [] + if padding_type == 'reflect': + conv2 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv2 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] + self.conv2 = nn.Sequential(*conv2) + self.style2 = ApplyStyle(latent_size, dim) + + + def forward(self, x, dlatents_in_slice): + y = self.conv1(x) + y = self.style1(y, dlatents_in_slice) + y = self.act1(y) + y = self.conv2(y) + y = self.style2(y, dlatents_in_slice) + out = x + y + return out + +class newres(nn.Module): + def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): + super(ResnetBlock_Adain, self).__init__() + + p = 0 + conv1 = [] + if padding_type == 'reflect': + conv1 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv1 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv1 += [nn.Conv2d(dim, dim, kernel_size=1), InstanceNorm()] + self.conv1 = nn.Sequential(*conv1) + self.style1 = ApplyStyle(latent_size, dim) + self.act1 = activation + + p = 0 + conv2 = [] + if padding_type == 'reflect': + conv2 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv2 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] + self.conv2 = nn.Sequential(*conv2) + self.style2 = ApplyStyle(latent_size, dim) + + + def forward(self, x, dlatents_in_slice): + y = self.conv1(x) + y = self.style1(y, dlatents_in_slice) + y = self.act1(y) + y = self.conv2(y) + y = self.style2(y, dlatents_in_slice) + out = x + y + return out + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + chn = kwargs["g_conv_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + activation = nn.ReLU(True) + + self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(128), activation) + + self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(256), activation) + + self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(512), activation) + + self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(512), activation) + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)] + self.BottleNeck = nn.Sequential(*BN) + + self.up4 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(512), activation + ) + + self.up3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(256), activation + ) + + self.up2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(128), activation + ) + + self.up1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(64), activation + ) + + self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, input, id): + x = input # 3*224*224 + skip1 = self.first_layer(x) + skip2 = self.down1(skip1) + skip3 = self.down2(skip2) + skip4 = self.down3(skip3) + res = self.down4(skip4) + + for i in range(len(self.BottleNeck)): + x = self.BottleNeck[i](res, id) + + x = self.up4(x) + x = self.up3(x) + x = self.up2(x) + x = self.up1(x) + x = self.last_layer(x) + + return x diff --git a/losses/PatchNCE.py b/losses/PatchNCE.py new file mode 100644 index 0000000..b1dc8d5 --- /dev/null +++ b/losses/PatchNCE.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: PatchNCE.py +# Created Date: Friday January 21st 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 21st January 2022 5:04:43 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + + diff --git a/test.py b/test.py index afd8a77..a3d460f 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 12th October 2021 7:44:02 pm +# Last Modified: Friday, 21st January 2022 10:55:59 am # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -30,24 +30,23 @@ def getParameters(): parser = argparse.ArgumentParser() # general settings - parser.add_argument('-v', '--version', type=str, default='fastnst_3', + parser.add_argument('-v', '--version', type=str, default='2layerFM', help="version name for train, test, finetune") - parser.add_argument('-c', '--cuda', type=int, default=-1) # >0 if it is set as -1, program will use CPU - parser.add_argument('-e', '--checkpoint_epoch', type=int, default=19, + parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU + parser.add_argument('-s', '--checkpoint_step', type=int, default=310000, help="checkpoint epoch for test phase or finetune phase") # test - parser.add_argument('-t', '--test_script_name', type=str, default='FastNST') + parser.add_argument('-t', '--test_script_name', type=str, default='video') parser.add_argument('-b', '--batch_size', type=int, default=1) parser.add_argument('-n', '--node_name', type=str, default='localhost', choices=['localhost', '4card','8card','new4card']) - - parser.add_argument('--save_test_result', action='store_false') - parser.add_argument('--test_dataloader', type=str, default='dir') - parser.add_argument('-p', '--test_data_path', type=str, default='G:\\UltraHighStyleTransfer\\benchmark') + parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\dlrb2.jpeg') + parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\G2010.mp4', + help="file path for attribute images or video") parser.add_argument('--use_specified_data', action='store_true') parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files') @@ -235,10 +234,10 @@ def main(): # TODO get the checkpoint file path sys_state["ckp_name"] = {} - for data_key in sys_state["checkpoint_names"].keys(): - sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"], - "%d_%s.pth"%(sys_state["checkpoint_epoch"], - sys_state["checkpoint_names"][data_key])) + # for data_key in sys_state["checkpoint_names"].keys(): + # sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"], + # "%d_%s.pth"%(sys_state["checkpoint_epoch"], + # sys_state["checkpoint_names"][data_key])) # Get the test configurations sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"] diff --git a/test_scripts/tester_FastNST.py b/test_scripts/tester_common copy.py similarity index 57% rename from test_scripts/tester_FastNST.py rename to test_scripts/tester_common copy.py index bc969c2..30ec590 100644 --- a/test_scripts/tester_FastNST.py +++ b/test_scripts/tester_common copy.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 12th October 2021 8:22:37 pm +# Last Modified: Sunday, 4th July 2021 11:32:14 am # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -35,13 +35,13 @@ class Tester(object): package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True) dataloaderClass = getattr(package, 'TestDataset') dataloader = dataloaderClass(config["test_data_path"], - 1, + config["batch_size"], ["png","jpg"]) self.test_loader= dataloader - self.test_iter = len(dataloader) - # if len(dataloader)%config["batch_size"]>0: - # self.test_iter+=1 + self.test_iter = len(dataloader)//config["batch_size"] + if len(dataloader)%config["batch_size"]>0: + self.test_iter+=1 def __init_framework__(self): @@ -52,14 +52,19 @@ class Tester(object): #===============build models================# print("build models...") # TODO [import models here] - model_config = self.config["model_configs"] - script_name = self.config["com_base"] + model_config["g_model"]["script"] - class_name = model_config["g_model"]["class_name"] + script_name = "components."+self.config["module_script_name"] + class_name = self.config["class_name"] package = __import__(script_name, fromlist=True) network_class = getattr(package, class_name) + n_class = len(self.config["selectedStyleDir"]) # TODO replace below lines to define the model framework - self.network = network_class(**model_config["g_model"]["module_params"]) + self.network = network_class(self.config["GConvDim"], + self.config["GKS"], + self.config["resNum"], + n_class + #**self.config["module_params"] + ) # print and recorde model structure self.reporter.writeInfo("Model structure:") @@ -68,14 +73,12 @@ class Tester(object): # train in GPU if self.config["cuda"] >=0: self.network = self.network.cuda() - - model_path = os.path.join(self.config["project_checkpoints"], - "epoch%d_%s.pth"%(self.config["checkpoint_epoch"], - self.config["checkpoint_names"]["generator_name"])) - - self.network.load_state_dict(torch.load(model_path)) + # loader1 = torch.load(self.config["ckp_name"]["generator_name"]) + # print(loader1.key()) + # pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"] + self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"]) # self.network.load_state_dict(torch.load(pathwocao)) - print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"])) + print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"])) def test(self): @@ -84,13 +87,18 @@ class Tester(object): ckp_epoch = self.config["checkpoint_epoch"] version = self.config["version"] batch_size = self.config["batch_size"] - win_size = self.config["model_configs"]["g_model"]["module_params"]["window_size"] + style_names = self.config["selectedStyleDir"] + n_class = len(style_names) # models self.__init_framework__() + condition_labels = torch.ones((n_class, batch_size, 1)).long() + for i in range(n_class): + condition_labels[i,:,:] = condition_labels[i,:,:]*i + if self.config["cuda"] >=0: + condition_labels = condition_labels.cuda() total = len(self.test_loader) - print("total:", total) # Start time import datetime print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) @@ -98,25 +106,18 @@ class Tester(object): start_time = time.time() self.network.eval() with torch.no_grad(): - for _ in tqdm(range(total)): + for _ in tqdm(range(total//batch_size)): contents, img_names = self.test_loader() - B, C, H, W = contents.shape - crop_h = H - H%32 - crop_w = W - W%32 - crop_s = min(crop_h, crop_w) - contents = contents[:,:,(H//2 - crop_s//2):(crop_s//2 + H//2), - (W//2 - crop_s//2):(crop_s//2 + W//2)] - if self.config["cuda"] >=0: - contents = contents.cuda() - res = self.network(contents, (crop_s, crop_s)) - print("res shape:", res.shape) - res = tensor2img(res.cpu()) - temp_img = res[0,:,:,:] - temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR) - print(save_dir) - print(img_names[0]) - cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}.png'.format( - img_names[0], version, ckp_epoch)),temp_img) + for i in range(n_class): + if self.config["cuda"] >=0: + contents = contents.cuda() + res, _ = self.network(contents, condition_labels[i, 0, :]) + res = tensor2img(res.cpu()) + for t in range(batch_size): + temp_img = res[t,:,:,:] + temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format( + img_names[t], version, ckp_epoch, style_names[i])),temp_img) elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) diff --git a/test_scripts/tester_video.py b/test_scripts/tester_video.py index 30ec590..bfe6627 100644 --- a/test_scripts/tester_video.py +++ b/test_scripts/tester_video.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Sunday, 4th July 2021 11:32:14 am +# Last Modified: Friday, 21st January 2022 11:06:37 am # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -15,12 +15,24 @@ import os import cv2 import time +import shutil import torch -from utilities.utilities import tensor2img +import torch.nn.functional as F +from torchvision import transforms -# from utilities.Reporter import Reporter +from moviepy.editor import AudioFileClip, VideoFileClip +from moviepy.video.io.ImageSequenceClip import ImageSequenceClip + +import numpy as np from tqdm import tqdm +from PIL import Image +import glob + +from utilities.ImagenetNorm import ImagenetNorm +from parsing_model.model import BiSeNet +from insightface_func.face_detect_crop_single import Face_detect_crop +from utilities.reverse2original import reverse2wholeimage class Tester(object): def __init__(self, config, reporter): @@ -29,20 +41,126 @@ class Tester(object): # logger self.reporter = reporter - #============build evaluation dataloader==============# - print("Prepare the test dataloader...") - dlModulename = config["test_dataloader"] - package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True) - dataloaderClass = getattr(package, 'TestDataset') - dataloader = dataloaderClass(config["test_data_path"], - config["batch_size"], - ["png","jpg"]) - self.test_loader= dataloader + self.transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1) + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1) + + def cv2totensor(self, cv2_img): + """ + cv2_img: an image read by cv2, H*W*C + return: an 1*C*H*W tensor + """ + cv2_img = cv2.cvtColor(cv2_img,cv2.COLOR_BGR2RGB) + cv2_img = torch.from_numpy(cv2_img) + cv2_img = cv2_img.permute(2,0,1).cuda() + temp = cv2_img / 255.0 + temp -= self.imagenet_mean + temp /= self.imagenet_std + return temp.unsqueeze(0) - self.test_iter = len(dataloader)//config["batch_size"] - if len(dataloader)%config["batch_size"]>0: - self.test_iter+=1 + def video_swap( + self, + video_path, + id_vetor, + save_path, + temp_results_dir='./temp_results', + crop_size=512, + use_mask =False + ): + + video_forcheck = VideoFileClip(video_path) + if video_forcheck.audio is None: + no_audio = True + else: + no_audio = False + + del video_forcheck + + if not no_audio: + video_audio_clip = AudioFileClip(video_path) + + video = cv2.VideoCapture(video_path) + ret = True + frame_index = 0 + + frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + + # video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + + # video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = video.get(cv2.CAP_PROP_FPS) + if os.path.exists(temp_results_dir): + shutil.rmtree(temp_results_dir) + spNorm =ImagenetNorm() + if use_mask: + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + save_pth = os.path.join('./parsing_model', '79999_iter.pth') + net.load_state_dict(torch.load(save_pth)) + net.eval() + else: + net =None + + # while ret: + for frame_index in tqdm(range(frame_count)): + ret, frame = video.read() + if ret: + detect_results = self.detect.get(frame,crop_size) + + if detect_results is not None: + # print(frame_index) + if not os.path.exists(temp_results_dir): + os.mkdir(temp_results_dir) + frame_align_crop_list = detect_results[0] + frame_mat_list = detect_results[1] + swap_result_list = [] + frame_align_crop_tenor_list = [] + for frame_align_crop in frame_align_crop_list: + frame_align_crop_tenor = self.cv2totensor(frame_align_crop) + swap_result = self.network(frame_align_crop_tenor, id_vetor)[0] + swap_result = swap_result* self.imagenet_std + self.imagenet_mean + swap_result = torch.clip(swap_result,0.0,1.0) + cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) + swap_result_list.append(swap_result) + frame_align_crop_tenor_list.append(frame_align_crop_tenor) + reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame,\ + os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),pasring_model =net,use_mask=use_mask, norm = spNorm) + + else: + if not os.path.exists(temp_results_dir): + os.mkdir(temp_results_dir) + frame = frame.astype(np.uint8) + cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) + else: + break + + video.release() + + # image_filename_list = [] + path = os.path.join(temp_results_dir,'*.jpg') + image_filenames = sorted(glob.glob(path)) + + clips = ImageSequenceClip(image_filenames,fps = fps) + + if not no_audio: + clips = clips.set_audio(video_audio_clip) + basename = os.path.basename(video_path) + basename = os.path.splitext(basename)[0] + save_filename = os.path.join(save_path, basename+".mp4") + index = 0 + while(True): + if os.path.exists(save_filename): + save_filename = os.path.join(save_path, basename+"_%d.mp4"%index) + index += 1 + else: + break + clips.write_videofile(save_filename,audio_codec='aac') + def __init_framework__(self): ''' @@ -52,53 +170,68 @@ class Tester(object): #===============build models================# print("build models...") # TODO [import models here] - script_name = "components."+self.config["module_script_name"] - class_name = self.config["class_name"] - package = __import__(script_name, fromlist=True) - network_class = getattr(package, class_name) - n_class = len(self.config["selectedStyleDir"]) + model_config = self.config["model_configs"] + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.network = gen_class(**model_config["g_model"]["module_params"]) # TODO replace below lines to define the model framework - self.network = network_class(self.config["GConvDim"], - self.config["GKS"], - self.config["resNum"], - n_class - #**self.config["module_params"] - ) - + self.network = gen_class(**model_config["g_model"]["module_params"]) + self.network = self.network.eval() # print and recorde model structure self.reporter.writeInfo("Model structure:") self.reporter.writeModel(self.network.__str__()) + + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + self.arcface.eval() + self.arcface.requires_grad_(False) # train in GPU if self.config["cuda"] >=0: self.network = self.network.cuda() + self.arcface = self.arcface.cuda() # loader1 = torch.load(self.config["ckp_name"]["generator_name"]) # print(loader1.key()) # pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"] - self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"]) + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.network.load_state_dict(torch.load(model_path)) # self.network.load_state_dict(torch.load(pathwocao)) - print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"])) + print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"])) def test(self): # save_result = self.config["saveTestResult"] save_dir = self.config["test_samples_path"] - ckp_epoch = self.config["checkpoint_epoch"] + ckp_step = self.config["checkpoint_step"] version = self.config["version"] - batch_size = self.config["batch_size"] - style_names = self.config["selectedStyleDir"] - n_class = len(style_names) + id_imgs = self.config["id_imgs"] + attr_files = self.config["attr_files"] + self.arcface_ckpt= self.config["arcface_ckpt"] # models self.__init_framework__() - condition_labels = torch.ones((n_class, batch_size, 1)).long() - for i in range(n_class): - condition_labels[i,:,:] = condition_labels[i,:,:]*i - if self.config["cuda"] >=0: - condition_labels = condition_labels.cuda() - total = len(self.test_loader) + + + mode = None + self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') + self.detect.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode) + + id_img = cv2.imread(id_imgs) + id_img_align_crop, _ = self.detect.get(id_img,512) + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB)) + id_img = self.transformer_Arcface(id_img_align_crop_pil) + id_img = id_img.unsqueeze(0).cuda() + + #create latent id + id_img = F.interpolate(id_img,size=(112,112), mode='bicubic') + latend_id = self.arcface(id_img) + latend_id = F.normalize(latend_id, p=2, dim=1) # Start time import datetime print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) @@ -106,18 +239,8 @@ class Tester(object): start_time = time.time() self.network.eval() with torch.no_grad(): - for _ in tqdm(range(total//batch_size)): - contents, img_names = self.test_loader() - for i in range(n_class): - if self.config["cuda"] >=0: - contents = contents.cuda() - res, _ = self.network(contents, condition_labels[i, 0, :]) - res = tensor2img(res.cpu()) - for t in range(batch_size): - temp_img = res[t,:,:,:] - temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR) - cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format( - img_names[t], version, ckp_epoch, style_names[i])),temp_img) + self.video_swap(attr_files, latend_id, save_dir, temp_results_dir="./.temples",\ + use_mask=False,crop_size=512) elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) diff --git a/utilities/ImagenetNorm.py b/utilities/ImagenetNorm.py new file mode 100644 index 0000000..89c57ca --- /dev/null +++ b/utilities/ImagenetNorm.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: ImagenetNorm.py +# Created Date: Friday January 21st 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 21st January 2022 10:41:50 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + + +import torch.nn as nn +import numpy as np +import torch +class ImagenetNorm(nn.Module): + def __init__(self, epsilon=1e-8): + """ + @notice: avoid in-place ops. + https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 + """ + super(ImagenetNorm, self).__init__() + self.mean = np.array([0.485, 0.456, 0.406]) + self.mean = torch.from_numpy(self.mean).float().cuda() + self.mean = self.mean.view([1, 3, 1, 1]) + + self.std = np.array([0.229, 0.224, 0.225]) + self.std = torch.from_numpy(self.std).float().cuda() + self.std = self.std.view([1, 3, 1, 1]) + + def forward(self, x): + mean = self.mean.expand([1, 3, x.shape[2], x.shape[3]]) + std = self.std.expand([1, 3, x.shape[2], x.shape[3]]) + + x = (x - mean) / std + + return x \ No newline at end of file diff --git a/utilities/reverse2original.py b/utilities/reverse2original.py new file mode 100644 index 0000000..8b7fbc9 --- /dev/null +++ b/utilities/reverse2original.py @@ -0,0 +1,173 @@ +import cv2 +import numpy as np +# import time +import torch +from torch.nn import functional as F +import torch.nn as nn + + +def encode_segmentation_rgb(segmentation, no_neck=True): + parse = segmentation + + face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14] + mouth_id = 11 + # hair_id = 17 + face_map = np.zeros([parse.shape[0], parse.shape[1]]) + mouth_map = np.zeros([parse.shape[0], parse.shape[1]]) + # hair_map = np.zeros([parse.shape[0], parse.shape[1]]) + + for valid_id in face_part_ids: + valid_index = np.where(parse==valid_id) + face_map[valid_index] = 255 + valid_index = np.where(parse==mouth_id) + mouth_map[valid_index] = 255 + # valid_index = np.where(parse==hair_id) + # hair_map[valid_index] = 255 + #return np.stack([face_map, mouth_map,hair_map], axis=2) + return np.stack([face_map, mouth_map], axis=2) + + +class SoftErosion(nn.Module): + def __init__(self, kernel_size=15, threshold=0.6, iterations=1): + super(SoftErosion, self).__init__() + r = kernel_size // 2 + self.padding = r + self.iterations = iterations + self.threshold = threshold + + # Create kernel + y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size)) + dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2) + kernel = dist.max() - dist + kernel /= kernel.sum() + kernel = kernel.view(1, 1, *kernel.shape) + self.register_buffer('weight', kernel) + + def forward(self, x): + x = x.float() + for i in range(self.iterations - 1): + x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)) + x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding) + + mask = x >= self.threshold + x[mask] = 1.0 + x[~mask] /= x[~mask].max() + + return x, mask + + +def postprocess(swapped_face, target, target_mask,smooth_mask): + # target_mask = cv2.resize(target_mask, (self.size, self.size)) + + mask_tensor = torch.from_numpy(target_mask.copy().transpose((2, 0, 1))).float().mul_(1/255.0).cuda() + face_mask_tensor = mask_tensor[0] + mask_tensor[1] + + soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0)) + soft_face_mask_tensor.squeeze_() + + soft_face_mask = soft_face_mask_tensor.cpu().numpy() + soft_face_mask = soft_face_mask[:, :, np.newaxis] + + result = swapped_face * soft_face_mask + target * (1 - soft_face_mask) + result = result[:,:,::-1]# .astype(np.uint8) + return result + +def reverse2wholeimage(b_align_crop_tenor_list,swaped_imgs, mats, crop_size, oriimg, save_path = '', \ + pasring_model =None, norm = None, use_mask = False): + + target_image_list = [] + img_mask_list = [] + if use_mask: + smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).cuda() + else: + pass + + # print(len(swaped_imgs)) + # print(mats) + # print(len(b_align_crop_tenor_list)) + for swaped_img, mat ,source_img in zip(swaped_imgs, mats,b_align_crop_tenor_list): + swaped_img = swaped_img.cpu().detach().numpy().transpose((1, 2, 0)) + img_white = np.full((crop_size,crop_size), 255, dtype=float) + + # inverse the Affine transformation matrix + mat_rev = np.zeros([2,3]) + div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0] + mat_rev[0][0] = mat[1][1]/div1 + mat_rev[0][1] = -mat[0][1]/div1 + mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1 + div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1] + mat_rev[1][0] = mat[1][0]/div2 + mat_rev[1][1] = -mat[0][0]/div2 + mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 + + orisize = (oriimg.shape[1], oriimg.shape[0]) + if use_mask: + source_img_norm = norm(source_img) + source_img_512 = F.interpolate(source_img_norm,size=(512,512)) + out = pasring_model(source_img_512)[0] + parsing = out.squeeze(0).detach().cpu().numpy().argmax(0) + vis_parsing_anno = parsing.copy().astype(np.uint8) + tgt_mask = encode_segmentation_rgb(vis_parsing_anno) + if tgt_mask.sum() >= 5000: + # face_mask_tensor = tgt_mask[...,0] + tgt_mask[...,1] + target_mask = cv2.resize(tgt_mask, (crop_size, crop_size)) + # print(source_img) + target_image_parsing = postprocess(swaped_img, source_img[0].cpu().detach().numpy().transpose((1, 2, 0)), target_mask,smooth_mask) + + + target_image = cv2.warpAffine(target_image_parsing, mat_rev, orisize) + # target_image_parsing = cv2.warpAffine(swaped_img, mat_rev, orisize) + else: + target_image = cv2.warpAffine(swaped_img, mat_rev, orisize)[..., ::-1] + else: + target_image = cv2.warpAffine(swaped_img, mat_rev, orisize) + # source_image = cv2.warpAffine(source_img, mat_rev, orisize) + + img_white = cv2.warpAffine(img_white, mat_rev, orisize) + + + img_white[img_white>20] =255 + + img_mask = img_white + + # if use_mask: + # kernel = np.ones((40,40),np.uint8) + # img_mask = cv2.erode(img_mask,kernel,iterations = 1) + # else: + kernel = np.ones((40,40),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + kernel_size = (20, 20) + blur_size = tuple(2*i+1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + + # kernel = np.ones((10,10),np.uint8) + # img_mask = cv2.erode(img_mask,kernel,iterations = 1) + + + + img_mask /= 255 + + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + + # pasing mask + + # target_image_parsing = postprocess(target_image, source_image, tgt_mask) + + if use_mask: + target_image = np.array(target_image, dtype=np.float) * 255 + else: + target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 + + + img_mask_list.append(img_mask) + target_image_list.append(target_image) + + + # target_image /= 255 + # target_image = 0 + img = np.array(oriimg, dtype=np.float) + for img_mask, target_image in zip(img_mask_list, target_image_list): + img = img_mask * target_image + (1-img_mask) * img + + final_img = img.astype(np.uint8) + cv2.imwrite(save_path, final_img)