This commit is contained in:
chenxuanhong
2022-01-17 13:17:49 +08:00
parent bf2df5c5a6
commit 601d2ee43d
58 changed files with 2748 additions and 5696 deletions
+132 -58
View File
@@ -1,112 +1,186 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Conditional_Generator_gpt_LN_encoder copy.py
# Created Date: Saturday October 9th 2021
# File: Generator.py
# Created Date: Sunday January 16th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 26th October 2021 3:25:47 pm
# Last Modified: Sunday, 16th January 2022 11:42:14 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from ResBlock_Adain import ResBlock_Adain
from torch.nn import init
from torch.nn import functional as F
class InstanceNorm(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(InstanceNorm, self).__init__()
self.epsilon = epsilon
def forward(self, x):
x = x - torch.mean(x, (2, 3), True)
tmp = torch.mul(x, x) # or x ** 2
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
return x * tmp
class ApplyStyle(nn.Module):
"""
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
"""
def __init__(self, latent_size, channels):
super(ApplyStyle, self).__init__()
self.linear = nn.Linear(latent_size, channels * 2)
def forward(self, x, latent):
style = self.linear(latent) # style => [batch_size, n_channels*2]
shape = [-1, 2, x.size(1), 1, 1]
style = style.view(shape) # [batch_size, 2, n_channels, ...]
#x = x * (style[:, 0] + 1.) + style[:, 1]
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
return x
class ResnetBlock_Adain(nn.Module):
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
super(ResnetBlock_Adain, self).__init__()
p = 0
conv1 = []
if padding_type == 'reflect':
conv1 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv1 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
self.conv1 = nn.Sequential(*conv1)
self.style1 = ApplyStyle(latent_size, dim)
self.act1 = activation
p = 0
conv2 = []
if padding_type == 'reflect':
conv2 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv2 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
self.conv2 = nn.Sequential(*conv2)
self.style2 = ApplyStyle(latent_size, dim)
def forward(self, x, dlatents_in_slice):
y = self.conv1(x)
y = self.style1(y, dlatents_in_slice)
y = self.act1(y)
y = self.conv2(y)
y = self.style2(y, dlatents_in_slice)
out = x + y
return out
from functools import partial
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super(Generator, self).__init__()
super().__init__()
input_nc = kwargs["g_conv_dim"]
output_nc = kwargs["g_kernel_size"]
latent_size = kwargs["latent_size"]
n_blocks = kwargs["resblock_num"]
norm_name = kwargs["norm_name"]
padding_type= kwargs["reflect"]
if norm_name == "bn":
norm_layer = partial(nn.BatchNorm2d, affine = True, track_running_stats=True)
elif norm_name == "in":
norm_name = nn.InstanceNorm2d
chn = kwargs["g_conv_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
assert (n_blocks >= 0)
activation = nn.ReLU(True)
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
norm_layer(64), activation)
self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64), activation)
### downsample
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
norm_layer(128), activation)
nn.BatchNorm2d(128), activation)
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
norm_layer(256), activation)
nn.BatchNorm2d(256), activation)
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)
nn.BatchNorm2d(512), activation)
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)
nn.BatchNorm2d(512), activation)
### resnet blocks
BN = []
for i in range(n_blocks):
for i in range(res_num):
BN += [
ResBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)]
ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)]
self.BottleNeck = nn.Sequential(*BN)
if self.deep:
self.up4 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512), activation
)
self.up4 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512), activation
)
self.up3 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256), activation
)
self.up2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128), activation
)
self.up1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64), activation
)
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0))
self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, input, id):
x = input # 3*224*224
res = self.first_layer(x)
res = self.down1(res)
res = self.down2(res)
res = self.down4(res)
res = self.down3(res)
skip1 = self.first_layer(x)
skip2 = self.down1(skip1)
skip3 = self.down2(skip2)
skip4 = self.down3(skip3)
res = self.down4(skip4)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
x = self.BottleNeck[i](res, id)
res = self.up4(res)
res = self.up3(res)
res = self.up2(res)
res = self.up1(res)
res = self.last_layer(res)
return res
if __name__ == '__main__':
upscale = 4
window_size = 8
height = 1024
width = 1024
model = Generator()
print(model)
x = self.up4(x)
x = self.up3(x)
x = self.up2(x)
x = self.up1(x)
x = self.last_layer(x)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
return x