update test scripts
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 15th February 2022 10:20:50 pm
|
||||
# Last Modified: Thursday, 17th February 2022 7:12:56 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -90,6 +90,9 @@ class Tester(object):
|
||||
attr_files = self.config["attr_files"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
|
||||
self.reporter.writeInfo("Version %s"%version)
|
||||
|
||||
if os.path.isdir(attr_files):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
|
||||
@@ -127,6 +130,8 @@ class Tester(object):
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
cos_dict = {}
|
||||
average_cos = 0
|
||||
with torch.no_grad():
|
||||
for img in imgs_list:
|
||||
print(img)
|
||||
@@ -151,6 +156,7 @@ class Tester(object):
|
||||
results_arc = self.arcface(results_arc)
|
||||
results_arc = F.normalize(results_arc, p=2, dim=1)
|
||||
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
|
||||
average_cos += results_cos_dis
|
||||
|
||||
results = results * self.imagenet_std + self.imagenet_mean
|
||||
results = results.cpu().permute(0,2,3,1)[0,...]
|
||||
@@ -203,7 +209,9 @@ class Tester(object):
|
||||
attr_basename,ckp_step,version))
|
||||
|
||||
cv2.imwrite(save_filename, final_img)
|
||||
|
||||
average_cos /= len(imgs_list)
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
|
||||
Reference in New Issue
Block a user