From 601d2ee43dff4f3087c72e682030de15b0b074d9 Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Mon, 17 Jan 2022 13:17:49 +0800 Subject: [PATCH] update --- .gitignore | 36 +- GUI.py | 6 +- GUI/file_sync/filestate_machine0.json | 49 + GUI/guiignore.json | 22 + GUI/machines.json | 11 + README.md | 24 +- ...onditional_Discriminator_Projection_big.py | 124 --- components/Conditional_Generator_Noskip.py | 114 --- components/Conditional_ResBlock_ModulaConv.py | 82 -- components/DeConv.py | 20 - components/Discriminator.py | 67 -- components/FastNST_CNN.py | 129 --- components/FastNST_CNN_Resblock.py | 110 --- components/FastNST_Liif.py | 144 --- components/FastNST_Liif_warp.py | 150 --- components/FastNST_Liif_warpinvo.py | 146 --- components/Generator.py | 190 ++-- components/Involution.py | 303 ------- components/Liif.py | 146 --- components/Liif_conv.py | 156 ---- components/Liif_invo.py | 164 ---- components/ResBlock.py | 38 - components/ResBlock_Adain.py | 76 -- components/Transform.py | 14 - components/network_swin.py | 854 ------------------ components/pg_modules/blocks.py | 325 +++++++ components/pg_modules/diffaug.py | 76 ++ components/pg_modules/discriminator.py | 186 ++++ components/pg_modules/networks_fastgan.py | 178 ++++ components/pg_modules/networks_stylegan2.py | 537 +++++++++++ components/pg_modules/projector.py | 158 ++++ components/projected_discriminator.py | 194 ++++ components/warp_invo.py | 45 - data_tools/data_loader_VGGFace2HQ.py | 118 ++- data_tools/data_loader_place365.py | 223 ----- env/env.json | 17 + models/__init__.py | 4 + models/arcface_models.py | 162 ++++ models/config.py | 28 + train.py | 16 +- train_scripts/trainer_FM.py | 333 +++++++ train_scripts/trainer_FastNST.py | 307 ------- train_scripts/trainer_FastNST_CNN.py | 297 ------ train_scripts/trainer_FastNST_Liif.py | 296 ------ train_scripts/trainer_FastNST_SWD.py | 300 ------ train_scripts/trainer_base.py | 114 +++ train_scripts/trainer_gan.py | 382 -------- train_scripts/trainer_naiv512.py | 40 +- train_yamls/train_512FM.yaml | 62 ++ train_yamls/train_FastNST.yaml | 83 -- train_yamls/train_FastNST_CNN.yaml | 108 --- train_yamls/train_FastNST_CNN_Resblock.yaml | 108 --- train_yamls/train_FastNST_Liif.yaml | 110 --- train_yamls/train_FastNST_Liif_warp.yaml | 109 --- train_yamls/train_FastNST_Liif_warpinvo.yaml | 109 --- train_yamls/train_FastNST_SWD.yaml | 109 --- train_yamls/train_noskip.yaml | 98 -- utilities/plot.py | 37 + 58 files changed, 2748 insertions(+), 5696 deletions(-) create mode 100644 GUI/file_sync/filestate_machine0.json create mode 100644 GUI/guiignore.json create mode 100644 GUI/machines.json delete mode 100644 components/Conditional_Discriminator_Projection_big.py delete mode 100644 components/Conditional_Generator_Noskip.py delete mode 100644 components/Conditional_ResBlock_ModulaConv.py delete mode 100644 components/DeConv.py delete mode 100644 components/Discriminator.py delete mode 100644 components/FastNST_CNN.py delete mode 100644 components/FastNST_CNN_Resblock.py delete mode 100644 components/FastNST_Liif.py delete mode 100644 components/FastNST_Liif_warp.py delete mode 100644 components/FastNST_Liif_warpinvo.py delete mode 100644 components/Involution.py delete mode 100644 components/Liif.py delete mode 100644 components/Liif_conv.py delete mode 100644 components/Liif_invo.py delete mode 100644 components/ResBlock.py delete mode 100644 components/ResBlock_Adain.py delete mode 100644 components/Transform.py delete mode 100644 components/network_swin.py create mode 100644 components/pg_modules/blocks.py create mode 100644 components/pg_modules/diffaug.py create mode 100644 components/pg_modules/discriminator.py create mode 100644 components/pg_modules/networks_fastgan.py create mode 100644 components/pg_modules/networks_stylegan2.py create mode 100644 components/pg_modules/projector.py create mode 100644 components/projected_discriminator.py delete mode 100644 components/warp_invo.py delete mode 100644 data_tools/data_loader_place365.py create mode 100644 env/env.json create mode 100644 models/__init__.py create mode 100644 models/arcface_models.py create mode 100644 models/config.py create mode 100644 train_scripts/trainer_FM.py delete mode 100644 train_scripts/trainer_FastNST.py delete mode 100644 train_scripts/trainer_FastNST_CNN.py delete mode 100644 train_scripts/trainer_FastNST_Liif.py delete mode 100644 train_scripts/trainer_FastNST_SWD.py create mode 100644 train_scripts/trainer_base.py delete mode 100644 train_scripts/trainer_gan.py create mode 100644 train_yamls/train_512FM.yaml delete mode 100644 train_yamls/train_FastNST.yaml delete mode 100644 train_yamls/train_FastNST_CNN.yaml delete mode 100644 train_yamls/train_FastNST_CNN_Resblock.yaml delete mode 100644 train_yamls/train_FastNST_Liif.yaml delete mode 100644 train_yamls/train_FastNST_Liif_warp.yaml delete mode 100644 train_yamls/train_FastNST_Liif_warpinvo.yaml delete mode 100644 train_yamls/train_FastNST_SWD.yaml delete mode 100644 train_yamls/train_noskip.yaml create mode 100644 utilities/plot.py diff --git a/.gitignore b/.gitignore index bdaf733..d4c33ea 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,8 @@ parts/ sdist/ var/ wheels/ +pip-wheel-metadata/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg @@ -45,6 +47,7 @@ htmlcov/ nosetests.xml coverage.xml *.cover +*.py,cover .hypothesis/ .pytest_cache/ @@ -56,6 +59,7 @@ coverage.xml *.log local_settings.py db.sqlite3 +db.sqlite3-journal # Flask stuff: instance/ @@ -80,21 +84,19 @@ ipython_config.py # pyenv .python-version +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - # Spyder project settings .spyderproject .spyproject @@ -112,9 +114,15 @@ dmypy.json # Pyre type checker .pyre/ +test_logs/ +train_logs/ +*.png +*.PNG +*.jpg +*.JPG -/train_logs -/test_logs -/GUI -/benchmark -/reference \ No newline at end of file +./env +./wandb +./train_logs +./test_logs +./arcface_ckpt \ No newline at end of file diff --git a/GUI.py b/GUI.py index 372598a..2ce3836 100644 --- a/GUI.py +++ b/GUI.py @@ -5,7 +5,7 @@ # Created Date: Wednesday December 22nd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Monday, 10th January 2022 1:47:55 pm +# Last Modified: Monday, 17th January 2022 12:45:32 am # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -545,16 +545,18 @@ class Application(tk.Frame): thread_update.start() def Machines_Update(self): - self.update_log_task() + # self.update_log_task() thread_update = threading.Thread(target=self.machines_update) thread_update.start() def machines_update(self): self.machine_list = read_config(self.machine_json) + print(self.machine_list) ip_list = [] for item in self.machine_list: self.machine_dict[item["ip"]] = item ip_list.append(item["ip"]) + print(ip_list) self.list_com["value"] = ip_list self.list_com.current(0) ip = self.list_com.get() diff --git a/GUI/file_sync/filestate_machine0.json b/GUI/file_sync/filestate_machine0.json new file mode 100644 index 0000000..cccc24e --- /dev/null +++ b/GUI/file_sync/filestate_machine0.json @@ -0,0 +1,49 @@ +{ + "GUI.py": 1642351532.4558506, + "test.py": 1634039043.4872007, + "train.py": 1642351831.0061252, + "components\\Generator.py": 1642347735.351465, + "components\\Involution.py": 1626748553.9503577, + "components\\projected_discriminator.py": 1642348101.4661522, + "components\\ResBlock.py": 1625415499.383468, + "components\\Transform.py": 1624954083.0098498, + "components\\warp_invo.py": 1634614033.6983366, + "components\\pg_modules\\blocks.py": 1640773190.0, + "components\\pg_modules\\diffaug.py": 1640773190.0, + "components\\pg_modules\\discriminator.py": 1642349784.9407308, + "components\\pg_modules\\networks_fastgan.py": 1640773190.0, + "components\\pg_modules\\networks_stylegan2.py": 1640773190.0, + "components\\pg_modules\\projector.py": 1642349764.3896568, + "data_tools\\data_loader.py": 1611123530.660446, + "data_tools\\data_loader_condition.py": 1625411562.8217106, + "data_tools\\data_loader_VGGFace2HQ.py": 1642349144.749807, + "losses\\PerceptualLoss.py": 1615020169.668723, + "losses\\SliceWassersteinDistance.py": 1634022704.6082795, + "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\experimental.py": 1583468787.0, + "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\stylize.py": 1583468787.0, + "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\train.py": 1583468787.0, + "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\transformer.py": 1583468787.0, + "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\utils.py": 1583468787.0, + "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\vgg.py": 1633868477.988523, + "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\video.py": 1583468787.0, + "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\webcam.py": 1583468787.0, + "test_scripts\\tester_common.py": 1625369535.199175, + "test_scripts\\tester_FastNST.py": 1634041357.607633, + "train_scripts\\trainer_base.py": 1642347616.205689, + "train_scripts\\trainer_FastNST_SWD.py": 1634581704.2218158, + "train_scripts\\trainer_FM.py": 1642350579.8586667, + "train_scripts\\trainer_gan.py": 1625571403.080787, + "utilities\\checkpoint_manager.py": 1611123530.6624403, + "utilities\\figure.py": 1611123530.6634378, + "utilities\\json_config.py": 1611123530.6614666, + "utilities\\learningrate_scheduler.py": 1611123530.675422, + "utilities\\logo_class.py": 1633883995.3093486, + "utilities\\plot.py": 1641911100.7995758, + "utilities\\reporter.py": 1625413813.7213495, + "utilities\\save_heatmap.py": 1611123530.679439, + "utilities\\sshupload.py": 1611123530.6624403, + "utilities\\transfer_checkpoint.py": 1612416429.5316093, + "utilities\\utilities.py": 1634019485.0783668, + "utilities\\yaml_config.py": 1611123530.6614666, + "train_yamls\\train_512FM.yaml": 1642351806.754128 +} \ No newline at end of file diff --git a/GUI/guiignore.json b/GUI/guiignore.json new file mode 100644 index 0000000..7b212a2 --- /dev/null +++ b/GUI/guiignore.json @@ -0,0 +1,22 @@ +{ + "white_list": { + "extension": [ + "py", + "yaml" + ], + "file": [], + "path": [] + }, + "black_list": { + "extension": [ + "png", + "yaml" + ], + "file": [], + "path": [ + "train_logs/", + "test_logs/", + "GUI/" + ] + } +} \ No newline at end of file diff --git a/GUI/machines.json b/GUI/machines.json new file mode 100644 index 0000000..30b51d2 --- /dev/null +++ b/GUI/machines.json @@ -0,0 +1,11 @@ +[ + { + "ip": "101.33.242.26", + "user": "ubuntu", + "port": 22, + "passwd": "zpKlOW0sMlyt!xhE", + "path": "/home/ubuntu/CXH/simswap_plus", + "ckp_path": "train_logs", + "logfilename": "filestate_machine0.json" + } +] \ No newline at end of file diff --git a/README.md b/README.md index 4aeb6c0..5a36251 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,22 @@ -# SimSwap_Plus - A high resolution and faster face editing framework +# Project Name + +## Dependencies +- python +- yaml (pip install pyyaml) +- paramiko (For ssh file transportation) +- pytorch +- tkinter (For GUI) +- pillow +- torchvision +- opencv +- tensorboard (pip install tensorboard) +- tensorboardX (pip install tensorboardX) + +## Usage +- To configure the project in the ```main.py```. + + +## Acknowledgement + +## Related Projects +Learn about our other projects [[RainNet]](https://neuralchen.github.io/RainNet), [[Sketch Generation]](https://github.com/TZYSJTU/Sketch-Generation-with-Drawing-Process-Guided-by-Vector-Flow-and-Grayscale), [[CooGAN]](https://github.com/neuralchen/CooGAN), [[Knowledge Style Transfer]](https://github.com/AceSix/Knowledge_Transfer), [[Youtube downloader]](https://github.com/AIARTSJTU/YoutubeDataCollector). \ No newline at end of file diff --git a/components/Conditional_Discriminator_Projection_big.py b/components/Conditional_Discriminator_Projection_big.py deleted file mode 100644 index c69114e..0000000 --- a/components/Conditional_Discriminator_Projection_big.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: Conditional_Discriminator copy.py -# Created Date: Saturday April 18th 2020 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 29th June 2021 4:26:33 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2020 Shanghai Jiao Tong University -############################################################# - - -import torch - -from torch import nn -from torch.nn import init -from torch.nn import functional as F - -from torch.nn import utils - -class Discriminator(nn.Module): - def __init__(self, chn=32, k_size=3, n_class=3): - super().__init__() - # padding_size = int((k_size -1)/2) - slop = 0.2 - enable_bias = True - - # stage 1 - self.block1 = nn.Sequential( - utils.spectral_norm(nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size= k_size, stride = 2, padding=2,bias= enable_bias)), - nn.LeakyReLU(slop), - utils.spectral_norm(nn.Conv2d(in_channels = chn, out_channels = chn * 2 , kernel_size= k_size, stride = 2,padding=2, bias= enable_bias)), # 1/4 - nn.LeakyReLU(slop) - ) - self.aux_classfier1 = nn.Sequential( - utils.spectral_norm(nn.Conv2d(in_channels = chn * 2 , out_channels = chn , kernel_size= 5, bias=enable_bias)), - nn.LeakyReLU(slop), - nn.AdaptiveAvgPool2d(1), - ) - self.embed1 = utils.spectral_norm(nn.Embedding(n_class, chn)) - self.linear1= utils.spectral_norm(nn.Linear(chn, 1)) - - # stage 2 - self.block2 = nn.Sequential( - utils.spectral_norm(nn.Conv2d(in_channels = chn * 2 , out_channels = chn * 4 , kernel_size= k_size, stride = 2, padding=2, bias= enable_bias)),# 1/8 - nn.LeakyReLU(slop), - utils.spectral_norm(nn.Conv2d(in_channels = chn * 4, out_channels = chn * 8 , kernel_size= k_size, stride = 2,padding=2, bias= enable_bias)),# 1/16 - nn.LeakyReLU(slop) - ) - self.aux_classfier2 = nn.Sequential( - utils.spectral_norm(nn.Conv2d(in_channels = chn * 8 , out_channels = chn , kernel_size= 5, bias= enable_bias)), - nn.LeakyReLU(slop), - nn.AdaptiveAvgPool2d(1), - ) - self.embed2 = utils.spectral_norm(nn.Embedding(n_class, chn)) - self.linear2= utils.spectral_norm(nn.Linear(chn, 1)) - - # stage 3 - self.block3 = nn.Sequential( - utils.spectral_norm(nn.Conv2d(in_channels = chn * 8 , out_channels = chn * 8 , kernel_size= k_size, stride = 2,padding=3, bias= enable_bias)),# 1/32 - nn.LeakyReLU(slop), - utils.spectral_norm(nn.Conv2d(in_channels = chn * 8, out_channels = chn * 16 , kernel_size= k_size, stride = 2,padding=3, bias= enable_bias)),# 1/64 - nn.LeakyReLU(slop) - ) - self.aux_classfier3 = nn.Sequential( - utils.spectral_norm(nn.Conv2d(in_channels = chn * 16 , out_channels = chn, kernel_size= 5, bias= enable_bias)), - nn.LeakyReLU(slop), - nn.AdaptiveAvgPool2d(1), - ) - self.embed3 = utils.spectral_norm(nn.Embedding(n_class, chn)) - self.linear3= utils.spectral_norm(nn.Linear(chn, 1)) - self.__weights_init__() - - def __weights_init__(self): - print("Init weights") - for m in self.modules(): - if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear): - nn.init.xavier_uniform_(m.weight) - try: - nn.init.zeros_(m.bias) - except: - print("No bias found!") - - if isinstance(m, nn.Embedding): - nn.init.xavier_uniform_(m.weight) - - def forward(self, input, condition): - - h = self.block1(input) - prep1 = self.aux_classfier1(h) - prep1 = prep1.view(prep1.size()[0], -1) - y1 = self.embed1(condition) - y1 = torch.sum(y1 * prep1, dim=1, keepdim=True) - prep1 = self.linear1(prep1) + y1 - - h = self.block2(h) - prep2 = self.aux_classfier2(h) - prep2 = prep2.view(prep2.size()[0], -1) - y2 = self.embed2(condition) - y2 = torch.sum(y2 * prep2, dim=1, keepdim=True) - prep2 = self.linear2(prep2) + y2 - - h = self.block3(h) - prep3 = self.aux_classfier3(h) - prep3 = prep3.view(prep3.size()[0], -1) - y3 = self.embed3(condition) - y3 = torch.sum(y3 * prep3, dim=1, keepdim=True) - prep3 = self.linear3(prep3) + y3 - - out_prep = [prep1,prep2,prep3] - return out_prep - - def get_outputs_len(self): - num = 0 - for m in self.modules(): - if isinstance(m,nn.Linear): - num+=1 - return num - -if __name__ == "__main__": - wocao = Discriminator().cuda() - from torchsummary import summary - summary(wocao, input_size=(3, 512, 512)) \ No newline at end of file diff --git a/components/Conditional_Generator_Noskip.py b/components/Conditional_Generator_Noskip.py deleted file mode 100644 index 8a1a046..0000000 --- a/components/Conditional_Generator_Noskip.py +++ /dev/null @@ -1,114 +0,0 @@ - -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: Conditional_Generator_tanh.py -# Created Date: Saturday April 18th 2020 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 6th July 2021 1:16:46 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2020 Shanghai Jiao Tong University -############################################################# - - -import torch - -from torch import nn -from torch.nn import init -from torch.nn import functional as F - -from components.ResBlock import ResBlock -from components.DeConv import DeConv -from components.Conditional_ResBlock_ModulaConv import Conditional_ResBlock - -class Generator(nn.Module): - def __init__( - self, - chn=32, - k_size=3, - res_num = 5, - class_num = 3, - **kwargs): - super().__init__() - padding_size = int((k_size -1)/2) - self.resblock_list = [] - self.n_class = class_num - self.encoder1 = nn.Sequential( - # nn.InstanceNorm2d(3, affine=True), - # nn.ReflectionPad2d(padding_size), - nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size= k_size, stride=1, padding=1, bias= False), - nn.InstanceNorm2d(chn, affine=True, momentum=0), - # nn.ReLU(), - nn.LeakyReLU(), - # nn.ReflectionPad2d(padding_size), - nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size= k_size, stride=2, padding=1,bias =False), # - nn.InstanceNorm2d(chn*2, affine=True, momentum=0), - # nn.ReLU(), - nn.LeakyReLU(), - # nn.ReflectionPad2d(padding_size), - nn.Conv2d(in_channels = chn*2, out_channels = chn * 4, kernel_size= k_size, stride=2, padding=1,bias =False), - nn.InstanceNorm2d(chn * 4, affine=True, momentum=0), - # nn.ReLU(), - nn.LeakyReLU(), - # nn.ReflectionPad2d(padding_size), - nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size= k_size, stride=2, padding=1,bias =False), - nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.ReLU(), - nn.LeakyReLU(), - # # nn.ReflectionPad2d(padding_size), - nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size= k_size, stride=2, padding=1,bias =False), - nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.ReLU(), - nn.LeakyReLU() - ) - - - res_size = chn * 8 - for _ in range(res_num-1): - self.resblock_list += [ResBlock(res_size,k_size),] - self.resblocks = nn.Sequential(*self.resblock_list) - self.conditional_res = Conditional_ResBlock(res_size, k_size, class_num) - self.decoder1 = nn.Sequential( - DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size= k_size), - nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.ReLU(), - nn.LeakyReLU(), - DeConv(in_channels = chn * 8, out_channels = chn *4, kernel_size= k_size), - nn.InstanceNorm2d(chn *4, affine=True, momentum=0), - # nn.ReLU(), - nn.LeakyReLU(), - DeConv(in_channels = chn * 4, out_channels = chn * 2 , kernel_size= k_size), - nn.InstanceNorm2d(chn*2, affine=True, momentum=0), - # nn.ReLU(), - nn.LeakyReLU(), - DeConv(in_channels = chn *2, out_channels = chn, kernel_size= k_size), - nn.InstanceNorm2d(chn, affine=True, momentum=0), - # nn.ReLU(), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn, out_channels =3, kernel_size= k_size, stride=1, padding=1,bias =True) - # nn.Tanh() - ) - - self.__weights_init__() - - def __weights_init__(self): - for layer in self.encoder1: - if isinstance(layer,nn.Conv2d): - nn.init.xavier_uniform_(layer.weight) - - # for layer in self.encoder2: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - def forward(self, input, condition=None, get_feature = False): - feature = self.encoder1(input) - if get_feature: - return feature - out = self.conditional_res(feature, condition) - out = self.resblocks(out) - # n, _,h,w = out.size() - # attr = condition.view((n, self.n_class, 1, 1)).expand((n, self.n_class, h, w)) - # out = torch.cat([out, attr], dim=1) - out = self.decoder1(out) - return out,feature \ No newline at end of file diff --git a/components/Conditional_ResBlock_ModulaConv.py b/components/Conditional_ResBlock_ModulaConv.py deleted file mode 100644 index dfc4401..0000000 --- a/components/Conditional_ResBlock_ModulaConv.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: Conditional_ResBlock_v2.py -# Created Date: Tuesday June 29th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 29th June 2021 3:59:44 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -# -*- coding:utf-8 -*- -################################################################### -### @FilePath: \ASMegaGAN\components\Conditional_ResBlock_v2.py -### @Author: Ziang Liu -### @Date: 2021-06-28 21:30:17 -### @LastEditors: Ziang Liu -### @LastEditTime: 2021-06-28 21:46:24 -### @Copyright (C) 2021 SJTU. All rights reserved. -################################################################### -import torch -from torch import nn -import torch.nn.functional as F -# from ops.Conditional_BN import Conditional_BN -# from components.Adain import Adain - -class Conv2DMod(nn.Module): - def __init__(self, in_channels, out_channels, kernel, demod=True, stride=1, dilation=1, eps = 1e-8, **kwargs): - super().__init__() - self.filters = out_channels - self.demod = demod - self.kernel = kernel - self.stride = stride - self.dilation = dilation - self.weight = nn.Parameter(torch.randn((out_channels, in_channels, kernel, kernel))) - self.eps = eps - - padding_size = int((kernel -1)/2) - self.same_padding = nn.ReplicationPad2d(padding_size) - nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') - - def forward(self, x, y): - b, c, h, w = x.shape - - w1 = y[:, None, :, None, None] - w2 = self.weight[None, :, :, :, :] - weights = w2 * (w1 + 1) - - if self.demod: - d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps) - weights = weights * d - - x = x.reshape(1, -1, h, w) - - _, _, *ws = weights.shape - weights = weights.reshape(b * self.filters, *ws) - - x = self.same_padding(x) - x = F.conv2d(x, weights, groups=b) - - x = x.reshape(-1, self.filters, h, w) - return x - -class Conditional_ResBlock(nn.Module): - def __init__(self, in_channel, k_size = 3, n_class = 2, stride=1): - super().__init__() - - self.embed1 = nn.Embedding(n_class, in_channel) - self.embed2 = nn.Embedding(n_class, in_channel) - self.conv1 = Conv2DMod(in_channels = in_channel , out_channels = in_channel, kernel= k_size, stride=stride) - self.conv2 = Conv2DMod(in_channels = in_channel , out_channels = in_channel, kernel= k_size, stride=stride) - - def forward(self, input, condition): - res = input - style1 = self.embed1(condition) - h = self.conv1(res, style1) - style2 = self.embed2(condition) - h = self.conv2(h, style2) - out = h + res - return out \ No newline at end of file diff --git a/components/DeConv.py b/components/DeConv.py deleted file mode 100644 index ed31179..0000000 --- a/components/DeConv.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch -from torch import nn - -class DeConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size = 3, upsampl_scale = 2): - super().__init__() - self.upsampling = nn.UpsamplingNearest2d(scale_factor=upsampl_scale) - padding_size = int((kernel_size -1)/2) - # self.same_padding = nn.ReflectionPad2d(padding_size) - self.conv = nn.Conv2d(in_channels = in_channels ,padding=padding_size, out_channels = out_channels , kernel_size= kernel_size, bias= False) - self.__weights_init__() - - def __weights_init__(self): - nn.init.xavier_uniform_(self.conv.weight) - - def forward(self, input): - h = self.upsampling(input) - # h = self.same_padding(h) - h = self.conv(h) - return h \ No newline at end of file diff --git a/components/Discriminator.py b/components/Discriminator.py deleted file mode 100644 index adf89e6..0000000 --- a/components/Discriminator.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -import torch.nn as nn - -class Discriminator(nn.Module): - def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, use_sigmoid=False): - super(Discriminator, self).__init__() - - kw = 4 - padw = 1 - self.down1 = nn.Sequential( - nn.Conv2d(input_nc, 64, kernel_size=kw, stride=2, padding=padw), - norm_layer(64), - nn.LeakyReLU(0.2, True) - ) - self.down2 = nn.Sequential( - nn.Conv2d(64, 128, kernel_size=kw, stride=2, padding=padw), - norm_layer(128), - nn.LeakyReLU(0.2, True) - ) - self.down3 = nn.Sequential( - nn.Conv2d(128, 256, kernel_size=kw, stride=2, padding=padw), - norm_layer(256), - nn.LeakyReLU(0.2, True) - ) - self.down4 = nn.Sequential( - nn.Conv2d(256, 512, kernel_size=kw, stride=2, padding=padw), - norm_layer(512), - nn.LeakyReLU(0.2, True) - ) - self.down5 = nn.Sequential( - nn.Conv2d(512, 512, kernel_size=kw, stride=2, padding=padw), - norm_layer(512), - nn.LeakyReLU(0.2, True) - ) - self.conv1 = nn.Sequential( - nn.Conv2d(512, 512, kernel_size=kw, stride=1, padding=padw), - norm_layer(512), - nn.LeakyReLU(0.2, True) - ) - - if use_sigmoid: - self.conv2 = nn.Sequential( - nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw), - nn.Sigmoid() - ) - else: - self.conv2 = nn.Sequential( - nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw) - ) - - def forward(self, input): - out = [] - x = self.down1(input) - #out.append(x) - x = self.down2(x) - #out.append(x) - x = self.down3(x) - #out.append(x) - x = self.down4(x) - x = self.down5(x) - out.append(x) - x = self.conv1(x) - out.append(x) - x = self.conv2(x) - out.append(x) - - return out \ No newline at end of file diff --git a/components/FastNST_CNN.py b/components/FastNST_CNN.py deleted file mode 100644 index 591de93..0000000 --- a/components/FastNST_CNN.py +++ /dev/null @@ -1,129 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: Conditional_Generator_gpt_LN_encoder copy.py -# Created Date: Saturday October 9th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Monday, 11th October 2021 5:22:22 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -import torch - -from torch import nn -from torch.nn import init -from torch.nn import functional as F - -from components.ResBlock import ResBlock -from components.DeConv import DeConv - -class ImageLN(nn.Module): - def __init__(self, dim) -> None: - super().__init__() - self.layer = nn.LayerNorm(dim) - def forward(self, x): - y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2) - return y - -class Generator(nn.Module): - def __init__( - self, - **kwargs - ): - super().__init__() - - chn = kwargs["g_conv_dim"] - k_size = kwargs["g_kernel_size"] - res_num = kwargs["res_num"] - class_num = kwargs["n_class"] - window_size = kwargs["window_size"] - image_size = kwargs["image_size"] - - padding_size = int((k_size -1)/2) - - self.resblock_list = [] - embed_dim = 96 - window_size = 8 - num_heads = 8 - mlp_ratio = 2. - norm_layer = nn.LayerNorm - qk_scale = None - qkv_bias = True - self.patch_norm = True - self.lnnorm = norm_layer(embed_dim) - - self.encoder = nn.Sequential( - nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False), - nn.InstanceNorm2d(chn), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), # - nn.InstanceNorm2d(chn * 2), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn*2, out_channels = embed_dim, kernel_size=k_size, stride=2, padding=1,bias =False), - nn.InstanceNorm2d(embed_dim), - nn.LeakyReLU(), - ) - - # self.encoder2 = nn.Sequential( - - # nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU() - # ) - self.decoder = nn.Sequential( - # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.LeakyReLU(), - # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.LeakyReLU(), - # DeConv(in_channels = chn * 8, out_channels = chn *4, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 4, affine=True, momentum=0), - # nn.LeakyReLU(), - DeConv(in_channels = embed_dim, out_channels = chn * 2 , kernel_size=k_size), - # nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - nn.InstanceNorm2d(chn * 2), - nn.LeakyReLU(), - DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size), - nn.InstanceNorm2d(chn), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True) - ) - - - # self.__weights_init__() - - # def __weights_init__(self): - # for layer in self.encoder: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - # for layer in self.encoder2: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - def forward(self, input): - x2 = self.encoder(input) - out = self.decoder(x2) - return out - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = 1024 - width = 1024 - model = Generator() - print(model) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) \ No newline at end of file diff --git a/components/FastNST_CNN_Resblock.py b/components/FastNST_CNN_Resblock.py deleted file mode 100644 index 017b534..0000000 --- a/components/FastNST_CNN_Resblock.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: Conditional_Generator_gpt_LN_encoder copy.py -# Created Date: Saturday October 9th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 7:35:08 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -import torch - -from torch import nn -from torch.nn import init -from torch.nn import functional as F - -from components.ResBlock import ResBlock -from components.DeConv import DeConv - -class ImageLN(nn.Module): - def __init__(self, dim) -> None: - super().__init__() - self.layer = nn.LayerNorm(dim) - def forward(self, x): - y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2) - return y - -class Generator(nn.Module): - def __init__( - self, - **kwargs - ): - super().__init__() - - chn = kwargs["g_conv_dim"] - k_size = kwargs["g_kernel_size"] - res_num = kwargs["res_num"] - - padding_size = int((k_size -1)/2) - - self.resblock_list = [] - - self.encoder = nn.Sequential( - nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False), - nn.InstanceNorm2d(chn, affine=True, momentum=0), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), # - nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False), - nn.InstanceNorm2d(chn * 4, affine=True, momentum=0), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False), - nn.InstanceNorm2d(chn * 4, affine=True, momentum=0), - nn.LeakyReLU(), - ) - for _ in range(res_num): - self.resblock_list += [ResBlock(chn * 4,k_size),] - self.resblocks = nn.Sequential(*self.resblock_list) - self.decoder = nn.Sequential( - # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.LeakyReLU(), - # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.LeakyReLU(), - DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size), - nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - nn.LeakyReLU(), - DeConv(in_channels = chn * 2, out_channels = chn * 2 , kernel_size=k_size), - nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - nn.ReLU(), - DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size), - nn.InstanceNorm2d(chn, affine=True, momentum=0), - nn.ReLU(), - nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True) - ) - - - # self.__weights_init__() - - # def __weights_init__(self): - # for layer in self.encoder: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - # for layer in self.encoder2: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - def forward(self, input): - x2 = self.encoder(input) - x2 = self.resblocks(x2) - out = self.decoder(x2) - return out - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = 1024 - width = 1024 - model = Generator() - print(model) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) \ No newline at end of file diff --git a/components/FastNST_Liif.py b/components/FastNST_Liif.py deleted file mode 100644 index 9d47f9d..0000000 --- a/components/FastNST_Liif.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: FastNST_Liif.py -# Created Date: Thursday October 14th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 2:39:09 am -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -import torch - -from torch import nn -from torch.nn import init -from torch.nn import functional as F - -from components.ResBlock import ResBlock -from components.DeConv import DeConv -from components.Liif import LIIF - -class ImageLN(nn.Module): - def __init__(self, dim) -> None: - super().__init__() - self.layer = nn.LayerNorm(dim) - def forward(self, x): - y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2) - return y - -class Generator(nn.Module): - def __init__( - self, - **kwargs - ): - super().__init__() - - chn = kwargs["g_conv_dim"] - k_size = kwargs["g_kernel_size"] - res_num = kwargs["res_num"] - class_num = kwargs["n_class"] - window_size = kwargs["window_size"] - image_size = kwargs["image_size"] - batch_size = kwargs["batch_size"] - # mlp_in_dim = kwargs["mlp_in_dim"] - # mlp_out_dim = kwargs["mlp_out_dim"] - mlp_hidden_list = kwargs["mlp_hidden_list"] - - padding_size = int((k_size -1)/2) - - self.resblock_list = [] - embed_dim = 96 - window_size = 8 - num_heads = 8 - mlp_ratio = 2. - norm_layer = nn.LayerNorm - qk_scale = None - qkv_bias = True - self.patch_norm = True - self.lnnorm = norm_layer(embed_dim) - - self.encoder = nn.Sequential( - nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False), - nn.InstanceNorm2d(chn), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), # - nn.InstanceNorm2d(chn * 2), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False), - nn.InstanceNorm2d(chn * 4), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False), - ImageLN(chn * 4), - nn.LeakyReLU(), - ) - for _ in range(res_num): - self.resblock_list += [ResBlock(chn * 4,k_size),] - self.resblocks = nn.Sequential(*self.resblock_list) - # self.encoder2 = nn.Sequential( - - # nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU() - # ) - self.decoder = nn.Sequential( - # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.LeakyReLU(), - # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.LeakyReLU(), - DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size), - nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - nn.LeakyReLU(), - DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - nn.InstanceNorm2d(chn), - nn.LeakyReLU() - # DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size), - # nn.InstanceNorm2d(chn), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True) - ) - - self.upsample = LIIF(chn, 3, mlp_hidden_list) - self.upsample.gen_coord((batch_size, \ - chn,image_size//2,image_size//2),(image_size,image_size)) - - # self.__weights_init__() - - # def __weights_init__(self): - # for layer in self.encoder: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - # for layer in self.encoder2: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - def forward(self, input): - x2 = self.encoder(input) - x2 = self.resblocks(x2) - out = self.decoder(x2) - out = self.upsample(out) - return out - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = 1024 - width = 1024 - model = Generator() - print(model) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) \ No newline at end of file diff --git a/components/FastNST_Liif_warp.py b/components/FastNST_Liif_warp.py deleted file mode 100644 index c684ecf..0000000 --- a/components/FastNST_Liif_warp.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: FastNST_Liif.py -# Created Date: Thursday October 14th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 4:33:51 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -import torch - -from torch import nn -from torch.nn import init -from torch.nn import functional as F - -from components.ResBlock import ResBlock -from components.DeConv import DeConv -from components.Liif_conv import LIIF - -class ImageLN(nn.Module): - def __init__(self, dim) -> None: - super().__init__() - self.layer = nn.LayerNorm(dim) - def forward(self, x): - y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2) - return y - -class Generator(nn.Module): - def __init__( - self, - **kwargs - ): - super().__init__() - - chn = kwargs["g_conv_dim"] - k_size = kwargs["g_kernel_size"] - res_num = kwargs["res_num"] - class_num = kwargs["n_class"] - window_size = kwargs["window_size"] - image_size = kwargs["image_size"] - batch_size = kwargs["batch_size"] - # mlp_in_dim = kwargs["mlp_in_dim"] - # mlp_out_dim = kwargs["mlp_out_dim"] - - - padding_size = int((k_size -1)/2) - - self.resblock_list = [] - embed_dim = 96 - window_size = 8 - num_heads = 8 - mlp_ratio = 2. - norm_layer = nn.LayerNorm - qk_scale = None - qkv_bias = True - self.patch_norm = True - self.lnnorm = norm_layer(embed_dim) - - self.encoder = nn.Sequential( - nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False), - nn.InstanceNorm2d(chn, affine=True, momentum=0), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), # - nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False), - nn.InstanceNorm2d(chn * 4, affine=True, momentum=0), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False), - nn.InstanceNorm2d(chn * 4, affine=True, momentum=0), - nn.LeakyReLU(), - ) - for _ in range(res_num): - self.resblock_list += [ResBlock(chn * 4,k_size),] - self.resblocks = nn.Sequential(*self.resblock_list) - # self.encoder2 = nn.Sequential( - - # nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False), - # ImageLN(chn * 8), - # nn.LeakyReLU() - # ) - self.decoder = nn.Sequential( - # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.LeakyReLU(), - # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # nn.LeakyReLU(), - DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size), - nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - nn.LeakyReLU(), - # DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size), - # # nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - # nn.InstanceNorm2d(chn, affine=True, momentum=0), - # nn.LeakyReLU() - # DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size), - # nn.InstanceNorm2d(chn), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True) - ) - - self.upsample1 = LIIF(chn*2, chn) - self.upsample1.gen_coord((batch_size, \ - chn,image_size//4,image_size//4),(image_size//2,image_size//2)) - - self.upsample2 = LIIF(chn, chn) - self.upsample2.gen_coord((batch_size, \ - chn,image_size//2,image_size//2),(image_size,image_size)) - self.out_conv = nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True) - # self.__weights_init__() - - # def __weights_init__(self): - # for layer in self.encoder: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - # for layer in self.encoder2: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - def forward(self, input): - x2 = self.encoder(input) - x2 = self.resblocks(x2) - out = self.decoder(x2) - out = self.upsample1(out) - out = self.upsample2(out) - out = self.out_conv(out) - return out - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = 1024 - width = 1024 - model = Generator() - print(model) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) \ No newline at end of file diff --git a/components/FastNST_Liif_warpinvo.py b/components/FastNST_Liif_warpinvo.py deleted file mode 100644 index 9d94d1e..0000000 --- a/components/FastNST_Liif_warpinvo.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: FastNST_Liif.py -# Created Date: Thursday October 14th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 8:47:28 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -import torch - -from torch import nn -from torch.nn import init -from torch.nn import functional as F - -from components.ResBlock import ResBlock -from components.DeConv import DeConv -from components.Liif_invo import LIIF - - -class Generator(nn.Module): - def __init__( - self, - **kwargs - ): - super().__init__() - - chn = kwargs["g_conv_dim"] - k_size = kwargs["g_kernel_size"] - res_num = kwargs["res_num"] - class_num = kwargs["n_class"] - window_size = kwargs["window_size"] - image_size = kwargs["image_size"] - batch_size = kwargs["batch_size"] - # mlp_in_dim = kwargs["mlp_in_dim"] - # mlp_out_dim = kwargs["mlp_out_dim"] - - - padding_size = int((k_size -1)/2) - - self.resblock_list = [] - embed_dim = 96 - norm_layer = nn.LayerNorm - - self.img_token = nn.Sequential( - nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False), - nn.InstanceNorm2d(chn, affine=True, momentum=0), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), # - nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False), - # nn.InstanceNorm2d(chn * 4, affine=True, momentum=0), - # nn.LeakyReLU(), - # nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False), - # nn.InstanceNorm2d(chn * 4, affine=True, momentum=0), - # nn.LeakyReLU(), - ) - image_size = image_size // 2 - self.downsample1 = LIIF(chn * 2, chn * 4) - self.downsample1.gen_coord((batch_size, \ - chn,image_size,image_size),(image_size//2,image_size//2)) - image_size = image_size // 2 - self.downsample2 = LIIF(chn * 4, chn * 4) - self.downsample2.gen_coord((batch_size, \ - chn,image_size,image_size),(image_size//2,image_size//2)) - - - for _ in range(res_num): - self.resblock_list += [ResBlock(chn * 4,k_size),] - self.resblocks = nn.Sequential(*self.resblock_list) - # self.decoder = nn.Sequential( - # # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # # nn.LeakyReLU(), - # # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size), - # # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0), - # # nn.LeakyReLU(), - # DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size), - # nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - # nn.LeakyReLU(), - # # DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size), - # # # nn.InstanceNorm2d(chn * 2, affine=True, momentum=0), - # # nn.InstanceNorm2d(chn, affine=True, momentum=0), - # # nn.LeakyReLU() - # # DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size), - # # nn.InstanceNorm2d(chn), - # # nn.LeakyReLU(), - # # nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True) - # ) - image_size = image_size // 2 - self.upsample1 = LIIF(chn*4, chn * 4) - self.upsample1.gen_coord((batch_size, \ - chn,image_size,image_size),(image_size*2,image_size*2)) - image_size = image_size * 2 - self.upsample2 = LIIF(chn*4, chn * 2) - self.upsample2.gen_coord((batch_size, \ - chn,image_size,image_size),(image_size*2,image_size*2)) - # image_size = image_size * 2 - # self.upsample2 = LIIF(chn, chn) - # self.upsample2.gen_coord((batch_size, \ - # chn,image_size,image_size),(image_size*2,image_size*2)) - self.decoder = nn.Sequential( - DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size), - nn.InstanceNorm2d(chn, affine=True, momentum=0), - nn.LeakyReLU(), - nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True) - ) - # self.out_conv = nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True) - # self.__weights_init__() - - # def __weights_init__(self): - # for layer in self.encoder: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - # for layer in self.encoder2: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - - def forward(self, input): - out = self.img_token(input) - out = self.downsample1(out) - out = self.downsample2(out) - out = self.resblocks(out) - - out = self.upsample1(out) - out = self.upsample2(out) - out = self.decoder(out) - return out - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = 1024 - width = 1024 - model = Generator() - print(model) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) \ No newline at end of file diff --git a/components/Generator.py b/components/Generator.py index ec91658..56552d2 100644 --- a/components/Generator.py +++ b/components/Generator.py @@ -1,112 +1,186 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- ############################################################# -# File: Conditional_Generator_gpt_LN_encoder copy.py -# Created Date: Saturday October 9th 2021 +# File: Generator.py +# Created Date: Sunday January 16th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 26th October 2021 3:25:47 pm +# Last Modified: Sunday, 16th January 2022 11:42:14 pm # Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University +# Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# - import torch from torch import nn -from ResBlock_Adain import ResBlock_Adain +from torch.nn import init +from torch.nn import functional as F + +class InstanceNorm(nn.Module): + def __init__(self, epsilon=1e-8): + """ + @notice: avoid in-place ops. + https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 + """ + super(InstanceNorm, self).__init__() + self.epsilon = epsilon + + def forward(self, x): + x = x - torch.mean(x, (2, 3), True) + tmp = torch.mul(x, x) # or x ** 2 + tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) + return x * tmp + +class ApplyStyle(nn.Module): + """ + @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb + """ + def __init__(self, latent_size, channels): + super(ApplyStyle, self).__init__() + self.linear = nn.Linear(latent_size, channels * 2) + + def forward(self, x, latent): + style = self.linear(latent) # style => [batch_size, n_channels*2] + shape = [-1, 2, x.size(1), 1, 1] + style = style.view(shape) # [batch_size, 2, n_channels, ...] + #x = x * (style[:, 0] + 1.) + style[:, 1] + x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1 + return x + +class ResnetBlock_Adain(nn.Module): + def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): + super(ResnetBlock_Adain, self).__init__() + + p = 0 + conv1 = [] + if padding_type == 'reflect': + conv1 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv1 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] + self.conv1 = nn.Sequential(*conv1) + self.style1 = ApplyStyle(latent_size, dim) + self.act1 = activation + + p = 0 + conv2 = [] + if padding_type == 'reflect': + conv2 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv2 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] + self.conv2 = nn.Sequential(*conv2) + self.style2 = ApplyStyle(latent_size, dim) + + + def forward(self, x, dlatents_in_slice): + y = self.conv1(x) + y = self.style1(y, dlatents_in_slice) + y = self.act1(y) + y = self.conv2(y) + y = self.style2(y, dlatents_in_slice) + out = x + y + return out -from functools import partial class Generator(nn.Module): def __init__( self, **kwargs ): - super(Generator, self).__init__() + super().__init__() - input_nc = kwargs["g_conv_dim"] - output_nc = kwargs["g_kernel_size"] - latent_size = kwargs["latent_size"] - n_blocks = kwargs["resblock_num"] - norm_name = kwargs["norm_name"] - padding_type= kwargs["reflect"] - - if norm_name == "bn": - norm_layer = partial(nn.BatchNorm2d, affine = True, track_running_stats=True) - elif norm_name == "in": - norm_name = nn.InstanceNorm2d + chn = kwargs["g_conv_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' - assert (n_blocks >= 0) activation = nn.ReLU(True) - self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0), - norm_layer(64), activation) + self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), activation) ### downsample self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), - norm_layer(128), activation) + nn.BatchNorm2d(128), activation) + self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), - norm_layer(256), activation) + nn.BatchNorm2d(256), activation) + self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), - norm_layer(512), activation) + nn.BatchNorm2d(512), activation) + self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), - norm_layer(512), activation) + nn.BatchNorm2d(512), activation) ### resnet blocks BN = [] - for i in range(n_blocks): + for i in range(res_num): BN += [ - ResBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)] + ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)] self.BottleNeck = nn.Sequential(*BN) - if self.deep: - self.up4 = nn.Sequential( - nn.Upsample(scale_factor=2, mode='bilinear'), - nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), - nn.BatchNorm2d(512), activation - ) + self.up4 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(512), activation + ) + self.up3 = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), activation ) + self.up2 = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), activation ) + self.up1 = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), activation ) - self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0)) + + self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) def forward(self, input, id): x = input # 3*224*224 - res = self.first_layer(x) - res = self.down1(res) - res = self.down2(res) - res = self.down4(res) - res = self.down3(res) + skip1 = self.first_layer(x) + skip2 = self.down1(skip1) + skip3 = self.down2(skip2) + skip4 = self.down3(skip3) + res = self.down4(skip4) for i in range(len(self.BottleNeck)): - res = self.BottleNeck[i](res, id) + x = self.BottleNeck[i](res, id) - res = self.up4(res) - res = self.up3(res) - res = self.up2(res) - res = self.up1(res) - res = self.last_layer(res) - return res - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = 1024 - width = 1024 - model = Generator() - print(model) + x = self.up4(x) + x = self.up3(x) + x = self.up2(x) + x = self.up1(x) + x = self.last_layer(x) - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) \ No newline at end of file + return x diff --git a/components/Involution.py b/components/Involution.py deleted file mode 100644 index c0bd15e..0000000 --- a/components/Involution.py +++ /dev/null @@ -1,303 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: Involution.py -# Created Date: Tuesday July 20th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 20th July 2021 10:35:52 am -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -from torch.autograd import Function -import torch -from torch.nn.modules.utils import _pair -import torch.nn.functional as F -import torch.nn as nn -from mmcv.cnn import ConvModule - - -from collections import namedtuple -import cupy -from string import Template - - -Stream = namedtuple('Stream', ['ptr']) - - -def Dtype(t): - if isinstance(t, torch.cuda.FloatTensor): - return 'float' - elif isinstance(t, torch.cuda.DoubleTensor): - return 'double' - - -@cupy._util.memoize(for_each_device=True) -def load_kernel(kernel_name, code, **kwargs): - code = Template(code).substitute(**kwargs) - kernel_code = cupy.cuda.compile_with_cache(code) - return kernel_code.get_function(kernel_name) - - -CUDA_NUM_THREADS = 1024 - -kernel_loop = ''' -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ - i < (n); \ - i += blockDim.x * gridDim.x) -''' - - -def GET_BLOCKS(N): - return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS - - -_involution_kernel = kernel_loop + ''' -extern "C" -__global__ void involution_forward_kernel( -const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) { - CUDA_KERNEL_LOOP(index, ${nthreads}) { - const int n = index / ${channels} / ${top_height} / ${top_width}; - const int c = (index / ${top_height} / ${top_width}) % ${channels}; - const int h = (index / ${top_width}) % ${top_height}; - const int w = index % ${top_width}; - const int g = c / (${channels} / ${groups}); - ${Dtype} value = 0; - #pragma unroll - for (int kh = 0; kh < ${kernel_h}; ++kh) { - #pragma unroll - for (int kw = 0; kw < ${kernel_w}; ++kw) { - const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; - const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; - if ((h_in >= 0) && (h_in < ${bottom_height}) - && (w_in >= 0) && (w_in < ${bottom_width})) { - const int offset = ((n * ${channels} + c) * ${bottom_height} + h_in) - * ${bottom_width} + w_in; - const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h) - * ${top_width} + w; - value += weight_data[offset_weight] * bottom_data[offset]; - } - } - } - top_data[index] = value; - } -} -''' - - -_involution_kernel_backward_grad_input = kernel_loop + ''' -extern "C" -__global__ void involution_backward_grad_input_kernel( - const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* const bottom_diff) { - CUDA_KERNEL_LOOP(index, ${nthreads}) { - const int n = index / ${channels} / ${bottom_height} / ${bottom_width}; - const int c = (index / ${bottom_height} / ${bottom_width}) % ${channels}; - const int h = (index / ${bottom_width}) % ${bottom_height}; - const int w = index % ${bottom_width}; - const int g = c / (${channels} / ${groups}); - ${Dtype} value = 0; - #pragma unroll - for (int kh = 0; kh < ${kernel_h}; ++kh) { - #pragma unroll - for (int kw = 0; kw < ${kernel_w}; ++kw) { - const int h_out_s = h + ${pad_h} - kh * ${dilation_h}; - const int w_out_s = w + ${pad_w} - kw * ${dilation_w}; - if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { - const int h_out = h_out_s / ${stride_h}; - const int w_out = w_out_s / ${stride_w}; - if ((h_out >= 0) && (h_out < ${top_height}) - && (w_out >= 0) && (w_out < ${top_width})) { - const int offset = ((n * ${channels} + c) * ${top_height} + h_out) - * ${top_width} + w_out; - const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h_out) - * ${top_width} + w_out; - value += weight_data[offset_weight] * top_diff[offset]; - } - } - } - } - bottom_diff[index] = value; - } -} -''' - - -_involution_kernel_backward_grad_weight = kernel_loop + ''' -extern "C" -__global__ void involution_backward_grad_weight_kernel( - const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* const buffer_data) { - CUDA_KERNEL_LOOP(index, ${nthreads}) { - const int h = (index / ${top_width}) % ${top_height}; - const int w = index % ${top_width}; - const int kh = (index / ${kernel_w} / ${top_height} / ${top_width}) - % ${kernel_h}; - const int kw = (index / ${top_height} / ${top_width}) % ${kernel_w}; - const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; - const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; - if ((h_in >= 0) && (h_in < ${bottom_height}) - && (w_in >= 0) && (w_in < ${bottom_width})) { - const int g = (index / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${groups}; - const int n = (index / ${groups} / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${num}; - ${Dtype} value = 0; - #pragma unroll - for (int c = g * (${channels} / ${groups}); c < (g + 1) * (${channels} / ${groups}); ++c) { - const int top_offset = ((n * ${channels} + c) * ${top_height} + h) - * ${top_width} + w; - const int bottom_offset = ((n * ${channels} + c) * ${bottom_height} + h_in) - * ${bottom_width} + w_in; - value += top_diff[top_offset] * bottom_data[bottom_offset]; - } - buffer_data[index] = value; - } else { - buffer_data[index] = 0; - } - } -} -''' - - -class _involution(Function): - @staticmethod - def forward(ctx, input, weight, stride, padding, dilation): - assert input.dim() == 4 and input.is_cuda - assert weight.dim() == 6 and weight.is_cuda - batch_size, channels, height, width = input.size() - kernel_h, kernel_w = weight.size()[2:4] - output_h = int((height + 2 * padding[0] - (dilation[0] * (kernel_h - 1) + 1)) / stride[0] + 1) - output_w = int((width + 2 * padding[1] - (dilation[1] * (kernel_w - 1) + 1)) / stride[1] + 1) - - output = input.new(batch_size, channels, output_h, output_w) - n = output.numel() - - with torch.cuda.device_of(input): - f = load_kernel('involution_forward_kernel', _involution_kernel, Dtype=Dtype(input), nthreads=n, - num=batch_size, channels=channels, groups=weight.size()[1], - bottom_height=height, bottom_width=width, - top_height=output_h, top_width=output_w, - kernel_h=kernel_h, kernel_w=kernel_w, - stride_h=stride[0], stride_w=stride[1], - dilation_h=dilation[0], dilation_w=dilation[1], - pad_h=padding[0], pad_w=padding[1]) - f(block=(CUDA_NUM_THREADS,1,1), - grid=(GET_BLOCKS(n),1,1), - args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()], - stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) - - ctx.save_for_backward(input, weight) - ctx.stride, ctx.padding, ctx.dilation = stride, padding, dilation - return output - - @staticmethod - def backward(ctx, grad_output): - assert grad_output.is_cuda and grad_output.is_contiguous() - input, weight = ctx.saved_tensors - stride, padding, dilation = ctx.stride, ctx.padding, ctx.dilation - - batch_size, channels, height, width = input.size() - kernel_h, kernel_w = weight.size()[2:4] - output_h, output_w = grad_output.size()[2:] - - grad_input, grad_weight = None, None - - opt = dict(Dtype=Dtype(grad_output), - num=batch_size, channels=channels, groups=weight.size()[1], - bottom_height=height, bottom_width=width, - top_height=output_h, top_width=output_w, - kernel_h=kernel_h, kernel_w=kernel_w, - stride_h=stride[0], stride_w=stride[1], - dilation_h=dilation[0], dilation_w=dilation[1], - pad_h=padding[0], pad_w=padding[1]) - - with torch.cuda.device_of(input): - if ctx.needs_input_grad[0]: - grad_input = input.new(input.size()) - - n = grad_input.numel() - opt['nthreads'] = n - - f = load_kernel('involution_backward_grad_input_kernel', - _involution_kernel_backward_grad_input, **opt) - f(block=(CUDA_NUM_THREADS,1,1), - grid=(GET_BLOCKS(n),1,1), - args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()], - stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) - - if ctx.needs_input_grad[1]: - grad_weight = weight.new(weight.size()) - - n = grad_weight.numel() - opt['nthreads'] = n - - f = load_kernel('involution_backward_grad_weight_kernel', - _involution_kernel_backward_grad_weight, **opt) - f(block=(CUDA_NUM_THREADS,1,1), - grid=(GET_BLOCKS(n),1,1), - args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()], - stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) - - return grad_input, grad_weight, None, None, None - - -def _involution_cuda(input, weight, bias=None, stride=1, padding=0, dilation=1): - """ involution kernel - """ - assert input.size(0) == weight.size(0) - assert input.size(-2)//stride == weight.size(-2) - assert input.size(-1)//stride == weight.size(-1) - if input.is_cuda: - out = _involution.apply(input, weight, _pair(stride), _pair(padding), _pair(dilation)) - if bias is not None: - out += bias.view(1,-1,1,1) - else: - raise NotImplementedError - return out - - -class involution(nn.Module): - - def __init__(self, - channels, - kernel_size, - stride): - super(involution, self).__init__() - self.kernel_size = kernel_size - self.stride = stride - self.channels = channels - reduction_ratio = 4 - self.group_channels = 8 - self.groups = self.channels // self.group_channels - self.seblock = nn.Sequential( - nn.Conv2d(in_channels = channels, out_channels = channels // reduction_ratio, kernel_size= 1), - nn.InstanceNorm2d(channels // reduction_ratio, affine=True, momentum=0), - nn.ReLU(), - nn.Conv2d(in_channels = channels // reduction_ratio, out_channels = kernel_size**2 * self.groups, kernel_size= 1) - ) - - # self.conv1 = ConvModule( - # in_channels=channels, - # out_channels=channels // reduction_ratio, - # kernel_size=1, - # conv_cfg=None, - # norm_cfg=dict(type='BN'), - # act_cfg=dict(type='ReLU')) - # self.conv2 = ConvModule( - # in_channels=channels // reduction_ratio, - # out_channels=kernel_size**2 * self.groups, - # kernel_size=1, - # stride=1, - # conv_cfg=None, - # norm_cfg=None, - # act_cfg=None) - if stride > 1: - self.avgpool = nn.AvgPool2d(stride, stride) - - def forward(self, x): - # weight = self.conv2(self.conv1(x if self.stride == 1 else self.avgpool(x))) - weight = self.seblock(x) - b, c, h, w = weight.shape - weight = weight.view(b, self.groups, self.kernel_size, self.kernel_size, h, w) - out = _involution_cuda(x, weight, stride=self.stride, padding=(self.kernel_size-1)//2) - return out diff --git a/components/Liif.py b/components/Liif.py deleted file mode 100644 index b6d59d7..0000000 --- a/components/Liif.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: Liif.py -# Created Date: Monday October 18th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 10:27:09 am -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def make_coord(shape, ranges=None, flatten=True): - """ Make coordinates at grid centers. - """ - coord_seqs = [] - for i, n in enumerate(shape): - print("i: %d, n: %d"%(i,n)) - if ranges is None: - v0, v1 = -1, 1 - else: - v0, v1 = ranges[i] - r = (v1 - v0) / (2 * n) - seq = v0 + r + (2 * r) * torch.arange(n).float() - coord_seqs.append(seq) - ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) - if flatten: - ret = ret.view(-1, ret.shape[-1]) - return ret - -class MLP(nn.Module): - - def __init__(self, in_dim, out_dim, hidden_list): - super().__init__() - layers = [] - lastv = in_dim - for hidden in hidden_list: - layers.append(nn.Linear(lastv, hidden)) - layers.append(nn.ReLU()) - lastv = hidden - layers.append(nn.Linear(lastv, out_dim)) - self.layers = nn.Sequential(*layers) - - def forward(self, x): - shape = x.shape[:-1] - x = self.layers(x.view(-1, x.shape[-1])) - return x.view(*shape, -1) - -class LIIF(nn.Module): - - def __init__(self, mlp_in_dim, mlp_out_dim, mlp_hidden_list): - super().__init__() - - imnet_in_dim = mlp_in_dim - imnet_in_dim *= 9 - imnet_in_dim += 2 # attach coord - imnet_in_dim += 2 - self.imnet = MLP(imnet_in_dim, mlp_out_dim, mlp_hidden_list).cuda() - - def gen_coord(self, in_shape, output_size): - - self.vx_lst = [-1, 1] - self.vy_lst = [-1, 1] - eps_shift = 1e-6 - self.image_size=output_size - - # field radius (global: [-1, 1]) - rx = 2 / in_shape[-2] / 2 - ry = 2 / in_shape[-1] / 2 - - coord = make_coord(output_size,flatten=False) \ - .expand(in_shape[0],output_size[0],output_size[1],2) \ - .view(in_shape[0],output_size[0]*output_size[1],2) - - cell = torch.ones_like(coord) - cell[:, :, 0] *= 2 / coord.shape[-2] - cell[:, :, 1] *= 2 / coord.shape[-1] - - feat_coord = make_coord(in_shape[-2:], flatten=False) \ - .permute(2, 0, 1) \ - .unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:]) - - areas = [] - - self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2)) - self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2)) - self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2)) - for vx in self.vx_lst: - for vy in self.vy_lst: - self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone() - self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift - self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift - self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6) - q_coord = F.grid_sample( - feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1), - mode='nearest', align_corners=False)[:, :, 0, :] \ - .permute(0, 2, 1) - self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord - self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2] - self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1] - - self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone() - self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2] - self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1] - area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1]) - areas.append(area + 1e-9) - tot_area = torch.stack(areas).sum(dim=0) - t = areas[0]; areas[0] = areas[3]; areas[3] = t - t = areas[1]; areas[1] = areas[2]; areas[2] = t - self.area_weights = [] - for item in areas: - self.area_weights.append((item / tot_area).unsqueeze(-1).cuda()) - - self.rel_coord = self.rel_coord.cuda() - self.rel_cell = self.rel_cell.cuda() - self.coord_ = self.coord_.cuda() - - def forward(self, feat): - # B K*K*Cin H W - feat = F.unfold(feat, 3, padding=1).view( - feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) - - preds = [] - for vx in [0,1]: - for vy in [0,1]: - q_feat = F.grid_sample( - feat, self.coord_[vx,vy,:,:,:].flip(-1).unsqueeze(1), - mode='nearest', align_corners=False)[:, :, 0, :] \ - .permute(0, 2, 1) - inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1) - - bs, q = self.coord_[0,0,:,:,:].shape[:2] - pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) - # print("pred shape: ",pred.shape) - preds.append(pred) - ret = 0 - for pred, area in zip(preds, self.area_weights): - ret = ret + pred * area - - return ret.permute(0, 2, 1).view(-1,3,self.image_size[0],self.image_size[1]) \ No newline at end of file diff --git a/components/Liif_conv.py b/components/Liif_conv.py deleted file mode 100644 index aef7171..0000000 --- a/components/Liif_conv.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: Liif.py -# Created Date: Monday October 18th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 4:26:26 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def make_coord(shape, ranges=None, flatten=True): - """ Make coordinates at grid centers. - """ - coord_seqs = [] - for i, n in enumerate(shape): - print("i: %d, n: %d"%(i,n)) - if ranges is None: - v0, v1 = -1, 1 - else: - v0, v1 = ranges[i] - r = (v1 - v0) / (2 * n) - seq = v0 + r + (2 * r) * torch.arange(n).float() - coord_seqs.append(seq) - ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) - if flatten: - ret = ret.view(-1, ret.shape[-1]) - return ret - -class MLP(nn.Module): - - def __init__(self, in_dim, out_dim, hidden_list): - super().__init__() - layers = [] - lastv = in_dim - for hidden in hidden_list: - layers.append(nn.Linear(lastv, hidden)) - layers.append(nn.ReLU()) - lastv = hidden - layers.append(nn.Linear(lastv, out_dim)) - self.layers = nn.Sequential(*layers) - - def forward(self, x): - shape = x.shape[:-1] - x = self.layers(x.view(-1, x.shape[-1])) - return x.view(*shape, -1) - -class LIIF(nn.Module): - - def __init__(self, in_dim, out_dim): - super().__init__() - - imnet_in_dim = in_dim - # imnet_in_dim += 2 # attach coord - # imnet_in_dim += 2 - self.imnet = nn.Sequential( \ - nn.Conv2d(in_channels = imnet_in_dim, out_channels = out_dim, kernel_size= 3,padding=1), - nn.InstanceNorm2d(out_dim, affine=True, momentum=0), - nn.LeakyReLU(), - # nn.Conv2d(in_channels = out_dim, out_channels = out_dim, kernel_size= 3,padding=1), - # nn.InstanceNorm2d(out_dim), - # nn.LeakyReLU(), - ) - - def gen_coord(self, in_shape, output_size): - - self.vx_lst = [-1, 1] - self.vy_lst = [-1, 1] - eps_shift = 1e-6 - self.image_size=output_size - - # field radius (global: [-1, 1]) - rx = 2 / in_shape[-2] / 2 - ry = 2 / in_shape[-1] / 2 - - self.coord = make_coord(output_size,flatten=False) \ - .expand(in_shape[0],output_size[0],output_size[1],2) - - # cell = torch.ones_like(coord) - # cell[:, :, 0] *= 2 / coord.shape[-2] - # cell[:, :, 1] *= 2 / coord.shape[-1] - - # feat_coord = make_coord(in_shape[-2:], flatten=False) \ - # .permute(2, 0, 1) \ - # .unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:]) - - # areas = [] - - # self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2)) - # self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2)) - # self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2)) - # for vx in self.vx_lst: - # for vy in self.vy_lst: - # self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone() - # self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift - # self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift - # self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6) - # q_coord = F.grid_sample( - # feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1), - # mode='nearest', align_corners=False)[:, :, 0, :] \ - # .permute(0, 2, 1) - # self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord - # self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2] - # self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1] - - # self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone() - # self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2] - # self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1] - # area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1]) - # areas.append(area + 1e-9) - # tot_area = torch.stack(areas).sum(dim=0) - # t = areas[0]; areas[0] = areas[3]; areas[3] = t - # t = areas[1]; areas[1] = areas[2]; areas[2] = t - # self.area_weights = [] - # for item in areas: - # self.area_weights.append((item / tot_area).unsqueeze(-1).cuda()) - - # self.rel_coord = self.rel_coord.cuda() - # self.rel_cell = self.rel_cell.cuda() - # self.coord_ = self.coord_.cuda() - self.coord = self.coord.cuda() - - - def forward(self, feat): - # B K*K*Cin H W - # feat = F.unfold(feat, 3, padding=1).view( - # feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) - - # preds = [] - # for vx in [0,1]: - # for vy in [0,1]: - # print("feat shape: ", feat.shape) - # print("coor shape: ", self.coord.shape) - q_feat = F.grid_sample( - feat, self.coord, - mode='bilinear', align_corners=False) - out = self.imnet(q_feat) - # inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1) - - # bs, q = self.coord_[0,0,:,:,:].shape[:2] - # pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) - # # print("pred shape: ",pred.shape) - # preds.append(pred) - # ret = 0 - # for pred, area in zip(preds, self.area_weights): - # ret = ret + pred * area - # print("warp output shape: ",out.shape) - - return out \ No newline at end of file diff --git a/components/Liif_invo.py b/components/Liif_invo.py deleted file mode 100644 index aad6e0c..0000000 --- a/components/Liif_invo.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: Liif.py -# Created Date: Monday October 18th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 8:25:18 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - - -import torch -import torch.nn as nn -import torch.nn.functional as F -from components.Involution import involution - - -def make_coord(shape, ranges=None, flatten=True): - """ Make coordinates at grid centers. - """ - coord_seqs = [] - for i, n in enumerate(shape): - print("i: %d, n: %d"%(i,n)) - if ranges is None: - v0, v1 = -1, 1 - else: - v0, v1 = ranges[i] - r = (v1 - v0) / (2 * n) - seq = v0 + r + (2 * r) * torch.arange(n).float() - coord_seqs.append(seq) - ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) - if flatten: - ret = ret.view(-1, ret.shape[-1]) - return ret - -class MLP(nn.Module): - - def __init__(self, in_dim, out_dim, hidden_list): - super().__init__() - layers = [] - lastv = in_dim - for hidden in hidden_list: - layers.append(nn.Linear(lastv, hidden)) - layers.append(nn.ReLU()) - lastv = hidden - layers.append(nn.Linear(lastv, out_dim)) - self.layers = nn.Sequential(*layers) - - def forward(self, x): - shape = x.shape[:-1] - x = self.layers(x.view(-1, x.shape[-1])) - return x.view(*shape, -1) - -class LIIF(nn.Module): - - def __init__(self, in_dim, out_dim): - super().__init__() - - imnet_in_dim = in_dim - # imnet_in_dim += 2 # attach coord - # imnet_in_dim += 2 - - self.conv1x1 = nn.Conv2d(in_channels = imnet_in_dim, out_channels = out_dim, kernel_size= 1) - # self.same_padding = nn.ReflectionPad2d(padding_size) - - # self.conv = involution(out_dim,5,1) - self.imnet = nn.Sequential( \ - # nn.Conv2d(in_channels = imnet_in_dim, out_channels = out_dim, kernel_size= 3,padding=1), - involution(out_dim,5,1), - nn.InstanceNorm2d(out_dim, affine=True, momentum=0), - nn.LeakyReLU(), - # nn.Conv2d(in_channels = out_dim, out_channels = out_dim, kernel_size= 3,padding=1), - # nn.InstanceNorm2d(out_dim), - # nn.LeakyReLU(), - ) - - def gen_coord(self, in_shape, output_size): - - self.vx_lst = [-1, 1] - self.vy_lst = [-1, 1] - eps_shift = 1e-6 - self.image_size=output_size - - # field radius (global: [-1, 1]) - rx = 2 / in_shape[-2] / 2 - ry = 2 / in_shape[-1] / 2 - - self.coord = make_coord(output_size,flatten=False) \ - .expand(in_shape[0],output_size[0],output_size[1],2) - - # cell = torch.ones_like(coord) - # cell[:, :, 0] *= 2 / coord.shape[-2] - # cell[:, :, 1] *= 2 / coord.shape[-1] - - # feat_coord = make_coord(in_shape[-2:], flatten=False) \ - # .permute(2, 0, 1) \ - # .unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:]) - - # areas = [] - - # self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2)) - # self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2)) - # self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2)) - # for vx in self.vx_lst: - # for vy in self.vy_lst: - # self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone() - # self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift - # self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift - # self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6) - # q_coord = F.grid_sample( - # feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1), - # mode='nearest', align_corners=False)[:, :, 0, :] \ - # .permute(0, 2, 1) - # self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord - # self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2] - # self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1] - - # self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone() - # self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2] - # self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1] - # area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1]) - # areas.append(area + 1e-9) - # tot_area = torch.stack(areas).sum(dim=0) - # t = areas[0]; areas[0] = areas[3]; areas[3] = t - # t = areas[1]; areas[1] = areas[2]; areas[2] = t - # self.area_weights = [] - # for item in areas: - # self.area_weights.append((item / tot_area).unsqueeze(-1).cuda()) - - # self.rel_coord = self.rel_coord.cuda() - # self.rel_cell = self.rel_cell.cuda() - # self.coord_ = self.coord_.cuda() - self.coord = self.coord.cuda() - - - def forward(self, feat): - # B K*K*Cin H W - # feat = F.unfold(feat, 3, padding=1).view( - # feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) - - # preds = [] - # for vx in [0,1]: - # for vy in [0,1]: - # print("feat shape: ", feat.shape) - # print("coor shape: ", self.coord.shape) - q_feat = self.conv1x1(feat) - q_feat = F.grid_sample( - q_feat, self.coord, - mode='bilinear', align_corners=False) - out = self.imnet(q_feat) - # inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1) - - # bs, q = self.coord_[0,0,:,:,:].shape[:2] - # pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) - # # print("pred shape: ",pred.shape) - # preds.append(pred) - # ret = 0 - # for pred, area in zip(preds, self.area_weights): - # ret = ret + pred * area - # print("warp output shape: ",out.shape) - - return out \ No newline at end of file diff --git a/components/ResBlock.py b/components/ResBlock.py deleted file mode 100644 index 3ec1add..0000000 --- a/components/ResBlock.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: ResBlock.py -# Created Date: Monday July 5th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Monday, 5th July 2021 12:18:18 am -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - -from torch import nn - -class ResBlock(nn.Module): - def __init__(self, in_channel, k_size = 3, stride=1): - super().__init__() - padding_size = int((k_size -1)/2) - self.block = nn.Sequential( - nn.ReflectionPad2d(padding_size), - nn.Conv2d(in_channels = in_channel , out_channels = in_channel , kernel_size= k_size, stride=stride, bias= False), - nn.InstanceNorm2d(in_channel, affine=True, momentum=0), - nn.ReflectionPad2d(padding_size), - nn.Conv2d(in_channels = in_channel , out_channels = in_channel , kernel_size= k_size, stride=stride, bias= False), - nn.InstanceNorm2d(in_channel, affine=True, momentum=0) - ) - self.__weights_init__() - - def __weights_init__(self): - for m in self.modules(): - if isinstance(m,nn.Conv2d): - nn.init.xavier_uniform_(m.weight) - - def forward(self, input): - res = input - h = self.block(input) - out = h + res - return out diff --git a/components/ResBlock_Adain.py b/components/ResBlock_Adain.py deleted file mode 100644 index ac4b3ec..0000000 --- a/components/ResBlock_Adain.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import torch.nn as nn - -class InstanceNorm(nn.Module): - def __init__(self, epsilon=1e-8): - """ - @notice: avoid in-place ops. - https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 - """ - super(InstanceNorm, self).__init__() - self.epsilon = epsilon - - def forward(self, x): - x = x - torch.mean(x, (2, 3), True) - tmp = torch.mul(x, x) # or x ** 2 - tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) - return x * tmp - -class ApplyStyle(nn.Module): - """ - @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb - """ - def __init__(self, latent_size, channels): - super(ApplyStyle, self).__init__() - self.linear = nn.Linear(latent_size, channels * 2) - - def forward(self, x, latent): - style = self.linear(latent) # style => [batch_size, n_channels*2] - shape = [-1, 2, x.size(1), 1, 1] - style = style.view(shape) # [batch_size, 2, n_channels, ...] - #x = x * (style[:, 0] + 1.) + style[:, 1] - x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1 - return x - -class ResBlock_Adain(nn.Module): - def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): - super(ResBlock_Adain, self).__init__() - - p = 0 - conv1 = [] - if padding_type == 'reflect': - conv1 += [nn.ReflectionPad2d(1)] - elif padding_type == 'replicate': - conv1 += [nn.ReplicationPad2d(1)] - elif padding_type == 'zero': - p = 1 - else: - raise NotImplementedError('padding [%s] is not implemented' % padding_type) - conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] - self.conv1 = nn.Sequential(*conv1) - self.style1 = ApplyStyle(latent_size, dim) - self.act1 = activation - - p = 0 - conv2 = [] - if padding_type == 'reflect': - conv2 += [nn.ReflectionPad2d(1)] - elif padding_type == 'replicate': - conv2 += [nn.ReplicationPad2d(1)] - elif padding_type == 'zero': - p = 1 - else: - raise NotImplementedError('padding [%s] is not implemented' % padding_type) - conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] - self.conv2 = nn.Sequential(*conv2) - self.style2 = ApplyStyle(latent_size, dim) - - - def forward(self, x, dlatents_in_slice): - y = self.conv1(x) - y = self.style1(y, dlatents_in_slice) - y = self.act1(y) - y = self.conv2(y) - y = self.style2(y, dlatents_in_slice) - out = x + y - return out \ No newline at end of file diff --git a/components/Transform.py b/components/Transform.py deleted file mode 100644 index 1888c60..0000000 --- a/components/Transform.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch -from torch import nn - -class Transform_block(nn.Module): - def __init__(self, k_size = 10): - super().__init__() - padding_size = int((k_size -1)/2) - # self.padding = nn.ReplicationPad2d(padding_size) - self.pool = nn.AvgPool2d(k_size, stride=1,padding=padding_size) - - def forward(self, input_image): - # h = self.padding(input) - out = self.pool(input_image) - return out \ No newline at end of file diff --git a/components/network_swin.py b/components/network_swin.py deleted file mode 100644 index 8a75fdd..0000000 --- a/components/network_swin.py +++ /dev/null @@ -1,854 +0,0 @@ -# ----------------------------------------------------------------------------------- -# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 -# Originally Written by Ze Liu, Modified by Jingyun Liang. -# ----------------------------------------------------------------------------------- - -import math -import torch -import torch.nn as nn -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - - self.proj_drop = nn.Dropout(proj_drop) - nn.init - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, - input_resolution, - dim, norm_layer = nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - flops = 0 - H, W = self.img_size - if self.norm is not None: - flops += H * W * self.embed_dim - return flops - - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - -class SwinIR(nn.Module): - r""" SwinIR - A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(SwinIR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - assert self.upscale == 4, 'only support x4 now.' - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - - return x - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = SwinIR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/components/pg_modules/blocks.py b/components/pg_modules/blocks.py new file mode 100644 index 0000000..78bd113 --- /dev/null +++ b/components/pg_modules/blocks.py @@ -0,0 +1,325 @@ +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import spectral_norm + + +### single layers + + +def conv2d(*args, **kwargs): + return spectral_norm(nn.Conv2d(*args, **kwargs)) + + +def convTranspose2d(*args, **kwargs): + return spectral_norm(nn.ConvTranspose2d(*args, **kwargs)) + + +def embedding(*args, **kwargs): + return spectral_norm(nn.Embedding(*args, **kwargs)) + + +def linear(*args, **kwargs): + return spectral_norm(nn.Linear(*args, **kwargs)) + + +def NormLayer(c, mode='batch'): + if mode == 'group': + return nn.GroupNorm(c//2, c) + elif mode == 'batch': + return nn.BatchNorm2d(c) + + +### Activations + + +class GLU(nn.Module): + def forward(self, x): + nc = x.size(1) + assert nc % 2 == 0, 'channels dont divide 2!' + nc = int(nc/2) + return x[:, :nc] * torch.sigmoid(x[:, nc:]) + + +class Swish(nn.Module): + def forward(self, feat): + return feat * torch.sigmoid(feat) + + +### Upblocks + + +class InitLayer(nn.Module): + def __init__(self, nz, channel, sz=4): + super().__init__() + + self.init = nn.Sequential( + convTranspose2d(nz, channel*2, sz, 1, 0, bias=False), + NormLayer(channel*2), + GLU(), + ) + + def forward(self, noise): + noise = noise.view(noise.shape[0], -1, 1, 1) + return self.init(noise) + + +def UpBlockSmall(in_planes, out_planes): + block = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False), + NormLayer(out_planes*2), GLU()) + return block + + +class UpBlockSmallCond(nn.Module): + def __init__(self, in_planes, out_planes, z_dim): + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + self.up = nn.Upsample(scale_factor=2, mode='nearest') + self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False) + + which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim) + self.bn = which_bn(2*out_planes) + self.act = GLU() + + def forward(self, x, c): + x = self.up(x) + x = self.conv(x) + x = self.bn(x, c) + x = self.act(x) + return x + + +def UpBlockBig(in_planes, out_planes): + block = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False), + NoiseInjection(), + NormLayer(out_planes*2), GLU(), + conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False), + NoiseInjection(), + NormLayer(out_planes*2), GLU() + ) + return block + + +class UpBlockBigCond(nn.Module): + def __init__(self, in_planes, out_planes, z_dim): + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + self.up = nn.Upsample(scale_factor=2, mode='nearest') + self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False) + self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False) + + which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim) + self.bn1 = which_bn(2*out_planes) + self.bn2 = which_bn(2*out_planes) + self.act = GLU() + self.noise = NoiseInjection() + + def forward(self, x, c): + # block 1 + x = self.up(x) + x = self.conv1(x) + x = self.noise(x) + x = self.bn1(x, c) + x = self.act(x) + + # block 2 + x = self.conv2(x) + x = self.noise(x) + x = self.bn2(x, c) + x = self.act(x) + + return x + + +class SEBlock(nn.Module): + def __init__(self, ch_in, ch_out): + super().__init__() + self.main = nn.Sequential( + nn.AdaptiveAvgPool2d(4), + conv2d(ch_in, ch_out, 4, 1, 0, bias=False), + Swish(), + conv2d(ch_out, ch_out, 1, 1, 0, bias=False), + nn.Sigmoid(), + ) + + def forward(self, feat_small, feat_big): + return feat_big * self.main(feat_small) + + +### Downblocks + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=False): + super(SeparableConv2d, self).__init__() + self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size, + groups=in_channels, bias=bias, padding=1) + self.pointwise = conv2d(in_channels, out_channels, + kernel_size=1, bias=bias) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out + + +class DownBlock(nn.Module): + def __init__(self, in_planes, out_planes, separable=False): + super().__init__() + if not separable: + self.main = nn.Sequential( + conv2d(in_planes, out_planes, 4, 2, 1), + NormLayer(out_planes), + nn.LeakyReLU(0.2, inplace=True), + ) + else: + self.main = nn.Sequential( + SeparableConv2d(in_planes, out_planes, 3), + NormLayer(out_planes), + nn.LeakyReLU(0.2, inplace=True), + nn.AvgPool2d(2, 2), + ) + + def forward(self, feat): + return self.main(feat) + + +class DownBlockPatch(nn.Module): + def __init__(self, in_planes, out_planes, separable=False): + super().__init__() + self.main = nn.Sequential( + DownBlock(in_planes, out_planes, separable), + conv2d(out_planes, out_planes, 1, 1, 0, bias=False), + NormLayer(out_planes), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, feat): + return self.main(feat) + + +### CSM + + +class ResidualConvUnit(nn.Module): + def __init__(self, cin, activation, bn): + super().__init__() + self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + return self.skip_add.add(self.conv(x), x) + + +class FeatureFusionBlock(nn.Module): + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False): + super().__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + output = xs[0] + + if len(xs) == 2: + output = self.skip_add.add(output, xs[1]) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + + +### Misc + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.zeros(1), requires_grad=True) + + def forward(self, feat, noise=None): + if noise is None: + batch, _, height, width = feat.shape + noise = torch.randn(batch, 1, height, width).to(feat.device) + + return feat + self.weight * noise + + +class CCBN(nn.Module): + ''' conditional batchnorm ''' + def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1): + super().__init__() + self.output_size, self.input_size = output_size, input_size + + # Prepare gain and bias layers + self.gain = which_linear(input_size, output_size) + self.bias = which_linear(input_size, output_size) + + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y): + # Calculate class-conditional gains and biases + gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) + bias = self.bias(y).view(y.size(0), -1, 1, 1) + out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + return out * gain + bias + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, size, mode='bilinear', align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.size = size + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + size=self.size, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x diff --git a/components/pg_modules/diffaug.py b/components/pg_modules/diffaug.py new file mode 100644 index 0000000..54020be --- /dev/null +++ b/components/pg_modules/diffaug.py @@ -0,0 +1,76 @@ +# Differentiable Augmentation for Data-Efficient GAN Training +# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han +# https://arxiv.org/pdf/2006.10738 + +import torch +import torch.nn.functional as F + + +def DiffAugment(x, policy='', channels_first=True): + if policy: + if not channels_first: + x = x.permute(0, 3, 1, 2) + for p in policy.split(','): + for f in AUGMENT_FNS[p]: + x = f(x) + if not channels_first: + x = x.permute(0, 2, 3, 1) + x = x.contiguous() + return x + + +def rand_brightness(x): + x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) + return x + + +def rand_saturation(x): + x_mean = x.mean(dim=1, keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean + return x + + +def rand_contrast(x): + x_mean = x.mean(dim=[1, 2, 3], keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean + return x + + +def rand_translation(x, ratio=0.125): + shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(x.size(2), dtype=torch.long, device=x.device), + torch.arange(x.size(3), dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) + grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) + x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) + x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) + return x + + +def rand_cutout(x, ratio=0.2): + cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(cutout_size[0], dtype=torch.long, device=x.device), + torch.arange(cutout_size[1], dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) + mask[grid_batch, grid_x, grid_y] = 0 + x = x * mask.unsqueeze(1) + return x + + +AUGMENT_FNS = { + 'color': [rand_brightness, rand_saturation, rand_contrast], + 'translation': [rand_translation], + 'cutout': [rand_cutout], +} diff --git a/components/pg_modules/discriminator.py b/components/pg_modules/discriminator.py new file mode 100644 index 0000000..02728e9 --- /dev/null +++ b/components/pg_modules/discriminator.py @@ -0,0 +1,186 @@ +from functools import partial +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from components.pg_modules.blocks import DownBlock, DownBlockPatch, conv2d +from components.pg_modules.projector import F_RandomProj +from components.pg_modules.diffaug import DiffAugment + + +class SingleDisc(nn.Module): + def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False): + super().__init__() + channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, + 256: 32, 512: 16, 1024: 8} + + # interpolate for start sz that are not powers of two + if start_sz not in channel_dict.keys(): + sizes = np.array(list(channel_dict.keys())) + start_sz = sizes[np.argmin(abs(sizes - start_sz))] + self.start_sz = start_sz + + # if given ndf, allocate all layers with the same ndf + if ndf is None: + nfc = channel_dict + else: + nfc = {k: ndf for k, v in channel_dict.items()} + + # for feature map discriminators with nfc not in channel_dict + # this is the case for the pretrained backbone (midas.pretrained) + if nc is not None and head is None: + nfc[start_sz] = nc + + layers = [] + + # Head if the initial input is the full modality + if head: + layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True)] + + # Down Blocks + DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) + while start_sz > end_sz: + layers.append(DB(nfc[start_sz], nfc[start_sz//2])) + start_sz = start_sz // 2 + + layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False)) + self.main = nn.Sequential(*layers) + + def forward(self, x, c): + return self.main(x) + + +class SingleDiscCond(nn.Module): + def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128): + super().__init__() + self.cmap_dim = cmap_dim + + # midas channels + channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, + 256: 32, 512: 16, 1024: 8} + + # interpolate for start sz that are not powers of two + if start_sz not in channel_dict.keys(): + sizes = np.array(list(channel_dict.keys())) + start_sz = sizes[np.argmin(abs(sizes - start_sz))] + self.start_sz = start_sz + + # if given ndf, allocate all layers with the same ndf + if ndf is None: + nfc = channel_dict + else: + nfc = {k: ndf for k, v in channel_dict.items()} + + # for feature map discriminators with nfc not in channel_dict + # this is the case for the pretrained backbone (midas.pretrained) + if nc is not None and head is None: + nfc[start_sz] = nc + + layers = [] + + # Head if the initial input is the full modality + if head: + layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True)] + + # Down Blocks + DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) + while start_sz > end_sz: + layers.append(DB(nfc[start_sz], nfc[start_sz//2])) + start_sz = start_sz // 2 + self.main = nn.Sequential(*layers) + + # additions for conditioning on class information + self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False) + self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim) + self.embed_proj = nn.Sequential( + nn.Linear(self.embed.embedding_dim, self.cmap_dim), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, x, c): + h = self.main(x) + out = self.cls(h) + + # conditioning via projection + cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1) + out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + return out + + +class MultiScaleD(nn.Module): + def __init__( + self, + channels, + resolutions, + num_discs=1, + proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing + cond=0, + separable=False, + patch=False, + **kwargs, + ): + super().__init__() + + assert num_discs in [1, 2, 3, 4] + + # the first disc is on the lowest level of the backbone + self.disc_in_channels = channels[:num_discs] + self.disc_in_res = resolutions[:num_discs] + Disc = SingleDiscCond if cond else SingleDisc + + mini_discs = [] + for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)): + start_sz = res if not patch else 16 + mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)], + self.mini_discs = nn.ModuleDict(mini_discs) + + def forward(self, features, c): + all_logits = [] + for k, disc in self.mini_discs.items(): + all_logits.append(disc(features[k], c).view(features[k].size(0), -1)) + + all_logits = torch.cat(all_logits, dim=1) + return all_logits + + +class ProjectedDiscriminator(torch.nn.Module): + def __init__( + self, + diffaug=True, + interp224=True, + backbone_kwargs={}, + **kwargs + ): + super().__init__() + self.diffaug = diffaug + self.interp224 = interp224 + self.feature_network = F_RandomProj(**backbone_kwargs) + self.discriminator = MultiScaleD( + channels=self.feature_network.CHANNELS, + resolutions=self.feature_network.RESOLUTIONS, + **backbone_kwargs, + ) + + def train(self, mode=True): + self.feature_network = self.feature_network.train(False) + self.discriminator = self.discriminator.train(mode) + return self + + def eval(self): + return self.train(False) + + def forward(self, x, c): + if self.diffaug: + x = DiffAugment(x, policy='color,translation,cutout') + + if self.interp224: + x = F.interpolate(x, 224, mode='bilinear', align_corners=False) + + features = self.feature_network(x) + logits = self.discriminator(features, c) + + return logits diff --git a/components/pg_modules/networks_fastgan.py b/components/pg_modules/networks_fastgan.py new file mode 100644 index 0000000..1a32056 --- /dev/null +++ b/components/pg_modules/networks_fastgan.py @@ -0,0 +1,178 @@ +# original implementation: https://github.com/odegeasslbc/FastGAN-pytorch/blob/main/models.py +# +# modified by Axel Sauer for "Projected GANs Converge Faster" +# +import torch.nn as nn +from pg_modules.blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d) + + +def normalize_second_moment(x, dim=1, eps=1e-8): + return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() + + +class DummyMapping(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z, c, **kwargs): + return z.unsqueeze(1) # to fit the StyleGAN API + + +class FastganSynthesis(nn.Module): + def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False): + super().__init__() + self.img_resolution = img_resolution + self.z_dim = z_dim + + # channel multiplier + nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, + 512:0.25, 1024:0.125} + nfc = {} + for k, v in nfc_multi.items(): + nfc[k] = int(v*ngf) + + # layers + self.init = InitLayer(z_dim, channel=nfc[2], sz=4) + + UpBlock = UpBlockSmall if lite else UpBlockBig + + self.feat_8 = UpBlock(nfc[4], nfc[8]) + self.feat_16 = UpBlock(nfc[8], nfc[16]) + self.feat_32 = UpBlock(nfc[16], nfc[32]) + self.feat_64 = UpBlock(nfc[32], nfc[64]) + self.feat_128 = UpBlock(nfc[64], nfc[128]) + self.feat_256 = UpBlock(nfc[128], nfc[256]) + + self.se_64 = SEBlock(nfc[4], nfc[64]) + self.se_128 = SEBlock(nfc[8], nfc[128]) + self.se_256 = SEBlock(nfc[16], nfc[256]) + + self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) + + if img_resolution > 256: + self.feat_512 = UpBlock(nfc[256], nfc[512]) + self.se_512 = SEBlock(nfc[32], nfc[512]) + if img_resolution > 512: + self.feat_1024 = UpBlock(nfc[512], nfc[1024]) + + def forward(self, input, c, **kwargs): + # map noise to hypersphere as in "Progressive Growing of GANS" + input = normalize_second_moment(input[:, 0]) + + feat_4 = self.init(input) + feat_8 = self.feat_8(feat_4) + feat_16 = self.feat_16(feat_8) + feat_32 = self.feat_32(feat_16) + feat_64 = self.se_64(feat_4, self.feat_64(feat_32)) + feat_128 = self.se_128(feat_8, self.feat_128(feat_64)) + + if self.img_resolution >= 128: + feat_last = feat_128 + + if self.img_resolution >= 256: + feat_last = self.se_256(feat_16, self.feat_256(feat_last)) + + if self.img_resolution >= 512: + feat_last = self.se_512(feat_32, self.feat_512(feat_last)) + + if self.img_resolution >= 1024: + feat_last = self.feat_1024(feat_last) + + return self.to_big(feat_last) + + +class FastganSynthesisCond(nn.Module): + def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False): + super().__init__() + + self.z_dim = z_dim + nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, + 512:0.25, 1024:0.125, 2048:0.125} + nfc = {} + for k, v in nfc_multi.items(): + nfc[k] = int(v*ngf) + + self.img_resolution = img_resolution + + self.init = InitLayer(z_dim, channel=nfc[2], sz=4) + + UpBlock = UpBlockSmallCond if lite else UpBlockBigCond + + self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim) + self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim) + self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim) + self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim) + self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim) + self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim) + + self.se_64 = SEBlock(nfc[4], nfc[64]) + self.se_128 = SEBlock(nfc[8], nfc[128]) + self.se_256 = SEBlock(nfc[16], nfc[256]) + + self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) + + if img_resolution > 256: + self.feat_512 = UpBlock(nfc[256], nfc[512]) + self.se_512 = SEBlock(nfc[32], nfc[512]) + if img_resolution > 512: + self.feat_1024 = UpBlock(nfc[512], nfc[1024]) + + self.embed = nn.Embedding(num_classes, z_dim) + + def forward(self, input, c, update_emas=False): + c = self.embed(c.argmax(1)) + + # map noise to hypersphere as in "Progressive Growing of GANS" + input = normalize_second_moment(input[:, 0]) + + feat_4 = self.init(input) + feat_8 = self.feat_8(feat_4, c) + feat_16 = self.feat_16(feat_8, c) + feat_32 = self.feat_32(feat_16, c) + feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c)) + feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c)) + + if self.img_resolution >= 128: + feat_last = feat_128 + + if self.img_resolution >= 256: + feat_last = self.se_256(feat_16, self.feat_256(feat_last, c)) + + if self.img_resolution >= 512: + feat_last = self.se_512(feat_32, self.feat_512(feat_last, c)) + + if self.img_resolution >= 1024: + feat_last = self.feat_1024(feat_last, c) + + return self.to_big(feat_last) + + +class Generator(nn.Module): + def __init__( + self, + z_dim=256, + c_dim=0, + w_dim=0, + img_resolution=256, + img_channels=3, + ngf=128, + cond=0, + mapping_kwargs={}, + synthesis_kwargs={} + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + # Mapping and Synthesis Networks + self.mapping = DummyMapping() # to fit the StyleGAN API + Synthesis = FastganSynthesisCond if cond else FastganSynthesis + self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs) + + def forward(self, z, c, **kwargs): + w = self.mapping(z, c) + img = self.synthesis(w, c) + return img diff --git a/components/pg_modules/networks_stylegan2.py b/components/pg_modules/networks_stylegan2.py new file mode 100644 index 0000000..c554a2f --- /dev/null +++ b/components/pg_modules/networks_stylegan2.py @@ -0,0 +1,537 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. +# +# modified by Axel Sauer for "Projected GANs Converge Faster" +# +import numpy as np +import torch +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_resample +from torch_utils.ops import upfirdn2d +from torch_utils.ops import bias_act +from torch_utils.ops import fma + + +@misc.profiled_function +def normalize_2nd_moment(x, dim=1, eps=1e-8): + return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() + + +@misc.profiled_function +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise = None, # Optional noise tensor to add to the output activations. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor. + padding = 0, # Padding with respect to the upsampled image. + resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate = True, # Apply weight demodulation? + flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(styles, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) + if demodulate and noise is not None: + x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(batch_size) + misc.assert_shape(x, [batch_size, in_channels, None, None]) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + + +@persistence.persistent_class +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 1, # Learning rate multiplier. + bias_init = 0, # Initial value for the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + + +@persistence.persistent_class +class Conv2dLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output to +-X, None = disable clamping. + channels_last = False, # Expect the input to have memory_format=channels_last? + trainable = True, # Update the weights of this layer during training? + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.activation = activation + self.up = up + self.down = down + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.act_gain = bias_act.activation_funcs[activation].def_gain + + memory_format = torch.channels_last if channels_last else torch.contiguous_format + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) + bias = torch.zeros([out_channels]) if bias else None + if trainable: + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + else: + self.register_buffer('weight', weight) + if bias is not None: + self.register_buffer('bias', bias) + else: + self.bias = None + + def forward(self, x, gain=1): + w = self.weight * self.weight_gain + b = self.bias.to(x.dtype) if self.bias is not None else None + flip_weight = (self.up == 1) # slightly faster + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join([ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},', + f'up={self.up}, down={self.down}']) + + +@persistence.persistent_class +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers = 8, # Number of mapping layers. + embed_features = None, # Label embedding dimensionality, None = same as w_dim. + layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function('input'): + if self.z_dim > 0: + misc.assert_shape(z, [None, self.z_dim]) + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + misc.assert_shape(c, [None, self.c_dim]) + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f'fc{idx}') + x = layer(x) + + # Update moving average of W. + if update_emas and self.w_avg_beta is not None: + with torch.autograd.profiler.record_function('update_w_avg'): + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function('broadcast'): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function('truncate'): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' + + +@persistence.persistent_class +class SynthesisLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size = 3, # Convolution kernel size. + up = 1, # Integer upsampling factor. + use_noise = True, # Enable noise input? + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last = False, # Use channels_last format for the weights? + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.resolution = resolution + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = bias_act.activation_funcs[activation].def_gain + + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + if use_noise: + self.register_buffer('noise_const', torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + + def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): + assert noise_mode in ['random', 'const', 'none'] + in_resolution = self.resolution // self.up + misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution]) + styles = self.affine(w) + + noise = None + if self.use_noise and noise_mode == 'random': + noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength + if self.use_noise and noise_mode == 'const': + noise = self.noise_const * self.noise_strength + + flip_weight = (self.up == 1) # slightly faster + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, + padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join([ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},', + f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}']) + + +@persistence.persistent_class +class ToRGBLayer(torch.nn.Module): + def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.conv_clamp = conv_clamp + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + + def forward(self, x, w, fused_modconv=True): + styles = self.affine(w) * self.weight_gain + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) + x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) + return x + + def extra_repr(self): + return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}' + + +@persistence.persistent_class +class SynthesisBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.fused_modconv_default = fused_modconv_default + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + + if in_channels == 0: + self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, + resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, + conv_clamp=conv_clamp, channels_last=self.channels_last) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): + _ = update_emas # unused + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + if ws.device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + fused_modconv = self.fused_modconv_default + if fused_modconv == 'inference_only': + fused_modconv = (not self.training) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None: + misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + + +@persistence.persistent_class +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.num_fp16_res = num_fp16_res + self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.num_ws = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res // 2] if res > 4 else 0 + out_channels = channels_dict[res] + use_fp16 = (res >= fp16_resolution) + is_last = (res == self.img_resolution) + block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) + self.num_ws += block.num_conv + if is_last: + self.num_ws += block.num_torgb + setattr(self, f'b{res}', block) + + def forward(self, ws, c=None, **block_kwargs): + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32) + w_idx = 0 + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + x = img = None + for res, cur_ws in zip(self.block_resolutions, block_ws): + block = getattr(self, f'b{res}') + x, img = block(x, img, cur_ws, **block_kwargs) + return img + + def extra_repr(self): + return ' '.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_fp16_res={self.num_fp16_res:d}']) + + +@persistence.persistent_class +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + mapping_kwargs = {}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + return img diff --git a/components/pg_modules/projector.py b/components/pg_modules/projector.py new file mode 100644 index 0000000..7ca03c1 --- /dev/null +++ b/components/pg_modules/projector.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +import timm +from components.pg_modules.blocks import FeatureFusionBlock + + +def _make_scratch_ccm(scratch, in_channels, cout, expand=False): + # shapes + out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4 + + scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True) + scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True) + scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True) + scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True) + + scratch.CHANNELS = out_channels + + return scratch + + +def _make_scratch_csm(scratch, in_channels, cout, expand): + scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True) + scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand) + scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand) + scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False)) + + # last refinenet does not expand to save channels in higher dimensions + scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4 + + return scratch + + +def _make_efficientnet(model): + pretrained = nn.Module() + pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2]) + pretrained.layer1 = nn.Sequential(*model.blocks[2:3]) + pretrained.layer2 = nn.Sequential(*model.blocks[3:5]) + pretrained.layer3 = nn.Sequential(*model.blocks[5:9]) + return pretrained + + +def calc_channels(pretrained, inp_res=224): + channels = [] + tmp = torch.zeros(1, 3, inp_res, inp_res) + + # forward pass + tmp = pretrained.layer0(tmp) + channels.append(tmp.shape[1]) + tmp = pretrained.layer1(tmp) + channels.append(tmp.shape[1]) + tmp = pretrained.layer2(tmp) + channels.append(tmp.shape[1]) + tmp = pretrained.layer3(tmp) + channels.append(tmp.shape[1]) + + return channels + + +def _make_projector(im_res, cout, proj_type, expand=False): + assert proj_type in [0, 1, 2], "Invalid projection type" + + ### Build pretrained feature network + model = timm.create_model('tf_efficientnet_lite0', pretrained=True) + pretrained = _make_efficientnet(model) + + # determine resolution of feature maps, this is later used to calculate the number + # of down blocks in the discriminators. Interestingly, the best results are achieved + # by fixing this to 256, ie., we use the same number of down blocks per discriminator + # independent of the dataset resolution + im_res = 256 + pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32] + pretrained.CHANNELS = calc_channels(pretrained) + + if proj_type == 0: return pretrained, None + + ### Build CCM + scratch = nn.Module() + scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand) + pretrained.CHANNELS = scratch.CHANNELS + + if proj_type == 1: return pretrained, scratch + + ### build CSM + scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand) + + # CSM upsamples x2 so the feature map resolution doubles + pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS] + pretrained.CHANNELS = scratch.CHANNELS + + return pretrained, scratch + + +class F_RandomProj(nn.Module): + def __init__( + self, + im_res=256, + cout=64, + expand=True, + proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing + **kwargs, + ): + super().__init__() + self.proj_type = proj_type + self.cout = cout + self.expand = expand + + # build pretrained feature network and random decoder (scratch) + self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand) + self.CHANNELS = self.pretrained.CHANNELS + self.RESOLUTIONS = self.pretrained.RESOLUTIONS + + def forward(self, x, get_features=False): + # predict feature maps + out0 = self.pretrained.layer0(x) + out1 = self.pretrained.layer1(out0) + out2 = self.pretrained.layer2(out1) + out3 = self.pretrained.layer3(out2) + + # start enumerating at the lowest layer (this is where we put the first discriminator) + backbone_features = { + '0': out0, + '1': out1, + '2': out2, + '3': out3, + } + if get_features: + return backbone_features + + if self.proj_type == 0: return backbone_features + + out0_channel_mixed = self.scratch.layer0_ccm(backbone_features['0']) + out1_channel_mixed = self.scratch.layer1_ccm(backbone_features['1']) + out2_channel_mixed = self.scratch.layer2_ccm(backbone_features['2']) + out3_channel_mixed = self.scratch.layer3_ccm(backbone_features['3']) + + out = { + '0': out0_channel_mixed, + '1': out1_channel_mixed, + '2': out2_channel_mixed, + '3': out3_channel_mixed, + } + + if self.proj_type == 1: return out + + # from bottom to top + out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed) + out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed) + out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed) + out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed) + + out = { + '0': out0_scale_mixed, + '1': out1_scale_mixed, + '2': out2_scale_mixed, + '3': out3_scale_mixed, + } + + return out, backbone_features diff --git a/components/projected_discriminator.py b/components/projected_discriminator.py new file mode 100644 index 0000000..3f2f23d --- /dev/null +++ b/components/projected_discriminator.py @@ -0,0 +1,194 @@ + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from functools import partial + +from components.pg_modules.blocks import DownBlock, DownBlockPatch, conv2d +from components.pg_modules.projector import F_RandomProj +# from components.pg_modules.diffaug import DiffAugment + + +class SingleDisc(nn.Module): + def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False): + super().__init__() + channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, + 256: 32, 512: 16, 1024: 8} + + # interpolate for start sz that are not powers of two + if start_sz not in channel_dict.keys(): + sizes = np.array(list(channel_dict.keys())) + start_sz = sizes[np.argmin(abs(sizes - start_sz))] + self.start_sz = start_sz + + # if given ndf, allocate all layers with the same ndf + if ndf is None: + nfc = channel_dict + else: + nfc = {k: ndf for k, v in channel_dict.items()} + + # for feature map discriminators with nfc not in channel_dict + # this is the case for the pretrained backbone (midas.pretrained) + if nc is not None and head is None: + nfc[start_sz] = nc + + layers = [] + + # Head if the initial input is the full modality + if head: + layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True)] + + # Down Blocks + DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) + while start_sz > end_sz: + layers.append(DB(nfc[start_sz], nfc[start_sz//2])) + start_sz = start_sz // 2 + + layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False)) + self.main = nn.Sequential(*layers) + + def forward(self, x, c): + return self.main(x) + + +class SingleDiscCond(nn.Module): + def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128): + super().__init__() + self.cmap_dim = cmap_dim + + # midas channels + channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, + 256: 32, 512: 16, 1024: 8} + + # interpolate for start sz that are not powers of two + if start_sz not in channel_dict.keys(): + sizes = np.array(list(channel_dict.keys())) + start_sz = sizes[np.argmin(abs(sizes - start_sz))] + self.start_sz = start_sz + + # if given ndf, allocate all layers with the same ndf + if ndf is None: + nfc = channel_dict + else: + nfc = {k: ndf for k, v in channel_dict.items()} + + # for feature map discriminators with nfc not in channel_dict + # this is the case for the pretrained backbone (midas.pretrained) + if nc is not None and head is None: + nfc[start_sz] = nc + + layers = [] + + # Head if the initial input is the full modality + if head: + layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True)] + + # Down Blocks + DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) + while start_sz > end_sz: + layers.append(DB(nfc[start_sz], nfc[start_sz//2])) + start_sz = start_sz // 2 + self.main = nn.Sequential(*layers) + + # additions for conditioning on class information + self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False) + self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim) + self.embed_proj = nn.Sequential( + nn.Linear(self.embed.embedding_dim, self.cmap_dim), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, x, c): + h = self.main(x) + out = self.cls(h) + + # conditioning via projection + cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1) + out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + return out + + +class MultiScaleD(nn.Module): + def __init__( + self, + channels, + resolutions, + num_discs=4, + proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing + cond=0, + separable=False, + patch=False, + **kwargs, + ): + super().__init__() + + assert num_discs in [1, 2, 3, 4] + + # the first disc is on the lowest level of the backbone + self.disc_in_channels = channels[:num_discs] + self.disc_in_res = resolutions[:num_discs] + Disc = SingleDiscCond if cond else SingleDisc + + mini_discs = [] + for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)): + start_sz = res if not patch else 16 + mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)], + self.mini_discs = nn.ModuleDict(mini_discs) + + def forward(self, features, c): + all_logits = [] + for k, disc in self.mini_discs.items(): + res = disc(features[k], c).view(features[k].size(0), -1) + all_logits.append(res) + + all_logits = torch.cat(all_logits, dim=1) + return all_logits + + +class ProjectedDiscriminator(torch.nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + self.diffaug = kwargs["diffaug"] + self.interp224 = kwargs["interp224"] + backbone_kwargs = kwargs["backbone_kwargs"] + + self.interp224 = False + self.feature_network = F_RandomProj(**backbone_kwargs) + self.discriminator = MultiScaleD( + channels=self.feature_network.CHANNELS, + resolutions=self.feature_network.RESOLUTIONS, + **backbone_kwargs, + ) + + def train(self, mode=True): + self.feature_network = self.feature_network.train(False) + self.discriminator = self.discriminator.train(mode) + return self + + def eval(self): + return self.train(False) + + def get_feature(self, x): + features = self.feature_network(x, get_features=True) + return features + + def forward(self, x, c): + # if self.diffaug: + # x = DiffAugment(x, policy='color,translation,cutout') + + # if self.interp224: + # x = F.interpolate(x, 224, mode='bilinear', align_corners=False) + + features,backbone_features = self.feature_network(x) + logits = self.discriminator(features, c) + + return logits,backbone_features + diff --git a/components/warp_invo.py b/components/warp_invo.py deleted file mode 100644 index f0e05e5..0000000 --- a/components/warp_invo.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: warp_invo.py -# Created Date: Tuesday October 19th 2021 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 11:27:13 am -# Modified By: Chen Xuanhong -# Copyright (c) 2021 Shanghai Jiao Tong University -############################################################# - -from torch import nn -from components.Involution import involution - - -class DeConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size = 3, upsampl_scale = 2, padding="reflect"): - super().__init__() - self.upsampling = nn.UpsamplingNearest2d(scale_factor=upsampl_scale) - padding_size = int((kernel_size -1)/2) - self.conv1x1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size= 1) - # self.same_padding = nn.ReflectionPad2d(padding_size) - if padding.lower() == "reflect": - - self.conv = involution(out_channels,5,1) - # self.conv = nn.Sequential( - # nn.ReflectionPad2d(padding_size), - # nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size= kernel_size, bias= False)) - # for layer in self.conv: - # if isinstance(layer,nn.Conv2d): - # nn.init.xavier_uniform_(layer.weight) - elif padding.lower() == "zero": - self.conv = involution(out_channels,5,1) - # nn.init.xavier_uniform_(self.conv.weight) - # self.__weights_init__() - - # def __weights_init__(self): - # nn.init.xavier_uniform_(self.conv.weight) - - def forward(self, input): - h = self.conv1x1(input) - h = self.upsampling(h) - h = self.conv(h) - return h \ No newline at end of file diff --git a/data_tools/data_loader_VGGFace2HQ.py b/data_tools/data_loader_VGGFace2HQ.py index 9c74fbe..6813a76 100644 --- a/data_tools/data_loader_VGGFace2HQ.py +++ b/data_tools/data_loader_VGGFace2HQ.py @@ -1,4 +1,5 @@ import os +import glob import torch import random from PIL import Image @@ -12,8 +13,8 @@ class data_prefetcher(): self.loader = loader self.dataiter = iter(loader) self.stream = torch.cuda.Stream() - # self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) - # self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) + self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1) + self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1) # With Amp, it isn't necessary to manually convert data to half. # if args.fp16: # self.mean = self.mean.half() @@ -23,13 +24,16 @@ class data_prefetcher(): def preload(self): try: - self.content = next(self.dataiter) + self.src_image1, self.src_image2 = next(self.dataiter) except StopIteration: self.dataiter = iter(self.loader) - self.content = next(self.dataiter) + self.src_image1, self.src_image2 = next(self.dataiter) with torch.cuda.stream(self.stream): - self.content= self.content.cuda(non_blocking=True) + self.src_image1 = self.src_image1.cuda(non_blocking=True) + self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std) + self.src_image2 = self.src_image2.cuda(non_blocking=True) + self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std) # With Amp, it isn't necessary to manually convert data to half. # if args.fp16: # self.next_input = self.next_input.half() @@ -38,9 +42,10 @@ class data_prefetcher(): # self.next_input = self.next_input.sub_(self.mean).div_(self.std) def next(self): torch.cuda.current_stream().wait_stream(self.stream) - content = self.content + src_image1 = self.src_image1 + src_image2 = self.src_image2 self.preload() - return content + return src_image1, src_image2 def __len__(self): """Return the number of images.""" @@ -50,90 +55,69 @@ class VGGFace2HQDataset(data.Dataset): """Dataset class for the Artworks dataset and content dataset.""" def __init__(self, - content_image_dir, - selectedContent, - content_transform, + image_dir, + img_transform, subffix='jpg', random_seed=1234): - """Initialize and preprocess the CelebA dataset.""" - self.content_image_dir = content_image_dir - self.content_transform = content_transform - self.selectedContent = selectedContent - self.subffix = subffix - self.content_dataset = [] - self.random_seed = random_seed + """Initialize and preprocess the VGGFace2 HQ dataset.""" + self.image_dir = image_dir + self.img_transform = img_transform + self.subffix = subffix + self.dataset = [] + self.random_seed = random_seed self.preprocess() - self.num_images = len(self.content_dataset) + self.num_images = len(self.dataset) def preprocess(self): - """Preprocess the Artworks dataset.""" - print("processing content images...") - for dir_item in self.selectedContent: - join_path = Path(self.content_image_dir,dir_item) - if join_path.exists(): - print("processing %s"%dir_item,end='\r') - images = join_path.glob('*.%s'%(self.subffix)) - for item in images: - self.content_dataset.append(item) - else: - print("%s dir does not exist!"%dir_item,end='\r') + """Preprocess the VGGFace2 HQ dataset.""" + print("processing VGGFace2 HQ dataset images...") + + temp_path = os.path.join(self.image_dir,'*/') + pathes = glob.glob(temp_path) + self.dataset = [] + for dir_item in pathes: + join_path = glob.glob(os.path.join(dir_item,'*.jpg')) + print("processing %s"%dir_item,end='\r') + temp_list = [] + for item in join_path: + temp_list.append(item) + self.dataset.append(temp_list) random.seed(self.random_seed) - random.shuffle(self.content_dataset) - print('Finished preprocessing the Content dataset, total image number: %d...'%len(self.content_dataset)) - + random.shuffle(self.dataset) + print('Finished preprocessing the VGGFace2 HQ dataset, total dirs number: %d...'%len(self.dataset)) + def __getitem__(self, index): - """Return one image and its corresponding attribute label.""" - filename = self.content_dataset[index] - image = Image.open(filename) - content = self.content_transform(image) - return content + """Return two src domain images and two dst domain images.""" + dir_tmp1 = self.dataset[index] + dir_tmp1_len = len(dir_tmp1) + filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] + filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] + image1 = self.img_transform(Image.open(filename1)) + image2 = self.img_transform(Image.open(filename2)) + return image1, image2 + def __len__(self): """Return the number of images.""" return self.num_images def GetLoader( dataset_roots, batch_size=16, - crop_size=512, **kwargs ): """Build and return a data loader.""" - if not kwargs: - a = "Input params error!" - raise ValueError(print(a)) - colorJitterEnable = kwargs["color_jitter"] - colorConfig = kwargs["color_config"] - num_workers = kwargs["dataloader_workers"] - num_workers = kwargs["dataloader_workers"] - place365_root = dataset_roots["Place365_big"] - selected_c_dir = kwargs["selected_content_dir"] - random_seed = kwargs["random_seed"] + data_root = dataset_roots + random_seed = kwargs["random_seed"] + num_workers = kwargs["dataloader_workers"] c_transforms = [] - # s_transforms.append(T.Resize(900)) - c_transforms.append(T.Resize(900)) - c_transforms.append(T.RandomCrop(crop_size)) - c_transforms.append(T.RandomHorizontalFlip()) - c_transforms.append(T.RandomVerticalFlip()) - - if colorJitterEnable: - if colorConfig is not None: - print("Enable color jitter!") - colorBrightness = colorConfig["brightness"] - colorContrast = colorConfig["contrast"] - colorSaturation = colorConfig["saturation"] - colorHue = (-colorConfig["hue"],colorConfig["hue"]) - c_transforms.append(T.ColorJitter(brightness=colorBrightness,\ - contrast=colorContrast,saturation=colorSaturation, hue=colorHue)) c_transforms.append(T.ToTensor()) - c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) c_transforms = T.Compose(c_transforms) - content_dataset = Place365Dataset( - place365_root, - selected_c_dir, + content_dataset = VGGFace2HQDataset( + data_root, c_transforms, "jpg", random_seed) diff --git a/data_tools/data_loader_place365.py b/data_tools/data_loader_place365.py deleted file mode 100644 index 0e339c3..0000000 --- a/data_tools/data_loader_place365.py +++ /dev/null @@ -1,223 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: data_loader_modify.py -# Created Date: Saturday April 4th 2020 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Monday, 11th October 2021 12:17:58 am -# Modified By: Chen Xuanhong -# Copyright (c) 2020 Shanghai Jiao Tong University -############################################################# - -import os -import torch -import random -from PIL import Image -from pathlib import Path -from torch.utils import data -import torchvision.datasets as dsets -from torchvision import transforms as T -from data_tools.StyleResize import StyleResize -# from StyleResize import StyleResize - -class data_prefetcher(): - def __init__(self, loader): - self.loader = loader - self.dataiter = iter(loader) - self.stream = torch.cuda.Stream() - # self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) - # self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) - # With Amp, it isn't necessary to manually convert data to half. - # if args.fp16: - # self.mean = self.mean.half() - # self.std = self.std.half() - self.num_images = len(loader) - self.preload() - - def preload(self): - try: - self.content = next(self.dataiter) - except StopIteration: - self.dataiter = iter(self.loader) - self.content = next(self.dataiter) - - with torch.cuda.stream(self.stream): - self.content= self.content.cuda(non_blocking=True) - # With Amp, it isn't necessary to manually convert data to half. - # if args.fp16: - # self.next_input = self.next_input.half() - # else: - # self.next_input = self.next_input.float() - # self.next_input = self.next_input.sub_(self.mean).div_(self.std) - def next(self): - torch.cuda.current_stream().wait_stream(self.stream) - content = self.content - self.preload() - return content - - def __len__(self): - """Return the number of images.""" - return self.num_images - -class Place365Dataset(data.Dataset): - """Dataset class for the Artworks dataset and content dataset.""" - - def __init__(self, - content_image_dir, - selectedContent, - content_transform, - subffix='jpg', - random_seed=1234): - """Initialize and preprocess the CelebA dataset.""" - self.content_image_dir = content_image_dir - self.content_transform = content_transform - self.selectedContent = selectedContent - self.subffix = subffix - self.content_dataset = [] - self.random_seed = random_seed - self.preprocess() - self.num_images = len(self.content_dataset) - - def preprocess(self): - """Preprocess the Artworks dataset.""" - print("processing content images...") - for dir_item in self.selectedContent: - join_path = Path(self.content_image_dir,dir_item) - if join_path.exists(): - print("processing %s"%dir_item,end='\r') - images = join_path.glob('*.%s'%(self.subffix)) - for item in images: - self.content_dataset.append(item) - else: - print("%s dir does not exist!"%dir_item,end='\r') - random.seed(self.random_seed) - random.shuffle(self.content_dataset) - print('Finished preprocessing the Content dataset, total image number: %d...'%len(self.content_dataset)) - - def __getitem__(self, index): - """Return one image and its corresponding attribute label.""" - filename = self.content_dataset[index] - image = Image.open(filename) - content = self.content_transform(image) - return content - - def __len__(self): - """Return the number of images.""" - return self.num_images - -def GetLoader( dataset_roots, - batch_size=16, - crop_size=512, - **kwargs - ): - """Build and return a data loader.""" - if not kwargs: - a = "Input params error!" - raise ValueError(print(a)) - - colorJitterEnable = kwargs["color_jitter"] - colorConfig = kwargs["color_config"] - num_workers = kwargs["dataloader_workers"] - num_workers = kwargs["dataloader_workers"] - place365_root = dataset_roots["Place365_big"] - selected_c_dir = kwargs["selected_content_dir"] - random_seed = kwargs["random_seed"] - - c_transforms = [] - - # s_transforms.append(T.Resize(900)) - c_transforms.append(T.Resize(900)) - c_transforms.append(T.RandomCrop(crop_size)) - c_transforms.append(T.RandomHorizontalFlip()) - c_transforms.append(T.RandomVerticalFlip()) - - if colorJitterEnable: - if colorConfig is not None: - print("Enable color jitter!") - colorBrightness = colorConfig["brightness"] - colorContrast = colorConfig["contrast"] - colorSaturation = colorConfig["saturation"] - colorHue = (-colorConfig["hue"],colorConfig["hue"]) - c_transforms.append(T.ColorJitter(brightness=colorBrightness,\ - contrast=colorContrast,saturation=colorSaturation, hue=colorHue)) - c_transforms.append(T.ToTensor()) - c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) - c_transforms = T.Compose(c_transforms) - - content_dataset = Place365Dataset( - place365_root, - selected_c_dir, - c_transforms, - "jpg", - random_seed) - content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size, - drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True) - prefetcher = data_prefetcher(content_data_loader) - return prefetcher - -def denorm(x): - out = (x + 1) / 2 - return out.clamp_(0, 1) - -if __name__ == "__main__": - from torchvision.utils import save_image - style_class = ["vangogh","picasso","samuel"] - categories_names = \ - ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor', - 'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field', - 'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse', - 'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus', - 'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet', - 'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse', - 'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway', - 'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor', - 'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior', - 'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge', - 'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe', - 'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah', - 'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse', - 'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain', - 'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park', - 'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track', - 'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge', - 'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole', - 's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house', - 'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct', - 'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave', - 'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft', - 'b/building_facade', 'c/cemetery'] - - s_datapath = "D:\\F_Disk\\data_set\\Art_Data\\data_art_backup" - c_datapath = "D:\\Downloads\\data_large" - savepath = "D:\\PatchFace\\PleaseWork\\multi-style-gan\\StyleTransfer\\dataloader_test" - - imsize = 512 - s_datasetloader= getLoader(s_datapath,c_datapath, - style_class, categories_names, - crop_size=imsize, batch_size=16, num_workers=4) - wocao = iter(s_datasetloader) - for i in range(500): - print("new batch") - s_image,c_image,label = next(wocao) - print(label) - # print(label) - # saved_image1 = torch.cat([denorm(image.data),denorm(hahh.data)],3) - # save_image(denorm(image), "%s\\%d-label-%d.jpg"%(savepath,i), nrow=1, padding=1) - pass - # import cv2 - # import os - # for dir_item in categories_names: - # join_path = Path(contentdatapath,dir_item) - # if join_path.exists(): - # print("processing %s"%dir_item,end='\r') - # images = join_path.glob('*.%s'%("jpg")) - # for item in images: - # temp_path = str(item) - # # temp = cv2.imread(temp_path) - # temp = Image.open(temp_path) - # if temp.layers<3: - # print("remove broken image...") - # print("image name:%s"%temp_path) - # del temp - # os.remove(item) \ No newline at end of file diff --git a/env/env.json b/env/env.json new file mode 100644 index 0000000..107f12a --- /dev/null +++ b/env/env.json @@ -0,0 +1,17 @@ +{ + "path":{ + "train_log_root":"./train_logs", + "test_log_root":"./test_logs", + "systemLog":"./system/system_log.log", + "dataset_paths": { + "vggface2_hq": "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan", + "val_dataset_root": "", + "test_dataset_root": "" + }, + "train_config_path":"./train_yamls", + "train_scripts_path":"./train_scripts", + "test_scripts_path":"./test_scripts", + "config_json_name":"model_config.json" + + } +} \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..289de91 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,4 @@ +from .arcface_models import ArcMarginModel +from .arcface_models import ResNet +from .arcface_models import IRBlock +from .arcface_models import SEBlock \ No newline at end of file diff --git a/models/arcface_models.py b/models/arcface_models.py new file mode 100644 index 0000000..c678011 --- /dev/null +++ b/models/arcface_models.py @@ -0,0 +1,162 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Parameter +from .config import device, num_classes + + + +class SEBlock(nn.Module): + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.PReLU(), + nn.Linear(channel // reduction, channel), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class IRBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, use_se=True): + self.inplanes = 64 + self.use_se = use_se + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn2 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc = nn.Linear(512 * 7 * 7, 512) + self.bn3 = nn.BatchNorm1d(512) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.bn2(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + x = self.bn3(x) + + return x + + +class ArcMarginModel(nn.Module): + def __init__(self, args): + super(ArcMarginModel, self).__init__() + + self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size)) + nn.init.xavier_uniform_(self.weight) + + self.easy_margin = args.easy_margin + self.m = args.margin_m + self.s = args.margin_s + + self.cos_m = math.cos(self.m) + self.sin_m = math.sin(self.m) + self.th = math.cos(math.pi - self.m) + self.mm = math.sin(math.pi - self.m) * self.m + + def forward(self, input, label): + x = F.normalize(input) + W = F.normalize(self.weight) + cosine = F.linear(x, W) + sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) + phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) + if self.easy_margin: + phi = torch.where(cosine > 0, phi, cosine) + else: + phi = torch.where(cosine > self.th, phi, cosine - self.mm) + one_hot = torch.zeros(cosine.size(), device=device) + one_hot.scatter_(1, label.view(-1, 1).long(), 1) + output = (one_hot * phi) + ((1.0 - one_hot) * cosine) + output *= self.s + return output \ No newline at end of file diff --git a/models/config.py b/models/config.py new file mode 100644 index 0000000..eb83edb --- /dev/null +++ b/models/config.py @@ -0,0 +1,28 @@ +import os + +import torch + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors + +# Model parameters +image_w = 112 +image_h = 112 +channel = 3 +emb_size = 512 + +# Training parameters +num_workers = 1 # for data-loading; right now, only 1 works with h5py +grad_clip = 5. # clip gradients at an absolute value of +print_freq = 100 # print training/validation stats every __ batches +checkpoint = None # path to checkpoint, None if none + +# Data parameters +num_classes = 93431 +num_samples = 5179510 +DATA_DIR = 'data' +# faces_ms1m_folder = 'data/faces_ms1m_112x112' +faces_ms1m_folder = 'data/ms1m-retinaface-t1' +path_imgidx = os.path.join(faces_ms1m_folder, 'train.idx') +path_imgrec = os.path.join(faces_ms1m_folder, 'train.rec') +IMG_DIR = 'data/images' +pickle_file = 'data/faces_ms1m_112x112.pickle' diff --git a/train.py b/train.py index b17cd0c..9f3bb4d 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ # Created Date: Tuesday April 28th 2020 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 8:50:15 pm +# Last Modified: Monday, 17th January 2022 1:00:00 pm # Modified By: Chen Xuanhong # Copyright (c) 2020 Shanghai Jiao Tong University ############################################################# @@ -31,22 +31,28 @@ def getParameters(): parser = argparse.ArgumentParser() # general settings - parser.add_argument('-v', '--version', type=str, default='liff_warpinvo_0', + parser.add_argument('-v', '--version', type=str, default='FM', help="version name for train, test, finetune") + parser.add_argument('-t', '--tag', type=str, default='test', + help="tag for current experiment") parser.add_argument('-p', '--phase', type=str, default="train", choices=['train', 'finetune','debug'], help="The phase of current project") - parser.add_argument('-c', '--cuda', type=int, default=1) # <0 if it is set as -1, program will use CPU - parser.add_argument('-e', '--checkpoint_epoch', type=int, default=74, + parser.add_argument('-c', '--cuda', type=int, default=0) # <0 if it is set as -1, program will use CPU + parser.add_argument('-e', '--ckpt', type=int, default=74, help="checkpoint epoch for test phase or finetune phase") # training parser.add_argument('--experiment_description', type=str, default="尝试使用Liif+Invo作为上采样和降采样的算子,降采样两个DSF算子,上采样两个DSF算子") - parser.add_argument('--train_yaml', type=str, default="train_FastNST_CNN_Resblock.yaml") + parser.add_argument('--train_yaml', type=str, default="train_512FM.yaml") + + # system logger + parser.add_argument('--logger', type=str, + default="wandb", choices=['tensorboard', 'wandb','none'], help='system logger') # # logs (does not to be changed in most time) # parser.add_argument('--dataloader_workers', type=int, default=6) diff --git a/train_scripts/trainer_FM.py b/train_scripts/trainer_FM.py new file mode 100644 index 0000000..40e0fcc --- /dev/null +++ b/train_scripts/trainer_FM.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Monday, 17th January 2022 1:12:08 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random + +import numpy as np + +import torch +import torch.nn.functional as F +from utilities.plot import plot_batch + +from train_scripts.trainer_base import TrainerBase + +class Trainer(TrainerBase): + + def __init__(self, config, reporter): + super(Trainer, self).__init__(config, reporter) + + self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + # TODO modify this function to build your models + def init_framework(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + + model_config = self.config["model_configs"] + + if self.config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + dscript_name = "components." + model_config["d_model"]["script"] + + elif self.config["phase"] == "finetune": + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + dscript_name = self.config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + self.reporter.writeInfo("Generator structure:") + self.reporter.writeModel(self.gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + self.dis = dis_class(**model_config["d_model"]["module_params"]) + self.dis.feature_network.requires_grad_(False) + + # print and recorde model structure + self.reporter.writeInfo("Discriminator structure:") + self.reporter.writeModel(self.dis.__str__()) + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + + # train in GPU + if self.config["cuda"] >=0: + self.gen = self.gen.cuda() + self.dis = self.dis.cuda() + self.arcface= self.arcface.cuda() + + self.arcface.eval() + self.arcface.requires_grad_(False) + + # if in finetune phase, load the pretrained checkpoint + if self.config["phase"] == "finetune": + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.gen.load_state_dict(torch.load(model_path)) + + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["discriminator_name"])) + self.dis.load_state_dict(torch.load(model_path)) + + print('loaded trained backbone model step {}...!'.format(self.config["project_checkpoints"])) + + # TODO modify this function to configurate the optimizer of your pipeline + def __setup_optimizers__(self): + g_train_opt = self.config['g_optim_config'] + d_train_opt = self.config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in self.gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + self.reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in self.dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + self.reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = self.config['optim_type'] + + if optim_type == 'Adam': + self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if self.config["phase"] == "finetune": + opt_path = os.path.join(self.config["project_checkpoints"], + "step%d_optim_%s.pth"%(self.config["checkpoint_step"], + self.config["optimizer_names"]["generator_name"])) + self.g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(self.config["project_checkpoints"], + "step%d_optim_%s.pth"%(self.config["checkpoint_step"], + self.config["optimizer_names"]["discriminator_name"])) + self.d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(self.config["project_checkpoints"])) + + + # TODO modify this function to evaluate your model + # Evaluate the checkpoint + def __evaluation__(self, + step = 0, + **kwargs + ): + src_image1 = kwargs["src1"] + src_image2 = kwargs["src2"] + batch_size = self.batch_size + self.gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* self.img_std + self.img_mean).numpy() + for r in range(batch_size): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = self.arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_size): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_size, 1, 1, 1) + img_fake = self.gen(image_infer, id_vector_src1).cpu() + + img_fake = img_fake * self.img_std + img_fake = img_fake + self.img_mean + img_fake = img_fake.numpy() + for j in range(batch_size): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(self.sample_dir, 'step_'+str(step+1)+'.jpg')) + + + + + def train(self): + + ckpt_dir = self.config["project_checkpoints"] + log_frep = self.config["log_step"] + model_freq = self.config["model_save_step"] + total_step = self.config["total_step"] + random_seed = self.config["dataset_params"]["random_seed"] + + self.batch_size = self.config["batch_size"] + self.sample_dir = self.config["project_samples"] + self.arcface_ckpt= self.config["arcface_ckpt"] + + + # prep_weights= self.config["layersWeight"] + id_w = self.config["id_weight"] + rec_w = self.config["reconstruct_weight"] + feat_w = self.config["feature_match_weight"] + + + + super().train() + + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + random.seed(random_seed) + randindex = [i for i in range(self.batch_size)] + random.shuffle(randindex) + import datetime + for step in range(self.start, total_step): + self.gen.train() + self.dis.train() + for interval in range(2): + random.shuffle(randindex) + src_image1, src_image2 = self.train_loader.next() + + if step%2 == 0: + img_id = src_image2 + else: + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = self.arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + if interval: + + img_fake = self.gen(src_image1, latent_id) + gen_logits,_ = self.dis(img_fake.detach(), None) + loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() + + real_logits,_ = self.dis(src_image2,None) + loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean() + + loss_D = loss_Dgen + loss_Dreal + self.d_optimizer.zero_grad() + loss_D.backward() + self.d_optimizer.step() + else: + + # model.netD.requires_grad_(True) + img_fake = self.gen(src_image1, latent_id) + # G loss + gen_logits,feat = self.dis(img_fake, None) + + loss_Gmain = (-gen_logits).mean() + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = self.arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean() + real_feat = self.dis.get_feature(src_image1) + feat_match_loss = l1_loss(feat["3"],real_feat["3"]) + loss_G = loss_Gmain + loss_G_ID * id_w + \ + feat_match_loss * feat_w + if step%2 == 0: + #G_Rec + loss_G_Rec = l1_loss(img_fake, src_image1) + loss_G += loss_G_Rec * rec_w + + self.g_optimizer.zero_grad() + loss_G.backward() + self.g_optimizer.step() + + # Print out log info + if (step + 1) % log_frep == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + + epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ + G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ + D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + format(self.config["version"], elapsed, step, total_step, \ + loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ + loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + print(epochinformation) + self.reporter.writeInfo(epochinformation) + + if self.config["logger"] == "tensorboard": + self.logger.add_scalar('G/G_loss', loss_G.item(), step) + self.logger.add_scalar('G/Rec_loss', loss_G_Rec.item(), step) + self.logger.add_scalar('G/Fm_loss', feat_match_loss.item(), step) + self.logger.add_scalar('D/D_loss', loss_D.item(), step) + self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step) + self.logger.add_scalar('D/D_real', loss_Dreal.item(), step) + elif self.config["logger"] == "wandb": + self.logger.log({"G_loss": loss_G.item()}, step = step) + self.logger.log({"Rec_loss": loss_G_Rec.item()}, step = step) + self.logger.log({"Fm_loss": feat_match_loss.item()}, step = step) + self.logger.log({"D_loss": loss_D.item()}, step = step) + self.logger.log({"D_fake": loss_Dgen.item()}, step = step) + self.logger.log({"D_real": loss_Dreal.item()}, step = step) + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if (step+1) % model_freq==0: + + torch.save(self.gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + self.config["checkpoint_names"]["generator_name"]))) + torch.save(self.dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + self.config["checkpoint_names"]["discriminator_name"]))) + + torch.save(self.g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + self.config["checkpoint_names"]["generator_name"]))) + + torch.save(self.d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + self.config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + + self.__evaluation__( + step = step, + **{ + "src1": src_image1, + "src2": src_image2 + }) \ No newline at end of file diff --git a/train_scripts/trainer_FastNST.py b/train_scripts/trainer_FastNST.py deleted file mode 100644 index 0931509..0000000 --- a/train_scripts/trainer_FastNST.py +++ /dev/null @@ -1,307 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: trainer_condition_SN_multiscale.py -# Created Date: Saturday April 18th 2020 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 12th October 2021 2:18:26 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2020 Shanghai Jiao Tong University -############################################################# - - -import os -import time - -import torch -from torchvision.utils import save_image - -from components.Transform import Transform_block -from utilities.utilities import denorm, Gram, img2tensor255 -from pretrained_weights.vgg import VGG16 - -class Trainer(object): - - def __init__(self, config, reporter): - - self.config = config - # logger - self.reporter = reporter - - # Data loader - #============build train dataloader==============# - # TODO to modify the key: "your_train_dataset" to get your train dataset path - self.train_dataset = config["dataset_paths"][config["dataset_name"]] - #================================================# - print("Prepare the train dataloader...") - dlModulename = config["dataloader"] - package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) - dataloaderClass = getattr(package, 'GetLoader') - self.dataloader_class = dataloaderClass - dataloader = self.dataloader_class(self.train_dataset, - config["batch_size"], - config["imcrop_size"], - **config["dataset_params"]) - - self.train_loader= dataloader - - #========build evaluation dataloader=============# - # TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path - # eval_dataset = config["dataset_paths"][config["eval_dataset_name"]] - - # #================================================# - # print("Prepare the evaluation dataloader...") - # dlModulename = config["eval_dataloader"] - # package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True) - # dataloaderClass = getattr(package, 'EvalDataset') - # dataloader = dataloaderClass(eval_dataset, - # config["eval_batch_size"]) - # self.eval_loader= dataloader - - # self.eval_iter = len(dataloader)//config["eval_batch_size"] - # if len(dataloader)%config["eval_batch_size"]>0: - # self.eval_iter+=1 - - #==============build tensorboard=================# - if self.config["use_tensorboard"]: - from utilities.utilities import build_tensorboard - self.tensorboard_writer = build_tensorboard(self.config["project_summary"]) - - # TODO modify this function to build your models - def __init_framework__(self): - ''' - This function is designed to define the framework, - and print the framework information into the log file - ''' - #===============build models================# - print("build models...") - # TODO [import models here] - - model_config = self.config["model_configs"] - - if self.config["phase"] == "train": - - gscript_name = "components." + model_config["g_model"]["script"] - - # TODO To save the important scripts - # save the yaml file - import shutil - file1 = os.path.join("components", "%s.py"%model_config["g_model"]["script"]) - tgtfile1 = os.path.join(self.config["project_scripts"], "%s.py"%model_config["g_model"]["script"]) - shutil.copyfile(file1,tgtfile1) - - elif self.config["phase"] == "finetune": - gscript_name = self.config["com_base"] + model_config["g_model"]["script"] - - class_name = model_config["g_model"]["class_name"] - package = __import__(gscript_name, fromlist=True) - gen_class = getattr(package, class_name) - self.gen = gen_class(**model_config["g_model"]["module_params"]) - - # print and recorde model structure - self.reporter.writeInfo("Generator structure:") - self.reporter.writeModel(self.gen.__str__()) - - # train in GPU - if self.config["cuda"] >=0: - self.gen = self.gen.cuda() - - # if in finetune phase, load the pretrained checkpoint - if self.config["phase"] == "finetune": - model_path = os.path.join(self.config["project_checkpoints"], - "epoch%d_%s.pth"%(self.config["checkpoint_step"], - self.config["checkpoint_names"]["generator_name"])) - self.gen.load_state_dict(torch.load(model_path)) - - print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"])) - - - # TODO modify this function to evaluate your model - def __evaluation__(self, epoch, step = 0): - # Evaluate the checkpoint - self.network.eval() - total_psnr = 0 - total_num = 0 - with torch.no_grad(): - for _ in range(self.eval_iter): - hr, lr = self.eval_loader() - - if self.config["cuda"] >=0: - hr = hr.cuda() - lr = lr.cuda() - hr = (hr + 1.0)/2.0 * 255.0 - hr = torch.clamp(hr,0.0,255.0) - lr = (lr + 1.0)/2.0 * 255.0 - lr = torch.clamp(lr,0.0,255.0) - res = self.network(lr) - # res = (res + 1.0)/2.0 * 255.0 - # hr = (hr + 1.0)/2.0 * 255.0 - res = torch.clamp(res,0.0,255.0) - diff = (res-hr) ** 2 - diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt() - psnrs = 20. * (255. / diff).log10() - total_psnr+= psnrs.sum() - total_num+=res.shape[0] - final_psnr = total_psnr/total_num - print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"], - epoch, final_psnr)) - self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr)) - self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch) - - # TODO modify this function to configurate the optimizer of your pipeline - def __setup_optimizers__(self): - g_train_opt = self.config['g_optim_config'] - g_optim_params = [] - for k, v in self.gen.named_parameters(): - if v.requires_grad: - g_optim_params.append(v) - else: - self.reporter.writeInfo(f'Params {k} will not be optimized.') - print(f'Params {k} will not be optimized.') - - optim_type = self.config['optim_type'] - - if optim_type == 'Adam': - self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') - # self.optimizers.append(self.optimizer_g) - - - def train(self): - - ckpt_dir = self.config["project_checkpoints"] - log_frep = self.config["log_step"] - model_freq = self.config["model_save_epoch"] - total_epoch = self.config["total_epoch"] - batch_size = self.config["batch_size"] - style_img = self.config["style_img_path"] - - # prep_weights= self.config["layersWeight"] - content_w = self.config["content_weight"] - style_w = self.config["style_weight"] - crop_size = self.config["imcrop_size"] - - sample_dir = self.config["project_samples"] - - - #===============build framework================# - self.__init_framework__() - - #===============build optimizer================# - # Optimizer - # TODO replace below lines to build your optimizer - print("build the optimizer...") - self.__setup_optimizers__() - - #===============build losses===================# - # TODO replace below lines to build your losses - MSE_loss = torch.nn.MSELoss() - - - # set the start point for training loop - if self.config["phase"] == "finetune": - start = self.config["checkpoint_epoch"] - 1 - else: - start = 0 - - # print("prepare the fixed labels...") - # fix_label = [i for i in range(n_class)] - # fix_label = torch.tensor(fix_label).long().cuda() - # fix_label = fix_label.view(n_class,1) - # fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1) - - # Start time - import datetime - print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) - - from utilities.logo_class import logo_class - logo_class.print_start_training() - start_time = time.time() - - # Caculate the epoch number - step_epoch = len(self.train_loader) - step_epoch = step_epoch // batch_size - print("Total step = %d in each epoch"%step_epoch) - - VGG = VGG16().cuda() - - MEAN_VAL = 127.5 - SCALE_VAL= 127.5 - # Get Style Features - imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda() - imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda() - - style_tensor = img2tensor255(style_img).cuda() - style_tensor = style_tensor.add(imagenet_neg_mean) - B, C, H, W = style_tensor.shape - style_tensor = VGG(style_tensor.expand([batch_size, C, H, W])) - # style_features = VGG(style_tensor) - style_gram = {} - for key, value in style_tensor.items(): - style_gram[key] = Gram(value) - del style_tensor - # step_epoch = 2 - for epoch in range(start, total_epoch): - for step in range(step_epoch): - self.gen.train() - - content_images = self.train_loader.next() - fake_image = self.gen(content_images) - generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11)) - content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11)) - content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2']) - - style_loss = 0.0 - for key, value in generated_features.items(): - s_loss = MSE_loss(Gram(value), style_gram[key]) - style_loss += s_loss - - # backward & optimize - g_loss = content_loss* content_w + style_loss* style_w - self.g_optimizer.zero_grad() - g_loss.backward() - self.g_optimizer.step() - - - # Print out log info - if (step + 1) % log_frep == 0: - elapsed = time.time() - start_time - elapsed = str(datetime.timedelta(seconds=elapsed)) - - # cumulative steps - cum_step = (step_epoch * epoch + step + 1) - - epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item()) - print(epochinformation) - self.reporter.writeInfo(epochinformation) - - if self.config["use_tensorboard"]: - self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step) - self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step) - self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step) - - #===============adjust learning rate============# - # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: - # print("Learning rate decay") - # for p in self.optimizer.param_groups: - # p['lr'] *= self.config["lr_decay"] - # print("Current learning rate is %f"%p['lr']) - - #===============save checkpoints================# - if (epoch+1) % model_freq==0: - print("Save epoch %d model checkpoint!"%(epoch+1)) - torch.save(self.gen.state_dict(), - os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1, - self.config["checkpoint_names"]["generator_name"]))) - - torch.cuda.empty_cache() - print('Sample images {}_fake.jpg'.format(epoch + 1)) - self.gen.eval() - with torch.no_grad(): - sample = fake_image - saved_image1 = denorm(sample.cpu().data) - save_image(saved_image1, - os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4) \ No newline at end of file diff --git a/train_scripts/trainer_FastNST_CNN.py b/train_scripts/trainer_FastNST_CNN.py deleted file mode 100644 index 49f1a61..0000000 --- a/train_scripts/trainer_FastNST_CNN.py +++ /dev/null @@ -1,297 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: trainer_condition_SN_multiscale.py -# Created Date: Saturday April 18th 2020 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 7:38:36 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2020 Shanghai Jiao Tong University -############################################################# - - -import os -import time - -import torch -from torchvision.utils import save_image - -from components.Transform import Transform_block -from utilities.utilities import denorm, Gram, img2tensor255crop -from pretrained_weights.vgg import VGG16 - -class Trainer(object): - - def __init__(self, config, reporter): - - self.config = config - # logger - self.reporter = reporter - - # Data loader - #============build train dataloader==============# - # TODO to modify the key: "your_train_dataset" to get your train dataset path - self.train_dataset = config["dataset_paths"][config["dataset_name"]] - #================================================# - print("Prepare the train dataloader...") - dlModulename = config["dataloader"] - package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) - dataloaderClass = getattr(package, 'GetLoader') - self.dataloader_class = dataloaderClass - dataloader = self.dataloader_class(self.train_dataset, - config["batch_size"], - config["imcrop_size"], - **config["dataset_params"]) - - self.train_loader= dataloader - - #========build evaluation dataloader=============# - # TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path - # eval_dataset = config["dataset_paths"][config["eval_dataset_name"]] - - # #================================================# - # print("Prepare the evaluation dataloader...") - # dlModulename = config["eval_dataloader"] - # package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True) - # dataloaderClass = getattr(package, 'EvalDataset') - # dataloader = dataloaderClass(eval_dataset, - # config["eval_batch_size"]) - # self.eval_loader= dataloader - - # self.eval_iter = len(dataloader)//config["eval_batch_size"] - # if len(dataloader)%config["eval_batch_size"]>0: - # self.eval_iter+=1 - - #==============build tensorboard=================# - if self.config["use_tensorboard"]: - from utilities.utilities import build_tensorboard - self.tensorboard_writer = build_tensorboard(self.config["project_summary"]) - - # TODO modify this function to build your models - def __init_framework__(self): - ''' - This function is designed to define the framework, - and print the framework information into the log file - ''' - #===============build models================# - print("build models...") - # TODO [import models here] - - model_config = self.config["model_configs"] - - if self.config["phase"] == "train": - gscript_name = "components." + model_config["g_model"]["script"] - - elif self.config["phase"] == "finetune": - gscript_name = self.config["com_base"] + model_config["g_model"]["script"] - - class_name = model_config["g_model"]["class_name"] - package = __import__(gscript_name, fromlist=True) - gen_class = getattr(package, class_name) - self.gen = gen_class(**model_config["g_model"]["module_params"]) - - # print and recorde model structure - self.reporter.writeInfo("Generator structure:") - self.reporter.writeModel(self.gen.__str__()) - - # train in GPU - if self.config["cuda"] >=0: - self.gen = self.gen.cuda() - - # if in finetune phase, load the pretrained checkpoint - if self.config["phase"] == "finetune": - model_path = os.path.join(self.config["project_checkpoints"], - "epoch%d_%s.pth"%(self.config["checkpoint_step"], - self.config["checkpoint_names"]["generator_name"])) - self.gen.load_state_dict(torch.load(model_path)) - - print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"])) - - - # TODO modify this function to evaluate your model - def __evaluation__(self, epoch, step = 0): - # Evaluate the checkpoint - self.network.eval() - total_psnr = 0 - total_num = 0 - with torch.no_grad(): - for _ in range(self.eval_iter): - hr, lr = self.eval_loader() - - if self.config["cuda"] >=0: - hr = hr.cuda() - lr = lr.cuda() - hr = (hr + 1.0)/2.0 * 255.0 - hr = torch.clamp(hr,0.0,255.0) - lr = (lr + 1.0)/2.0 * 255.0 - lr = torch.clamp(lr,0.0,255.0) - res = self.network(lr) - # res = (res + 1.0)/2.0 * 255.0 - # hr = (hr + 1.0)/2.0 * 255.0 - res = torch.clamp(res,0.0,255.0) - diff = (res-hr) ** 2 - diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt() - psnrs = 20. * (255. / diff).log10() - total_psnr+= psnrs.sum() - total_num+=res.shape[0] - final_psnr = total_psnr/total_num - print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"], - epoch, final_psnr)) - self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr)) - self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch) - - # TODO modify this function to configurate the optimizer of your pipeline - def __setup_optimizers__(self): - g_train_opt = self.config['g_optim_config'] - g_optim_params = [] - for k, v in self.gen.named_parameters(): - if v.requires_grad: - g_optim_params.append(v) - else: - self.reporter.writeInfo(f'Params {k} will not be optimized.') - print(f'Params {k} will not be optimized.') - - optim_type = self.config['optim_type'] - - if optim_type == 'Adam': - self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') - # self.optimizers.append(self.optimizer_g) - - - def train(self): - - ckpt_dir = self.config["project_checkpoints"] - log_frep = self.config["log_step"] - model_freq = self.config["model_save_epoch"] - total_epoch = self.config["total_epoch"] - batch_size = self.config["batch_size"] - style_img = self.config["style_img_path"] - - # prep_weights= self.config["layersWeight"] - content_w = self.config["content_weight"] - style_w = self.config["style_weight"] - crop_size = self.config["imcrop_size"] - - sample_dir = self.config["project_samples"] - - - #===============build framework================# - self.__init_framework__() - - #===============build optimizer================# - # Optimizer - # TODO replace below lines to build your optimizer - print("build the optimizer...") - self.__setup_optimizers__() - - #===============build losses===================# - # TODO replace below lines to build your losses - MSE_loss = torch.nn.MSELoss() - - - # set the start point for training loop - if self.config["phase"] == "finetune": - start = self.config["checkpoint_epoch"] - 1 - else: - start = 0 - - # print("prepare the fixed labels...") - # fix_label = [i for i in range(n_class)] - # fix_label = torch.tensor(fix_label).long().cuda() - # fix_label = fix_label.view(n_class,1) - # fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1) - - # Start time - import datetime - print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) - - from utilities.logo_class import logo_class - logo_class.print_start_training() - start_time = time.time() - - # Caculate the epoch number - step_epoch = len(self.train_loader) - step_epoch = step_epoch // batch_size - print("Total step = %d in each epoch"%step_epoch) - - VGG = VGG16().cuda() - - MEAN_VAL = 127.5 - SCALE_VAL= 127.5 - # Get Style Features - imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda() - imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda() - - style_tensor = img2tensor255crop(style_img,crop_size).cuda() - style_tensor = style_tensor.add(imagenet_neg_mean) - B, C, H, W = style_tensor.shape - style_features = VGG(style_tensor.expand([batch_size, C, H, W])) - style_gram = {} - for key, value in style_features.items(): - style_gram[key] = Gram(value) - # step_epoch = 2 - for epoch in range(start, total_epoch): - for step in range(step_epoch): - self.gen.train() - - content_images = self.train_loader.next() - fake_image = self.gen(content_images) - generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11)) - content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11)) - content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2']) - - style_loss = 0.0 - for key, value in generated_features.items(): - s_loss = MSE_loss(Gram(value), style_gram[key]) - style_loss += s_loss - - # backward & optimize - g_loss = content_loss* content_w + style_loss* style_w - self.g_optimizer.zero_grad() - g_loss.backward() - self.g_optimizer.step() - - - # Print out log info - if (step + 1) % log_frep == 0: - elapsed = time.time() - start_time - elapsed = str(datetime.timedelta(seconds=elapsed)) - - # cumulative steps - cum_step = (step_epoch * epoch + step + 1) - - epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item()) - print(epochinformation) - self.reporter.writeInfo(epochinformation) - - if self.config["use_tensorboard"]: - self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step) - self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step) - self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step) - - #===============adjust learning rate============# - # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: - # print("Learning rate decay") - # for p in self.optimizer.param_groups: - # p['lr'] *= self.config["lr_decay"] - # print("Current learning rate is %f"%p['lr']) - - #===============save checkpoints================# - if (epoch+1) % model_freq==0: - print("Save epoch %d model checkpoint!"%(epoch+1)) - torch.save(self.gen.state_dict(), - os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1, - self.config["checkpoint_names"]["generator_name"]))) - - torch.cuda.empty_cache() - print('Sample images {}_fake.jpg'.format(epoch + 1)) - self.gen.eval() - with torch.no_grad(): - sample = fake_image - saved_image1 = denorm(sample.cpu().data) - save_image(saved_image1, - os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4) \ No newline at end of file diff --git a/train_scripts/trainer_FastNST_Liif.py b/train_scripts/trainer_FastNST_Liif.py deleted file mode 100644 index 9343ed6..0000000 --- a/train_scripts/trainer_FastNST_Liif.py +++ /dev/null @@ -1,296 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: trainer_condition_SN_multiscale.py -# Created Date: Saturday April 18th 2020 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 9:25:13 am -# Modified By: Chen Xuanhong -# Copyright (c) 2020 Shanghai Jiao Tong University -############################################################# - - -import os -import time - -import torch -from torchvision.utils import save_image - -from utilities.utilities import denorm, Gram, img2tensor255crop -from pretrained_weights.vgg import VGG16 - -class Trainer(object): - - def __init__(self, config, reporter): - - self.config = config - # logger - self.reporter = reporter - - # Data loader - #============build train dataloader==============# - # TODO to modify the key: "your_train_dataset" to get your train dataset path - self.train_dataset = config["dataset_paths"][config["dataset_name"]] - #================================================# - print("Prepare the train dataloader...") - dlModulename = config["dataloader"] - package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) - dataloaderClass = getattr(package, 'GetLoader') - self.dataloader_class = dataloaderClass - dataloader = self.dataloader_class(self.train_dataset, - config["batch_size"], - config["imcrop_size"], - **config["dataset_params"]) - - self.train_loader= dataloader - - #========build evaluation dataloader=============# - # TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path - # eval_dataset = config["dataset_paths"][config["eval_dataset_name"]] - - # #================================================# - # print("Prepare the evaluation dataloader...") - # dlModulename = config["eval_dataloader"] - # package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True) - # dataloaderClass = getattr(package, 'EvalDataset') - # dataloader = dataloaderClass(eval_dataset, - # config["eval_batch_size"]) - # self.eval_loader= dataloader - - # self.eval_iter = len(dataloader)//config["eval_batch_size"] - # if len(dataloader)%config["eval_batch_size"]>0: - # self.eval_iter+=1 - - #==============build tensorboard=================# - if self.config["use_tensorboard"]: - from utilities.utilities import build_tensorboard - self.tensorboard_writer = build_tensorboard(self.config["project_summary"]) - - # TODO modify this function to build your models - def __init_framework__(self): - ''' - This function is designed to define the framework, - and print the framework information into the log file - ''' - #===============build models================# - print("build models...") - # TODO [import models here] - - model_config = self.config["model_configs"] - - if self.config["phase"] == "train": - gscript_name = "components." + model_config["g_model"]["script"] - - elif self.config["phase"] == "finetune": - gscript_name = self.config["com_base"] + model_config["g_model"]["script"] - - class_name = model_config["g_model"]["class_name"] - package = __import__(gscript_name, fromlist=True) - gen_class = getattr(package, class_name) - self.gen = gen_class(**model_config["g_model"]["module_params"]) - - # print and recorde model structure - self.reporter.writeInfo("Generator structure:") - self.reporter.writeModel(self.gen.__str__()) - - # train in GPU - if self.config["cuda"] >=0: - self.gen = self.gen.cuda() - - # if in finetune phase, load the pretrained checkpoint - if self.config["phase"] == "finetune": - model_path = os.path.join(self.config["project_checkpoints"], - "epoch%d_%s.pth"%(self.config["checkpoint_step"], - self.config["checkpoint_names"]["generator_name"])) - self.gen.load_state_dict(torch.load(model_path)) - - print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"])) - - - # TODO modify this function to evaluate your model - def __evaluation__(self, epoch, step = 0): - # Evaluate the checkpoint - self.network.eval() - total_psnr = 0 - total_num = 0 - with torch.no_grad(): - for _ in range(self.eval_iter): - hr, lr = self.eval_loader() - - if self.config["cuda"] >=0: - hr = hr.cuda() - lr = lr.cuda() - hr = (hr + 1.0)/2.0 * 255.0 - hr = torch.clamp(hr,0.0,255.0) - lr = (lr + 1.0)/2.0 * 255.0 - lr = torch.clamp(lr,0.0,255.0) - res = self.network(lr) - # res = (res + 1.0)/2.0 * 255.0 - # hr = (hr + 1.0)/2.0 * 255.0 - res = torch.clamp(res,0.0,255.0) - diff = (res-hr) ** 2 - diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt() - psnrs = 20. * (255. / diff).log10() - total_psnr+= psnrs.sum() - total_num+=res.shape[0] - final_psnr = total_psnr/total_num - print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"], - epoch, final_psnr)) - self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr)) - self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch) - - # TODO modify this function to configurate the optimizer of your pipeline - def __setup_optimizers__(self): - g_train_opt = self.config['g_optim_config'] - g_optim_params = [] - for k, v in self.gen.named_parameters(): - if v.requires_grad: - g_optim_params.append(v) - else: - self.reporter.writeInfo(f'Params {k} will not be optimized.') - print(f'Params {k} will not be optimized.') - - optim_type = self.config['optim_type'] - - if optim_type == 'Adam': - self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') - # self.optimizers.append(self.optimizer_g) - - - def train(self): - - ckpt_dir = self.config["project_checkpoints"] - log_frep = self.config["log_step"] - model_freq = self.config["model_save_epoch"] - total_epoch = self.config["total_epoch"] - batch_size = self.config["batch_size"] - style_img = self.config["style_img_path"] - - # prep_weights= self.config["layersWeight"] - content_w = self.config["content_weight"] - style_w = self.config["style_weight"] - crop_size = self.config["imcrop_size"] - - sample_dir = self.config["project_samples"] - - - #===============build framework================# - self.__init_framework__() - - #===============build optimizer================# - # Optimizer - # TODO replace below lines to build your optimizer - print("build the optimizer...") - self.__setup_optimizers__() - - #===============build losses===================# - # TODO replace below lines to build your losses - MSE_loss = torch.nn.MSELoss() - - - # set the start point for training loop - if self.config["phase"] == "finetune": - start = self.config["checkpoint_epoch"] - 1 - else: - start = 0 - - # print("prepare the fixed labels...") - # fix_label = [i for i in range(n_class)] - # fix_label = torch.tensor(fix_label).long().cuda() - # fix_label = fix_label.view(n_class,1) - # fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1) - - # Start time - import datetime - print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) - - from utilities.logo_class import logo_class - logo_class.print_start_training() - start_time = time.time() - - # Caculate the epoch number - step_epoch = len(self.train_loader) - step_epoch = step_epoch // batch_size - print("Total step = %d in each epoch"%step_epoch) - - VGG = VGG16().cuda() - - MEAN_VAL = 127.5 - SCALE_VAL= 127.5 - # Get Style Features - imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda() - imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda() - - style_tensor = img2tensor255crop(style_img,crop_size).cuda() - style_tensor = style_tensor.add(imagenet_neg_mean) - B, C, H, W = style_tensor.shape - style_features = VGG(style_tensor.expand([batch_size, C, H, W])) - style_gram = {} - for key, value in style_features.items(): - style_gram[key] = Gram(value) - # step_epoch = 2 - for epoch in range(start, total_epoch): - for step in range(step_epoch): - self.gen.train() - - content_images = self.train_loader.next() - fake_image = self.gen(content_images) - generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11)) - content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11)) - content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2']) - - style_loss = 0.0 - for key, value in generated_features.items(): - s_loss = MSE_loss(Gram(value), style_gram[key]) - style_loss += s_loss - - # backward & optimize - g_loss = content_loss* content_w + style_loss* style_w - self.g_optimizer.zero_grad() - g_loss.backward() - self.g_optimizer.step() - - - # Print out log info - if (step + 1) % log_frep == 0: - elapsed = time.time() - start_time - elapsed = str(datetime.timedelta(seconds=elapsed)) - - # cumulative steps - cum_step = (step_epoch * epoch + step + 1) - - epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item()) - print(epochinformation) - self.reporter.writeInfo(epochinformation) - - if self.config["use_tensorboard"]: - self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step) - self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step) - self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step) - - #===============adjust learning rate============# - # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: - # print("Learning rate decay") - # for p in self.optimizer.param_groups: - # p['lr'] *= self.config["lr_decay"] - # print("Current learning rate is %f"%p['lr']) - - #===============save checkpoints================# - if (epoch+1) % model_freq==0: - print("Save epoch %d model checkpoint!"%(epoch+1)) - torch.save(self.gen.state_dict(), - os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1, - self.config["checkpoint_names"]["generator_name"]))) - - torch.cuda.empty_cache() - print('Sample images {}_fake.jpg'.format(epoch + 1)) - self.gen.eval() - with torch.no_grad(): - sample = fake_image - saved_image1 = denorm(sample.cpu().data) - save_image(saved_image1, - os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4) \ No newline at end of file diff --git a/train_scripts/trainer_FastNST_SWD.py b/train_scripts/trainer_FastNST_SWD.py deleted file mode 100644 index 695f44f..0000000 --- a/train_scripts/trainer_FastNST_SWD.py +++ /dev/null @@ -1,300 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: trainer_condition_SN_multiscale.py -# Created Date: Saturday April 18th 2020 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 19th October 2021 2:28:24 am -# Modified By: Chen Xuanhong -# Copyright (c) 2020 Shanghai Jiao Tong University -############################################################# - - -import os -import time - -import torch -from torchvision.utils import save_image - -from utilities.utilities import denorm, img2tensor255crop -from losses.SliceWassersteinDistance import SWD -from pretrained_weights.vgg import VGG16 - -class Trainer(object): - - def __init__(self, config, reporter): - - self.config = config - # logger - self.reporter = reporter - - # Data loader - #============build train dataloader==============# - # TODO to modify the key: "your_train_dataset" to get your train dataset path - self.train_dataset = config["dataset_paths"][config["dataset_name"]] - #================================================# - print("Prepare the train dataloader...") - dlModulename = config["dataloader"] - package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) - dataloaderClass = getattr(package, 'GetLoader') - self.dataloader_class = dataloaderClass - dataloader = self.dataloader_class(self.train_dataset, - config["batch_size"], - config["imcrop_size"], - **config["dataset_params"]) - - self.train_loader= dataloader - - #========build evaluation dataloader=============# - # TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path - # eval_dataset = config["dataset_paths"][config["eval_dataset_name"]] - - # #================================================# - # print("Prepare the evaluation dataloader...") - # dlModulename = config["eval_dataloader"] - # package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True) - # dataloaderClass = getattr(package, 'EvalDataset') - # dataloader = dataloaderClass(eval_dataset, - # config["eval_batch_size"]) - # self.eval_loader= dataloader - - # self.eval_iter = len(dataloader)//config["eval_batch_size"] - # if len(dataloader)%config["eval_batch_size"]>0: - # self.eval_iter+=1 - - #==============build tensorboard=================# - if self.config["use_tensorboard"]: - from utilities.utilities import build_tensorboard - self.tensorboard_writer = build_tensorboard(self.config["project_summary"]) - - # TODO modify this function to build your models - def __init_framework__(self): - ''' - This function is designed to define the framework, - and print the framework information into the log file - ''' - #===============build models================# - print("build models...") - # TODO [import models here] - - model_config = self.config["model_configs"] - - if self.config["phase"] == "train": - gscript_name = "components." + model_config["g_model"]["script"] - - elif self.config["phase"] == "finetune": - gscript_name = self.config["com_base"] + model_config["g_model"]["script"] - - class_name = model_config["g_model"]["class_name"] - package = __import__(gscript_name, fromlist=True) - gen_class = getattr(package, class_name) - self.gen = gen_class(**model_config["g_model"]["module_params"]) - - # print and recorde model structure - self.reporter.writeInfo("Generator structure:") - self.reporter.writeModel(self.gen.__str__()) - - # train in GPU - if self.config["cuda"] >=0: - self.gen = self.gen.cuda() - - # if in finetune phase, load the pretrained checkpoint - if self.config["phase"] == "finetune": - model_path = os.path.join(self.config["project_checkpoints"], - "epoch%d_%s.pth"%(self.config["checkpoint_epoch"], - self.config["checkpoint_names"]["generator_name"])) - self.gen.load_state_dict(torch.load(model_path)) - - print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"])) - - - # TODO modify this function to evaluate your model - def __evaluation__(self, epoch, step = 0): - # Evaluate the checkpoint - self.network.eval() - total_psnr = 0 - total_num = 0 - with torch.no_grad(): - for _ in range(self.eval_iter): - hr, lr = self.eval_loader() - - if self.config["cuda"] >=0: - hr = hr.cuda() - lr = lr.cuda() - hr = (hr + 1.0)/2.0 * 255.0 - hr = torch.clamp(hr,0.0,255.0) - lr = (lr + 1.0)/2.0 * 255.0 - lr = torch.clamp(lr,0.0,255.0) - res = self.network(lr) - # res = (res + 1.0)/2.0 * 255.0 - # hr = (hr + 1.0)/2.0 * 255.0 - res = torch.clamp(res,0.0,255.0) - diff = (res-hr) ** 2 - diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt() - psnrs = 20. * (255. / diff).log10() - total_psnr+= psnrs.sum() - total_num+=res.shape[0] - final_psnr = total_psnr/total_num - print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"], - epoch, final_psnr)) - self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr)) - self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch) - - # TODO modify this function to configurate the optimizer of your pipeline - def __setup_optimizers__(self): - g_train_opt = self.config['g_optim_config'] - g_optim_params = [] - for k, v in self.gen.named_parameters(): - if v.requires_grad: - g_optim_params.append(v) - else: - self.reporter.writeInfo(f'Params {k} will not be optimized.') - print(f'Params {k} will not be optimized.') - - optim_type = self.config['optim_type'] - - if optim_type == 'Adam': - self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') - # self.optimizers.append(self.optimizer_g) - - - def train(self): - - ckpt_dir = self.config["project_checkpoints"] - log_frep = self.config["log_step"] - model_freq = self.config["model_save_epoch"] - total_epoch = self.config["total_epoch"] - batch_size = self.config["batch_size"] - style_img = self.config["style_img_path"] - - # prep_weights= self.config["layersWeight"] - content_w = self.config["content_weight"] - style_w = self.config["style_weight"] - crop_size = self.config["imcrop_size"] - swd_dim = self.config["swd_dim"] - sample_dir = self.config["project_samples"] - - - #===============build framework================# - self.__init_framework__() - - #===============build optimizer================# - # Optimizer - # TODO replace below lines to build your optimizer - print("build the optimizer...") - self.__setup_optimizers__() - - #===============build losses===================# - # TODO replace below lines to build your losses - MSE_loss = torch.nn.MSELoss() - - - # set the start point for training loop - if self.config["phase"] == "finetune": - start = self.config["checkpoint_epoch"] - 1 - else: - start = 0 - - # print("prepare the fixed labels...") - # fix_label = [i for i in range(n_class)] - # fix_label = torch.tensor(fix_label).long().cuda() - # fix_label = fix_label.view(n_class,1) - # fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1) - - # Start time - import datetime - print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) - - from utilities.logo_class import logo_class - logo_class.print_start_training() - start_time = time.time() - - # Caculate the epoch number - step_epoch = len(self.train_loader) - step_epoch = step_epoch // batch_size - print("Total step = %d in each epoch"%step_epoch) - - VGG = VGG16().cuda() - - MEAN_VAL = 127.5 - SCALE_VAL= 127.5 - # Get Style Features - imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda() - imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda() - - # swd = SWD() - style_tensor = img2tensor255crop(style_img,crop_size).cuda() - style_tensor = style_tensor.add(imagenet_neg_mean) - B, C, H, W = style_tensor.shape - style_features = VGG(style_tensor.expand([batch_size, C, H, W])) - swd_list = {} - for key, value in style_features.items(): - - swd_list[key] = SWD(value.shape[1],swd_dim).cuda() - # step_epoch = 2 - for epoch in range(start, total_epoch): - for step in range(step_epoch): - self.gen.train() - - content_images = self.train_loader.next() - fake_image = self.gen(content_images) - generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11)) - content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11)) - content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2']) - - style_loss = 0.0 - for key, value in generated_features.items(): - swd_list[key].update() - s_loss = MSE_loss(swd_list[key](value), swd_list[key](style_features[key])) - style_loss += s_loss - - # backward & optimize - g_loss = content_loss* content_w + style_loss* style_w - self.g_optimizer.zero_grad() - g_loss.backward() - self.g_optimizer.step() - - - # Print out log info - if (step + 1) % log_frep == 0: - elapsed = time.time() - start_time - elapsed = str(datetime.timedelta(seconds=elapsed)) - - # cumulative steps - cum_step = (step_epoch * epoch + step + 1) - - epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item()) - print(epochinformation) - self.reporter.writeInfo(epochinformation) - - if self.config["use_tensorboard"]: - self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step) - self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step) - self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step) - - #===============adjust learning rate============# - # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: - # print("Learning rate decay") - # for p in self.optimizer.param_groups: - # p['lr'] *= self.config["lr_decay"] - # print("Current learning rate is %f"%p['lr']) - - #===============save checkpoints================# - if (epoch+1) % model_freq==0: - print("Save epoch %d model checkpoint!"%(epoch+1)) - torch.save(self.gen.state_dict(), - os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1, - self.config["checkpoint_names"]["generator_name"]))) - - torch.cuda.empty_cache() - print('Sample images {}_fake.jpg'.format(epoch + 1)) - self.gen.eval() - with torch.no_grad(): - sample = fake_image - saved_image1 = denorm(sample.cpu().data) - save_image(saved_image1, - os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4) \ No newline at end of file diff --git a/train_scripts/trainer_base.py b/train_scripts/trainer_base.py new file mode 100644 index 0000000..8d6cb49 --- /dev/null +++ b/train_scripts/trainer_base.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_base.py +# Created Date: Sunday January 16th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Monday, 17th January 2022 1:08:25 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +class TrainerBase(object): + + def __init__(self, config, reporter): + + self.config = config + # logger + self.reporter = reporter + + # Data loader + #============build train dataloader==============# + # TODO to modify the key: "your_train_dataset" to get your train dataset path + self.train_dataset = config["dataset_paths"][config["dataset_name"]] + #================================================# + print("Prepare the train dataloader...") + dlModulename = config["dataloader"] + package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'GetLoader') + self.dataloader_class = dataloaderClass + dataloader = self.dataloader_class(self.train_dataset, + config["batch_size"], + **config["dataset_params"]) + + self.train_loader= dataloader + + #========build evaluation dataloader=============# + # TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path + # eval_dataset = config["dataset_paths"][config["eval_dataset_name"]] + + # #================================================# + # print("Prepare the evaluation dataloader...") + # dlModulename = config["eval_dataloader"] + # package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True) + # dataloaderClass = getattr(package, 'EvalDataset') + # dataloader = dataloaderClass(eval_dataset, + # config["eval_batch_size"]) + # self.eval_loader= dataloader + + # self.eval_iter = len(dataloader)//config["eval_batch_size"] + # if len(dataloader)%config["eval_batch_size"]>0: + # self.eval_iter+=1 + + #==============build tensorboard=================# + if self.config["logger"] == "tensorboard": + from utilities.utilities import build_tensorboard + tensorboard_writer = build_tensorboard(self.config["project_summary"]) + self.logger = tensorboard_writer + elif self.config["logger"] == "wandb": + import wandb + wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", + tags=[self.config["tag"]], name=self.config["version"]) + + wandb.config = { + "total_step": self.config["total_step"], + "batch_size": self.config["batch_size"] + } + self.logger = wandb + + # TODO modify this function to build your models + def __init_framework__(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + pass + + # TODO modify this function to configurate the optimizer of your pipeline + def __setup_optimizers__(self): + pass + + + # TODO modify this function to evaluate your model + # Evaluate the checkpoint + def __evaluation__(self, + step = 0, + **kwargs + ): + pass + + + def train(self): + #===============build framework================# + self.init_framework() + + #===============build optimizer================# + # Optimizer + # TODO replace below lines to build your optimizer + print("build the optimizer...") + self.__setup_optimizers__() + + # set the start point for training loop + if self.config["phase"] == "finetune": + self.start = self.config["checkpoint_step"] + else: + self.start = 0 + + # Start time + import datetime + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from utilities.logo_class import logo_class + logo_class.print_start_training() \ No newline at end of file diff --git a/train_scripts/trainer_gan.py b/train_scripts/trainer_gan.py deleted file mode 100644 index d50a562..0000000 --- a/train_scripts/trainer_gan.py +++ /dev/null @@ -1,382 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -############################################################# -# File: trainer_condition_SN_multiscale.py -# Created Date: Saturday April 18th 2020 -# Author: Chen Xuanhong -# Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 6th July 2021 7:36:42 pm -# Modified By: Chen Xuanhong -# Copyright (c) 2020 Shanghai Jiao Tong University -############################################################# - - -import os -import time - -import torch -from torchvision.utils import save_image - -from components.Transform import Transform_block -from utilities.utilities import denorm - -class Trainer(object): - - def __init__(self, config, reporter): - - self.config = config - # logger - self.reporter = reporter - - # Data loader - #============build train dataloader==============# - # TODO to modify the key: "your_train_dataset" to get your train dataset path - self.train_dataset = config["dataset_paths"][config["dataset_name"]] - #================================================# - print("Prepare the train dataloader...") - dlModulename = config["dataloader"] - package = __import__("data_tools.dataloader_%s"%dlModulename, fromlist=True) - dataloaderClass = getattr(package, 'GetLoader') - self.dataloader_class = dataloaderClass - # dataloader = self.dataloader_class(self.train_dataset, - # config["batch_size_list"][0], - # config["imcrop_size_list"][0], - # **config["dataset_params"]) - - # self.train_loader= dataloader - - #========build evaluation dataloader=============# - # TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path - # eval_dataset = config["dataset_paths"][config["eval_dataset_name"]] - - # #================================================# - # print("Prepare the evaluation dataloader...") - # dlModulename = config["eval_dataloader"] - # package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True) - # dataloaderClass = getattr(package, 'EvalDataset') - # dataloader = dataloaderClass(eval_dataset, - # config["eval_batch_size"]) - # self.eval_loader= dataloader - - # self.eval_iter = len(dataloader)//config["eval_batch_size"] - # if len(dataloader)%config["eval_batch_size"]>0: - # self.eval_iter+=1 - - #==============build tensorboard=================# - if self.config["use_tensorboard"]: - from utilities.utilities import build_tensorboard - self.tensorboard_writer = build_tensorboard(self.config["project_summary"]) - - # TODO modify this function to build your models - def __init_framework__(self): - ''' - This function is designed to define the framework, - and print the framework information into the log file - ''' - #===============build models================# - print("build models...") - # TODO [import models here] - - model_config = self.config["model_configs"] - - if self.config["phase"] == "train": - gscript_name = "components." + model_config["g_model"]["script"] - dscript_name = "components." + model_config["d_model"]["script"] - elif self.config["phase"] == "finetune": - gscript_name = self.config["com_base"] + model_config["g_model"]["script"] - dscript_name = self.config["com_base"] + model_config["d_model"]["script"] - - class_name = model_config["g_model"]["class_name"] - package = __import__(gscript_name, fromlist=True) - gen_class = getattr(package, class_name) - self.gen = gen_class(**model_config["g_model"]["module_params"]) - - class_name = model_config["d_model"]["class_name"] - package = __import__(dscript_name, fromlist=True) - dis_class = getattr(package, class_name) - self.dis = dis_class(**model_config["d_model"]["module_params"]) - - # print and recorde model structure - self.reporter.writeInfo("Generator structure:") - self.reporter.writeModel(self.gen.__str__()) - self.reporter.writeInfo("Discriminator structure:") - self.reporter.writeModel(self.dis.__str__()) - - # train in GPU - if self.config["cuda"] >=0: - self.gen = self.gen.cuda() - self.dis = self.dis.cuda() - - # if in finetune phase, load the pretrained checkpoint - if self.config["phase"] == "finetune": - model_path = os.path.join(self.config["project_checkpoints"], - "epoch%d_%s.pth"%(self.config["checkpoint_step"], - self.config["checkpoint_names"]["generator_name"])) - self.gen.load_state_dict(torch.load(model_path)) - - model_path = os.path.join(self.config["project_checkpoints"], - "epoch%d_%s.pth"%(self.config["checkpoint_step"], - self.config["checkpoint_names"]["discriminator_name"])) - self.dis.load_state_dict(torch.load(model_path)) - - print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"])) - - - # TODO modify this function to evaluate your model - def __evaluation__(self, epoch, step = 0): - # Evaluate the checkpoint - self.network.eval() - total_psnr = 0 - total_num = 0 - with torch.no_grad(): - for _ in range(self.eval_iter): - hr, lr = self.eval_loader() - - if self.config["cuda"] >=0: - hr = hr.cuda() - lr = lr.cuda() - hr = (hr + 1.0)/2.0 * 255.0 - hr = torch.clamp(hr,0.0,255.0) - lr = (lr + 1.0)/2.0 * 255.0 - lr = torch.clamp(lr,0.0,255.0) - res = self.network(lr) - # res = (res + 1.0)/2.0 * 255.0 - # hr = (hr + 1.0)/2.0 * 255.0 - res = torch.clamp(res,0.0,255.0) - diff = (res-hr) ** 2 - diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt() - psnrs = 20. * (255. / diff).log10() - total_psnr+= psnrs.sum() - total_num+=res.shape[0] - final_psnr = total_psnr/total_num - print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"], - epoch, final_psnr)) - self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr)) - self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch) - - # TODO modify this function to configurate the optimizer of your pipeline - def __setup_optimizers__(self): - g_train_opt = self.config['g_optim_config'] - d_train_opt = self.config['d_optim_config'] - g_optim_params = [] - d_optim_params = [] - for k, v in self.gen.named_parameters(): - if v.requires_grad: - g_optim_params.append(v) - else: - self.reporter.writeInfo(f'Params {k} will not be optimized.') - print(f'Params {k} will not be optimized.') - - for k, v in self.dis.named_parameters(): - if v.requires_grad: - d_optim_params.append(v) - else: - self.reporter.writeInfo(f'Params {k} will not be optimized.') - print(f'Params {k} will not be optimized.') - - optim_type = self.config['optim_type'] - - if optim_type == 'Adam': - self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) - self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') - # self.optimizers.append(self.optimizer_g) - - - def train(self): - - ckpt_dir = self.config["project_checkpoints"] - log_frep = self.config["log_step"] - model_freq = self.config["model_save_epoch"] - total_epoch = self.config["total_epoch"] - - n_class = len(self.config["selected_style_dir"]) - # prep_weights= self.config["layersWeight"] - feature_w = self.config["feature_weight"] - transform_w = self.config["transform_weight"] - d_step = self.config["d_step"] - g_step = self.config["g_step"] - - batch_size_list = self.config["batch_size_list"] - switch_epoch_list = self.config["switch_epoch_list"] - imcrop_size_list = self.config["imcrop_size_list"] - sample_dir = self.config["project_samples"] - - current_epoch_index = 0 - - #===============build framework================# - self.__init_framework__() - - #===============build optimizer================# - # Optimizer - # TODO replace below lines to build your optimizer - print("build the optimizer...") - self.__setup_optimizers__() - - #===============build losses===================# - # TODO replace below lines to build your losses - Transform = Transform_block().cuda() - L1_loss = torch.nn.L1Loss() - MSE_loss = torch.nn.MSELoss() - Hinge_loss = torch.nn.ReLU().cuda() - - - # set the start point for training loop - if self.config["phase"] == "finetune": - start = self.config["checkpoint_epoch"] - 1 - else: - start = 0 - - - output_size = self.dis.get_outputs_len() - - print("prepare the fixed labels...") - fix_label = [i for i in range(n_class)] - fix_label = torch.tensor(fix_label).long().cuda() - # fix_label = fix_label.view(n_class,1) - # fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1) - - # Start time - import datetime - print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) - - from utilities.logo_class import logo_class - logo_class.print_start_training() - start_time = time.time() - - for epoch in range(start, total_epoch): - - # switch training image size - if epoch in switch_epoch_list: - print('Current epoch: {}'.format(epoch)) - print('***Redefining the dataloader for progressive training.***') - print('***Current spatial size is {} and batch size is {}.***'.format( - imcrop_size_list[current_epoch_index], batch_size_list[current_epoch_index])) - del self.train_loader - self.train_loader = self.dataloader_class(self.train_dataset, - batch_size_list[current_epoch_index], - imcrop_size_list[current_epoch_index], - **self.config["dataset_params"]) - # Caculate the epoch number - step_epoch = len(self.train_loader) - step_epoch = step_epoch // (d_step + g_step) - print("Total step = %d in each epoch"%step_epoch) - current_epoch_index += 1 - - for step in range(step_epoch): - self.dis.train() - self.gen.train() - - # ================== Train D ================== # - # Compute loss with real images - for _ in range(d_step): - content_images,style_images,label = self.train_loader.next() - label = label.long() - - d_out = self.dis(style_images,label) - d_loss_real = 0 - for i in range(output_size): - temp = Hinge_loss(1 - d_out[i]).mean() - d_loss_real += temp - - d_loss_photo = 0 - d_out = self.dis(content_images,label) - for i in range(output_size): - temp = Hinge_loss(1 + d_out[i]).mean() - d_loss_photo += temp - - fake_image,_= self.gen(content_images,label) - d_out = self.dis(fake_image.detach(),label) - d_loss_fake = 0 - for i in range(output_size): - temp = Hinge_loss(1 + d_out[i]).mean() - # temp *= prep_weights[i] - d_loss_fake += temp - - # Backward + Optimize - d_loss = d_loss_real + d_loss_photo + d_loss_fake - self.d_optimizer.zero_grad() - d_loss.backward() - self.d_optimizer.step() - - # ================== Train G ================== # - for _ in range(g_step): - - content_images,_,_ = self.train_loader.next() - fake_image,real_feature = self.gen(content_images,label) - fake_feature = self.gen(fake_image, get_feature=True) - d_out = self.dis(fake_image,label.long()) - - g_feature_loss = L1_loss(fake_feature,real_feature) - g_transform_loss = MSE_loss(Transform(content_images), Transform(fake_image)) - g_loss_fake = 0 - for i in range(output_size): - temp = -d_out[i].mean() - # temp *= prep_weights[i] - g_loss_fake += temp - - # backward & optimize - g_loss = g_loss_fake + g_feature_loss* feature_w + g_transform_loss* transform_w - self.g_optimizer.zero_grad() - g_loss.backward() - self.g_optimizer.step() - - - # Print out log info - if (step + 1) % log_frep == 0: - elapsed = time.time() - start_time - elapsed = str(datetime.timedelta(seconds=elapsed)) - - # cumulative steps - cum_step = (step_epoch * epoch + step + 1) - - epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, d_loss_real: {:.4f}, \\\ - d_loss_photo: {:.4f}, d_loss_fake: {:.4f}, g_loss: {:.4f}, g_loss_fake: {:.4f}, \\\ - g_feature_loss: {:.4f}, g_transform_loss: {:.4f}".format(self.config["version"], - epoch + 1, total_epoch, elapsed, step + 1, step_epoch, - d_loss.item(), d_loss_real.item(), d_loss_photo.item(), - d_loss_fake.item(), g_loss.item(), g_loss_fake.item(),\ - g_feature_loss.item(), g_transform_loss.item()) - print(epochinformation) - self.reporter.writeRawInfo(epochinformation) - - if self.config["use_tensorboard"]: - self.tensorboard_writer.add_scalar('data/d_loss', d_loss.item(), cum_step) - self.tensorboard_writer.add_scalar('data/d_loss_real', d_loss_real.item(), cum_step) - self.tensorboard_writer.add_scalar('data/d_loss_photo', d_loss_photo.item(), cum_step) - self.tensorboard_writer.add_scalar('data/d_loss_fake', d_loss_fake.item(), cum_step) - self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step) - self.tensorboard_writer.add_scalar('data/g_loss_fake', g_loss_fake.item(), cum_step) - self.tensorboard_writer.add_scalar('data/g_feature_loss', g_feature_loss, cum_step) - self.tensorboard_writer.add_scalar('data/g_transform_loss', g_transform_loss, cum_step) - - #===============adjust learning rate============# - if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: - print("Learning rate decay") - for p in self.optimizer.param_groups: - p['lr'] *= self.config["lr_decay"] - print("Current learning rate is %f"%p['lr']) - - #===============save checkpoints================# - if (epoch+1) % model_freq==0: - print("Save epoch %d model checkpoint!"%(epoch+1)) - torch.save(self.gen.state_dict(), - os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1, - self.config["checkpoint_names"]["generator_name"]))) - torch.save(self.dis.state_dict(), - os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1, - self.config["checkpoint_names"]["discriminator_name"]))) - - torch.cuda.empty_cache() - print('Sample images {}_fake.jpg'.format(step + 1)) - self.gen.eval() - with torch.no_grad(): - sample = content_images[0, :, :, :].unsqueeze(0) - saved_image1 = denorm(sample.cpu().data) - for index in range(n_class): - fake_images,_ = self.gen(sample, fix_label[index].unsqueeze(0)) - saved_image1 = torch.cat((saved_image1, denorm(fake_images.cpu().data)), 0) - save_image(saved_image1, - os.path.join(sample_dir, '{}_fake.jpg'.format(step + 1)),nrow=3) \ No newline at end of file diff --git a/train_scripts/trainer_naiv512.py b/train_scripts/trainer_naiv512.py index 28acd41..5a0823c 100644 --- a/train_scripts/trainer_naiv512.py +++ b/train_scripts/trainer_naiv512.py @@ -5,15 +5,17 @@ # Created Date: Sunday January 9th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Sunday, 9th January 2022 12:31:03 am +# Last Modified: Tuesday, 11th January 2022 3:06:14 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# import os import time +import random import torch +import torch.nn.functional as F from torchvision.utils import save_image from utilities.utilities import denorm @@ -182,12 +184,10 @@ class Trainer(object): model_freq = self.config["model_save_epoch"] total_epoch = self.config["total_epoch"] batch_size = self.config["batch_size"] - style_img = self.config["style_img_path"] # prep_weights= self.config["layersWeight"] content_w = self.config["content_weight"] style_w = self.config["style_weight"] - crop_size = self.config["imcrop_size"] sample_dir = self.config["project_samples"] @@ -231,32 +231,30 @@ class Trainer(object): step_epoch = step_epoch // batch_size print("Total step = %d in each epoch"%step_epoch) + randindex = [i for i in range(batch_size)] + + # step_epoch = 2 for epoch in range(start, total_epoch): for step in range(step_epoch): - self.gen.train() + image1, image2 = self.train_loader.next() + random.shuffle(randindex) - src_image1, src_image2 = self.train_loader.next() - - - img_att = src_image1 + img_att = image1 if step%2 == 0: - img_id = src_image2 + img_id = image2 # swap with same id, different pose else: - img_id = src_image2[randindex] + img_id = image2[randindex] # swap with different face - src_image1_112 = F.interpolate(src_image1,size=(112,112), mode='bicubic') img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') - img_id_112_norm = spnorm(img_id_112) - - latent_id = model.netArc(img_id_112_norm) + latent_id = self.arcface(img_id_112) latent_id = F.normalize(latent_id, p=2, dim=1) - losses, img_fake= self.gen(src_image1, latent_id) + losses, img_fake= self.gen(image1, latent_id) # update Generator weights losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] @@ -275,18 +273,6 @@ class Trainer(object): loss_D.backward() optimizer_D.step() - self.gen.train() - - content_images = self.train_loader.next() - fake_image = self.gen(content_images) - generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11)) - content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11)) - content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2']) - - style_loss = 0.0 - for key, value in generated_features.items(): - s_loss = MSE_loss(Gram(value), style_gram[key]) - style_loss += s_loss # backward & optimize g_loss = content_loss* content_w + style_loss* style_w diff --git a/train_yamls/train_512FM.yaml b/train_yamls/train_512FM.yaml new file mode 100644 index 0000000..1d14b6f --- /dev/null +++ b/train_yamls/train_512FM.yaml @@ -0,0 +1,62 @@ +# Related scripts +train_script_name: FM + +# models' scripts +model_configs: + g_model: + script: Generator + class_name: Generator + module_params: + g_conv_dim: 512 + g_kernel_size: 3 + res_num: 9 + + d_model: + script: projected_discriminator + class_name: ProjectedDiscriminator + module_params: + diffaug: False + interp224: False + backbone_kwargs: {} + +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 1 + +# Dataset +dataloader: VGGFace2HQ +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 8 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +id_weight: 10.0 +reconstruct_weight: 1.0 +feature_match_weight: 5.0 + +# Log +log_step: 10 +model_save_step: 20 +total_step: 1000000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_FastNST.yaml b/train_yamls/train_FastNST.yaml deleted file mode 100644 index 4f98a08..0000000 --- a/train_yamls/train_FastNST.yaml +++ /dev/null @@ -1,83 +0,0 @@ -# Related scripts -train_script_name: FastNST - -# models' scripts -model_configs: - g_model: - script: FastNST - class_name: Generator - module_params: - g_conv_dim: 32 - g_kernel_size: 3 - res_num: 6 - n_class: 11 - image_size: 256 - window_size: 8 - -# Training information -total_epoch: 120 -batch_size: 16 -imcrop_size: 256 -max2Keep: 10 - -# Dataset -style_img_path: "G:\\UltraHighStyleTransfer\\reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\images\\mosaic.jpg" -dataloader: place365 -dataset_name: styletransfer -dataset_params: - random_seed: 1234 - dataloader_workers: 8 - color_jitter: Enable - color_config: - brightness: 0.05 - contrast: 0.05 - saturation: 0.05 - hue: 0.05 - - selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor', - 'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field', - 'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse', - 'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus', - 'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet', - 'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse', - 'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway', - 'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor', - 'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior', - 'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge', - 'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe', - 'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah', - 'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse', - 'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain', - 'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park', - 'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track', - 'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge', - 'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole', - 's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house', - 'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct', - 'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave', - 'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft', - 'b/building_facade', 'c/cemetery'] - - -eval_dataloader: DIV2K_hdf5 -eval_dataset_name: DF2K_H5_Eval -eval_batch_size: 2 - -# Dataset - -# Optimizer -optim_type: Adam -g_optim_config: - lr: !!float 2e-4 - betas: [0.9, 0.99] - -content_weight: 2.0 -style_weight: 1.0 -layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0] - -# Log -log_step: 100 -model_save_epoch: 1 -use_tensorboard: True -checkpoint_names: - generator_name: Generator \ No newline at end of file diff --git a/train_yamls/train_FastNST_CNN.yaml b/train_yamls/train_FastNST_CNN.yaml deleted file mode 100644 index 0b20b16..0000000 --- a/train_yamls/train_FastNST_CNN.yaml +++ /dev/null @@ -1,108 +0,0 @@ -# Related scripts -train_script_name: FastNST_CNN - -# models' scripts -model_configs: - g_model: - script: FastNST_CNN - class_name: Generator - module_params: - g_conv_dim: 32 - g_kernel_size: 3 - res_num: 6 - n_class: 11 - image_size: 256 - window_size: 8 - -# Training information -total_epoch: 120 -batch_size: 16 -imcrop_size: 256 -max2Keep: 10 - -# Dataset -style_img_path: "images/mosaic.jpg" -dataloader: place365 -dataset_name: styletransfer -dataset_params: - random_seed: 1234 - dataloader_workers: 8 - color_jitter: Enable - color_config: - brightness: 0.05 - contrast: 0.05 - saturation: 0.05 - hue: 0.05 - - selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor', - 'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field', - 'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse', - 'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus', - 'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet', - 'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse', - 'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway', - 'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor', - 'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior', - 'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge', - 'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe', - 'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah', - 'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse', - 'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain', - 'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park', - 'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track', - 'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge', - 'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole', - 's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house', - 'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct', - 'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave', - 'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft', - 'b/building_facade', 'c/cemetery'] - # selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor', - # 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field', - # 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse', - # 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus', - # 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet', - # 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse', - # 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway', - # 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor', - # 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior', - # 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge', - # 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe', - # 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah', - # 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse', - # 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain', - # 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park', - # 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track', - # 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge', - # 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole', - # 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house', - # 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct', - # 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave', - # 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft', - # 'b_building_facade', - # 'c_cemetery' - # ] - - -eval_dataloader: DIV2K_hdf5 -eval_dataset_name: DF2K_H5_Eval -eval_batch_size: 2 - -# Dataset - -# Optimizer -optim_type: Adam -g_optim_config: - lr: !!float 2e-4 - betas: [0.9, 0.99] - -content_weight: 2.0 -style_weight: 1.0 -layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0] - -# Log -log_step: 100 -model_save_epoch: 1 -use_tensorboard: True -checkpoint_names: - generator_name: Generator \ No newline at end of file diff --git a/train_yamls/train_FastNST_CNN_Resblock.yaml b/train_yamls/train_FastNST_CNN_Resblock.yaml deleted file mode 100644 index 99c2c09..0000000 --- a/train_yamls/train_FastNST_CNN_Resblock.yaml +++ /dev/null @@ -1,108 +0,0 @@ -# Related scripts -train_script_name: FastNST_CNN - -# models' scripts -model_configs: - g_model: - script: FastNST_CNN_Resblock - class_name: Generator - module_params: - g_conv_dim: 32 - g_kernel_size: 3 - res_num: 6 - n_class: 11 - image_size: 256 - window_size: 8 - -# Training information -total_epoch: 120 -batch_size: 16 -imcrop_size: 256 -max2Keep: 10 - -# Dataset -style_img_path: "images/mosaic.jpg" -dataloader: place365 -dataset_name: styletransfer -dataset_params: - random_seed: 1234 - dataloader_workers: 8 - color_jitter: Enable - color_config: - brightness: 0.05 - contrast: 0.05 - saturation: 0.05 - hue: 0.05 - - selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor', - 'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field', - 'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse', - 'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus', - 'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet', - 'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse', - 'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway', - 'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor', - 'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior', - 'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge', - 'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe', - 'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah', - 'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse', - 'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain', - 'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park', - 'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track', - 'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge', - 'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole', - 's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house', - 'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct', - 'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave', - 'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft', - 'b/building_facade', 'c/cemetery'] - # selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor', - # 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field', - # 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse', - # 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus', - # 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet', - # 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse', - # 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway', - # 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor', - # 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior', - # 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge', - # 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe', - # 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah', - # 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse', - # 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain', - # 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park', - # 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track', - # 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge', - # 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole', - # 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house', - # 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct', - # 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave', - # 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft', - # 'b_building_facade', - # 'c_cemetery' - # ] - - -eval_dataloader: DIV2K_hdf5 -eval_dataset_name: DF2K_H5_Eval -eval_batch_size: 2 - -# Dataset - -# Optimizer -optim_type: Adam -g_optim_config: - lr: !!float 2e-4 - betas: [0.9, 0.99] - -content_weight: 2.0 -style_weight: 1.0 -layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0] - -# Log -log_step: 100 -model_save_epoch: 1 -use_tensorboard: True -checkpoint_names: - generator_name: Generator \ No newline at end of file diff --git a/train_yamls/train_FastNST_Liif.yaml b/train_yamls/train_FastNST_Liif.yaml deleted file mode 100644 index 3a1325c..0000000 --- a/train_yamls/train_FastNST_Liif.yaml +++ /dev/null @@ -1,110 +0,0 @@ -# Related scripts -train_script_name: FastNST_Liif - -# models' scripts -model_configs: - g_model: - script: FastNST_Liif - class_name: Generator - module_params: - g_conv_dim: 32 - g_kernel_size: 3 - res_num: 6 - n_class: 11 - image_size: 256 - mlp_hidden_list: [32,32] - batch_size: 10 - window_size: 8 - -# Training information -total_epoch: 120 -batch_size: 10 -imcrop_size: 256 -max2Keep: 10 - -# Dataset -style_img_path: "images/mosaic.jpg" -dataloader: place365 -dataset_name: styletransfer -dataset_params: - random_seed: 1234 - dataloader_workers: 8 - color_jitter: Enable - color_config: - brightness: 0.05 - contrast: 0.05 - saturation: 0.05 - hue: 0.05 - - selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor', - 'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field', - 'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse', - 'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus', - 'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet', - 'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse', - 'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway', - 'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor', - 'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior', - 'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge', - 'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe', - 'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah', - 'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse', - 'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain', - 'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park', - 'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track', - 'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge', - 'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole', - 's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house', - 'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct', - 'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave', - 'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft', - 'b/building_facade', 'c/cemetery'] - # selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor', - # 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field', - # 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse', - # 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus', - # 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet', - # 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse', - # 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway', - # 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor', - # 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior', - # 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge', - # 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe', - # 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah', - # 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse', - # 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain', - # 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park', - # 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track', - # 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge', - # 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole', - # 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house', - # 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct', - # 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave', - # 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft', - # 'b_building_facade', - # 'c_cemetery' - # ] - - -eval_dataloader: DIV2K_hdf5 -eval_dataset_name: DF2K_H5_Eval -eval_batch_size: 2 - -# Dataset - -# Optimizer -optim_type: Adam -g_optim_config: - lr: !!float 2e-4 - betas: [0.9, 0.99] - -content_weight: 2.0 -style_weight: 1.0 -layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0] - -# Log -log_step: 100 -model_save_epoch: 1 -use_tensorboard: True -checkpoint_names: - generator_name: Generator \ No newline at end of file diff --git a/train_yamls/train_FastNST_Liif_warp.yaml b/train_yamls/train_FastNST_Liif_warp.yaml deleted file mode 100644 index 7fa4ded..0000000 --- a/train_yamls/train_FastNST_Liif_warp.yaml +++ /dev/null @@ -1,109 +0,0 @@ -# Related scripts -train_script_name: FastNST_Liif - -# models' scripts -model_configs: - g_model: - script: FastNST_Liif_warp - class_name: Generator - module_params: - g_conv_dim: 32 - g_kernel_size: 3 - res_num: 6 - n_class: 11 - image_size: 256 - batch_size: 16 - window_size: 8 - -# Training information -total_epoch: 120 -batch_size: 16 -imcrop_size: 256 -max2Keep: 10 - -# Dataset -style_img_path: "images/mosaic.jpg" -dataloader: place365 -dataset_name: styletransfer -dataset_params: - random_seed: 1234 - dataloader_workers: 8 - color_jitter: Enable - color_config: - brightness: 0.05 - contrast: 0.05 - saturation: 0.05 - hue: 0.05 - - selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor', - 'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field', - 'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse', - 'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus', - 'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet', - 'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse', - 'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway', - 'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor', - 'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior', - 'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge', - 'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe', - 'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah', - 'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse', - 'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain', - 'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park', - 'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track', - 'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge', - 'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole', - 's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house', - 'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct', - 'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave', - 'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft', - 'b/building_facade', 'c/cemetery'] - # selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor', - # 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field', - # 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse', - # 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus', - # 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet', - # 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse', - # 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway', - # 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor', - # 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior', - # 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge', - # 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe', - # 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah', - # 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse', - # 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain', - # 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park', - # 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track', - # 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge', - # 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole', - # 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house', - # 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct', - # 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave', - # 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft', - # 'b_building_facade', - # 'c_cemetery' - # ] - - -eval_dataloader: DIV2K_hdf5 -eval_dataset_name: DF2K_H5_Eval -eval_batch_size: 2 - -# Dataset - -# Optimizer -optim_type: Adam -g_optim_config: - lr: !!float 2e-4 - betas: [0.9, 0.99] - -content_weight: 2.0 -style_weight: 1.0 -layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0] - -# Log -log_step: 100 -model_save_epoch: 1 -use_tensorboard: True -checkpoint_names: - generator_name: Generator \ No newline at end of file diff --git a/train_yamls/train_FastNST_Liif_warpinvo.yaml b/train_yamls/train_FastNST_Liif_warpinvo.yaml deleted file mode 100644 index 7fa4ded..0000000 --- a/train_yamls/train_FastNST_Liif_warpinvo.yaml +++ /dev/null @@ -1,109 +0,0 @@ -# Related scripts -train_script_name: FastNST_Liif - -# models' scripts -model_configs: - g_model: - script: FastNST_Liif_warp - class_name: Generator - module_params: - g_conv_dim: 32 - g_kernel_size: 3 - res_num: 6 - n_class: 11 - image_size: 256 - batch_size: 16 - window_size: 8 - -# Training information -total_epoch: 120 -batch_size: 16 -imcrop_size: 256 -max2Keep: 10 - -# Dataset -style_img_path: "images/mosaic.jpg" -dataloader: place365 -dataset_name: styletransfer -dataset_params: - random_seed: 1234 - dataloader_workers: 8 - color_jitter: Enable - color_config: - brightness: 0.05 - contrast: 0.05 - saturation: 0.05 - hue: 0.05 - - selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor', - 'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field', - 'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse', - 'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus', - 'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet', - 'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse', - 'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway', - 'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor', - 'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior', - 'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge', - 'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe', - 'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah', - 'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse', - 'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain', - 'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park', - 'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track', - 'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge', - 'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole', - 's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house', - 'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct', - 'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave', - 'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft', - 'b/building_facade', 'c/cemetery'] - # selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor', - # 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field', - # 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse', - # 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus', - # 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet', - # 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse', - # 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway', - # 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor', - # 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior', - # 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge', - # 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe', - # 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah', - # 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse', - # 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain', - # 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park', - # 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track', - # 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge', - # 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole', - # 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house', - # 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct', - # 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave', - # 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft', - # 'b_building_facade', - # 'c_cemetery' - # ] - - -eval_dataloader: DIV2K_hdf5 -eval_dataset_name: DF2K_H5_Eval -eval_batch_size: 2 - -# Dataset - -# Optimizer -optim_type: Adam -g_optim_config: - lr: !!float 2e-4 - betas: [0.9, 0.99] - -content_weight: 2.0 -style_weight: 1.0 -layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0] - -# Log -log_step: 100 -model_save_epoch: 1 -use_tensorboard: True -checkpoint_names: - generator_name: Generator \ No newline at end of file diff --git a/train_yamls/train_FastNST_SWD.yaml b/train_yamls/train_FastNST_SWD.yaml deleted file mode 100644 index 7084d18..0000000 --- a/train_yamls/train_FastNST_SWD.yaml +++ /dev/null @@ -1,109 +0,0 @@ -# Related scripts -train_script_name: FastNST_SWD - -# models' scripts -model_configs: - g_model: - script: FastNST_CNN_Resblock - class_name: Generator - module_params: - g_conv_dim: 32 - g_kernel_size: 3 - res_num: 6 - n_class: 11 - image_size: 256 - window_size: 8 - -# Training information -total_epoch: 120 -batch_size: 16 -imcrop_size: 256 -max2Keep: 10 - -# Dataset -style_img_path: "images/mosaic.jpg" -dataloader: place365 -dataset_name: styletransfer -dataset_params: - random_seed: 1234 - dataloader_workers: 8 - color_jitter: Enable - color_config: - brightness: 0.05 - contrast: 0.05 - saturation: 0.05 - hue: 0.05 - - selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor', - 'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field', - 'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse', - 'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus', - 'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet', - 'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse', - 'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway', - 'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor', - 'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior', - 'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge', - 'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe', - 'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah', - 'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse', - 'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain', - 'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park', - 'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track', - 'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge', - 'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole', - 's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house', - 'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct', - 'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave', - 'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft', - 'b/building_facade', 'c/cemetery'] - # selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor', - # 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field', - # 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse', - # 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus', - # 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet', - # 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse', - # 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway', - # 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor', - # 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior', - # 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge', - # 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe', - # 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah', - # 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse', - # 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain', - # 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park', - # 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track', - # 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge', - # 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole', - # 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house', - # 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct', - # 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave', - # 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft', - # 'b_building_facade', - # 'c_cemetery' - # ] - - -eval_dataloader: DIV2K_hdf5 -eval_dataset_name: DF2K_H5_Eval -eval_batch_size: 2 - -# Dataset - -# Optimizer -optim_type: Adam -g_optim_config: - lr: !!float 2e-4 - betas: [0.9, 0.99] - -content_weight: 1.0 -style_weight: 10.0 -layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0] -swd_dim: 32 - -# Log -log_step: 100 -model_save_epoch: 1 -use_tensorboard: True -checkpoint_names: - generator_name: Generator \ No newline at end of file diff --git a/train_yamls/train_noskip.yaml b/train_yamls/train_noskip.yaml deleted file mode 100644 index a8c76f6..0000000 --- a/train_yamls/train_noskip.yaml +++ /dev/null @@ -1,98 +0,0 @@ -# Related scripts -train_script_name: gan - -# models' scripts -model_configs: - g_model: - script: Conditional_Generator_Noskip - class_name: Generator - module_params: - g_conv_dim: 32 - g_kernel_size: 3 - res_num: 8 - n_class: 11 - d_model: - script: Conditional_Discriminator_Projection_big - class_name: Discriminator - module_params: - d_conv_dim: 32 - d_kernel_size: 5 - -# Training information -total_epoch: 120 -batch_size_list: [8, 4, 2] -switch_epoch_list: [0, 5, 10] -imcrop_size_list: [256, 512, 768] -max2Keep: 10 -movingAverage: 0.05 -d_success_threshold: 0.8 -d_step: 3 -g_step: 1 - -# Dataset -dataloader: condition -dataset_name: styletransfer -dataset_params: - random_seed: 1234 - dataloader_workers: 8 - color_jitter: Enable - color_config: - brightness: 0.05 - contrast: 0.05 - saturation: 0.05 - hue: 0.05 - selected_style_dir: ['berthe-morisot','edvard-munch', - 'ernst-ludwig-kirchner','jackson-pollock','kandinsky','monet', - 'nicholas','paul-cezanne','picasso','samuel','vangogh'] - selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor', - 'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field', - 'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse', - 'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus', - 'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet', - 'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse', - 'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway', - 'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor', - 'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior', - 'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge', - 'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe', - 'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah', - 'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse', - 'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain', - 'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park', - 'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track', - 'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge', - 'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole', - 's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house', - 'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct', - 'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave', - 'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft', - 'b/building_facade', 'c/cemetery'] - - -eval_dataloader: DIV2K_hdf5 -eval_dataset_name: DF2K_H5_Eval -eval_batch_size: 2 - -# Dataset - -# Optimizer -optim_type: Adam -g_optim_config: - lr: !!float 2e-4 - betas: [0.9, 0.99] -d_optim_config: - lr: !!float 2e-4 - betas: [0.9, 0.99] - -feature_weight: 50.0 -transform_weight: 50.0 -layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0] - -# Log -log_step: 1000 -sampleStep: 2000 -model_save_epoch: 1 -useTensorboard: True -checkpoint_names: - generator_name: Generator - discriminator_name: Discriminator \ No newline at end of file diff --git a/utilities/plot.py b/utilities/plot.py new file mode 100644 index 0000000..0da1c75 --- /dev/null +++ b/utilities/plot.py @@ -0,0 +1,37 @@ +import numpy as np +import math +import PIL + +def postprocess(x): + """[0,1] to uint8.""" + + x = np.clip(255 * x, 0, 255) + x = np.cast[np.uint8](x) + return x + +def tile(X, rows, cols): + """Tile images for display.""" + tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype) + for i in range(rows): + for j in range(cols): + idx = i * cols + j + if idx < X.shape[0]: + img = X[idx,...] + tiling[ + i*X.shape[1]:(i+1)*X.shape[1], + j*X.shape[2]:(j+1)*X.shape[2], + :] = img + return tiling + + +def plot_batch(X, out_path): + """Save batch of images tiled.""" + n_channels = X.shape[3] + if n_channels > 3: + X = X[:,:,:,np.random.choice(n_channels, size = 3)] + X = postprocess(X) + rc = math.sqrt(X.shape[0]) + rows = cols = math.ceil(rc) + canvas = tile(X, rows, cols) + canvas = np.squeeze(canvas) + PIL.Image.fromarray(canvas).save(out_path) \ No newline at end of file