fix save model bugs

This commit is contained in:
chenxuanhong
2022-04-21 20:15:21 +08:00
parent b893316e41
commit 7ed12d218f
4 changed files with 9 additions and 9 deletions
+1 -1
View File
@@ -52,7 +52,7 @@ class BaseModel(torch.nn.Module):
torch.save(network.state_dict(), save_path)
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
def save_network(self, network, network_label, epoch_label, gpu_ids=None):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
+5 -2
View File
@@ -89,9 +89,11 @@ class Generator_Adain_Upsample(nn.Module):
padding_type='reflect'):
assert (n_blocks >= 0)
super(Generator_Adain_Upsample, self).__init__()
activation = nn.ReLU(True)
self.deep = deep
activation = nn.ReLU(True)
self.deep = deep
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
norm_layer(64), activation)
### downsample
@@ -101,6 +103,7 @@ class Generator_Adain_Upsample(nn.Module):
norm_layer(256), activation)
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)
if self.deep:
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)
+2 -2
View File
@@ -5,7 +5,7 @@
# Created Date: Wednesday January 12th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 20th April 2022 6:34:47 pm
# Last Modified: Thursday, 21st April 2022 8:13:37 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -94,7 +94,7 @@ class fsModel(BaseModel):
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch)
self.save_network(self.netD, 'D', which_epoch)
self.save_optim(self.optimizer_G, 'G', which_epoch,)
self.save_optim(self.optimizer_G, 'G', which_epoch)
self.save_optim(self.optimizer_D, 'D', which_epoch)
'''if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
+1 -4
View File
@@ -5,7 +5,7 @@
# Created Date: Monday December 27th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 21st April 2022 6:21:17 pm
# Last Modified: Thursday, 21st April 2022 8:10:05 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -43,9 +43,6 @@ class TrainOptions:
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--isTrain', type=str2bool, default='True')
# parser.add_argument('--use_tensorboard', type=str2bool, default='True',
# choices=['True', 'False'], help='enable the tensorboard')
# input/output sizes
self.parser.add_argument('--batchSize', type=int, default=16, help='input batch size')