From f48dc8cf622c42874095f01c22426e5fc810991c Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Wed, 20 Apr 2022 18:36:26 +0800 Subject: [PATCH] training scripts released --- .gitignore | 5 +- data/data_loader_Swapping.py | 127 ++++++++++ models/base_model.py | 54 +++++ models/fs_networks_fix.py | 169 ++++++++++++++ models/projected_model.py | 122 ++++++++++ models/projectionhead.py | 14 ++ options/test_options.py | 4 +- pg_modules/blocks.py | 325 ++++++++++++++++++++++++++ pg_modules/diffaug.py | 76 ++++++ pg_modules/projected_discriminator.py | 191 +++++++++++++++ pg_modules/projector.py | 158 +++++++++++++ train.py | 293 +++++++++++++++++++++++ util/json_config.py | 15 ++ util/logo_class.py | 44 ++++ util/plot.py | 37 +++ util/save_heatmap.py | 57 +++++ 16 files changed, 1688 insertions(+), 3 deletions(-) create mode 100644 data/data_loader_Swapping.py create mode 100644 models/fs_networks_fix.py create mode 100644 models/projected_model.py create mode 100644 models/projectionhead.py create mode 100644 pg_modules/blocks.py create mode 100644 pg_modules/diffaug.py create mode 100644 pg_modules/projected_discriminator.py create mode 100644 pg_modules/projector.py create mode 100644 train.py create mode 100644 util/json_config.py create mode 100644 util/logo_class.py create mode 100644 util/plot.py create mode 100644 util/save_heatmap.py diff --git a/.gitignore b/.gitignore index fd548c0..f0e744d 100644 --- a/.gitignore +++ b/.gitignore @@ -135,4 +135,7 @@ checkpoints/ *.zip *.avi *.pdf -*.pptx \ No newline at end of file +*.pptx + +*.pth +*.onnx \ No newline at end of file diff --git a/data/data_loader_Swapping.py b/data/data_loader_Swapping.py new file mode 100644 index 0000000..f254928 --- /dev/null +++ b/data/data_loader_Swapping.py @@ -0,0 +1,127 @@ +import os +import glob +import torch +import random +from PIL import Image +from pathlib import Path +from torch.utils import data +from torchvision import transforms as T +# from StyleResize import StyleResize + +class data_prefetcher(): + def __init__(self, loader): + self.loader = loader + self.dataiter = iter(loader) + self.stream = torch.cuda.Stream() + self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1) + self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1) + # With Amp, it isn't necessary to manually convert data to half. + # if args.fp16: + # self.mean = self.mean.half() + # self.std = self.std.half() + self.num_images = len(loader) + self.preload() + + def preload(self): + try: + self.src_image1, self.src_image2 = next(self.dataiter) + except StopIteration: + self.dataiter = iter(self.loader) + self.src_image1, self.src_image2 = next(self.dataiter) + + with torch.cuda.stream(self.stream): + self.src_image1 = self.src_image1.cuda(non_blocking=True) + self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std) + self.src_image2 = self.src_image2.cuda(non_blocking=True) + self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + src_image1 = self.src_image1 + src_image2 = self.src_image2 + self.preload() + return src_image1, src_image2 + + def __len__(self): + """Return the number of images.""" + return self.num_images + +class SwappingDataset(data.Dataset): + """Dataset class for the Artworks dataset and content dataset.""" + + def __init__(self, + image_dir, + img_transform, + subffix='jpg', + random_seed=1234): + """Initialize and preprocess the Swapping dataset.""" + self.image_dir = image_dir + self.img_transform = img_transform + self.subffix = subffix + self.dataset = [] + self.random_seed = random_seed + self.preprocess() + self.num_images = len(self.dataset) + + def preprocess(self): + """Preprocess the Swapping dataset.""" + print("processing Swapping dataset images...") + + temp_path = os.path.join(self.image_dir,'*/') + pathes = glob.glob(temp_path) + self.dataset = [] + for dir_item in pathes: + join_path = glob.glob(os.path.join(dir_item,'*.jpg')) + print("processing %s"%dir_item,end='\r') + temp_list = [] + for item in join_path: + temp_list.append(item) + self.dataset.append(temp_list) + random.seed(self.random_seed) + random.shuffle(self.dataset) + print('Finished preprocessing the Swapping dataset, total dirs number: %d...'%len(self.dataset)) + + def __getitem__(self, index): + """Return two src domain images and two dst domain images.""" + dir_tmp1 = self.dataset[index] + dir_tmp1_len = len(dir_tmp1) + + filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] + filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] + image1 = self.img_transform(Image.open(filename1)) + image2 = self.img_transform(Image.open(filename2)) + return image1, image2 + + def __len__(self): + """Return the number of images.""" + return self.num_images + +def GetLoader( dataset_roots, + batch_size=16, + dataloader_workers=8, + random_seed = 1234 + ): + """Build and return a data loader.""" + + num_workers = dataloader_workers + data_root = dataset_roots + random_seed = random_seed + + c_transforms = [] + + c_transforms.append(T.ToTensor()) + c_transforms = T.Compose(c_transforms) + + content_dataset = SwappingDataset( + data_root, + c_transforms, + "jpg", + random_seed) + content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size, + drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True) + prefetcher = data_prefetcher(content_data_loader) + return prefetcher + +def denorm(x): + out = (x + 1) / 2 + return out.clamp_(0, 1) \ No newline at end of file diff --git a/models/base_model.py b/models/base_model.py index f3f6b53..1799129 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -37,6 +37,19 @@ class BaseModel(torch.nn.Module): def save(self, label): pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids=None): + save_filename = '{}_net_{}.pth'.format(epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if torch.cuda.is_available(): + network.cuda() + + def save_optim(self, network, network_label, epoch_label, gpu_ids=None): + save_filename = '{}_optim_{}.pth'.format(epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.state_dict(), save_path) # helper saving function that can be used by subclasses def save_network(self, network, network_label, epoch_label, gpu_ids): @@ -63,6 +76,47 @@ class BaseModel(torch.nn.Module): except: pretrained_dict = torch.load(save_path) model_dict = network.state_dict() + try: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + network.load_state_dict(pretrained_dict) + if self.opt.verbose: + print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) + except: + print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) + for k, v in pretrained_dict.items(): + if v.size() == model_dict[k].size(): + model_dict[k] = v + + if sys.version_info >= (3,0): + not_initialized = set() + else: + from sets import Set + not_initialized = Set() + + for k, v in model_dict.items(): + if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): + not_initialized.add(k.split('.')[0]) + + print(sorted(not_initialized)) + network.load_state_dict(model_dict) + + # helper loading function that can be used by subclasses + def load_optim(self, network, network_label, epoch_label, save_dir=''): + save_filename = '%s_optim_%s.pth' % (epoch_label, network_label) + if not save_dir: + save_dir = self.save_dir + save_path = os.path.join(save_dir, save_filename) + if not os.path.isfile(save_path): + print('%s not exists yet!' % save_path) + if network_label == 'G': + raise('Generator must exist!') + else: + #network.load_state_dict(torch.load(save_path)) + try: + network.load_state_dict(torch.load(save_path, map_location=torch.device("cpu"))) + except: + pretrained_dict = torch.load(save_path, map_location=torch.device("cpu")) + model_dict = network.state_dict() try: pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} network.load_state_dict(pretrained_dict) diff --git a/models/fs_networks_fix.py b/models/fs_networks_fix.py new file mode 100644 index 0000000..af641c8 --- /dev/null +++ b/models/fs_networks_fix.py @@ -0,0 +1,169 @@ +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +import torch +import torch.nn as nn + + +class InstanceNorm(nn.Module): + def __init__(self, epsilon=1e-8): + """ + @notice: avoid in-place ops. + https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 + """ + super(InstanceNorm, self).__init__() + self.epsilon = epsilon + + def forward(self, x): + x = x - torch.mean(x, (2, 3), True) + tmp = torch.mul(x, x) # or x ** 2 + tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) + return x * tmp + +class ApplyStyle(nn.Module): + """ + @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb + """ + def __init__(self, latent_size, channels): + super(ApplyStyle, self).__init__() + self.linear = nn.Linear(latent_size, channels * 2) + + def forward(self, x, latent): + style = self.linear(latent) # style => [batch_size, n_channels*2] + shape = [-1, 2, x.size(1), 1, 1] + style = style.view(shape) # [batch_size, 2, n_channels, ...] + #x = x * (style[:, 0] + 1.) + style[:, 1] + x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1 + return x + +class ResnetBlock_Adain(nn.Module): + def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): + super(ResnetBlock_Adain, self).__init__() + + p = 0 + conv1 = [] + if padding_type == 'reflect': + conv1 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv1 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] + self.conv1 = nn.Sequential(*conv1) + self.style1 = ApplyStyle(latent_size, dim) + self.act1 = activation + + p = 0 + conv2 = [] + if padding_type == 'reflect': + conv2 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv2 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] + self.conv2 = nn.Sequential(*conv2) + self.style2 = ApplyStyle(latent_size, dim) + + + def forward(self, x, dlatents_in_slice): + y = self.conv1(x) + y = self.style1(y, dlatents_in_slice) + y = self.act1(y) + y = self.conv2(y) + y = self.style2(y, dlatents_in_slice) + out = x + y + return out + + + +class Generator_Adain_Upsample(nn.Module): + def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False, + norm_layer=nn.BatchNorm2d, + padding_type='reflect'): + assert (n_blocks >= 0) + super(Generator_Adain_Upsample, self).__init__() + activation = nn.ReLU(True) + self.deep = deep + + self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0), + norm_layer(64), activation) + ### downsample + self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + norm_layer(128), activation) + self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), + norm_layer(256), activation) + self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), + norm_layer(512), activation) + if self.deep: + self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + norm_layer(512), activation) + + ### resnet blocks + BN = [] + for i in range(n_blocks): + BN += [ + ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)] + self.BottleNeck = nn.Sequential(*BN) + + if self.deep: + self.up4 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(512), activation + ) + self.up3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), + nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(256), activation + ) + self.up2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), + nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(128), activation + ) + self.up1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(64), activation + ) + self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0)) + + def forward(self, input, dlatents): + x = input # 3*224*224 + + skip1 = self.first_layer(x) + skip2 = self.down1(skip1) + skip3 = self.down2(skip2) + if self.deep: + skip4 = self.down3(skip3) + x = self.down4(skip4) + else: + x = self.down3(skip3) + bot = [] + bot.append(x) + features = [] + for i in range(len(self.BottleNeck)): + x = self.BottleNeck[i](x, dlatents) + bot.append(x) + + if self.deep: + x = self.up4(x) + features.append(x) + x = self.up3(x) + features.append(x) + x = self.up2(x) + features.append(x) + x = self.up1(x) + features.append(x) + x = self.last_layer(x) + # x = (x + 1) / 2 + + # return x, bot, features, dlatents + return x \ No newline at end of file diff --git a/models/projected_model.py b/models/projected_model.py new file mode 100644 index 0000000..5c6e81d --- /dev/null +++ b/models/projected_model.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: fs_model_fix_idnorm_donggp_saveoptim copy.py +# Created Date: Wednesday January 12th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 20th April 2022 6:34:47 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + + +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .fs_networks_fix import Generator_Adain_Upsample + +from pg_modules.projected_discriminator import ProjectedDiscriminator + +def compute_grad2(d_out, x_in): + batch_size = x_in.size(0) + grad_dout = torch.autograd.grad( + outputs=d_out.sum(), inputs=x_in, + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_dout2 = grad_dout.pow(2) + assert(grad_dout2.size() == x_in.size()) + reg = grad_dout2.view(batch_size, -1).sum(1) + return reg + +class fsModel(BaseModel): + def name(self): + return 'fsModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + # if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM + self.isTrain = opt.isTrain + + # Generator network + self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep) + self.netG.cuda() + + # Id network + netArc_checkpoint = opt.Arc_path + netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu")) + self.netArc = netArc_checkpoint['model'].module + self.netArc = self.netArc.cuda() + self.netArc.eval() + self.netArc.requires_grad_(False) + if not self.isTrain: + pretrained_path = opt.checkpoints_dir + self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) + return + self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{}) + # self.netD.feature_network.requires_grad_(False) + self.netD.cuda() + + + if self.isTrain: + # define loss functions + self.criterionFeat = nn.L1Loss() + self.criterionRec = nn.L1Loss() + + + # initialize optimizers + + # optimizer G + params = list(self.netG.parameters()) + self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8) + + # optimizer D + params = list(self.netD.parameters()) + self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8) + + # load networks + if opt.continue_train: + pretrained_path = '' if not self.isTrain else opt.load_pretrain + # print (pretrained_path) + self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) + self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) + self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path) + self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path) + torch.cuda.empty_cache() + + def cosin_metric(self, x1, x2): + #return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2)) + return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1)) + + + + def save(self, which_epoch): + self.save_network(self.netG, 'G', which_epoch) + self.save_network(self.netD, 'D', which_epoch) + self.save_optim(self.optimizer_G, 'G', which_epoch,) + self.save_optim(self.optimizer_D, 'D', which_epoch) + '''if self.gen_features: + self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)''' + + def update_fixed_params(self): + # after fixing the global generator for a number of iterations, also start finetuning it + params = list(self.netG.parameters()) + if self.gen_features: + params += list(self.netE.parameters()) + self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) + if self.opt.verbose: + print('------------ Now also finetuning global generator -----------') + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + if self.opt.verbose: + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr + + diff --git a/models/projectionhead.py b/models/projectionhead.py new file mode 100644 index 0000000..a0d18f1 --- /dev/null +++ b/models/projectionhead.py @@ -0,0 +1,14 @@ +import torch.nn as nn + +class ProjectionHead(nn.Module): + def __init__(self, proj_dim=256): + super(ProjectionHead, self).__init__() + + self.proj = nn.Sequential( + nn.Linear(proj_dim, proj_dim), + nn.ReLU(), + nn.Linear(proj_dim, proj_dim), + ) + + def forward(self, x): + return self.proj(x) \ No newline at end of file diff --git a/options/test_options.py b/options/test_options.py index 18c5527..bbbee5a 100644 --- a/options/test_options.py +++ b/options/test_options.py @@ -15,7 +15,7 @@ class TestOptions(BaseOptions): self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') - self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--which_epoch', type=str, default='9900000', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map') @@ -35,4 +35,4 @@ class TestOptions(BaseOptions): self.parser.add_argument('--use_mask', action='store_true', help='Use mask for better result') self.parser.add_argument('--crop_size', type=int, default=224, help='Crop of size of input image') - self.isTrain = False + self.isTrain = False \ No newline at end of file diff --git a/pg_modules/blocks.py b/pg_modules/blocks.py new file mode 100644 index 0000000..78bd113 --- /dev/null +++ b/pg_modules/blocks.py @@ -0,0 +1,325 @@ +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import spectral_norm + + +### single layers + + +def conv2d(*args, **kwargs): + return spectral_norm(nn.Conv2d(*args, **kwargs)) + + +def convTranspose2d(*args, **kwargs): + return spectral_norm(nn.ConvTranspose2d(*args, **kwargs)) + + +def embedding(*args, **kwargs): + return spectral_norm(nn.Embedding(*args, **kwargs)) + + +def linear(*args, **kwargs): + return spectral_norm(nn.Linear(*args, **kwargs)) + + +def NormLayer(c, mode='batch'): + if mode == 'group': + return nn.GroupNorm(c//2, c) + elif mode == 'batch': + return nn.BatchNorm2d(c) + + +### Activations + + +class GLU(nn.Module): + def forward(self, x): + nc = x.size(1) + assert nc % 2 == 0, 'channels dont divide 2!' + nc = int(nc/2) + return x[:, :nc] * torch.sigmoid(x[:, nc:]) + + +class Swish(nn.Module): + def forward(self, feat): + return feat * torch.sigmoid(feat) + + +### Upblocks + + +class InitLayer(nn.Module): + def __init__(self, nz, channel, sz=4): + super().__init__() + + self.init = nn.Sequential( + convTranspose2d(nz, channel*2, sz, 1, 0, bias=False), + NormLayer(channel*2), + GLU(), + ) + + def forward(self, noise): + noise = noise.view(noise.shape[0], -1, 1, 1) + return self.init(noise) + + +def UpBlockSmall(in_planes, out_planes): + block = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False), + NormLayer(out_planes*2), GLU()) + return block + + +class UpBlockSmallCond(nn.Module): + def __init__(self, in_planes, out_planes, z_dim): + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + self.up = nn.Upsample(scale_factor=2, mode='nearest') + self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False) + + which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim) + self.bn = which_bn(2*out_planes) + self.act = GLU() + + def forward(self, x, c): + x = self.up(x) + x = self.conv(x) + x = self.bn(x, c) + x = self.act(x) + return x + + +def UpBlockBig(in_planes, out_planes): + block = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False), + NoiseInjection(), + NormLayer(out_planes*2), GLU(), + conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False), + NoiseInjection(), + NormLayer(out_planes*2), GLU() + ) + return block + + +class UpBlockBigCond(nn.Module): + def __init__(self, in_planes, out_planes, z_dim): + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + self.up = nn.Upsample(scale_factor=2, mode='nearest') + self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False) + self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False) + + which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim) + self.bn1 = which_bn(2*out_planes) + self.bn2 = which_bn(2*out_planes) + self.act = GLU() + self.noise = NoiseInjection() + + def forward(self, x, c): + # block 1 + x = self.up(x) + x = self.conv1(x) + x = self.noise(x) + x = self.bn1(x, c) + x = self.act(x) + + # block 2 + x = self.conv2(x) + x = self.noise(x) + x = self.bn2(x, c) + x = self.act(x) + + return x + + +class SEBlock(nn.Module): + def __init__(self, ch_in, ch_out): + super().__init__() + self.main = nn.Sequential( + nn.AdaptiveAvgPool2d(4), + conv2d(ch_in, ch_out, 4, 1, 0, bias=False), + Swish(), + conv2d(ch_out, ch_out, 1, 1, 0, bias=False), + nn.Sigmoid(), + ) + + def forward(self, feat_small, feat_big): + return feat_big * self.main(feat_small) + + +### Downblocks + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=False): + super(SeparableConv2d, self).__init__() + self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size, + groups=in_channels, bias=bias, padding=1) + self.pointwise = conv2d(in_channels, out_channels, + kernel_size=1, bias=bias) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out + + +class DownBlock(nn.Module): + def __init__(self, in_planes, out_planes, separable=False): + super().__init__() + if not separable: + self.main = nn.Sequential( + conv2d(in_planes, out_planes, 4, 2, 1), + NormLayer(out_planes), + nn.LeakyReLU(0.2, inplace=True), + ) + else: + self.main = nn.Sequential( + SeparableConv2d(in_planes, out_planes, 3), + NormLayer(out_planes), + nn.LeakyReLU(0.2, inplace=True), + nn.AvgPool2d(2, 2), + ) + + def forward(self, feat): + return self.main(feat) + + +class DownBlockPatch(nn.Module): + def __init__(self, in_planes, out_planes, separable=False): + super().__init__() + self.main = nn.Sequential( + DownBlock(in_planes, out_planes, separable), + conv2d(out_planes, out_planes, 1, 1, 0, bias=False), + NormLayer(out_planes), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, feat): + return self.main(feat) + + +### CSM + + +class ResidualConvUnit(nn.Module): + def __init__(self, cin, activation, bn): + super().__init__() + self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + return self.skip_add.add(self.conv(x), x) + + +class FeatureFusionBlock(nn.Module): + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False): + super().__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + output = xs[0] + + if len(xs) == 2: + output = self.skip_add.add(output, xs[1]) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + + +### Misc + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.zeros(1), requires_grad=True) + + def forward(self, feat, noise=None): + if noise is None: + batch, _, height, width = feat.shape + noise = torch.randn(batch, 1, height, width).to(feat.device) + + return feat + self.weight * noise + + +class CCBN(nn.Module): + ''' conditional batchnorm ''' + def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1): + super().__init__() + self.output_size, self.input_size = output_size, input_size + + # Prepare gain and bias layers + self.gain = which_linear(input_size, output_size) + self.bias = which_linear(input_size, output_size) + + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y): + # Calculate class-conditional gains and biases + gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) + bias = self.bias(y).view(y.size(0), -1, 1, 1) + out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + return out * gain + bias + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, size, mode='bilinear', align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.size = size + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + size=self.size, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x diff --git a/pg_modules/diffaug.py b/pg_modules/diffaug.py new file mode 100644 index 0000000..54020be --- /dev/null +++ b/pg_modules/diffaug.py @@ -0,0 +1,76 @@ +# Differentiable Augmentation for Data-Efficient GAN Training +# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han +# https://arxiv.org/pdf/2006.10738 + +import torch +import torch.nn.functional as F + + +def DiffAugment(x, policy='', channels_first=True): + if policy: + if not channels_first: + x = x.permute(0, 3, 1, 2) + for p in policy.split(','): + for f in AUGMENT_FNS[p]: + x = f(x) + if not channels_first: + x = x.permute(0, 2, 3, 1) + x = x.contiguous() + return x + + +def rand_brightness(x): + x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) + return x + + +def rand_saturation(x): + x_mean = x.mean(dim=1, keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean + return x + + +def rand_contrast(x): + x_mean = x.mean(dim=[1, 2, 3], keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean + return x + + +def rand_translation(x, ratio=0.125): + shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(x.size(2), dtype=torch.long, device=x.device), + torch.arange(x.size(3), dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) + grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) + x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) + x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) + return x + + +def rand_cutout(x, ratio=0.2): + cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(cutout_size[0], dtype=torch.long, device=x.device), + torch.arange(cutout_size[1], dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) + mask[grid_batch, grid_x, grid_y] = 0 + x = x * mask.unsqueeze(1) + return x + + +AUGMENT_FNS = { + 'color': [rand_brightness, rand_saturation, rand_contrast], + 'translation': [rand_translation], + 'cutout': [rand_cutout], +} diff --git a/pg_modules/projected_discriminator.py b/pg_modules/projected_discriminator.py new file mode 100644 index 0000000..d0c879f --- /dev/null +++ b/pg_modules/projected_discriminator.py @@ -0,0 +1,191 @@ +from functools import partial +import numpy as np +import torch +import torch.nn as nn + +from pg_modules.blocks import DownBlock, DownBlockPatch, conv2d +from pg_modules.projector import F_RandomProj +from pg_modules.diffaug import DiffAugment + + +class SingleDisc(nn.Module): + def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False): + super().__init__() + channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, + 256: 32, 512: 16, 1024: 8} + + # interpolate for start sz that are not powers of two + if start_sz not in channel_dict.keys(): + sizes = np.array(list(channel_dict.keys())) + start_sz = sizes[np.argmin(abs(sizes - start_sz))] + self.start_sz = start_sz + + # if given ndf, allocate all layers with the same ndf + if ndf is None: + nfc = channel_dict + else: + nfc = {k: ndf for k, v in channel_dict.items()} + + # for feature map discriminators with nfc not in channel_dict + # this is the case for the pretrained backbone (midas.pretrained) + if nc is not None and head is None: + nfc[start_sz] = nc + + layers = [] + + # Head if the initial input is the full modality + if head: + layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True)] + + # Down Blocks + DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) + while start_sz > end_sz: + layers.append(DB(nfc[start_sz], nfc[start_sz//2])) + start_sz = start_sz // 2 + + layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False)) + self.main = nn.Sequential(*layers) + + def forward(self, x, c): + return self.main(x) + + +class SingleDiscCond(nn.Module): + def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128): + super().__init__() + self.cmap_dim = cmap_dim + + # midas channels + channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, + 256: 32, 512: 16, 1024: 8} + + # interpolate for start sz that are not powers of two + if start_sz not in channel_dict.keys(): + sizes = np.array(list(channel_dict.keys())) + start_sz = sizes[np.argmin(abs(sizes - start_sz))] + self.start_sz = start_sz + + # if given ndf, allocate all layers with the same ndf + if ndf is None: + nfc = channel_dict + else: + nfc = {k: ndf for k, v in channel_dict.items()} + + # for feature map discriminators with nfc not in channel_dict + # this is the case for the pretrained backbone (midas.pretrained) + if nc is not None and head is None: + nfc[start_sz] = nc + + layers = [] + + # Head if the initial input is the full modality + if head: + layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True)] + + # Down Blocks + DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) + while start_sz > end_sz: + layers.append(DB(nfc[start_sz], nfc[start_sz//2])) + start_sz = start_sz // 2 + self.main = nn.Sequential(*layers) + + # additions for conditioning on class information + self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False) + self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim) + self.embed_proj = nn.Sequential( + nn.Linear(self.embed.embedding_dim, self.cmap_dim), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, x, c): + h = self.main(x) + out = self.cls(h) + + # conditioning via projection + cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1) + out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + return out + + +class MultiScaleD(nn.Module): + def __init__( + self, + channels, + resolutions, + num_discs=4, + proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing + cond=0, + separable=False, + patch=False, + **kwargs, + ): + super().__init__() + + assert num_discs in [1, 2, 3, 4] + + # the first disc is on the lowest level of the backbone + self.disc_in_channels = channels[:num_discs] + self.disc_in_res = resolutions[:num_discs] + Disc = SingleDiscCond if cond else SingleDisc + + mini_discs = [] + for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)): + start_sz = res if not patch else 16 + mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)], + self.mini_discs = nn.ModuleDict(mini_discs) + + def forward(self, features, c): + all_logits = [] + for k, disc in self.mini_discs.items(): + res = disc(features[k], c).view(features[k].size(0), -1) + all_logits.append(res) + + all_logits = torch.cat(all_logits, dim=1) + return all_logits + + +class ProjectedDiscriminator(torch.nn.Module): + def __init__( + self, + diffaug=True, + interp224=True, + backbone_kwargs={}, + **kwargs + ): + super().__init__() + self.diffaug = diffaug + self.interp224 = interp224 + self.feature_network = F_RandomProj(**backbone_kwargs) + self.discriminator = MultiScaleD( + channels=self.feature_network.CHANNELS, + resolutions=self.feature_network.RESOLUTIONS, + **backbone_kwargs, + ) + + def train(self, mode=True): + self.feature_network = self.feature_network.train(False) + self.discriminator = self.discriminator.train(mode) + return self + + def eval(self): + return self.train(False) + + def get_feature(self, x): + features = self.feature_network(x, get_features=True) + return features + + def forward(self, x, c): + # if self.diffaug: + # x = DiffAugment(x, policy='color,translation,cutout') + + # if self.interp224: + # x = F.interpolate(x, 224, mode='bilinear', align_corners=False) + + features,backbone_features = self.feature_network(x) + logits = self.discriminator(features, c) + + return logits,backbone_features + diff --git a/pg_modules/projector.py b/pg_modules/projector.py new file mode 100644 index 0000000..610a482 --- /dev/null +++ b/pg_modules/projector.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +import timm +from pg_modules.blocks import FeatureFusionBlock + + +def _make_scratch_ccm(scratch, in_channels, cout, expand=False): + # shapes + out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4 + + scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True) + scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True) + scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True) + scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True) + + scratch.CHANNELS = out_channels + + return scratch + + +def _make_scratch_csm(scratch, in_channels, cout, expand): + scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True) + scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand) + scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand) + scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False)) + + # last refinenet does not expand to save channels in higher dimensions + scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4 + + return scratch + + +def _make_efficientnet(model): + pretrained = nn.Module() + pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2]) + pretrained.layer1 = nn.Sequential(*model.blocks[2:3]) + pretrained.layer2 = nn.Sequential(*model.blocks[3:5]) + pretrained.layer3 = nn.Sequential(*model.blocks[5:9]) + return pretrained + + +def calc_channels(pretrained, inp_res=224): + channels = [] + tmp = torch.zeros(1, 3, inp_res, inp_res) + + # forward pass + tmp = pretrained.layer0(tmp) + channels.append(tmp.shape[1]) + tmp = pretrained.layer1(tmp) + channels.append(tmp.shape[1]) + tmp = pretrained.layer2(tmp) + channels.append(tmp.shape[1]) + tmp = pretrained.layer3(tmp) + channels.append(tmp.shape[1]) + + return channels + + +def _make_projector(im_res, cout, proj_type, expand=False): + assert proj_type in [0, 1, 2], "Invalid projection type" + + ### Build pretrained feature network + model = timm.create_model('tf_efficientnet_lite0', pretrained=True) + pretrained = _make_efficientnet(model) + + # determine resolution of feature maps, this is later used to calculate the number + # of down blocks in the discriminators. Interestingly, the best results are achieved + # by fixing this to 256, ie., we use the same number of down blocks per discriminator + # independent of the dataset resolution + im_res = 256 + pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32] + pretrained.CHANNELS = calc_channels(pretrained) + + if proj_type == 0: return pretrained, None + + ### Build CCM + scratch = nn.Module() + scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand) + pretrained.CHANNELS = scratch.CHANNELS + + if proj_type == 1: return pretrained, scratch + + ### build CSM + scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand) + + # CSM upsamples x2 so the feature map resolution doubles + pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS] + pretrained.CHANNELS = scratch.CHANNELS + + return pretrained, scratch + + +class F_RandomProj(nn.Module): + def __init__( + self, + im_res=256, + cout=64, + expand=True, + proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing + **kwargs, + ): + super().__init__() + self.proj_type = proj_type + self.cout = cout + self.expand = expand + + # build pretrained feature network and random decoder (scratch) + self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand) + self.CHANNELS = self.pretrained.CHANNELS + self.RESOLUTIONS = self.pretrained.RESOLUTIONS + + def forward(self, x, get_features=False): + # predict feature maps + out0 = self.pretrained.layer0(x) + out1 = self.pretrained.layer1(out0) + out2 = self.pretrained.layer2(out1) + out3 = self.pretrained.layer3(out2) + + # start enumerating at the lowest layer (this is where we put the first discriminator) + backbone_features = { + '0': out0, + '1': out1, + '2': out2, + '3': out3, + } + if get_features: + return backbone_features + + if self.proj_type == 0: return backbone_features + + out0_channel_mixed = self.scratch.layer0_ccm(backbone_features['0']) + out1_channel_mixed = self.scratch.layer1_ccm(backbone_features['1']) + out2_channel_mixed = self.scratch.layer2_ccm(backbone_features['2']) + out3_channel_mixed = self.scratch.layer3_ccm(backbone_features['3']) + + out = { + '0': out0_channel_mixed, + '1': out1_channel_mixed, + '2': out2_channel_mixed, + '3': out3_channel_mixed, + } + + if self.proj_type == 1: return out + + # from bottom to top + out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed) + out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed) + out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed) + out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed) + + out = { + '0': out0_scale_mixed, + '1': out1_scale_mixed, + '2': out2_scale_mixed, + '3': out3_scale_mixed, + } + + return out, backbone_features diff --git a/train.py b/train.py new file mode 100644 index 0000000..9e77695 --- /dev/null +++ b/train.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: train.py +# Created Date: Monday December 27th 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 20th April 2022 6:33:30 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + +import os +import time +import wandb +import random +import argparse +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.backends import cudnn +import torch.utils.tensorboard as tensorboard + +from util import util +from util.plot import plot_batch + +from models.projected_model import fsModel +from data.data_loader_Swapping import GetLoader + + +class TrainOptions: + def __init__(self): + self.parser = argparse.ArgumentParser() + self.initialized = False + + def initialize(self): + self.parser.add_argument('--name', type=str, default='simswap', help='name of the experiment. It decides where to store samples and models') + self.parser.add_argument('--gpu_ids', default='0') + self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + self.parser.add_argument('--isTrain', type=bool, default=True) + + # input/output sizes + self.parser.add_argument('--batchSize', type=int, default=16, help='input batch size') + + # for displays + self.parser.add_argument('--tag', type=str, default='simswap') + + # for training + self.parser.add_argument('--dataset', type=str, default="G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan", help='path to the face swapping dataset') + self.parser.add_argument('--continue_train', type=bool, default=False, help='continue training: load the latest model') + self.parser.add_argument('--load_pretrain', type=str, default='checkpoints', help='load the pretrained model from the specified location') + self.parser.add_argument('--which_epoch', type=str, default='800000', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate') + self.parser.add_argument('--niter_decay', type=int, default=10000, help='# of iter to linearly decay learning rate to zero') + self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam') + self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam') + self.parser.add_argument("--Gdeep",type=bool,default=False) + self.parser.add_argument("--train_simswap",type=bool,default=True) + + # for discriminators + self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') + self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss') + self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss') + + self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT") + self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step') + self.parser.add_argument("--log_frep", type=int, default=250, help='frequence for printing log information') + self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling') + self.parser.add_argument("--model_freq", type=int, default=10000, help='frequence for saving the model') + + + + + self.isTrain = True + + def parse(self, save=True): + if not self.initialized: + self.initialize() + self.opt = self.parser.parse_args() + self.opt.isTrain = self.isTrain # train or test + + args = vars(self.opt) + + print('------------ Options -------------') + for k, v in sorted(args.items()): + print('%s: %s' % (str(k), str(v))) + print('-------------- End ----------------') + + # save to the disk + if self.opt.isTrain: + expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + util.mkdirs(expr_dir) + if save and not self.opt.continue_train: + file_name = os.path.join(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write('------------ Options -------------\n') + for k, v in sorted(args.items()): + opt_file.write('%s: %s\n' % (str(k), str(v))) + opt_file.write('-------------- End ----------------\n') + return self.opt + + +if __name__ == '__main__': + + opt = TrainOptions().parse() + iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') + + sample_path = os.path.join(opt.checkpoints_dir, opt.name, 'samples') + + if not os.path.exists(sample_path): + os.makedirs(sample_path) + + log_path = os.path.join(opt.checkpoints_dir, opt.name, 'summary') + + if not os.path.exists(log_path): + os.makedirs(log_path) + + if opt.continue_train: + try: + start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) + except: + start_epoch, epoch_iter = 1, 0 + print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) + else: + start_epoch, epoch_iter = 1, 0 + + os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_ids) + print("GPU used : ", str(opt.gpu_ids)) + + + cudnn.benchmark = True + + + + model = fsModel() + + model.initialize(opt) + + ##################################################### + + tensorboard_writer = tensorboard.SummaryWriter(log_path) + logger = tensorboard_writer + log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + + with open(log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + optimizer_G, optimizer_D = model.optimizer_G, model.optimizer_D + + loss_avg = 0 + refresh_count = 0 + imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + train_loader = GetLoader(opt.dataset,opt.batchSize,8,1234) + + randindex = [i for i in range(opt.batchSize)] + random.shuffle(randindex) + + if not opt.continue_train: + start = 0 + else: + start = int(opt.which_epoch) + total_step = opt.total_step + import datetime + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from util.logo_class import logo_class + logo_class.print_start_training() + model.netD.feature_network.requires_grad_(False) + + # Training Cycle + for step in range(start, total_step): + model.netG.train() + for interval in range(2): + random.shuffle(randindex) + src_image1, src_image2 = train_loader.next() + if opt.train_simswap: + src_image1 = F.interpolate(src_image1,size=(256,256), mode='bicubic') + src_image2 = F.interpolate(src_image2,size=(256,256), mode='bicubic') + + if step%2 == 0: + img_id = src_image2 + else: + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = model.netArc(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + if interval: + + img_fake = model.netG(src_image1, latent_id) + gen_logits,_ = model.netD(img_fake.detach(), None) + loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() + + real_logits,_ = model.netD(src_image2,None) + loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean() + + loss_D = loss_Dgen + loss_Dreal + optimizer_D.zero_grad() + loss_D.backward() + optimizer_D.step() + else: + + # model.netD.requires_grad_(True) + img_fake = model.netG(src_image1, latent_id) + # G loss + gen_logits,feat = model.netD(img_fake, None) + + loss_Gmain = (-gen_logits).mean() + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = model.netArc(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - model.cosin_metric(latent_fake, latent_id)).mean() + real_feat = model.netD.get_feature(src_image1) + feat_match_loss = model.criterionFeat(feat["3"],real_feat["3"]) + loss_G = loss_Gmain + loss_G_ID * opt.lambda_id + feat_match_loss * opt.lambda_feat + + + if step%2 == 0: + #G_Rec + loss_G_Rec = model.criterionRec(img_fake, src_image1) * opt.lambda_rec + loss_G += loss_G_Rec + + optimizer_G.zero_grad() + loss_G.backward() + optimizer_G.step() + + + ############## Display results and errors ########## + ### print out errors + # Print out log info + if (step + 1) % opt.log_frep == 0: + # errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()} + errors = { + "G_Loss":loss_Gmain.item(), + "G_ID":loss_G_ID.item(), + "G_Rec":loss_G_Rec.item(), + "G_feat_match":feat_match_loss.item(), + "D_fake":loss_Dgen.item(), + "D_real":loss_Dreal.item(), + "D_loss":loss_D.item() + } + + for tag, value in errors.items(): + logger.add_scalar(tag, value, step) + message = '( step: %d, ) ' % (step) + for k, v in errors.items(): + message += '%s: %.3f ' % (k, v) + + print(message) + with open(log_name, "a") as log_file: + log_file.write('%s\n' % message) + + ### display output images + if (step + 1) % opt.sample_freq == 0: + model.netG.eval() + with torch.no_grad(): + imgs = list() + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* imagenet_std + imagenet_mean).numpy() + for r in range(opt.batchSize): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = model.netArc(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(opt.batchSize): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(opt.batchSize, 1, 1, 1) + img_fake = model.netG(image_infer, id_vector_src1).cpu() + + img_fake = img_fake * imagenet_std + img_fake = img_fake + imagenet_mean + img_fake = img_fake.numpy() + for j in range(opt.batchSize): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(sample_path, 'step_'+str(step+1)+'.jpg')) + + ### save latest model + if (step+1) % opt.model_freq==0: + print('saving the latest model (steps %d)' % (step+1)) + model.save(step+1) + np.savetxt(iter_path, (step+1, total_step), delimiter=',', fmt='%d') + wandb.finish() \ No newline at end of file diff --git a/util/json_config.py b/util/json_config.py new file mode 100644 index 0000000..c68fbff --- /dev/null +++ b/util/json_config.py @@ -0,0 +1,15 @@ +import json + + +def readConfig(path): + with open(path,'r') as cf: + nodelocaltionstr = cf.read() + nodelocaltioninf = json.loads(nodelocaltionstr) + if isinstance(nodelocaltioninf,str): + nodelocaltioninf = json.loads(nodelocaltioninf) + return nodelocaltioninf + +def writeConfig(path, info): + with open(path, 'w') as cf: + configjson = json.dumps(info, indent=4) + cf.writelines(configjson) \ No newline at end of file diff --git a/util/logo_class.py b/util/logo_class.py new file mode 100644 index 0000000..044dce3 --- /dev/null +++ b/util/logo_class.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: logo_class.py +# Created Date: Tuesday June 29th 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Monday, 11th October 2021 12:39:55 am +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + +class logo_class: + + @staticmethod + def print_group_logo(): + logo_str = """ + +███╗ ██╗██████╗ ███████╗██╗ ██████╗ ███████╗ ██╗████████╗██╗ ██╗ +████╗ ██║██╔══██╗██╔════╝██║██╔════╝ ██╔════╝ ██║╚══██╔══╝██║ ██║ +██╔██╗ ██║██████╔╝███████╗██║██║ ███╗ ███████╗ ██║ ██║ ██║ ██║ +██║╚██╗██║██╔══██╗╚════██║██║██║ ██║ ╚════██║██ ██║ ██║ ██║ ██║ +██║ ╚████║██║ ██║███████║██║╚██████╔╝ ███████║╚█████╔╝ ██║ ╚██████╔╝ +╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═════╝ ╚══════╝ ╚════╝ ╚═╝ ╚═════╝ +Neural Rendering Special Interesting Group of SJTU + + """ + print(logo_str) + + @staticmethod + def print_start_training(): + logo_str = """ + _____ __ __ ______ _ _ + / ___/ / /_ ____ _ _____ / /_ /_ __/_____ ____ _ (_)____ (_)____ ____ _ + \__ \ / __// __ `// ___// __/ / / / ___// __ `// // __ \ / // __ \ / __ `/ + ___/ // /_ / /_/ // / / /_ / / / / / /_/ // // / / // // / / // /_/ / +/____/ \__/ \__,_//_/ \__/ /_/ /_/ \__,_//_//_/ /_//_//_/ /_/ \__, / + /____/ + """ + print(logo_str) + +if __name__=="__main__": + # logo_class.print_group_logo() + logo_class.print_start_training() \ No newline at end of file diff --git a/util/plot.py b/util/plot.py new file mode 100644 index 0000000..0da1c75 --- /dev/null +++ b/util/plot.py @@ -0,0 +1,37 @@ +import numpy as np +import math +import PIL + +def postprocess(x): + """[0,1] to uint8.""" + + x = np.clip(255 * x, 0, 255) + x = np.cast[np.uint8](x) + return x + +def tile(X, rows, cols): + """Tile images for display.""" + tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype) + for i in range(rows): + for j in range(cols): + idx = i * cols + j + if idx < X.shape[0]: + img = X[idx,...] + tiling[ + i*X.shape[1]:(i+1)*X.shape[1], + j*X.shape[2]:(j+1)*X.shape[2], + :] = img + return tiling + + +def plot_batch(X, out_path): + """Save batch of images tiled.""" + n_channels = X.shape[3] + if n_channels > 3: + X = X[:,:,:,np.random.choice(n_channels, size = 3)] + X = postprocess(X) + rc = math.sqrt(X.shape[0]) + rows = cols = math.ceil(rc) + canvas = tile(X, rows, cols) + canvas = np.squeeze(canvas) + PIL.Image.fromarray(canvas).save(out_path) \ No newline at end of file diff --git a/util/save_heatmap.py b/util/save_heatmap.py new file mode 100644 index 0000000..71ce4c9 --- /dev/null +++ b/util/save_heatmap.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: save_heatmap.py +# Created Date: Friday January 15th 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 19th January 2022 1:22:47 am +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + +import os +import shutil +import seaborn as sns +import matplotlib.pyplot as plt +import cv2 +import numpy as np + +def SaveHeatmap(heatmaps, path, row=-1, dpi=72): + """ + The input tensor must be B X 1 X H X W + """ + batch_size = heatmaps.shape[0] + temp_path = ".temp/" + if not os.path.exists(temp_path): + os.makedirs(temp_path) + final_img = None + if row < 1: + col = batch_size + row = 1 + else: + col = batch_size // row + if row * col = col: + col_i = 0 + row_i += 1 + cv2.imwrite(path,final_img) + +if __name__ == "__main__": + random_map = np.random.randn(16,1,10,10) + SaveHeatmap(random_map,"./wocao.png",1)