update
This commit is contained in:
@@ -0,0 +1,279 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th April 2022 9:04:01 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# for name in self.network.state_dict():
|
||||
# print(name)
|
||||
self.features = {}
|
||||
mapping_layers = [
|
||||
"first_layer",
|
||||
"down4",
|
||||
"BottleNeck.2"
|
||||
]
|
||||
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
|
||||
|
||||
def test(self):
|
||||
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
crop_mode = self.config["crop_mode"]
|
||||
list_txt = self.config["img_list_txt"]
|
||||
record_metric= self.config["record_metric"]
|
||||
specified_save_path = self.config["specified_save_path"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
|
||||
self.reporter.writeInfo("Version %s"%version)
|
||||
|
||||
if os.path.isdir(specified_save_path):
|
||||
print("Input a legal specified save path!")
|
||||
save_dir = specified_save_path
|
||||
imgs_list = []
|
||||
with open(list_txt,'r') as logf:
|
||||
for line in logf:
|
||||
cells = line.split(";")
|
||||
imgs_list.append([cells[0],cells[1],cells[2].replace("\n","")])
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
mode = crop_mode.lower()
|
||||
if mode == "vggface":
|
||||
mode = "none"
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
cos_dict = {}
|
||||
average_cos = 0
|
||||
with torch.no_grad():
|
||||
for img in imgs_list:
|
||||
id_img_n, attr_img_n, fusion= img
|
||||
print("id image:%s---attr image:%s"%(id_img_n, attr_img_n))
|
||||
id_img = cv2.imread(id_img_n)
|
||||
print(fusion)
|
||||
if fusion.lower() == "fusion":
|
||||
try:
|
||||
id_img_align_crop, _ = self.detect.get(id_img,512)
|
||||
except:
|
||||
print("Image %s Do not detect a face!"%id_img_n)
|
||||
continue
|
||||
# id_basename = os.path.splitext(os.path.basename(id_img_n))[0]
|
||||
# cv2.imwrite(os.path.join(save_dir, "id_%s.png"%(id_basename)),id_img_align_crop[0])
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
else:
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB))
|
||||
|
||||
id_img = self.transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0).cuda()
|
||||
|
||||
#create latent id
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
latend_id = self.arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
attr_img_ori= cv2.imread(attr_img_n)
|
||||
|
||||
if fusion.lower() == "fusion":
|
||||
try:
|
||||
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
|
||||
except:
|
||||
print("Image %s Do not detect a face!"%attr_img_n)
|
||||
continue
|
||||
|
||||
# attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0]
|
||||
# cv2.imwrite(os.path.join(save_dir, "attr_%s.png"%(attr_basename)),attr_img_align_crop[0])
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
|
||||
else:
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_ori,cv2.COLOR_BGR2RGB))
|
||||
|
||||
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
|
||||
|
||||
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
|
||||
|
||||
attr_id = self.arcface(attr_img_arc)
|
||||
attr_id = F.normalize(attr_id, p=2, dim=1)
|
||||
cos_dis = 1 - cos_loss(latend_id, attr_id)
|
||||
|
||||
|
||||
results,mask_lr,mask_hr= self.network(attr_img, latend_id)
|
||||
|
||||
mask_lr = mask_lr.cpu().permute(0,2,3,1)[0,...]
|
||||
mask_lr = mask_lr.numpy()
|
||||
# mask_lr = (mask_lr - np.min(mask_lr))/np.max(mask_lr)
|
||||
mask_lr = np.clip(mask_lr,0.0,1.0) * 255
|
||||
mask_hr = mask_hr.cpu().permute(0,2,3,1)[0,...]
|
||||
mask_hr = mask_hr.numpy()
|
||||
# mask_hr = (mask_hr - np.min(mask_hr))/np.max(mask_hr)
|
||||
mask_hr = np.clip(mask_hr,0.0,1.0) * 255
|
||||
|
||||
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
|
||||
results_arc = self.arcface(results_arc)
|
||||
results_arc = F.normalize(results_arc, p=2, dim=1)
|
||||
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
|
||||
average_cos += results_cos_dis
|
||||
|
||||
results = results * self.imagenet_std + self.imagenet_mean
|
||||
results = results.cpu().permute(0,2,3,1)[0,...]
|
||||
results = results.numpy()
|
||||
results = np.clip(results,0.0,1.0)
|
||||
if fusion.lower() == "fusion":
|
||||
mat = mat[0]
|
||||
img_white = np.full((512,512), 255, dtype=float)
|
||||
|
||||
# inverse the Affine transformation matrix
|
||||
mat_rev = np.zeros([2,3])
|
||||
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
|
||||
mat_rev[0][0] = mat[1][1]/div1
|
||||
mat_rev[0][1] = -mat[0][1]/div1
|
||||
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
|
||||
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
|
||||
mat_rev[1][0] = mat[1][0]/div2
|
||||
mat_rev[1][1] = -mat[0][0]/div2
|
||||
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
|
||||
|
||||
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
|
||||
|
||||
target_image = cv2.warpAffine(results, mat_rev, orisize)
|
||||
|
||||
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
|
||||
|
||||
|
||||
img_white[img_white>20] =255
|
||||
|
||||
img_mask = img_white
|
||||
|
||||
kernel = np.ones((40,40),np.uint8)
|
||||
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
kernel_size = (20, 20)
|
||||
blur_size = tuple(2*i+1 for i in kernel_size)
|
||||
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
|
||||
|
||||
img_mask /= 255
|
||||
|
||||
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
||||
|
||||
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
|
||||
|
||||
img1 = np.array(attr_img_ori, dtype=np.float)
|
||||
img1 = img_mask * target_image + (1-img_mask) * img1
|
||||
else:
|
||||
results = results*255
|
||||
img1 = cv2.cvtColor(results,cv2.COLOR_RGB2BGR)
|
||||
|
||||
final_img = img1.astype(np.uint8)
|
||||
id_basename = os.path.basename(id_img_n)
|
||||
id_basename = os.path.splitext(os.path.basename(id_img_n))[0]
|
||||
attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0]
|
||||
if record_metric:
|
||||
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
|
||||
print(save_dir)
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
|
||||
cv2.imwrite(save_filename, final_img)
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask_lr.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
cv2.imwrite(save_filename,mask_lr)
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask_hr.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
cv2.imwrite(save_filename,mask_hr)
|
||||
|
||||
average_cos /= len(imgs_list)
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
@@ -0,0 +1,328 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 23rd April 2022 10:04:51 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# for name in self.network.state_dict():
|
||||
# print(name)
|
||||
self.features = {}
|
||||
mapping_layers = [
|
||||
"first_layer",
|
||||
"down4",
|
||||
"BottleNeck.2"
|
||||
]
|
||||
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
|
||||
if self.config["preprocess"]:
|
||||
print("Employ GFPGAN to upsampling detected face images!")
|
||||
from face_enhancer.gfpgan import GFPGANer
|
||||
version = '1.2'
|
||||
if version == '1':
|
||||
arch = 'original'
|
||||
channel_multiplier = 1
|
||||
model_name = 'GFPGANv1'
|
||||
elif version == '1.2':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANCleanv1-NoCE-C2'
|
||||
elif version == '1.3':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANv1.3'
|
||||
|
||||
# determine model paths
|
||||
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
raise ValueError(f'Model {model_name} does not exist.')
|
||||
|
||||
self.restorer = GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=1,
|
||||
arch=arch,
|
||||
channel_multiplier=channel_multiplier,
|
||||
bg_upsampler=None)
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
|
||||
|
||||
def test(self):
|
||||
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
crop_mode = self.config["crop_mode"]
|
||||
list_txt = self.config["img_list_txt"]
|
||||
record_metric= self.config["record_metric"]
|
||||
specified_save_path = self.config["specified_save_path"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
|
||||
self.reporter.writeInfo("Version %s"%version)
|
||||
|
||||
if os.path.isdir(specified_save_path):
|
||||
print("Input a legal specified save path!")
|
||||
save_dir = specified_save_path
|
||||
imgs_list = []
|
||||
with open(list_txt,'r') as logf:
|
||||
for line in logf:
|
||||
cells = line.split(";")
|
||||
imgs_list.append([cells[0],cells[1],cells[2].replace("\n","")])
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
mode = crop_mode.lower()
|
||||
if mode == "vggface":
|
||||
mode = "none"
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
cos_dict = {}
|
||||
average_cos = 0
|
||||
with torch.no_grad():
|
||||
for img in imgs_list:
|
||||
id_img_n, attr_img_n, fusion= img
|
||||
print("id image:%s---attr image:%s"%(id_img_n, attr_img_n))
|
||||
id_img = cv2.imread(id_img_n)
|
||||
print(fusion)
|
||||
if fusion.lower() == "fusion":
|
||||
try:
|
||||
id_img_align_crop, _ = self.detect.get(id_img,512)
|
||||
except:
|
||||
print("Image %s Do not detect a face!"%id_img_n)
|
||||
continue
|
||||
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
else:
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB))
|
||||
|
||||
id_img = self.transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0).cuda()
|
||||
|
||||
#create latent id
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
latend_id = self.arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
attr_img_ori= cv2.imread(attr_img_n)
|
||||
|
||||
if fusion.lower() == "fusion":
|
||||
try:
|
||||
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
|
||||
except:
|
||||
print("Image %s Do not detect a face!"%attr_img_n)
|
||||
continue
|
||||
|
||||
# attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0]
|
||||
# cv2.imwrite(os.path.join(save_dir, "attr_%s.png"%(attr_basename)),attr_img_align_crop[0])
|
||||
restored_face = attr_img_align_crop[0]
|
||||
if self.config["preprocess"]:
|
||||
_, _, restored_face = self.restorer.enhance(
|
||||
restored_face, has_aligned=False, only_center_face=True, paste_back=True)
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB))
|
||||
|
||||
else:
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_ori,cv2.COLOR_BGR2RGB))
|
||||
|
||||
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
|
||||
|
||||
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
|
||||
|
||||
attr_id = self.arcface(attr_img_arc)
|
||||
attr_id = F.normalize(attr_id, p=2, dim=1)
|
||||
cos_dis = 1 - cos_loss(latend_id, attr_id)
|
||||
|
||||
|
||||
# results,mask= self.network(attr_img, latend_id)
|
||||
pred = self.network(attr_img, latend_id)
|
||||
results = pred[0]
|
||||
|
||||
|
||||
|
||||
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
|
||||
results_arc = self.arcface(results_arc)
|
||||
results_arc = F.normalize(results_arc, p=2, dim=1)
|
||||
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
|
||||
average_cos += results_cos_dis
|
||||
|
||||
results = results * self.imagenet_std + self.imagenet_mean
|
||||
results = results.cpu().permute(0,2,3,1)[0,...]
|
||||
results = results.numpy()
|
||||
results = np.clip(results,0.0,1.0)
|
||||
if fusion.lower() == "fusion":
|
||||
mat = mat[0]
|
||||
img_white = np.full((512,512), 255, dtype=float)
|
||||
|
||||
# inverse the Affine transformation matrix
|
||||
mat_rev = np.zeros([2,3])
|
||||
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
|
||||
mat_rev[0][0] = mat[1][1]/div1
|
||||
mat_rev[0][1] = -mat[0][1]/div1
|
||||
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
|
||||
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
|
||||
mat_rev[1][0] = mat[1][0]/div2
|
||||
mat_rev[1][1] = -mat[0][0]/div2
|
||||
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
|
||||
|
||||
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
|
||||
|
||||
target_image = cv2.warpAffine(results, mat_rev, orisize)
|
||||
|
||||
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
|
||||
|
||||
|
||||
img_white[img_white>20] =255
|
||||
|
||||
img_mask = img_white
|
||||
|
||||
kernel = np.ones((40,40),np.uint8)
|
||||
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
kernel_size = (20, 20)
|
||||
blur_size = tuple(2*i+1 for i in kernel_size)
|
||||
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
|
||||
|
||||
img_mask /= 255
|
||||
|
||||
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
||||
|
||||
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
|
||||
|
||||
img1 = np.array(attr_img_ori, dtype=np.float)
|
||||
img1 = img_mask * target_image + (1-img_mask) * img1
|
||||
else:
|
||||
results = results*255
|
||||
img1 = cv2.cvtColor(results,cv2.COLOR_RGB2BGR)
|
||||
|
||||
final_img = img1.astype(np.uint8)
|
||||
id_basename = os.path.basename(id_img_n)
|
||||
id_basename = os.path.splitext(os.path.basename(id_img_n))[0]
|
||||
attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0]
|
||||
if record_metric:
|
||||
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
|
||||
print(save_dir)
|
||||
if self.config["preprocess"]:
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_gfpgan.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
else:
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
|
||||
cv2.imwrite(save_filename, final_img)
|
||||
|
||||
if self.config["save_mask"]:
|
||||
num = 0
|
||||
|
||||
for mask in pred[1:]:
|
||||
|
||||
mask = mask.cpu().permute(0,2,3,1)[0,...]
|
||||
mask = mask.numpy()
|
||||
mask = (mask - np.min(mask))/np.max(mask)
|
||||
mask = np.clip(mask,0.0,1.0) * 255
|
||||
|
||||
if self.config["preprocess"]:
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask%d_gfpgan.png"%(id_basename,
|
||||
attr_basename,ckp_step,version,num))
|
||||
else:
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask%d.png"%(id_basename,
|
||||
attr_basename,ckp_step,version,num))
|
||||
|
||||
|
||||
cv2.imwrite(save_filename,mask)
|
||||
num += 1
|
||||
average_cos /= len(imgs_list)
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th April 2022 10:09:21 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# for name in self.network.state_dict():
|
||||
# print(name)
|
||||
self.features = {}
|
||||
mapping_layers = [
|
||||
"first_layer",
|
||||
"down4",
|
||||
"BottleNeck.2"
|
||||
]
|
||||
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
|
||||
|
||||
def test(self):
|
||||
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
id_imgs = self.config["id_imgs"]
|
||||
crop_mode = self.config["crop_mode"]
|
||||
attr_files = self.config["attr_files"]
|
||||
specified_save_path = self.config["specified_save_path"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
|
||||
self.reporter.writeInfo("Version %s"%version)
|
||||
|
||||
if os.path.isdir(specified_save_path):
|
||||
print("Input a legal specified save path!")
|
||||
save_dir = specified_save_path
|
||||
|
||||
if os.path.isdir(attr_files):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
|
||||
for item in imgs:
|
||||
imgs_list.append(item)
|
||||
print(imgs_list)
|
||||
else:
|
||||
print("Input an image....")
|
||||
imgs_list.append(attr_files)
|
||||
id_basename = os.path.basename(id_imgs)
|
||||
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
mode = crop_mode.lower()
|
||||
if mode == "vggface":
|
||||
mode = "none"
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
id_img = cv2.imread(id_imgs)
|
||||
id_img_align_crop, _ = self.detect.get(id_img,512)
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
id_img = self.transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0).cuda()
|
||||
|
||||
#create latent id
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
latend_id = self.arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
cos_dict = {}
|
||||
average_cos = 0
|
||||
with torch.no_grad():
|
||||
for img in imgs_list:
|
||||
print(img)
|
||||
attr_img_ori= cv2.imread(img)
|
||||
try:
|
||||
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
|
||||
except:
|
||||
continue
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
|
||||
|
||||
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
|
||||
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
|
||||
attr_id = self.arcface(attr_img_arc)
|
||||
attr_id = F.normalize(attr_id, p=2, dim=1)
|
||||
cos_dis = 1 - cos_loss(latend_id, attr_id)
|
||||
|
||||
mat = mat[0]
|
||||
results,mask_lr,mask_hr= self.network(attr_img, latend_id)
|
||||
|
||||
mask_lr = mask_lr.cpu().permute(0,2,3,1)[0,...]
|
||||
mask_lr = mask_lr.numpy()
|
||||
# mask_lr = (mask_lr - np.min(mask_lr))/np.max(mask_lr)
|
||||
mask_lr = np.clip(mask_lr,0.0,1.0) * 255
|
||||
mask_hr = mask_hr.cpu().permute(0,2,3,1)[0,...]
|
||||
mask_hr = mask_hr.numpy()
|
||||
# mask_hr = (mask_hr - np.min(mask_hr))/np.max(mask_hr)
|
||||
mask_hr = np.clip(mask_hr,0.0,1.0) * 255
|
||||
|
||||
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
|
||||
results_arc = self.arcface(results_arc)
|
||||
results_arc = F.normalize(results_arc, p=2, dim=1)
|
||||
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
|
||||
average_cos += results_cos_dis
|
||||
|
||||
results = results * self.imagenet_std + self.imagenet_mean
|
||||
results = results.cpu().permute(0,2,3,1)[0,...]
|
||||
results = results.numpy()
|
||||
results = np.clip(results,0.0,1.0)
|
||||
img_white = np.full((512,512), 255, dtype=float)
|
||||
|
||||
# inverse the Affine transformation matrix
|
||||
mat_rev = np.zeros([2,3])
|
||||
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
|
||||
mat_rev[0][0] = mat[1][1]/div1
|
||||
mat_rev[0][1] = -mat[0][1]/div1
|
||||
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
|
||||
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
|
||||
mat_rev[1][0] = mat[1][0]/div2
|
||||
mat_rev[1][1] = -mat[0][0]/div2
|
||||
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
|
||||
|
||||
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
|
||||
|
||||
target_image = cv2.warpAffine(results, mat_rev, orisize)
|
||||
|
||||
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
|
||||
|
||||
|
||||
img_white[img_white>20] =255
|
||||
|
||||
img_mask = img_white
|
||||
|
||||
kernel = np.ones((40,40),np.uint8)
|
||||
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
kernel_size = (20, 20)
|
||||
blur_size = tuple(2*i+1 for i in kernel_size)
|
||||
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
|
||||
|
||||
img_mask /= 255
|
||||
|
||||
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
||||
|
||||
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
|
||||
|
||||
img1 = np.array(attr_img_ori, dtype=np.float)
|
||||
img1 = img_mask * target_image + (1-img_mask) * img1
|
||||
final_img = img1.astype(np.uint8)
|
||||
attr_basename = os.path.splitext(os.path.basename(img))[0]
|
||||
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
|
||||
cv2.imwrite(save_filename, final_img)
|
||||
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask_lr.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
cv2.imwrite(save_filename,mask_lr)
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask_hr.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
cv2.imwrite(save_filename,mask_hr)
|
||||
|
||||
average_cos /= len(imgs_list)
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
@@ -0,0 +1,286 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 16th April 2022 5:20:54 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
from face_enhancer.gfpgan import GFPGANer
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# for name in self.network.state_dict():
|
||||
# print(name)
|
||||
self.features = {}
|
||||
mapping_layers = [
|
||||
"first_layer",
|
||||
"down4",
|
||||
"BottleNeck.2"
|
||||
]
|
||||
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
version = '1.2'
|
||||
if version == '1':
|
||||
arch = 'original'
|
||||
channel_multiplier = 1
|
||||
model_name = 'GFPGANv1'
|
||||
elif version == '1.2':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANCleanv1-NoCE-C2'
|
||||
elif version == '1.3':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANv1.3'
|
||||
|
||||
# determine model paths
|
||||
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
raise ValueError(f'Model {model_name} does not exist.')
|
||||
|
||||
self.restorer = GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=1,
|
||||
arch=arch,
|
||||
channel_multiplier=channel_multiplier,
|
||||
bg_upsampler=None)
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
|
||||
|
||||
def test(self):
|
||||
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
id_imgs = self.config["id_imgs"]
|
||||
crop_mode = self.config["crop_mode"]
|
||||
attr_files = self.config["attr_files"]
|
||||
specified_save_path = self.config["specified_save_path"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
|
||||
self.reporter.writeInfo("Version %s"%version)
|
||||
|
||||
if os.path.isdir(specified_save_path):
|
||||
print("Input a legal specified save path!")
|
||||
save_dir = specified_save_path
|
||||
|
||||
if os.path.isdir(attr_files):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
|
||||
for item in imgs:
|
||||
imgs_list.append(item)
|
||||
print(imgs_list)
|
||||
else:
|
||||
print("Input an image....")
|
||||
imgs_list.append(attr_files)
|
||||
id_basename = os.path.basename(id_imgs)
|
||||
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
mode = crop_mode.lower()
|
||||
if mode == "vggface":
|
||||
mode = "none"
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
id_img = cv2.imread(id_imgs)
|
||||
id_img_align_crop, _ = self.detect.get(id_img,512)
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
id_img = self.transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0).cuda()
|
||||
|
||||
#create latent id
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
latend_id = self.arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
cos_dict = {}
|
||||
average_cos = 0
|
||||
with torch.no_grad():
|
||||
for img in imgs_list:
|
||||
print(img)
|
||||
attr_img_ori= cv2.imread(img)
|
||||
try:
|
||||
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
|
||||
except:
|
||||
continue
|
||||
_, _, restored_face = self.restorer.enhance(
|
||||
attr_img_align_crop[0], has_aligned=False, only_center_face=True, paste_back=True)
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB))
|
||||
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
|
||||
|
||||
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
|
||||
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
|
||||
attr_id = self.arcface(attr_img_arc)
|
||||
attr_id = F.normalize(attr_id, p=2, dim=1)
|
||||
cos_dis = 1 - cos_loss(latend_id, attr_id)
|
||||
|
||||
mat = mat[0]
|
||||
results,mask_lr,mask_hr= self.network(attr_img, latend_id)
|
||||
|
||||
mask_lr = mask_lr.cpu().permute(0,2,3,1)[0,...]
|
||||
mask_lr = mask_lr.numpy()
|
||||
# mask_lr = (mask_lr - np.min(mask_lr))/np.max(mask_lr)
|
||||
mask_lr = np.clip(mask_lr,0.0,1.0) * 255
|
||||
mask_hr = mask_hr.cpu().permute(0,2,3,1)[0,...]
|
||||
mask_hr = mask_hr.numpy()
|
||||
# mask_hr = (mask_hr - np.min(mask_hr))/np.max(mask_hr)
|
||||
mask_hr = np.clip(mask_hr,0.0,1.0) * 255
|
||||
|
||||
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
|
||||
results_arc = self.arcface(results_arc)
|
||||
results_arc = F.normalize(results_arc, p=2, dim=1)
|
||||
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
|
||||
average_cos += results_cos_dis
|
||||
|
||||
results = results * self.imagenet_std + self.imagenet_mean
|
||||
results = results.cpu().permute(0,2,3,1)[0,...]
|
||||
results = results.numpy()
|
||||
results = np.clip(results,0.0,1.0)
|
||||
img_white = np.full((512,512), 255, dtype=float)
|
||||
|
||||
# inverse the Affine transformation matrix
|
||||
mat_rev = np.zeros([2,3])
|
||||
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
|
||||
mat_rev[0][0] = mat[1][1]/div1
|
||||
mat_rev[0][1] = -mat[0][1]/div1
|
||||
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
|
||||
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
|
||||
mat_rev[1][0] = mat[1][0]/div2
|
||||
mat_rev[1][1] = -mat[0][0]/div2
|
||||
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
|
||||
|
||||
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
|
||||
|
||||
target_image = cv2.warpAffine(results, mat_rev, orisize)
|
||||
|
||||
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
|
||||
|
||||
|
||||
img_white[img_white>20] =255
|
||||
|
||||
img_mask = img_white
|
||||
|
||||
kernel = np.ones((40,40),np.uint8)
|
||||
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
kernel_size = (20, 20)
|
||||
blur_size = tuple(2*i+1 for i in kernel_size)
|
||||
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
|
||||
|
||||
img_mask /= 255
|
||||
|
||||
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
||||
|
||||
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
|
||||
|
||||
img1 = np.array(attr_img_ori, dtype=np.float)
|
||||
img1 = img_mask * target_image + (1-img_mask) * img1
|
||||
final_img = img1.astype(np.uint8)
|
||||
attr_basename = os.path.splitext(os.path.basename(img))[0]
|
||||
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
|
||||
cv2.imwrite(save_filename, final_img)
|
||||
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask_lr.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
cv2.imwrite(save_filename,mask_lr)
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask_hr.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
cv2.imwrite(save_filename,mask_hr)
|
||||
|
||||
average_cos /= len(imgs_list)
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
@@ -0,0 +1,301 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 23rd April 2022 10:05:22 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# for name in self.network.state_dict():
|
||||
# print(name)
|
||||
self.features = {}
|
||||
mapping_layers = [
|
||||
"first_layer",
|
||||
"down4",
|
||||
"BottleNeck.2"
|
||||
]
|
||||
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
if self.config["preprocess"]:
|
||||
print("Employ GFPGAN to upsampling detected face images!")
|
||||
from face_enhancer.gfpgan import GFPGANer
|
||||
version = '1.2'
|
||||
if version == '1':
|
||||
arch = 'original'
|
||||
channel_multiplier = 1
|
||||
model_name = 'GFPGANv1'
|
||||
elif version == '1.2':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANCleanv1-NoCE-C2'
|
||||
elif version == '1.3':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANv1.3'
|
||||
|
||||
# determine model paths
|
||||
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
raise ValueError(f'Model {model_name} does not exist.')
|
||||
|
||||
self.restorer = GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=1,
|
||||
arch=arch,
|
||||
channel_multiplier=channel_multiplier,
|
||||
bg_upsampler=None)
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
|
||||
|
||||
def test(self):
|
||||
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
id_imgs = self.config["id_imgs"]
|
||||
crop_mode = self.config["crop_mode"]
|
||||
attr_files = self.config["attr_files"]
|
||||
specified_save_path = self.config["specified_save_path"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
|
||||
self.reporter.writeInfo("Version %s"%version)
|
||||
|
||||
if os.path.isdir(specified_save_path):
|
||||
print("Input a legal specified save path!")
|
||||
save_dir = specified_save_path
|
||||
|
||||
if os.path.isdir(attr_files):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
|
||||
for item in imgs:
|
||||
imgs_list.append(item)
|
||||
print(imgs_list)
|
||||
else:
|
||||
print("Input an image....")
|
||||
imgs_list.append(attr_files)
|
||||
id_basename = os.path.basename(id_imgs)
|
||||
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
mode = crop_mode.lower()
|
||||
if mode == "vggface":
|
||||
mode = "none"
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
id_img = cv2.imread(id_imgs)
|
||||
id_img_align_crop, _ = self.detect.get(id_img,512)
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
id_img = self.transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0).cuda()
|
||||
|
||||
#create latent id
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
latend_id = self.arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
cos_dict = {}
|
||||
average_cos = 0
|
||||
with torch.no_grad():
|
||||
for img in imgs_list:
|
||||
print(img)
|
||||
attr_img_ori= cv2.imread(img)
|
||||
try:
|
||||
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
|
||||
except:
|
||||
continue
|
||||
restored_face = attr_img_align_crop[0]
|
||||
if self.config["preprocess"]:
|
||||
_, _, restored_face = self.restorer.enhance(
|
||||
restored_face, has_aligned=False, only_center_face=True, paste_back=True)
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB))
|
||||
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
|
||||
|
||||
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
|
||||
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
|
||||
attr_id = self.arcface(attr_img_arc)
|
||||
attr_id = F.normalize(attr_id, p=2, dim=1)
|
||||
cos_dis = 1 - cos_loss(latend_id, attr_id)
|
||||
|
||||
mat = mat[0]
|
||||
pred = self.network(attr_img, latend_id)
|
||||
results = pred[0]
|
||||
|
||||
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
|
||||
results_arc = self.arcface(results_arc)
|
||||
results_arc = F.normalize(results_arc, p=2, dim=1)
|
||||
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
|
||||
average_cos += results_cos_dis
|
||||
|
||||
results = results * self.imagenet_std + self.imagenet_mean
|
||||
results = results.cpu().permute(0,2,3,1)[0,...]
|
||||
results = results.numpy()
|
||||
results = np.clip(results,0.0,1.0)
|
||||
img_white = np.full((512,512), 255, dtype=float)
|
||||
|
||||
# inverse the Affine transformation matrix
|
||||
mat_rev = np.zeros([2,3])
|
||||
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
|
||||
mat_rev[0][0] = mat[1][1]/div1
|
||||
mat_rev[0][1] = -mat[0][1]/div1
|
||||
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
|
||||
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
|
||||
mat_rev[1][0] = mat[1][0]/div2
|
||||
mat_rev[1][1] = -mat[0][0]/div2
|
||||
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
|
||||
|
||||
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
|
||||
|
||||
target_image = cv2.warpAffine(results, mat_rev, orisize)
|
||||
|
||||
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
|
||||
|
||||
|
||||
img_white[img_white>20] =255
|
||||
|
||||
img_mask = img_white
|
||||
|
||||
kernel = np.ones((40,40),np.uint8)
|
||||
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
kernel_size = (20, 20)
|
||||
blur_size = tuple(2*i+1 for i in kernel_size)
|
||||
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
|
||||
|
||||
img_mask /= 255
|
||||
|
||||
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
||||
|
||||
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
|
||||
|
||||
img1 = np.array(attr_img_ori, dtype=np.float)
|
||||
img1 = img_mask * target_image + (1-img_mask) * img1
|
||||
final_img = img1.astype(np.uint8)
|
||||
attr_basename = os.path.splitext(os.path.basename(img))[0]
|
||||
if self.config["record_metric"]:
|
||||
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
|
||||
if self.config["preprocess"]:
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_gfpgan.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
else:
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
cv2.imwrite(save_filename, final_img)
|
||||
|
||||
if self.config["save_mask"]:
|
||||
num = 0
|
||||
|
||||
for mask in pred[1:]:
|
||||
|
||||
mask = mask.cpu().permute(0,2,3,1)[0,...]
|
||||
mask = mask.numpy()
|
||||
mask = (mask - np.min(mask))/np.max(mask)
|
||||
mask = np.clip(mask,0.0,1.0) * 255
|
||||
|
||||
if self.config["preprocess"]:
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask%d_gfpgan.png"%(id_basename,
|
||||
attr_basename,ckp_step,version,num))
|
||||
else:
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask%d.png"%(id_basename,
|
||||
attr_basename,ckp_step,version,num))
|
||||
|
||||
|
||||
cv2.imwrite(save_filename,mask)
|
||||
num += 1
|
||||
|
||||
average_cos /= len(imgs_list)
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 14th April 2022 1:48:18 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
from face_enhancer.gfpgan import GFPGANer
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# for name in self.network.state_dict():
|
||||
# print(name)
|
||||
self.features = {}
|
||||
mapping_layers = [
|
||||
"first_layer",
|
||||
"down4",
|
||||
"BottleNeck.2"
|
||||
]
|
||||
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
|
||||
version = '1.2'
|
||||
if version == '1':
|
||||
arch = 'original'
|
||||
channel_multiplier = 1
|
||||
model_name = 'GFPGANv1'
|
||||
elif version == '1.2':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANCleanv1-NoCE-C2'
|
||||
elif version == '1.3':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANv1.3'
|
||||
|
||||
# determine model paths
|
||||
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
raise ValueError(f'Model {model_name} does not exist.')
|
||||
|
||||
self.restorer = GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=1,
|
||||
arch=arch,
|
||||
channel_multiplier=channel_multiplier,
|
||||
bg_upsampler=None)
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
|
||||
|
||||
def test(self):
|
||||
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
id_imgs = self.config["id_imgs"]
|
||||
crop_mode = self.config["crop_mode"]
|
||||
attr_files = self.config["attr_files"]
|
||||
specified_save_path = self.config["specified_save_path"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
|
||||
self.reporter.writeInfo("Version %s"%version)
|
||||
|
||||
if os.path.isdir(specified_save_path):
|
||||
print("Input a legal specified save path!")
|
||||
save_dir = specified_save_path
|
||||
|
||||
if os.path.isdir(attr_files):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
|
||||
for item in imgs:
|
||||
imgs_list.append(item)
|
||||
print(imgs_list)
|
||||
else:
|
||||
print("Input an image....")
|
||||
imgs_list.append(attr_files)
|
||||
id_basename = os.path.basename(id_imgs)
|
||||
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
mode = crop_mode.lower()
|
||||
if mode == "vggface":
|
||||
mode = "none"
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
id_img = cv2.imread(id_imgs)
|
||||
id_img_align_crop, _ = self.detect.get(id_img,512)
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
id_img = self.transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0).cuda()
|
||||
|
||||
#create latent id
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
latend_id = self.arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
cos_dict = {}
|
||||
average_cos = 0
|
||||
with torch.no_grad():
|
||||
for img in imgs_list:
|
||||
print(img)
|
||||
attr_img_ori= cv2.imread(img)
|
||||
try:
|
||||
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
|
||||
except:
|
||||
continue
|
||||
_, _, restored_face = self.restorer.enhance(
|
||||
attr_img_align_crop[0], has_aligned=False, only_center_face=True, paste_back=True)
|
||||
# cv2.imwrite("id_wocao.png",restored_face)
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB))
|
||||
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
|
||||
|
||||
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
|
||||
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
|
||||
attr_id = self.arcface(attr_img_arc)
|
||||
attr_id = F.normalize(attr_id, p=2, dim=1)
|
||||
cos_dis = 1 - cos_loss(latend_id, attr_id)
|
||||
|
||||
mat = mat[0]
|
||||
results,mask= self.network(attr_img, latend_id)
|
||||
|
||||
mask = mask.cpu().permute(0,2,3,1)[0,...]
|
||||
mask = mask.numpy()
|
||||
mask = (mask - np.min(mask))/np.max(mask)
|
||||
mask = np.clip(mask,0.0,1.0) * 255
|
||||
|
||||
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
|
||||
results_arc = self.arcface(results_arc)
|
||||
results_arc = F.normalize(results_arc, p=2, dim=1)
|
||||
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
|
||||
average_cos += results_cos_dis
|
||||
|
||||
results = results * self.imagenet_std + self.imagenet_mean
|
||||
results = results.cpu().permute(0,2,3,1)[0,...]
|
||||
results = results.numpy()
|
||||
results = np.clip(results,0.0,1.0)
|
||||
img_white = np.full((512,512), 255, dtype=float)
|
||||
|
||||
# inverse the Affine transformation matrix
|
||||
mat_rev = np.zeros([2,3])
|
||||
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
|
||||
mat_rev[0][0] = mat[1][1]/div1
|
||||
mat_rev[0][1] = -mat[0][1]/div1
|
||||
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
|
||||
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
|
||||
mat_rev[1][0] = mat[1][0]/div2
|
||||
mat_rev[1][1] = -mat[0][0]/div2
|
||||
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
|
||||
|
||||
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
|
||||
|
||||
target_image = cv2.warpAffine(results, mat_rev, orisize)
|
||||
|
||||
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
|
||||
|
||||
|
||||
img_white[img_white>20] =255
|
||||
|
||||
img_mask = img_white
|
||||
|
||||
kernel = np.ones((40,40),np.uint8)
|
||||
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
kernel_size = (20, 20)
|
||||
blur_size = tuple(2*i+1 for i in kernel_size)
|
||||
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
|
||||
|
||||
img_mask /= 255
|
||||
|
||||
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
||||
|
||||
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
|
||||
|
||||
img1 = np.array(attr_img_ori, dtype=np.float)
|
||||
img1 = img_mask * target_image + (1-img_mask) * img1
|
||||
final_img = img1.astype(np.uint8)
|
||||
attr_basename = os.path.splitext(os.path.basename(img))[0]
|
||||
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
|
||||
cv2.imwrite(save_filename, final_img)
|
||||
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s_mask.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
cv2.imwrite(save_filename,mask)
|
||||
|
||||
average_cos /= len(imgs_list)
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 21st January 2022 11:06:37 am
|
||||
# Last Modified: Friday, 22nd April 2022 11:20:19 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -33,6 +33,8 @@ from utilities.ImagenetNorm import ImagenetNorm
|
||||
from parsing_model.model import BiSeNet
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
from utilities.reverse2original import reverse2wholeimage
|
||||
from face_enhancer.gfpgan import GFPGANer
|
||||
from utilities.utilities import load_file_from_url
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
@@ -64,6 +66,7 @@ class Tester(object):
|
||||
def video_swap(
|
||||
self,
|
||||
video_path,
|
||||
gfpgan,
|
||||
id_vetor,
|
||||
save_path,
|
||||
temp_results_dir='./temp_results',
|
||||
@@ -121,8 +124,11 @@ class Tester(object):
|
||||
swap_result_list = []
|
||||
frame_align_crop_tenor_list = []
|
||||
for frame_align_crop in frame_align_crop_list:
|
||||
if gfpgan:
|
||||
_, _, frame_align_crop = gfpgan.enhance(
|
||||
frame_align_crop, has_aligned=False, only_center_face=True, paste_back=True)
|
||||
frame_align_crop_tenor = self.cv2totensor(frame_align_crop)
|
||||
swap_result = self.network(frame_align_crop_tenor, id_vetor)[0]
|
||||
swap_result = self.network(frame_align_crop_tenor, id_vetor)[0][0]
|
||||
swap_result = swap_result* self.imagenet_std + self.imagenet_mean
|
||||
swap_result = torch.clip(swap_result,0.0,1.0)
|
||||
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
|
||||
@@ -216,6 +222,39 @@ class Tester(object):
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
if self.config["preprocess"]:
|
||||
print("Employ GFPGAN to upsampling detected face images!")
|
||||
version = '1.2'
|
||||
if version == '1':
|
||||
arch = 'original'
|
||||
channel_multiplier = 1
|
||||
model_name = 'GFPGANv1'
|
||||
elif version == '1.2':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANCleanv1-NoCE-C2'
|
||||
elif version == '1.3':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANv1.3'
|
||||
|
||||
# determine model paths
|
||||
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
|
||||
url_path = "https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth"
|
||||
if not os.path.isfile(model_path):
|
||||
# raise ValueError(f'Model {model_name} does not exist.')
|
||||
print(f'Model {model_name} does not exist. Prepare to download it......')
|
||||
model_path = load_file_from_url(
|
||||
url=url_path, model_dir=model_path, progress=True, file_name=None)
|
||||
restorer = GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=1,
|
||||
arch=arch,
|
||||
channel_multiplier=channel_multiplier,
|
||||
bg_upsampler=None)
|
||||
else:
|
||||
restorer = None
|
||||
|
||||
|
||||
|
||||
mode = None
|
||||
@@ -239,7 +278,7 @@ class Tester(object):
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
self.video_swap(attr_files, latend_id, save_dir, temp_results_dir="./.temples",\
|
||||
self.video_swap(attr_files, restorer, latend_id, save_dir, temp_results_dir="./.temples",\
|
||||
use_mask=False,crop_size=512)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
@@ -0,0 +1,285 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 14th April 2022 11:40:45 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
from moviepy.editor import AudioFileClip, VideoFileClip
|
||||
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import glob
|
||||
|
||||
from utilities.ImagenetNorm import ImagenetNorm
|
||||
from parsing_model.model import BiSeNet
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
from utilities.reverse2original import reverse2wholeimage
|
||||
from face_enhancer.gfpgan import GFPGANer
|
||||
from utilities.utilities import load_file_from_url
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
|
||||
def cv2totensor(self, cv2_img):
|
||||
"""
|
||||
cv2_img: an image read by cv2, H*W*C
|
||||
return: an 1*C*H*W tensor
|
||||
"""
|
||||
cv2_img = cv2.cvtColor(cv2_img,cv2.COLOR_BGR2RGB)
|
||||
cv2_img = torch.from_numpy(cv2_img)
|
||||
cv2_img = cv2_img.permute(2,0,1).cuda()
|
||||
temp = cv2_img / 255.0
|
||||
temp -= self.imagenet_mean
|
||||
temp /= self.imagenet_std
|
||||
return temp.unsqueeze(0)
|
||||
|
||||
def video_swap(
|
||||
self,
|
||||
video_path,
|
||||
gfpgan,
|
||||
id_vetor,
|
||||
save_path,
|
||||
temp_results_dir='./temp_results',
|
||||
crop_size=512,
|
||||
use_mask =False
|
||||
):
|
||||
|
||||
video_forcheck = VideoFileClip(video_path)
|
||||
if video_forcheck.audio is None:
|
||||
no_audio = True
|
||||
else:
|
||||
no_audio = False
|
||||
|
||||
del video_forcheck
|
||||
|
||||
if not no_audio:
|
||||
video_audio_clip = AudioFileClip(video_path)
|
||||
|
||||
video = cv2.VideoCapture(video_path)
|
||||
ret = True
|
||||
frame_index = 0
|
||||
|
||||
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
# video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
|
||||
# video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
fps = video.get(cv2.CAP_PROP_FPS)
|
||||
if os.path.exists(temp_results_dir):
|
||||
shutil.rmtree(temp_results_dir)
|
||||
spNorm =ImagenetNorm()
|
||||
if use_mask:
|
||||
n_classes = 19
|
||||
net = BiSeNet(n_classes=n_classes)
|
||||
net.cuda()
|
||||
save_pth = os.path.join('./parsing_model', '79999_iter.pth')
|
||||
net.load_state_dict(torch.load(save_pth))
|
||||
net.eval()
|
||||
else:
|
||||
net =None
|
||||
|
||||
# while ret:
|
||||
for frame_index in tqdm(range(frame_count)):
|
||||
ret, frame = video.read()
|
||||
if ret:
|
||||
detect_results = self.detect.get(frame,crop_size)
|
||||
|
||||
if detect_results is not None:
|
||||
# print(frame_index)
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame_align_crop_list = detect_results[0]
|
||||
frame_mat_list = detect_results[1]
|
||||
swap_result_list = []
|
||||
frame_align_crop_tenor_list = []
|
||||
for frame_align_crop in frame_align_crop_list:
|
||||
_, _, restored_face = gfpgan.enhance(
|
||||
frame_align_crop, has_aligned=False, only_center_face=True, paste_back=True)
|
||||
frame_align_crop_tenor = self.cv2totensor(restored_face)
|
||||
swap_result = self.network(frame_align_crop_tenor, id_vetor)[0][0]
|
||||
swap_result = swap_result* self.imagenet_std + self.imagenet_mean
|
||||
swap_result = torch.clip(swap_result,0.0,1.0)
|
||||
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
|
||||
swap_result_list.append(swap_result)
|
||||
frame_align_crop_tenor_list.append(frame_align_crop_tenor)
|
||||
reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame,\
|
||||
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),pasring_model =net,use_mask=use_mask, norm = spNorm)
|
||||
|
||||
else:
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame = frame.astype(np.uint8)
|
||||
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
|
||||
else:
|
||||
break
|
||||
|
||||
video.release()
|
||||
|
||||
# image_filename_list = []
|
||||
path = os.path.join(temp_results_dir,'*.jpg')
|
||||
image_filenames = sorted(glob.glob(path))
|
||||
|
||||
clips = ImageSequenceClip(image_filenames,fps = fps)
|
||||
|
||||
if not no_audio:
|
||||
clips = clips.set_audio(video_audio_clip)
|
||||
basename = os.path.basename(video_path)
|
||||
basename = os.path.splitext(basename)[0]
|
||||
save_filename = os.path.join(save_path, basename+".mp4")
|
||||
index = 0
|
||||
while(True):
|
||||
if os.path.exists(save_filename):
|
||||
save_filename = os.path.join(save_path, basename+"_%d.mp4"%index)
|
||||
index += 1
|
||||
else:
|
||||
break
|
||||
clips.write_videofile(save_filename,audio_codec='aac')
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
|
||||
# print(loader1.key())
|
||||
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path))
|
||||
# self.network.load_state_dict(torch.load(pathwocao))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
def test(self):
|
||||
|
||||
# save_result = self.config["saveTestResult"]
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
id_imgs = self.config["id_imgs"]
|
||||
attr_files = self.config["attr_files"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
version = '1.2'
|
||||
if version == '1':
|
||||
arch = 'original'
|
||||
channel_multiplier = 1
|
||||
model_name = 'GFPGANv1'
|
||||
elif version == '1.2':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANCleanv1-NoCE-C2'
|
||||
elif version == '1.3':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANv1.3'
|
||||
|
||||
# determine model paths
|
||||
model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth')
|
||||
url_path = "https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth"
|
||||
|
||||
if not os.path.isfile(model_path):
|
||||
# raise ValueError(f'Model {model_name} does not exist.')
|
||||
print(f'Model {model_name} does not exist. Prepare to download it......')
|
||||
model_path = load_file_from_url(
|
||||
url=url_path, model_dir=model_path, progress=True, file_name=None)
|
||||
|
||||
restorer = GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=1,
|
||||
arch=arch,
|
||||
channel_multiplier=channel_multiplier,
|
||||
bg_upsampler=None)
|
||||
|
||||
|
||||
|
||||
mode = None
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
id_img = cv2.imread(id_imgs)
|
||||
id_img_align_crop, _ = self.detect.get(id_img,512)
|
||||
# _, _, restored_face = restorer.enhance(
|
||||
# id_img_align_crop[0], has_aligned=False, only_center_face=True, paste_back=True)
|
||||
# cv2.imwrite("id_wocao.png",restored_face)
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
id_img = self.transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0).cuda()
|
||||
|
||||
#create latent id
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
latend_id = self.arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
self.video_swap(attr_files, restorer, latend_id, save_dir, temp_results_dir="./.temples",\
|
||||
use_mask=False,crop_size=512)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
Reference in New Issue
Block a user