init
This commit is contained in:
@@ -0,0 +1,307 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_condition_SN_multiscale.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th October 2021 2:18:26 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from components.Transform import Transform_block
|
||||
from utilities.utilities import denorm, Gram, img2tensor255
|
||||
from pretrained_weights.vgg import VGG16
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
dataloader = self.dataloader_class(self.train_dataset,
|
||||
config["batch_size"],
|
||||
config["imcrop_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
# TODO To save the important scripts
|
||||
# save the yaml file
|
||||
import shutil
|
||||
file1 = os.path.join("components", "%s.py"%model_config["g_model"]["script"])
|
||||
tgtfile1 = os.path.join(self.config["project_scripts"], "%s.py"%model_config["g_model"]["script"])
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
# print("prepare the fixed labels...")
|
||||
# fix_label = [i for i in range(n_class)]
|
||||
# fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
style_tensor = img2tensor255(style_img).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_tensor = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
# style_features = VGG(style_tensor)
|
||||
style_gram = {}
|
||||
for key, value in style_tensor.items():
|
||||
style_gram[key] = Gram(value)
|
||||
del style_tensor
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
s_loss = MSE_loss(Gram(value), style_gram[key])
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = fake_image
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|
||||
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_condition_SN_multiscale.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 7:38:36 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from components.Transform import Transform_block
|
||||
from utilities.utilities import denorm, Gram, img2tensor255crop
|
||||
from pretrained_weights.vgg import VGG16
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
dataloader = self.dataloader_class(self.train_dataset,
|
||||
config["batch_size"],
|
||||
config["imcrop_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
# print("prepare the fixed labels...")
|
||||
# fix_label = [i for i in range(n_class)]
|
||||
# fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
style_gram = {}
|
||||
for key, value in style_features.items():
|
||||
style_gram[key] = Gram(value)
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
s_loss = MSE_loss(Gram(value), style_gram[key])
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = fake_image
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|
||||
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_condition_SN_multiscale.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 9:25:13 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from utilities.utilities import denorm, Gram, img2tensor255crop
|
||||
from pretrained_weights.vgg import VGG16
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
dataloader = self.dataloader_class(self.train_dataset,
|
||||
config["batch_size"],
|
||||
config["imcrop_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
# print("prepare the fixed labels...")
|
||||
# fix_label = [i for i in range(n_class)]
|
||||
# fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
style_gram = {}
|
||||
for key, value in style_features.items():
|
||||
style_gram[key] = Gram(value)
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
s_loss = MSE_loss(Gram(value), style_gram[key])
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = fake_image
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|
||||
@@ -0,0 +1,300 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_condition_SN_multiscale.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 2:28:24 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from utilities.utilities import denorm, img2tensor255crop
|
||||
from losses.SliceWassersteinDistance import SWD
|
||||
from pretrained_weights.vgg import VGG16
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
dataloader = self.dataloader_class(self.train_dataset,
|
||||
config["batch_size"],
|
||||
config["imcrop_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_epoch"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
swd_dim = self.config["swd_dim"]
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
# print("prepare the fixed labels...")
|
||||
# fix_label = [i for i in range(n_class)]
|
||||
# fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
# swd = SWD()
|
||||
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
swd_list = {}
|
||||
for key, value in style_features.items():
|
||||
|
||||
swd_list[key] = SWD(value.shape[1],swd_dim).cuda()
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
swd_list[key].update()
|
||||
s_loss = MSE_loss(swd_list[key](value), swd_list[key](style_features[key]))
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = fake_image
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|
||||
@@ -0,0 +1,382 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_condition_SN_multiscale.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 6th July 2021 7:36:42 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from components.Transform import Transform_block
|
||||
from utilities.utilities import denorm
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.dataloader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
# dataloader = self.dataloader_class(self.train_dataset,
|
||||
# config["batch_size_list"][0],
|
||||
# config["imcrop_size_list"][0],
|
||||
# **config["dataset_params"])
|
||||
|
||||
# self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
dscript_name = "components." + model_config["d_model"]["script"]
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
dscript_name = self.config["com_base"] + model_config["d_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
class_name = model_config["d_model"]["class_name"]
|
||||
package = __import__(dscript_name, fromlist=True)
|
||||
dis_class = getattr(package, class_name)
|
||||
self.dis = dis_class(**model_config["d_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
self.reporter.writeInfo("Discriminator structure:")
|
||||
self.reporter.writeModel(self.dis.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
self.dis = self.dis.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["discriminator_name"]))
|
||||
self.dis.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
d_train_opt = self.config['d_optim_config']
|
||||
g_optim_params = []
|
||||
d_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
for k, v in self.dis.named_parameters():
|
||||
if v.requires_grad:
|
||||
d_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
|
||||
n_class = len(self.config["selected_style_dir"])
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
feature_w = self.config["feature_weight"]
|
||||
transform_w = self.config["transform_weight"]
|
||||
d_step = self.config["d_step"]
|
||||
g_step = self.config["g_step"]
|
||||
|
||||
batch_size_list = self.config["batch_size_list"]
|
||||
switch_epoch_list = self.config["switch_epoch_list"]
|
||||
imcrop_size_list = self.config["imcrop_size_list"]
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
current_epoch_index = 0
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
Transform = Transform_block().cuda()
|
||||
L1_loss = torch.nn.L1Loss()
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
Hinge_loss = torch.nn.ReLU().cuda()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
|
||||
output_size = self.dis.get_outputs_len()
|
||||
|
||||
print("prepare the fixed labels...")
|
||||
fix_label = [i for i in range(n_class)]
|
||||
fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(start, total_epoch):
|
||||
|
||||
# switch training image size
|
||||
if epoch in switch_epoch_list:
|
||||
print('Current epoch: {}'.format(epoch))
|
||||
print('***Redefining the dataloader for progressive training.***')
|
||||
print('***Current spatial size is {} and batch size is {}.***'.format(
|
||||
imcrop_size_list[current_epoch_index], batch_size_list[current_epoch_index]))
|
||||
del self.train_loader
|
||||
self.train_loader = self.dataloader_class(self.train_dataset,
|
||||
batch_size_list[current_epoch_index],
|
||||
imcrop_size_list[current_epoch_index],
|
||||
**self.config["dataset_params"])
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // (d_step + g_step)
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
current_epoch_index += 1
|
||||
|
||||
for step in range(step_epoch):
|
||||
self.dis.train()
|
||||
self.gen.train()
|
||||
|
||||
# ================== Train D ================== #
|
||||
# Compute loss with real images
|
||||
for _ in range(d_step):
|
||||
content_images,style_images,label = self.train_loader.next()
|
||||
label = label.long()
|
||||
|
||||
d_out = self.dis(style_images,label)
|
||||
d_loss_real = 0
|
||||
for i in range(output_size):
|
||||
temp = Hinge_loss(1 - d_out[i]).mean()
|
||||
d_loss_real += temp
|
||||
|
||||
d_loss_photo = 0
|
||||
d_out = self.dis(content_images,label)
|
||||
for i in range(output_size):
|
||||
temp = Hinge_loss(1 + d_out[i]).mean()
|
||||
d_loss_photo += temp
|
||||
|
||||
fake_image,_= self.gen(content_images,label)
|
||||
d_out = self.dis(fake_image.detach(),label)
|
||||
d_loss_fake = 0
|
||||
for i in range(output_size):
|
||||
temp = Hinge_loss(1 + d_out[i]).mean()
|
||||
# temp *= prep_weights[i]
|
||||
d_loss_fake += temp
|
||||
|
||||
# Backward + Optimize
|
||||
d_loss = d_loss_real + d_loss_photo + d_loss_fake
|
||||
self.d_optimizer.zero_grad()
|
||||
d_loss.backward()
|
||||
self.d_optimizer.step()
|
||||
|
||||
# ================== Train G ================== #
|
||||
for _ in range(g_step):
|
||||
|
||||
content_images,_,_ = self.train_loader.next()
|
||||
fake_image,real_feature = self.gen(content_images,label)
|
||||
fake_feature = self.gen(fake_image, get_feature=True)
|
||||
d_out = self.dis(fake_image,label.long())
|
||||
|
||||
g_feature_loss = L1_loss(fake_feature,real_feature)
|
||||
g_transform_loss = MSE_loss(Transform(content_images), Transform(fake_image))
|
||||
g_loss_fake = 0
|
||||
for i in range(output_size):
|
||||
temp = -d_out[i].mean()
|
||||
# temp *= prep_weights[i]
|
||||
g_loss_fake += temp
|
||||
|
||||
# backward & optimize
|
||||
g_loss = g_loss_fake + g_feature_loss* feature_w + g_transform_loss* transform_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, d_loss_real: {:.4f}, \\\
|
||||
d_loss_photo: {:.4f}, d_loss_fake: {:.4f}, g_loss: {:.4f}, g_loss_fake: {:.4f}, \\\
|
||||
g_feature_loss: {:.4f}, g_transform_loss: {:.4f}".format(self.config["version"],
|
||||
epoch + 1, total_epoch, elapsed, step + 1, step_epoch,
|
||||
d_loss.item(), d_loss_real.item(), d_loss_photo.item(),
|
||||
d_loss_fake.item(), g_loss.item(), g_loss_fake.item(),\
|
||||
g_feature_loss.item(), g_transform_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeRawInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/d_loss', d_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/d_loss_real', d_loss_real.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/d_loss_photo', d_loss_photo.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/d_loss_fake', d_loss_fake.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/g_loss_fake', g_loss_fake.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/g_feature_loss', g_feature_loss, cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/g_transform_loss', g_transform_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
print("Learning rate decay")
|
||||
for p in self.optimizer.param_groups:
|
||||
p['lr'] *= self.config["lr_decay"]
|
||||
print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
torch.save(self.dis.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["discriminator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(step + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = content_images[0, :, :, :].unsqueeze(0)
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
for index in range(n_class):
|
||||
fake_images,_ = self.gen(sample, fix_label[index].unsqueeze(0))
|
||||
saved_image1 = torch.cat((saved_image1, denorm(fake_images.cpu().data)), 0)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(step + 1)),nrow=3)
|
||||
@@ -0,0 +1,295 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_naiv512.py
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 9th January 2022 12:31:03 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from utilities.utilities import denorm, Gram, img2tensor255crop
|
||||
from pretrained_weights.vgg import VGG16
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
dataloader = self.dataloader_class(self.train_dataset,
|
||||
config["batch_size"],
|
||||
config["imcrop_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
# print("prepare the fixed labels...")
|
||||
# fix_label = [i for i in range(n_class)]
|
||||
# fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
style_gram = {}
|
||||
for key, value in style_features.items():
|
||||
style_gram[key] = Gram(value)
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
s_loss = MSE_loss(Gram(value), style_gram[key])
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = fake_image
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|
||||
Reference in New Issue
Block a user