Update projected_model.py #310

Open
MARCOCHEUNG0124 wants to merge 1 commits from MARCOCHEUNG0124/patch-2 into main

View File

@@ -50,15 +50,16 @@ class fsModel(BaseModel):
self.netArc = self.netArc.cuda()
self.netArc.eval()
self.netArc.requires_grad_(False)
if not self.isTrain:
pretrained_path = opt.checkpoints_dir
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
return
self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
# self.netD.feature_network.requires_grad_(False)
self.netD.cuda()
if self.isTrain:
# define loss functions
self.criterionFeat = nn.L1Loss()
@@ -83,6 +84,7 @@ class fsModel(BaseModel):
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
torch.cuda.empty_cache()
def cosin_metric(self, x1, x2):