arcface
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 15th February 2022 12:00:24 am
|
||||
# Last Modified: Thursday, 17th March 2022 1:01:52 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -26,6 +26,8 @@ from torch_utils import training_stats
|
||||
from torch_utils.ops import conv2d_gradfix
|
||||
from torch_utils.ops import grid_sample_gradfix
|
||||
|
||||
from arcface_torch.backbones.iresnet import iresnet100
|
||||
|
||||
from utilities.plot import plot_batch
|
||||
from losses.cos import cosin_metric
|
||||
from train_scripts.trainer_multigpu_base import TrainerBase
|
||||
@@ -95,8 +97,12 @@ def init_framework(config, reporter, device, rank):
|
||||
reporter.writeInfo("Discriminator structure:")
|
||||
reporter.writeModel(dis.__str__())
|
||||
|
||||
arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
|
||||
arcface = arcface1['model'].module
|
||||
# arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
|
||||
# arcface = arcface1['model'].module
|
||||
|
||||
arcface = iresnet100(pretrained=False, fp16=False)
|
||||
arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu'))
|
||||
arcface.eval()
|
||||
|
||||
# train in GPU
|
||||
|
||||
|
||||
Reference in New Issue
Block a user