This commit is contained in:
chenxuanhong
2023-06-01 23:32:08 +08:00
parent 7a6c92ae69
commit 70b22c4e4b
8 changed files with 77 additions and 15 deletions
-3
View File
@@ -59,9 +59,6 @@ class fsModel(BaseModel):
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False)
self.netG.to(device)
# Id network
netArc_checkpoint = opt.Arc_path
netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
+3 -7
View File
@@ -7,13 +7,9 @@ from .config import device, num_classes
def create_model(opt):
if opt.model == 'pix2pixHD':
#from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
from .fs_model import fsModel
model = fsModel()
else:
from .ui_model import UIModel
model = UIModel()
#from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
from .fs_model import fsModel
model = fsModel()
model.initialize(opt)
if opt.verbose:
+2 -2
View File
@@ -5,7 +5,7 @@
# Created Date: Wednesday January 12th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 21st April 2022 8:13:37 pm
# Last Modified: Saturday, 13th May 2023 9:56:35 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -46,7 +46,7 @@ class fsModel(BaseModel):
# Id network
netArc_checkpoint = opt.Arc_path
netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
self.netArc = netArc_checkpoint['model'].module
self.netArc = netArc_checkpoint
self.netArc = self.netArc.cuda()
self.netArc.eval()
self.netArc.requires_grad_(False)