update
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 29th January 2022 12:02:31 pm
|
||||
# Last Modified: Tuesday, 15th February 2022 10:20:50 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -67,17 +67,19 @@ class Tester(object):
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
|
||||
def test(self):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user