323 lines
13 KiB
Python
323 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: trainer_naiv512.py
|
|
# Created Date: Sunday January 9th 2022
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Tuesday, 11th January 2022 3:06:14 pm
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
import os
|
|
import time
|
|
import random
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torchvision.utils import save_image
|
|
|
|
from utilities.utilities import denorm
|
|
|
|
|
|
class Trainer(object):
|
|
|
|
def __init__(self, config, reporter):
|
|
|
|
self.config = config
|
|
# logger
|
|
self.reporter = reporter
|
|
|
|
# Data loader
|
|
#============build train dataloader==============#
|
|
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
|
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
|
#================================================#
|
|
print("Prepare the train dataloader...")
|
|
dlModulename = config["dataloader"]
|
|
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
|
dataloaderClass = getattr(package, 'GetLoader')
|
|
self.dataloader_class = dataloaderClass
|
|
dataloader = self.dataloader_class(self.train_dataset,
|
|
config["batch_size"],
|
|
config["imcrop_size"],
|
|
**config["dataset_params"])
|
|
|
|
self.train_loader= dataloader
|
|
|
|
#========build evaluation dataloader=============#
|
|
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
|
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
|
|
|
# #================================================#
|
|
# print("Prepare the evaluation dataloader...")
|
|
# dlModulename = config["eval_dataloader"]
|
|
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
|
# dataloaderClass = getattr(package, 'EvalDataset')
|
|
# dataloader = dataloaderClass(eval_dataset,
|
|
# config["eval_batch_size"])
|
|
# self.eval_loader= dataloader
|
|
|
|
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
|
# if len(dataloader)%config["eval_batch_size"]>0:
|
|
# self.eval_iter+=1
|
|
|
|
#==============build tensorboard=================#
|
|
if self.config["use_tensorboard"]:
|
|
from utilities.utilities import build_tensorboard
|
|
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
|
|
|
# TODO modify this function to build your models
|
|
def __init_framework__(self):
|
|
'''
|
|
This function is designed to define the framework,
|
|
and print the framework information into the log file
|
|
'''
|
|
#===============build models================#
|
|
print("build models...")
|
|
# TODO [import models here]
|
|
|
|
model_config = self.config["model_configs"]
|
|
|
|
if self.config["phase"] == "train":
|
|
gscript_name = "components." + model_config["g_model"]["script"]
|
|
|
|
elif self.config["phase"] == "finetune":
|
|
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
|
|
|
class_name = model_config["g_model"]["class_name"]
|
|
package = __import__(gscript_name, fromlist=True)
|
|
gen_class = getattr(package, class_name)
|
|
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
|
|
|
# print and recorde model structure
|
|
self.reporter.writeInfo("Generator structure:")
|
|
self.reporter.writeModel(self.gen.__str__())
|
|
|
|
|
|
|
|
|
|
# id extractor network
|
|
arcface_ckpt = self.config["arcface_ckpt"]
|
|
arcface_ckpt = torch.load(arcface_ckpt, map_location=torch.device("cpu"))
|
|
self.arcface = arcface_ckpt['model'].module
|
|
|
|
|
|
|
|
|
|
# train in GPU
|
|
if self.config["cuda"] >=0:
|
|
self.gen = self.gen.cuda()
|
|
self.arcface = self.arcface.cuda()
|
|
|
|
self.arcface.eval()
|
|
self.arcface.requires_grad_(False)
|
|
|
|
# if in finetune phase, load the pretrained checkpoint
|
|
if self.config["phase"] == "finetune":
|
|
model_path = os.path.join(self.config["project_checkpoints"],
|
|
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
|
self.config["checkpoint_names"]["generator_name"]))
|
|
self.gen.load_state_dict(torch.load(model_path))
|
|
|
|
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
|
|
|
|
|
# TODO modify this function to evaluate your model
|
|
def __evaluation__(self, epoch, step = 0):
|
|
# Evaluate the checkpoint
|
|
self.network.eval()
|
|
total_psnr = 0
|
|
total_num = 0
|
|
with torch.no_grad():
|
|
for _ in range(self.eval_iter):
|
|
hr, lr = self.eval_loader()
|
|
|
|
if self.config["cuda"] >=0:
|
|
hr = hr.cuda()
|
|
lr = lr.cuda()
|
|
hr = (hr + 1.0)/2.0 * 255.0
|
|
hr = torch.clamp(hr,0.0,255.0)
|
|
lr = (lr + 1.0)/2.0 * 255.0
|
|
lr = torch.clamp(lr,0.0,255.0)
|
|
res = self.network(lr)
|
|
# res = (res + 1.0)/2.0 * 255.0
|
|
# hr = (hr + 1.0)/2.0 * 255.0
|
|
res = torch.clamp(res,0.0,255.0)
|
|
diff = (res-hr) ** 2
|
|
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
|
psnrs = 20. * (255. / diff).log10()
|
|
total_psnr+= psnrs.sum()
|
|
total_num+=res.shape[0]
|
|
final_psnr = total_psnr/total_num
|
|
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
|
epoch, final_psnr))
|
|
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
|
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
|
|
|
# TODO modify this function to configurate the optimizer of your pipeline
|
|
def __setup_optimizers__(self):
|
|
g_train_opt = self.config['g_optim_config']
|
|
g_optim_params = []
|
|
for k, v in self.gen.named_parameters():
|
|
if v.requires_grad:
|
|
g_optim_params.append(v)
|
|
else:
|
|
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
|
print(f'Params {k} will not be optimized.')
|
|
|
|
optim_type = self.config['optim_type']
|
|
|
|
if optim_type == 'Adam':
|
|
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
|
else:
|
|
raise NotImplementedError(
|
|
f'optimizer {optim_type} is not supperted yet.')
|
|
# self.optimizers.append(self.optimizer_g)
|
|
|
|
|
|
def train(self):
|
|
|
|
ckpt_dir = self.config["project_checkpoints"]
|
|
log_frep = self.config["log_step"]
|
|
model_freq = self.config["model_save_epoch"]
|
|
total_epoch = self.config["total_epoch"]
|
|
batch_size = self.config["batch_size"]
|
|
|
|
# prep_weights= self.config["layersWeight"]
|
|
content_w = self.config["content_weight"]
|
|
style_w = self.config["style_weight"]
|
|
|
|
sample_dir = self.config["project_samples"]
|
|
|
|
|
|
#===============build framework================#
|
|
self.__init_framework__()
|
|
|
|
#===============build optimizer================#
|
|
# Optimizer
|
|
# TODO replace below lines to build your optimizer
|
|
print("build the optimizer...")
|
|
self.__setup_optimizers__()
|
|
|
|
#===============build losses===================#
|
|
# TODO replace below lines to build your losses
|
|
MSE_loss = torch.nn.MSELoss()
|
|
|
|
|
|
# set the start point for training loop
|
|
if self.config["phase"] == "finetune":
|
|
start = self.config["checkpoint_epoch"] - 1
|
|
else:
|
|
start = 0
|
|
|
|
# print("prepare the fixed labels...")
|
|
# fix_label = [i for i in range(n_class)]
|
|
# fix_label = torch.tensor(fix_label).long().cuda()
|
|
# fix_label = fix_label.view(n_class,1)
|
|
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
|
|
|
# Start time
|
|
import datetime
|
|
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
|
|
|
from utilities.logo_class import logo_class
|
|
logo_class.print_start_training()
|
|
start_time = time.time()
|
|
|
|
# Caculate the epoch number
|
|
step_epoch = len(self.train_loader)
|
|
step_epoch = step_epoch // batch_size
|
|
print("Total step = %d in each epoch"%step_epoch)
|
|
|
|
randindex = [i for i in range(batch_size)]
|
|
|
|
|
|
# step_epoch = 2
|
|
for epoch in range(start, total_epoch):
|
|
for step in range(step_epoch):
|
|
self.gen.train()
|
|
image1, image2 = self.train_loader.next()
|
|
random.shuffle(randindex)
|
|
|
|
img_att = image1
|
|
|
|
if step%2 == 0:
|
|
img_id = image2 # swap with same id, different pose
|
|
else:
|
|
img_id = image2[randindex] # swap with different face
|
|
|
|
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
|
|
|
|
latent_id = self.arcface(img_id_112)
|
|
|
|
latent_id = F.normalize(latent_id, p=2, dim=1)
|
|
|
|
losses, img_fake= self.gen(image1, latent_id)
|
|
|
|
# update Generator weights
|
|
losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
|
|
loss_dict = dict(zip(model.loss_names, losses))
|
|
|
|
loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict['G_ID'] * opt.lambda_id
|
|
if step%2 == 0:
|
|
loss_G += loss_dict['G_Rec']
|
|
|
|
optimizer_G.zero_grad()
|
|
loss_G.backward(retain_graph=True)
|
|
optimizer_G.step()
|
|
|
|
loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 + loss_dict['D_GP']
|
|
optimizer_D.zero_grad()
|
|
loss_D.backward()
|
|
optimizer_D.step()
|
|
|
|
|
|
# backward & optimize
|
|
g_loss = content_loss* content_w + style_loss* style_w
|
|
self.g_optimizer.zero_grad()
|
|
g_loss.backward()
|
|
self.g_optimizer.step()
|
|
|
|
|
|
# Print out log info
|
|
if (step + 1) % log_frep == 0:
|
|
elapsed = time.time() - start_time
|
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
|
|
|
# cumulative steps
|
|
cum_step = (step_epoch * epoch + step + 1)
|
|
|
|
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
|
print(epochinformation)
|
|
self.reporter.writeInfo(epochinformation)
|
|
|
|
if self.config["use_tensorboard"]:
|
|
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
|
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
|
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
|
|
|
#===============adjust learning rate============#
|
|
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
|
# print("Learning rate decay")
|
|
# for p in self.optimizer.param_groups:
|
|
# p['lr'] *= self.config["lr_decay"]
|
|
# print("Current learning rate is %f"%p['lr'])
|
|
|
|
#===============save checkpoints================#
|
|
if (epoch+1) % model_freq==0:
|
|
print("Save epoch %d model checkpoint!"%(epoch+1))
|
|
torch.save(self.gen.state_dict(),
|
|
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
|
self.config["checkpoint_names"]["generator_name"])))
|
|
|
|
torch.cuda.empty_cache()
|
|
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
|
self.gen.eval()
|
|
with torch.no_grad():
|
|
sample = fake_image
|
|
saved_image1 = denorm(sample.cpu().data)
|
|
save_image(saved_image1,
|
|
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|