diff --git a/models/base_model.py b/models/base_model.py index 1799129..3c6ca43 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -52,7 +52,7 @@ class BaseModel(torch.nn.Module): torch.save(network.state_dict(), save_path) # helper saving function that can be used by subclasses - def save_network(self, network, network_label, epoch_label, gpu_ids): + def save_network(self, network, network_label, epoch_label, gpu_ids=None): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) diff --git a/models/fs_networks_fix.py b/models/fs_networks_fix.py index af641c8..c7b0525 100644 --- a/models/fs_networks_fix.py +++ b/models/fs_networks_fix.py @@ -89,9 +89,11 @@ class Generator_Adain_Upsample(nn.Module): padding_type='reflect'): assert (n_blocks >= 0) super(Generator_Adain_Upsample, self).__init__() - activation = nn.ReLU(True) - self.deep = deep + activation = nn.ReLU(True) + + self.deep = deep + self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0), norm_layer(64), activation) ### downsample @@ -101,6 +103,7 @@ class Generator_Adain_Upsample(nn.Module): norm_layer(256), activation) self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), norm_layer(512), activation) + if self.deep: self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), norm_layer(512), activation) diff --git a/models/projected_model.py b/models/projected_model.py index 5c6e81d..477f63e 100644 --- a/models/projected_model.py +++ b/models/projected_model.py @@ -5,7 +5,7 @@ # Created Date: Wednesday January 12th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Wednesday, 20th April 2022 6:34:47 pm +# Last Modified: Thursday, 21st April 2022 8:13:37 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -94,7 +94,7 @@ class fsModel(BaseModel): def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch) self.save_network(self.netD, 'D', which_epoch) - self.save_optim(self.optimizer_G, 'G', which_epoch,) + self.save_optim(self.optimizer_G, 'G', which_epoch) self.save_optim(self.optimizer_D, 'D', which_epoch) '''if self.gen_features: self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)''' diff --git a/train.py b/train.py index 93df868..87a3b58 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ # Created Date: Monday December 27th 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 21st April 2022 6:21:17 pm +# Last Modified: Thursday, 21st April 2022 8:10:05 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -43,9 +43,6 @@ class TrainOptions: self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') self.parser.add_argument('--isTrain', type=str2bool, default='True') - # parser.add_argument('--use_tensorboard', type=str2bool, default='True', - # choices=['True', 'False'], help='enable the tensorboard') - # input/output sizes self.parser.add_argument('--batchSize', type=int, default=16, help='input batch size')