This commit is contained in:
chenxuanhong
2022-04-24 15:44:47 +08:00
parent 99ed65aaa3
commit 29d8914c0a
138 changed files with 24864 additions and 353 deletions
+279
View File
@@ -0,0 +1,279 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_commonn.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th April 2022 9:04:01 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import glob
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from insightface_func.face_detect_crop_single import Face_detect_crop
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# for name in self.network.state_dict():
# print(name)
self.features = {}
mapping_layers = [
"first_layer",
"down4",
"BottleNeck.2"
]
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
def test(self):
save_dir = self.config["test_samples_path"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
crop_mode = self.config["crop_mode"]
list_txt = self.config["img_list_txt"]
record_metric= self.config["record_metric"]
specified_save_path = self.config["specified_save_path"]
self.arcface_ckpt= self.config["arcface_ckpt"]
imgs_list = []
self.reporter.writeInfo("Version %s"%version)
if os.path.isdir(specified_save_path):
print("Input a legal specified save path!")
save_dir = specified_save_path
imgs_list = []
with open(list_txt,'r') as logf:
for line in logf:
cells = line.split(";")
imgs_list.append([cells[0],cells[1],cells[2].replace("\n","")])
# models
self.__init_framework__()
mode = crop_mode.lower()
if mode == "vggface":
mode = "none"
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
cos_loss = torch.nn.CosineSimilarity()
font = cv2.FONT_HERSHEY_SIMPLEX
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
cos_dict = {}
average_cos = 0
with torch.no_grad():
for img in imgs_list:
id_img_n, attr_img_n, fusion= img
print("id image:%s---attr image:%s"%(id_img_n, attr_img_n))
id_img = cv2.imread(id_img_n)
print(fusion)
if fusion.lower() == "fusion":
try:
id_img_align_crop, _ = self.detect.get(id_img,512)
except:
print("Image %s Do not detect a face!"%id_img_n)
continue
# id_basename = os.path.splitext(os.path.basename(id_img_n))[0]
# cv2.imwrite(os.path.join(save_dir, "id_%s.png"%(id_basename)),id_img_align_crop[0])
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
else:
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
attr_img_ori= cv2.imread(attr_img_n)
if fusion.lower() == "fusion":
try:
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
except:
print("Image %s Do not detect a face!"%attr_img_n)
continue
# attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0]
# cv2.imwrite(os.path.join(save_dir, "attr_%s.png"%(attr_basename)),attr_img_align_crop[0])
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB))
else:
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_ori,cv2.COLOR_BGR2RGB))
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
attr_id = self.arcface(attr_img_arc)
attr_id = F.normalize(attr_id, p=2, dim=1)
cos_dis = 1 - cos_loss(latend_id, attr_id)
results,mask_lr,mask_hr= self.network(attr_img, latend_id)
mask_lr = mask_lr.cpu().permute(0,2,3,1)[0,...]
mask_lr = mask_lr.numpy()
# mask_lr = (mask_lr - np.min(mask_lr))/np.max(mask_lr)
mask_lr = np.clip(mask_lr,0.0,1.0) * 255
mask_hr = mask_hr.cpu().permute(0,2,3,1)[0,...]
mask_hr = mask_hr.numpy()
# mask_hr = (mask_hr - np.min(mask_hr))/np.max(mask_hr)
mask_hr = np.clip(mask_hr,0.0,1.0) * 255
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
results_arc = self.arcface(results_arc)
results_arc = F.normalize(results_arc, p=2, dim=1)
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
average_cos += results_cos_dis
results = results * self.imagenet_std + self.imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
results = results.numpy()
results = np.clip(results,0.0,1.0)
if fusion.lower() == "fusion":
mat = mat[0]
img_white = np.full((512,512), 255, dtype=float)
# inverse the Affine transformation matrix
mat_rev = np.zeros([2,3])
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
mat_rev[0][0] = mat[1][1]/div1
mat_rev[0][1] = -mat[0][1]/div1
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
mat_rev[1][0] = mat[1][0]/div2
mat_rev[1][1] = -mat[0][0]/div2
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
target_image = cv2.warpAffine(results, mat_rev, orisize)
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
img_white[img_white>20] =255
img_mask = img_white
kernel = np.ones((40,40),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel_size = (20, 20)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
img_mask /= 255
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
img1 = np.array(attr_img_ori, dtype=np.float)
img1 = img_mask * target_image + (1-img_mask) * img1
else:
results = results*255
img1 = cv2.cvtColor(results,cv2.COLOR_RGB2BGR)
final_img = img1.astype(np.uint8)
id_basename = os.path.basename(id_img_n)
id_basename = os.path.splitext(os.path.basename(id_img_n))[0]
attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0]
if record_metric:
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
print(save_dir)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename, final_img)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask_lr.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename,mask_lr)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask_hr.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename,mask_hr)
average_cos /= len(imgs_list)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
+328
View File
@@ -0,0 +1,328 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_commonn.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Saturday, 23rd April 2022 10:04:51 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import glob
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from insightface_func.face_detect_crop_single import Face_detect_crop
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# for name in self.network.state_dict():
# print(name)
self.features = {}
mapping_layers = [
"first_layer",
"down4",
"BottleNeck.2"
]
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
if self.config["preprocess"]:
print("Employ GFPGAN to upsampling detected face images!")
from face_enhancer.gfpgan import GFPGANer
version = '1.2'
if version == '1':
arch = 'original'
channel_multiplier = 1
model_name = 'GFPGANv1'
elif version == '1.2':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANCleanv1-NoCE-C2'
elif version == '1.3':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.3'
# determine model paths
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth')
if not os.path.isfile(model_path):
raise ValueError(f'Model {model_name} does not exist.')
self.restorer = GFPGANer(
model_path=model_path,
upscale=1,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=None)
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
def test(self):
save_dir = self.config["test_samples_path"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
crop_mode = self.config["crop_mode"]
list_txt = self.config["img_list_txt"]
record_metric= self.config["record_metric"]
specified_save_path = self.config["specified_save_path"]
self.arcface_ckpt= self.config["arcface_ckpt"]
imgs_list = []
self.reporter.writeInfo("Version %s"%version)
if os.path.isdir(specified_save_path):
print("Input a legal specified save path!")
save_dir = specified_save_path
imgs_list = []
with open(list_txt,'r') as logf:
for line in logf:
cells = line.split(";")
imgs_list.append([cells[0],cells[1],cells[2].replace("\n","")])
# models
self.__init_framework__()
mode = crop_mode.lower()
if mode == "vggface":
mode = "none"
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
cos_loss = torch.nn.CosineSimilarity()
font = cv2.FONT_HERSHEY_SIMPLEX
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
cos_dict = {}
average_cos = 0
with torch.no_grad():
for img in imgs_list:
id_img_n, attr_img_n, fusion= img
print("id image:%s---attr image:%s"%(id_img_n, attr_img_n))
id_img = cv2.imread(id_img_n)
print(fusion)
if fusion.lower() == "fusion":
try:
id_img_align_crop, _ = self.detect.get(id_img,512)
except:
print("Image %s Do not detect a face!"%id_img_n)
continue
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
else:
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
attr_img_ori= cv2.imread(attr_img_n)
if fusion.lower() == "fusion":
try:
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
except:
print("Image %s Do not detect a face!"%attr_img_n)
continue
# attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0]
# cv2.imwrite(os.path.join(save_dir, "attr_%s.png"%(attr_basename)),attr_img_align_crop[0])
restored_face = attr_img_align_crop[0]
if self.config["preprocess"]:
_, _, restored_face = self.restorer.enhance(
restored_face, has_aligned=False, only_center_face=True, paste_back=True)
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB))
else:
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_ori,cv2.COLOR_BGR2RGB))
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
attr_id = self.arcface(attr_img_arc)
attr_id = F.normalize(attr_id, p=2, dim=1)
cos_dis = 1 - cos_loss(latend_id, attr_id)
# results,mask= self.network(attr_img, latend_id)
pred = self.network(attr_img, latend_id)
results = pred[0]
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
results_arc = self.arcface(results_arc)
results_arc = F.normalize(results_arc, p=2, dim=1)
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
average_cos += results_cos_dis
results = results * self.imagenet_std + self.imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
results = results.numpy()
results = np.clip(results,0.0,1.0)
if fusion.lower() == "fusion":
mat = mat[0]
img_white = np.full((512,512), 255, dtype=float)
# inverse the Affine transformation matrix
mat_rev = np.zeros([2,3])
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
mat_rev[0][0] = mat[1][1]/div1
mat_rev[0][1] = -mat[0][1]/div1
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
mat_rev[1][0] = mat[1][0]/div2
mat_rev[1][1] = -mat[0][0]/div2
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
target_image = cv2.warpAffine(results, mat_rev, orisize)
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
img_white[img_white>20] =255
img_mask = img_white
kernel = np.ones((40,40),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel_size = (20, 20)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
img_mask /= 255
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
img1 = np.array(attr_img_ori, dtype=np.float)
img1 = img_mask * target_image + (1-img_mask) * img1
else:
results = results*255
img1 = cv2.cvtColor(results,cv2.COLOR_RGB2BGR)
final_img = img1.astype(np.uint8)
id_basename = os.path.basename(id_img_n)
id_basename = os.path.splitext(os.path.basename(id_img_n))[0]
attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0]
if record_metric:
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
print(save_dir)
if self.config["preprocess"]:
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_gfpgan.png"%(id_basename,
attr_basename,ckp_step,version))
else:
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename, final_img)
if self.config["save_mask"]:
num = 0
for mask in pred[1:]:
mask = mask.cpu().permute(0,2,3,1)[0,...]
mask = mask.numpy()
mask = (mask - np.min(mask))/np.max(mask)
mask = np.clip(mask,0.0,1.0) * 255
if self.config["preprocess"]:
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask%d_gfpgan.png"%(id_basename,
attr_basename,ckp_step,version,num))
else:
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask%d.png"%(id_basename,
attr_basename,ckp_step,version,num))
cv2.imwrite(save_filename,mask)
num += 1
average_cos /= len(imgs_list)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
+255
View File
@@ -0,0 +1,255 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_commonn.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th April 2022 10:09:21 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import glob
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from insightface_func.face_detect_crop_single import Face_detect_crop
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# for name in self.network.state_dict():
# print(name)
self.features = {}
mapping_layers = [
"first_layer",
"down4",
"BottleNeck.2"
]
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
def test(self):
save_dir = self.config["test_samples_path"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
id_imgs = self.config["id_imgs"]
crop_mode = self.config["crop_mode"]
attr_files = self.config["attr_files"]
specified_save_path = self.config["specified_save_path"]
self.arcface_ckpt= self.config["arcface_ckpt"]
imgs_list = []
self.reporter.writeInfo("Version %s"%version)
if os.path.isdir(specified_save_path):
print("Input a legal specified save path!")
save_dir = specified_save_path
if os.path.isdir(attr_files):
print("Input a dir....")
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
for item in imgs:
imgs_list.append(item)
print(imgs_list)
else:
print("Input an image....")
imgs_list.append(attr_files)
id_basename = os.path.basename(id_imgs)
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
# models
self.__init_framework__()
mode = crop_mode.lower()
if mode == "vggface":
mode = "none"
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
id_img = cv2.imread(id_imgs)
id_img_align_crop, _ = self.detect.get(id_img,512)
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
cos_loss = torch.nn.CosineSimilarity()
font = cv2.FONT_HERSHEY_SIMPLEX
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
cos_dict = {}
average_cos = 0
with torch.no_grad():
for img in imgs_list:
print(img)
attr_img_ori= cv2.imread(img)
try:
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
except:
continue
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB))
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
attr_id = self.arcface(attr_img_arc)
attr_id = F.normalize(attr_id, p=2, dim=1)
cos_dis = 1 - cos_loss(latend_id, attr_id)
mat = mat[0]
results,mask_lr,mask_hr= self.network(attr_img, latend_id)
mask_lr = mask_lr.cpu().permute(0,2,3,1)[0,...]
mask_lr = mask_lr.numpy()
# mask_lr = (mask_lr - np.min(mask_lr))/np.max(mask_lr)
mask_lr = np.clip(mask_lr,0.0,1.0) * 255
mask_hr = mask_hr.cpu().permute(0,2,3,1)[0,...]
mask_hr = mask_hr.numpy()
# mask_hr = (mask_hr - np.min(mask_hr))/np.max(mask_hr)
mask_hr = np.clip(mask_hr,0.0,1.0) * 255
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
results_arc = self.arcface(results_arc)
results_arc = F.normalize(results_arc, p=2, dim=1)
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
average_cos += results_cos_dis
results = results * self.imagenet_std + self.imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
results = results.numpy()
results = np.clip(results,0.0,1.0)
img_white = np.full((512,512), 255, dtype=float)
# inverse the Affine transformation matrix
mat_rev = np.zeros([2,3])
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
mat_rev[0][0] = mat[1][1]/div1
mat_rev[0][1] = -mat[0][1]/div1
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
mat_rev[1][0] = mat[1][0]/div2
mat_rev[1][1] = -mat[0][0]/div2
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
target_image = cv2.warpAffine(results, mat_rev, orisize)
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
img_white[img_white>20] =255
img_mask = img_white
kernel = np.ones((40,40),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel_size = (20, 20)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
img_mask /= 255
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
img1 = np.array(attr_img_ori, dtype=np.float)
img1 = img_mask * target_image + (1-img_mask) * img1
final_img = img1.astype(np.uint8)
attr_basename = os.path.splitext(os.path.basename(img))[0]
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename, final_img)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask_lr.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename,mask_lr)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask_hr.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename,mask_hr)
average_cos /= len(imgs_list)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
+286
View File
@@ -0,0 +1,286 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_commonn.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Saturday, 16th April 2022 5:20:54 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import glob
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from insightface_func.face_detect_crop_single import Face_detect_crop
from face_enhancer.gfpgan import GFPGANer
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# for name in self.network.state_dict():
# print(name)
self.features = {}
mapping_layers = [
"first_layer",
"down4",
"BottleNeck.2"
]
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
version = '1.2'
if version == '1':
arch = 'original'
channel_multiplier = 1
model_name = 'GFPGANv1'
elif version == '1.2':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANCleanv1-NoCE-C2'
elif version == '1.3':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.3'
# determine model paths
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth')
if not os.path.isfile(model_path):
raise ValueError(f'Model {model_name} does not exist.')
self.restorer = GFPGANer(
model_path=model_path,
upscale=1,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=None)
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
def test(self):
save_dir = self.config["test_samples_path"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
id_imgs = self.config["id_imgs"]
crop_mode = self.config["crop_mode"]
attr_files = self.config["attr_files"]
specified_save_path = self.config["specified_save_path"]
self.arcface_ckpt= self.config["arcface_ckpt"]
imgs_list = []
self.reporter.writeInfo("Version %s"%version)
if os.path.isdir(specified_save_path):
print("Input a legal specified save path!")
save_dir = specified_save_path
if os.path.isdir(attr_files):
print("Input a dir....")
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
for item in imgs:
imgs_list.append(item)
print(imgs_list)
else:
print("Input an image....")
imgs_list.append(attr_files)
id_basename = os.path.basename(id_imgs)
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
# models
self.__init_framework__()
mode = crop_mode.lower()
if mode == "vggface":
mode = "none"
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
id_img = cv2.imread(id_imgs)
id_img_align_crop, _ = self.detect.get(id_img,512)
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
cos_loss = torch.nn.CosineSimilarity()
font = cv2.FONT_HERSHEY_SIMPLEX
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
cos_dict = {}
average_cos = 0
with torch.no_grad():
for img in imgs_list:
print(img)
attr_img_ori= cv2.imread(img)
try:
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
except:
continue
_, _, restored_face = self.restorer.enhance(
attr_img_align_crop[0], has_aligned=False, only_center_face=True, paste_back=True)
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB))
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
attr_id = self.arcface(attr_img_arc)
attr_id = F.normalize(attr_id, p=2, dim=1)
cos_dis = 1 - cos_loss(latend_id, attr_id)
mat = mat[0]
results,mask_lr,mask_hr= self.network(attr_img, latend_id)
mask_lr = mask_lr.cpu().permute(0,2,3,1)[0,...]
mask_lr = mask_lr.numpy()
# mask_lr = (mask_lr - np.min(mask_lr))/np.max(mask_lr)
mask_lr = np.clip(mask_lr,0.0,1.0) * 255
mask_hr = mask_hr.cpu().permute(0,2,3,1)[0,...]
mask_hr = mask_hr.numpy()
# mask_hr = (mask_hr - np.min(mask_hr))/np.max(mask_hr)
mask_hr = np.clip(mask_hr,0.0,1.0) * 255
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
results_arc = self.arcface(results_arc)
results_arc = F.normalize(results_arc, p=2, dim=1)
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
average_cos += results_cos_dis
results = results * self.imagenet_std + self.imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
results = results.numpy()
results = np.clip(results,0.0,1.0)
img_white = np.full((512,512), 255, dtype=float)
# inverse the Affine transformation matrix
mat_rev = np.zeros([2,3])
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
mat_rev[0][0] = mat[1][1]/div1
mat_rev[0][1] = -mat[0][1]/div1
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
mat_rev[1][0] = mat[1][0]/div2
mat_rev[1][1] = -mat[0][0]/div2
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
target_image = cv2.warpAffine(results, mat_rev, orisize)
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
img_white[img_white>20] =255
img_mask = img_white
kernel = np.ones((40,40),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel_size = (20, 20)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
img_mask /= 255
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
img1 = np.array(attr_img_ori, dtype=np.float)
img1 = img_mask * target_image + (1-img_mask) * img1
final_img = img1.astype(np.uint8)
attr_basename = os.path.splitext(os.path.basename(img))[0]
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename, final_img)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask_lr.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename,mask_lr)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask_hr.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename,mask_hr)
average_cos /= len(imgs_list)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
+301
View File
@@ -0,0 +1,301 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_commonn.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Saturday, 23rd April 2022 10:05:22 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import glob
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from insightface_func.face_detect_crop_single import Face_detect_crop
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# for name in self.network.state_dict():
# print(name)
self.features = {}
mapping_layers = [
"first_layer",
"down4",
"BottleNeck.2"
]
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
if self.config["preprocess"]:
print("Employ GFPGAN to upsampling detected face images!")
from face_enhancer.gfpgan import GFPGANer
version = '1.2'
if version == '1':
arch = 'original'
channel_multiplier = 1
model_name = 'GFPGANv1'
elif version == '1.2':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANCleanv1-NoCE-C2'
elif version == '1.3':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.3'
# determine model paths
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth')
if not os.path.isfile(model_path):
raise ValueError(f'Model {model_name} does not exist.')
self.restorer = GFPGANer(
model_path=model_path,
upscale=1,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=None)
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
def test(self):
save_dir = self.config["test_samples_path"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
id_imgs = self.config["id_imgs"]
crop_mode = self.config["crop_mode"]
attr_files = self.config["attr_files"]
specified_save_path = self.config["specified_save_path"]
self.arcface_ckpt= self.config["arcface_ckpt"]
imgs_list = []
self.reporter.writeInfo("Version %s"%version)
if os.path.isdir(specified_save_path):
print("Input a legal specified save path!")
save_dir = specified_save_path
if os.path.isdir(attr_files):
print("Input a dir....")
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
for item in imgs:
imgs_list.append(item)
print(imgs_list)
else:
print("Input an image....")
imgs_list.append(attr_files)
id_basename = os.path.basename(id_imgs)
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
# models
self.__init_framework__()
mode = crop_mode.lower()
if mode == "vggface":
mode = "none"
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
id_img = cv2.imread(id_imgs)
id_img_align_crop, _ = self.detect.get(id_img,512)
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
cos_loss = torch.nn.CosineSimilarity()
font = cv2.FONT_HERSHEY_SIMPLEX
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
cos_dict = {}
average_cos = 0
with torch.no_grad():
for img in imgs_list:
print(img)
attr_img_ori= cv2.imread(img)
try:
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
except:
continue
restored_face = attr_img_align_crop[0]
if self.config["preprocess"]:
_, _, restored_face = self.restorer.enhance(
restored_face, has_aligned=False, only_center_face=True, paste_back=True)
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB))
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
attr_id = self.arcface(attr_img_arc)
attr_id = F.normalize(attr_id, p=2, dim=1)
cos_dis = 1 - cos_loss(latend_id, attr_id)
mat = mat[0]
pred = self.network(attr_img, latend_id)
results = pred[0]
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
results_arc = self.arcface(results_arc)
results_arc = F.normalize(results_arc, p=2, dim=1)
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
average_cos += results_cos_dis
results = results * self.imagenet_std + self.imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
results = results.numpy()
results = np.clip(results,0.0,1.0)
img_white = np.full((512,512), 255, dtype=float)
# inverse the Affine transformation matrix
mat_rev = np.zeros([2,3])
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
mat_rev[0][0] = mat[1][1]/div1
mat_rev[0][1] = -mat[0][1]/div1
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
mat_rev[1][0] = mat[1][0]/div2
mat_rev[1][1] = -mat[0][0]/div2
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
target_image = cv2.warpAffine(results, mat_rev, orisize)
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
img_white[img_white>20] =255
img_mask = img_white
kernel = np.ones((40,40),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel_size = (20, 20)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
img_mask /= 255
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
img1 = np.array(attr_img_ori, dtype=np.float)
img1 = img_mask * target_image + (1-img_mask) * img1
final_img = img1.astype(np.uint8)
attr_basename = os.path.splitext(os.path.basename(img))[0]
if self.config["record_metric"]:
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
if self.config["preprocess"]:
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_gfpgan.png"%(id_basename,
attr_basename,ckp_step,version))
else:
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename, final_img)
if self.config["save_mask"]:
num = 0
for mask in pred[1:]:
mask = mask.cpu().permute(0,2,3,1)[0,...]
mask = mask.numpy()
mask = (mask - np.min(mask))/np.max(mask)
mask = np.clip(mask,0.0,1.0) * 255
if self.config["preprocess"]:
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask%d_gfpgan.png"%(id_basename,
attr_basename,ckp_step,version,num))
else:
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask%d.png"%(id_basename,
attr_basename,ckp_step,version,num))
cv2.imwrite(save_filename,mask)
num += 1
average_cos /= len(imgs_list)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
+280
View File
@@ -0,0 +1,280 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_commonn.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 14th April 2022 1:48:18 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import glob
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from insightface_func.face_detect_crop_single import Face_detect_crop
from face_enhancer.gfpgan import GFPGANer
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# for name in self.network.state_dict():
# print(name)
self.features = {}
mapping_layers = [
"first_layer",
"down4",
"BottleNeck.2"
]
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
version = '1.2'
if version == '1':
arch = 'original'
channel_multiplier = 1
model_name = 'GFPGANv1'
elif version == '1.2':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANCleanv1-NoCE-C2'
elif version == '1.3':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.3'
# determine model paths
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth')
if not os.path.isfile(model_path):
raise ValueError(f'Model {model_name} does not exist.')
self.restorer = GFPGANer(
model_path=model_path,
upscale=1,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=None)
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
def test(self):
save_dir = self.config["test_samples_path"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
id_imgs = self.config["id_imgs"]
crop_mode = self.config["crop_mode"]
attr_files = self.config["attr_files"]
specified_save_path = self.config["specified_save_path"]
self.arcface_ckpt= self.config["arcface_ckpt"]
imgs_list = []
self.reporter.writeInfo("Version %s"%version)
if os.path.isdir(specified_save_path):
print("Input a legal specified save path!")
save_dir = specified_save_path
if os.path.isdir(attr_files):
print("Input a dir....")
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
for item in imgs:
imgs_list.append(item)
print(imgs_list)
else:
print("Input an image....")
imgs_list.append(attr_files)
id_basename = os.path.basename(id_imgs)
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
# models
self.__init_framework__()
mode = crop_mode.lower()
if mode == "vggface":
mode = "none"
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
id_img = cv2.imread(id_imgs)
id_img_align_crop, _ = self.detect.get(id_img,512)
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
cos_loss = torch.nn.CosineSimilarity()
font = cv2.FONT_HERSHEY_SIMPLEX
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
cos_dict = {}
average_cos = 0
with torch.no_grad():
for img in imgs_list:
print(img)
attr_img_ori= cv2.imread(img)
try:
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
except:
continue
_, _, restored_face = self.restorer.enhance(
attr_img_align_crop[0], has_aligned=False, only_center_face=True, paste_back=True)
# cv2.imwrite("id_wocao.png",restored_face)
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB))
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
attr_id = self.arcface(attr_img_arc)
attr_id = F.normalize(attr_id, p=2, dim=1)
cos_dis = 1 - cos_loss(latend_id, attr_id)
mat = mat[0]
results,mask= self.network(attr_img, latend_id)
mask = mask.cpu().permute(0,2,3,1)[0,...]
mask = mask.numpy()
mask = (mask - np.min(mask))/np.max(mask)
mask = np.clip(mask,0.0,1.0) * 255
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
results_arc = self.arcface(results_arc)
results_arc = F.normalize(results_arc, p=2, dim=1)
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
average_cos += results_cos_dis
results = results * self.imagenet_std + self.imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
results = results.numpy()
results = np.clip(results,0.0,1.0)
img_white = np.full((512,512), 255, dtype=float)
# inverse the Affine transformation matrix
mat_rev = np.zeros([2,3])
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
mat_rev[0][0] = mat[1][1]/div1
mat_rev[0][1] = -mat[0][1]/div1
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
mat_rev[1][0] = mat[1][0]/div2
mat_rev[1][1] = -mat[0][0]/div2
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
target_image = cv2.warpAffine(results, mat_rev, orisize)
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
img_white[img_white>20] =255
img_mask = img_white
kernel = np.ones((40,40),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel_size = (20, 20)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
img_mask /= 255
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
img1 = np.array(attr_img_ori, dtype=np.float)
img1 = img_mask * target_image + (1-img_mask) * img1
final_img = img1.astype(np.uint8)
attr_basename = os.path.splitext(os.path.basename(img))[0]
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename, final_img)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s_mask.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename,mask)
average_cos /= len(imgs_list)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
+42 -3
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 21st January 2022 11:06:37 am
# Last Modified: Friday, 22nd April 2022 11:20:19 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -33,6 +33,8 @@ from utilities.ImagenetNorm import ImagenetNorm
from parsing_model.model import BiSeNet
from insightface_func.face_detect_crop_single import Face_detect_crop
from utilities.reverse2original import reverse2wholeimage
from face_enhancer.gfpgan import GFPGANer
from utilities.utilities import load_file_from_url
class Tester(object):
def __init__(self, config, reporter):
@@ -64,6 +66,7 @@ class Tester(object):
def video_swap(
self,
video_path,
gfpgan,
id_vetor,
save_path,
temp_results_dir='./temp_results',
@@ -121,8 +124,11 @@ class Tester(object):
swap_result_list = []
frame_align_crop_tenor_list = []
for frame_align_crop in frame_align_crop_list:
if gfpgan:
_, _, frame_align_crop = gfpgan.enhance(
frame_align_crop, has_aligned=False, only_center_face=True, paste_back=True)
frame_align_crop_tenor = self.cv2totensor(frame_align_crop)
swap_result = self.network(frame_align_crop_tenor, id_vetor)[0]
swap_result = self.network(frame_align_crop_tenor, id_vetor)[0][0]
swap_result = swap_result* self.imagenet_std + self.imagenet_mean
swap_result = torch.clip(swap_result,0.0,1.0)
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
@@ -216,6 +222,39 @@ class Tester(object):
# models
self.__init_framework__()
if self.config["preprocess"]:
print("Employ GFPGAN to upsampling detected face images!")
version = '1.2'
if version == '1':
arch = 'original'
channel_multiplier = 1
model_name = 'GFPGANv1'
elif version == '1.2':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANCleanv1-NoCE-C2'
elif version == '1.3':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.3'
# determine model paths
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
url_path = "https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth"
if not os.path.isfile(model_path):
# raise ValueError(f'Model {model_name} does not exist.')
print(f'Model {model_name} does not exist. Prepare to download it......')
model_path = load_file_from_url(
url=url_path, model_dir=model_path, progress=True, file_name=None)
restorer = GFPGANer(
model_path=model_path,
upscale=1,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=None)
else:
restorer = None
mode = None
@@ -239,7 +278,7 @@ class Tester(object):
start_time = time.time()
self.network.eval()
with torch.no_grad():
self.video_swap(attr_files, latend_id, save_dir, temp_results_dir="./.temples",\
self.video_swap(attr_files, restorer, latend_id, save_dir, temp_results_dir="./.temples",\
use_mask=False,crop_size=512)
elapsed = time.time() - start_time
+285
View File
@@ -0,0 +1,285 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_commonn.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 14th April 2022 11:40:45 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import shutil
import torch
import torch.nn.functional as F
from torchvision import transforms
from moviepy.editor import AudioFileClip, VideoFileClip
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
import numpy as np
from tqdm import tqdm
from PIL import Image
import glob
from utilities.ImagenetNorm import ImagenetNorm
from parsing_model.model import BiSeNet
from insightface_func.face_detect_crop_single import Face_detect_crop
from utilities.reverse2original import reverse2wholeimage
from face_enhancer.gfpgan import GFPGANer
from utilities.utilities import load_file_from_url
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def cv2totensor(self, cv2_img):
"""
cv2_img: an image read by cv2, H*W*C
return: an 1*C*H*W tensor
"""
cv2_img = cv2.cvtColor(cv2_img,cv2.COLOR_BGR2RGB)
cv2_img = torch.from_numpy(cv2_img)
cv2_img = cv2_img.permute(2,0,1).cuda()
temp = cv2_img / 255.0
temp -= self.imagenet_mean
temp /= self.imagenet_std
return temp.unsqueeze(0)
def video_swap(
self,
video_path,
gfpgan,
id_vetor,
save_path,
temp_results_dir='./temp_results',
crop_size=512,
use_mask =False
):
video_forcheck = VideoFileClip(video_path)
if video_forcheck.audio is None:
no_audio = True
else:
no_audio = False
del video_forcheck
if not no_audio:
video_audio_clip = AudioFileClip(video_path)
video = cv2.VideoCapture(video_path)
ret = True
frame_index = 0
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
# video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = video.get(cv2.CAP_PROP_FPS)
if os.path.exists(temp_results_dir):
shutil.rmtree(temp_results_dir)
spNorm =ImagenetNorm()
if use_mask:
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = os.path.join('./parsing_model', '79999_iter.pth')
net.load_state_dict(torch.load(save_pth))
net.eval()
else:
net =None
# while ret:
for frame_index in tqdm(range(frame_count)):
ret, frame = video.read()
if ret:
detect_results = self.detect.get(frame,crop_size)
if detect_results is not None:
# print(frame_index)
if not os.path.exists(temp_results_dir):
os.mkdir(temp_results_dir)
frame_align_crop_list = detect_results[0]
frame_mat_list = detect_results[1]
swap_result_list = []
frame_align_crop_tenor_list = []
for frame_align_crop in frame_align_crop_list:
_, _, restored_face = gfpgan.enhance(
frame_align_crop, has_aligned=False, only_center_face=True, paste_back=True)
frame_align_crop_tenor = self.cv2totensor(restored_face)
swap_result = self.network(frame_align_crop_tenor, id_vetor)[0][0]
swap_result = swap_result* self.imagenet_std + self.imagenet_mean
swap_result = torch.clip(swap_result,0.0,1.0)
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
swap_result_list.append(swap_result)
frame_align_crop_tenor_list.append(frame_align_crop_tenor)
reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame,\
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),pasring_model =net,use_mask=use_mask, norm = spNorm)
else:
if not os.path.exists(temp_results_dir):
os.mkdir(temp_results_dir)
frame = frame.astype(np.uint8)
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
else:
break
video.release()
# image_filename_list = []
path = os.path.join(temp_results_dir,'*.jpg')
image_filenames = sorted(glob.glob(path))
clips = ImageSequenceClip(image_filenames,fps = fps)
if not no_audio:
clips = clips.set_audio(video_audio_clip)
basename = os.path.basename(video_path)
basename = os.path.splitext(basename)[0]
save_filename = os.path.join(save_path, basename+".mp4")
index = 0
while(True):
if os.path.exists(save_filename):
save_filename = os.path.join(save_path, basename+"_%d.mp4"%index)
index += 1
else:
break
clips.write_videofile(save_filename,audio_codec='aac')
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
# print(loader1.key())
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path))
# self.network.load_state_dict(torch.load(pathwocao))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
def test(self):
# save_result = self.config["saveTestResult"]
save_dir = self.config["test_samples_path"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
id_imgs = self.config["id_imgs"]
attr_files = self.config["attr_files"]
self.arcface_ckpt= self.config["arcface_ckpt"]
# models
self.__init_framework__()
version = '1.2'
if version == '1':
arch = 'original'
channel_multiplier = 1
model_name = 'GFPGANv1'
elif version == '1.2':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANCleanv1-NoCE-C2'
elif version == '1.3':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.3'
# determine model paths
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
url_path = "https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth"
if not os.path.isfile(model_path):
# raise ValueError(f'Model {model_name} does not exist.')
print(f'Model {model_name} does not exist. Prepare to download it......')
model_path = load_file_from_url(
url=url_path, model_dir=model_path, progress=True, file_name=None)
restorer = GFPGANer(
model_path=model_path,
upscale=1,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=None)
mode = None
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode)
id_img = cv2.imread(id_imgs)
id_img_align_crop, _ = self.detect.get(id_img,512)
# _, _, restored_face = restorer.enhance(
# id_img_align_crop[0], has_aligned=False, only_center_face=True, paste_back=True)
# cv2.imwrite("id_wocao.png",restored_face)
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
with torch.no_grad():
self.video_swap(attr_files, restorer, latend_id, save_dir, temp_results_dir="./.temples",\
use_mask=False,crop_size=512)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))