pose and id validation tool

This commit is contained in:
chenxuanhong
2022-03-05 01:18:41 +08:00
parent e02f756116
commit f65c0dfa09
2 changed files with 35 additions and 161 deletions
+10
View File
@@ -17,4 +17,14 @@
"ckp_path": "train_logs",
"logfilename": "filestate_machine1.json"
}
,
{
"ip": "192.168.4.120",
"user": "gdp",
"port": 22,
"passwd": "glass123456",
"path": "/home/gdp/harddisk/Data2/simswap_plus",
"ckp_path": "train_logs",
"logfilename": "filestate_machine2.json"
}
]
+25 -161
View File
@@ -5,7 +5,7 @@
# Created Date: Friday March 4th 2022
# Author: Liu Naiyuan
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 4th March 2022 5:33:47 pm
# Last Modified: Saturday, 5th March 2022 1:00:29 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -18,7 +18,6 @@ import glob
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils import data
@@ -29,16 +28,12 @@ import PIL
from PIL import Image
class TotalDataset(data.Dataset):
"""Dataset class for the vggface dataset with precalulated face landmarks."""
def __init__(self,image_dir,content_transform, img_size=224):
def __init__(self,image_dir,content_transform):
self.image_dir= image_dir
self.content_transform= content_transform
self.img_size = img_size
self.dataset = []
self.preprocess()
self.num_images = len(self.dataset)
@@ -70,50 +65,30 @@ class TotalDataset(data.Dataset):
"""Return the number of images."""
return len(self.dataset)
def getLoader_sourceface(c_image_dir,
img_size=224, batch_size=16, num_workers=8):
def getLoader(c_image_dir, batch_size=16):
"""Build and return a data loader."""
c_transforms = []
c_transforms.append(T.ToTensor())
c_transforms.append(T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
# c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
c_transforms = T.Compose(c_transforms)
num_workers = 8
content_dataset = TotalDataset(c_image_dir, c_transforms, 224)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True)
return content_data_loader, len(content_dataset)
def getLoader_targetface(c_image_dir,
img_size=224, batch_size=16, num_workers=8):
"""Build and return a data loader."""
c_transforms = []
c_transforms = []
c_transforms.append(transforms.ToTensor())
# c_transforms.append(T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
c_transforms.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
# c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
c_transforms = transforms.Compose(c_transforms)
content_dataset = TotalDataset(c_image_dir, c_transforms, 224)
content_dataset = TotalDataset(c_image_dir, c_transforms)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True)
return content_data_loader, len(content_dataset)
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
@@ -161,52 +136,30 @@ class Tester(object):
def test(self):
save_dir = self.config["test_samples_path"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
id_imgs = self.config["id_imgs"]
attr_files = self.config["attr_files"]
batch_size = self.config["batch_size"]
specified_save_path = self.config["specified_save_path"]
self.arcface_ckpt= self.config["arcface_ckpt"]
imgs_list = []
self.reporter.writeInfo("Version %s"%version)
if os.path.isdir(specified_save_path):
print("Input a legal specified save path!")
save_dir = specified_save_path
save_dir = os.path.join(save_dir,"v_%s_step_%d"%(version,self.config["checkpoint_step"]))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if os.path.isdir(attr_files):
print("Input a dir....")
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
for item in imgs:
imgs_list.append(item)
print(imgs_list)
else:
print("Input an image....")
imgs_list.append(attr_files)
id_basename = os.path.basename(id_imgs)
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
source_loader, dataet_len = getLoader_sourceface(
self.config["env_config"]["dataset_paths"]["id_pose_source_root"], batch_size=opt.batchSize)
target_loader, dataet_len = getLoader_targetface(
self.config["env_config"]["dataset_paths"]["id_pose_source_root"], batch_size=opt.batchSize)
source_loader, dataet_len = getLoader(
self.config["env_config"]["dataset_paths"]["id_pose_source_root"], batch_size=batch_size)
target_loader, dataet_len = getLoader(
self.config["env_config"]["dataset_paths"]["id_pose_source_root"], batch_size=batch_size)
source_iter = iter(source_loader)
target_iter = iter(target_loader)
# models
self.__init_framework__()
id_img = cv2.imread(id_imgs)
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
@@ -217,7 +170,7 @@ class Tester(object):
for profile_batch, filename_batch in tqdm(source_iter):
profile_batch = profile_batch.cuda()
profile_id_downsample = F.interpolate(profile_batch, (112,112), mode='bicubic')
profile_latent_id = model.netArc(profile_id_downsample)
profile_latent_id = self.arcface(profile_id_downsample)
profile_latent_id = F.normalize(profile_latent_id, p=2, dim=1)
if init_batch ==True:
wholeid_batch = profile_latent_id.cpu()
@@ -241,106 +194,17 @@ class Tester(object):
init_id_batch = False
else:
batch_id = torch.cat([batch_id, wholeid_batch[target_index][None].cuda()],dim = 0)
img_fakes = model(None, target_batch.cuda(), batch_id, None, True)
img_fakes = self.network(target_batch.cuda(), batch_id)
for img_fake, target_index_tmp,filename_tmp in zip(img_fakes, target_index_list,filename_batch):
filename_tmp_split = filename_tmp.split('_')
final_filename = filename_tmp_split[0] + '_' +str(target_index_tmp) + '_' + filename_tmp_split[-1]
save_path = os.path.join(simswap_eval_save_image_path,final_filename)
save_image = postprocess(img_fake.cpu().numpy().transpose(1,2,0))
PIL.Image.fromarray(save_image).save(save_path,quality=95)
for img in imgs_list:
print(img)
attr_img_ori= cv2.imread(img)
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB))
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
attr_id = self.arcface(attr_img_arc)
attr_id = F.normalize(attr_id, p=2, dim=1)
results = self.network(attr_img, latend_id)
results = results * self.imagenet_std + self.imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
results = results.numpy()
results = np.clip(results,0.0,1.0)
final_img = img1.astype(np.uint8)
attr_basename = os.path.splitext(os.path.basename(img))[0]
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename, final_img)
average_cos /= len(imgs_list)
save_path = os.path.join(save_dir,final_filename)
img_fake = img_fake * self.imagenet_std + self.imagenet_mean
img_fake = img_fake.numpy().transpose(1,2,0)
img_fake = np.clip(img_fake,0.0,1.0) * 255
PIL.Image.fromarray(img_fake).save(save_path,quality=100)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
if __name__ == '__main__':
opt = TestOptions().parse()
with torch.no_grad():
source_loader, dataet_len = getLoader_sourceface('/home/gdp/harddisk/Data2/Faceswap/FaceForensics++_image_hififacestyle_source_Nonearcstyle', batch_size=opt.batchSize)
target_loader, dataet_len = getLoader_targetface('/home/gdp/harddisk/Data2/Faceswap/FaceForensics++_image_target_even10_pro_withmat_Nonearcstyle_256', batch_size=opt.batchSize)
simswap_eval_save_image_path = opt.output_path
criterion = nn.L1Loss()
if not os.path.exists(simswap_eval_save_image_path):
os.makedirs(simswap_eval_save_image_path)
torch.nn.Module.dump_patches = True
model = create_model(opt)
model.eval()
source_iter = iter(source_loader)
target_iter = iter(target_loader)
init_batch = True
for profile_batch, filename_batch in tqdm(source_iter):
# src_batch, filename_batch = data_iter.next()
profile_batch = profile_batch.cuda()
profile_id_downsample = F.interpolate(profile_batch, (112,112))
profile_latent_id = model.netArc(profile_id_downsample)
profile_latent_id = F.normalize(profile_latent_id, p=2, dim=1)
if init_batch ==True:
wholeid_batch = profile_latent_id.cpu()
init_batch = False
else:
wholeid_batch = torch.cat([wholeid_batch,profile_latent_id.cpu()],dim=0)
print(wholeid_batch.shape)
# np.save("simswap_wholeid_batch.npy", wholeid_batch.detach().cpu().numpy())
target_source_pair_dict = np.load('/home/gdp/harddisk/Data2/Faceswap/npy_file/target_source_pair.npy' ,allow_pickle=True).item()
for target_batch, filename_batch in tqdm(target_iter):
target_index_list = []
init_id_batch = True
for filename_tmp in filename_batch:
source_index = int(filename_tmp.split('_')[0])
target_index = target_source_pair_dict[source_index]
target_index_list.append(target_index)
if init_id_batch:
batch_id = wholeid_batch[target_index][None].cuda()
init_id_batch = False
else:
batch_id = torch.cat([batch_id, wholeid_batch[target_index][None].cuda()],dim = 0)
img_fakes = model(None, target_batch.cuda(), batch_id, None, True)
for img_fake, target_index_tmp,filename_tmp in zip(img_fakes, target_index_list,filename_batch):
filename_tmp_split = filename_tmp.split('_')
final_filename = filename_tmp_split[0] + '_' +str(target_index_tmp) + '_' + filename_tmp_split[-1]
save_path = os.path.join(simswap_eval_save_image_path,final_filename)
save_image = postprocess(img_fake.cpu().numpy().transpose(1,2,0))
PIL.Image.fromarray(save_image).save(save_path,quality=95)
print("Elapsed [{}]".format(elapsed))