fix save model bugs
This commit is contained in:
@@ -52,7 +52,7 @@ class BaseModel(torch.nn.Module):
|
||||
torch.save(network.state_dict(), save_path)
|
||||
|
||||
# helper saving function that can be used by subclasses
|
||||
def save_network(self, network, network_label, epoch_label, gpu_ids):
|
||||
def save_network(self, network, network_label, epoch_label, gpu_ids=None):
|
||||
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
torch.save(network.cpu().state_dict(), save_path)
|
||||
|
||||
@@ -89,9 +89,11 @@ class Generator_Adain_Upsample(nn.Module):
|
||||
padding_type='reflect'):
|
||||
assert (n_blocks >= 0)
|
||||
super(Generator_Adain_Upsample, self).__init__()
|
||||
activation = nn.ReLU(True)
|
||||
self.deep = deep
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.deep = deep
|
||||
|
||||
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
|
||||
norm_layer(64), activation)
|
||||
### downsample
|
||||
@@ -101,6 +103,7 @@ class Generator_Adain_Upsample(nn.Module):
|
||||
norm_layer(256), activation)
|
||||
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(512), activation)
|
||||
|
||||
if self.deep:
|
||||
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(512), activation)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Wednesday January 12th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 20th April 2022 6:34:47 pm
|
||||
# Last Modified: Thursday, 21st April 2022 8:13:37 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -94,7 +94,7 @@ class fsModel(BaseModel):
|
||||
def save(self, which_epoch):
|
||||
self.save_network(self.netG, 'G', which_epoch)
|
||||
self.save_network(self.netD, 'D', which_epoch)
|
||||
self.save_optim(self.optimizer_G, 'G', which_epoch,)
|
||||
self.save_optim(self.optimizer_G, 'G', which_epoch)
|
||||
self.save_optim(self.optimizer_D, 'D', which_epoch)
|
||||
'''if self.gen_features:
|
||||
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Monday December 27th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 21st April 2022 6:21:17 pm
|
||||
# Last Modified: Thursday, 21st April 2022 8:10:05 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -43,9 +43,6 @@ class TrainOptions:
|
||||
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
||||
self.parser.add_argument('--isTrain', type=str2bool, default='True')
|
||||
|
||||
# parser.add_argument('--use_tensorboard', type=str2bool, default='True',
|
||||
# choices=['True', 'False'], help='enable the tensorboard')
|
||||
|
||||
# input/output sizes
|
||||
self.parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user