update
This commit is contained in:
@@ -1,13 +1,9 @@
|
||||
{
|
||||
"GUI.py": 1642351532.4558506,
|
||||
"test.py": 1634039043.4872007,
|
||||
"train.py": 1642351831.0061252,
|
||||
"train.py": 1642408973.489742,
|
||||
"components\\Generator.py": 1642347735.351465,
|
||||
"components\\Involution.py": 1626748553.9503577,
|
||||
"components\\projected_discriminator.py": 1642348101.4661522,
|
||||
"components\\ResBlock.py": 1625415499.383468,
|
||||
"components\\Transform.py": 1624954083.0098498,
|
||||
"components\\warp_invo.py": 1634614033.6983366,
|
||||
"components\\pg_modules\\blocks.py": 1640773190.0,
|
||||
"components\\pg_modules\\diffaug.py": 1640773190.0,
|
||||
"components\\pg_modules\\discriminator.py": 1642349784.9407308,
|
||||
@@ -17,22 +13,18 @@
|
||||
"data_tools\\data_loader.py": 1611123530.660446,
|
||||
"data_tools\\data_loader_condition.py": 1625411562.8217106,
|
||||
"data_tools\\data_loader_VGGFace2HQ.py": 1642349144.749807,
|
||||
"data_tools\\StyleResize.py": 1624954084.7176485,
|
||||
"data_tools\\test_dataloader_dir.py": 1634041792.6743984,
|
||||
"losses\\PerceptualLoss.py": 1615020169.668723,
|
||||
"losses\\SliceWassersteinDistance.py": 1634022704.6082795,
|
||||
"reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\experimental.py": 1583468787.0,
|
||||
"reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\stylize.py": 1583468787.0,
|
||||
"reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\train.py": 1583468787.0,
|
||||
"reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\transformer.py": 1583468787.0,
|
||||
"reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\utils.py": 1583468787.0,
|
||||
"reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\vgg.py": 1633868477.988523,
|
||||
"reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\video.py": 1583468787.0,
|
||||
"reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\webcam.py": 1583468787.0,
|
||||
"models\\arcface_models.py": 1642390690.623,
|
||||
"models\\config.py": 1632643596.2908099,
|
||||
"models\\__init__.py": 1642390864.8828168,
|
||||
"test_scripts\\tester_common.py": 1625369535.199175,
|
||||
"test_scripts\\tester_FastNST.py": 1634041357.607633,
|
||||
"train_scripts\\trainer_base.py": 1642347616.205689,
|
||||
"train_scripts\\trainer_FastNST_SWD.py": 1634581704.2218158,
|
||||
"train_scripts\\trainer_FM.py": 1642350579.8586667,
|
||||
"train_scripts\\trainer_gan.py": 1625571403.080787,
|
||||
"train_scripts\\trainer_base.py": 1642396105.3868554,
|
||||
"train_scripts\\trainer_FM.py": 1642396334.407562,
|
||||
"train_scripts\\trainer_naiv512.py": 1642315674.9740853,
|
||||
"utilities\\checkpoint_manager.py": 1611123530.6624403,
|
||||
"utilities\\figure.py": 1611123530.6634378,
|
||||
"utilities\\json_config.py": 1611123530.6614666,
|
||||
@@ -42,8 +34,10 @@
|
||||
"utilities\\reporter.py": 1625413813.7213495,
|
||||
"utilities\\save_heatmap.py": 1611123530.679439,
|
||||
"utilities\\sshupload.py": 1611123530.6624403,
|
||||
"utilities\\transfer_checkpoint.py": 1612416429.5316093,
|
||||
"utilities\\transfer_checkpoint.py": 1642397157.0163105,
|
||||
"utilities\\utilities.py": 1634019485.0783668,
|
||||
"utilities\\yaml_config.py": 1611123530.6614666,
|
||||
"train_yamls\\train_512FM.yaml": 1642351806.754128
|
||||
"train_yamls\\train_512FM.yaml": 1642408586.2747152,
|
||||
"train_scripts\\trainer_2layer_FM.py": 1642408727.3950334,
|
||||
"train_yamls\\train_2layer_FM.yaml": 1642408813.9488862
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 4th July 2021 11:32:14 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
|
||||
import torch
|
||||
from utilities.utilities import tensor2img
|
||||
|
||||
# from utilities.Reporter import Reporter
|
||||
from tqdm import tqdm
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
#============build evaluation dataloader==============#
|
||||
print("Prepare the test dataloader...")
|
||||
dlModulename = config["test_dataloader"]
|
||||
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'TestDataset')
|
||||
dataloader = dataloaderClass(config["test_data_path"],
|
||||
config["batch_size"],
|
||||
["png","jpg"])
|
||||
self.test_loader= dataloader
|
||||
|
||||
self.test_iter = len(dataloader)//config["batch_size"]
|
||||
if len(dataloader)%config["batch_size"]>0:
|
||||
self.test_iter+=1
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
script_name = "components."+self.config["module_script_name"]
|
||||
class_name = self.config["class_name"]
|
||||
package = __import__(script_name, fromlist=True)
|
||||
network_class = getattr(package, class_name)
|
||||
n_class = len(self.config["selectedStyleDir"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = network_class(self.config["GConvDim"],
|
||||
self.config["GKS"],
|
||||
self.config["resNum"],
|
||||
n_class
|
||||
#**self.config["module_params"]
|
||||
)
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
|
||||
# print(loader1.key())
|
||||
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
|
||||
self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"])
|
||||
# self.network.load_state_dict(torch.load(pathwocao))
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"]))
|
||||
|
||||
def test(self):
|
||||
|
||||
# save_result = self.config["saveTestResult"]
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_epoch = self.config["checkpoint_epoch"]
|
||||
version = self.config["version"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_names = self.config["selectedStyleDir"]
|
||||
n_class = len(style_names)
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
condition_labels = torch.ones((n_class, batch_size, 1)).long()
|
||||
for i in range(n_class):
|
||||
condition_labels[i,:,:] = condition_labels[i,:,:]*i
|
||||
if self.config["cuda"] >=0:
|
||||
condition_labels = condition_labels.cuda()
|
||||
total = len(self.test_loader)
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
for _ in tqdm(range(total//batch_size)):
|
||||
contents, img_names = self.test_loader()
|
||||
for i in range(n_class):
|
||||
if self.config["cuda"] >=0:
|
||||
contents = contents.cuda()
|
||||
res, _ = self.network(contents, condition_labels[i, 0, :])
|
||||
res = tensor2img(res.cpu())
|
||||
for t in range(batch_size):
|
||||
temp_img = res[t,:,:,:]
|
||||
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format(
|
||||
img_names[t], version, ckp_epoch, style_names[i])),temp_img)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Tuesday April 28th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 17th January 2022 1:00:00 pm
|
||||
# Last Modified: Monday, 17th January 2022 4:42:53 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -31,24 +31,24 @@ def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='FM',
|
||||
parser.add_argument('-v', '--version', type=str, default='2layerFM',
|
||||
help="version name for train, test, finetune")
|
||||
parser.add_argument('-t', '--tag', type=str, default='test',
|
||||
parser.add_argument('-t', '--tag', type=str, default='Feature_match',
|
||||
help="tag for current experiment")
|
||||
|
||||
parser.add_argument('-p', '--phase', type=str, default="train",
|
||||
choices=['train', 'finetune','debug'],
|
||||
help="The phase of current project")
|
||||
|
||||
parser.add_argument('-c', '--cuda', type=int, default=0) # <0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-c', '--cuda', type=int, default=1) # <0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-e', '--ckpt', type=int, default=74,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
|
||||
# training
|
||||
parser.add_argument('--experiment_description', type=str,
|
||||
default="尝试使用Liif+Invo作为上采样和降采样的算子,降采样两个DSF算子,上采样两个DSF算子")
|
||||
default="减小重建和feature match的权重,使用2和3的feature作为feature")
|
||||
|
||||
parser.add_argument('--train_yaml', type=str, default="train_512FM.yaml")
|
||||
parser.add_argument('--train_yaml', type=str, default="train_2layer_FM.yaml")
|
||||
|
||||
# system logger
|
||||
parser.add_argument('--logger', type=str,
|
||||
|
||||
@@ -0,0 +1,342 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_naiv512.py
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 17th January 2022 9:27:48 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from utilities.plot import plot_batch
|
||||
|
||||
from train_scripts.trainer_base import TrainerBase
|
||||
|
||||
class Trainer(TrainerBase):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
super(Trainer, self).__init__(config, reporter)
|
||||
|
||||
self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def init_framework(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
dscript_name = "components." + model_config["d_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
dscript_name = self.config["com_base"] + model_config["d_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
class_name = model_config["d_model"]["class_name"]
|
||||
package = __import__(dscript_name, fromlist=True)
|
||||
dis_class = getattr(package, class_name)
|
||||
self.dis = dis_class(**model_config["d_model"]["module_params"])
|
||||
self.dis.feature_network.requires_grad_(False)
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Discriminator structure:")
|
||||
self.reporter.writeModel(self.dis.__str__())
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
self.dis = self.dis.cuda()
|
||||
self.arcface= self.arcface.cuda()
|
||||
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["discriminator_name"]))
|
||||
self.dis.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
d_train_opt = self.config['d_optim_config']
|
||||
|
||||
g_optim_params = []
|
||||
d_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
for k, v in self.dis.named_parameters():
|
||||
if v.requires_grad:
|
||||
d_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
if self.config["phase"] == "finetune":
|
||||
opt_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["optimizer_names"]["generator_name"]))
|
||||
self.g_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
opt_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["optimizer_names"]["discriminator_name"]))
|
||||
self.d_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
print('loaded trained optimizer step {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
# Evaluate the checkpoint
|
||||
def __evaluation__(self,
|
||||
step = 0,
|
||||
**kwargs
|
||||
):
|
||||
src_image1 = kwargs["src1"]
|
||||
src_image2 = kwargs["src2"]
|
||||
batch_size = self.batch_size
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
imgs = []
|
||||
zero_img = (torch.zeros_like(src_image1[0,...]))
|
||||
imgs.append(zero_img.cpu().numpy())
|
||||
save_img = ((src_image1.cpu())* self.img_std + self.img_mean).numpy()
|
||||
for r in range(batch_size):
|
||||
imgs.append(save_img[r,...])
|
||||
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
|
||||
id_vector_src1 = self.arcface(arcface_112)
|
||||
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
|
||||
|
||||
for i in range(batch_size):
|
||||
|
||||
imgs.append(save_img[i,...])
|
||||
image_infer = src_image1[i, ...].repeat(batch_size, 1, 1, 1)
|
||||
img_fake = self.gen(image_infer, id_vector_src1).cpu()
|
||||
|
||||
img_fake = img_fake * self.img_std
|
||||
img_fake = img_fake + self.img_mean
|
||||
img_fake = img_fake.numpy()
|
||||
for j in range(batch_size):
|
||||
imgs.append(img_fake[j,...])
|
||||
print("Save test data")
|
||||
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
|
||||
plot_batch(imgs, os.path.join(self.sample_dir, 'step_'+str(step+1)+'.jpg'))
|
||||
|
||||
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_step"]
|
||||
sample_freq = self.config["sample_step"]
|
||||
total_step = self.config["total_step"]
|
||||
random_seed = self.config["dataset_params"]["random_seed"]
|
||||
|
||||
self.batch_size = self.config["batch_size"]
|
||||
self.sample_dir = self.config["project_samples"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
id_w = self.config["id_weight"]
|
||||
rec_w = self.config["reconstruct_weight"]
|
||||
feat_w = self.config["feature_match_weight"]
|
||||
|
||||
|
||||
|
||||
super().train()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
# MSE_loss = torch.nn.MSELoss()
|
||||
l1_loss = torch.nn.L1Loss()
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
print("Total step = %d"%total_step)
|
||||
random.seed(random_seed)
|
||||
randindex = [i for i in range(self.batch_size)]
|
||||
random.shuffle(randindex)
|
||||
import datetime
|
||||
for step in range(self.start, total_step):
|
||||
self.gen.train()
|
||||
self.dis.train()
|
||||
for interval in range(2):
|
||||
random.shuffle(randindex)
|
||||
src_image1, src_image2 = self.train_loader.next()
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
else:
|
||||
img_id = src_image2[randindex]
|
||||
|
||||
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
|
||||
latent_id = self.arcface(img_id_112)
|
||||
latent_id = F.normalize(latent_id, p=2, dim=1)
|
||||
if interval:
|
||||
|
||||
img_fake = self.gen(src_image1, latent_id)
|
||||
gen_logits,_ = self.dis(img_fake.detach(), None)
|
||||
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
|
||||
|
||||
real_logits,_ = self.dis(src_image2,None)
|
||||
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
|
||||
|
||||
loss_D = loss_Dgen + loss_Dreal
|
||||
self.d_optimizer.zero_grad()
|
||||
loss_D.backward()
|
||||
self.d_optimizer.step()
|
||||
else:
|
||||
|
||||
# model.netD.requires_grad_(True)
|
||||
img_fake = self.gen(src_image1, latent_id)
|
||||
# G loss
|
||||
gen_logits,feat = self.dis(img_fake, None)
|
||||
|
||||
loss_Gmain = (-gen_logits).mean()
|
||||
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
|
||||
latent_fake = self.arcface(img_fake_down)
|
||||
latent_fake = F.normalize(latent_fake, p=2, dim=1)
|
||||
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
|
||||
real_feat = self.dis.get_feature(src_image1)
|
||||
feat_match_loss = l1_loss(feat["3"],real_feat["3"]) + l1_loss(feat["2"],real_feat["2"])
|
||||
loss_G = loss_Gmain + loss_G_ID * id_w + \
|
||||
feat_match_loss * feat_w
|
||||
if step%2 == 0:
|
||||
#G_Rec
|
||||
loss_G_Rec = l1_loss(img_fake, src_image1)
|
||||
loss_G += loss_G_Rec * rec_w
|
||||
|
||||
self.g_optimizer.zero_grad()
|
||||
loss_G.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(self.config["version"], elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["logger"] == "tensorboard":
|
||||
self.logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
self.logger.add_scalar('G/Rec_loss', loss_G_Rec.item(), step)
|
||||
self.logger.add_scalar('G/Fm_loss', feat_match_loss.item(), step)
|
||||
self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
self.logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
self.logger.add_scalar('D/D_real', loss_Dreal.item(), step)
|
||||
elif self.config["logger"] == "wandb":
|
||||
self.logger.log({"G_loss": loss_G.item()}, step = step)
|
||||
self.logger.log({"Rec_loss": loss_G_Rec.item()}, step = step)
|
||||
self.logger.log({"Fm_loss": feat_match_loss.item()}, step = step)
|
||||
self.logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
self.logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
self.logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
self.logger.log({"D_real": loss_Dreal.item()}, step = step)
|
||||
if (step + 1) % sample_freq == 0:
|
||||
self.__evaluation__(
|
||||
step = step,
|
||||
**{
|
||||
"src1": src_image1,
|
||||
"src2": src_image2
|
||||
})
|
||||
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (step+1) % model_freq==0:
|
||||
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
torch.save(self.dis.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
self.config["checkpoint_names"]["discriminator_name"])))
|
||||
|
||||
torch.save(self.g_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.save(self.d_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
self.config["checkpoint_names"]["discriminator_name"])))
|
||||
print("Save step %d model checkpoint!"%(step+1))
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.__evaluation__(
|
||||
step = step,
|
||||
**{
|
||||
"src1": src_image1,
|
||||
"src2": src_image2
|
||||
})
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 17th January 2022 1:12:08 pm
|
||||
# Last Modified: Monday, 17th January 2022 5:31:43 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -180,8 +180,9 @@ class Trainer(TrainerBase):
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
log_freq = self.config["log_step"]
|
||||
model_freq = self.config["model_save_step"]
|
||||
sample_freq = self.config["sample_step"]
|
||||
total_step = self.config["total_step"]
|
||||
random_seed = self.config["dataset_params"]["random_seed"]
|
||||
|
||||
@@ -268,7 +269,7 @@ class Trainer(TrainerBase):
|
||||
self.g_optimizer.step()
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
if (step + 1) % log_freq == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
@@ -295,6 +296,14 @@ class Trainer(TrainerBase):
|
||||
self.logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
self.logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
self.logger.log({"D_real": loss_Dreal.item()}, step = step)
|
||||
|
||||
if (step + 1) % sample_freq == 0:
|
||||
self.__evaluation__(
|
||||
step = step,
|
||||
**{
|
||||
"src1": src_image1,
|
||||
"src2": src_image2
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,350 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_naiv512.py
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 19th January 2022 4:21:03 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from utilities.plot import plot_batch
|
||||
|
||||
from train_scripts.trainer_base import TrainerBase
|
||||
|
||||
class Trainer(TrainerBase):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
super(Trainer, self).__init__(config, reporter)
|
||||
|
||||
self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def init_framework(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
dscript_name = "components." + model_config["d_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
dscript_name = self.config["com_base"] + model_config["d_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
class_name = model_config["d_model"]["class_name"]
|
||||
package = __import__(dscript_name, fromlist=True)
|
||||
dis_class = getattr(package, class_name)
|
||||
self.dis = dis_class(**model_config["d_model"]["module_params"])
|
||||
self.dis.feature_network.requires_grad_(False)
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Discriminator structure:")
|
||||
self.reporter.writeModel(self.dis.__str__())
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
self.dis = self.dis.cuda()
|
||||
self.arcface= self.arcface.cuda()
|
||||
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["discriminator_name"]))
|
||||
self.dis.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
d_train_opt = self.config['d_optim_config']
|
||||
|
||||
g_optim_params = []
|
||||
d_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
for k, v in self.dis.named_parameters():
|
||||
if v.requires_grad:
|
||||
d_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
if self.config["phase"] == "finetune":
|
||||
opt_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["optimizer_names"]["generator_name"]))
|
||||
self.g_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
opt_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["optimizer_names"]["discriminator_name"]))
|
||||
self.d_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
print('loaded trained optimizer step {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
# Evaluate the checkpoint
|
||||
def __evaluation__(self,
|
||||
step = 0,
|
||||
**kwargs
|
||||
):
|
||||
src_image1 = kwargs["src1"]
|
||||
src_image2 = kwargs["src2"]
|
||||
batch_size = self.batch_size
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
imgs = []
|
||||
zero_img = (torch.zeros_like(src_image1[0,...]))
|
||||
imgs.append(zero_img.cpu().numpy())
|
||||
save_img = ((src_image1.cpu())* self.img_std + self.img_mean).numpy()
|
||||
for r in range(batch_size):
|
||||
imgs.append(save_img[r,...])
|
||||
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
|
||||
id_vector_src1 = self.arcface(arcface_112)
|
||||
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
|
||||
|
||||
for i in range(batch_size):
|
||||
|
||||
imgs.append(save_img[i,...])
|
||||
image_infer = src_image1[i, ...].repeat(batch_size, 1, 1, 1)
|
||||
img_fake = self.gen(image_infer, id_vector_src1).cpu()
|
||||
|
||||
img_fake = img_fake * self.img_std
|
||||
img_fake = img_fake + self.img_mean
|
||||
img_fake = img_fake.numpy()
|
||||
for j in range(batch_size):
|
||||
imgs.append(img_fake[j,...])
|
||||
print("Save sample image at step = % ..............."%step)
|
||||
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
|
||||
plot_batch(imgs, os.path.join(self.sample_dir, 'step_'+str(step+1)+'.jpg'))
|
||||
|
||||
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_step"]
|
||||
sample_freq = self.config["sample_step"]
|
||||
total_step = self.config["total_step"]
|
||||
random_seed = self.config["dataset_params"]["random_seed"]
|
||||
|
||||
self.batch_size = self.config["batch_size"]
|
||||
self.sample_dir = self.config["project_samples"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
id_w = self.config["id_weight"]
|
||||
rec_w = self.config["reconstruct_weight"]
|
||||
feat_w = self.config["feature_match_weight"]
|
||||
cyc_w = self.config["cycle_weight"]
|
||||
|
||||
|
||||
|
||||
super().train()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
# MSE_loss = torch.nn.MSELoss()
|
||||
l1_loss = torch.nn.L1Loss()
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
print("Total step = %d"%total_step)
|
||||
random.seed(random_seed)
|
||||
randindex = [i for i in range(self.batch_size)]
|
||||
random.shuffle(randindex)
|
||||
import datetime
|
||||
for step in range(self.start, total_step):
|
||||
self.gen.train()
|
||||
self.dis.train()
|
||||
for interval in range(2):
|
||||
random.shuffle(randindex)
|
||||
src_image1, src_image2 = self.train_loader.next()
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
else:
|
||||
img_id = src_image2[randindex]
|
||||
|
||||
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
|
||||
latent_id = self.arcface(img_id_112)
|
||||
latent_id = F.normalize(latent_id, p=2, dim=1)
|
||||
if interval:
|
||||
|
||||
img_fake = self.gen(src_image1, latent_id)
|
||||
gen_logits,_ = self.dis(img_fake.detach(), None)
|
||||
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
|
||||
|
||||
real_logits,_ = self.dis(src_image2,None)
|
||||
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
|
||||
|
||||
loss_D = loss_Dgen + loss_Dreal
|
||||
self.d_optimizer.zero_grad()
|
||||
loss_D.backward()
|
||||
self.d_optimizer.step()
|
||||
else:
|
||||
|
||||
# model.netD.requires_grad_(True)
|
||||
img_fake = self.gen(src_image1, latent_id)
|
||||
# G loss
|
||||
gen_logits,feat = self.dis(img_fake, None)
|
||||
|
||||
loss_Gmain = (-gen_logits).mean()
|
||||
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
|
||||
latent_fake = self.arcface(img_fake_down)
|
||||
latent_fake = F.normalize(latent_fake, p=2, dim=1)
|
||||
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
|
||||
real_feat = self.dis.get_feature(src_image1)
|
||||
feat_match_loss = l1_loss(feat["3"],real_feat["3"]) + l1_loss(feat["2"],real_feat["2"])
|
||||
|
||||
src1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic')
|
||||
src1_id = self.arcface(src1_down)
|
||||
cyc_fake = self.gen(img_fake, src1_id)
|
||||
loss_cyc = l1_loss(cyc_fake, src_image1)
|
||||
loss_G = loss_Gmain + loss_G_ID * id_w + \
|
||||
feat_match_loss * feat_w + cyc_w * loss_cyc
|
||||
if step%2 == 0:
|
||||
#G_Rec
|
||||
loss_G_Rec = l1_loss(img_fake, src_image1)
|
||||
loss_G += loss_G_Rec * rec_w
|
||||
|
||||
self.g_optimizer.zero_grad()
|
||||
loss_G.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(self.config["version"], elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["logger"] == "tensorboard":
|
||||
self.logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
self.logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
|
||||
self.logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
|
||||
self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
self.logger.add_scalar('G/Cycle', loss_cyc.item(), step)
|
||||
self.logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
self.logger.add_scalar('D/D_real', loss_Dreal.item(), step)
|
||||
elif self.config["logger"] == "wandb":
|
||||
self.logger.log({"G_loss": loss_G.item()}, step = step)
|
||||
self.logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
|
||||
self.logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
|
||||
self.logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
self.logger.log({"Cycle": loss_cyc.item()}, step = step)
|
||||
self.logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
self.logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
self.logger.log({"D_real": loss_Dreal.item()}, step = step)
|
||||
if (step + 1) % sample_freq == 0:
|
||||
self.__evaluation__(
|
||||
step = step,
|
||||
**{
|
||||
"src1": src_image1,
|
||||
"src2": src_image2
|
||||
})
|
||||
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (step+1) % model_freq==0:
|
||||
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
torch.save(self.dis.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
self.config["checkpoint_names"]["discriminator_name"])))
|
||||
|
||||
torch.save(self.g_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.save(self.d_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
self.config["checkpoint_names"]["discriminator_name"])))
|
||||
print("Save step %d model checkpoint!"%(step+1))
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.__evaluation__(
|
||||
step = step,
|
||||
**{
|
||||
"src1": src_image1,
|
||||
"src2": src_image2
|
||||
})
|
||||
@@ -0,0 +1,63 @@
|
||||
# Related scripts
|
||||
train_script_name: 2layer_FM
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: Generator
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 512
|
||||
g_kernel_size: 3
|
||||
res_num: 9
|
||||
|
||||
d_model:
|
||||
script: projected_discriminator
|
||||
class_name: ProjectedDiscriminator
|
||||
module_params:
|
||||
diffaug: False
|
||||
interp224: False
|
||||
backbone_kwargs: {}
|
||||
|
||||
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
|
||||
|
||||
# Training information
|
||||
batch_size: 12
|
||||
|
||||
# Dataset
|
||||
dataloader: VGGFace2HQ
|
||||
dataset_name: vggface2_hq
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: 0.0004
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
d_optim_config:
|
||||
lr: 0.0004
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
id_weight: 20.0
|
||||
reconstruct_weight: 1.0
|
||||
feature_match_weight: 1.0
|
||||
|
||||
# Log
|
||||
log_step: 300
|
||||
model_save_step: 10000
|
||||
total_step: 1000000
|
||||
sample_step: 1000
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
discriminator_name: Discriminator
|
||||
@@ -22,7 +22,7 @@ model_configs:
|
||||
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
|
||||
|
||||
# Training information
|
||||
batch_size: 1
|
||||
batch_size: 12
|
||||
|
||||
# Dataset
|
||||
dataloader: VGGFace2HQ
|
||||
@@ -49,13 +49,14 @@ d_optim_config:
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
id_weight: 10.0
|
||||
id_weight: 20.0
|
||||
reconstruct_weight: 1.0
|
||||
feature_match_weight: 5.0
|
||||
feature_match_weight: 10.0
|
||||
|
||||
# Log
|
||||
log_step: 10
|
||||
model_save_step: 20
|
||||
log_step: 300
|
||||
model_save_step: 10000
|
||||
sample_step: 1000
|
||||
total_step: 1000000
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
# Related scripts
|
||||
train_script_name: cycleloss
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: Generator
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 512
|
||||
g_kernel_size: 3
|
||||
res_num: 9
|
||||
|
||||
d_model:
|
||||
script: projected_discriminator
|
||||
class_name: ProjectedDiscriminator
|
||||
module_params:
|
||||
diffaug: False
|
||||
interp224: False
|
||||
backbone_kwargs: {}
|
||||
|
||||
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
|
||||
|
||||
# Training information
|
||||
batch_size: 12
|
||||
|
||||
# Dataset
|
||||
dataloader: VGGFace2HQ
|
||||
dataset_name: vggface2_hq
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: 0.0004
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
d_optim_config:
|
||||
lr: 0.0004
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
id_weight: 20.0
|
||||
reconstruct_weight: 0.1
|
||||
feature_match_weight: 0.1
|
||||
cycle_weight: 10.0
|
||||
|
||||
# Log
|
||||
log_step: 400
|
||||
model_save_step: 10000
|
||||
total_step: 1000000
|
||||
sample_step: 1000
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
discriminator_name: Discriminator
|
||||
Reference in New Issue
Block a user