Update projected_model.py #310
@@ -50,15 +50,16 @@ class fsModel(BaseModel):
|
||||
self.netArc = self.netArc.cuda()
|
||||
self.netArc.eval()
|
||||
self.netArc.requires_grad_(False)
|
||||
|
||||
if not self.isTrain:
|
||||
pretrained_path = opt.checkpoints_dir
|
||||
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
|
||||
return
|
||||
|
||||
self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
|
||||
# self.netD.feature_network.requires_grad_(False)
|
||||
self.netD.cuda()
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
# define loss functions
|
||||
self.criterionFeat = nn.L1Loss()
|
||||
@@ -83,6 +84,7 @@ class fsModel(BaseModel):
|
||||
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
|
||||
self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
|
||||
self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def cosin_metric(self, x1, x2):
|
||||
|
||||
Reference in New Issue
Block a user