update test scripts

This commit is contained in:
chenxuanhong
2022-02-17 19:13:16 +08:00
parent 913e4916d4
commit 522cc43d1c
2 changed files with 15 additions and 7 deletions
+4 -4
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 17th February 2022 10:23:54 am
# Last Modified: Thursday, 17th February 2022 7:00:12 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -34,7 +34,7 @@ def getParameters():
help="version name for train, test, finetune")
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
parser.add_argument('-s', '--checkpoint_step', type=int, default=300000,
parser.add_argument('-s', '--checkpoint_step', type=int, default=510000,
help="checkpoint epoch for test phase or finetune phase")
# test
@@ -44,9 +44,9 @@ def getParameters():
choices=['localhost', '4card','8card','new4card'])
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\gxt3.jpeg')
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg')
# parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg')
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\8',
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID',
help="file path for attribute images or video")
parser.add_argument('--use_specified_data', action='store_true')
+11 -3
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 15th February 2022 10:20:50 pm
# Last Modified: Thursday, 17th February 2022 7:12:56 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -90,6 +90,9 @@ class Tester(object):
attr_files = self.config["attr_files"]
self.arcface_ckpt= self.config["arcface_ckpt"]
imgs_list = []
self.reporter.writeInfo("Version %s"%version)
if os.path.isdir(attr_files):
print("Input a dir....")
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
@@ -127,6 +130,8 @@ class Tester(object):
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
cos_dict = {}
average_cos = 0
with torch.no_grad():
for img in imgs_list:
print(img)
@@ -151,6 +156,7 @@ class Tester(object):
results_arc = self.arcface(results_arc)
results_arc = F.normalize(results_arc, p=2, dim=1)
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
average_cos += results_cos_dis
results = results * self.imagenet_std + self.imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
@@ -203,7 +209,9 @@ class Tester(object):
attr_basename,ckp_step,version))
cv2.imwrite(save_filename, final_img)
average_cos /= len(imgs_list)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
print("Elapsed [{}]".format(elapsed))
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))