diff --git a/test.py b/test.py index 851dc33..8157120 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 17th February 2022 10:23:54 am +# Last Modified: Thursday, 17th February 2022 7:00:12 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -34,7 +34,7 @@ def getParameters(): help="version name for train, test, finetune") parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU - parser.add_argument('-s', '--checkpoint_step', type=int, default=300000, + parser.add_argument('-s', '--checkpoint_step', type=int, default=510000, help="checkpoint epoch for test phase or finetune phase") # test @@ -44,9 +44,9 @@ def getParameters(): choices=['localhost', '4card','8card','new4card']) - parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\gxt3.jpeg') + parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg') # parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg') - parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\8', + parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID', help="file path for attribute images or video") parser.add_argument('--use_specified_data', action='store_true') diff --git a/test_scripts/tester_image.py b/test_scripts/tester_image.py index 7f7c525..7be905a 100644 --- a/test_scripts/tester_image.py +++ b/test_scripts/tester_image.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 15th February 2022 10:20:50 pm +# Last Modified: Thursday, 17th February 2022 7:12:56 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -90,6 +90,9 @@ class Tester(object): attr_files = self.config["attr_files"] self.arcface_ckpt= self.config["arcface_ckpt"] imgs_list = [] + + self.reporter.writeInfo("Version %s"%version) + if os.path.isdir(attr_files): print("Input a dir....") imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True) @@ -127,6 +130,8 @@ class Tester(object): print('Start =================================== test...') start_time = time.time() self.network.eval() + cos_dict = {} + average_cos = 0 with torch.no_grad(): for img in imgs_list: print(img) @@ -151,6 +156,7 @@ class Tester(object): results_arc = self.arcface(results_arc) results_arc = F.normalize(results_arc, p=2, dim=1) results_cos_dis = 1 - cos_loss(latend_id, results_arc) + average_cos += results_cos_dis results = results * self.imagenet_std + self.imagenet_mean results = results.cpu().permute(0,2,3,1)[0,...] @@ -203,7 +209,9 @@ class Tester(object): attr_basename,ckp_step,version)) cv2.imwrite(save_filename, final_img) - + average_cos /= len(imgs_list) elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) - print("Elapsed [{}]".format(elapsed)) \ No newline at end of file + print("Elapsed [{}]".format(elapsed)) + print("Average cosin similarity between ID and results [{}]".format(average_cos.item())) + self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item())) \ No newline at end of file