From 591c650dd932a331525230571af54db3efff937b Mon Sep 17 00:00:00 2001 From: Xuanhong Chen Date: Mon, 10 Jan 2022 17:04:25 +0800 Subject: [PATCH] update --- components/Discriminator.py | 67 +++++++++++++ components/FastNST.py | 156 ------------------------------- components/Generator.py | 112 ++++++++++++++++++++++ components/ResBlock_Adain.py | 76 +++++++++++++++ train_scripts/trainer_naiv512.py | 77 +++++++++++---- 5 files changed, 314 insertions(+), 174 deletions(-) create mode 100644 components/Discriminator.py delete mode 100644 components/FastNST.py create mode 100644 components/Generator.py create mode 100644 components/ResBlock_Adain.py diff --git a/components/Discriminator.py b/components/Discriminator.py new file mode 100644 index 0000000..adf89e6 --- /dev/null +++ b/components/Discriminator.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn + +class Discriminator(nn.Module): + def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, use_sigmoid=False): + super(Discriminator, self).__init__() + + kw = 4 + padw = 1 + self.down1 = nn.Sequential( + nn.Conv2d(input_nc, 64, kernel_size=kw, stride=2, padding=padw), + norm_layer(64), + nn.LeakyReLU(0.2, True) + ) + self.down2 = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=kw, stride=2, padding=padw), + norm_layer(128), + nn.LeakyReLU(0.2, True) + ) + self.down3 = nn.Sequential( + nn.Conv2d(128, 256, kernel_size=kw, stride=2, padding=padw), + norm_layer(256), + nn.LeakyReLU(0.2, True) + ) + self.down4 = nn.Sequential( + nn.Conv2d(256, 512, kernel_size=kw, stride=2, padding=padw), + norm_layer(512), + nn.LeakyReLU(0.2, True) + ) + self.down5 = nn.Sequential( + nn.Conv2d(512, 512, kernel_size=kw, stride=2, padding=padw), + norm_layer(512), + nn.LeakyReLU(0.2, True) + ) + self.conv1 = nn.Sequential( + nn.Conv2d(512, 512, kernel_size=kw, stride=1, padding=padw), + norm_layer(512), + nn.LeakyReLU(0.2, True) + ) + + if use_sigmoid: + self.conv2 = nn.Sequential( + nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw), + nn.Sigmoid() + ) + else: + self.conv2 = nn.Sequential( + nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw) + ) + + def forward(self, input): + out = [] + x = self.down1(input) + #out.append(x) + x = self.down2(x) + #out.append(x) + x = self.down3(x) + #out.append(x) + x = self.down4(x) + x = self.down5(x) + out.append(x) + x = self.conv1(x) + out.append(x) + x = self.conv2(x) + out.append(x) + + return out \ No newline at end of file diff --git a/components/FastNST.py b/components/FastNST.py deleted file mode 100644 index 44bdd6b..0000000 --- a/components/FastNST.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: Conditional_Generator_gpt_LN_encoder copy.py -# Created Date: Saturday October 9th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 26th October 2021 3:25:47 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -import torch - -from torch import nn -from torch.nn import init -from torch.nn import functional as F -from components.DeConv import DeConv -from components.network_swin import SwinTransformerBlock, PatchEmbed, PatchUnEmbed - -class ImageLN(nn.Module): - def __init__(self, dim) -> None: - super().__init__() - self.layer = nn.LayerNorm(dim) - def forward(self, x): - y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2) - return y - -class Generator(nn.Module): - def __init__( - self, - **kwargs - ): - super().__init__() - - chn = kwargs["g_conv_dim"] - k_size = kwargs["g_kernel_size"] - res_num = kwargs["res_num"] - class_num = kwargs["n_class"] - window_size = kwargs["window_size"] - image_size = kwargs["image_size"] - - padding_size = int((k_size -1)/2) - - self.resblock_list = [] - embed_dim = 96 - window_size = 8 - num_heads = 8 - mlp_ratio = 2. - norm_layer = nn.LayerNorm - qk_scale = None - qkv_bias = True - self.patch_norm = True - self.lnnorm = norm_layer(embed_dim) - - self.encoder = nn.Sequential( - nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False), - ImageLN(chn), - nn.ReLU(), - nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), # - ImageLN(chn * 2), - nn.ReLU(), - nn.Conv2d(in_channels = chn*2, out_channels = embed_dim, kernel_size=k_size, stride=2, padding=1,bias =False), - ImageLN(embed_dim), - nn.ReLU(), - ) - - # self.encoder2 = nn.Sequential( - - # nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU() - # ) - - self.fea_size = (image_size//4, image_size//4) - # self.conditional_GPT = GPT_Spatial(2, res_dim, res_num, class_num) - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=embed_dim, input_resolution=self.fea_size, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=0.0, attn_drop=0.0, - drop_path=0.1, - norm_layer=norm_layer) - for i in range(res_num)]) - - self.decoder = nn.Sequential( - # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.LeakyReLU(), - # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.LeakyReLU(), - # DeConv(in_channels = chn * 8, out_channels = chn *4, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 4, affine=True, momentum=0), - # nn.LeakyReLU(), - DeConv(in_channels = embed_dim, out_channels = chn * 2 , kernel_size=k_size), - # nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - ImageLN(chn * 2), - nn.ReLU(), - DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size), - ImageLN(chn), - nn.ReLU(), - nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True) - ) - - self.patch_embed = PatchEmbed( - img_size=self.fea_size[0], patch_size=1, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - self.patch_unembed = PatchUnEmbed( - img_size=self.fea_size[0], patch_size=1, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # self.__weights_init__() - - # def __weights_init__(self): - # for layer in self.encoder: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - # for layer in self.encoder2: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - def forward(self, input): - x2 = self.encoder(input) - x2 = self.patch_embed(x2) - for blk in self.blocks: - x2 = blk(x2,self.fea_size) - x2 = self.lnnorm(x2) - x2 = self.patch_unembed(x2,self.fea_size) - out = self.decoder(x2) - return out - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = 1024 - width = 1024 - model = Generator() - print(model) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) \ No newline at end of file diff --git a/components/Generator.py b/components/Generator.py new file mode 100644 index 0000000..8cd3819 --- /dev/null +++ b/components/Generator.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Conditional_Generator_gpt_LN_encoder copy.py +# Created Date: Saturday October 9th 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 26th October 2021 3:25:47 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + + +import torch +from torch import nn +from ResBlock_Adain import ResBlock_Adain + +from functools import partial + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super(Generator, self).__init__() + + input_nc = kwargs["g_conv_dim"] + output_nc = kwargs["g_kernel_size"] + latent_size = kwargs["latent_size"] + n_blocks = kwargs["resblock_num"] + norm_name = kwargs["norm_name"] + padding_type= kwargs["reflect"] + + if norm_name == "bn": + norm_layer = partial(nn.BatchNorm2d, affine = True, track_running_stats=True) + elif norm_name == "in": + norm_name = nn.InstanceNorm2d + + assert (n_blocks >= 0) + activation = nn.ReLU(True) + + self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0), + norm_layer(64), activation) + ### downsample + self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + norm_layer(128), activation) + self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), + norm_layer(256), activation) + self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), + norm_layer(512), activation) + self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + norm_layer(512), activation) + + ### resnet blocks + BN = [] + for i in range(n_blocks): + BN += [ + ResBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)] + self.BottleNeck = nn.Sequential(*BN) + + if self.deep: + self.up4 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(512), activation + ) + self.up3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(256), activation + ) + self.up2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(128), activation + ) + self.up1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(64), activation + ) + self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0)) + + def forward(self, input, dlatents): + x = input # 3*224*224 + res = self.first_layer(x) + res = self.down1(res) + res = self.down2(res) + res = self.down4(res) + res = self.down3(res) + + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, dlatents) + + res = self.up4(res) + res = self.up3(res) + res = self.up2(res) + res = self.up1(res) + res = self.last_layer(res) + return res + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = 1024 + width = 1024 + model = Generator() + print(model) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/components/ResBlock_Adain.py b/components/ResBlock_Adain.py new file mode 100644 index 0000000..ac4b3ec --- /dev/null +++ b/components/ResBlock_Adain.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn + +class InstanceNorm(nn.Module): + def __init__(self, epsilon=1e-8): + """ + @notice: avoid in-place ops. + https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 + """ + super(InstanceNorm, self).__init__() + self.epsilon = epsilon + + def forward(self, x): + x = x - torch.mean(x, (2, 3), True) + tmp = torch.mul(x, x) # or x ** 2 + tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) + return x * tmp + +class ApplyStyle(nn.Module): + """ + @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb + """ + def __init__(self, latent_size, channels): + super(ApplyStyle, self).__init__() + self.linear = nn.Linear(latent_size, channels * 2) + + def forward(self, x, latent): + style = self.linear(latent) # style => [batch_size, n_channels*2] + shape = [-1, 2, x.size(1), 1, 1] + style = style.view(shape) # [batch_size, 2, n_channels, ...] + #x = x * (style[:, 0] + 1.) + style[:, 1] + x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1 + return x + +class ResBlock_Adain(nn.Module): + def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): + super(ResBlock_Adain, self).__init__() + + p = 0 + conv1 = [] + if padding_type == 'reflect': + conv1 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv1 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] + self.conv1 = nn.Sequential(*conv1) + self.style1 = ApplyStyle(latent_size, dim) + self.act1 = activation + + p = 0 + conv2 = [] + if padding_type == 'reflect': + conv2 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv2 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] + self.conv2 = nn.Sequential(*conv2) + self.style2 = ApplyStyle(latent_size, dim) + + + def forward(self, x, dlatents_in_slice): + y = self.conv1(x) + y = self.style1(y, dlatents_in_slice) + y = self.act1(y) + y = self.conv2(y) + y = self.style2(y, dlatents_in_slice) + out = x + y + return out \ No newline at end of file diff --git a/train_scripts/trainer_naiv512.py b/train_scripts/trainer_naiv512.py index 347943c..9204103 100644 --- a/train_scripts/trainer_naiv512.py +++ b/train_scripts/trainer_naiv512.py @@ -16,8 +16,8 @@ import time import torch from torchvision.utils import save_image -from utilities.utilities import denorm, Gram, img2tensor255crop -from pretrained_weights.vgg import VGG16 +from utilities.utilities import denorm + class Trainer(object): @@ -92,10 +92,25 @@ class Trainer(object): # print and recorde model structure self.reporter.writeInfo("Generator structure:") self.reporter.writeModel(self.gen.__str__()) + + + + + # id extractor network + arcface_ckpt = self.config["arcface_ckpt"] + arcface_ckpt = torch.load(arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface_ckpt['model'].module + + + # train in GPU if self.config["cuda"] >=0: - self.gen = self.gen.cuda() + self.gen = self.gen.cuda() + self.arcface = self.arcface.cuda() + + self.arcface.eval() + self.arcface.requires_grad_(False) # if in finetune phase, load the pretrained checkpoint if self.config["phase"] == "finetune": @@ -216,24 +231,50 @@ class Trainer(object): step_epoch = step_epoch // batch_size print("Total step = %d in each epoch"%step_epoch) - VGG = VGG16().cuda() - - MEAN_VAL = 127.5 - SCALE_VAL= 127.5 - # Get Style Features - imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda() - imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda() - - style_tensor = img2tensor255crop(style_img,crop_size).cuda() - style_tensor = style_tensor.add(imagenet_neg_mean) - B, C, H, W = style_tensor.shape - style_features = VGG(style_tensor.expand([batch_size, C, H, W])) - style_gram = {} - for key, value in style_features.items(): - style_gram[key] = Gram(value) # step_epoch = 2 for epoch in range(start, total_epoch): for step in range(step_epoch): + + self.gen.train() + + src_image1, src_image2 = self.train_loader.next() + + + img_att = src_image1 + + if step%2 == 0: + img_id = src_image2 + else: + img_id = src_image2[randindex] + + src_image1_112 = F.interpolate(src_image1,size=(112,112), mode='bicubic') + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + + img_id_112_norm = spnorm(img_id_112) + + latent_id = model.netArc(img_id_112_norm) + + latent_id = F.normalize(latent_id, p=2, dim=1) + + losses, img_fake= model(None, src_image1, latent_id, None, for_G=True) + + # update Generator weights + losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] + loss_dict = dict(zip(model.loss_names, losses)) + + loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict['G_ID'] * opt.lambda_id + if step%2 == 0: + loss_G += loss_dict['G_Rec'] + + optimizer_G.zero_grad() + loss_G.backward(retain_graph=True) + optimizer_G.step() + + loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 + loss_dict['D_GP'] + optimizer_D.zero_grad() + loss_D.backward() + optimizer_D.step() + self.gen.train() content_images = self.train_loader.next()