diff --git a/GUI/file_sync/filestate_machine0.json b/GUI/file_sync/filestate_machine0.json index cccc24e..e65df0a 100644 --- a/GUI/file_sync/filestate_machine0.json +++ b/GUI/file_sync/filestate_machine0.json @@ -1,13 +1,9 @@ { "GUI.py": 1642351532.4558506, "test.py": 1634039043.4872007, - "train.py": 1642351831.0061252, + "train.py": 1642408973.489742, "components\\Generator.py": 1642347735.351465, - "components\\Involution.py": 1626748553.9503577, "components\\projected_discriminator.py": 1642348101.4661522, - "components\\ResBlock.py": 1625415499.383468, - "components\\Transform.py": 1624954083.0098498, - "components\\warp_invo.py": 1634614033.6983366, "components\\pg_modules\\blocks.py": 1640773190.0, "components\\pg_modules\\diffaug.py": 1640773190.0, "components\\pg_modules\\discriminator.py": 1642349784.9407308, @@ -17,22 +13,18 @@ "data_tools\\data_loader.py": 1611123530.660446, "data_tools\\data_loader_condition.py": 1625411562.8217106, "data_tools\\data_loader_VGGFace2HQ.py": 1642349144.749807, + "data_tools\\StyleResize.py": 1624954084.7176485, + "data_tools\\test_dataloader_dir.py": 1634041792.6743984, "losses\\PerceptualLoss.py": 1615020169.668723, "losses\\SliceWassersteinDistance.py": 1634022704.6082795, - "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\experimental.py": 1583468787.0, - "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\stylize.py": 1583468787.0, - "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\train.py": 1583468787.0, - "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\transformer.py": 1583468787.0, - "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\utils.py": 1583468787.0, - "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\vgg.py": 1633868477.988523, - "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\video.py": 1583468787.0, - "reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\webcam.py": 1583468787.0, + "models\\arcface_models.py": 1642390690.623, + "models\\config.py": 1632643596.2908099, + "models\\__init__.py": 1642390864.8828168, "test_scripts\\tester_common.py": 1625369535.199175, "test_scripts\\tester_FastNST.py": 1634041357.607633, - "train_scripts\\trainer_base.py": 1642347616.205689, - "train_scripts\\trainer_FastNST_SWD.py": 1634581704.2218158, - "train_scripts\\trainer_FM.py": 1642350579.8586667, - "train_scripts\\trainer_gan.py": 1625571403.080787, + "train_scripts\\trainer_base.py": 1642396105.3868554, + "train_scripts\\trainer_FM.py": 1642396334.407562, + "train_scripts\\trainer_naiv512.py": 1642315674.9740853, "utilities\\checkpoint_manager.py": 1611123530.6624403, "utilities\\figure.py": 1611123530.6634378, "utilities\\json_config.py": 1611123530.6614666, @@ -42,8 +34,10 @@ "utilities\\reporter.py": 1625413813.7213495, "utilities\\save_heatmap.py": 1611123530.679439, "utilities\\sshupload.py": 1611123530.6624403, - "utilities\\transfer_checkpoint.py": 1612416429.5316093, + "utilities\\transfer_checkpoint.py": 1642397157.0163105, "utilities\\utilities.py": 1634019485.0783668, "utilities\\yaml_config.py": 1611123530.6614666, - "train_yamls\\train_512FM.yaml": 1642351806.754128 + "train_yamls\\train_512FM.yaml": 1642408586.2747152, + "train_scripts\\trainer_2layer_FM.py": 1642408727.3950334, + "train_yamls\\train_2layer_FM.yaml": 1642408813.9488862 } \ No newline at end of file diff --git a/test_scripts/tester_video.py b/test_scripts/tester_video.py new file mode 100644 index 0000000..30ec590 --- /dev/null +++ b/test_scripts/tester_video.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: tester_commonn.py +# Created Date: Saturday July 3rd 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 4th July 2021 11:32:14 am +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + + + +import os +import cv2 +import time + +import torch +from utilities.utilities import tensor2img + +# from utilities.Reporter import Reporter +from tqdm import tqdm + +class Tester(object): + def __init__(self, config, reporter): + + self.config = config + # logger + self.reporter = reporter + + #============build evaluation dataloader==============# + print("Prepare the test dataloader...") + dlModulename = config["test_dataloader"] + package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'TestDataset') + dataloader = dataloaderClass(config["test_data_path"], + config["batch_size"], + ["png","jpg"]) + self.test_loader= dataloader + + self.test_iter = len(dataloader)//config["batch_size"] + if len(dataloader)%config["batch_size"]>0: + self.test_iter+=1 + + + def __init_framework__(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + script_name = "components."+self.config["module_script_name"] + class_name = self.config["class_name"] + package = __import__(script_name, fromlist=True) + network_class = getattr(package, class_name) + n_class = len(self.config["selectedStyleDir"]) + + # TODO replace below lines to define the model framework + self.network = network_class(self.config["GConvDim"], + self.config["GKS"], + self.config["resNum"], + n_class + #**self.config["module_params"] + ) + + # print and recorde model structure + self.reporter.writeInfo("Model structure:") + self.reporter.writeModel(self.network.__str__()) + + # train in GPU + if self.config["cuda"] >=0: + self.network = self.network.cuda() + # loader1 = torch.load(self.config["ckp_name"]["generator_name"]) + # print(loader1.key()) + # pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"] + self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"]) + # self.network.load_state_dict(torch.load(pathwocao)) + print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"])) + + def test(self): + + # save_result = self.config["saveTestResult"] + save_dir = self.config["test_samples_path"] + ckp_epoch = self.config["checkpoint_epoch"] + version = self.config["version"] + batch_size = self.config["batch_size"] + style_names = self.config["selectedStyleDir"] + n_class = len(style_names) + + # models + self.__init_framework__() + + condition_labels = torch.ones((n_class, batch_size, 1)).long() + for i in range(n_class): + condition_labels[i,:,:] = condition_labels[i,:,:]*i + if self.config["cuda"] >=0: + condition_labels = condition_labels.cuda() + total = len(self.test_loader) + # Start time + import datetime + print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + print('Start =================================== test...') + start_time = time.time() + self.network.eval() + with torch.no_grad(): + for _ in tqdm(range(total//batch_size)): + contents, img_names = self.test_loader() + for i in range(n_class): + if self.config["cuda"] >=0: + contents = contents.cuda() + res, _ = self.network(contents, condition_labels[i, 0, :]) + res = tensor2img(res.cpu()) + for t in range(batch_size): + temp_img = res[t,:,:,:] + temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format( + img_names[t], version, ckp_epoch, style_names[i])),temp_img) + + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + print("Elapsed [{}]".format(elapsed)) \ No newline at end of file diff --git a/train.py b/train.py index 9f3bb4d..dbcd33e 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ # Created Date: Tuesday April 28th 2020 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Monday, 17th January 2022 1:00:00 pm +# Last Modified: Monday, 17th January 2022 4:42:53 pm # Modified By: Chen Xuanhong # Copyright (c) 2020 Shanghai Jiao Tong University ############################################################# @@ -31,24 +31,24 @@ def getParameters(): parser = argparse.ArgumentParser() # general settings - parser.add_argument('-v', '--version', type=str, default='FM', + parser.add_argument('-v', '--version', type=str, default='2layerFM', help="version name for train, test, finetune") - parser.add_argument('-t', '--tag', type=str, default='test', + parser.add_argument('-t', '--tag', type=str, default='Feature_match', help="tag for current experiment") parser.add_argument('-p', '--phase', type=str, default="train", choices=['train', 'finetune','debug'], help="The phase of current project") - parser.add_argument('-c', '--cuda', type=int, default=0) # <0 if it is set as -1, program will use CPU + parser.add_argument('-c', '--cuda', type=int, default=1) # <0 if it is set as -1, program will use CPU parser.add_argument('-e', '--ckpt', type=int, default=74, help="checkpoint epoch for test phase or finetune phase") # training parser.add_argument('--experiment_description', type=str, - default="尝试使用Liif+Invo作为上采样和降采样的算子,降采样两个DSF算子,上采样两个DSF算子") + default="减小重建和feature match的权重,使用2和3的feature作为feature") - parser.add_argument('--train_yaml', type=str, default="train_512FM.yaml") + parser.add_argument('--train_yaml', type=str, default="train_2layer_FM.yaml") # system logger parser.add_argument('--logger', type=str, diff --git a/train_scripts/trainer_2layer_FM.py b/train_scripts/trainer_2layer_FM.py new file mode 100644 index 0000000..35d8537 --- /dev/null +++ b/train_scripts/trainer_2layer_FM.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Monday, 17th January 2022 9:27:48 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random + +import numpy as np + +import torch +import torch.nn.functional as F +from utilities.plot import plot_batch + +from train_scripts.trainer_base import TrainerBase + +class Trainer(TrainerBase): + + def __init__(self, config, reporter): + super(Trainer, self).__init__(config, reporter) + + self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + # TODO modify this function to build your models + def init_framework(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + + model_config = self.config["model_configs"] + + if self.config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + dscript_name = "components." + model_config["d_model"]["script"] + + elif self.config["phase"] == "finetune": + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + dscript_name = self.config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + self.reporter.writeInfo("Generator structure:") + self.reporter.writeModel(self.gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + self.dis = dis_class(**model_config["d_model"]["module_params"]) + self.dis.feature_network.requires_grad_(False) + + # print and recorde model structure + self.reporter.writeInfo("Discriminator structure:") + self.reporter.writeModel(self.dis.__str__()) + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + + # train in GPU + if self.config["cuda"] >=0: + self.gen = self.gen.cuda() + self.dis = self.dis.cuda() + self.arcface= self.arcface.cuda() + + self.arcface.eval() + self.arcface.requires_grad_(False) + + # if in finetune phase, load the pretrained checkpoint + if self.config["phase"] == "finetune": + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.gen.load_state_dict(torch.load(model_path)) + + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["discriminator_name"])) + self.dis.load_state_dict(torch.load(model_path)) + + print('loaded trained backbone model step {}...!'.format(self.config["project_checkpoints"])) + + # TODO modify this function to configurate the optimizer of your pipeline + def __setup_optimizers__(self): + g_train_opt = self.config['g_optim_config'] + d_train_opt = self.config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in self.gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + self.reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in self.dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + self.reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = self.config['optim_type'] + + if optim_type == 'Adam': + self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if self.config["phase"] == "finetune": + opt_path = os.path.join(self.config["project_checkpoints"], + "step%d_optim_%s.pth"%(self.config["checkpoint_step"], + self.config["optimizer_names"]["generator_name"])) + self.g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(self.config["project_checkpoints"], + "step%d_optim_%s.pth"%(self.config["checkpoint_step"], + self.config["optimizer_names"]["discriminator_name"])) + self.d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(self.config["project_checkpoints"])) + + + # TODO modify this function to evaluate your model + # Evaluate the checkpoint + def __evaluation__(self, + step = 0, + **kwargs + ): + src_image1 = kwargs["src1"] + src_image2 = kwargs["src2"] + batch_size = self.batch_size + self.gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* self.img_std + self.img_mean).numpy() + for r in range(batch_size): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = self.arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_size): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_size, 1, 1, 1) + img_fake = self.gen(image_infer, id_vector_src1).cpu() + + img_fake = img_fake * self.img_std + img_fake = img_fake + self.img_mean + img_fake = img_fake.numpy() + for j in range(batch_size): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(self.sample_dir, 'step_'+str(step+1)+'.jpg')) + + + + + def train(self): + + ckpt_dir = self.config["project_checkpoints"] + log_frep = self.config["log_step"] + model_freq = self.config["model_save_step"] + sample_freq = self.config["sample_step"] + total_step = self.config["total_step"] + random_seed = self.config["dataset_params"]["random_seed"] + + self.batch_size = self.config["batch_size"] + self.sample_dir = self.config["project_samples"] + self.arcface_ckpt= self.config["arcface_ckpt"] + + + # prep_weights= self.config["layersWeight"] + id_w = self.config["id_weight"] + rec_w = self.config["reconstruct_weight"] + feat_w = self.config["feature_match_weight"] + + + + super().train() + + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + random.seed(random_seed) + randindex = [i for i in range(self.batch_size)] + random.shuffle(randindex) + import datetime + for step in range(self.start, total_step): + self.gen.train() + self.dis.train() + for interval in range(2): + random.shuffle(randindex) + src_image1, src_image2 = self.train_loader.next() + + if step%2 == 0: + img_id = src_image2 + else: + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = self.arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + if interval: + + img_fake = self.gen(src_image1, latent_id) + gen_logits,_ = self.dis(img_fake.detach(), None) + loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() + + real_logits,_ = self.dis(src_image2,None) + loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean() + + loss_D = loss_Dgen + loss_Dreal + self.d_optimizer.zero_grad() + loss_D.backward() + self.d_optimizer.step() + else: + + # model.netD.requires_grad_(True) + img_fake = self.gen(src_image1, latent_id) + # G loss + gen_logits,feat = self.dis(img_fake, None) + + loss_Gmain = (-gen_logits).mean() + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = self.arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean() + real_feat = self.dis.get_feature(src_image1) + feat_match_loss = l1_loss(feat["3"],real_feat["3"]) + l1_loss(feat["2"],real_feat["2"]) + loss_G = loss_Gmain + loss_G_ID * id_w + \ + feat_match_loss * feat_w + if step%2 == 0: + #G_Rec + loss_G_Rec = l1_loss(img_fake, src_image1) + loss_G += loss_G_Rec * rec_w + + self.g_optimizer.zero_grad() + loss_G.backward() + self.g_optimizer.step() + + # Print out log info + if (step + 1) % log_frep == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + + epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ + G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ + D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + format(self.config["version"], elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ + loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + print(epochinformation) + self.reporter.writeInfo(epochinformation) + + if self.config["logger"] == "tensorboard": + self.logger.add_scalar('G/G_loss', loss_G.item(), step) + self.logger.add_scalar('G/Rec_loss', loss_G_Rec.item(), step) + self.logger.add_scalar('G/Fm_loss', feat_match_loss.item(), step) + self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + self.logger.add_scalar('D/D_loss', loss_D.item(), step) + self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step) + self.logger.add_scalar('D/D_real', loss_Dreal.item(), step) + elif self.config["logger"] == "wandb": + self.logger.log({"G_loss": loss_G.item()}, step = step) + self.logger.log({"Rec_loss": loss_G_Rec.item()}, step = step) + self.logger.log({"Fm_loss": feat_match_loss.item()}, step = step) + self.logger.log({"G_ID": loss_G_ID.item()}, step = step) + self.logger.log({"D_loss": loss_D.item()}, step = step) + self.logger.log({"D_fake": loss_Dgen.item()}, step = step) + self.logger.log({"D_real": loss_Dreal.item()}, step = step) + if (step + 1) % sample_freq == 0: + self.__evaluation__( + step = step, + **{ + "src1": src_image1, + "src2": src_image2 + }) + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if (step+1) % model_freq==0: + + torch.save(self.gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + self.config["checkpoint_names"]["generator_name"]))) + torch.save(self.dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + self.config["checkpoint_names"]["discriminator_name"]))) + + torch.save(self.g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + self.config["checkpoint_names"]["generator_name"]))) + + torch.save(self.d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + self.config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + + self.__evaluation__( + step = step, + **{ + "src1": src_image1, + "src2": src_image2 + }) \ No newline at end of file diff --git a/train_scripts/trainer_FM.py b/train_scripts/trainer_FM.py index 40e0fcc..9cac031 100644 --- a/train_scripts/trainer_FM.py +++ b/train_scripts/trainer_FM.py @@ -5,7 +5,7 @@ # Created Date: Sunday January 9th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Monday, 17th January 2022 1:12:08 pm +# Last Modified: Monday, 17th January 2022 5:31:43 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -180,8 +180,9 @@ class Trainer(TrainerBase): def train(self): ckpt_dir = self.config["project_checkpoints"] - log_frep = self.config["log_step"] + log_freq = self.config["log_step"] model_freq = self.config["model_save_step"] + sample_freq = self.config["sample_step"] total_step = self.config["total_step"] random_seed = self.config["dataset_params"]["random_seed"] @@ -268,7 +269,7 @@ class Trainer(TrainerBase): self.g_optimizer.step() # Print out log info - if (step + 1) % log_frep == 0: + if (step + 1) % log_freq == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) @@ -295,6 +296,14 @@ class Trainer(TrainerBase): self.logger.log({"D_loss": loss_D.item()}, step = step) self.logger.log({"D_fake": loss_Dgen.item()}, step = step) self.logger.log({"D_real": loss_Dreal.item()}, step = step) + + if (step + 1) % sample_freq == 0: + self.__evaluation__( + step = step, + **{ + "src1": src_image1, + "src2": src_image2 + }) diff --git a/train_scripts/trainer_cycleloss.py b/train_scripts/trainer_cycleloss.py new file mode 100644 index 0000000..d1191bd --- /dev/null +++ b/train_scripts/trainer_cycleloss.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 19th January 2022 4:21:03 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random + +import numpy as np + +import torch +import torch.nn.functional as F +from utilities.plot import plot_batch + +from train_scripts.trainer_base import TrainerBase + +class Trainer(TrainerBase): + + def __init__(self, config, reporter): + super(Trainer, self).__init__(config, reporter) + + self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + # TODO modify this function to build your models + def init_framework(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + + model_config = self.config["model_configs"] + + if self.config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + dscript_name = "components." + model_config["d_model"]["script"] + + elif self.config["phase"] == "finetune": + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + dscript_name = self.config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + self.reporter.writeInfo("Generator structure:") + self.reporter.writeModel(self.gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + self.dis = dis_class(**model_config["d_model"]["module_params"]) + self.dis.feature_network.requires_grad_(False) + + # print and recorde model structure + self.reporter.writeInfo("Discriminator structure:") + self.reporter.writeModel(self.dis.__str__()) + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + + # train in GPU + if self.config["cuda"] >=0: + self.gen = self.gen.cuda() + self.dis = self.dis.cuda() + self.arcface= self.arcface.cuda() + + self.arcface.eval() + self.arcface.requires_grad_(False) + + # if in finetune phase, load the pretrained checkpoint + if self.config["phase"] == "finetune": + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.gen.load_state_dict(torch.load(model_path)) + + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["discriminator_name"])) + self.dis.load_state_dict(torch.load(model_path)) + + print('loaded trained backbone model step {}...!'.format(self.config["project_checkpoints"])) + + # TODO modify this function to configurate the optimizer of your pipeline + def __setup_optimizers__(self): + g_train_opt = self.config['g_optim_config'] + d_train_opt = self.config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in self.gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + self.reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in self.dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + self.reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = self.config['optim_type'] + + if optim_type == 'Adam': + self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if self.config["phase"] == "finetune": + opt_path = os.path.join(self.config["project_checkpoints"], + "step%d_optim_%s.pth"%(self.config["checkpoint_step"], + self.config["optimizer_names"]["generator_name"])) + self.g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(self.config["project_checkpoints"], + "step%d_optim_%s.pth"%(self.config["checkpoint_step"], + self.config["optimizer_names"]["discriminator_name"])) + self.d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(self.config["project_checkpoints"])) + + + # TODO modify this function to evaluate your model + # Evaluate the checkpoint + def __evaluation__(self, + step = 0, + **kwargs + ): + src_image1 = kwargs["src1"] + src_image2 = kwargs["src2"] + batch_size = self.batch_size + self.gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* self.img_std + self.img_mean).numpy() + for r in range(batch_size): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = self.arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_size): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_size, 1, 1, 1) + img_fake = self.gen(image_infer, id_vector_src1).cpu() + + img_fake = img_fake * self.img_std + img_fake = img_fake + self.img_mean + img_fake = img_fake.numpy() + for j in range(batch_size): + imgs.append(img_fake[j,...]) + print("Save sample image at step = % ..............."%step) + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(self.sample_dir, 'step_'+str(step+1)+'.jpg')) + + + + + def train(self): + + ckpt_dir = self.config["project_checkpoints"] + log_frep = self.config["log_step"] + model_freq = self.config["model_save_step"] + sample_freq = self.config["sample_step"] + total_step = self.config["total_step"] + random_seed = self.config["dataset_params"]["random_seed"] + + self.batch_size = self.config["batch_size"] + self.sample_dir = self.config["project_samples"] + self.arcface_ckpt= self.config["arcface_ckpt"] + + + # prep_weights= self.config["layersWeight"] + id_w = self.config["id_weight"] + rec_w = self.config["reconstruct_weight"] + feat_w = self.config["feature_match_weight"] + cyc_w = self.config["cycle_weight"] + + + + super().train() + + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + random.seed(random_seed) + randindex = [i for i in range(self.batch_size)] + random.shuffle(randindex) + import datetime + for step in range(self.start, total_step): + self.gen.train() + self.dis.train() + for interval in range(2): + random.shuffle(randindex) + src_image1, src_image2 = self.train_loader.next() + + if step%2 == 0: + img_id = src_image2 + else: + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = self.arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + if interval: + + img_fake = self.gen(src_image1, latent_id) + gen_logits,_ = self.dis(img_fake.detach(), None) + loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() + + real_logits,_ = self.dis(src_image2,None) + loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean() + + loss_D = loss_Dgen + loss_Dreal + self.d_optimizer.zero_grad() + loss_D.backward() + self.d_optimizer.step() + else: + + # model.netD.requires_grad_(True) + img_fake = self.gen(src_image1, latent_id) + # G loss + gen_logits,feat = self.dis(img_fake, None) + + loss_Gmain = (-gen_logits).mean() + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = self.arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean() + real_feat = self.dis.get_feature(src_image1) + feat_match_loss = l1_loss(feat["3"],real_feat["3"]) + l1_loss(feat["2"],real_feat["2"]) + + src1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic') + src1_id = self.arcface(src1_down) + cyc_fake = self.gen(img_fake, src1_id) + loss_cyc = l1_loss(cyc_fake, src_image1) + loss_G = loss_Gmain + loss_G_ID * id_w + \ + feat_match_loss * feat_w + cyc_w * loss_cyc + if step%2 == 0: + #G_Rec + loss_G_Rec = l1_loss(img_fake, src_image1) + loss_G += loss_G_Rec * rec_w + + self.g_optimizer.zero_grad() + loss_G.backward() + self.g_optimizer.step() + + # Print out log info + if (step + 1) % log_frep == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + + epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ + G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ + D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + format(self.config["version"], elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ + loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + print(epochinformation) + self.reporter.writeInfo(epochinformation) + + if self.config["logger"] == "tensorboard": + self.logger.add_scalar('G/G_loss', loss_G.item(), step) + self.logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + self.logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step) + self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + self.logger.add_scalar('G/Cycle', loss_cyc.item(), step) + self.logger.add_scalar('D/D_loss', loss_D.item(), step) + self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step) + self.logger.add_scalar('D/D_real', loss_Dreal.item(), step) + elif self.config["logger"] == "wandb": + self.logger.log({"G_loss": loss_G.item()}, step = step) + self.logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + self.logger.log({"G_feat_match": feat_match_loss.item()}, step = step) + self.logger.log({"G_ID": loss_G_ID.item()}, step = step) + self.logger.log({"Cycle": loss_cyc.item()}, step = step) + self.logger.log({"D_loss": loss_D.item()}, step = step) + self.logger.log({"D_fake": loss_Dgen.item()}, step = step) + self.logger.log({"D_real": loss_Dreal.item()}, step = step) + if (step + 1) % sample_freq == 0: + self.__evaluation__( + step = step, + **{ + "src1": src_image1, + "src2": src_image2 + }) + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if (step+1) % model_freq==0: + + torch.save(self.gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + self.config["checkpoint_names"]["generator_name"]))) + torch.save(self.dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + self.config["checkpoint_names"]["discriminator_name"]))) + + torch.save(self.g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + self.config["checkpoint_names"]["generator_name"]))) + + torch.save(self.d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + self.config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + + self.__evaluation__( + step = step, + **{ + "src1": src_image1, + "src2": src_image2 + }) \ No newline at end of file diff --git a/train_yamls/train_2layer_FM.yaml b/train_yamls/train_2layer_FM.yaml new file mode 100644 index 0000000..d073535 --- /dev/null +++ b/train_yamls/train_2layer_FM.yaml @@ -0,0 +1,63 @@ +# Related scripts +train_script_name: 2layer_FM + +# models' scripts +model_configs: + g_model: + script: Generator + class_name: Generator + module_params: + g_conv_dim: 512 + g_kernel_size: 3 + res_num: 9 + + d_model: + script: projected_discriminator + class_name: ProjectedDiscriminator + module_params: + diffaug: False + interp224: False + backbone_kwargs: {} + +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 12 + +# Dataset +dataloader: VGGFace2HQ +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 8 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +id_weight: 20.0 +reconstruct_weight: 1.0 +feature_match_weight: 1.0 + +# Log +log_step: 300 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_512FM.yaml b/train_yamls/train_512FM.yaml index 1d14b6f..59dc115 100644 --- a/train_yamls/train_512FM.yaml +++ b/train_yamls/train_512FM.yaml @@ -22,7 +22,7 @@ model_configs: arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar # Training information -batch_size: 1 +batch_size: 12 # Dataset dataloader: VGGFace2HQ @@ -49,13 +49,14 @@ d_optim_config: betas: [ 0, 0.99] eps: !!float 1e-8 -id_weight: 10.0 +id_weight: 20.0 reconstruct_weight: 1.0 -feature_match_weight: 5.0 +feature_match_weight: 10.0 # Log -log_step: 10 -model_save_step: 20 +log_step: 300 +model_save_step: 10000 +sample_step: 1000 total_step: 1000000 checkpoint_names: generator_name: Generator diff --git a/train_yamls/train_cycleloss.yaml b/train_yamls/train_cycleloss.yaml new file mode 100644 index 0000000..9c9de89 --- /dev/null +++ b/train_yamls/train_cycleloss.yaml @@ -0,0 +1,64 @@ +# Related scripts +train_script_name: cycleloss + +# models' scripts +model_configs: + g_model: + script: Generator + class_name: Generator + module_params: + g_conv_dim: 512 + g_kernel_size: 3 + res_num: 9 + + d_model: + script: projected_discriminator + class_name: ProjectedDiscriminator + module_params: + diffaug: False + interp224: False + backbone_kwargs: {} + +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 12 + +# Dataset +dataloader: VGGFace2HQ +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 8 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +id_weight: 20.0 +reconstruct_weight: 0.1 +feature_match_weight: 0.1 +cycle_weight: 10.0 + +# Log +log_step: 400 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file