From 9efbe03d3f74e13df0bb9a7708989b5d6d3e3a55 Mon Sep 17 00:00:00 2001 From: XHChen0528 Date: Sat, 19 Mar 2022 22:47:08 +0800 Subject: [PATCH] update --- GUI/file_sync/filestate_machine0.json | 287 ++++++------ components/Generator_LSTU_config.py | 218 +++++++++ components/LSTU.py | 47 ++ losses/PatchNCE.py | 4 +- train_multigpu.py | 8 +- train_scripts/trainer_multi_gpu_CUT.py | 525 ++++++++++++++++++++++ train_scripts/trainer_multi_gpu_cycle.py | 534 +++++++++++++++++++++++ train_yamls/train_cycleloss.yaml | 17 +- 8 files changed, 1503 insertions(+), 137 deletions(-) create mode 100644 components/Generator_LSTU_config.py create mode 100644 components/LSTU.py create mode 100644 train_scripts/trainer_multi_gpu_CUT.py create mode 100644 train_scripts/trainer_multi_gpu_cycle.py diff --git a/GUI/file_sync/filestate_machine0.json b/GUI/file_sync/filestate_machine0.json index 5d788aa..56bff43 100644 --- a/GUI/file_sync/filestate_machine0.json +++ b/GUI/file_sync/filestate_machine0.json @@ -1,91 +1,91 @@ { - "GUI.py": 1645109256.0056663, - "test.py": 1646330130.1009316, - "train.py": 1643397924.974299, - "components\\Generator.py": 1644689001.9005148, - "components\\projected_discriminator.py": 1642348101.4661522, - "components\\pg_modules\\blocks.py": 1640773190.0, - "components\\pg_modules\\diffaug.py": 1640773190.0, - "components\\pg_modules\\discriminator.py": 1642349784.9407308, - "components\\pg_modules\\networks_fastgan.py": 1640773190.0, - "components\\pg_modules\\networks_stylegan2.py": 1640773190.0, - "components\\pg_modules\\projector.py": 1642349764.3896568, - "data_tools\\data_loader.py": 1611123530.660446, - "data_tools\\data_loader_condition.py": 1625411562.8217106, - "data_tools\\data_loader_VGGFace2HQ.py": 1644234949.3769877, - "data_tools\\StyleResize.py": 1624954084.7176485, - "data_tools\\test_dataloader_dir.py": 1634041792.6743984, - "losses\\PerceptualLoss.py": 1615020169.668723, - "losses\\SliceWassersteinDistance.py": 1634022704.6082795, - "models\\arcface_models.py": 1642390690.623, - "models\\config.py": 1632643596.2908099, - "models\\__init__.py": 1642390864.8828168, - "test_scripts\\tester_common.py": 1625369535.199175, + "GUI.py": 1647657822.9152665, + "test.py": 1647657822.945273, + "train.py": 1647657822.9562755, + "components\\Generator.py": 1647657822.93127, + "components\\projected_discriminator.py": 1647657822.938272, + "components\\pg_modules\\blocks.py": 1647657822.9362714, + "components\\pg_modules\\diffaug.py": 1647657822.9362714, + "components\\pg_modules\\discriminator.py": 1647657822.937271, + "components\\pg_modules\\networks_fastgan.py": 1647657822.937271, + "components\\pg_modules\\networks_stylegan2.py": 1647657822.937271, + "components\\pg_modules\\projector.py": 1647657822.938272, + "data_tools\\data_loader.py": 1647657822.9392715, + "data_tools\\data_loader_condition.py": 1647657822.9402719, + "data_tools\\data_loader_VGGFace2HQ.py": 1647657822.9392715, + "data_tools\\StyleResize.py": 1647657822.9392715, + "data_tools\\test_dataloader_dir.py": 1647657822.941272, + "losses\\PerceptualLoss.py": 1647657822.9432724, + "losses\\SliceWassersteinDistance.py": 1647657822.9432724, + "models\\arcface_models.py": 1647657822.9442725, + "models\\config.py": 1647657822.9442725, + "models\\__init__.py": 1647657822.9442725, + "test_scripts\\tester_common.py": 1647657822.9472733, "test_scripts\\tester_FastNST.py": 1634041357.607633, - "train_scripts\\trainer_base.py": 1642396105.3868554, - "train_scripts\\trainer_FM.py": 1643021959.3577182, - "train_scripts\\trainer_naiv512.py": 1642315674.9740853, - "utilities\\checkpoint_manager.py": 1611123530.6624403, - "utilities\\figure.py": 1611123530.6634378, - "utilities\\json_config.py": 1611123530.6614666, - "utilities\\learningrate_scheduler.py": 1611123530.675422, - "utilities\\logo_class.py": 1633883995.3093486, - "utilities\\plot.py": 1641911100.7995758, - "utilities\\reporter.py": 1646311333.3067005, - "utilities\\save_heatmap.py": 1611123530.679439, - "utilities\\sshupload.py": 1645168814.6421573, - "utilities\\transfer_checkpoint.py": 1642397157.0163105, - "utilities\\utilities.py": 1634019485.0783668, - "utilities\\yaml_config.py": 1611123530.6614666, - "train_yamls\\train_512FM.yaml": 1643021615.8106658, - "train_scripts\\trainer_2layer_FM.py": 1642826548.2530458, - "train_yamls\\train_2layer_FM.yaml": 1642411635.5534878, - "components\\Generator_reduce.py": 1645020911.0651233, + "train_scripts\\trainer_base.py": 1647657822.9582758, + "train_scripts\\trainer_FM.py": 1647657822.957276, + "train_scripts\\trainer_naiv512.py": 1647657822.9602764, + "utilities\\checkpoint_manager.py": 1647657822.9652774, + "utilities\\figure.py": 1647657822.9652774, + "utilities\\json_config.py": 1647657822.9652774, + "utilities\\learningrate_scheduler.py": 1647657822.9652774, + "utilities\\logo_class.py": 1647657822.9662776, + "utilities\\plot.py": 1647657822.9662776, + "utilities\\reporter.py": 1647657822.9662776, + "utilities\\save_heatmap.py": 1647657822.967278, + "utilities\\sshupload.py": 1647657822.967278, + "utilities\\transfer_checkpoint.py": 1647657822.967278, + "utilities\\utilities.py": 1647657822.9682784, + "utilities\\yaml_config.py": 1647657822.9682784, + "train_yamls\\train_512FM.yaml": 1647657822.961277, + "train_scripts\\trainer_2layer_FM.py": 1647657822.957276, + "train_yamls\\train_2layer_FM.yaml": 1647657822.961277, + "components\\Generator_reduce.py": 1647657822.934271, "insightface_func\\face_detect_crop_multi.py": 1643796928.6362474, "insightface_func\\face_detect_crop_single.py": 1638370471.7967434, "insightface_func\\__init__.py": 1624197300.011183, "insightface_func\\utils\\face_align_ffhqandnewarc.py": 1638370471.850638, - "losses\\PatchNCE.py": 1642989384.9713614, + "losses\\PatchNCE.py": 1647677173.0239084, "parsing_model\\model.py": 1626745709.554252, "parsing_model\\resnet.py": 1626745709.554252, - "test_scripts\\tester_common copy.py": 1625369535.199175, - "test_scripts\\tester_video.py": 1642734397.3307388, - "train_scripts\\trainer_cycleloss.py": 1642580463.495596, - "train_scripts\\trainer_GramFM.py": 1643095575.2628715, - "utilities\\ImagenetNorm.py": 1642732910.5280058, - "utilities\\reverse2original.py": 1642733688.7976837, - "train_yamls\\train_cycleloss.yaml": 1642577741.345273, - "train_yamls\\train_GramFM.yaml": 1643398791.363959, - "train_yamls\\train_512FM_Modulation.yaml": 1643022022.3165789, - "face_crop.py": 1643789609.1834445, - "face_crop_video.py": 1643815024.5516832, - "similarity.py": 1643269705.1073737, - "train_multigpu.py": 1646329983.38444, - "components\\arcface_decoder.py": 1643396144.2575414, + "test_scripts\\tester_common copy.py": 1647657822.9472733, + "test_scripts\\tester_video.py": 1647657822.9482737, + "train_scripts\\trainer_cycleloss.py": 1647657822.9592762, + "train_scripts\\trainer_GramFM.py": 1647657822.9582758, + "utilities\\ImagenetNorm.py": 1647657822.9642777, + "utilities\\reverse2original.py": 1647657822.9662776, + "train_yamls\\train_cycleloss.yaml": 1647700399.7641768, + "train_yamls\\train_GramFM.yaml": 1647657822.9622767, + "train_yamls\\train_512FM_Modulation.yaml": 1647657822.961277, + "face_crop.py": 1647657822.9422722, + "face_crop_video.py": 1647657822.9422722, + "similarity.py": 1647657822.945273, + "train_multigpu.py": 1647700474.445049, + "components\\arcface_decoder.py": 1647657822.9352713, "components\\Generator_nobias.py": 1643179001.810856, - "data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1644861019.9044807, - "data_tools\\data_loader_VGGFace2HQ_Rec.py": 1643398754.86898, - "test_scripts\\tester_arcface_Rec.py": 1643431261.9333818, - "test_scripts\\tester_image.py": 1645547412.8218117, - "torch_utils\\custom_ops.py": 1640773190.0, - "torch_utils\\misc.py": 1640773190.0, - "torch_utils\\persistence.py": 1640773190.0, - "torch_utils\\training_stats.py": 1640773190.0, - "torch_utils\\utils_spectrum.py": 1640773190.0, - "torch_utils\\__init__.py": 1640773190.0, - "torch_utils\\ops\\bias_act.py": 1640773190.0, - "torch_utils\\ops\\conv2d_gradfix.py": 1640773190.0, - "torch_utils\\ops\\conv2d_resample.py": 1640773190.0, - "torch_utils\\ops\\filtered_lrelu.py": 1640773190.0, - "torch_utils\\ops\\fma.py": 1640773190.0, - "torch_utils\\ops\\grid_sample_gradfix.py": 1640773190.0, - "torch_utils\\ops\\upfirdn2d.py": 1640773190.0, - "torch_utils\\ops\\__init__.py": 1640773190.0, - "train_scripts\\trainer_arcface_rec.py": 1643399647.0182135, - "train_scripts\\trainer_multigpu_base.py": 1644131205.772292, - "train_scripts\\trainer_multi_gpu.py": 1644854424.6483445, - "train_yamls\\train_arcface_rec.yaml": 1643398807.3434353, - "train_yamls\\train_multigpu.yaml": 1644549590.0652373, + "data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1647657822.9402719, + "data_tools\\data_loader_VGGFace2HQ_Rec.py": 1647657822.9392715, + "test_scripts\\tester_arcface_Rec.py": 1647657822.946273, + "test_scripts\\tester_image.py": 1647657822.9472733, + "torch_utils\\custom_ops.py": 1647657822.9482737, + "torch_utils\\misc.py": 1647657822.9492736, + "torch_utils\\persistence.py": 1647657822.9552753, + "torch_utils\\training_stats.py": 1647657822.9562755, + "torch_utils\\utils_spectrum.py": 1647657822.9562755, + "torch_utils\\__init__.py": 1647657822.9482737, + "torch_utils\\ops\\bias_act.py": 1647657822.9502747, + "torch_utils\\ops\\conv2d_gradfix.py": 1647657822.9512744, + "torch_utils\\ops\\conv2d_resample.py": 1647657822.9512744, + "torch_utils\\ops\\filtered_lrelu.py": 1647657822.9532747, + "torch_utils\\ops\\fma.py": 1647657822.9532747, + "torch_utils\\ops\\grid_sample_gradfix.py": 1647657822.9542756, + "torch_utils\\ops\\upfirdn2d.py": 1647657822.9552753, + "torch_utils\\ops\\__init__.py": 1647657822.9492736, + "train_scripts\\trainer_arcface_rec.py": 1647657822.9582758, + "train_scripts\\trainer_multigpu_base.py": 1647657822.9602764, + "train_scripts\\trainer_multi_gpu.py": 1647657822.9592762, + "train_yamls\\train_arcface_rec.yaml": 1647657822.9622767, + "train_yamls\\train_multigpu.yaml": 1647657822.963277, "wandb\\run-20220129_032741-340btp9k\\files\\conda-environment.yaml": 1643398065.409959, "wandb\\run-20220129_032741-340btp9k\\files\\config.yaml": 1643398069.2392955, "wandb\\run-20220129_032939-2nmaozxq\\files\\conda-environment.yaml": 1643398182.647548, @@ -100,50 +100,91 @@ "wandb\\run-20220129_034859-2puk6sph\\files\\config.yaml": 1643399477.881678, "wandb\\run-20220129_035624-3hmwgcgw\\files\\conda-environment.yaml": 1643399787.8899708, "wandb\\run-20220129_035624-3hmwgcgw\\files\\config.yaml": 1643426465.6088357, - "dnnlib\\util.py": 1640773190.0, - "dnnlib\\__init__.py": 1640773190.0, - "components\\Generator_ori.py": 1644689174.414655, - "losses\\cos.py": 1644229583.4023254, - "data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1644860106.943826, - "speed_test.py": 1646304298.3483005, - "components\\DeConv_Invo.py": 1644426607.1588645, + "dnnlib\\util.py": 1647657822.941272, + "dnnlib\\__init__.py": 1647657822.941272, + "components\\Generator_ori.py": 1647657822.9332705, + "losses\\cos.py": 1647657822.9442725, + "data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1647657822.9402719, + "speed_test.py": 1647657822.945273, + "components\\DeConv_Invo.py": 1647657822.9302697, "components\\Generator_reduce_up.py": 1644688655.2096283, - "components\\Generator_upsample.py": 1644689723.8293872, - "components\\misc\\Involution.py": 1644509321.5267963, - "train_yamls\\train_Invoup.yaml": 1644689981.9794765, - "flops.py": 1646330033.710075, - "detection_test.py": 1644935512.6830947, - "components\\DeConv_Depthwise.py": 1645064447.4379447, - "components\\DeConv_Depthwise1.py": 1644946969.5054545, + "components\\Generator_upsample.py": 1647657822.9352713, + "components\\misc\\Involution.py": 1647657822.9352713, + "train_yamls\\train_Invoup.yaml": 1647657822.9622767, + "flops.py": 1647657822.9422722, + "detection_test.py": 1647657822.941272, + "components\\DeConv_Depthwise.py": 1647657822.9292698, + "components\\DeConv_Depthwise1.py": 1647657822.9292698, "components\\Generator_modulation_depthwise.py": 1644861291.4467516, - "components\\Generator_modulation_depthwise_config.py": 1645262162.9779513, - "components\\Generator_modulation_up.py": 1644946498.7005584, - "components\\Generator_oriae_modulation.py": 1644897798.1987727, - "components\\Generator_ori_config.py": 1646329319.6131227, - "train_scripts\\trainer_multi_gpu1.py": 1644859528.8428593, - "train_yamls\\train_Depthwise.yaml": 1644860961.099242, - "train_yamls\\train_depthwise_modulation.yaml": 1645035964.9551077, - "train_yamls\\train_oriae_modulation.yaml": 1644897891.2576747, - "train_distillation_mgpu.py": 1645554603.908166, - "components\\DeConv.py": 1645263338.9001615, - "components\\DeConv_Depthwise_ECA.py": 1645265769.1076133, - "components\\ECA.py": 1614848426.9604986, - "components\\ECA_Depthwise_Conv.py": 1645265754.2023985, - "components\\Generator_eca_depthwise.py": 1645266338.9750814, - "losses\\KA.py": 1645546325.331715, - "train_scripts\\trainer_distillation_mgpu.py": 1645601961.4139585, - "train_yamls\\train_distillation.yaml": 1645600099.540936, - "annotation.py": 1645931038.719335, - "components\\DeConv_ECA_Invo.py": 1645869347.379311, - "components\\DeConv_Invobn.py": 1645862876.018001, - "components\\Generator_Invobn_config.py": 1645929418.6924264, - "components\\Generator_Invobn_config1.py": 1645862695.8743145, - "components\\misc\\Involution_BN.py": 1645867197.3984175, - "components\\misc\\Involution_ECA.py": 1645869012.4927464, - "train_yamls\\train_Invobn_config.yaml": 1646101598.499709, - "components\\Generator_Invobn_config2.py": 1645962618.7056074, - "components\\Generator_Invobn_config3.py": 1646302561.1984286, - "components\\Generator_ori_modulation_config.py": 1646329636.719998, - "test_scripts\\tester_image_allstep.py": 1646312637.9363256, - "train_yamls\\train_ori_modulation_config.yaml": 1646330406.200162 + "components\\Generator_modulation_depthwise_config.py": 1647657822.93227, + "components\\Generator_modulation_up.py": 1647657822.9332705, + "components\\Generator_oriae_modulation.py": 1647657822.934271, + "components\\Generator_ori_config.py": 1647657822.934271, + "train_scripts\\trainer_multi_gpu1.py": 1647657822.9602764, + "train_yamls\\train_Depthwise.yaml": 1647657822.961277, + "train_yamls\\train_depthwise_modulation.yaml": 1647657822.963277, + "train_yamls\\train_oriae_modulation.yaml": 1647657822.9642777, + "train_distillation_mgpu.py": 1647657822.9562755, + "components\\DeConv.py": 1647657822.9292698, + "components\\DeConv_Depthwise_ECA.py": 1647657822.9292698, + "components\\ECA.py": 1647657822.9302697, + "components\\ECA_Depthwise_Conv.py": 1647657822.93127, + "components\\Generator_eca_depthwise.py": 1647657822.93227, + "losses\\KA.py": 1647657822.9432724, + "train_scripts\\trainer_distillation_mgpu.py": 1647657822.9592762, + "train_yamls\\train_distillation.yaml": 1647657822.963277, + "annotation.py": 1647657822.9172668, + "components\\DeConv_ECA_Invo.py": 1647657822.9302697, + "components\\DeConv_Invobn.py": 1647657822.9302697, + "components\\Generator_Invobn_config.py": 1647657822.93127, + "components\\Generator_Invobn_config1.py": 1647657822.93127, + "components\\misc\\Involution_BN.py": 1647657822.9362714, + "components\\misc\\Involution_ECA.py": 1647657822.9362714, + "train_yamls\\train_Invobn_config.yaml": 1647657822.9622767, + "components\\Generator_Invobn_config2.py": 1647657822.93227, + "components\\Generator_Invobn_config3.py": 1647657822.93227, + "components\\Generator_ori_modulation_config.py": 1647657822.934271, + "test_scripts\\tester_image_allstep.py": 1647657822.9482737, + "train_yamls\\train_ori_modulation_config.yaml": 1647657822.9642777, + "test_arcface.py": 1647657822.946273, + "arcface_torch\\dataset.py": 1647657822.9222684, + "arcface_torch\\eval_ijbc.py": 1647657822.9242685, + "arcface_torch\\inference.py": 1647657822.9242685, + "arcface_torch\\losses.py": 1647657822.9242685, + "arcface_torch\\lr_scheduler.py": 1647657822.9242685, + "arcface_torch\\onnx_helper.py": 1647657822.9252684, + "arcface_torch\\onnx_ijbc.py": 1647657822.9252684, + "arcface_torch\\partial_fc.py": 1647657822.9252684, + "arcface_torch\\torch2onnx.py": 1647657822.9262686, + "arcface_torch\\train.py": 1647657822.9262686, + "arcface_torch\\backbones\\iresnet.py": 1647657822.918267, + "arcface_torch\\backbones\\iresnet2060.py": 1647657822.9192681, + "arcface_torch\\backbones\\mobilefacenet.py": 1647657822.9192681, + "arcface_torch\\backbones\\__init__.py": 1647657822.918267, + "arcface_torch\\configs\\3millions.py": 1647657822.9192681, + "arcface_torch\\configs\\base.py": 1647657822.9192681, + "arcface_torch\\configs\\glint360k_mobileface_lr02_bs4k.py": 1647657822.9202676, + "arcface_torch\\configs\\glint360k_r100_lr02_bs4k_16gpus.py": 1647657822.9202676, + "arcface_torch\\configs\\ms1mv3_mobileface_lr02.py": 1647657822.9202676, + "arcface_torch\\configs\\ms1mv3_r100_lr02.py": 1647657822.9202676, + "arcface_torch\\configs\\ms1mv3_r50_lr02.py": 1647657822.9202676, + "arcface_torch\\configs\\webface42m_mobilefacenet_pfc02_bs8k_16gpus.py": 1647657822.9212687, + "arcface_torch\\configs\\webface42m_r100_lr01_pfc02_bs4k_16gpus.py": 1647657822.9212687, + "arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_32gpus.py": 1647657822.9212687, + "arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_8gpus.py": 1647657822.9212687, + "arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs8k_16gpus.py": 1647657822.9212687, + "arcface_torch\\configs\\__init__.py": 1647657822.9192681, + "arcface_torch\\eval\\verification.py": 1647657822.923268, + "arcface_torch\\eval\\__init__.py": 1647657822.923268, + "arcface_torch\\utils\\plot.py": 1647657822.927269, + "arcface_torch\\utils\\utils_callbacks.py": 1647657822.927269, + "arcface_torch\\utils\\utils_config.py": 1647657822.927269, + "arcface_torch\\utils\\utils_logging.py": 1647657822.927269, + "arcface_torch\\utils\\__init__.py": 1647657822.9262686, + "components\\LSTU.py": 1647697688.593807, + "test_scripts\\tester_ID_Pose.py": 1647657822.946273, + "train_scripts\\trainer_distillation_mgpu_withrec_importweight.py": 1647657822.9592762, + "train_scripts\\trainer_multi_gpu_CUT.py": 1647676964.475, + "train_scripts\\trainer_multi_gpu_cycle.py": 1647699496.9083836, + "components\\Generator_LSTU_config.py": 1647697793.0348723 } \ No newline at end of file diff --git a/components/Generator_LSTU_config.py b/components/Generator_LSTU_config.py new file mode 100644 index 0000000..d154e35 --- /dev/null +++ b/components/Generator_LSTU_config.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 27th February 2022 7:50:18 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + + +import torch +from torch import nn +from LSTU import LSTU + +# from components.DeConv_Invo import DeConv +class InstanceNorm(nn.Module): + def __init__(self, epsilon=1e-8): + """ + @notice: avoid in-place ops. + https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 + """ + super(InstanceNorm, self).__init__() + self.epsilon = epsilon + + def forward(self, x): + x = x - torch.mean(x, (2, 3), True) + tmp = torch.mul(x, x) # or x ** 2 + tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) + return x * tmp + +class ApplyStyle(nn.Module): + """ + @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb + """ + def __init__(self, latent_size, channels): + super(ApplyStyle, self).__init__() + self.linear = nn.Linear(latent_size, channels * 2) + + def forward(self, x, latent): + style = self.linear(latent) # style => [batch_size, n_channels*2] + shape = [-1, 2, x.size(1), 1, 1] + style = style.view(shape) # [batch_size, 2, n_channels, ...] + #x = x * (style[:, 0] + 1.) + style[:, 1] + x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1 + return x + +class ResnetBlock_Adain(nn.Module): + def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): + super(ResnetBlock_Adain, self).__init__() + + p = 0 + conv1 = [] + if padding_type == 'reflect': + conv1 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv1 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] + self.conv1 = nn.Sequential(*conv1) + self.style1 = ApplyStyle(latent_size, dim) + self.act1 = activation + + p = 0 + conv2 = [] + if padding_type == 'reflect': + conv2 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv2 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] + self.conv2 = nn.Sequential(*conv2) + self.style2 = ApplyStyle(latent_size, dim) + + + def forward(self, x, dlatents_in_slice): + y = self.conv1(x) + y = self.style1(y, dlatents_in_slice) + y = self.act1(y) + y = self.conv2(y) + y = self.style2(y, dlatents_in_slice) + out = x + y + return out + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + + aggregator = kwargs["aggregator"] + res_mode = aggregator + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + activation = nn.ReLU(True) + from components.DeConv_Depthwise import DeConv + + # self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=False), + # nn.BatchNorm2d(64), activation) + self.first_layer = nn.Sequential(nn.ReflectionPad2d(1), + nn.Conv2d(3, in_channel, kernel_size=3, padding=0, bias=False), + nn.BatchNorm2d(in_channel), + activation) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = nn.Sequential( + nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1, groups=in_channel), + nn.Conv2d(in_channel, in_channel*2, kernel_size=1, bias=False), + nn.BatchNorm2d(in_channel*2), + activation) + + self.down2 = nn.Sequential( + nn.Conv2d(in_channel*2, in_channel*2, kernel_size=3, padding=1, groups=in_channel*2), + nn.Conv2d(in_channel*2, in_channel*4, kernel_size=1, bias=False), + nn.BatchNorm2d(in_channel*4), + activation) + + self.lstu = LSTU(in_channel*4,in_channel*4,in_channel*8,4) + + self.down3 = nn.Sequential( + nn.Conv2d(in_channel*4, in_channel*4, kernel_size=3, padding=1, groups=in_channel*4), + nn.Conv2d(in_channel*4, in_channel*8, kernel_size=1, bias=False), + nn.BatchNorm2d(in_channel*8), + activation) + + self.down4 = nn.Sequential( + nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, padding=1, groups=in_channel*8), + nn.Conv2d(in_channel*8, in_channel*8, kernel_size=1, bias=False), + nn.BatchNorm2d(in_channel*8), + activation) + + + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + ResnetBlock_Adain(in_channel*8, latent_size=id_dim, + padding_type=padding_type, activation=activation, res_mode=res_mode)] + self.BottleNeck = nn.Sequential(*BN) + + self.up4 = nn.Sequential( + DeConv(in_channel*8,in_channel*8,3,up_mode=up_mode), + nn.BatchNorm2d(in_channel*8), + activation + ) + + self.up3 = nn.Sequential( + DeConv(in_channel*8,in_channel*4,3,up_mode=up_mode), + nn.BatchNorm2d(in_channel*4), + activation + ) + + self.up2 = nn.Sequential( + DeConv(in_channel*4,in_channel*2,3,up_mode=up_mode), + nn.BatchNorm2d(in_channel*2), + activation + ) + + self.up1 = nn.Sequential( + DeConv(in_channel*2,in_channel,3,up_mode=up_mode), + nn.BatchNorm2d(in_channel), + activation + ) + # self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1)) + self.last_layer = nn.Sequential(nn.ReflectionPad2d(1), + nn.Conv2d(in_channel, 3, kernel_size=3, padding=0)) + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.first_layer(img) + res = self.down1(res) + res1 = self.down2(res) + res = self.down3(res1) + res = self.down4(res) + + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + + res = self.up4(res) + res = self.up3(res) + skip = self.lstu(res1) + res = self.up2(res + skip) + res = self.up1(res) + res = self.last_layer(res) + + return res \ No newline at end of file diff --git a/components/LSTU.py b/components/LSTU.py new file mode 100644 index 0000000..de457bd --- /dev/null +++ b/components/LSTU.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator.py +# Created Date: Sunday January 16th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 13th February 2022 2:03:21 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import torch +from torch import nn + + +class LSTU(nn.Module): + def __init__( + self, + in_channel, + out_channel, + latent_channel, + scale = 4 + ): + super().__init__() + sig = nn.Sigmoid() + self.relu = nn.Relu() + + self.up_sample = nn.Sequential(nn.ConvTranspose2d(latent_channel, out_channel, kernel_size=4, stride=scale, padding=0, bias=False), + nn.BatchNorm2d(out_channel), sig) + + self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channel), sig) + + self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channel), sig) + + self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True)) + + def forward(self, encoder_in, bottleneck_in): + h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel` + h_bar_l = self.conv11(h_hat_l_1) + f_l = self.forget_gate(h_hat_l_1) + r_l = self.reset_gate (h_hat_l_1) + h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in + x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1 + return x_hat_l \ No newline at end of file diff --git a/losses/PatchNCE.py b/losses/PatchNCE.py index aabdedb..ed8db2c 100644 --- a/losses/PatchNCE.py +++ b/losses/PatchNCE.py @@ -167,8 +167,8 @@ class PatchNCELoss(nn.Module): def forward(self, feat_q, feat_k): num_patches = feat_q.shape[0] - dim = feat_q.shape[1] - feat_k = feat_k.detach() + dim = feat_q.shape[1] + feat_k = feat_k.detach() # pos logit l_pos = torch.bmm( diff --git a/train_multigpu.py b/train_multigpu.py index 4f33805..41fe225 100644 --- a/train_multigpu.py +++ b/train_multigpu.py @@ -31,9 +31,9 @@ def getParameters(): parser = argparse.ArgumentParser() # general settings - parser.add_argument('-v', '--version', type=str, default='ori_tiny', + parser.add_argument('-v', '--version', type=str, default='cycle_lstu1', help="version name for train, test, finetune") - parser.add_argument('-t', '--tag', type=str, default='tiny', + parser.add_argument('-t', '--tag', type=str, default='cycle', help="tag for current experiment") parser.add_argument('-p', '--phase', type=str, default="train", @@ -46,9 +46,9 @@ def getParameters(): # training parser.add_argument('--experiment_description', type=str, - default="只用conv,训练最小的模型") + default="cycle配合LSTU") - parser.add_argument('--train_yaml', type=str, default="train_ori_modulation_config.yaml") + parser.add_argument('--train_yaml', type=str, default="train_cycleloss.yaml") # system logger parser.add_argument('--logger', type=str, diff --git a/train_scripts/trainer_multi_gpu_CUT.py b/train_scripts/trainer_multi_gpu_CUT.py new file mode 100644 index 0000000..2763ac4 --- /dev/null +++ b/train_scripts/trainer_multi_gpu_CUT.py @@ -0,0 +1,525 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 17th March 2022 1:01:52 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random +import shutil +import tempfile + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +from arcface_torch.backbones.iresnet import iresnet100 + +from utilities.plot import plot_batch +from losses.cos import cosin_metric +from train_scripts.trainer_multigpu_base import TrainerBase + + +class Trainer(TrainerBase): + + def __init__(self, + config, + reporter): + super(Trainer, self).__init__(config, reporter) + + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) + + def train(self): + # Launch processes. + num_gpus = len(self.config["gpus"]) + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus) + +# TODO modify this function to build your models +def init_framework(config, reporter, device, rank): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + model_config = config["model_configs"] + + if config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + file1 = os.path.join("components", model_config["g_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + dscript_name = "components." + model_config["d_model"]["script"] + file1 = os.path.join("components", model_config["d_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + + elif config["phase"] == "finetune": + gscript_name = config["com_base"] + model_config["g_model"]["script"] + dscript_name = config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + reporter.writeInfo("Generator structure:") + reporter.writeModel(gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + dis = dis_class(**model_config["d_model"]["module_params"]) + + + # print and recorde model structure + reporter.writeInfo("Discriminator structure:") + reporter.writeModel(dis.__str__()) + + # arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + # arcface = arcface1['model'].module + + arcface = iresnet100(pretrained=False, fp16=False) + arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu')) + arcface.eval() + + # train in GPU + + # if in finetune phase, load the pretrained checkpoint + if config["phase"] == "finetune": + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["generator_name"])) + gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["discriminator_name"])) + dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + + gen = gen.to(device) + dis = dis.to(device) + arcface= arcface.to(device) + arcface.requires_grad_(False) + arcface.eval() + + + + return gen, dis, arcface + +# TODO modify this function to configurate the optimizer of your pipeline +def setup_optimizers(config, reporter, gen, dis, rank): + + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + g_train_opt = config['g_optim_config'] + d_train_opt = config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = config['optim_type'] + + if optim_type == 'Adam': + g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if config["phase"] == "finetune": + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["generator_name"])) + g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["discriminator_name"])) + d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"])) + return g_optimizer, d_optimizer + + +def train_loop( + rank, + config, + reporter, + temp_dir + ): + + version = config["version"] + + ckpt_dir = config["project_checkpoints"] + sample_dir = config["project_samples"] + + log_freq = config["log_step"] + model_freq = config["model_save_step"] + sample_freq = config["sample_step"] + total_step = config["total_step"] + random_seed = config["dataset_params"]["random_seed"] + + + id_w = config["id_weight"] + rec_w = config["reconstruct_weight"] + feat_w = config["feature_match_weight"] + num_gpus = len(config["gpus"]) + batch_gpu = config["batch_size"] // num_gpus + + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + + + + if rank == 0: + img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + + # Initialize. + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Create dataloader. + if rank == 0: + print('Loading training set...') + + dataset = config["dataset_paths"][config["dataset_name"]] + #================================================# + print("Prepare the train dataloader...") + dlModulename = config["dataloader"] + package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'GetLoader') + dataloader_class= dataloaderClass + dataloader = dataloader_class(dataset, + rank, + num_gpus, + batch_gpu, + **config["dataset_params"]) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + gen, dis, arcface = init_framework(config, reporter, device, rank) + + # Check for existing checkpoint + + # Print network summary tables. + # if rank == 0: + # attr = torch.empty([batch_gpu, 3, 512, 512], device=device) + # id = torch.empty([batch_gpu, 3, 112, 112], device=device) + # latent = misc.print_module_summary(arcface, [id]) + # img = misc.print_module_summary(gen, [attr, latent]) + # misc.print_module_summary(dis, [img, None]) + # del attr + # del id + # del latent + # del img + # torch.cuda.empty_cache() + + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [gen, dis, arcface]: + if module is not None and num_gpus > 1: + for param in misc.params_and_buffers(module): + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + #==============build tensorboard=================# + if config["logger"] == "tensorboard": + import torch.utils.tensorboard as tensorboard + tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) + logger = tensorboard_writer + + elif config["logger"] == "wandb": + import wandb + wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", + tags=[config["tag"]], name=version) + + wandb.config = { + "total_step": config["total_step"], + "batch_size": config["batch_size"] + } + logger = wandb + + + random.seed(random_seed) + randindex = [i for i in range(batch_gpu)] + + # set the start point for training loop + if config["phase"] == "finetune": + start = config["ckpt"] + else: + start = 0 + if rank == 0: + import datetime + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from utilities.logo_class import logo_class + logo_class.print_start_training() + + dis.feature_network.requires_grad_(False) + + for step in range(start, total_step): + gen.train() + dis.train() + for interval in range(2): + random.shuffle(randindex) + src_image1, src_image2 = dataloader.next() + # if rank ==0: + + # elapsed = time.time() - start_time + # elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("dataloader:",elapsed) + + if step%2 == 0: + img_id = src_image2 + else: + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + if interval == 0: + + img_fake = gen(src_image1, latent_id) + gen_logits,_ = dis(img_fake.detach(), None) + loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() + + real_logits,_ = dis(src_image2,None) + loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean() + + loss_D = loss_Dgen + loss_Dreal + d_optimizer.zero_grad(set_to_none=True) + loss_D.backward() + with torch.autograd.profiler.record_function('discriminator_opt'): + # params = [param for param in dis.parameters() if param.grad is not None] + # if len(params) > 0: + # flat = torch.cat([param.grad.flatten() for param in params]) + # if num_gpus > 1: + # torch.distributed.all_reduce(flat) + # flat /= num_gpus + # misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + # grads = flat.split([param.numel() for param in params]) + # for param, grad in zip(params, grads): + # param.grad = grad.reshape(param.shape) + params = [param for param in dis.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + d_optimizer.step() + # if rank ==0: + + # elapsed = time.time() - start_time + # elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("Discriminator training:",elapsed) + else: + + # model.netD.requires_grad_(True) + img_fake = gen(src_image1, latent_id) + # G loss + gen_logits,feat = dis(img_fake, None) + + loss_Gmain = (-gen_logits).mean() + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean() + real_feat = dis.get_feature(src_image1) + feat_match_loss = l1_loss(feat["3"],real_feat["3"]) + loss_G = loss_Gmain + loss_G_ID * id_w + \ + feat_match_loss * feat_w + if step%2 == 0: + #G_Rec + loss_G_Rec = l1_loss(img_fake, src_image1) + loss_G += loss_G_Rec * rec_w + + g_optimizer.zero_grad(set_to_none=True) + loss_G.backward() + with torch.autograd.profiler.record_function('generator_opt'): + params = [param for param in gen.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + g_optimizer.step() + # if rank ==0: + + # elapsed = time.time() - start_time + # elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("Generator training:",elapsed) + + + # Print out log info + if rank == 0 and (step + 1) % log_freq == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("ready to report losses") + # ID_Total= loss_G_ID + # torch.distributed.all_reduce(ID_Total) + + epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ + G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ + D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + format(version, elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ + loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + print(epochinformation) + reporter.writeInfo(epochinformation) + + if config["logger"] == "tensorboard": + logger.add_scalar('G/G_loss', loss_G.item(), step) + logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step) + logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + logger.add_scalar('D/D_loss', loss_D.item(), step) + logger.add_scalar('D/D_fake', loss_Dgen.item(), step) + logger.add_scalar('D/D_real', loss_Dreal.item(), step) + elif config["logger"] == "wandb": + logger.log({"G_Loss": loss_G.item()}, step = step) + logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + logger.log({"G_feat_match": feat_match_loss.item()}, step = step) + logger.log({"G_ID": loss_G_ID.item()}, step = step) + logger.log({"D_loss": loss_D.item()}, step = step) + logger.log({"D_fake": loss_Dgen.item()}, step = step) + logger.log({"D_real": loss_Dreal.item()}, step = step) + torch.cuda.empty_cache() + + if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0): + gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* img_std + img_mean).numpy() + for r in range(batch_gpu): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_gpu): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1) + img_fake = gen(image_infer, id_vector_src1).cpu() + + img_fake = img_fake * img_std + img_fake = img_fake + img_mean + img_fake = img_fake.numpy() + for j in range(batch_gpu): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg')) + torch.cuda.empty_cache() + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if rank == 0 and (step+1) % model_freq==0: + + torch.save(gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + torch.save(dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + + torch.save(g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + + torch.save(d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + print("Rank %d process done!"%rank) + torch.distributed.barrier() \ No newline at end of file diff --git a/train_scripts/trainer_multi_gpu_cycle.py b/train_scripts/trainer_multi_gpu_cycle.py new file mode 100644 index 0000000..8104312 --- /dev/null +++ b/train_scripts/trainer_multi_gpu_cycle.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 17th March 2022 1:01:52 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random +import shutil +import tempfile + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +from arcface_torch.backbones.iresnet import iresnet100 + +from utilities.plot import plot_batch +from losses.cos import cosin_metric +from train_scripts.trainer_multigpu_base import TrainerBase + + +class Trainer(TrainerBase): + + def __init__(self, + config, + reporter): + super(Trainer, self).__init__(config, reporter) + + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) + + def train(self): + # Launch processes. + num_gpus = len(self.config["gpus"]) + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus) + +# TODO modify this function to build your models +def init_framework(config, reporter, device, rank): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + model_config = config["model_configs"] + + if config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + file1 = os.path.join("components", model_config["g_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + dscript_name = "components." + model_config["d_model"]["script"] + file1 = os.path.join("components", model_config["d_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + + elif config["phase"] == "finetune": + gscript_name = config["com_base"] + model_config["g_model"]["script"] + dscript_name = config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + reporter.writeInfo("Generator structure:") + reporter.writeModel(gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + dis = dis_class(**model_config["d_model"]["module_params"]) + + + # print and recorde model structure + reporter.writeInfo("Discriminator structure:") + reporter.writeModel(dis.__str__()) + + # arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + # arcface = arcface1['model'].module + + arcface = iresnet100(pretrained=False, fp16=False) + arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu')) + arcface.eval() + + # train in GPU + + # if in finetune phase, load the pretrained checkpoint + if config["phase"] == "finetune": + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["generator_name"])) + gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["discriminator_name"])) + dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + + gen = gen.to(device) + dis = dis.to(device) + arcface= arcface.to(device) + arcface.requires_grad_(False) + arcface.eval() + + + + return gen, dis, arcface + +# TODO modify this function to configurate the optimizer of your pipeline +def setup_optimizers(config, reporter, gen, dis, rank): + + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + g_train_opt = config['g_optim_config'] + d_train_opt = config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = config['optim_type'] + + if optim_type == 'Adam': + g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if config["phase"] == "finetune": + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["generator_name"])) + g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["discriminator_name"])) + d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"])) + return g_optimizer, d_optimizer + + +def train_loop( + rank, + config, + reporter, + temp_dir + ): + + version = config["version"] + + ckpt_dir = config["project_checkpoints"] + sample_dir = config["project_samples"] + + log_freq = config["log_step"] + model_freq = config["model_save_step"] + sample_freq = config["sample_step"] + total_step = config["total_step"] + random_seed = config["dataset_params"]["random_seed"] + + + id_w = config["id_weight"] + rec_w = config["reconstruct_weight"] + rec_fm_w = config["rec_feature_match_weight"] + cycle_fm_w = config["cycle_feature_match_weight"] + cycle_w = config["cycle_weight"] + num_gpus = len(config["gpus"]) + batch_gpu = config["batch_size"] // num_gpus + + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + + + + if rank == 0: + img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + + # Initialize. + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Create dataloader. + if rank == 0: + print('Loading training set...') + + dataset = config["dataset_paths"][config["dataset_name"]] + #================================================# + print("Prepare the train dataloader...") + dlModulename = config["dataloader"] + package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'GetLoader') + dataloader_class= dataloaderClass + dataloader = dataloader_class(dataset, + rank, + num_gpus, + batch_gpu, + **config["dataset_params"]) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + gen, dis, arcface = init_framework(config, reporter, device, rank) + + # Check for existing checkpoint + + # Print network summary tables. + # if rank == 0: + # attr = torch.empty([batch_gpu, 3, 512, 512], device=device) + # id = torch.empty([batch_gpu, 3, 112, 112], device=device) + # latent = misc.print_module_summary(arcface, [id]) + # img = misc.print_module_summary(gen, [attr, latent]) + # misc.print_module_summary(dis, [img, None]) + # del attr + # del id + # del latent + # del img + # torch.cuda.empty_cache() + + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [gen, dis, arcface]: + if module is not None and num_gpus > 1: + for param in misc.params_and_buffers(module): + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + #==============build tensorboard=================# + if config["logger"] == "tensorboard": + import torch.utils.tensorboard as tensorboard + tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) + logger = tensorboard_writer + + elif config["logger"] == "wandb": + import wandb + wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", + tags=[config["tag"]], name=version) + + wandb.config = { + "total_step": config["total_step"], + "batch_size": config["batch_size"] + } + logger = wandb + + + random.seed(random_seed) + randindex = [i for i in range(batch_gpu)] + + # set the start point for training loop + if config["phase"] == "finetune": + start = config["ckpt"] + else: + start = 0 + if rank == 0: + import datetime + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from utilities.logo_class import logo_class + logo_class.print_start_training() + + dis.feature_network.requires_grad_(False) + + for step in range(start, total_step): + gen.train() + dis.train() + for interval in range(2): + random.shuffle(randindex) + src_image1, src_image2 = dataloader.next() + # if rank ==0: + + # elapsed = time.time() - start_time + # elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("dataloader:",elapsed) + + if step%2 == 0: + img_id = src_image2 + else: + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + if interval == 0: + img_fake = gen(src_image1, latent_id) + gen_logits,_ = dis(img_fake.detach(), None) + loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() + + real_logits,_ = dis(src_image2,None) + loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean() + + loss_D = loss_Dgen + loss_Dreal + d_optimizer.zero_grad(set_to_none=True) + loss_D.backward() + with torch.autograd.profiler.record_function('discriminator_opt'): + # params = [param for param in dis.parameters() if param.grad is not None] + # if len(params) > 0: + # flat = torch.cat([param.grad.flatten() for param in params]) + # if num_gpus > 1: + # torch.distributed.all_reduce(flat) + # flat /= num_gpus + # misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + # grads = flat.split([param.numel() for param in params]) + # for param, grad in zip(params, grads): + # param.grad = grad.reshape(param.shape) + params = [param for param in dis.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + d_optimizer.step() + # if rank ==0: + + # elapsed = time.time() - start_time + # elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("Discriminator training:",elapsed) + else: + + # model.netD.requires_grad_(True) + img_fake = gen(src_image1, latent_id) + # G loss + gen_logits,feat = dis(img_fake, None) + real_feat = dis.get_feature(src_image1) + loss_Gmain = (-gen_logits).mean() + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean() + + if step%2 == 0: + #G_Rec + rec_fm = l1_loss(feat["3"],real_feat["3"]) + loss_G_Rec = l1_loss(img_fake, src_image1) + loss_G += loss_G_Rec * rec_w + rec_fm * rec_fm_w + else: + source1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic') + latent_source1 = arcface(source1_down) + latent_source1 = F.normalize(latent_source1, p=2, dim=1) + cycle_src = gen(img_fake, latent_source1) + cycle_loss = l1_loss(src_image1,cycle_src) + cycle_feat = dis.get_feature(cycle_src) + cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) + loss_G += cycle_loss * cycle_w + cycle_fm * cycle_fm_w + + loss_G = loss_Gmain + loss_G_ID * id_w + g_optimizer.zero_grad(set_to_none=True) + loss_G.backward() + with torch.autograd.profiler.record_function('generator_opt'): + params = [param for param in gen.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + g_optimizer.step() + # if rank ==0: + + # elapsed = time.time() - start_time + # elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("Generator training:",elapsed) + + + # Print out log info + if rank == 0 and (step + 1) % log_freq == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + # print("ready to report losses") + # ID_Total= loss_G_ID + # torch.distributed.all_reduce(ID_Total) + + epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ + G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ + D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + format(version, elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ + loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + print(epochinformation) + reporter.writeInfo(epochinformation) + + if config["logger"] == "tensorboard": + logger.add_scalar('G/G_loss', loss_G.item(), step) + logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step) + logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + logger.add_scalar('D/D_loss', loss_D.item(), step) + logger.add_scalar('D/D_fake', loss_Dgen.item(), step) + logger.add_scalar('D/D_real', loss_Dreal.item(), step) + elif config["logger"] == "wandb": + logger.log({"G_Loss": loss_G.item()}, step = step) + logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + logger.log({"G_feat_match": feat_match_loss.item()}, step = step) + logger.log({"G_ID": loss_G_ID.item()}, step = step) + logger.log({"D_loss": loss_D.item()}, step = step) + logger.log({"D_fake": loss_Dgen.item()}, step = step) + logger.log({"D_real": loss_Dreal.item()}, step = step) + torch.cuda.empty_cache() + + if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0): + gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* img_std + img_mean).numpy() + for r in range(batch_gpu): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_gpu): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1) + img_fake = gen(image_infer, id_vector_src1).cpu() + + img_fake = img_fake * img_std + img_fake = img_fake + img_mean + img_fake = img_fake.numpy() + for j in range(batch_gpu): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg')) + torch.cuda.empty_cache() + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if rank == 0 and (step+1) % model_freq==0: + + torch.save(gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + torch.save(dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + + torch.save(g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + + torch.save(d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + print("Rank %d process done!"%rank) + torch.distributed.barrier() \ No newline at end of file diff --git a/train_yamls/train_cycleloss.yaml b/train_yamls/train_cycleloss.yaml index 9c9de89..5813cf4 100644 --- a/train_yamls/train_cycleloss.yaml +++ b/train_yamls/train_cycleloss.yaml @@ -1,10 +1,10 @@ # Related scripts -train_script_name: cycleloss +train_script_name: multi_gpu_cycle # models' scripts model_configs: g_model: - script: Generator + script: Generator_LSTU_config class_name: Generator module_params: g_conv_dim: 512 @@ -22,10 +22,10 @@ model_configs: arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar # Training information -batch_size: 12 +batch_size: 24 # Dataset -dataloader: VGGFace2HQ +dataloader: VGGFace2HQ_multigpu dataset_name: vggface2_hq dataset_params: random_seed: 1234 @@ -40,18 +40,19 @@ eval_batch_size: 2 # Optimizer optim_type: Adam g_optim_config: - lr: 0.0004 + lr: 0.0006 betas: [ 0, 0.99] eps: !!float 1e-8 d_optim_config: - lr: 0.0004 + lr: 0.0006 betas: [ 0, 0.99] eps: !!float 1e-8 id_weight: 20.0 -reconstruct_weight: 0.1 -feature_match_weight: 0.1 +reconstruct_weight: 10.0 +rec_feature_match_weight: 10.0 +cycle_feature_match_weight: 10.0 cycle_weight: 10.0 # Log