This commit is contained in:
chenxuanhong
2022-02-17 16:30:33 +08:00
parent d82e46f0e7
commit 913e4916d4
17 changed files with 45302 additions and 71 deletions
+8 -6
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Saturday, 29th January 2022 12:02:31 pm
# Last Modified: Tuesday, 15th February 2022 10:20:50 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -67,17 +67,19 @@ class Tester(object):
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
def test(self):