update
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"GUI.py": 1642351532.4558506,
|
||||
"test.py": 1634039043.4872007,
|
||||
"train.py": 1642408973.489742,
|
||||
"test.py": 1642733759.8015468,
|
||||
"train.py": 1643009568.7253726,
|
||||
"components\\Generator.py": 1642347735.351465,
|
||||
"components\\projected_discriminator.py": 1642348101.4661522,
|
||||
"components\\pg_modules\\blocks.py": 1640773190.0,
|
||||
@@ -23,7 +23,7 @@
|
||||
"test_scripts\\tester_common.py": 1625369535.199175,
|
||||
"test_scripts\\tester_FastNST.py": 1634041357.607633,
|
||||
"train_scripts\\trainer_base.py": 1642396105.3868554,
|
||||
"train_scripts\\trainer_FM.py": 1642396334.407562,
|
||||
"train_scripts\\trainer_FM.py": 1642826710.3532298,
|
||||
"train_scripts\\trainer_naiv512.py": 1642315674.9740853,
|
||||
"utilities\\checkpoint_manager.py": 1611123530.6624403,
|
||||
"utilities\\figure.py": 1611123530.6634378,
|
||||
@@ -37,7 +37,23 @@
|
||||
"utilities\\transfer_checkpoint.py": 1642397157.0163105,
|
||||
"utilities\\utilities.py": 1634019485.0783668,
|
||||
"utilities\\yaml_config.py": 1611123530.6614666,
|
||||
"train_yamls\\train_512FM.yaml": 1642408586.2747152,
|
||||
"train_scripts\\trainer_2layer_FM.py": 1642408727.3950334,
|
||||
"train_yamls\\train_2layer_FM.yaml": 1642408813.9488862
|
||||
"train_yamls\\train_512FM.yaml": 1642412254.0831068,
|
||||
"train_scripts\\trainer_2layer_FM.py": 1642826548.2530458,
|
||||
"train_yamls\\train_2layer_FM.yaml": 1642411635.5534878,
|
||||
"components\\Generator_reduce.py": 1642690262.572149,
|
||||
"insightface_func\\face_detect_crop_multi.py": 1638370471.789609,
|
||||
"insightface_func\\face_detect_crop_single.py": 1638370471.7967434,
|
||||
"insightface_func\\__init__.py": 1624197300.011183,
|
||||
"insightface_func\\utils\\face_align_ffhqandnewarc.py": 1638370471.850638,
|
||||
"losses\\PatchNCE.py": 1642989384.9713614,
|
||||
"parsing_model\\model.py": 1626745709.554252,
|
||||
"parsing_model\\resnet.py": 1626745709.554252,
|
||||
"test_scripts\\tester_common copy.py": 1625369535.199175,
|
||||
"test_scripts\\tester_video.py": 1642734397.3307388,
|
||||
"train_scripts\\trainer_cycleloss.py": 1642580463.495596,
|
||||
"train_scripts\\trainer_GramFM.py": 1643010471.1077821,
|
||||
"utilities\\ImagenetNorm.py": 1642732910.5280058,
|
||||
"utilities\\reverse2original.py": 1642733688.7976837,
|
||||
"train_yamls\\train_cycleloss.yaml": 1642577741.345273,
|
||||
"train_yamls\\train_GramFM.yaml": 1643011210.82505
|
||||
}
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 20th January 2022 10:51:02 pm
|
||||
# Last Modified: Monday, 24th January 2022 6:47:22 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -15,17 +15,16 @@ from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
class InstanceNorm(nn.Module):
|
||||
class Demodule(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
"""
|
||||
@notice: avoid in-place ops.
|
||||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
||||
"""
|
||||
super(InstanceNorm, self).__init__()
|
||||
super(Demodule, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
x = x - torch.mean(x, (2, 3), True)
|
||||
tmp = torch.mul(x, x) # or x ** 2
|
||||
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
||||
return x * tmp
|
||||
@@ -46,9 +45,22 @@ class ApplyStyle(nn.Module):
|
||||
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
|
||||
return x
|
||||
|
||||
class ResnetBlock_Adain(nn.Module):
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, latent_size, channels):
|
||||
super(Modulation, self).__init__()
|
||||
self.linear = nn.Linear(latent_size, channels)
|
||||
|
||||
def forward(self, x, latent):
|
||||
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
||||
shape = [-1, x.size(1), 1, 1]
|
||||
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
||||
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
||||
x = x * style
|
||||
return x
|
||||
|
||||
class ResnetBlock_Modulation(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Adain, self).__init__()
|
||||
super(ResnetBlock_Modulation, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
@@ -60,9 +72,9 @@ class ResnetBlock_Adain(nn.Module):
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), Demodule()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = ApplyStyle(latent_size, dim)
|
||||
self.style1 = Modulation(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
@@ -75,52 +87,9 @@ class ResnetBlock_Adain(nn.Module):
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), Demodule()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = ApplyStyle(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
y = self.conv1(x)
|
||||
y = self.style1(y, dlatents_in_slice)
|
||||
y = self.act1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.style2(y, dlatents_in_slice)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
class newres(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Adain, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
if padding_type == 'reflect':
|
||||
conv1 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv1 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=1), InstanceNorm()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = ApplyStyle(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
conv2 = []
|
||||
if padding_type == 'reflect':
|
||||
conv2 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv2 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = ApplyStyle(latent_size, dim)
|
||||
self.style2 = Modulation(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
@@ -148,7 +117,7 @@ class Generator(nn.Module):
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
||||
self.stem = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
@@ -167,7 +136,7 @@ class Generator(nn.Module):
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)]
|
||||
ResnetBlock_Modulation(512, latent_size=chn, padding_type=padding_type, activation=activation)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
self.up4 = nn.Sequential(
|
||||
@@ -210,7 +179,7 @@ class Generator(nn.Module):
|
||||
|
||||
def forward(self, input, id):
|
||||
x = input # 3*224*224
|
||||
skip1 = self.first_layer(x)
|
||||
skip1 = self.stem(x)
|
||||
skip2 = self.down1(skip1)
|
||||
skip3 = self.down2(skip2)
|
||||
skip4 = self.down3(skip3)
|
||||
|
||||
+217
-1
@@ -5,9 +5,225 @@
|
||||
# Created Date: Friday January 21st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 21st January 2022 5:04:43 pm
|
||||
# Last Modified: Monday, 24th January 2022 9:56:24 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
import torch.nn.functional as F
|
||||
from packaging import version
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
|
||||
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
||||
Parameters:
|
||||
net (network) -- the network to be initialized
|
||||
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
||||
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
||||
|
||||
Return an initialized network.
|
||||
"""
|
||||
if len(gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
net.to(gpu_ids[0])
|
||||
# if not amp:
|
||||
# net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
|
||||
if initialize_weights:
|
||||
init_weights(net, init_type, init_gain=init_gain, debug=debug)
|
||||
return net
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
|
||||
"""Initialize network weights.
|
||||
|
||||
Parameters:
|
||||
net (network) -- network to be initialized
|
||||
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
||||
|
||||
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
||||
work better for some applications. Feel free to try yourself.
|
||||
"""
|
||||
def init_func(m): # define the initialization function
|
||||
classname = m.__class__.__name__
|
||||
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
||||
if debug:
|
||||
print(classname)
|
||||
if init_type == 'normal':
|
||||
init.normal_(m.weight.data, 0.0, init_gain)
|
||||
elif init_type == 'xavier':
|
||||
init.xavier_normal_(m.weight.data, gain=init_gain)
|
||||
elif init_type == 'kaiming':
|
||||
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
init.orthogonal_(m.weight.data, gain=init_gain)
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
||||
init.normal_(m.weight.data, 1.0, init_gain)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
net.apply(init_func) # apply the initialization function <init_func>
|
||||
|
||||
class Normalize(nn.Module):
|
||||
|
||||
def __init__(self, power=2):
|
||||
super(Normalize, self).__init__()
|
||||
self.power = power
|
||||
|
||||
def forward(self, x):
|
||||
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
|
||||
out = x.div(norm + 1e-7)
|
||||
return out
|
||||
|
||||
class PatchSampleF(nn.Module):
|
||||
def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
|
||||
# potential issues: currently, we use the same patch_ids for multiple images in the batch
|
||||
super(PatchSampleF, self).__init__()
|
||||
self.l2norm = Normalize(2)
|
||||
self.use_mlp = use_mlp
|
||||
self.nc = nc # hard-coded
|
||||
self.mlp_init = False
|
||||
self.init_type = init_type
|
||||
self.init_gain = init_gain
|
||||
self.gpu_ids = gpu_ids
|
||||
|
||||
def create_mlp(self, feats):
|
||||
for mlp_id, feat in enumerate(feats):
|
||||
input_nc = feat.shape[1]
|
||||
mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
|
||||
if len(self.gpu_ids) > 0:
|
||||
mlp.cuda()
|
||||
setattr(self, 'mlp_%d' % mlp_id, mlp)
|
||||
init_net(self, self.init_type, self.init_gain, self.gpu_ids)
|
||||
self.mlp_init = True
|
||||
|
||||
def forward(self, feats, num_patches=64, patch_ids=None):
|
||||
return_ids = []
|
||||
return_feats = []
|
||||
if self.use_mlp and not self.mlp_init:
|
||||
self.create_mlp(feats)
|
||||
for feat_id, feat in enumerate(feats):
|
||||
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
|
||||
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
if num_patches > 0:
|
||||
if patch_ids is not None:
|
||||
patch_id = patch_ids[feat_id]
|
||||
else:
|
||||
# torch.randperm produces cudaErrorIllegalAddress for newer versions of PyTorch. https://github.com/taesungp/contrastive-unpaired-translation/issues/83
|
||||
#patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
|
||||
patch_id = np.random.permutation(feat_reshape.shape[1])
|
||||
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
|
||||
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
|
||||
else:
|
||||
x_sample = feat_reshape
|
||||
patch_id = []
|
||||
if self.use_mlp:
|
||||
mlp = getattr(self, 'mlp_%d' % feat_id)
|
||||
x_sample = mlp(x_sample)
|
||||
return_ids.append(patch_id)
|
||||
x_sample = self.l2norm(x_sample)
|
||||
|
||||
if num_patches == 0:
|
||||
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
|
||||
return_feats.append(x_sample)
|
||||
return return_feats, return_ids
|
||||
|
||||
|
||||
|
||||
# def calculate_NCE_loss(src, tgt):
|
||||
# n_layers = len(self.nce_layers)
|
||||
# feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
|
||||
|
||||
# if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
||||
# feat_q = [torch.flip(fq, [3]) for fq in feat_q]
|
||||
|
||||
# feat_k = self.netG(src, self.nce_layers, encode_only=True)
|
||||
# feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
|
||||
# feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
|
||||
|
||||
# total_nce_loss = 0.0
|
||||
# for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
|
||||
# loss = crit(f_q, f_k) * self.opt.lambda_NCE
|
||||
# total_nce_loss += loss.mean()
|
||||
|
||||
# return total_nce_loss / n_layers
|
||||
|
||||
class PatchNCELoss(nn.Module):
|
||||
def __init__(self,batch_size, nce_T = 0.07):
|
||||
super().__init__()
|
||||
self.nce_T = nce_T
|
||||
self.batch_size = batch_size
|
||||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
|
||||
|
||||
def forward(self, feat_q, feat_k):
|
||||
num_patches = feat_q.shape[0]
|
||||
dim = feat_q.shape[1]
|
||||
feat_k = feat_k.detach()
|
||||
|
||||
# pos logit
|
||||
l_pos = torch.bmm(
|
||||
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
|
||||
l_pos = l_pos.view(num_patches, 1)
|
||||
|
||||
# neg logit
|
||||
|
||||
# Should the negatives from the other samples of a minibatch be utilized?
|
||||
# In CUT and FastCUT, we found that it's best to only include negatives
|
||||
# from the same image. Therefore, we set
|
||||
# --nce_includes_all_negatives_from_minibatch as False
|
||||
# However, for single-image translation, the minibatch consists of
|
||||
# crops from the "same" high-resolution image.
|
||||
# Therefore, we will include the negatives from the entire minibatch.
|
||||
# if self.opt.nce_includes_all_negatives_from_minibatch:
|
||||
# # reshape features as if they are all negatives of minibatch of size 1.
|
||||
# batch_dim_for_bmm = 1
|
||||
# else:
|
||||
batch_dim_for_bmm = self.batch_size
|
||||
|
||||
# reshape features to batch size
|
||||
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
|
||||
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
|
||||
npatches = feat_q.size(1)
|
||||
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
|
||||
|
||||
# diagonal entries are similarity between same features, and hence meaningless.
|
||||
# just fill the diagonal with very small number, which is exp(-10) and almost zero
|
||||
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
|
||||
l_neg_curbatch.masked_fill_(diagonal, -10.0)
|
||||
l_neg = l_neg_curbatch.view(-1, npatches)
|
||||
|
||||
out = torch.cat((l_pos, l_neg), dim=1) / self.nce_T
|
||||
|
||||
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
|
||||
device=feat_q.device))
|
||||
|
||||
return loss
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch = 16
|
||||
nc = 256
|
||||
num_patches = 64
|
||||
netF = PatchSampleF(use_mlp=True, init_type='normal', init_gain=0.02, gpu_ids=["cuda:0"], nc=nc)
|
||||
|
||||
feat_q = [torch.ones((batch,nc,32,32)).cuda()]
|
||||
|
||||
# if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
||||
# feat_q = [torch.flip(fq, [3]) for fq in feat_q]
|
||||
crit = PatchNCELoss(batch)
|
||||
feat_k = [torch.ones((batch,nc,32,32)).cuda()]
|
||||
feat_k_pool, sample_ids = netF(feat_k, num_patches, None)
|
||||
print(feat_k_pool[0].shape)
|
||||
feat_q_pool, _ = netF(feat_q, num_patches, sample_ids)
|
||||
print(feat_q_pool[0].shape)
|
||||
loss = crit(feat_q_pool[0], feat_k_pool[0])
|
||||
print(loss)
|
||||
@@ -0,0 +1,3 @@
|
||||
|
||||
|
||||
nohup python train.py > GramFM.log 2>&1 &
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Tuesday April 28th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 17th January 2022 4:42:53 pm
|
||||
# Last Modified: Monday, 24th January 2022 3:32:47 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -31,9 +31,9 @@ def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='2layerFM',
|
||||
parser.add_argument('-v', '--version', type=str, default='GramFM',
|
||||
help="version name for train, test, finetune")
|
||||
parser.add_argument('-t', '--tag', type=str, default='Feature_match',
|
||||
parser.add_argument('-t', '--tag', type=str, default='Gram_Feature_match',
|
||||
help="tag for current experiment")
|
||||
|
||||
parser.add_argument('-p', '--phase', type=str, default="train",
|
||||
@@ -46,9 +46,9 @@ def getParameters():
|
||||
|
||||
# training
|
||||
parser.add_argument('--experiment_description', type=str,
|
||||
default="减小重建和feature match的权重,使用2和3的feature作为feature")
|
||||
default="使用3作为feature, 尝试使用gram矩阵来计算feature matching")
|
||||
|
||||
parser.add_argument('--train_yaml', type=str, default="train_2layer_FM.yaml")
|
||||
parser.add_argument('--train_yaml', type=str, default="train_GramFM.yaml")
|
||||
|
||||
# system logger
|
||||
parser.add_argument('--logger', type=str,
|
||||
|
||||
+15
-10
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 22nd January 2022 12:45:09 pm
|
||||
# Last Modified: Monday, 24th January 2022 6:56:17 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -28,6 +28,9 @@ class Trainer(TrainerBase):
|
||||
config,
|
||||
reporter):
|
||||
super(Trainer, self).__init__(config, reporter)
|
||||
|
||||
import inspect
|
||||
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
|
||||
|
||||
self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
@@ -276,25 +279,27 @@ class Trainer(TrainerBase):
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(self.config["version"], elapsed, step, total_step, \
|
||||
loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(self.config["version"], elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["logger"] == "tensorboard":
|
||||
self.logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
self.logger.add_scalar('G/Rec_loss', loss_G_Rec.item(), step)
|
||||
self.logger.add_scalar('G/Fm_loss', feat_match_loss.item(), step)
|
||||
self.logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
|
||||
self.logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
|
||||
self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
self.logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
self.logger.add_scalar('D/D_real', loss_Dreal.item(), step)
|
||||
elif self.config["logger"] == "wandb":
|
||||
self.logger.log({"G_loss": loss_G.item()}, step = step)
|
||||
self.logger.log({"Rec_loss": loss_G_Rec.item()}, step = step)
|
||||
self.logger.log({"Fm_loss": feat_match_loss.item()}, step = step)
|
||||
self.logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
|
||||
self.logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
|
||||
self.logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
self.logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
self.logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
self.logger.log({"D_real": loss_Dreal.item()}, step = step)
|
||||
|
||||
@@ -0,0 +1,351 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_naiv512.py
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 24th January 2022 6:23:16 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from utilities.plot import plot_batch
|
||||
|
||||
from train_scripts.trainer_base import TrainerBase
|
||||
|
||||
from utilities.utilities import Gram
|
||||
|
||||
class Trainer(TrainerBase):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
reporter):
|
||||
super(Trainer, self).__init__(config, reporter)
|
||||
import inspect
|
||||
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
|
||||
|
||||
self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def init_framework(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
dscript_name = "components." + model_config["d_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
dscript_name = self.config["com_base"] + model_config["d_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
class_name = model_config["d_model"]["class_name"]
|
||||
package = __import__(dscript_name, fromlist=True)
|
||||
dis_class = getattr(package, class_name)
|
||||
self.dis = dis_class(**model_config["d_model"]["module_params"])
|
||||
self.dis.feature_network.requires_grad_(False)
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Discriminator structure:")
|
||||
self.reporter.writeModel(self.dis.__str__())
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
self.dis = self.dis.cuda()
|
||||
self.arcface= self.arcface.cuda()
|
||||
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["discriminator_name"]))
|
||||
self.dis.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
d_train_opt = self.config['d_optim_config']
|
||||
|
||||
g_optim_params = []
|
||||
d_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
for k, v in self.dis.named_parameters():
|
||||
if v.requires_grad:
|
||||
d_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
if self.config["phase"] == "finetune":
|
||||
opt_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["optimizer_names"]["generator_name"]))
|
||||
self.g_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
opt_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["optimizer_names"]["discriminator_name"]))
|
||||
self.d_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
print('loaded trained optimizer step {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
# Evaluate the checkpoint
|
||||
def __evaluation__(self,
|
||||
step = 0,
|
||||
**kwargs
|
||||
):
|
||||
src_image1 = kwargs["src1"]
|
||||
src_image2 = kwargs["src2"]
|
||||
batch_size = self.batch_size
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
imgs = []
|
||||
zero_img = (torch.zeros_like(src_image1[0,...]))
|
||||
imgs.append(zero_img.cpu().numpy())
|
||||
save_img = ((src_image1.cpu())* self.img_std + self.img_mean).numpy()
|
||||
for r in range(batch_size):
|
||||
imgs.append(save_img[r,...])
|
||||
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
|
||||
id_vector_src1 = self.arcface(arcface_112)
|
||||
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
|
||||
|
||||
for i in range(batch_size):
|
||||
|
||||
imgs.append(save_img[i,...])
|
||||
image_infer = src_image1[i, ...].repeat(batch_size, 1, 1, 1)
|
||||
img_fake = self.gen(image_infer, id_vector_src1).cpu()
|
||||
|
||||
img_fake = img_fake * self.img_std
|
||||
img_fake = img_fake + self.img_mean
|
||||
img_fake = img_fake.numpy()
|
||||
for j in range(batch_size):
|
||||
imgs.append(img_fake[j,...])
|
||||
print("Save test data")
|
||||
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
|
||||
plot_batch(imgs, os.path.join(self.sample_dir, 'step_'+str(step+1)+'.jpg'))
|
||||
|
||||
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_freq = self.config["log_step"]
|
||||
model_freq = self.config["model_save_step"]
|
||||
sample_freq = self.config["sample_step"]
|
||||
total_step = self.config["total_step"]
|
||||
random_seed = self.config["dataset_params"]["random_seed"]
|
||||
|
||||
self.batch_size = self.config["batch_size"]
|
||||
self.sample_dir = self.config["project_samples"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
id_w = self.config["id_weight"]
|
||||
rec_w = self.config["reconstruct_weight"]
|
||||
feat_w = self.config["feature_match_weight"]
|
||||
|
||||
|
||||
|
||||
super().train()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
l1_loss = torch.nn.L1Loss()
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
print("Total step = %d"%total_step)
|
||||
random.seed(random_seed)
|
||||
randindex = [i for i in range(self.batch_size)]
|
||||
random.shuffle(randindex)
|
||||
import datetime
|
||||
for step in range(self.start, total_step):
|
||||
self.gen.train()
|
||||
self.dis.train()
|
||||
for interval in range(2):
|
||||
random.shuffle(randindex)
|
||||
src_image1, src_image2 = self.train_loader.next()
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
else:
|
||||
img_id = src_image2[randindex]
|
||||
|
||||
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
|
||||
latent_id = self.arcface(img_id_112)
|
||||
latent_id = F.normalize(latent_id, p=2, dim=1)
|
||||
if interval:
|
||||
|
||||
img_fake = self.gen(src_image1, latent_id)
|
||||
gen_logits,_ = self.dis(img_fake.detach(), None)
|
||||
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
|
||||
|
||||
real_logits,_ = self.dis(src_image2,None)
|
||||
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
|
||||
|
||||
loss_D = loss_Dgen + loss_Dreal
|
||||
self.d_optimizer.zero_grad()
|
||||
loss_D.backward()
|
||||
self.d_optimizer.step()
|
||||
else:
|
||||
|
||||
# model.netD.requires_grad_(True)
|
||||
img_fake = self.gen(src_image1, latent_id)
|
||||
# G loss
|
||||
gen_logits,feat = self.dis(img_fake, None)
|
||||
|
||||
loss_Gmain = (-gen_logits).mean()
|
||||
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
|
||||
latent_fake = self.arcface(img_fake_down)
|
||||
latent_fake = F.normalize(latent_fake, p=2, dim=1)
|
||||
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
|
||||
real_feat = self.dis.get_feature(src_image1)
|
||||
feat_match_loss = l1_loss(Gram(feat["3"]), Gram(real_feat["3"])) + \
|
||||
l1_loss(Gram(feat["2"]), Gram(real_feat["2"]))
|
||||
loss_G = loss_Gmain + loss_G_ID * id_w + \
|
||||
feat_match_loss * feat_w
|
||||
if step%2 == 0:
|
||||
#G_Rec
|
||||
loss_G_Rec = l1_loss(img_fake, src_image1)
|
||||
loss_G += loss_G_Rec * rec_w
|
||||
|
||||
self.g_optimizer.zero_grad()
|
||||
loss_G.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_freq == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(self.config["version"], elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["logger"] == "tensorboard":
|
||||
self.logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
self.logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
|
||||
self.logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
|
||||
self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
self.logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
self.logger.add_scalar('D/D_real', loss_Dreal.item(), step)
|
||||
elif self.config["logger"] == "wandb":
|
||||
self.logger.log({"G_loss": loss_G.item()}, step = step)
|
||||
self.logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
|
||||
self.logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
|
||||
self.logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
self.logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
self.logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
self.logger.log({"D_real": loss_Dreal.item()}, step = step)
|
||||
|
||||
if (step + 1) % sample_freq == 0:
|
||||
self.__evaluation__(
|
||||
step = step,
|
||||
**{
|
||||
"src1": src_image1,
|
||||
"src2": src_image2
|
||||
})
|
||||
|
||||
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (step+1) % model_freq==0:
|
||||
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
torch.save(self.dis.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
self.config["checkpoint_names"]["discriminator_name"])))
|
||||
|
||||
torch.save(self.g_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.save(self.d_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
self.config["checkpoint_names"]["discriminator_name"])))
|
||||
print("Save step %d model checkpoint!"%(step+1))
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.__evaluation__(
|
||||
step = step,
|
||||
**{
|
||||
"src1": src_image1,
|
||||
"src2": src_image2
|
||||
})
|
||||
@@ -50,7 +50,7 @@ d_optim_config:
|
||||
eps: !!float 1e-8
|
||||
|
||||
id_weight: 20.0
|
||||
reconstruct_weight: 1.0
|
||||
reconstruct_weight: 10.0
|
||||
feature_match_weight: 10.0
|
||||
|
||||
# Log
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
# Related scripts
|
||||
train_script_name: FM
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: Generator_reduce
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 512
|
||||
g_kernel_size: 3
|
||||
res_num: 9
|
||||
|
||||
d_model:
|
||||
script: projected_discriminator
|
||||
class_name: ProjectedDiscriminator
|
||||
module_params:
|
||||
diffaug: False
|
||||
interp224: False
|
||||
backbone_kwargs: {}
|
||||
|
||||
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
|
||||
|
||||
# Training information
|
||||
batch_size: 12
|
||||
|
||||
# Dataset
|
||||
dataloader: VGGFace2HQ
|
||||
dataset_name: vggface2_hq
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: 0.0004
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
d_optim_config:
|
||||
lr: 0.0004
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
id_weight: 20.0
|
||||
reconstruct_weight: 10.0
|
||||
feature_match_weight: 10.0
|
||||
|
||||
# Log
|
||||
log_step: 300
|
||||
model_save_step: 10000
|
||||
sample_step: 1000
|
||||
total_step: 1000000
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
discriminator_name: Discriminator
|
||||
@@ -0,0 +1,63 @@
|
||||
# Related scripts
|
||||
train_script_name: GramFM
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: Generator
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 512
|
||||
g_kernel_size: 3
|
||||
res_num: 9
|
||||
|
||||
d_model:
|
||||
script: projected_discriminator
|
||||
class_name: ProjectedDiscriminator
|
||||
module_params:
|
||||
diffaug: False
|
||||
interp224: False
|
||||
backbone_kwargs: {}
|
||||
|
||||
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
|
||||
|
||||
# Training information
|
||||
batch_size: 12
|
||||
|
||||
# Dataset
|
||||
dataloader: VGGFace2HQ
|
||||
dataset_name: vggface2_hq
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: 0.0004
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
d_optim_config:
|
||||
lr: 0.0004
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
id_weight: 20.0
|
||||
reconstruct_weight: 10.0
|
||||
feature_match_weight: 100.0
|
||||
|
||||
# Log
|
||||
log_step: 300
|
||||
model_save_step: 10000
|
||||
total_step: 1000000
|
||||
sample_step: 1000
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
discriminator_name: Discriminator
|
||||
Reference in New Issue
Block a user