This commit is contained in:
chenxuanhong
2022-01-24 19:01:00 +08:00
parent 0f8c2f929e
commit 94534e2e30
10 changed files with 765 additions and 79 deletions
+22 -6
View File
@@ -1,7 +1,7 @@
{
"GUI.py": 1642351532.4558506,
"test.py": 1634039043.4872007,
"train.py": 1642408973.489742,
"test.py": 1642733759.8015468,
"train.py": 1643009568.7253726,
"components\\Generator.py": 1642347735.351465,
"components\\projected_discriminator.py": 1642348101.4661522,
"components\\pg_modules\\blocks.py": 1640773190.0,
@@ -23,7 +23,7 @@
"test_scripts\\tester_common.py": 1625369535.199175,
"test_scripts\\tester_FastNST.py": 1634041357.607633,
"train_scripts\\trainer_base.py": 1642396105.3868554,
"train_scripts\\trainer_FM.py": 1642396334.407562,
"train_scripts\\trainer_FM.py": 1642826710.3532298,
"train_scripts\\trainer_naiv512.py": 1642315674.9740853,
"utilities\\checkpoint_manager.py": 1611123530.6624403,
"utilities\\figure.py": 1611123530.6634378,
@@ -37,7 +37,23 @@
"utilities\\transfer_checkpoint.py": 1642397157.0163105,
"utilities\\utilities.py": 1634019485.0783668,
"utilities\\yaml_config.py": 1611123530.6614666,
"train_yamls\\train_512FM.yaml": 1642408586.2747152,
"train_scripts\\trainer_2layer_FM.py": 1642408727.3950334,
"train_yamls\\train_2layer_FM.yaml": 1642408813.9488862
"train_yamls\\train_512FM.yaml": 1642412254.0831068,
"train_scripts\\trainer_2layer_FM.py": 1642826548.2530458,
"train_yamls\\train_2layer_FM.yaml": 1642411635.5534878,
"components\\Generator_reduce.py": 1642690262.572149,
"insightface_func\\face_detect_crop_multi.py": 1638370471.789609,
"insightface_func\\face_detect_crop_single.py": 1638370471.7967434,
"insightface_func\\__init__.py": 1624197300.011183,
"insightface_func\\utils\\face_align_ffhqandnewarc.py": 1638370471.850638,
"losses\\PatchNCE.py": 1642989384.9713614,
"parsing_model\\model.py": 1626745709.554252,
"parsing_model\\resnet.py": 1626745709.554252,
"test_scripts\\tester_common copy.py": 1625369535.199175,
"test_scripts\\tester_video.py": 1642734397.3307388,
"train_scripts\\trainer_cycleloss.py": 1642580463.495596,
"train_scripts\\trainer_GramFM.py": 1643010471.1077821,
"utilities\\ImagenetNorm.py": 1642732910.5280058,
"utilities\\reverse2original.py": 1642733688.7976837,
"train_yamls\\train_cycleloss.yaml": 1642577741.345273,
"train_yamls\\train_GramFM.yaml": 1643011210.82505
}
+25 -56
View File
@@ -5,7 +5,7 @@
# Created Date: Sunday January 16th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 20th January 2022 10:51:02 pm
# Last Modified: Monday, 24th January 2022 6:47:22 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -15,17 +15,16 @@ from torch import nn
from torch.nn import init
from torch.nn import functional as F
class InstanceNorm(nn.Module):
class Demodule(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(InstanceNorm, self).__init__()
super(Demodule, self).__init__()
self.epsilon = epsilon
def forward(self, x):
x = x - torch.mean(x, (2, 3), True)
tmp = torch.mul(x, x) # or x ** 2
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
return x * tmp
@@ -46,9 +45,22 @@ class ApplyStyle(nn.Module):
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
return x
class ResnetBlock_Adain(nn.Module):
class Modulation(nn.Module):
def __init__(self, latent_size, channels):
super(Modulation, self).__init__()
self.linear = nn.Linear(latent_size, channels)
def forward(self, x, latent):
style = self.linear(latent) # style => [batch_size, n_channels*2]
shape = [-1, x.size(1), 1, 1]
style = style.view(shape) # [batch_size, 2, n_channels, ...]
#x = x * (style[:, 0] + 1.) + style[:, 1]
x = x * style
return x
class ResnetBlock_Modulation(nn.Module):
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
super(ResnetBlock_Adain, self).__init__()
super(ResnetBlock_Modulation, self).__init__()
p = 0
conv1 = []
@@ -60,9 +72,9 @@ class ResnetBlock_Adain(nn.Module):
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), Demodule()]
self.conv1 = nn.Sequential(*conv1)
self.style1 = ApplyStyle(latent_size, dim)
self.style1 = Modulation(latent_size, dim)
self.act1 = activation
p = 0
@@ -75,52 +87,9 @@ class ResnetBlock_Adain(nn.Module):
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), Demodule()]
self.conv2 = nn.Sequential(*conv2)
self.style2 = ApplyStyle(latent_size, dim)
def forward(self, x, dlatents_in_slice):
y = self.conv1(x)
y = self.style1(y, dlatents_in_slice)
y = self.act1(y)
y = self.conv2(y)
y = self.style2(y, dlatents_in_slice)
out = x + y
return out
class newres(nn.Module):
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
super(ResnetBlock_Adain, self).__init__()
p = 0
conv1 = []
if padding_type == 'reflect':
conv1 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv1 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv1 += [nn.Conv2d(dim, dim, kernel_size=1), InstanceNorm()]
self.conv1 = nn.Sequential(*conv1)
self.style1 = ApplyStyle(latent_size, dim)
self.act1 = activation
p = 0
conv2 = []
if padding_type == 'reflect':
conv2 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv2 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
self.conv2 = nn.Sequential(*conv2)
self.style2 = ApplyStyle(latent_size, dim)
self.style2 = Modulation(latent_size, dim)
def forward(self, x, dlatents_in_slice):
@@ -148,7 +117,7 @@ class Generator(nn.Module):
activation = nn.ReLU(True)
self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),
self.stem = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64), activation)
### downsample
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
@@ -167,7 +136,7 @@ class Generator(nn.Module):
BN = []
for i in range(res_num):
BN += [
ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)]
ResnetBlock_Modulation(512, latent_size=chn, padding_type=padding_type, activation=activation)]
self.BottleNeck = nn.Sequential(*BN)
self.up4 = nn.Sequential(
@@ -210,7 +179,7 @@ class Generator(nn.Module):
def forward(self, input, id):
x = input # 3*224*224
skip1 = self.first_layer(x)
skip1 = self.stem(x)
skip2 = self.down1(skip1)
skip3 = self.down2(skip2)
skip4 = self.down3(skip3)
+217 -1
View File
@@ -5,9 +5,225 @@
# Created Date: Friday January 21st 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 21st January 2022 5:04:43 pm
# Last Modified: Monday, 24th January 2022 9:56:24 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from packaging import version
import numpy as np
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
# if not amp:
# net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
if initialize_weights:
init_weights(net, init_type, init_gain=init_gain, debug=debug)
return net
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if debug:
print(classname)
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
net.apply(init_func) # apply the initialization function <init_func>
class Normalize(nn.Module):
def __init__(self, power=2):
super(Normalize, self).__init__()
self.power = power
def forward(self, x):
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
out = x.div(norm + 1e-7)
return out
class PatchSampleF(nn.Module):
def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
# potential issues: currently, we use the same patch_ids for multiple images in the batch
super(PatchSampleF, self).__init__()
self.l2norm = Normalize(2)
self.use_mlp = use_mlp
self.nc = nc # hard-coded
self.mlp_init = False
self.init_type = init_type
self.init_gain = init_gain
self.gpu_ids = gpu_ids
def create_mlp(self, feats):
for mlp_id, feat in enumerate(feats):
input_nc = feat.shape[1]
mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
if len(self.gpu_ids) > 0:
mlp.cuda()
setattr(self, 'mlp_%d' % mlp_id, mlp)
init_net(self, self.init_type, self.init_gain, self.gpu_ids)
self.mlp_init = True
def forward(self, feats, num_patches=64, patch_ids=None):
return_ids = []
return_feats = []
if self.use_mlp and not self.mlp_init:
self.create_mlp(feats)
for feat_id, feat in enumerate(feats):
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
if num_patches > 0:
if patch_ids is not None:
patch_id = patch_ids[feat_id]
else:
# torch.randperm produces cudaErrorIllegalAddress for newer versions of PyTorch. https://github.com/taesungp/contrastive-unpaired-translation/issues/83
#patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
patch_id = np.random.permutation(feat_reshape.shape[1])
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
else:
x_sample = feat_reshape
patch_id = []
if self.use_mlp:
mlp = getattr(self, 'mlp_%d' % feat_id)
x_sample = mlp(x_sample)
return_ids.append(patch_id)
x_sample = self.l2norm(x_sample)
if num_patches == 0:
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
return_feats.append(x_sample)
return return_feats, return_ids
# def calculate_NCE_loss(src, tgt):
# n_layers = len(self.nce_layers)
# feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
# if self.opt.flip_equivariance and self.flipped_for_equivariance:
# feat_q = [torch.flip(fq, [3]) for fq in feat_q]
# feat_k = self.netG(src, self.nce_layers, encode_only=True)
# feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
# feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
# total_nce_loss = 0.0
# for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
# loss = crit(f_q, f_k) * self.opt.lambda_NCE
# total_nce_loss += loss.mean()
# return total_nce_loss / n_layers
class PatchNCELoss(nn.Module):
def __init__(self,batch_size, nce_T = 0.07):
super().__init__()
self.nce_T = nce_T
self.batch_size = batch_size
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
def forward(self, feat_q, feat_k):
num_patches = feat_q.shape[0]
dim = feat_q.shape[1]
feat_k = feat_k.detach()
# pos logit
l_pos = torch.bmm(
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
l_pos = l_pos.view(num_patches, 1)
# neg logit
# Should the negatives from the other samples of a minibatch be utilized?
# In CUT and FastCUT, we found that it's best to only include negatives
# from the same image. Therefore, we set
# --nce_includes_all_negatives_from_minibatch as False
# However, for single-image translation, the minibatch consists of
# crops from the "same" high-resolution image.
# Therefore, we will include the negatives from the entire minibatch.
# if self.opt.nce_includes_all_negatives_from_minibatch:
# # reshape features as if they are all negatives of minibatch of size 1.
# batch_dim_for_bmm = 1
# else:
batch_dim_for_bmm = self.batch_size
# reshape features to batch size
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
npatches = feat_q.size(1)
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
# diagonal entries are similarity between same features, and hence meaningless.
# just fill the diagonal with very small number, which is exp(-10) and almost zero
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
l_neg_curbatch.masked_fill_(diagonal, -10.0)
l_neg = l_neg_curbatch.view(-1, npatches)
out = torch.cat((l_pos, l_neg), dim=1) / self.nce_T
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
device=feat_q.device))
return loss
if __name__ == "__main__":
batch = 16
nc = 256
num_patches = 64
netF = PatchSampleF(use_mlp=True, init_type='normal', init_gain=0.02, gpu_ids=["cuda:0"], nc=nc)
feat_q = [torch.ones((batch,nc,32,32)).cuda()]
# if self.opt.flip_equivariance and self.flipped_for_equivariance:
# feat_q = [torch.flip(fq, [3]) for fq in feat_q]
crit = PatchNCELoss(batch)
feat_k = [torch.ones((batch,nc,32,32)).cuda()]
feat_k_pool, sample_ids = netF(feat_k, num_patches, None)
print(feat_k_pool[0].shape)
feat_q_pool, _ = netF(feat_q, num_patches, sample_ids)
print(feat_q_pool[0].shape)
loss = crit(feat_q_pool[0], feat_k_pool[0])
print(loss)
+3
View File
@@ -0,0 +1,3 @@
nohup python train.py > GramFM.log 2>&1 &
+5 -5
View File
@@ -5,7 +5,7 @@
# Created Date: Tuesday April 28th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 17th January 2022 4:42:53 pm
# Last Modified: Monday, 24th January 2022 3:32:47 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
@@ -31,9 +31,9 @@ def getParameters():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('-v', '--version', type=str, default='2layerFM',
parser.add_argument('-v', '--version', type=str, default='GramFM',
help="version name for train, test, finetune")
parser.add_argument('-t', '--tag', type=str, default='Feature_match',
parser.add_argument('-t', '--tag', type=str, default='Gram_Feature_match',
help="tag for current experiment")
parser.add_argument('-p', '--phase', type=str, default="train",
@@ -46,9 +46,9 @@ def getParameters():
# training
parser.add_argument('--experiment_description', type=str,
default="减小重建和feature match的权重,使用2和3的feature作为feature")
default="使用3作为feature, 尝试使用gram矩阵来计算feature matching")
parser.add_argument('--train_yaml', type=str, default="train_2layer_FM.yaml")
parser.add_argument('--train_yaml', type=str, default="train_GramFM.yaml")
# system logger
parser.add_argument('--logger', type=str,
+15 -10
View File
@@ -5,7 +5,7 @@
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Saturday, 22nd January 2022 12:45:09 pm
# Last Modified: Monday, 24th January 2022 6:56:17 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -28,6 +28,9 @@ class Trainer(TrainerBase):
config,
reporter):
super(Trainer, self).__init__(config, reporter)
import inspect
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
@@ -276,25 +279,27 @@ class Trainer(TrainerBase):
elapsed = str(datetime.timedelta(seconds=elapsed))
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
format(self.config["version"], elapsed, step, total_step, \
loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
format(self.config["version"], elapsed, step, total_step, \
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["logger"] == "tensorboard":
self.logger.add_scalar('G/G_loss', loss_G.item(), step)
self.logger.add_scalar('G/Rec_loss', loss_G_Rec.item(), step)
self.logger.add_scalar('G/Fm_loss', feat_match_loss.item(), step)
self.logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
self.logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
self.logger.add_scalar('D/D_loss', loss_D.item(), step)
self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
self.logger.add_scalar('D/D_real', loss_Dreal.item(), step)
elif self.config["logger"] == "wandb":
self.logger.log({"G_loss": loss_G.item()}, step = step)
self.logger.log({"Rec_loss": loss_G_Rec.item()}, step = step)
self.logger.log({"Fm_loss": feat_match_loss.item()}, step = step)
self.logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
self.logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
self.logger.log({"G_ID": loss_G_ID.item()}, step = step)
self.logger.log({"D_loss": loss_D.item()}, step = step)
self.logger.log({"D_fake": loss_Dgen.item()}, step = step)
self.logger.log({"D_real": loss_Dreal.item()}, step = step)
+351
View File
@@ -0,0 +1,351 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_naiv512.py
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 24th January 2022 6:23:16 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import time
import random
import numpy as np
import torch
import torch.nn.functional as F
from utilities.plot import plot_batch
from train_scripts.trainer_base import TrainerBase
from utilities.utilities import Gram
class Trainer(TrainerBase):
def __init__(self,
config,
reporter):
super(Trainer, self).__init__(config, reporter)
import inspect
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
# TODO modify this function to build your models
def init_framework(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
dscript_name = "components." + model_config["d_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
dscript_name = self.config["com_base"] + model_config["d_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
class_name = model_config["d_model"]["class_name"]
package = __import__(dscript_name, fromlist=True)
dis_class = getattr(package, class_name)
self.dis = dis_class(**model_config["d_model"]["module_params"])
self.dis.feature_network.requires_grad_(False)
# print and recorde model structure
self.reporter.writeInfo("Discriminator structure:")
self.reporter.writeModel(self.dis.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
self.dis = self.dis.cuda()
self.arcface= self.arcface.cuda()
self.arcface.eval()
self.arcface.requires_grad_(False)
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["discriminator_name"]))
self.dis.load_state_dict(torch.load(model_path))
print('loaded trained backbone model step {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
d_train_opt = self.config['d_optim_config']
g_optim_params = []
d_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
for k, v in self.dis.named_parameters():
if v.requires_grad:
d_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
if self.config["phase"] == "finetune":
opt_path = os.path.join(self.config["project_checkpoints"],
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
self.config["optimizer_names"]["generator_name"]))
self.g_optimizer.load_state_dict(torch.load(opt_path))
opt_path = os.path.join(self.config["project_checkpoints"],
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
self.config["optimizer_names"]["discriminator_name"]))
self.d_optimizer.load_state_dict(torch.load(opt_path))
print('loaded trained optimizer step {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
# Evaluate the checkpoint
def __evaluation__(self,
step = 0,
**kwargs
):
src_image1 = kwargs["src1"]
src_image2 = kwargs["src2"]
batch_size = self.batch_size
self.gen.eval()
with torch.no_grad():
imgs = []
zero_img = (torch.zeros_like(src_image1[0,...]))
imgs.append(zero_img.cpu().numpy())
save_img = ((src_image1.cpu())* self.img_std + self.img_mean).numpy()
for r in range(batch_size):
imgs.append(save_img[r,...])
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
id_vector_src1 = self.arcface(arcface_112)
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
for i in range(batch_size):
imgs.append(save_img[i,...])
image_infer = src_image1[i, ...].repeat(batch_size, 1, 1, 1)
img_fake = self.gen(image_infer, id_vector_src1).cpu()
img_fake = img_fake * self.img_std
img_fake = img_fake + self.img_mean
img_fake = img_fake.numpy()
for j in range(batch_size):
imgs.append(img_fake[j,...])
print("Save test data")
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
plot_batch(imgs, os.path.join(self.sample_dir, 'step_'+str(step+1)+'.jpg'))
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_freq = self.config["log_step"]
model_freq = self.config["model_save_step"]
sample_freq = self.config["sample_step"]
total_step = self.config["total_step"]
random_seed = self.config["dataset_params"]["random_seed"]
self.batch_size = self.config["batch_size"]
self.sample_dir = self.config["project_samples"]
self.arcface_ckpt= self.config["arcface_ckpt"]
# prep_weights= self.config["layersWeight"]
id_w = self.config["id_weight"]
rec_w = self.config["reconstruct_weight"]
feat_w = self.config["feature_match_weight"]
super().train()
#===============build losses===================#
# TODO replace below lines to build your losses
MSE_loss = torch.nn.MSELoss()
l1_loss = torch.nn.L1Loss()
cos_loss = torch.nn.CosineSimilarity()
start_time = time.time()
# Caculate the epoch number
print("Total step = %d"%total_step)
random.seed(random_seed)
randindex = [i for i in range(self.batch_size)]
random.shuffle(randindex)
import datetime
for step in range(self.start, total_step):
self.gen.train()
self.dis.train()
for interval in range(2):
random.shuffle(randindex)
src_image1, src_image2 = self.train_loader.next()
if step%2 == 0:
img_id = src_image2
else:
img_id = src_image2[randindex]
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
latent_id = self.arcface(img_id_112)
latent_id = F.normalize(latent_id, p=2, dim=1)
if interval:
img_fake = self.gen(src_image1, latent_id)
gen_logits,_ = self.dis(img_fake.detach(), None)
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
real_logits,_ = self.dis(src_image2,None)
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
loss_D = loss_Dgen + loss_Dreal
self.d_optimizer.zero_grad()
loss_D.backward()
self.d_optimizer.step()
else:
# model.netD.requires_grad_(True)
img_fake = self.gen(src_image1, latent_id)
# G loss
gen_logits,feat = self.dis(img_fake, None)
loss_Gmain = (-gen_logits).mean()
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
latent_fake = self.arcface(img_fake_down)
latent_fake = F.normalize(latent_fake, p=2, dim=1)
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
real_feat = self.dis.get_feature(src_image1)
feat_match_loss = l1_loss(Gram(feat["3"]), Gram(real_feat["3"])) + \
l1_loss(Gram(feat["2"]), Gram(real_feat["2"]))
loss_G = loss_Gmain + loss_G_ID * id_w + \
feat_match_loss * feat_w
if step%2 == 0:
#G_Rec
loss_G_Rec = l1_loss(img_fake, src_image1)
loss_G += loss_G_Rec * rec_w
self.g_optimizer.zero_grad()
loss_G.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_freq == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
format(self.config["version"], elapsed, step, total_step, \
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["logger"] == "tensorboard":
self.logger.add_scalar('G/G_loss', loss_G.item(), step)
self.logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
self.logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
self.logger.add_scalar('D/D_loss', loss_D.item(), step)
self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
self.logger.add_scalar('D/D_real', loss_Dreal.item(), step)
elif self.config["logger"] == "wandb":
self.logger.log({"G_loss": loss_G.item()}, step = step)
self.logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
self.logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
self.logger.log({"G_ID": loss_G_ID.item()}, step = step)
self.logger.log({"D_loss": loss_D.item()}, step = step)
self.logger.log({"D_fake": loss_Dgen.item()}, step = step)
self.logger.log({"D_real": loss_Dreal.item()}, step = step)
if (step + 1) % sample_freq == 0:
self.__evaluation__(
step = step,
**{
"src1": src_image1,
"src2": src_image2
})
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (step+1) % model_freq==0:
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.save(self.dis.state_dict(),
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
self.config["checkpoint_names"]["discriminator_name"])))
torch.save(self.g_optimizer.state_dict(),
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.save(self.d_optimizer.state_dict(),
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
self.config["checkpoint_names"]["discriminator_name"])))
print("Save step %d model checkpoint!"%(step+1))
torch.cuda.empty_cache()
self.__evaluation__(
step = step,
**{
"src1": src_image1,
"src2": src_image2
})
+1 -1
View File
@@ -50,7 +50,7 @@ d_optim_config:
eps: !!float 1e-8
id_weight: 20.0
reconstruct_weight: 1.0
reconstruct_weight: 10.0
feature_match_weight: 10.0
# Log
+63
View File
@@ -0,0 +1,63 @@
# Related scripts
train_script_name: FM
# models' scripts
model_configs:
g_model:
script: Generator_reduce
class_name: Generator
module_params:
g_conv_dim: 512
g_kernel_size: 3
res_num: 9
d_model:
script: projected_discriminator
class_name: ProjectedDiscriminator
module_params:
diffaug: False
interp224: False
backbone_kwargs: {}
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
# Training information
batch_size: 12
# Dataset
dataloader: VGGFace2HQ
dataset_name: vggface2_hq
dataset_params:
random_seed: 1234
dataloader_workers: 8
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: 0.0004
betas: [ 0, 0.99]
eps: !!float 1e-8
d_optim_config:
lr: 0.0004
betas: [ 0, 0.99]
eps: !!float 1e-8
id_weight: 20.0
reconstruct_weight: 10.0
feature_match_weight: 10.0
# Log
log_step: 300
model_save_step: 10000
sample_step: 1000
total_step: 1000000
checkpoint_names:
generator_name: Generator
discriminator_name: Discriminator
+63
View File
@@ -0,0 +1,63 @@
# Related scripts
train_script_name: GramFM
# models' scripts
model_configs:
g_model:
script: Generator
class_name: Generator
module_params:
g_conv_dim: 512
g_kernel_size: 3
res_num: 9
d_model:
script: projected_discriminator
class_name: ProjectedDiscriminator
module_params:
diffaug: False
interp224: False
backbone_kwargs: {}
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
# Training information
batch_size: 12
# Dataset
dataloader: VGGFace2HQ
dataset_name: vggface2_hq
dataset_params:
random_seed: 1234
dataloader_workers: 8
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: 0.0004
betas: [ 0, 0.99]
eps: !!float 1e-8
d_optim_config:
lr: 0.0004
betas: [ 0, 0.99]
eps: !!float 1e-8
id_weight: 20.0
reconstruct_weight: 10.0
feature_match_weight: 100.0
# Log
log_step: 300
model_save_step: 10000
total_step: 1000000
sample_step: 1000
checkpoint_names:
generator_name: Generator
discriminator_name: Discriminator