#!/usr/bin/env python3 # -*- coding:utf-8 -*- ############################################################# # File: tester_commonn.py # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com # Last Modified: Saturday, 29th January 2022 12:02:31 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# import os import cv2 import time import glob import torch import torch.nn.functional as F from torchvision import transforms import numpy as np from PIL import Image from insightface_func.face_detect_crop_single import Face_detect_crop class Tester(object): def __init__(self, config, reporter): self.config = config # logger self.reporter = reporter self.transformer_Arcface = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1) self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1) def __init_framework__(self): ''' This function is designed to define the framework, and print the framework information into the log file ''' #===============build models================# print("build models...") # TODO [import models here] model_config = self.config["model_configs"] gscript_name = self.config["com_base"] + model_config["g_model"]["script"] class_name = model_config["g_model"]["class_name"] package = __import__(gscript_name, fromlist=True) gen_class = getattr(package, class_name) self.network = gen_class(**model_config["g_model"]["module_params"]) # TODO replace below lines to define the model framework self.network = gen_class(**model_config["g_model"]["module_params"]) self.network = self.network.eval() # print and recorde model structure self.reporter.writeInfo("Model structure:") self.reporter.writeModel(self.network.__str__()) arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) self.arcface = arcface1['model'].module self.arcface.eval() self.arcface.requires_grad_(False) # train in GPU if self.config["cuda"] >=0: self.network = self.network.cuda() self.arcface = self.arcface.cuda() model_path = os.path.join(self.config["project_checkpoints"], "step%d_%s.pth"%(self.config["checkpoint_step"], self.config["checkpoint_names"]["generator_name"])) self.network.load_state_dict(torch.load(model_path)) print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"])) def test(self): save_dir = self.config["test_samples_path"] ckp_step = self.config["checkpoint_step"] version = self.config["version"] id_imgs = self.config["id_imgs"] attr_files = self.config["attr_files"] self.arcface_ckpt= self.config["arcface_ckpt"] imgs_list = [] if os.path.isdir(attr_files): print("Input a dir....") imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True) for item in imgs: imgs_list.append(item) print(imgs_list) else: print("Input an image....") imgs_list.append(attr_files) id_basename = os.path.basename(id_imgs) id_basename = os.path.splitext(os.path.basename(id_imgs))[0] # models self.__init_framework__() mode = None self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode) id_img = cv2.imread(id_imgs) id_img_align_crop, _ = self.detect.get(id_img,512) id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB)) id_img = self.transformer_Arcface(id_img_align_crop_pil) id_img = id_img.unsqueeze(0).cuda() #create latent id id_img = F.interpolate(id_img,size=(112,112), mode='bicubic') latend_id = self.arcface(id_img) latend_id = F.normalize(latend_id, p=2, dim=1) cos_loss = torch.nn.CosineSimilarity() font = cv2.FONT_HERSHEY_SIMPLEX # Start time import datetime print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) print('Start =================================== test...') start_time = time.time() self.network.eval() with torch.no_grad(): for img in imgs_list: print(img) attr_img_ori= cv2.imread(img) try: attr_img_align_crop, mat = self.detect.get(attr_img_ori,512) except: continue attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB)) attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda() attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic') # cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0]) attr_id = self.arcface(attr_img_arc) attr_id = F.normalize(attr_id, p=2, dim=1) cos_dis = 1 - cos_loss(latend_id, attr_id) mat = mat[0] results = self.network(attr_img, latend_id) results_arc = F.interpolate(results,size=(112,112), mode='bicubic') results_arc = self.arcface(results_arc) results_arc = F.normalize(results_arc, p=2, dim=1) results_cos_dis = 1 - cos_loss(latend_id, results_arc) results = results * self.imagenet_std + self.imagenet_mean results = results.cpu().permute(0,2,3,1)[0,...] results = results.numpy() results = np.clip(results,0.0,1.0) img_white = np.full((512,512), 255, dtype=float) # inverse the Affine transformation matrix mat_rev = np.zeros([2,3]) div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0] mat_rev[0][0] = mat[1][1]/div1 mat_rev[0][1] = -mat[0][1]/div1 mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1 div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1] mat_rev[1][0] = mat[1][0]/div2 mat_rev[1][1] = -mat[0][0]/div2 mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0]) target_image = cv2.warpAffine(results, mat_rev, orisize) img_white = cv2.warpAffine(img_white, mat_rev, orisize) img_white[img_white>20] =255 img_mask = img_white kernel = np.ones((40,40),np.uint8) img_mask = cv2.erode(img_mask,kernel,iterations = 1) kernel_size = (20, 20) blur_size = tuple(2*i+1 for i in kernel_size) img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) img_mask /= 255 img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 img1 = np.array(attr_img_ori, dtype=np.float) img1 = img_mask * target_image + (1-img_mask) * img1 final_img = img1.astype(np.uint8) attr_basename = os.path.splitext(os.path.basename(img))[0] final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2) final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2) save_filename = os.path.join(save_dir, "id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename, attr_basename,ckp_step,version)) cv2.imwrite(save_filename, final_img) elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) print("Elapsed [{}]".format(elapsed))