diff --git a/GUI/file_sync/filestate_machine0.json b/GUI/file_sync/filestate_machine0.json index 309731d..3f05a17 100644 --- a/GUI/file_sync/filestate_machine0.json +++ b/GUI/file_sync/filestate_machine0.json @@ -1,6 +1,6 @@ { - "GUI.py": 1644423287.9844918, - "test.py": 1645015463.4072468, + "GUI.py": 1645109256.0056663, + "test.py": 1645344802.7112515, "train.py": 1643397924.974299, "components\\Generator.py": 1644689001.9005148, "components\\projected_discriminator.py": 1642348101.4661522, @@ -33,7 +33,7 @@ "utilities\\plot.py": 1641911100.7995758, "utilities\\reporter.py": 1625413813.7213495, "utilities\\save_heatmap.py": 1611123530.679439, - "utilities\\sshupload.py": 1611123530.6624403, + "utilities\\sshupload.py": 1645168814.6421573, "utilities\\transfer_checkpoint.py": 1642397157.0163105, "utilities\\utilities.py": 1634019485.0783668, "utilities\\yaml_config.py": 1611123530.6614666, @@ -60,13 +60,13 @@ "face_crop.py": 1643789609.1834445, "face_crop_video.py": 1643815024.5516832, "similarity.py": 1643269705.1073737, - "train_multigpu.py": 1645035569.415791, + "train_multigpu.py": 1645548174.898882, "components\\arcface_decoder.py": 1643396144.2575414, "components\\Generator_nobias.py": 1643179001.810856, "data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1644861019.9044807, "data_tools\\data_loader_VGGFace2HQ_Rec.py": 1643398754.86898, "test_scripts\\tester_arcface_Rec.py": 1643431261.9333818, - "test_scripts\\tester_image.py": 1644934851.442447, + "test_scripts\\tester_image.py": 1645547412.8218117, "torch_utils\\custom_ops.py": 1640773190.0, "torch_utils\\misc.py": 1640773190.0, "torch_utils\\persistence.py": 1640773190.0, @@ -105,23 +105,32 @@ "components\\Generator_ori.py": 1644689174.414655, "losses\\cos.py": 1644229583.4023254, "data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1644860106.943826, - "speed_test.py": 1645034614.282678, + "speed_test.py": 1645266259.6685307, "components\\DeConv_Invo.py": 1644426607.1588645, "components\\Generator_reduce_up.py": 1644688655.2096283, "components\\Generator_upsample.py": 1644689723.8293872, "components\\misc\\Involution.py": 1644509321.5267963, "train_yamls\\train_Invoup.yaml": 1644689981.9794765, - "flops.py": 1645034657.122085, + "flops.py": 1645540971.0513766, "detection_test.py": 1644935512.6830947, - "components\\DeConv_Depthwise.py": 1645027608.8040042, + "components\\DeConv_Depthwise.py": 1645064447.4379447, "components\\DeConv_Depthwise1.py": 1644946969.5054545, "components\\Generator_modulation_depthwise.py": 1644861291.4467516, - "components\\Generator_modulation_depthwise_config.py": 1645034769.4103642, + "components\\Generator_modulation_depthwise_config.py": 1645262162.9779513, "components\\Generator_modulation_up.py": 1644946498.7005584, "components\\Generator_oriae_modulation.py": 1644897798.1987727, "components\\Generator_ori_config.py": 1644946742.3635018, "train_scripts\\trainer_multi_gpu1.py": 1644859528.8428593, "train_yamls\\train_Depthwise.yaml": 1644860961.099242, "train_yamls\\train_depthwise_modulation.yaml": 1645035964.9551077, - "train_yamls\\train_oriae_modulation.yaml": 1644897891.2576747 + "train_yamls\\train_oriae_modulation.yaml": 1644897891.2576747, + "train_distillation_mgpu.py": 1645553439.948758, + "components\\DeConv.py": 1645263338.9001615, + "components\\DeConv_Depthwise_ECA.py": 1645265769.1076133, + "components\\ECA.py": 1614848426.9604986, + "components\\ECA_Depthwise_Conv.py": 1645265754.2023985, + "components\\Generator_eca_depthwise.py": 1645266338.9750814, + "losses\\KA.py": 1645546325.331715, + "train_scripts\\trainer_distillation_mgpu.py": 1645553282.8011973, + "train_yamls\\train_distillation.yaml": 1645553621.3982964 } \ No newline at end of file diff --git a/flops.py b/flops.py index d16385d..a5c45e1 100644 --- a/flops.py +++ b/flops.py @@ -5,7 +5,7 @@ # Created Date: Sunday February 13th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 17th February 2022 2:32:48 am +# Last Modified: Tuesday, 22nd February 2022 10:42:51 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -29,7 +29,7 @@ if __name__ == '__main__': model_config={ "id_dim": 512, "g_kernel_size": 3, - "in_channel":16, + "in_channel":32, "res_num": 9, # "up_mode": "nearest", "up_mode": "bilinear", diff --git a/losses/KA.py b/losses/KA.py new file mode 100644 index 0000000..d6d907e --- /dev/null +++ b/losses/KA.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: KA.py +# Created Date: Wednesday February 23rd 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 23rd February 2022 12:12:05 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + + +def KA(X, Y): + X_ = X.view(X.size(0), -1) + Y_ = Y.view(Y.size(0), -1) + assert X_.shape[0] == Y_.shape[ + 0], f'X_ and Y_ must have the same shape on dim 0, but got {X_.shape[0]} for X_ and {Y_.shape[0]} for Y_.' + X_vec = X_ @ X_.T + Y_vec = Y_ @ Y_.T + ret = (X_vec * Y_vec).sum() / ((X_vec**2).sum() * (Y_vec**2).sum())**0.5 + return ret \ No newline at end of file diff --git a/test.py b/test.py index f795e4e..c9e5ee3 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Saturday, 19th February 2022 11:46:06 am +# Last Modified: Sunday, 20th February 2022 4:13:22 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -34,7 +34,7 @@ def getParameters(): help="version name for train, test, finetune") parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU - parser.add_argument('-s', '--checkpoint_step', type=int, default=170000, + parser.add_argument('-s', '--checkpoint_step', type=int, default=250000, help="checkpoint epoch for test phase or finetune phase") # test diff --git a/test_scripts/tester_image.py b/test_scripts/tester_image.py index 6604843..bbdad11 100644 --- a/test_scripts/tester_image.py +++ b/test_scripts/tester_image.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Friday, 18th February 2022 5:00:28 pm +# Last Modified: Wednesday, 23rd February 2022 12:30:12 am # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -59,6 +59,17 @@ class Tester(object): # TODO replace below lines to define the model framework self.network = gen_class(**model_config["g_model"]["module_params"]) self.network = self.network.eval() + # for name in self.network.state_dict(): + # print(name) + self.features = {} + mapping_layers = [ + "first_layer", + "down4", + "BottleNeck.2" + ] + + + # print and recorde model structure self.reporter.writeInfo("Model structure:") self.reporter.writeModel(self.network.__str__()) diff --git a/train_distillation_mgpu.py b/train_distillation_mgpu.py new file mode 100644 index 0000000..3ad79f1 --- /dev/null +++ b/train_distillation_mgpu.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: train.py +# Created Date: Tuesday April 28th 2020 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 23rd February 2022 2:30:03 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + + +from curses.panel import version +import os +import shutil +import argparse +from torch.backends import cudnn +from utilities.json_config import readConfig, writeConfig +from utilities.reporter import Reporter +from utilities.yaml_config import getConfigYaml + + + +def str2bool(v): + return v.lower() in ('true') + +#################################################################################### +# To configure the seting of training\finetune\test +# +#################################################################################### +def getParameters(): + + parser = argparse.ArgumentParser() + # general settings + parser.add_argument('-v', '--version', type=str, default='distillation', + help="version name for train, test, finetune") + parser.add_argument('-t', '--tag', type=str, default='distillation', + help="tag for current experiment") + + parser.add_argument('-p', '--phase', type=str, default="train", + choices=['train', 'finetune','debug'], + help="The phase of current project") + + parser.add_argument('-c', '--gpus', type=int, nargs='+', default=[0,1]) # <0 if it is set as -1, program will use CPU + parser.add_argument('-e', '--ckpt', type=int, default=74, + help="checkpoint epoch for test phase or finetune phase") + + # training + parser.add_argument('--experiment_description', type=str, + default="测试蒸馏代码") + + parser.add_argument('--train_yaml', type=str, default="train_distillation.yaml") + + # system logger + parser.add_argument('--logger', type=str, + default="none", choices=['tensorboard', 'wandb','none'], help='system logger') + + # # logs (does not to be changed in most time) + # parser.add_argument('--dataloader_workers', type=int, default=6) + # parser.add_argument('--use_tensorboard', type=str2bool, default='True', + # choices=['True', 'False'], help='enable the tensorboard') + # parser.add_argument('--log_step', type=int, default=100) + # parser.add_argument('--sample_step', type=int, default=100) + + # # template (onece editing finished, it should be deleted) + # parser.add_argument('--str_parameter', type=str, default="default", help='str parameter') + # parser.add_argument('--str_parameter_choices', type=str, + # default="default", choices=['choice1', 'choice2','choice3'], help='str parameter with choices list') + # parser.add_argument('--int_parameter', type=int, default=0, help='int parameter') + # parser.add_argument('--float_parameter', type=float, default=0.0, help='float parameter') + # parser.add_argument('--bool_parameter', type=str2bool, default='True', choices=['True', 'False'], help='bool parameter') + # parser.add_argument('--list_str_parameter', type=str, nargs='+', default=["element1","element2"], help='str list parameter') + # parser.add_argument('--list_int_parameter', type=int, nargs='+', default=[0,1], help='int list parameter') + return parser.parse_args() + +ignoreKey = [ + "dataloader_workers", + "log_root_path", + "project_root", + "project_summary", + "project_checkpoints", + "project_samples", + "project_scripts", + "reporter_path", + "use_specified_data", + "specified_data_paths", + "dataset_path","cuda", + "test_script_name", + "test_dataloader", + "test_dataset_path", + "save_test_result", + "test_batch_size", + "node_name", + "checkpoint_epoch", + "test_dataset_path", + "test_dataset_name", + "use_my_test_date"] + +#################################################################################### +# This function will create the related directories before the +# training\fintune\test starts +# Your_log_root (version name) +# |---summary/... +# |---samples/... (save evaluated images) +# |---checkpoints/... +# |---scripts/... +# +#################################################################################### +def createDirs(sys_state): + # the base dir + if not os.path.exists(sys_state["log_root_path"]): + os.makedirs(sys_state["log_root_path"]) + + # create dirs + sys_state["project_root"] = os.path.join(sys_state["log_root_path"], + sys_state["version"]) + + project_root = sys_state["project_root"] + if not os.path.exists(project_root): + os.makedirs(project_root) + + sys_state["project_summary"] = os.path.join(project_root, "summary") + if not os.path.exists(sys_state["project_summary"]): + os.makedirs(sys_state["project_summary"]) + + sys_state["project_checkpoints"] = os.path.join(project_root, "checkpoints") + if not os.path.exists(sys_state["project_checkpoints"]): + os.makedirs(sys_state["project_checkpoints"]) + + sys_state["project_samples"] = os.path.join(project_root, "samples") + if not os.path.exists(sys_state["project_samples"]): + os.makedirs(sys_state["project_samples"]) + + sys_state["project_scripts"] = os.path.join(project_root, "scripts") + if not os.path.exists(sys_state["project_scripts"]): + os.makedirs(sys_state["project_scripts"]) + + sys_state["reporter_path"] = os.path.join(project_root,sys_state["version"]+"_report") + +def fetch_teacher_files(sys_state, env_config): + + version = sys_state["teacher_model"]["version"] + if not os.path.exists(sys_state["log_root_path"]): + os.makedirs(sys_state["log_root_path"]) + # create dirs + sys_state["teacher_model"]["project_root"] = os.path.join(sys_state["log_root_path"], version) + + project_root = sys_state["teacher_model"]["project_root"] + if not os.path.exists(project_root): + os.makedirs(project_root) + + sys_state["teacher_model"]["project_checkpoints"] = os.path.join(project_root, "checkpoints") + if not os.path.exists(sys_state["teacher_model"]["project_checkpoints"]): + os.makedirs(sys_state["teacher_model"]["project_checkpoints"]) + + sys_state["teacher_model"]["project_scripts"] = os.path.join(project_root, "scripts") + if not os.path.exists(sys_state["teacher_model"]["project_scripts"]): + os.makedirs(sys_state["teacher_model"]["project_scripts"]) + if sys_state["teacher_model"]["node_ip"] != "localhost": + from utilities.sshupload import fileUploaderClass + machine_config = env_config["machine_config"] + machine_config = readConfig(machine_config) + nodeinf = None + for item in machine_config: + if item["ip"] == sys_state["teacher_model"]["node_ip"]: + nodeinf = item + break + if not nodeinf: + raise Exception(print("Configuration of node %s is unavaliable"%sys_state["node_ip"])) + print("ready to fetch related files from server: %s ......"%nodeinf["ip"]) + uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"]) + + remotebase = os.path.join(nodeinf['path'],"train_logs",version).replace('\\','/') + + # Get the config.json + print("ready to get the teacher's config.json...") + remoteFile = os.path.join(remotebase, env_config["config_json_name"]).replace('\\','/') + localFile = os.path.join(project_root, env_config["config_json_name"]) + + ssh_state = uploader.sshScpGet(remoteFile, localFile) + if not ssh_state: + raise Exception(print("Get file %s failed! config.json does not exist!"%remoteFile)) + print("success get the teacher's config.json from server %s"%nodeinf['ip']) + + # Get scripts + remoteDir = os.path.join(remotebase, "scripts").replace('\\','/') + localDir = os.path.join(sys_state["teacher_model"]["project_scripts"]) + ssh_state = uploader.sshScpGetDir(remoteDir, localDir) + if not ssh_state: + raise Exception(print("Get file %s failed! Program exists!"%remoteFile)) + print("Get the teacher's scripts successful!") + # Read model_config.json + config_json = os.path.join(project_root, env_config["config_json_name"]) + json_obj = readConfig(config_json) + for item in json_obj.items(): + if item[0] in ignoreKey: + pass + else: + sys_state["teacher_model"][item[0]] = item[1] + + # Get checkpoints + if sys_state["teacher_model"]["node_ip"] != "localhost": + ckpt_name = "step%d_%s.pth"%(sys_state["teacher_model"]["checkpoint_step"], + sys_state["teacher_model"]["checkpoint_names"]["generator_name"]) + localFile = os.path.join(sys_state["teacher_model"]["project_checkpoints"],ckpt_name) + if not os.path.exists(localFile): + remoteFile = os.path.join(remotebase, "checkpoints", ckpt_name).replace('\\','/') + ssh_state = uploader.sshScpGet(remoteFile, localFile, True) + if not ssh_state: + raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile)) + print("Get the teacher's checkpoint %s successfully!"%(ckpt_name)) + else: + print("%s exists!"%(ckpt_name)) + +def main(): + + config = getParameters() + # speed up the program + cudnn.benchmark = True + cudnn.enabled = True + + from utilities.logo_class import logo_class + logo_class.print_group_logo() + + sys_state = {} + + # set the GPU number + gpus = [str(i) for i in config.gpus] + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpus) + + # read system environment paths + env_config = readConfig('env/env.json') + env_config = env_config["path"] + + # obtain all configurations in argparse + config_dic = vars(config) + for config_key in config_dic.keys(): + sys_state[config_key] = config_dic[config_key] + + #=======================Train Phase=========================# + if config.phase == "train": + # read training configurations from yaml file + ymal_config = getConfigYaml(os.path.join(env_config["train_config_path"], config.train_yaml)) + for item in ymal_config.items(): + sys_state[item[0]] = item[1] + + # create related dirs + sys_state["log_root_path"] = env_config["train_log_root"] + createDirs(sys_state) + + # create reporter file + reporter = Reporter(sys_state["reporter_path"]) + + # save the config json + config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"]) + writeConfig(config_json, sys_state) + + # save the dependent scripts + # TODO and copy the scripts to the project dir + + # save the trainer script into [train_logs_root]\[version name]\scripts\ + file1 = os.path.join(env_config["train_scripts_path"], + "trainer_%s.py"%sys_state["train_script_name"]) + tgtfile1 = os.path.join(sys_state["project_scripts"], + "trainer_%s.py"%sys_state["train_script_name"]) + shutil.copyfile(file1,tgtfile1) + + # save the yaml file + file1 = os.path.join(env_config["train_config_path"], config.train_yaml) + tgtfile1 = os.path.join(sys_state["project_scripts"], config.train_yaml) + shutil.copyfile(file1,tgtfile1) + + # TODO replace below lines, here to save the critical scripts + + #=====================Finetune Phase=====================# + elif config.phase == "finetune": + sys_state["log_root_path"] = env_config["train_log_root"] + sys_state["project_root"] = os.path.join(sys_state["log_root_path"], sys_state["version"]) + + config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"]) + train_config = readConfig(config_json) + for item in train_config.items(): + if item[0] in ignoreKey: + pass + else: + sys_state[item[0]] = item[1] + + createDirs(sys_state) + reporter = Reporter(sys_state["reporter_path"]) + sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"] + + fetch_teacher_files(sys_state,env_config) + # get the dataset path + sys_state["dataset_paths"] = {} + for data_key in env_config["dataset_paths"].keys(): + sys_state["dataset_paths"][data_key] = env_config["dataset_paths"][data_key] + + # display the training information + moduleName = "train_scripts.trainer_" + sys_state["train_script_name"] + if config.phase == "finetune": + moduleName = sys_state["com_base"] + "trainer_" + sys_state["train_script_name"] + + # print some important information + # TODO + # print("Start to run training script: {}".format(moduleName)) + # print("Traning version: %s"%sys_state["version"]) + # print("Dataloader Name: %s"%sys_state["dataloader"]) + # # print("Image Size: %d"%sys_state["imsize"]) + # print("Batch size: %d"%(sys_state["batch_size"])) + # print("GPUs:", gpus) + print("\n========================================================================\n") + print(sys_state) + for data_key in sys_state.keys(): + print("[%s]---[%s]"%(data_key,sys_state[data_key])) + print("\n========================================================================\n") + + + # Load the training script and start to train + reporter.writeConfig(sys_state) + + package = __import__(moduleName, fromlist=True) + trainerClass= getattr(package, 'Trainer') + trainer = trainerClass(sys_state, reporter) + trainer.train() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/train_multigpu.py b/train_multigpu.py index d7f965a..f31e020 100644 --- a/train_multigpu.py +++ b/train_multigpu.py @@ -5,7 +5,7 @@ # Created Date: Tuesday April 28th 2020 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 17th February 2022 2:19:29 am +# Last Modified: Wednesday, 23rd February 2022 12:42:54 am # Modified By: Chen Xuanhong # Copyright (c) 2020 Shanghai Jiao Tong University ############################################################# diff --git a/train_scripts/trainer_distillation_mgpu.py b/train_scripts/trainer_distillation_mgpu.py new file mode 100644 index 0000000..d660200 --- /dev/null +++ b/train_scripts/trainer_distillation_mgpu.py @@ -0,0 +1,558 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 23rd February 2022 2:36:05 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random +import shutil +import tempfile + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +from losses.KA import KA +from utilities.plot import plot_batch +from train_scripts.trainer_multigpu_base import TrainerBase + + +class Trainer(TrainerBase): + + def __init__(self, + config, + reporter): + super(Trainer, self).__init__(config, reporter) + + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) + + def train(self): + # Launch processes. + num_gpus = len(self.config["gpus"]) + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus) + +def add_mapping_hook(network, features,mapping_layers): + mapping_hooks = [] + + def get_activation(mem, name): + def get_output_hook(module, input, output): + mem[name] = output + + return get_output_hook + + def add_hook(net, mem, mapping_layers): + for n, m in net.named_modules(): + if n in mapping_layers: + mapping_hooks.append( + m.register_forward_hook(get_activation(mem, n))) + + add_hook(network, features, mapping_layers) + + +# TODO modify this function to build your models +def init_framework(config, reporter, device, rank): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + model_config = config["model_configs"] + + if config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + file1 = os.path.join("components", model_config["g_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + dscript_name = "components." + model_config["d_model"]["script"] + file1 = os.path.join("components", model_config["d_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + + elif config["phase"] == "finetune": + gscript_name = config["com_base"] + model_config["g_model"]["script"] + dscript_name = config["com_base"] + model_config["d_model"]["script"] + com_base = "train_logs."+config["teacher_model"]["version"]+".scripts" + tscript_name = com_base +"."+ config["teacher_model"]["model_configs"]["g_model"]["script"] + class_name = config["teacher_model"]["model_configs"]["g_model"]["class_name"] + package = __import__(tscript_name, fromlist=True) + gen_class = getattr(package, class_name) + tgen = gen_class(**config["teacher_model"]["model_configs"]["g_model"]["module_params"]) + tgen = tgen.eval() + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + gen = gen_class(**model_config["g_model"]["module_params"]) + + + + # print and recorde model structure + reporter.writeInfo("Generator structure:") + reporter.writeModel(gen.__str__()) + reporter.writeInfo("Teacher structure:") + reporter.writeModel(tgen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + dis = dis_class(**model_config["d_model"]["module_params"]) + + + # print and recorde model structure + reporter.writeInfo("Discriminator structure:") + reporter.writeModel(dis.__str__()) + + arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + arcface = arcface1['model'].module + + # train in GPU + + # if in finetune phase, load the pretrained checkpoint + if config["phase"] == "finetune": + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["generator_name"])) + gen.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) + + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["discriminator_name"])) + dis.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) + + print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + model_path = os.path.join(config["teacher_model"]["project_checkpoints"], + "step%d_%s.pth"%(config["teacher_model"]["model_step"], + config["teacher_model"]["checkpoint_names"]["generator_name"])) + tgen.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) + print('loaded trained teacher backbone model step {}...!'.format(config["teacher_model"]["model_step"])) + tgen = tgen.to(device) + tgen.requires_grad_(False) + gen = gen.to(device) + dis = dis.to(device) + arcface= arcface.to(device) + arcface.requires_grad_(False) + arcface.eval() + + t_features = {} + s_features = {} + add_mapping_hook(tgen,t_features,config["feature_list"]) + add_mapping_hook(gen,s_features,config["feature_list"]) + + return tgen, gen, dis, arcface, t_features, s_features + +# TODO modify this function to configurate the optimizer of your pipeline +def setup_optimizers(config, reporter, gen, dis, rank): + + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + g_train_opt = config['g_optim_config'] + d_train_opt = config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = config['optim_type'] + + if optim_type == 'Adam': + g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if config["phase"] == "finetune": + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["generator_name"])) + g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["discriminator_name"])) + d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"])) + return g_optimizer, d_optimizer + + +def train_loop( + rank, + config, + reporter, + temp_dir + ): + + version = config["version"] + + ckpt_dir = config["project_checkpoints"] + sample_dir = config["project_samples"] + + log_freq = config["log_step"] + model_freq = config["model_save_step"] + sample_freq = config["sample_step"] + total_step = config["total_step"] + random_seed = config["dataset_params"]["random_seed"] + + + id_w = config["id_weight"] + rec_w = config["reconstruct_weight"] + feat_w = config["feature_match_weight"] + num_gpus = len(config["gpus"]) + batch_gpu = config["batch_size"] // num_gpus + + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + + + + if rank == 0: + img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + + # Initialize. + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Create dataloader. + if rank == 0: + print('Loading training set...') + + dataset = config["dataset_paths"][config["dataset_name"]] + #================================================# + print("Prepare the train dataloader...") + dlModulename = config["dataloader"] + package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'GetLoader') + dataloader_class= dataloaderClass + dataloader = dataloader_class(dataset, + rank, + num_gpus, + batch_gpu, + **config["dataset_params"]) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + tgen, gen, dis, arcface, t_feat, s_feat = init_framework(config, reporter, device, rank) + + # Check for existing checkpoint + + # Print network summary tables. + # if rank == 0: + # attr = torch.empty([batch_gpu, 3, 512, 512], device=device) + # id = torch.empty([batch_gpu, 3, 112, 112], device=device) + # latent = misc.print_module_summary(arcface, [id]) + # img = misc.print_module_summary(gen, [attr, latent]) + # misc.print_module_summary(dis, [img, None]) + # del attr + # del id + # del latent + # del img + # torch.cuda.empty_cache() + + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [gen, dis, arcface, tgen]: + if module is not None and num_gpus > 1: + for param in misc.params_and_buffers(module): + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + #==============build tensorboard=================# + if config["logger"] == "tensorboard": + import torch.utils.tensorboard as tensorboard + tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) + logger = tensorboard_writer + + elif config["logger"] == "wandb": + import wandb + wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", + tags=[config["tag"]], name=version) + + wandb.config = { + "total_step": config["total_step"], + "batch_size": config["batch_size"] + } + logger = wandb + + + random.seed(random_seed) + randindex = [i for i in range(batch_gpu)] + + # set the start point for training loop + if config["phase"] == "finetune": + start = config["ckpt"] + else: + start = 0 + if rank == 0: + import datetime + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from utilities.logo_class import logo_class + logo_class.print_start_training() + + dis.feature_network.requires_grad_(False) + + for step in range(start, total_step): + gen.train() + dis.train() + for interval in range(2): + random.shuffle(randindex) + src_image1, src_image2 = dataloader.next() + # if rank ==0: + + # elapsed = time.time() - start_time + # elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("dataloader:",elapsed) + + if step%2 == 0: + img_id = src_image2 + else: + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + if interval == 0: + + img_t = tgen(src_image1, latent_id) + img_fake = gen(src_image1, latent_id) + gen_logits,_ = dis(img_fake.detach(), None) + loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() + + real_logits,_ = dis(src_image2,None) + loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean() + + loss_D = loss_Dgen + loss_Dreal + d_optimizer.zero_grad(set_to_none=True) + loss_D.backward() + with torch.autograd.profiler.record_function('discriminator_opt'): + # params = [param for param in dis.parameters() if param.grad is not None] + # if len(params) > 0: + # flat = torch.cat([param.grad.flatten() for param in params]) + # if num_gpus > 1: + # torch.distributed.all_reduce(flat) + # flat /= num_gpus + # misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + # grads = flat.split([param.numel() for param in params]) + # for param, grad in zip(params, grads): + # param.grad = grad.reshape(param.shape) + params = [param for param in dis.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + d_optimizer.step() + # if rank ==0: + + # elapsed = time.time() - start_time + # elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("Discriminator training:",elapsed) + else: + + # model.netD.requires_grad_(True) + img_fake = gen(src_image1, latent_id) + # G loss + gen_logits,feat = dis(img_fake, None) + + loss_Gmain = (-gen_logits).mean() + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean() + real_feat = dis.get_feature(src_image1) + feat_match_loss = l1_loss(feat["3"],real_feat["3"]) + loss_G = loss_Gmain + loss_G_ID * id_w + \ + feat_match_loss * feat_w + if step%2 == 0: + #G_Rec + loss_G_Rec = l1_loss(img_fake, src_image1) + loss_G += loss_G_Rec * rec_w + + g_optimizer.zero_grad(set_to_none=True) + loss_G.backward() + with torch.autograd.profiler.record_function('generator_opt'): + params = [param for param in gen.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + g_optimizer.step() + # if rank ==0: + + # elapsed = time.time() - start_time + # elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("Generator training:",elapsed) + + + # Print out log info + if rank == 0 and (step + 1) % log_freq == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("ready to report losses") + # ID_Total= loss_G_ID + # torch.distributed.all_reduce(ID_Total) + + epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ + G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ + D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + format(version, elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ + loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + print(epochinformation) + reporter.writeInfo(epochinformation) + + if config["logger"] == "tensorboard": + logger.add_scalar('G/G_loss', loss_G.item(), step) + logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step) + logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + logger.add_scalar('D/D_loss', loss_D.item(), step) + logger.add_scalar('D/D_fake', loss_Dgen.item(), step) + logger.add_scalar('D/D_real', loss_Dreal.item(), step) + elif config["logger"] == "wandb": + logger.log({"G_Loss": loss_G.item()}, step = step) + logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + logger.log({"G_feat_match": feat_match_loss.item()}, step = step) + logger.log({"G_ID": loss_G_ID.item()}, step = step) + logger.log({"D_loss": loss_D.item()}, step = step) + logger.log({"D_fake": loss_Dgen.item()}, step = step) + logger.log({"D_real": loss_Dreal.item()}, step = step) + torch.cuda.empty_cache() + + if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0): + gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* img_std + img_mean).numpy() + for r in range(batch_gpu): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_gpu): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1) + img_fake = gen(image_infer, id_vector_src1).cpu() + + img_fake = img_fake * img_std + img_fake = img_fake + img_mean + img_fake = img_fake.numpy() + for j in range(batch_gpu): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg')) + torch.cuda.empty_cache() + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if rank == 0 and (step+1) % model_freq==0: + + torch.save(gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + torch.save(dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + + torch.save(g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + + torch.save(d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + print("Rank %d process done!"%rank) + torch.distributed.barrier() \ No newline at end of file diff --git a/train_yamls/train_distillation.yaml b/train_yamls/train_distillation.yaml new file mode 100644 index 0000000..bdb0960 --- /dev/null +++ b/train_yamls/train_distillation.yaml @@ -0,0 +1,72 @@ +# Related scripts +train_script_name: distillation_mgpu + +# models' scripts +model_configs: + g_model: + script: Generator_modulation_depthwise_config + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 16 + res_num: 9 + up_mode: bilinear + res_mode: depthwise + + d_model: + script: projected_discriminator + class_name: ProjectedDiscriminator + module_params: + diffaug: False + interp224: False + backbone_kwargs: {} + +teacher_model: + node_ip: localhost + version: depthwise + model_step: 430000 + +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 64 +feature_list: ["down4","BN1"] + +# Dataset +dataloader: VGGFace2HQ_multigpu +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +id_weight: 20.0 +reconstruct_weight: 10.0 +feature_match_weight: 10.0 + +# Log +log_step: 300 +model_save_step: 10000 +sample_step: 1000 +total_step: 1000000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file