This commit is contained in:
chenxuanhong
2022-03-25 18:52:25 +08:00
parent 17c5d6a6b5
commit 99ed65aaa3
13 changed files with 136078 additions and 26 deletions
+15 -10
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 27th February 2022 7:50:18 pm
# Last Modified: Thursday, 24th March 2022 11:24:26 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -255,7 +255,8 @@ class Generator(nn.Module):
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.LeakyReLU(0.2)
# activation = nn.LeakyReLU(0.2)
activation = nn.ReLU()
# self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=False),
# nn.BatchNorm2d(64), activation)
@@ -266,13 +267,13 @@ class Generator(nn.Module):
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResDownSampleBlock(in_channel, in_channel*2,res_mode=res_mode)
self.down1 = ResDownSampleBlock(in_channel, in_channel*2, activation=activation, res_mode=res_mode) # 128
# nn.Sequential(
# nn.Conv2d(in_channel, in_channel*2, stride=2, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(in_channel*2),
# activation) # 128
self.down2 = ResDownSampleBlock(in_channel*2, in_channel*4,res_mode=res_mode)
self.down2 = ResDownSampleBlock(in_channel*2, in_channel*4, activation=activation, res_mode=res_mode) # 64
# nn.Sequential(
# nn.Conv2d(in_channel*2, in_channel*4, stride=2, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(in_channel*4),
@@ -280,7 +281,9 @@ class Generator(nn.Module):
# self.lstu = LSTU(in_channel*4,in_channel*4,in_channel*8,4)
self.down3 = ResDownSampleBlock(in_channel*4, in_channel*8,res_mode=res_mode)
self.down3 = ResDownSampleBlock(in_channel*4, in_channel*8, activation=activation, res_mode=res_mode) # 32
self.down4 = ResDownSampleBlock(in_channel*8, in_channel*8, activation=activation, res_mode=res_mode) # 16
# nn.Sequential(
# nn.Conv2d(in_channel*4, in_channel*8, stride=2, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(in_channel*8),
@@ -297,10 +300,12 @@ class Generator(nn.Module):
BN = []
for i in range(res_num):
BN += [
ResnetBlock_Adain(in_channel*8, latent_size=id_dim,res_mode=res_mode)]
ResnetBlock_Adain(in_channel*8, latent_size=id_dim, activation=activation, res_mode=res_mode)]
self.BottleNeck = nn.Sequential(*BN)
self.up5 = ResUpSampleBlock(in_channel*8, in_channel*8, id_dim, activation=activation, res_mode=res_mode) # 32
self.up4 = ResUpSampleBlock(in_channel*8,in_channel*8,id_dim,res_mode=res_mode) # 64
self.up4 = ResUpSampleBlock(in_channel*8, in_channel*4, id_dim, activation=activation, res_mode=res_mode) # 64
# nn.Sequential(
# nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=1, padding=1, bias=False),
@@ -308,7 +313,7 @@ class Generator(nn.Module):
# activation
# )
self.up3 = ResUpSampleBlock(in_channel*8,in_channel*4,id_dim,res_mode=res_mode) # 128
self.up3 = ResUpSampleBlock(in_channel*4, in_channel*2, id_dim, activation=activation, res_mode=res_mode) # 128
# nn.Sequential(
# nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.Conv2d(in_channel*8, in_channel*4, kernel_size=3, stride=1, padding=1, bias=False),
@@ -316,7 +321,7 @@ class Generator(nn.Module):
# activation
# )
self.up2 = ResUpSampleBlock(in_channel*4,in_channel*2,id_dim,res_mode=res_mode) # 256
self.up2 = ResUpSampleBlock(in_channel*2, in_channel, id_dim, activation=activation, res_mode=res_mode) # 256
# nn.Sequential(
# nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.Conv2d(in_channel*4, in_channel*2, kernel_size=3, stride=1, padding=1, bias=False),
@@ -324,7 +329,7 @@ class Generator(nn.Module):
# activation
# )
self.up1 = ResUpSampleBlock(in_channel*2,in_channel,id_dim,res_mode=res_mode) # 512
self.up1 = ResUpSampleBlock(in_channel, in_channel , id_dim, activation=activation, res_mode=res_mode) # 512
# nn.Sequential(
# nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),