Update projected_model.py

This commit is contained in:
Marco Cheung
2022-08-01 22:26:20 +08:00
committed by GitHub
parent dd1ecdd2a7
commit dde1148355
+3 -1
View File
@@ -50,15 +50,16 @@ class fsModel(BaseModel):
self.netArc = self.netArc.cuda()
self.netArc.eval()
self.netArc.requires_grad_(False)
if not self.isTrain:
pretrained_path = opt.checkpoints_dir
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
return
self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
# self.netD.feature_network.requires_grad_(False)
self.netD.cuda()
if self.isTrain:
# define loss functions
self.criterionFeat = nn.L1Loss()
@@ -83,6 +84,7 @@ class fsModel(BaseModel):
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
torch.cuda.empty_cache()
def cosin_metric(self, x1, x2):