training scripts released

This commit is contained in:
chenxuanhong
2022-04-20 18:36:26 +08:00
parent 9492873690
commit f48dc8cf62
16 changed files with 1688 additions and 3 deletions
+4 -1
View File
@@ -135,4 +135,7 @@ checkpoints/
*.zip
*.avi
*.pdf
*.pptx
*.pptx
*.pth
*.onnx
+127
View File
@@ -0,0 +1,127 @@
import os
import glob
import torch
import random
from PIL import Image
from pathlib import Path
from torch.utils import data
from torchvision import transforms as T
# from StyleResize import StyleResize
class data_prefetcher():
def __init__(self, loader):
self.loader = loader
self.dataiter = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.num_images = len(loader)
self.preload()
def preload(self):
try:
self.src_image1, self.src_image2 = next(self.dataiter)
except StopIteration:
self.dataiter = iter(self.loader)
self.src_image1, self.src_image2 = next(self.dataiter)
with torch.cuda.stream(self.stream):
self.src_image1 = self.src_image1.cuda(non_blocking=True)
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
self.src_image2 = self.src_image2.cuda(non_blocking=True)
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
src_image1 = self.src_image1
src_image2 = self.src_image2
self.preload()
return src_image1, src_image2
def __len__(self):
"""Return the number of images."""
return self.num_images
class SwappingDataset(data.Dataset):
"""Dataset class for the Artworks dataset and content dataset."""
def __init__(self,
image_dir,
img_transform,
subffix='jpg',
random_seed=1234):
"""Initialize and preprocess the Swapping dataset."""
self.image_dir = image_dir
self.img_transform = img_transform
self.subffix = subffix
self.dataset = []
self.random_seed = random_seed
self.preprocess()
self.num_images = len(self.dataset)
def preprocess(self):
"""Preprocess the Swapping dataset."""
print("processing Swapping dataset images...")
temp_path = os.path.join(self.image_dir,'*/')
pathes = glob.glob(temp_path)
self.dataset = []
for dir_item in pathes:
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
print("processing %s"%dir_item,end='\r')
temp_list = []
for item in join_path:
temp_list.append(item)
self.dataset.append(temp_list)
random.seed(self.random_seed)
random.shuffle(self.dataset)
print('Finished preprocessing the Swapping dataset, total dirs number: %d...'%len(self.dataset))
def __getitem__(self, index):
"""Return two src domain images and two dst domain images."""
dir_tmp1 = self.dataset[index]
dir_tmp1_len = len(dir_tmp1)
filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
image1 = self.img_transform(Image.open(filename1))
image2 = self.img_transform(Image.open(filename2))
return image1, image2
def __len__(self):
"""Return the number of images."""
return self.num_images
def GetLoader( dataset_roots,
batch_size=16,
dataloader_workers=8,
random_seed = 1234
):
"""Build and return a data loader."""
num_workers = dataloader_workers
data_root = dataset_roots
random_seed = random_seed
c_transforms = []
c_transforms.append(T.ToTensor())
c_transforms = T.Compose(c_transforms)
content_dataset = SwappingDataset(
data_root,
c_transforms,
"jpg",
random_seed)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
prefetcher = data_prefetcher(content_data_loader)
return prefetcher
def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
+54
View File
@@ -37,6 +37,19 @@ class BaseModel(torch.nn.Module):
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids=None):
save_filename = '{}_net_{}.pth'.format(epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if torch.cuda.is_available():
network.cuda()
def save_optim(self, network, network_label, epoch_label, gpu_ids=None):
save_filename = '{}_optim_{}.pth'.format(epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.state_dict(), save_path)
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
@@ -63,6 +76,47 @@ class BaseModel(torch.nn.Module):
except:
pretrained_dict = torch.load(save_path)
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
if self.opt.verbose:
print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
except:
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
if sys.version_info >= (3,0):
not_initialized = set()
else:
from sets import Set
not_initialized = Set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split('.')[0])
print(sorted(not_initialized))
network.load_state_dict(model_dict)
# helper loading function that can be used by subclasses
def load_optim(self, network, network_label, epoch_label, save_dir=''):
save_filename = '%s_optim_%s.pth' % (epoch_label, network_label)
if not save_dir:
save_dir = self.save_dir
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print('%s not exists yet!' % save_path)
if network_label == 'G':
raise('Generator must exist!')
else:
#network.load_state_dict(torch.load(save_path))
try:
network.load_state_dict(torch.load(save_path, map_location=torch.device("cpu")))
except:
pretrained_dict = torch.load(save_path, map_location=torch.device("cpu"))
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
+169
View File
@@ -0,0 +1,169 @@
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import torch
import torch.nn as nn
class InstanceNorm(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(InstanceNorm, self).__init__()
self.epsilon = epsilon
def forward(self, x):
x = x - torch.mean(x, (2, 3), True)
tmp = torch.mul(x, x) # or x ** 2
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
return x * tmp
class ApplyStyle(nn.Module):
"""
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
"""
def __init__(self, latent_size, channels):
super(ApplyStyle, self).__init__()
self.linear = nn.Linear(latent_size, channels * 2)
def forward(self, x, latent):
style = self.linear(latent) # style => [batch_size, n_channels*2]
shape = [-1, 2, x.size(1), 1, 1]
style = style.view(shape) # [batch_size, 2, n_channels, ...]
#x = x * (style[:, 0] + 1.) + style[:, 1]
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
return x
class ResnetBlock_Adain(nn.Module):
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
super(ResnetBlock_Adain, self).__init__()
p = 0
conv1 = []
if padding_type == 'reflect':
conv1 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv1 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
self.conv1 = nn.Sequential(*conv1)
self.style1 = ApplyStyle(latent_size, dim)
self.act1 = activation
p = 0
conv2 = []
if padding_type == 'reflect':
conv2 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv2 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
self.conv2 = nn.Sequential(*conv2)
self.style2 = ApplyStyle(latent_size, dim)
def forward(self, x, dlatents_in_slice):
y = self.conv1(x)
y = self.style1(y, dlatents_in_slice)
y = self.act1(y)
y = self.conv2(y)
y = self.style2(y, dlatents_in_slice)
out = x + y
return out
class Generator_Adain_Upsample(nn.Module):
def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert (n_blocks >= 0)
super(Generator_Adain_Upsample, self).__init__()
activation = nn.ReLU(True)
self.deep = deep
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
norm_layer(64), activation)
### downsample
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
norm_layer(128), activation)
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
norm_layer(256), activation)
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)
if self.deep:
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)
### resnet blocks
BN = []
for i in range(n_blocks):
BN += [
ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)]
self.BottleNeck = nn.Sequential(*BN)
if self.deep:
self.up4 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512), activation
)
self.up3 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256), activation
)
self.up2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128), activation
)
self.up1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64), activation
)
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0))
def forward(self, input, dlatents):
x = input # 3*224*224
skip1 = self.first_layer(x)
skip2 = self.down1(skip1)
skip3 = self.down2(skip2)
if self.deep:
skip4 = self.down3(skip3)
x = self.down4(skip4)
else:
x = self.down3(skip3)
bot = []
bot.append(x)
features = []
for i in range(len(self.BottleNeck)):
x = self.BottleNeck[i](x, dlatents)
bot.append(x)
if self.deep:
x = self.up4(x)
features.append(x)
x = self.up3(x)
features.append(x)
x = self.up2(x)
features.append(x)
x = self.up1(x)
features.append(x)
x = self.last_layer(x)
# x = (x + 1) / 2
# return x, bot, features, dlatents
return x
+122
View File
@@ -0,0 +1,122 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: fs_model_fix_idnorm_donggp_saveoptim copy.py
# Created Date: Wednesday January 12th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 20th April 2022 6:34:47 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
from .base_model import BaseModel
from .fs_networks_fix import Generator_Adain_Upsample
from pg_modules.projected_discriminator import ProjectedDiscriminator
def compute_grad2(d_out, x_in):
batch_size = x_in.size(0)
grad_dout = torch.autograd.grad(
outputs=d_out.sum(), inputs=x_in,
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_dout2 = grad_dout.pow(2)
assert(grad_dout2.size() == x_in.size())
reg = grad_dout2.view(batch_size, -1).sum(1)
return reg
class fsModel(BaseModel):
def name(self):
return 'fsModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
# if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
self.isTrain = opt.isTrain
# Generator network
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep)
self.netG.cuda()
# Id network
netArc_checkpoint = opt.Arc_path
netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
self.netArc = netArc_checkpoint['model'].module
self.netArc = self.netArc.cuda()
self.netArc.eval()
self.netArc.requires_grad_(False)
if not self.isTrain:
pretrained_path = opt.checkpoints_dir
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
return
self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
# self.netD.feature_network.requires_grad_(False)
self.netD.cuda()
if self.isTrain:
# define loss functions
self.criterionFeat = nn.L1Loss()
self.criterionRec = nn.L1Loss()
# initialize optimizers
# optimizer G
params = list(self.netG.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
# load networks
if opt.continue_train:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
# print (pretrained_path)
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
torch.cuda.empty_cache()
def cosin_metric(self, x1, x2):
#return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch)
self.save_network(self.netD, 'D', which_epoch)
self.save_optim(self.optimizer_G, 'G', which_epoch,)
self.save_optim(self.optimizer_D, 'D', which_epoch)
'''if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
def update_fixed_params(self):
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
+14
View File
@@ -0,0 +1,14 @@
import torch.nn as nn
class ProjectionHead(nn.Module):
def __init__(self, proj_dim=256):
super(ProjectionHead, self).__init__()
self.proj = nn.Sequential(
nn.Linear(proj_dim, proj_dim),
nn.ReLU(),
nn.Linear(proj_dim, proj_dim),
)
def forward(self, x):
return self.proj(x)
+2 -2
View File
@@ -15,7 +15,7 @@ class TestOptions(BaseOptions):
self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--which_epoch', type=str, default='9900000', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map')
@@ -35,4 +35,4 @@ class TestOptions(BaseOptions):
self.parser.add_argument('--use_mask', action='store_true', help='Use mask for better result')
self.parser.add_argument('--crop_size', type=int, default=224, help='Crop of size of input image')
self.isTrain = False
self.isTrain = False
+325
View File
@@ -0,0 +1,325 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
### single layers
def conv2d(*args, **kwargs):
return spectral_norm(nn.Conv2d(*args, **kwargs))
def convTranspose2d(*args, **kwargs):
return spectral_norm(nn.ConvTranspose2d(*args, **kwargs))
def embedding(*args, **kwargs):
return spectral_norm(nn.Embedding(*args, **kwargs))
def linear(*args, **kwargs):
return spectral_norm(nn.Linear(*args, **kwargs))
def NormLayer(c, mode='batch'):
if mode == 'group':
return nn.GroupNorm(c//2, c)
elif mode == 'batch':
return nn.BatchNorm2d(c)
### Activations
class GLU(nn.Module):
def forward(self, x):
nc = x.size(1)
assert nc % 2 == 0, 'channels dont divide 2!'
nc = int(nc/2)
return x[:, :nc] * torch.sigmoid(x[:, nc:])
class Swish(nn.Module):
def forward(self, feat):
return feat * torch.sigmoid(feat)
### Upblocks
class InitLayer(nn.Module):
def __init__(self, nz, channel, sz=4):
super().__init__()
self.init = nn.Sequential(
convTranspose2d(nz, channel*2, sz, 1, 0, bias=False),
NormLayer(channel*2),
GLU(),
)
def forward(self, noise):
noise = noise.view(noise.shape[0], -1, 1, 1)
return self.init(noise)
def UpBlockSmall(in_planes, out_planes):
block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
NormLayer(out_planes*2), GLU())
return block
class UpBlockSmallCond(nn.Module):
def __init__(self, in_planes, out_planes, z_dim):
super().__init__()
self.in_planes = in_planes
self.out_planes = out_planes
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
self.bn = which_bn(2*out_planes)
self.act = GLU()
def forward(self, x, c):
x = self.up(x)
x = self.conv(x)
x = self.bn(x, c)
x = self.act(x)
return x
def UpBlockBig(in_planes, out_planes):
block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
NoiseInjection(),
NormLayer(out_planes*2), GLU(),
conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False),
NoiseInjection(),
NormLayer(out_planes*2), GLU()
)
return block
class UpBlockBigCond(nn.Module):
def __init__(self, in_planes, out_planes, z_dim):
super().__init__()
self.in_planes = in_planes
self.out_planes = out_planes
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False)
which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
self.bn1 = which_bn(2*out_planes)
self.bn2 = which_bn(2*out_planes)
self.act = GLU()
self.noise = NoiseInjection()
def forward(self, x, c):
# block 1
x = self.up(x)
x = self.conv1(x)
x = self.noise(x)
x = self.bn1(x, c)
x = self.act(x)
# block 2
x = self.conv2(x)
x = self.noise(x)
x = self.bn2(x, c)
x = self.act(x)
return x
class SEBlock(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.main = nn.Sequential(
nn.AdaptiveAvgPool2d(4),
conv2d(ch_in, ch_out, 4, 1, 0, bias=False),
Swish(),
conv2d(ch_out, ch_out, 1, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, feat_small, feat_big):
return feat_big * self.main(feat_small)
### Downblocks
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=False):
super(SeparableConv2d, self).__init__()
self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size,
groups=in_channels, bias=bias, padding=1)
self.pointwise = conv2d(in_channels, out_channels,
kernel_size=1, bias=bias)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class DownBlock(nn.Module):
def __init__(self, in_planes, out_planes, separable=False):
super().__init__()
if not separable:
self.main = nn.Sequential(
conv2d(in_planes, out_planes, 4, 2, 1),
NormLayer(out_planes),
nn.LeakyReLU(0.2, inplace=True),
)
else:
self.main = nn.Sequential(
SeparableConv2d(in_planes, out_planes, 3),
NormLayer(out_planes),
nn.LeakyReLU(0.2, inplace=True),
nn.AvgPool2d(2, 2),
)
def forward(self, feat):
return self.main(feat)
class DownBlockPatch(nn.Module):
def __init__(self, in_planes, out_planes, separable=False):
super().__init__()
self.main = nn.Sequential(
DownBlock(in_planes, out_planes, separable),
conv2d(out_planes, out_planes, 1, 1, 0, bias=False),
NormLayer(out_planes),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, feat):
return self.main(feat)
### CSM
class ResidualConvUnit(nn.Module):
def __init__(self, cin, activation, bn):
super().__init__()
self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
return self.skip_add.add(self.conv(x), x)
class FeatureFusionBlock(nn.Module):
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False):
super().__init__()
self.deconv = deconv
self.align_corners = align_corners
self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
output = xs[0]
if len(xs) == 2:
output = self.skip_add.add(output, xs[1])
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
### Misc
class NoiseInjection(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
def forward(self, feat, noise=None):
if noise is None:
batch, _, height, width = feat.shape
noise = torch.randn(batch, 1, height, width).to(feat.device)
return feat + self.weight * noise
class CCBN(nn.Module):
''' conditional batchnorm '''
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1):
super().__init__()
self.output_size, self.input_size = output_size, input_size
# Prepare gain and bias layers
self.gain = which_linear(input_size, output_size)
self.bias = which_linear(input_size, output_size)
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
def forward(self, x, y):
# Calculate class-conditional gains and biases
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
bias = self.bias(y).view(y.size(0), -1, 1, 1)
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
self.training, 0.1, self.eps)
return out * gain + bias
class Interpolate(nn.Module):
"""Interpolation module."""
def __init__(self, size, mode='bilinear', align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x,
size=self.size,
mode=self.mode,
align_corners=self.align_corners,
)
return x
+76
View File
@@ -0,0 +1,76 @@
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738
import torch
import torch.nn.functional as F
def DiffAugment(x, policy='', channels_first=True):
if policy:
if not channels_first:
x = x.permute(0, 3, 1, 2)
for p in policy.split(','):
for f in AUGMENT_FNS[p]:
x = f(x)
if not channels_first:
x = x.permute(0, 2, 3, 1)
x = x.contiguous()
return x
def rand_brightness(x):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x
def rand_saturation(x):
x_mean = x.mean(dim=1, keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
return x
def rand_contrast(x):
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
return x
def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
return x
def rand_cutout(x, ratio=0.2):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
x = x * mask.unsqueeze(1)
return x
AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
+191
View File
@@ -0,0 +1,191 @@
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from pg_modules.blocks import DownBlock, DownBlockPatch, conv2d
from pg_modules.projector import F_RandomProj
from pg_modules.diffaug import DiffAugment
class SingleDisc(nn.Module):
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False):
super().__init__()
channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
256: 32, 512: 16, 1024: 8}
# interpolate for start sz that are not powers of two
if start_sz not in channel_dict.keys():
sizes = np.array(list(channel_dict.keys()))
start_sz = sizes[np.argmin(abs(sizes - start_sz))]
self.start_sz = start_sz
# if given ndf, allocate all layers with the same ndf
if ndf is None:
nfc = channel_dict
else:
nfc = {k: ndf for k, v in channel_dict.items()}
# for feature map discriminators with nfc not in channel_dict
# this is the case for the pretrained backbone (midas.pretrained)
if nc is not None and head is None:
nfc[start_sz] = nc
layers = []
# Head if the initial input is the full modality
if head:
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True)]
# Down Blocks
DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
while start_sz > end_sz:
layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
start_sz = start_sz // 2
layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
self.main = nn.Sequential(*layers)
def forward(self, x, c):
return self.main(x)
class SingleDiscCond(nn.Module):
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128):
super().__init__()
self.cmap_dim = cmap_dim
# midas channels
channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
256: 32, 512: 16, 1024: 8}
# interpolate for start sz that are not powers of two
if start_sz not in channel_dict.keys():
sizes = np.array(list(channel_dict.keys()))
start_sz = sizes[np.argmin(abs(sizes - start_sz))]
self.start_sz = start_sz
# if given ndf, allocate all layers with the same ndf
if ndf is None:
nfc = channel_dict
else:
nfc = {k: ndf for k, v in channel_dict.items()}
# for feature map discriminators with nfc not in channel_dict
# this is the case for the pretrained backbone (midas.pretrained)
if nc is not None and head is None:
nfc[start_sz] = nc
layers = []
# Head if the initial input is the full modality
if head:
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True)]
# Down Blocks
DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
while start_sz > end_sz:
layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
start_sz = start_sz // 2
self.main = nn.Sequential(*layers)
# additions for conditioning on class information
self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False)
self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim)
self.embed_proj = nn.Sequential(
nn.Linear(self.embed.embedding_dim, self.cmap_dim),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, x, c):
h = self.main(x)
out = self.cls(h)
# conditioning via projection
cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1)
out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
return out
class MultiScaleD(nn.Module):
def __init__(
self,
channels,
resolutions,
num_discs=4,
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
cond=0,
separable=False,
patch=False,
**kwargs,
):
super().__init__()
assert num_discs in [1, 2, 3, 4]
# the first disc is on the lowest level of the backbone
self.disc_in_channels = channels[:num_discs]
self.disc_in_res = resolutions[:num_discs]
Disc = SingleDiscCond if cond else SingleDisc
mini_discs = []
for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
start_sz = res if not patch else 16
mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)],
self.mini_discs = nn.ModuleDict(mini_discs)
def forward(self, features, c):
all_logits = []
for k, disc in self.mini_discs.items():
res = disc(features[k], c).view(features[k].size(0), -1)
all_logits.append(res)
all_logits = torch.cat(all_logits, dim=1)
return all_logits
class ProjectedDiscriminator(torch.nn.Module):
def __init__(
self,
diffaug=True,
interp224=True,
backbone_kwargs={},
**kwargs
):
super().__init__()
self.diffaug = diffaug
self.interp224 = interp224
self.feature_network = F_RandomProj(**backbone_kwargs)
self.discriminator = MultiScaleD(
channels=self.feature_network.CHANNELS,
resolutions=self.feature_network.RESOLUTIONS,
**backbone_kwargs,
)
def train(self, mode=True):
self.feature_network = self.feature_network.train(False)
self.discriminator = self.discriminator.train(mode)
return self
def eval(self):
return self.train(False)
def get_feature(self, x):
features = self.feature_network(x, get_features=True)
return features
def forward(self, x, c):
# if self.diffaug:
# x = DiffAugment(x, policy='color,translation,cutout')
# if self.interp224:
# x = F.interpolate(x, 224, mode='bilinear', align_corners=False)
features,backbone_features = self.feature_network(x)
logits = self.discriminator(features, c)
return logits,backbone_features
+158
View File
@@ -0,0 +1,158 @@
import torch
import torch.nn as nn
import timm
from pg_modules.blocks import FeatureFusionBlock
def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
# shapes
out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
scratch.CHANNELS = out_channels
return scratch
def _make_scratch_csm(scratch, in_channels, cout, expand):
scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
# last refinenet does not expand to save channels in higher dimensions
scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
return scratch
def _make_efficientnet(model):
pretrained = nn.Module()
pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2])
pretrained.layer1 = nn.Sequential(*model.blocks[2:3])
pretrained.layer2 = nn.Sequential(*model.blocks[3:5])
pretrained.layer3 = nn.Sequential(*model.blocks[5:9])
return pretrained
def calc_channels(pretrained, inp_res=224):
channels = []
tmp = torch.zeros(1, 3, inp_res, inp_res)
# forward pass
tmp = pretrained.layer0(tmp)
channels.append(tmp.shape[1])
tmp = pretrained.layer1(tmp)
channels.append(tmp.shape[1])
tmp = pretrained.layer2(tmp)
channels.append(tmp.shape[1])
tmp = pretrained.layer3(tmp)
channels.append(tmp.shape[1])
return channels
def _make_projector(im_res, cout, proj_type, expand=False):
assert proj_type in [0, 1, 2], "Invalid projection type"
### Build pretrained feature network
model = timm.create_model('tf_efficientnet_lite0', pretrained=True)
pretrained = _make_efficientnet(model)
# determine resolution of feature maps, this is later used to calculate the number
# of down blocks in the discriminators. Interestingly, the best results are achieved
# by fixing this to 256, ie., we use the same number of down blocks per discriminator
# independent of the dataset resolution
im_res = 256
pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
pretrained.CHANNELS = calc_channels(pretrained)
if proj_type == 0: return pretrained, None
### Build CCM
scratch = nn.Module()
scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
pretrained.CHANNELS = scratch.CHANNELS
if proj_type == 1: return pretrained, scratch
### build CSM
scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
# CSM upsamples x2 so the feature map resolution doubles
pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
pretrained.CHANNELS = scratch.CHANNELS
return pretrained, scratch
class F_RandomProj(nn.Module):
def __init__(
self,
im_res=256,
cout=64,
expand=True,
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
**kwargs,
):
super().__init__()
self.proj_type = proj_type
self.cout = cout
self.expand = expand
# build pretrained feature network and random decoder (scratch)
self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand)
self.CHANNELS = self.pretrained.CHANNELS
self.RESOLUTIONS = self.pretrained.RESOLUTIONS
def forward(self, x, get_features=False):
# predict feature maps
out0 = self.pretrained.layer0(x)
out1 = self.pretrained.layer1(out0)
out2 = self.pretrained.layer2(out1)
out3 = self.pretrained.layer3(out2)
# start enumerating at the lowest layer (this is where we put the first discriminator)
backbone_features = {
'0': out0,
'1': out1,
'2': out2,
'3': out3,
}
if get_features:
return backbone_features
if self.proj_type == 0: return backbone_features
out0_channel_mixed = self.scratch.layer0_ccm(backbone_features['0'])
out1_channel_mixed = self.scratch.layer1_ccm(backbone_features['1'])
out2_channel_mixed = self.scratch.layer2_ccm(backbone_features['2'])
out3_channel_mixed = self.scratch.layer3_ccm(backbone_features['3'])
out = {
'0': out0_channel_mixed,
'1': out1_channel_mixed,
'2': out2_channel_mixed,
'3': out3_channel_mixed,
}
if self.proj_type == 1: return out
# from bottom to top
out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
out = {
'0': out0_scale_mixed,
'1': out1_scale_mixed,
'2': out2_scale_mixed,
'3': out3_scale_mixed,
}
return out, backbone_features
+293
View File
@@ -0,0 +1,293 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: train.py
# Created Date: Monday December 27th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 20th April 2022 6:33:30 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import time
import wandb
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.backends import cudnn
import torch.utils.tensorboard as tensorboard
from util import util
from util.plot import plot_batch
from models.projected_model import fsModel
from data.data_loader_Swapping import GetLoader
class TrainOptions:
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self):
self.parser.add_argument('--name', type=str, default='simswap', help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--gpu_ids', default='0')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--isTrain', type=bool, default=True)
# input/output sizes
self.parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
# for displays
self.parser.add_argument('--tag', type=str, default='simswap')
# for training
self.parser.add_argument('--dataset', type=str, default="G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan", help='path to the face swapping dataset')
self.parser.add_argument('--continue_train', type=bool, default=False, help='continue training: load the latest model')
self.parser.add_argument('--load_pretrain', type=str, default='checkpoints', help='load the pretrained model from the specified location')
self.parser.add_argument('--which_epoch', type=str, default='800000', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=10000, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam')
self.parser.add_argument("--Gdeep",type=bool,default=False)
self.parser.add_argument("--train_simswap",type=bool,default=True)
# for discriminators
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss')
self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss')
self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT")
self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step')
self.parser.add_argument("--log_frep", type=int, default=250, help='frequence for printing log information')
self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling')
self.parser.add_argument("--model_freq", type=int, default=10000, help='frequence for saving the model')
self.isTrain = True
def parse(self, save=True):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
args = vars(self.opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
if self.opt.isTrain:
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
util.mkdirs(expr_dir)
if save and not self.opt.continue_train:
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt
if __name__ == '__main__':
opt = TrainOptions().parse()
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
sample_path = os.path.join(opt.checkpoints_dir, opt.name, 'samples')
if not os.path.exists(sample_path):
os.makedirs(sample_path)
log_path = os.path.join(opt.checkpoints_dir, opt.name, 'summary')
if not os.path.exists(log_path):
os.makedirs(log_path)
if opt.continue_train:
try:
start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
except:
start_epoch, epoch_iter = 1, 0
print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
else:
start_epoch, epoch_iter = 1, 0
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_ids)
print("GPU used : ", str(opt.gpu_ids))
cudnn.benchmark = True
model = fsModel()
model.initialize(opt)
#####################################################
tensorboard_writer = tensorboard.SummaryWriter(log_path)
logger = tensorboard_writer
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
optimizer_G, optimizer_D = model.optimizer_G, model.optimizer_D
loss_avg = 0
refresh_count = 0
imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
train_loader = GetLoader(opt.dataset,opt.batchSize,8,1234)
randindex = [i for i in range(opt.batchSize)]
random.shuffle(randindex)
if not opt.continue_train:
start = 0
else:
start = int(opt.which_epoch)
total_step = opt.total_step
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from util.logo_class import logo_class
logo_class.print_start_training()
model.netD.feature_network.requires_grad_(False)
# Training Cycle
for step in range(start, total_step):
model.netG.train()
for interval in range(2):
random.shuffle(randindex)
src_image1, src_image2 = train_loader.next()
if opt.train_simswap:
src_image1 = F.interpolate(src_image1,size=(256,256), mode='bicubic')
src_image2 = F.interpolate(src_image2,size=(256,256), mode='bicubic')
if step%2 == 0:
img_id = src_image2
else:
img_id = src_image2[randindex]
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
latent_id = model.netArc(img_id_112)
latent_id = F.normalize(latent_id, p=2, dim=1)
if interval:
img_fake = model.netG(src_image1, latent_id)
gen_logits,_ = model.netD(img_fake.detach(), None)
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
real_logits,_ = model.netD(src_image2,None)
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
loss_D = loss_Dgen + loss_Dreal
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
else:
# model.netD.requires_grad_(True)
img_fake = model.netG(src_image1, latent_id)
# G loss
gen_logits,feat = model.netD(img_fake, None)
loss_Gmain = (-gen_logits).mean()
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
latent_fake = model.netArc(img_fake_down)
latent_fake = F.normalize(latent_fake, p=2, dim=1)
loss_G_ID = (1 - model.cosin_metric(latent_fake, latent_id)).mean()
real_feat = model.netD.get_feature(src_image1)
feat_match_loss = model.criterionFeat(feat["3"],real_feat["3"])
loss_G = loss_Gmain + loss_G_ID * opt.lambda_id + feat_match_loss * opt.lambda_feat
if step%2 == 0:
#G_Rec
loss_G_Rec = model.criterionRec(img_fake, src_image1) * opt.lambda_rec
loss_G += loss_G_Rec
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
############## Display results and errors ##########
### print out errors
# Print out log info
if (step + 1) % opt.log_frep == 0:
# errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
errors = {
"G_Loss":loss_Gmain.item(),
"G_ID":loss_G_ID.item(),
"G_Rec":loss_G_Rec.item(),
"G_feat_match":feat_match_loss.item(),
"D_fake":loss_Dgen.item(),
"D_real":loss_Dreal.item(),
"D_loss":loss_D.item()
}
for tag, value in errors.items():
logger.add_scalar(tag, value, step)
message = '( step: %d, ) ' % (step)
for k, v in errors.items():
message += '%s: %.3f ' % (k, v)
print(message)
with open(log_name, "a") as log_file:
log_file.write('%s\n' % message)
### display output images
if (step + 1) % opt.sample_freq == 0:
model.netG.eval()
with torch.no_grad():
imgs = list()
zero_img = (torch.zeros_like(src_image1[0,...]))
imgs.append(zero_img.cpu().numpy())
save_img = ((src_image1.cpu())* imagenet_std + imagenet_mean).numpy()
for r in range(opt.batchSize):
imgs.append(save_img[r,...])
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
id_vector_src1 = model.netArc(arcface_112)
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
for i in range(opt.batchSize):
imgs.append(save_img[i,...])
image_infer = src_image1[i, ...].repeat(opt.batchSize, 1, 1, 1)
img_fake = model.netG(image_infer, id_vector_src1).cpu()
img_fake = img_fake * imagenet_std
img_fake = img_fake + imagenet_mean
img_fake = img_fake.numpy()
for j in range(opt.batchSize):
imgs.append(img_fake[j,...])
print("Save test data")
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
plot_batch(imgs, os.path.join(sample_path, 'step_'+str(step+1)+'.jpg'))
### save latest model
if (step+1) % opt.model_freq==0:
print('saving the latest model (steps %d)' % (step+1))
model.save(step+1)
np.savetxt(iter_path, (step+1, total_step), delimiter=',', fmt='%d')
wandb.finish()
+15
View File
@@ -0,0 +1,15 @@
import json
def readConfig(path):
with open(path,'r') as cf:
nodelocaltionstr = cf.read()
nodelocaltioninf = json.loads(nodelocaltionstr)
if isinstance(nodelocaltioninf,str):
nodelocaltioninf = json.loads(nodelocaltioninf)
return nodelocaltioninf
def writeConfig(path, info):
with open(path, 'w') as cf:
configjson = json.dumps(info, indent=4)
cf.writelines(configjson)
+44
View File
@@ -0,0 +1,44 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: logo_class.py
# Created Date: Tuesday June 29th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 11th October 2021 12:39:55 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
class logo_class:
@staticmethod
def print_group_logo():
logo_str = """
███╗ ██╗██████╗ ███████╗██╗ ██████╗ ███████╗ ██╗████████╗██╗ ██╗
████╗ ██║██╔══██╗██╔════╝██║██╔════╝ ██╔════╝ ██║╚══██╔══╝██║ ██║
██╔██╗ ██║██████╔╝███████╗██║██║ ███╗ ███████╗ ██║ ██║ ██║ ██║
██║╚██╗██║██╔══██╗╚════██║██║██║ ██║ ╚════██║██ ██║ ██║ ██║ ██║
██║ ╚████║██║ ██║███████║██║╚██████╔╝ ███████║╚█████╔╝ ██║ ╚██████╔╝
╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═════╝ ╚══════╝ ╚════╝ ╚═╝ ╚═════╝
Neural Rendering Special Interesting Group of SJTU
"""
print(logo_str)
@staticmethod
def print_start_training():
logo_str = """
_____ __ __ ______ _ _
/ ___/ / /_ ____ _ _____ / /_ /_ __/_____ ____ _ (_)____ (_)____ ____ _
\__ \ / __// __ `// ___// __/ / / / ___// __ `// // __ \ / // __ \ / __ `/
___/ // /_ / /_/ // / / /_ / / / / / /_/ // // / / // // / / // /_/ /
/____/ \__/ \__,_//_/ \__/ /_/ /_/ \__,_//_//_/ /_//_//_/ /_/ \__, /
/____/
"""
print(logo_str)
if __name__=="__main__":
# logo_class.print_group_logo()
logo_class.print_start_training()
+37
View File
@@ -0,0 +1,37 @@
import numpy as np
import math
import PIL
def postprocess(x):
"""[0,1] to uint8."""
x = np.clip(255 * x, 0, 255)
x = np.cast[np.uint8](x)
return x
def tile(X, rows, cols):
"""Tile images for display."""
tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype)
for i in range(rows):
for j in range(cols):
idx = i * cols + j
if idx < X.shape[0]:
img = X[idx,...]
tiling[
i*X.shape[1]:(i+1)*X.shape[1],
j*X.shape[2]:(j+1)*X.shape[2],
:] = img
return tiling
def plot_batch(X, out_path):
"""Save batch of images tiled."""
n_channels = X.shape[3]
if n_channels > 3:
X = X[:,:,:,np.random.choice(n_channels, size = 3)]
X = postprocess(X)
rc = math.sqrt(X.shape[0])
rows = cols = math.ceil(rc)
canvas = tile(X, rows, cols)
canvas = np.squeeze(canvas)
PIL.Image.fromarray(canvas).save(out_path)
+57
View File
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: save_heatmap.py
# Created Date: Friday January 15th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 19th January 2022 1:22:47 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import shutil
import seaborn as sns
import matplotlib.pyplot as plt
import cv2
import numpy as np
def SaveHeatmap(heatmaps, path, row=-1, dpi=72):
"""
The input tensor must be B X 1 X H X W
"""
batch_size = heatmaps.shape[0]
temp_path = ".temp/"
if not os.path.exists(temp_path):
os.makedirs(temp_path)
final_img = None
if row < 1:
col = batch_size
row = 1
else:
col = batch_size // row
if row * col <batch_size:
col +=1
row_i = 0
col_i = 0
for i in range(batch_size):
img_path = os.path.join(temp_path,'temp_batch_{}.png'.format(i))
sns.heatmap(heatmaps[i,0,:,:],vmin=0,vmax=heatmaps[i,0,:,:].max(),cbar=False)
plt.savefig(img_path, dpi=dpi, bbox_inches = 'tight', pad_inches = 0)
img = cv2.imread(img_path)
if i == 0:
H,W,C = img.shape
final_img = np.zeros((H*row,W*col,C))
final_img[H*row_i:H*(row_i+1),W*col_i:W*(col_i+1),:] = img
col_i += 1
if col_i >= col:
col_i = 0
row_i += 1
cv2.imwrite(path,final_img)
if __name__ == "__main__":
random_map = np.random.randn(16,1,10,10)
SaveHeatmap(random_map,"./wocao.png",1)