This commit is contained in:
XHChen0528
2022-03-24 00:32:14 +08:00
parent 34c80c5315
commit fe52d5fbd5
5 changed files with 383 additions and 17 deletions
+11 -12
View File
@@ -120,14 +120,14 @@ class ResUpSampleBlock(nn.Module):
latent_size,
activation=nn.LeakyReLU(0.2),
res_mode="depthwise"):
super(ResnetBlock_Adain, self).__init__()
super(ResUpSampleBlock, self).__init__()
conv1 = []
self.in1 = InstanceNorm()
self.in2 = InstanceNorm()
if res_mode.lower() == "conv":
conv1 += [activation,
nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)]
nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, bias=False)]
elif res_mode.lower() == "depthwise":
conv1 += [activation,
@@ -145,7 +145,7 @@ class ResUpSampleBlock(nn.Module):
conv2 = []
if res_mode.lower() == "conv":
conv2 += [activation,
nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)]
nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False)]
elif res_mode.lower() == "depthwise":
conv2 += [activation,
@@ -183,14 +183,14 @@ class ResDownSampleBlock(nn.Module):
out_dim,
activation=nn.LeakyReLU(0.2),
res_mode="depthwise"):
super(ResnetBlock_Adain, self).__init__()
super(ResDownSampleBlock, self).__init__()
conv1 = []
if res_mode.lower() == "conv":
conv1 += [
nn.BatchNorm2d(in_dim),
activation,
nn.Conv2d(in_dim, in_dim, kernel_size=3, padding=1)]
nn.Conv2d(in_dim, in_dim, kernel_size=3, padding=1, bias=False)]
elif res_mode.lower() == "depthwise":
conv1 += [
@@ -213,7 +213,7 @@ class ResDownSampleBlock(nn.Module):
conv2 += [
nn.BatchNorm2d(in_dim),
activation,
nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)]
nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, bias=False)]
elif res_mode.lower() == "depthwise":
conv2 += [
@@ -235,7 +235,7 @@ class ResDownSampleBlock(nn.Module):
def forward(self, x):
y = self.conv1(y)
y = self.conv1(x)
y = self.resampling(y)
y = self.conv2(y)
res = self.reshape1_1(x)
@@ -264,7 +264,6 @@ class Generator(nn.Module):
padding_type= 'reflect'
activation = nn.LeakyReLU(0.2)
from components.DeConv_Depthwise import DeConv
# self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=False),
# nn.BatchNorm2d(64), activation)
@@ -310,7 +309,7 @@ class Generator(nn.Module):
# padding_type=padding_type, activation=activation, res_mode=res_mode)]
# self.BottleNeck = nn.Sequential(*BN)
self.up4 = ResDownSampleBlock(in_channel*8,in_channel*8,id_dim,res_mode=res_mode) # 64
self.up4 = ResUpSampleBlock(in_channel*8,in_channel*8,id_dim,res_mode=res_mode) # 64
# nn.Sequential(
# nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=1, padding=1, bias=False),
@@ -318,7 +317,7 @@ class Generator(nn.Module):
# activation
# )
self.up3 = ResDownSampleBlock(in_channel*8,in_channel*4,id_dim,res_mode=res_mode) # 128
self.up3 = ResUpSampleBlock(in_channel*8,in_channel*4,id_dim,res_mode=res_mode) # 128
# nn.Sequential(
# nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.Conv2d(in_channel*8, in_channel*4, kernel_size=3, stride=1, padding=1, bias=False),
@@ -326,7 +325,7 @@ class Generator(nn.Module):
# activation
# )
self.up2 = ResDownSampleBlock(in_channel*4,in_channel*2,id_dim,res_mode=res_mode) # 256
self.up2 = ResUpSampleBlock(in_channel*4,in_channel*2,id_dim,res_mode=res_mode) # 256
# nn.Sequential(
# nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.Conv2d(in_channel*4, in_channel*2, kernel_size=3, stride=1, padding=1, bias=False),
@@ -334,7 +333,7 @@ class Generator(nn.Module):
# activation
# )
self.up1 = ResDownSampleBlock(in_channel*2,in_channel,id_dim,res_mode=res_mode) # 512
self.up1 = ResUpSampleBlock(in_channel*2,in_channel,id_dim,res_mode=res_mode) # 512
# nn.Sequential(
# nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),