This commit is contained in:
chenxuanhong
2022-03-19 00:43:03 +08:00
parent f65c0dfa09
commit 489a6c5f68
49 changed files with 10736 additions and 5 deletions
+9 -3
View File
@@ -5,7 +5,7 @@
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 15th February 2022 12:00:24 am
# Last Modified: Thursday, 17th March 2022 1:01:52 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -26,6 +26,8 @@ from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import grid_sample_gradfix
from arcface_torch.backbones.iresnet import iresnet100
from utilities.plot import plot_batch
from losses.cos import cosin_metric
from train_scripts.trainer_multigpu_base import TrainerBase
@@ -95,8 +97,12 @@ def init_framework(config, reporter, device, rank):
reporter.writeInfo("Discriminator structure:")
reporter.writeModel(dis.__str__())
arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
arcface = arcface1['model'].module
# arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
# arcface = arcface1['model'].module
arcface = iresnet100(pretrained=False, fp16=False)
arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu'))
arcface.eval()
# train in GPU