This commit is contained in:
chenxuanhong
2022-01-17 13:17:49 +08:00
parent bf2df5c5a6
commit 601d2ee43d
58 changed files with 2748 additions and 5696 deletions
+333
View File
@@ -0,0 +1,333 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_naiv512.py
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 17th January 2022 1:12:08 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import time
import random
import numpy as np
import torch
import torch.nn.functional as F
from utilities.plot import plot_batch
from train_scripts.trainer_base import TrainerBase
class Trainer(TrainerBase):
def __init__(self, config, reporter):
super(Trainer, self).__init__(config, reporter)
self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
# TODO modify this function to build your models
def init_framework(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
dscript_name = "components." + model_config["d_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
dscript_name = self.config["com_base"] + model_config["d_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
class_name = model_config["d_model"]["class_name"]
package = __import__(dscript_name, fromlist=True)
dis_class = getattr(package, class_name)
self.dis = dis_class(**model_config["d_model"]["module_params"])
self.dis.feature_network.requires_grad_(False)
# print and recorde model structure
self.reporter.writeInfo("Discriminator structure:")
self.reporter.writeModel(self.dis.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
self.dis = self.dis.cuda()
self.arcface= self.arcface.cuda()
self.arcface.eval()
self.arcface.requires_grad_(False)
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["discriminator_name"]))
self.dis.load_state_dict(torch.load(model_path))
print('loaded trained backbone model step {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
d_train_opt = self.config['d_optim_config']
g_optim_params = []
d_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
for k, v in self.dis.named_parameters():
if v.requires_grad:
d_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
if self.config["phase"] == "finetune":
opt_path = os.path.join(self.config["project_checkpoints"],
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
self.config["optimizer_names"]["generator_name"]))
self.g_optimizer.load_state_dict(torch.load(opt_path))
opt_path = os.path.join(self.config["project_checkpoints"],
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
self.config["optimizer_names"]["discriminator_name"]))
self.d_optimizer.load_state_dict(torch.load(opt_path))
print('loaded trained optimizer step {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
# Evaluate the checkpoint
def __evaluation__(self,
step = 0,
**kwargs
):
src_image1 = kwargs["src1"]
src_image2 = kwargs["src2"]
batch_size = self.batch_size
self.gen.eval()
with torch.no_grad():
imgs = []
zero_img = (torch.zeros_like(src_image1[0,...]))
imgs.append(zero_img.cpu().numpy())
save_img = ((src_image1.cpu())* self.img_std + self.img_mean).numpy()
for r in range(batch_size):
imgs.append(save_img[r,...])
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
id_vector_src1 = self.arcface(arcface_112)
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
for i in range(batch_size):
imgs.append(save_img[i,...])
image_infer = src_image1[i, ...].repeat(batch_size, 1, 1, 1)
img_fake = self.gen(image_infer, id_vector_src1).cpu()
img_fake = img_fake * self.img_std
img_fake = img_fake + self.img_mean
img_fake = img_fake.numpy()
for j in range(batch_size):
imgs.append(img_fake[j,...])
print("Save test data")
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
plot_batch(imgs, os.path.join(self.sample_dir, 'step_'+str(step+1)+'.jpg'))
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_step"]
total_step = self.config["total_step"]
random_seed = self.config["dataset_params"]["random_seed"]
self.batch_size = self.config["batch_size"]
self.sample_dir = self.config["project_samples"]
self.arcface_ckpt= self.config["arcface_ckpt"]
# prep_weights= self.config["layersWeight"]
id_w = self.config["id_weight"]
rec_w = self.config["reconstruct_weight"]
feat_w = self.config["feature_match_weight"]
super().train()
#===============build losses===================#
# TODO replace below lines to build your losses
# MSE_loss = torch.nn.MSELoss()
l1_loss = torch.nn.L1Loss()
cos_loss = torch.nn.CosineSimilarity()
start_time = time.time()
# Caculate the epoch number
print("Total step = %d"%total_step)
random.seed(random_seed)
randindex = [i for i in range(self.batch_size)]
random.shuffle(randindex)
import datetime
for step in range(self.start, total_step):
self.gen.train()
self.dis.train()
for interval in range(2):
random.shuffle(randindex)
src_image1, src_image2 = self.train_loader.next()
if step%2 == 0:
img_id = src_image2
else:
img_id = src_image2[randindex]
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
latent_id = self.arcface(img_id_112)
latent_id = F.normalize(latent_id, p=2, dim=1)
if interval:
img_fake = self.gen(src_image1, latent_id)
gen_logits,_ = self.dis(img_fake.detach(), None)
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
real_logits,_ = self.dis(src_image2,None)
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
loss_D = loss_Dgen + loss_Dreal
self.d_optimizer.zero_grad()
loss_D.backward()
self.d_optimizer.step()
else:
# model.netD.requires_grad_(True)
img_fake = self.gen(src_image1, latent_id)
# G loss
gen_logits,feat = self.dis(img_fake, None)
loss_Gmain = (-gen_logits).mean()
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
latent_fake = self.arcface(img_fake_down)
latent_fake = F.normalize(latent_fake, p=2, dim=1)
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
real_feat = self.dis.get_feature(src_image1)
feat_match_loss = l1_loss(feat["3"],real_feat["3"])
loss_G = loss_Gmain + loss_G_ID * id_w + \
feat_match_loss * feat_w
if step%2 == 0:
#G_Rec
loss_G_Rec = l1_loss(img_fake, src_image1)
loss_G += loss_G_Rec * rec_w
self.g_optimizer.zero_grad()
loss_G.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
format(self.config["version"], elapsed, step, total_step, \
loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["logger"] == "tensorboard":
self.logger.add_scalar('G/G_loss', loss_G.item(), step)
self.logger.add_scalar('G/Rec_loss', loss_G_Rec.item(), step)
self.logger.add_scalar('G/Fm_loss', feat_match_loss.item(), step)
self.logger.add_scalar('D/D_loss', loss_D.item(), step)
self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
self.logger.add_scalar('D/D_real', loss_Dreal.item(), step)
elif self.config["logger"] == "wandb":
self.logger.log({"G_loss": loss_G.item()}, step = step)
self.logger.log({"Rec_loss": loss_G_Rec.item()}, step = step)
self.logger.log({"Fm_loss": feat_match_loss.item()}, step = step)
self.logger.log({"D_loss": loss_D.item()}, step = step)
self.logger.log({"D_fake": loss_Dgen.item()}, step = step)
self.logger.log({"D_real": loss_Dreal.item()}, step = step)
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (step+1) % model_freq==0:
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.save(self.dis.state_dict(),
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
self.config["checkpoint_names"]["discriminator_name"])))
torch.save(self.g_optimizer.state_dict(),
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.save(self.d_optimizer.state_dict(),
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
self.config["checkpoint_names"]["discriminator_name"])))
print("Save step %d model checkpoint!"%(step+1))
torch.cuda.empty_cache()
self.__evaluation__(
step = step,
**{
"src1": src_image1,
"src2": src_image2
})
-307
View File
@@ -1,307 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_condition_SN_multiscale.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 2:18:26 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from components.Transform import Transform_block
from utilities.utilities import denorm, Gram, img2tensor255
from pretrained_weights.vgg import VGG16
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
dataloader = self.dataloader_class(self.train_dataset,
config["batch_size"],
config["imcrop_size"],
**config["dataset_params"])
self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
# TODO To save the important scripts
# save the yaml file
import shutil
file1 = os.path.join("components", "%s.py"%model_config["g_model"]["script"])
tgtfile1 = os.path.join(self.config["project_scripts"], "%s.py"%model_config["g_model"]["script"])
shutil.copyfile(file1,tgtfile1)
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
g_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
sample_dir = self.config["project_samples"]
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
MSE_loss = torch.nn.MSELoss()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
# print("prepare the fixed labels...")
# fix_label = [i for i in range(n_class)]
# fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
VGG = VGG16().cuda()
MEAN_VAL = 127.5
SCALE_VAL= 127.5
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
style_tensor = img2tensor255(style_img).cuda()
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_tensor = VGG(style_tensor.expand([batch_size, C, H, W]))
# style_features = VGG(style_tensor)
style_gram = {}
for key, value in style_tensor.items():
style_gram[key] = Gram(value)
del style_tensor
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
s_loss = MSE_loss(Gram(value), style_gram[key])
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(epoch + 1))
self.gen.eval()
with torch.no_grad():
sample = fake_image
saved_image1 = denorm(sample.cpu().data)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
-297
View File
@@ -1,297 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_condition_SN_multiscale.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 7:38:36 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from components.Transform import Transform_block
from utilities.utilities import denorm, Gram, img2tensor255crop
from pretrained_weights.vgg import VGG16
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
dataloader = self.dataloader_class(self.train_dataset,
config["batch_size"],
config["imcrop_size"],
**config["dataset_params"])
self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
g_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
sample_dir = self.config["project_samples"]
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
MSE_loss = torch.nn.MSELoss()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
# print("prepare the fixed labels...")
# fix_label = [i for i in range(n_class)]
# fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
VGG = VGG16().cuda()
MEAN_VAL = 127.5
SCALE_VAL= 127.5
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
style_gram = {}
for key, value in style_features.items():
style_gram[key] = Gram(value)
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
s_loss = MSE_loss(Gram(value), style_gram[key])
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(epoch + 1))
self.gen.eval()
with torch.no_grad():
sample = fake_image
saved_image1 = denorm(sample.cpu().data)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
-296
View File
@@ -1,296 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_condition_SN_multiscale.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 9:25:13 am
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from utilities.utilities import denorm, Gram, img2tensor255crop
from pretrained_weights.vgg import VGG16
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
dataloader = self.dataloader_class(self.train_dataset,
config["batch_size"],
config["imcrop_size"],
**config["dataset_params"])
self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
g_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
sample_dir = self.config["project_samples"]
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
MSE_loss = torch.nn.MSELoss()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
# print("prepare the fixed labels...")
# fix_label = [i for i in range(n_class)]
# fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
VGG = VGG16().cuda()
MEAN_VAL = 127.5
SCALE_VAL= 127.5
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
style_gram = {}
for key, value in style_features.items():
style_gram[key] = Gram(value)
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
s_loss = MSE_loss(Gram(value), style_gram[key])
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(epoch + 1))
self.gen.eval()
with torch.no_grad():
sample = fake_image
saved_image1 = denorm(sample.cpu().data)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
-300
View File
@@ -1,300 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_condition_SN_multiscale.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 2:28:24 am
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from utilities.utilities import denorm, img2tensor255crop
from losses.SliceWassersteinDistance import SWD
from pretrained_weights.vgg import VGG16
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
dataloader = self.dataloader_class(self.train_dataset,
config["batch_size"],
config["imcrop_size"],
**config["dataset_params"])
self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_epoch"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
g_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
swd_dim = self.config["swd_dim"]
sample_dir = self.config["project_samples"]
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
MSE_loss = torch.nn.MSELoss()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
# print("prepare the fixed labels...")
# fix_label = [i for i in range(n_class)]
# fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
VGG = VGG16().cuda()
MEAN_VAL = 127.5
SCALE_VAL= 127.5
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
# swd = SWD()
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
swd_list = {}
for key, value in style_features.items():
swd_list[key] = SWD(value.shape[1],swd_dim).cuda()
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
swd_list[key].update()
s_loss = MSE_loss(swd_list[key](value), swd_list[key](style_features[key]))
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(epoch + 1))
self.gen.eval()
with torch.no_grad():
sample = fake_image
saved_image1 = denorm(sample.cpu().data)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
+114
View File
@@ -0,0 +1,114 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_base.py
# Created Date: Sunday January 16th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 17th January 2022 1:08:25 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
class TrainerBase(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
dataloader = self.dataloader_class(self.train_dataset,
config["batch_size"],
**config["dataset_params"])
self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["logger"] == "tensorboard":
from utilities.utilities import build_tensorboard
tensorboard_writer = build_tensorboard(self.config["project_summary"])
self.logger = tensorboard_writer
elif self.config["logger"] == "wandb":
import wandb
wandb.init(project="Simswap_HQ", entity="xhchen", notes="512",
tags=[self.config["tag"]], name=self.config["version"])
wandb.config = {
"total_step": self.config["total_step"],
"batch_size": self.config["batch_size"]
}
self.logger = wandb
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
pass
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
pass
# TODO modify this function to evaluate your model
# Evaluate the checkpoint
def __evaluation__(self,
step = 0,
**kwargs
):
pass
def train(self):
#===============build framework================#
self.init_framework()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
# set the start point for training loop
if self.config["phase"] == "finetune":
self.start = self.config["checkpoint_step"]
else:
self.start = 0
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
-382
View File
@@ -1,382 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_condition_SN_multiscale.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 6th July 2021 7:36:42 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from components.Transform import Transform_block
from utilities.utilities import denorm
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.dataloader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
# dataloader = self.dataloader_class(self.train_dataset,
# config["batch_size_list"][0],
# config["imcrop_size_list"][0],
# **config["dataset_params"])
# self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
dscript_name = "components." + model_config["d_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
dscript_name = self.config["com_base"] + model_config["d_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
class_name = model_config["d_model"]["class_name"]
package = __import__(dscript_name, fromlist=True)
dis_class = getattr(package, class_name)
self.dis = dis_class(**model_config["d_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
self.reporter.writeInfo("Discriminator structure:")
self.reporter.writeModel(self.dis.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
self.dis = self.dis.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["discriminator_name"]))
self.dis.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
d_train_opt = self.config['d_optim_config']
g_optim_params = []
d_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
for k, v in self.dis.named_parameters():
if v.requires_grad:
d_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
n_class = len(self.config["selected_style_dir"])
# prep_weights= self.config["layersWeight"]
feature_w = self.config["feature_weight"]
transform_w = self.config["transform_weight"]
d_step = self.config["d_step"]
g_step = self.config["g_step"]
batch_size_list = self.config["batch_size_list"]
switch_epoch_list = self.config["switch_epoch_list"]
imcrop_size_list = self.config["imcrop_size_list"]
sample_dir = self.config["project_samples"]
current_epoch_index = 0
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
Transform = Transform_block().cuda()
L1_loss = torch.nn.L1Loss()
MSE_loss = torch.nn.MSELoss()
Hinge_loss = torch.nn.ReLU().cuda()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
output_size = self.dis.get_outputs_len()
print("prepare the fixed labels...")
fix_label = [i for i in range(n_class)]
fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
for epoch in range(start, total_epoch):
# switch training image size
if epoch in switch_epoch_list:
print('Current epoch: {}'.format(epoch))
print('***Redefining the dataloader for progressive training.***')
print('***Current spatial size is {} and batch size is {}.***'.format(
imcrop_size_list[current_epoch_index], batch_size_list[current_epoch_index]))
del self.train_loader
self.train_loader = self.dataloader_class(self.train_dataset,
batch_size_list[current_epoch_index],
imcrop_size_list[current_epoch_index],
**self.config["dataset_params"])
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // (d_step + g_step)
print("Total step = %d in each epoch"%step_epoch)
current_epoch_index += 1
for step in range(step_epoch):
self.dis.train()
self.gen.train()
# ================== Train D ================== #
# Compute loss with real images
for _ in range(d_step):
content_images,style_images,label = self.train_loader.next()
label = label.long()
d_out = self.dis(style_images,label)
d_loss_real = 0
for i in range(output_size):
temp = Hinge_loss(1 - d_out[i]).mean()
d_loss_real += temp
d_loss_photo = 0
d_out = self.dis(content_images,label)
for i in range(output_size):
temp = Hinge_loss(1 + d_out[i]).mean()
d_loss_photo += temp
fake_image,_= self.gen(content_images,label)
d_out = self.dis(fake_image.detach(),label)
d_loss_fake = 0
for i in range(output_size):
temp = Hinge_loss(1 + d_out[i]).mean()
# temp *= prep_weights[i]
d_loss_fake += temp
# Backward + Optimize
d_loss = d_loss_real + d_loss_photo + d_loss_fake
self.d_optimizer.zero_grad()
d_loss.backward()
self.d_optimizer.step()
# ================== Train G ================== #
for _ in range(g_step):
content_images,_,_ = self.train_loader.next()
fake_image,real_feature = self.gen(content_images,label)
fake_feature = self.gen(fake_image, get_feature=True)
d_out = self.dis(fake_image,label.long())
g_feature_loss = L1_loss(fake_feature,real_feature)
g_transform_loss = MSE_loss(Transform(content_images), Transform(fake_image))
g_loss_fake = 0
for i in range(output_size):
temp = -d_out[i].mean()
# temp *= prep_weights[i]
g_loss_fake += temp
# backward & optimize
g_loss = g_loss_fake + g_feature_loss* feature_w + g_transform_loss* transform_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, d_loss_real: {:.4f}, \\\
d_loss_photo: {:.4f}, d_loss_fake: {:.4f}, g_loss: {:.4f}, g_loss_fake: {:.4f}, \\\
g_feature_loss: {:.4f}, g_transform_loss: {:.4f}".format(self.config["version"],
epoch + 1, total_epoch, elapsed, step + 1, step_epoch,
d_loss.item(), d_loss_real.item(), d_loss_photo.item(),
d_loss_fake.item(), g_loss.item(), g_loss_fake.item(),\
g_feature_loss.item(), g_transform_loss.item())
print(epochinformation)
self.reporter.writeRawInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/d_loss', d_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/d_loss_real', d_loss_real.item(), cum_step)
self.tensorboard_writer.add_scalar('data/d_loss_photo', d_loss_photo.item(), cum_step)
self.tensorboard_writer.add_scalar('data/d_loss_fake', d_loss_fake.item(), cum_step)
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/g_loss_fake', g_loss_fake.item(), cum_step)
self.tensorboard_writer.add_scalar('data/g_feature_loss', g_feature_loss, cum_step)
self.tensorboard_writer.add_scalar('data/g_transform_loss', g_transform_loss, cum_step)
#===============adjust learning rate============#
if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
print("Learning rate decay")
for p in self.optimizer.param_groups:
p['lr'] *= self.config["lr_decay"]
print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.save(self.dis.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["discriminator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(step + 1))
self.gen.eval()
with torch.no_grad():
sample = content_images[0, :, :, :].unsqueeze(0)
saved_image1 = denorm(sample.cpu().data)
for index in range(n_class):
fake_images,_ = self.gen(sample, fix_label[index].unsqueeze(0))
saved_image1 = torch.cat((saved_image1, denorm(fake_images.cpu().data)), 0)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(step + 1)),nrow=3)
+13 -27
View File
@@ -5,15 +5,17 @@
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 9th January 2022 12:31:03 am
# Last Modified: Tuesday, 11th January 2022 3:06:14 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import time
import random
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from utilities.utilities import denorm
@@ -182,12 +184,10 @@ class Trainer(object):
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
sample_dir = self.config["project_samples"]
@@ -231,32 +231,30 @@ class Trainer(object):
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
randindex = [i for i in range(batch_size)]
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
image1, image2 = self.train_loader.next()
random.shuffle(randindex)
src_image1, src_image2 = self.train_loader.next()
img_att = src_image1
img_att = image1
if step%2 == 0:
img_id = src_image2
img_id = image2 # swap with same id, different pose
else:
img_id = src_image2[randindex]
img_id = image2[randindex] # swap with different face
src_image1_112 = F.interpolate(src_image1,size=(112,112), mode='bicubic')
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
img_id_112_norm = spnorm(img_id_112)
latent_id = model.netArc(img_id_112_norm)
latent_id = self.arcface(img_id_112)
latent_id = F.normalize(latent_id, p=2, dim=1)
losses, img_fake= self.gen(src_image1, latent_id)
losses, img_fake= self.gen(image1, latent_id)
# update Generator weights
losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
@@ -275,18 +273,6 @@ class Trainer(object):
loss_D.backward()
optimizer_D.step()
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
s_loss = MSE_loss(Gram(value), style_gram[key])
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w