update
This commit is contained in:
@@ -1,91 +1,91 @@
|
||||
{
|
||||
"GUI.py": 1645109256.0056663,
|
||||
"test.py": 1646330130.1009316,
|
||||
"train.py": 1643397924.974299,
|
||||
"components\\Generator.py": 1644689001.9005148,
|
||||
"components\\projected_discriminator.py": 1642348101.4661522,
|
||||
"components\\pg_modules\\blocks.py": 1640773190.0,
|
||||
"components\\pg_modules\\diffaug.py": 1640773190.0,
|
||||
"components\\pg_modules\\discriminator.py": 1642349784.9407308,
|
||||
"components\\pg_modules\\networks_fastgan.py": 1640773190.0,
|
||||
"components\\pg_modules\\networks_stylegan2.py": 1640773190.0,
|
||||
"components\\pg_modules\\projector.py": 1642349764.3896568,
|
||||
"data_tools\\data_loader.py": 1611123530.660446,
|
||||
"data_tools\\data_loader_condition.py": 1625411562.8217106,
|
||||
"data_tools\\data_loader_VGGFace2HQ.py": 1644234949.3769877,
|
||||
"data_tools\\StyleResize.py": 1624954084.7176485,
|
||||
"data_tools\\test_dataloader_dir.py": 1634041792.6743984,
|
||||
"losses\\PerceptualLoss.py": 1615020169.668723,
|
||||
"losses\\SliceWassersteinDistance.py": 1634022704.6082795,
|
||||
"models\\arcface_models.py": 1642390690.623,
|
||||
"models\\config.py": 1632643596.2908099,
|
||||
"models\\__init__.py": 1642390864.8828168,
|
||||
"test_scripts\\tester_common.py": 1625369535.199175,
|
||||
"GUI.py": 1647657822.9152665,
|
||||
"test.py": 1647657822.945273,
|
||||
"train.py": 1647657822.9562755,
|
||||
"components\\Generator.py": 1647657822.93127,
|
||||
"components\\projected_discriminator.py": 1647657822.938272,
|
||||
"components\\pg_modules\\blocks.py": 1647657822.9362714,
|
||||
"components\\pg_modules\\diffaug.py": 1647657822.9362714,
|
||||
"components\\pg_modules\\discriminator.py": 1647657822.937271,
|
||||
"components\\pg_modules\\networks_fastgan.py": 1647657822.937271,
|
||||
"components\\pg_modules\\networks_stylegan2.py": 1647657822.937271,
|
||||
"components\\pg_modules\\projector.py": 1647657822.938272,
|
||||
"data_tools\\data_loader.py": 1647657822.9392715,
|
||||
"data_tools\\data_loader_condition.py": 1647657822.9402719,
|
||||
"data_tools\\data_loader_VGGFace2HQ.py": 1647657822.9392715,
|
||||
"data_tools\\StyleResize.py": 1647657822.9392715,
|
||||
"data_tools\\test_dataloader_dir.py": 1647657822.941272,
|
||||
"losses\\PerceptualLoss.py": 1647657822.9432724,
|
||||
"losses\\SliceWassersteinDistance.py": 1647657822.9432724,
|
||||
"models\\arcface_models.py": 1647657822.9442725,
|
||||
"models\\config.py": 1647657822.9442725,
|
||||
"models\\__init__.py": 1647657822.9442725,
|
||||
"test_scripts\\tester_common.py": 1647657822.9472733,
|
||||
"test_scripts\\tester_FastNST.py": 1634041357.607633,
|
||||
"train_scripts\\trainer_base.py": 1642396105.3868554,
|
||||
"train_scripts\\trainer_FM.py": 1643021959.3577182,
|
||||
"train_scripts\\trainer_naiv512.py": 1642315674.9740853,
|
||||
"utilities\\checkpoint_manager.py": 1611123530.6624403,
|
||||
"utilities\\figure.py": 1611123530.6634378,
|
||||
"utilities\\json_config.py": 1611123530.6614666,
|
||||
"utilities\\learningrate_scheduler.py": 1611123530.675422,
|
||||
"utilities\\logo_class.py": 1633883995.3093486,
|
||||
"utilities\\plot.py": 1641911100.7995758,
|
||||
"utilities\\reporter.py": 1646311333.3067005,
|
||||
"utilities\\save_heatmap.py": 1611123530.679439,
|
||||
"utilities\\sshupload.py": 1645168814.6421573,
|
||||
"utilities\\transfer_checkpoint.py": 1642397157.0163105,
|
||||
"utilities\\utilities.py": 1634019485.0783668,
|
||||
"utilities\\yaml_config.py": 1611123530.6614666,
|
||||
"train_yamls\\train_512FM.yaml": 1643021615.8106658,
|
||||
"train_scripts\\trainer_2layer_FM.py": 1642826548.2530458,
|
||||
"train_yamls\\train_2layer_FM.yaml": 1642411635.5534878,
|
||||
"components\\Generator_reduce.py": 1645020911.0651233,
|
||||
"train_scripts\\trainer_base.py": 1647657822.9582758,
|
||||
"train_scripts\\trainer_FM.py": 1647657822.957276,
|
||||
"train_scripts\\trainer_naiv512.py": 1647657822.9602764,
|
||||
"utilities\\checkpoint_manager.py": 1647657822.9652774,
|
||||
"utilities\\figure.py": 1647657822.9652774,
|
||||
"utilities\\json_config.py": 1647657822.9652774,
|
||||
"utilities\\learningrate_scheduler.py": 1647657822.9652774,
|
||||
"utilities\\logo_class.py": 1647657822.9662776,
|
||||
"utilities\\plot.py": 1647657822.9662776,
|
||||
"utilities\\reporter.py": 1647657822.9662776,
|
||||
"utilities\\save_heatmap.py": 1647657822.967278,
|
||||
"utilities\\sshupload.py": 1647657822.967278,
|
||||
"utilities\\transfer_checkpoint.py": 1647657822.967278,
|
||||
"utilities\\utilities.py": 1647657822.9682784,
|
||||
"utilities\\yaml_config.py": 1647657822.9682784,
|
||||
"train_yamls\\train_512FM.yaml": 1647657822.961277,
|
||||
"train_scripts\\trainer_2layer_FM.py": 1647657822.957276,
|
||||
"train_yamls\\train_2layer_FM.yaml": 1647657822.961277,
|
||||
"components\\Generator_reduce.py": 1647657822.934271,
|
||||
"insightface_func\\face_detect_crop_multi.py": 1643796928.6362474,
|
||||
"insightface_func\\face_detect_crop_single.py": 1638370471.7967434,
|
||||
"insightface_func\\__init__.py": 1624197300.011183,
|
||||
"insightface_func\\utils\\face_align_ffhqandnewarc.py": 1638370471.850638,
|
||||
"losses\\PatchNCE.py": 1642989384.9713614,
|
||||
"losses\\PatchNCE.py": 1647677173.0239084,
|
||||
"parsing_model\\model.py": 1626745709.554252,
|
||||
"parsing_model\\resnet.py": 1626745709.554252,
|
||||
"test_scripts\\tester_common copy.py": 1625369535.199175,
|
||||
"test_scripts\\tester_video.py": 1642734397.3307388,
|
||||
"train_scripts\\trainer_cycleloss.py": 1642580463.495596,
|
||||
"train_scripts\\trainer_GramFM.py": 1643095575.2628715,
|
||||
"utilities\\ImagenetNorm.py": 1642732910.5280058,
|
||||
"utilities\\reverse2original.py": 1642733688.7976837,
|
||||
"train_yamls\\train_cycleloss.yaml": 1642577741.345273,
|
||||
"train_yamls\\train_GramFM.yaml": 1643398791.363959,
|
||||
"train_yamls\\train_512FM_Modulation.yaml": 1643022022.3165789,
|
||||
"face_crop.py": 1643789609.1834445,
|
||||
"face_crop_video.py": 1643815024.5516832,
|
||||
"similarity.py": 1643269705.1073737,
|
||||
"train_multigpu.py": 1646329983.38444,
|
||||
"components\\arcface_decoder.py": 1643396144.2575414,
|
||||
"test_scripts\\tester_common copy.py": 1647657822.9472733,
|
||||
"test_scripts\\tester_video.py": 1647657822.9482737,
|
||||
"train_scripts\\trainer_cycleloss.py": 1647657822.9592762,
|
||||
"train_scripts\\trainer_GramFM.py": 1647657822.9582758,
|
||||
"utilities\\ImagenetNorm.py": 1647657822.9642777,
|
||||
"utilities\\reverse2original.py": 1647657822.9662776,
|
||||
"train_yamls\\train_cycleloss.yaml": 1647700399.7641768,
|
||||
"train_yamls\\train_GramFM.yaml": 1647657822.9622767,
|
||||
"train_yamls\\train_512FM_Modulation.yaml": 1647657822.961277,
|
||||
"face_crop.py": 1647657822.9422722,
|
||||
"face_crop_video.py": 1647657822.9422722,
|
||||
"similarity.py": 1647657822.945273,
|
||||
"train_multigpu.py": 1647700474.445049,
|
||||
"components\\arcface_decoder.py": 1647657822.9352713,
|
||||
"components\\Generator_nobias.py": 1643179001.810856,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1644861019.9044807,
|
||||
"data_tools\\data_loader_VGGFace2HQ_Rec.py": 1643398754.86898,
|
||||
"test_scripts\\tester_arcface_Rec.py": 1643431261.9333818,
|
||||
"test_scripts\\tester_image.py": 1645547412.8218117,
|
||||
"torch_utils\\custom_ops.py": 1640773190.0,
|
||||
"torch_utils\\misc.py": 1640773190.0,
|
||||
"torch_utils\\persistence.py": 1640773190.0,
|
||||
"torch_utils\\training_stats.py": 1640773190.0,
|
||||
"torch_utils\\utils_spectrum.py": 1640773190.0,
|
||||
"torch_utils\\__init__.py": 1640773190.0,
|
||||
"torch_utils\\ops\\bias_act.py": 1640773190.0,
|
||||
"torch_utils\\ops\\conv2d_gradfix.py": 1640773190.0,
|
||||
"torch_utils\\ops\\conv2d_resample.py": 1640773190.0,
|
||||
"torch_utils\\ops\\filtered_lrelu.py": 1640773190.0,
|
||||
"torch_utils\\ops\\fma.py": 1640773190.0,
|
||||
"torch_utils\\ops\\grid_sample_gradfix.py": 1640773190.0,
|
||||
"torch_utils\\ops\\upfirdn2d.py": 1640773190.0,
|
||||
"torch_utils\\ops\\__init__.py": 1640773190.0,
|
||||
"train_scripts\\trainer_arcface_rec.py": 1643399647.0182135,
|
||||
"train_scripts\\trainer_multigpu_base.py": 1644131205.772292,
|
||||
"train_scripts\\trainer_multi_gpu.py": 1644854424.6483445,
|
||||
"train_yamls\\train_arcface_rec.yaml": 1643398807.3434353,
|
||||
"train_yamls\\train_multigpu.yaml": 1644549590.0652373,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1647657822.9402719,
|
||||
"data_tools\\data_loader_VGGFace2HQ_Rec.py": 1647657822.9392715,
|
||||
"test_scripts\\tester_arcface_Rec.py": 1647657822.946273,
|
||||
"test_scripts\\tester_image.py": 1647657822.9472733,
|
||||
"torch_utils\\custom_ops.py": 1647657822.9482737,
|
||||
"torch_utils\\misc.py": 1647657822.9492736,
|
||||
"torch_utils\\persistence.py": 1647657822.9552753,
|
||||
"torch_utils\\training_stats.py": 1647657822.9562755,
|
||||
"torch_utils\\utils_spectrum.py": 1647657822.9562755,
|
||||
"torch_utils\\__init__.py": 1647657822.9482737,
|
||||
"torch_utils\\ops\\bias_act.py": 1647657822.9502747,
|
||||
"torch_utils\\ops\\conv2d_gradfix.py": 1647657822.9512744,
|
||||
"torch_utils\\ops\\conv2d_resample.py": 1647657822.9512744,
|
||||
"torch_utils\\ops\\filtered_lrelu.py": 1647657822.9532747,
|
||||
"torch_utils\\ops\\fma.py": 1647657822.9532747,
|
||||
"torch_utils\\ops\\grid_sample_gradfix.py": 1647657822.9542756,
|
||||
"torch_utils\\ops\\upfirdn2d.py": 1647657822.9552753,
|
||||
"torch_utils\\ops\\__init__.py": 1647657822.9492736,
|
||||
"train_scripts\\trainer_arcface_rec.py": 1647657822.9582758,
|
||||
"train_scripts\\trainer_multigpu_base.py": 1647657822.9602764,
|
||||
"train_scripts\\trainer_multi_gpu.py": 1647657822.9592762,
|
||||
"train_yamls\\train_arcface_rec.yaml": 1647657822.9622767,
|
||||
"train_yamls\\train_multigpu.yaml": 1647657822.963277,
|
||||
"wandb\\run-20220129_032741-340btp9k\\files\\conda-environment.yaml": 1643398065.409959,
|
||||
"wandb\\run-20220129_032741-340btp9k\\files\\config.yaml": 1643398069.2392955,
|
||||
"wandb\\run-20220129_032939-2nmaozxq\\files\\conda-environment.yaml": 1643398182.647548,
|
||||
@@ -100,50 +100,91 @@
|
||||
"wandb\\run-20220129_034859-2puk6sph\\files\\config.yaml": 1643399477.881678,
|
||||
"wandb\\run-20220129_035624-3hmwgcgw\\files\\conda-environment.yaml": 1643399787.8899708,
|
||||
"wandb\\run-20220129_035624-3hmwgcgw\\files\\config.yaml": 1643426465.6088357,
|
||||
"dnnlib\\util.py": 1640773190.0,
|
||||
"dnnlib\\__init__.py": 1640773190.0,
|
||||
"components\\Generator_ori.py": 1644689174.414655,
|
||||
"losses\\cos.py": 1644229583.4023254,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1644860106.943826,
|
||||
"speed_test.py": 1646304298.3483005,
|
||||
"components\\DeConv_Invo.py": 1644426607.1588645,
|
||||
"dnnlib\\util.py": 1647657822.941272,
|
||||
"dnnlib\\__init__.py": 1647657822.941272,
|
||||
"components\\Generator_ori.py": 1647657822.9332705,
|
||||
"losses\\cos.py": 1647657822.9442725,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1647657822.9402719,
|
||||
"speed_test.py": 1647657822.945273,
|
||||
"components\\DeConv_Invo.py": 1647657822.9302697,
|
||||
"components\\Generator_reduce_up.py": 1644688655.2096283,
|
||||
"components\\Generator_upsample.py": 1644689723.8293872,
|
||||
"components\\misc\\Involution.py": 1644509321.5267963,
|
||||
"train_yamls\\train_Invoup.yaml": 1644689981.9794765,
|
||||
"flops.py": 1646330033.710075,
|
||||
"detection_test.py": 1644935512.6830947,
|
||||
"components\\DeConv_Depthwise.py": 1645064447.4379447,
|
||||
"components\\DeConv_Depthwise1.py": 1644946969.5054545,
|
||||
"components\\Generator_upsample.py": 1647657822.9352713,
|
||||
"components\\misc\\Involution.py": 1647657822.9352713,
|
||||
"train_yamls\\train_Invoup.yaml": 1647657822.9622767,
|
||||
"flops.py": 1647657822.9422722,
|
||||
"detection_test.py": 1647657822.941272,
|
||||
"components\\DeConv_Depthwise.py": 1647657822.9292698,
|
||||
"components\\DeConv_Depthwise1.py": 1647657822.9292698,
|
||||
"components\\Generator_modulation_depthwise.py": 1644861291.4467516,
|
||||
"components\\Generator_modulation_depthwise_config.py": 1645262162.9779513,
|
||||
"components\\Generator_modulation_up.py": 1644946498.7005584,
|
||||
"components\\Generator_oriae_modulation.py": 1644897798.1987727,
|
||||
"components\\Generator_ori_config.py": 1646329319.6131227,
|
||||
"train_scripts\\trainer_multi_gpu1.py": 1644859528.8428593,
|
||||
"train_yamls\\train_Depthwise.yaml": 1644860961.099242,
|
||||
"train_yamls\\train_depthwise_modulation.yaml": 1645035964.9551077,
|
||||
"train_yamls\\train_oriae_modulation.yaml": 1644897891.2576747,
|
||||
"train_distillation_mgpu.py": 1645554603.908166,
|
||||
"components\\DeConv.py": 1645263338.9001615,
|
||||
"components\\DeConv_Depthwise_ECA.py": 1645265769.1076133,
|
||||
"components\\ECA.py": 1614848426.9604986,
|
||||
"components\\ECA_Depthwise_Conv.py": 1645265754.2023985,
|
||||
"components\\Generator_eca_depthwise.py": 1645266338.9750814,
|
||||
"losses\\KA.py": 1645546325.331715,
|
||||
"train_scripts\\trainer_distillation_mgpu.py": 1645601961.4139585,
|
||||
"train_yamls\\train_distillation.yaml": 1645600099.540936,
|
||||
"annotation.py": 1645931038.719335,
|
||||
"components\\DeConv_ECA_Invo.py": 1645869347.379311,
|
||||
"components\\DeConv_Invobn.py": 1645862876.018001,
|
||||
"components\\Generator_Invobn_config.py": 1645929418.6924264,
|
||||
"components\\Generator_Invobn_config1.py": 1645862695.8743145,
|
||||
"components\\misc\\Involution_BN.py": 1645867197.3984175,
|
||||
"components\\misc\\Involution_ECA.py": 1645869012.4927464,
|
||||
"train_yamls\\train_Invobn_config.yaml": 1646101598.499709,
|
||||
"components\\Generator_Invobn_config2.py": 1645962618.7056074,
|
||||
"components\\Generator_Invobn_config3.py": 1646302561.1984286,
|
||||
"components\\Generator_ori_modulation_config.py": 1646329636.719998,
|
||||
"test_scripts\\tester_image_allstep.py": 1646312637.9363256,
|
||||
"train_yamls\\train_ori_modulation_config.yaml": 1646330406.200162
|
||||
"components\\Generator_modulation_depthwise_config.py": 1647657822.93227,
|
||||
"components\\Generator_modulation_up.py": 1647657822.9332705,
|
||||
"components\\Generator_oriae_modulation.py": 1647657822.934271,
|
||||
"components\\Generator_ori_config.py": 1647657822.934271,
|
||||
"train_scripts\\trainer_multi_gpu1.py": 1647657822.9602764,
|
||||
"train_yamls\\train_Depthwise.yaml": 1647657822.961277,
|
||||
"train_yamls\\train_depthwise_modulation.yaml": 1647657822.963277,
|
||||
"train_yamls\\train_oriae_modulation.yaml": 1647657822.9642777,
|
||||
"train_distillation_mgpu.py": 1647657822.9562755,
|
||||
"components\\DeConv.py": 1647657822.9292698,
|
||||
"components\\DeConv_Depthwise_ECA.py": 1647657822.9292698,
|
||||
"components\\ECA.py": 1647657822.9302697,
|
||||
"components\\ECA_Depthwise_Conv.py": 1647657822.93127,
|
||||
"components\\Generator_eca_depthwise.py": 1647657822.93227,
|
||||
"losses\\KA.py": 1647657822.9432724,
|
||||
"train_scripts\\trainer_distillation_mgpu.py": 1647657822.9592762,
|
||||
"train_yamls\\train_distillation.yaml": 1647657822.963277,
|
||||
"annotation.py": 1647657822.9172668,
|
||||
"components\\DeConv_ECA_Invo.py": 1647657822.9302697,
|
||||
"components\\DeConv_Invobn.py": 1647657822.9302697,
|
||||
"components\\Generator_Invobn_config.py": 1647657822.93127,
|
||||
"components\\Generator_Invobn_config1.py": 1647657822.93127,
|
||||
"components\\misc\\Involution_BN.py": 1647657822.9362714,
|
||||
"components\\misc\\Involution_ECA.py": 1647657822.9362714,
|
||||
"train_yamls\\train_Invobn_config.yaml": 1647657822.9622767,
|
||||
"components\\Generator_Invobn_config2.py": 1647657822.93227,
|
||||
"components\\Generator_Invobn_config3.py": 1647657822.93227,
|
||||
"components\\Generator_ori_modulation_config.py": 1647657822.934271,
|
||||
"test_scripts\\tester_image_allstep.py": 1647657822.9482737,
|
||||
"train_yamls\\train_ori_modulation_config.yaml": 1647657822.9642777,
|
||||
"test_arcface.py": 1647657822.946273,
|
||||
"arcface_torch\\dataset.py": 1647657822.9222684,
|
||||
"arcface_torch\\eval_ijbc.py": 1647657822.9242685,
|
||||
"arcface_torch\\inference.py": 1647657822.9242685,
|
||||
"arcface_torch\\losses.py": 1647657822.9242685,
|
||||
"arcface_torch\\lr_scheduler.py": 1647657822.9242685,
|
||||
"arcface_torch\\onnx_helper.py": 1647657822.9252684,
|
||||
"arcface_torch\\onnx_ijbc.py": 1647657822.9252684,
|
||||
"arcface_torch\\partial_fc.py": 1647657822.9252684,
|
||||
"arcface_torch\\torch2onnx.py": 1647657822.9262686,
|
||||
"arcface_torch\\train.py": 1647657822.9262686,
|
||||
"arcface_torch\\backbones\\iresnet.py": 1647657822.918267,
|
||||
"arcface_torch\\backbones\\iresnet2060.py": 1647657822.9192681,
|
||||
"arcface_torch\\backbones\\mobilefacenet.py": 1647657822.9192681,
|
||||
"arcface_torch\\backbones\\__init__.py": 1647657822.918267,
|
||||
"arcface_torch\\configs\\3millions.py": 1647657822.9192681,
|
||||
"arcface_torch\\configs\\base.py": 1647657822.9192681,
|
||||
"arcface_torch\\configs\\glint360k_mobileface_lr02_bs4k.py": 1647657822.9202676,
|
||||
"arcface_torch\\configs\\glint360k_r100_lr02_bs4k_16gpus.py": 1647657822.9202676,
|
||||
"arcface_torch\\configs\\ms1mv3_mobileface_lr02.py": 1647657822.9202676,
|
||||
"arcface_torch\\configs\\ms1mv3_r100_lr02.py": 1647657822.9202676,
|
||||
"arcface_torch\\configs\\ms1mv3_r50_lr02.py": 1647657822.9202676,
|
||||
"arcface_torch\\configs\\webface42m_mobilefacenet_pfc02_bs8k_16gpus.py": 1647657822.9212687,
|
||||
"arcface_torch\\configs\\webface42m_r100_lr01_pfc02_bs4k_16gpus.py": 1647657822.9212687,
|
||||
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_32gpus.py": 1647657822.9212687,
|
||||
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_8gpus.py": 1647657822.9212687,
|
||||
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs8k_16gpus.py": 1647657822.9212687,
|
||||
"arcface_torch\\configs\\__init__.py": 1647657822.9192681,
|
||||
"arcface_torch\\eval\\verification.py": 1647657822.923268,
|
||||
"arcface_torch\\eval\\__init__.py": 1647657822.923268,
|
||||
"arcface_torch\\utils\\plot.py": 1647657822.927269,
|
||||
"arcface_torch\\utils\\utils_callbacks.py": 1647657822.927269,
|
||||
"arcface_torch\\utils\\utils_config.py": 1647657822.927269,
|
||||
"arcface_torch\\utils\\utils_logging.py": 1647657822.927269,
|
||||
"arcface_torch\\utils\\__init__.py": 1647657822.9262686,
|
||||
"components\\LSTU.py": 1647697688.593807,
|
||||
"test_scripts\\tester_ID_Pose.py": 1647657822.946273,
|
||||
"train_scripts\\trainer_distillation_mgpu_withrec_importweight.py": 1647657822.9592762,
|
||||
"train_scripts\\trainer_multi_gpu_CUT.py": 1647676964.475,
|
||||
"train_scripts\\trainer_multi_gpu_cycle.py": 1647699496.9083836,
|
||||
"components\\Generator_LSTU_config.py": 1647697793.0348723
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 27th February 2022 7:50:18 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from LSTU import LSTU
|
||||
|
||||
# from components.DeConv_Invo import DeConv
|
||||
class InstanceNorm(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
"""
|
||||
@notice: avoid in-place ops.
|
||||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
||||
"""
|
||||
super(InstanceNorm, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
x = x - torch.mean(x, (2, 3), True)
|
||||
tmp = torch.mul(x, x) # or x ** 2
|
||||
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
||||
return x * tmp
|
||||
|
||||
class ApplyStyle(nn.Module):
|
||||
"""
|
||||
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
||||
"""
|
||||
def __init__(self, latent_size, channels):
|
||||
super(ApplyStyle, self).__init__()
|
||||
self.linear = nn.Linear(latent_size, channels * 2)
|
||||
|
||||
def forward(self, x, latent):
|
||||
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
||||
shape = [-1, 2, x.size(1), 1, 1]
|
||||
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
||||
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
||||
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
|
||||
return x
|
||||
|
||||
class ResnetBlock_Adain(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Adain, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
if padding_type == 'reflect':
|
||||
conv1 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv1 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = ApplyStyle(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
conv2 = []
|
||||
if padding_type == 'reflect':
|
||||
conv2 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv2 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = ApplyStyle(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
y = self.conv1(x)
|
||||
y = self.style1(y, dlatents_in_slice)
|
||||
y = self.act1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.style2(y, dlatents_in_slice)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = aggregator
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
from components.DeConv_Depthwise import DeConv
|
||||
|
||||
# self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
self.first_layer = nn.Sequential(nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(3, in_channel, kernel_size=3, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_channel),
|
||||
activation)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = nn.Sequential(
|
||||
nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1, groups=in_channel),
|
||||
nn.Conv2d(in_channel, in_channel*2, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel*2),
|
||||
activation)
|
||||
|
||||
self.down2 = nn.Sequential(
|
||||
nn.Conv2d(in_channel*2, in_channel*2, kernel_size=3, padding=1, groups=in_channel*2),
|
||||
nn.Conv2d(in_channel*2, in_channel*4, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel*4),
|
||||
activation)
|
||||
|
||||
self.lstu = LSTU(in_channel*4,in_channel*4,in_channel*8,4)
|
||||
|
||||
self.down3 = nn.Sequential(
|
||||
nn.Conv2d(in_channel*4, in_channel*4, kernel_size=3, padding=1, groups=in_channel*4),
|
||||
nn.Conv2d(in_channel*4, in_channel*8, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel*8),
|
||||
activation)
|
||||
|
||||
self.down4 = nn.Sequential(
|
||||
nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, padding=1, groups=in_channel*8),
|
||||
nn.Conv2d(in_channel*8, in_channel*8, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel*8),
|
||||
activation)
|
||||
|
||||
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
ResnetBlock_Adain(in_channel*8, latent_size=id_dim,
|
||||
padding_type=padding_type, activation=activation, res_mode=res_mode)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
self.up4 = nn.Sequential(
|
||||
DeConv(in_channel*8,in_channel*8,3,up_mode=up_mode),
|
||||
nn.BatchNorm2d(in_channel*8),
|
||||
activation
|
||||
)
|
||||
|
||||
self.up3 = nn.Sequential(
|
||||
DeConv(in_channel*8,in_channel*4,3,up_mode=up_mode),
|
||||
nn.BatchNorm2d(in_channel*4),
|
||||
activation
|
||||
)
|
||||
|
||||
self.up2 = nn.Sequential(
|
||||
DeConv(in_channel*4,in_channel*2,3,up_mode=up_mode),
|
||||
nn.BatchNorm2d(in_channel*2),
|
||||
activation
|
||||
)
|
||||
|
||||
self.up1 = nn.Sequential(
|
||||
DeConv(in_channel*2,in_channel,3,up_mode=up_mode),
|
||||
nn.BatchNorm2d(in_channel),
|
||||
activation
|
||||
)
|
||||
# self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
|
||||
self.last_layer = nn.Sequential(nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(in_channel, 3, kernel_size=3, padding=0))
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.first_layer(img)
|
||||
res = self.down1(res)
|
||||
res1 = self.down2(res)
|
||||
res = self.down3(res1)
|
||||
res = self.down4(res)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
|
||||
res = self.up4(res)
|
||||
res = self.up3(res)
|
||||
skip = self.lstu(res1)
|
||||
res = self.up2(res + skip)
|
||||
res = self.up1(res)
|
||||
res = self.last_layer(res)
|
||||
|
||||
return res
|
||||
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator.py
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 13th February 2022 2:03:21 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LSTU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
latent_channel,
|
||||
scale = 4
|
||||
):
|
||||
super().__init__()
|
||||
sig = nn.Sigmoid()
|
||||
self.relu = nn.Relu()
|
||||
|
||||
self.up_sample = nn.Sequential(nn.ConvTranspose2d(latent_channel, out_channel, kernel_size=4, stride=scale, padding=0, bias=False),
|
||||
nn.BatchNorm2d(out_channel), sig)
|
||||
|
||||
self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channel), sig)
|
||||
|
||||
self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channel), sig)
|
||||
|
||||
self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True))
|
||||
|
||||
def forward(self, encoder_in, bottleneck_in):
|
||||
h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel`
|
||||
h_bar_l = self.conv11(h_hat_l_1)
|
||||
f_l = self.forget_gate(h_hat_l_1)
|
||||
r_l = self.reset_gate (h_hat_l_1)
|
||||
h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in
|
||||
x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1
|
||||
return x_hat_l
|
||||
+2
-2
@@ -167,8 +167,8 @@ class PatchNCELoss(nn.Module):
|
||||
|
||||
def forward(self, feat_q, feat_k):
|
||||
num_patches = feat_q.shape[0]
|
||||
dim = feat_q.shape[1]
|
||||
feat_k = feat_k.detach()
|
||||
dim = feat_q.shape[1]
|
||||
feat_k = feat_k.detach()
|
||||
|
||||
# pos logit
|
||||
l_pos = torch.bmm(
|
||||
|
||||
+4
-4
@@ -31,9 +31,9 @@ def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='ori_tiny',
|
||||
parser.add_argument('-v', '--version', type=str, default='cycle_lstu1',
|
||||
help="version name for train, test, finetune")
|
||||
parser.add_argument('-t', '--tag', type=str, default='tiny',
|
||||
parser.add_argument('-t', '--tag', type=str, default='cycle',
|
||||
help="tag for current experiment")
|
||||
|
||||
parser.add_argument('-p', '--phase', type=str, default="train",
|
||||
@@ -46,9 +46,9 @@ def getParameters():
|
||||
|
||||
# training
|
||||
parser.add_argument('--experiment_description', type=str,
|
||||
default="只用conv,训练最小的模型")
|
||||
default="cycle配合LSTU")
|
||||
|
||||
parser.add_argument('--train_yaml', type=str, default="train_ori_modulation_config.yaml")
|
||||
parser.add_argument('--train_yaml', type=str, default="train_cycleloss.yaml")
|
||||
|
||||
# system logger
|
||||
parser.add_argument('--logger', type=str,
|
||||
|
||||
@@ -0,0 +1,525 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_naiv512.py
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 17th March 2022 1:01:52 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch_utils import misc
|
||||
from torch_utils import training_stats
|
||||
from torch_utils.ops import conv2d_gradfix
|
||||
from torch_utils.ops import grid_sample_gradfix
|
||||
|
||||
from arcface_torch.backbones.iresnet import iresnet100
|
||||
|
||||
from utilities.plot import plot_batch
|
||||
from losses.cos import cosin_metric
|
||||
from train_scripts.trainer_multigpu_base import TrainerBase
|
||||
|
||||
|
||||
class Trainer(TrainerBase):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
reporter):
|
||||
super(Trainer, self).__init__(config, reporter)
|
||||
|
||||
import inspect
|
||||
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
|
||||
|
||||
def train(self):
|
||||
# Launch processes.
|
||||
num_gpus = len(self.config["gpus"])
|
||||
print('Launching processes...')
|
||||
torch.multiprocessing.set_start_method('spawn')
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus)
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def init_framework(config, reporter, device, rank):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
torch.cuda.set_device(rank)
|
||||
torch.cuda.empty_cache()
|
||||
model_config = config["model_configs"]
|
||||
|
||||
if config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
file1 = os.path.join("components", model_config["g_model"]["script"]+".py")
|
||||
tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py")
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
dscript_name = "components." + model_config["d_model"]["script"]
|
||||
file1 = os.path.join("components", model_config["d_model"]["script"]+".py")
|
||||
tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py")
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
|
||||
elif config["phase"] == "finetune":
|
||||
gscript_name = config["com_base"] + model_config["g_model"]["script"]
|
||||
dscript_name = config["com_base"] + model_config["d_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
reporter.writeInfo("Generator structure:")
|
||||
reporter.writeModel(gen.__str__())
|
||||
|
||||
class_name = model_config["d_model"]["class_name"]
|
||||
package = __import__(dscript_name, fromlist=True)
|
||||
dis_class = getattr(package, class_name)
|
||||
dis = dis_class(**model_config["d_model"]["module_params"])
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
reporter.writeInfo("Discriminator structure:")
|
||||
reporter.writeModel(dis.__str__())
|
||||
|
||||
# arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
|
||||
# arcface = arcface1['model'].module
|
||||
|
||||
arcface = iresnet100(pretrained=False, fp16=False)
|
||||
arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu'))
|
||||
arcface.eval()
|
||||
|
||||
# train in GPU
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if config["phase"] == "finetune":
|
||||
model_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(config["ckpt"],
|
||||
config["checkpoint_names"]["generator_name"]))
|
||||
gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu"))
|
||||
|
||||
model_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(config["ckpt"],
|
||||
config["checkpoint_names"]["discriminator_name"]))
|
||||
dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu"))
|
||||
|
||||
print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"]))
|
||||
|
||||
gen = gen.to(device)
|
||||
dis = dis.to(device)
|
||||
arcface= arcface.to(device)
|
||||
arcface.requires_grad_(False)
|
||||
arcface.eval()
|
||||
|
||||
|
||||
|
||||
return gen, dis, arcface
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def setup_optimizers(config, reporter, gen, dis, rank):
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
torch.cuda.empty_cache()
|
||||
g_train_opt = config['g_optim_config']
|
||||
d_train_opt = config['d_optim_config']
|
||||
|
||||
g_optim_params = []
|
||||
d_optim_params = []
|
||||
for k, v in gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
for k, v in dis.named_parameters():
|
||||
if v.requires_grad:
|
||||
d_optim_params.append(v)
|
||||
else:
|
||||
reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
if config["phase"] == "finetune":
|
||||
opt_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(config["ckpt"],
|
||||
config["optimizer_names"]["generator_name"]))
|
||||
g_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
opt_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(config["ckpt"],
|
||||
config["optimizer_names"]["discriminator_name"]))
|
||||
d_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"]))
|
||||
return g_optimizer, d_optimizer
|
||||
|
||||
|
||||
def train_loop(
|
||||
rank,
|
||||
config,
|
||||
reporter,
|
||||
temp_dir
|
||||
):
|
||||
|
||||
version = config["version"]
|
||||
|
||||
ckpt_dir = config["project_checkpoints"]
|
||||
sample_dir = config["project_samples"]
|
||||
|
||||
log_freq = config["log_step"]
|
||||
model_freq = config["model_save_step"]
|
||||
sample_freq = config["sample_step"]
|
||||
total_step = config["total_step"]
|
||||
random_seed = config["dataset_params"]["random_seed"]
|
||||
|
||||
|
||||
id_w = config["id_weight"]
|
||||
rec_w = config["reconstruct_weight"]
|
||||
feat_w = config["feature_match_weight"]
|
||||
num_gpus = len(config["gpus"])
|
||||
batch_gpu = config["batch_size"] // num_gpus
|
||||
|
||||
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
|
||||
if os.name == 'nt':
|
||||
init_method = 'file:///' + init_file.replace('\\', '/')
|
||||
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus)
|
||||
else:
|
||||
init_method = f'file://{init_file}'
|
||||
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus)
|
||||
|
||||
# Init torch_utils.
|
||||
sync_device = torch.device('cuda', rank)
|
||||
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
|
||||
|
||||
|
||||
|
||||
if rank == 0:
|
||||
img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
|
||||
|
||||
# Initialize.
|
||||
device = torch.device('cuda', rank)
|
||||
np.random.seed(random_seed * num_gpus + rank)
|
||||
torch.manual_seed(random_seed * num_gpus + rank)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy.
|
||||
torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy.
|
||||
conv2d_gradfix.enabled = True # Improves training speed.
|
||||
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
|
||||
|
||||
# Create dataloader.
|
||||
if rank == 0:
|
||||
print('Loading training set...')
|
||||
|
||||
dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
dataloader_class= dataloaderClass
|
||||
dataloader = dataloader_class(dataset,
|
||||
rank,
|
||||
num_gpus,
|
||||
batch_gpu,
|
||||
**config["dataset_params"])
|
||||
|
||||
# Construct networks.
|
||||
if rank == 0:
|
||||
print('Constructing networks...')
|
||||
gen, dis, arcface = init_framework(config, reporter, device, rank)
|
||||
|
||||
# Check for existing checkpoint
|
||||
|
||||
# Print network summary tables.
|
||||
# if rank == 0:
|
||||
# attr = torch.empty([batch_gpu, 3, 512, 512], device=device)
|
||||
# id = torch.empty([batch_gpu, 3, 112, 112], device=device)
|
||||
# latent = misc.print_module_summary(arcface, [id])
|
||||
# img = misc.print_module_summary(gen, [attr, latent])
|
||||
# misc.print_module_summary(dis, [img, None])
|
||||
# del attr
|
||||
# del id
|
||||
# del latent
|
||||
# del img
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Distribute across GPUs.
|
||||
if rank == 0:
|
||||
print(f'Distributing across {num_gpus} GPUs...')
|
||||
for module in [gen, dis, arcface]:
|
||||
if module is not None and num_gpus > 1:
|
||||
for param in misc.params_and_buffers(module):
|
||||
torch.distributed.broadcast(param, src=0)
|
||||
|
||||
# Setup training phases.
|
||||
if rank == 0:
|
||||
print('Setting up training phases...')
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
# MSE_loss = torch.nn.MSELoss()
|
||||
l1_loss = torch.nn.L1Loss()
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
|
||||
g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank)
|
||||
|
||||
# Initialize logs.
|
||||
if rank == 0:
|
||||
print('Initializing logs...')
|
||||
#==============build tensorboard=================#
|
||||
if config["logger"] == "tensorboard":
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"])
|
||||
logger = tensorboard_writer
|
||||
|
||||
elif config["logger"] == "wandb":
|
||||
import wandb
|
||||
wandb.init(project="Simswap_HQ", entity="xhchen", notes="512",
|
||||
tags=[config["tag"]], name=version)
|
||||
|
||||
wandb.config = {
|
||||
"total_step": config["total_step"],
|
||||
"batch_size": config["batch_size"]
|
||||
}
|
||||
logger = wandb
|
||||
|
||||
|
||||
random.seed(random_seed)
|
||||
randindex = [i for i in range(batch_gpu)]
|
||||
|
||||
# set the start point for training loop
|
||||
if config["phase"] == "finetune":
|
||||
start = config["ckpt"]
|
||||
else:
|
||||
start = 0
|
||||
if rank == 0:
|
||||
import datetime
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
print("Total step = %d"%total_step)
|
||||
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
|
||||
dis.feature_network.requires_grad_(False)
|
||||
|
||||
for step in range(start, total_step):
|
||||
gen.train()
|
||||
dis.train()
|
||||
for interval in range(2):
|
||||
random.shuffle(randindex)
|
||||
src_image1, src_image2 = dataloader.next()
|
||||
# if rank ==0:
|
||||
|
||||
# elapsed = time.time() - start_time
|
||||
# elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("dataloader:",elapsed)
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
else:
|
||||
img_id = src_image2[randindex]
|
||||
|
||||
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
|
||||
latent_id = arcface(img_id_112)
|
||||
latent_id = F.normalize(latent_id, p=2, dim=1)
|
||||
|
||||
if interval == 0:
|
||||
|
||||
img_fake = gen(src_image1, latent_id)
|
||||
gen_logits,_ = dis(img_fake.detach(), None)
|
||||
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
|
||||
|
||||
real_logits,_ = dis(src_image2,None)
|
||||
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
|
||||
|
||||
loss_D = loss_Dgen + loss_Dreal
|
||||
d_optimizer.zero_grad(set_to_none=True)
|
||||
loss_D.backward()
|
||||
with torch.autograd.profiler.record_function('discriminator_opt'):
|
||||
# params = [param for param in dis.parameters() if param.grad is not None]
|
||||
# if len(params) > 0:
|
||||
# flat = torch.cat([param.grad.flatten() for param in params])
|
||||
# if num_gpus > 1:
|
||||
# torch.distributed.all_reduce(flat)
|
||||
# flat /= num_gpus
|
||||
# misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
|
||||
# grads = flat.split([param.numel() for param in params])
|
||||
# for param, grad in zip(params, grads):
|
||||
# param.grad = grad.reshape(param.shape)
|
||||
params = [param for param in dis.parameters() if param.grad is not None]
|
||||
flat = torch.cat([param.grad.flatten() for param in params])
|
||||
torch.distributed.all_reduce(flat)
|
||||
flat /= num_gpus
|
||||
misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
|
||||
grads = flat.split([param.numel() for param in params])
|
||||
for param, grad in zip(params, grads):
|
||||
param.grad = grad.reshape(param.shape)
|
||||
d_optimizer.step()
|
||||
# if rank ==0:
|
||||
|
||||
# elapsed = time.time() - start_time
|
||||
# elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("Discriminator training:",elapsed)
|
||||
else:
|
||||
|
||||
# model.netD.requires_grad_(True)
|
||||
img_fake = gen(src_image1, latent_id)
|
||||
# G loss
|
||||
gen_logits,feat = dis(img_fake, None)
|
||||
|
||||
loss_Gmain = (-gen_logits).mean()
|
||||
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
|
||||
latent_fake = arcface(img_fake_down)
|
||||
latent_fake = F.normalize(latent_fake, p=2, dim=1)
|
||||
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
|
||||
real_feat = dis.get_feature(src_image1)
|
||||
feat_match_loss = l1_loss(feat["3"],real_feat["3"])
|
||||
loss_G = loss_Gmain + loss_G_ID * id_w + \
|
||||
feat_match_loss * feat_w
|
||||
if step%2 == 0:
|
||||
#G_Rec
|
||||
loss_G_Rec = l1_loss(img_fake, src_image1)
|
||||
loss_G += loss_G_Rec * rec_w
|
||||
|
||||
g_optimizer.zero_grad(set_to_none=True)
|
||||
loss_G.backward()
|
||||
with torch.autograd.profiler.record_function('generator_opt'):
|
||||
params = [param for param in gen.parameters() if param.grad is not None]
|
||||
flat = torch.cat([param.grad.flatten() for param in params])
|
||||
torch.distributed.all_reduce(flat)
|
||||
flat /= num_gpus
|
||||
misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
|
||||
grads = flat.split([param.numel() for param in params])
|
||||
for param, grad in zip(params, grads):
|
||||
param.grad = grad.reshape(param.shape)
|
||||
g_optimizer.step()
|
||||
# if rank ==0:
|
||||
|
||||
# elapsed = time.time() - start_time
|
||||
# elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("Generator training:",elapsed)
|
||||
|
||||
|
||||
# Print out log info
|
||||
if rank == 0 and (step + 1) % log_freq == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("ready to report losses")
|
||||
# ID_Total= loss_G_ID
|
||||
# torch.distributed.all_reduce(ID_Total)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(version, elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
reporter.writeInfo(epochinformation)
|
||||
|
||||
if config["logger"] == "tensorboard":
|
||||
logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
|
||||
logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
|
||||
logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
logger.add_scalar('D/D_real', loss_Dreal.item(), step)
|
||||
elif config["logger"] == "wandb":
|
||||
logger.log({"G_Loss": loss_G.item()}, step = step)
|
||||
logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
|
||||
logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
|
||||
logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
logger.log({"D_real": loss_Dreal.item()}, step = step)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0):
|
||||
gen.eval()
|
||||
with torch.no_grad():
|
||||
imgs = []
|
||||
zero_img = (torch.zeros_like(src_image1[0,...]))
|
||||
imgs.append(zero_img.cpu().numpy())
|
||||
save_img = ((src_image1.cpu())* img_std + img_mean).numpy()
|
||||
for r in range(batch_gpu):
|
||||
imgs.append(save_img[r,...])
|
||||
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
|
||||
id_vector_src1 = arcface(arcface_112)
|
||||
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
|
||||
|
||||
for i in range(batch_gpu):
|
||||
|
||||
imgs.append(save_img[i,...])
|
||||
image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1)
|
||||
img_fake = gen(image_infer, id_vector_src1).cpu()
|
||||
|
||||
img_fake = img_fake * img_std
|
||||
img_fake = img_fake + img_mean
|
||||
img_fake = img_fake.numpy()
|
||||
for j in range(batch_gpu):
|
||||
imgs.append(img_fake[j,...])
|
||||
print("Save test data")
|
||||
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
|
||||
plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg'))
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if rank == 0 and (step+1) % model_freq==0:
|
||||
|
||||
torch.save(gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
config["checkpoint_names"]["generator_name"])))
|
||||
torch.save(dis.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
config["checkpoint_names"]["discriminator_name"])))
|
||||
|
||||
torch.save(g_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.save(d_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
config["checkpoint_names"]["discriminator_name"])))
|
||||
print("Save step %d model checkpoint!"%(step+1))
|
||||
torch.cuda.empty_cache()
|
||||
print("Rank %d process done!"%rank)
|
||||
torch.distributed.barrier()
|
||||
@@ -0,0 +1,534 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_naiv512.py
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 17th March 2022 1:01:52 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch_utils import misc
|
||||
from torch_utils import training_stats
|
||||
from torch_utils.ops import conv2d_gradfix
|
||||
from torch_utils.ops import grid_sample_gradfix
|
||||
|
||||
from arcface_torch.backbones.iresnet import iresnet100
|
||||
|
||||
from utilities.plot import plot_batch
|
||||
from losses.cos import cosin_metric
|
||||
from train_scripts.trainer_multigpu_base import TrainerBase
|
||||
|
||||
|
||||
class Trainer(TrainerBase):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
reporter):
|
||||
super(Trainer, self).__init__(config, reporter)
|
||||
|
||||
import inspect
|
||||
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
|
||||
|
||||
def train(self):
|
||||
# Launch processes.
|
||||
num_gpus = len(self.config["gpus"])
|
||||
print('Launching processes...')
|
||||
torch.multiprocessing.set_start_method('spawn')
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus)
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def init_framework(config, reporter, device, rank):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
torch.cuda.set_device(rank)
|
||||
torch.cuda.empty_cache()
|
||||
model_config = config["model_configs"]
|
||||
|
||||
if config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
file1 = os.path.join("components", model_config["g_model"]["script"]+".py")
|
||||
tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py")
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
dscript_name = "components." + model_config["d_model"]["script"]
|
||||
file1 = os.path.join("components", model_config["d_model"]["script"]+".py")
|
||||
tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py")
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
|
||||
elif config["phase"] == "finetune":
|
||||
gscript_name = config["com_base"] + model_config["g_model"]["script"]
|
||||
dscript_name = config["com_base"] + model_config["d_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
reporter.writeInfo("Generator structure:")
|
||||
reporter.writeModel(gen.__str__())
|
||||
|
||||
class_name = model_config["d_model"]["class_name"]
|
||||
package = __import__(dscript_name, fromlist=True)
|
||||
dis_class = getattr(package, class_name)
|
||||
dis = dis_class(**model_config["d_model"]["module_params"])
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
reporter.writeInfo("Discriminator structure:")
|
||||
reporter.writeModel(dis.__str__())
|
||||
|
||||
# arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
|
||||
# arcface = arcface1['model'].module
|
||||
|
||||
arcface = iresnet100(pretrained=False, fp16=False)
|
||||
arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu'))
|
||||
arcface.eval()
|
||||
|
||||
# train in GPU
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if config["phase"] == "finetune":
|
||||
model_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(config["ckpt"],
|
||||
config["checkpoint_names"]["generator_name"]))
|
||||
gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu"))
|
||||
|
||||
model_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(config["ckpt"],
|
||||
config["checkpoint_names"]["discriminator_name"]))
|
||||
dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu"))
|
||||
|
||||
print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"]))
|
||||
|
||||
gen = gen.to(device)
|
||||
dis = dis.to(device)
|
||||
arcface= arcface.to(device)
|
||||
arcface.requires_grad_(False)
|
||||
arcface.eval()
|
||||
|
||||
|
||||
|
||||
return gen, dis, arcface
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def setup_optimizers(config, reporter, gen, dis, rank):
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
torch.cuda.empty_cache()
|
||||
g_train_opt = config['g_optim_config']
|
||||
d_train_opt = config['d_optim_config']
|
||||
|
||||
g_optim_params = []
|
||||
d_optim_params = []
|
||||
for k, v in gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
for k, v in dis.named_parameters():
|
||||
if v.requires_grad:
|
||||
d_optim_params.append(v)
|
||||
else:
|
||||
reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
if config["phase"] == "finetune":
|
||||
opt_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(config["ckpt"],
|
||||
config["optimizer_names"]["generator_name"]))
|
||||
g_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
opt_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(config["ckpt"],
|
||||
config["optimizer_names"]["discriminator_name"]))
|
||||
d_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"]))
|
||||
return g_optimizer, d_optimizer
|
||||
|
||||
|
||||
def train_loop(
|
||||
rank,
|
||||
config,
|
||||
reporter,
|
||||
temp_dir
|
||||
):
|
||||
|
||||
version = config["version"]
|
||||
|
||||
ckpt_dir = config["project_checkpoints"]
|
||||
sample_dir = config["project_samples"]
|
||||
|
||||
log_freq = config["log_step"]
|
||||
model_freq = config["model_save_step"]
|
||||
sample_freq = config["sample_step"]
|
||||
total_step = config["total_step"]
|
||||
random_seed = config["dataset_params"]["random_seed"]
|
||||
|
||||
|
||||
id_w = config["id_weight"]
|
||||
rec_w = config["reconstruct_weight"]
|
||||
rec_fm_w = config["rec_feature_match_weight"]
|
||||
cycle_fm_w = config["cycle_feature_match_weight"]
|
||||
cycle_w = config["cycle_weight"]
|
||||
num_gpus = len(config["gpus"])
|
||||
batch_gpu = config["batch_size"] // num_gpus
|
||||
|
||||
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
|
||||
if os.name == 'nt':
|
||||
init_method = 'file:///' + init_file.replace('\\', '/')
|
||||
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus)
|
||||
else:
|
||||
init_method = f'file://{init_file}'
|
||||
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus)
|
||||
|
||||
# Init torch_utils.
|
||||
sync_device = torch.device('cuda', rank)
|
||||
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
|
||||
|
||||
|
||||
|
||||
if rank == 0:
|
||||
img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
|
||||
|
||||
# Initialize.
|
||||
device = torch.device('cuda', rank)
|
||||
np.random.seed(random_seed * num_gpus + rank)
|
||||
torch.manual_seed(random_seed * num_gpus + rank)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy.
|
||||
torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy.
|
||||
conv2d_gradfix.enabled = True # Improves training speed.
|
||||
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
|
||||
|
||||
# Create dataloader.
|
||||
if rank == 0:
|
||||
print('Loading training set...')
|
||||
|
||||
dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
dataloader_class= dataloaderClass
|
||||
dataloader = dataloader_class(dataset,
|
||||
rank,
|
||||
num_gpus,
|
||||
batch_gpu,
|
||||
**config["dataset_params"])
|
||||
|
||||
# Construct networks.
|
||||
if rank == 0:
|
||||
print('Constructing networks...')
|
||||
gen, dis, arcface = init_framework(config, reporter, device, rank)
|
||||
|
||||
# Check for existing checkpoint
|
||||
|
||||
# Print network summary tables.
|
||||
# if rank == 0:
|
||||
# attr = torch.empty([batch_gpu, 3, 512, 512], device=device)
|
||||
# id = torch.empty([batch_gpu, 3, 112, 112], device=device)
|
||||
# latent = misc.print_module_summary(arcface, [id])
|
||||
# img = misc.print_module_summary(gen, [attr, latent])
|
||||
# misc.print_module_summary(dis, [img, None])
|
||||
# del attr
|
||||
# del id
|
||||
# del latent
|
||||
# del img
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Distribute across GPUs.
|
||||
if rank == 0:
|
||||
print(f'Distributing across {num_gpus} GPUs...')
|
||||
for module in [gen, dis, arcface]:
|
||||
if module is not None and num_gpus > 1:
|
||||
for param in misc.params_and_buffers(module):
|
||||
torch.distributed.broadcast(param, src=0)
|
||||
|
||||
# Setup training phases.
|
||||
if rank == 0:
|
||||
print('Setting up training phases...')
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
# MSE_loss = torch.nn.MSELoss()
|
||||
l1_loss = torch.nn.L1Loss()
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
|
||||
g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank)
|
||||
|
||||
# Initialize logs.
|
||||
if rank == 0:
|
||||
print('Initializing logs...')
|
||||
#==============build tensorboard=================#
|
||||
if config["logger"] == "tensorboard":
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"])
|
||||
logger = tensorboard_writer
|
||||
|
||||
elif config["logger"] == "wandb":
|
||||
import wandb
|
||||
wandb.init(project="Simswap_HQ", entity="xhchen", notes="512",
|
||||
tags=[config["tag"]], name=version)
|
||||
|
||||
wandb.config = {
|
||||
"total_step": config["total_step"],
|
||||
"batch_size": config["batch_size"]
|
||||
}
|
||||
logger = wandb
|
||||
|
||||
|
||||
random.seed(random_seed)
|
||||
randindex = [i for i in range(batch_gpu)]
|
||||
|
||||
# set the start point for training loop
|
||||
if config["phase"] == "finetune":
|
||||
start = config["ckpt"]
|
||||
else:
|
||||
start = 0
|
||||
if rank == 0:
|
||||
import datetime
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
print("Total step = %d"%total_step)
|
||||
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
|
||||
dis.feature_network.requires_grad_(False)
|
||||
|
||||
for step in range(start, total_step):
|
||||
gen.train()
|
||||
dis.train()
|
||||
for interval in range(2):
|
||||
random.shuffle(randindex)
|
||||
src_image1, src_image2 = dataloader.next()
|
||||
# if rank ==0:
|
||||
|
||||
# elapsed = time.time() - start_time
|
||||
# elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("dataloader:",elapsed)
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
else:
|
||||
img_id = src_image2[randindex]
|
||||
|
||||
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
|
||||
latent_id = arcface(img_id_112)
|
||||
latent_id = F.normalize(latent_id, p=2, dim=1)
|
||||
|
||||
if interval == 0:
|
||||
img_fake = gen(src_image1, latent_id)
|
||||
gen_logits,_ = dis(img_fake.detach(), None)
|
||||
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
|
||||
|
||||
real_logits,_ = dis(src_image2,None)
|
||||
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
|
||||
|
||||
loss_D = loss_Dgen + loss_Dreal
|
||||
d_optimizer.zero_grad(set_to_none=True)
|
||||
loss_D.backward()
|
||||
with torch.autograd.profiler.record_function('discriminator_opt'):
|
||||
# params = [param for param in dis.parameters() if param.grad is not None]
|
||||
# if len(params) > 0:
|
||||
# flat = torch.cat([param.grad.flatten() for param in params])
|
||||
# if num_gpus > 1:
|
||||
# torch.distributed.all_reduce(flat)
|
||||
# flat /= num_gpus
|
||||
# misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
|
||||
# grads = flat.split([param.numel() for param in params])
|
||||
# for param, grad in zip(params, grads):
|
||||
# param.grad = grad.reshape(param.shape)
|
||||
params = [param for param in dis.parameters() if param.grad is not None]
|
||||
flat = torch.cat([param.grad.flatten() for param in params])
|
||||
torch.distributed.all_reduce(flat)
|
||||
flat /= num_gpus
|
||||
misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
|
||||
grads = flat.split([param.numel() for param in params])
|
||||
for param, grad in zip(params, grads):
|
||||
param.grad = grad.reshape(param.shape)
|
||||
d_optimizer.step()
|
||||
# if rank ==0:
|
||||
|
||||
# elapsed = time.time() - start_time
|
||||
# elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("Discriminator training:",elapsed)
|
||||
else:
|
||||
|
||||
# model.netD.requires_grad_(True)
|
||||
img_fake = gen(src_image1, latent_id)
|
||||
# G loss
|
||||
gen_logits,feat = dis(img_fake, None)
|
||||
real_feat = dis.get_feature(src_image1)
|
||||
loss_Gmain = (-gen_logits).mean()
|
||||
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
|
||||
latent_fake = arcface(img_fake_down)
|
||||
latent_fake = F.normalize(latent_fake, p=2, dim=1)
|
||||
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
|
||||
|
||||
if step%2 == 0:
|
||||
#G_Rec
|
||||
rec_fm = l1_loss(feat["3"],real_feat["3"])
|
||||
loss_G_Rec = l1_loss(img_fake, src_image1)
|
||||
loss_G += loss_G_Rec * rec_w + rec_fm * rec_fm_w
|
||||
else:
|
||||
source1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic')
|
||||
latent_source1 = arcface(source1_down)
|
||||
latent_source1 = F.normalize(latent_source1, p=2, dim=1)
|
||||
cycle_src = gen(img_fake, latent_source1)
|
||||
cycle_loss = l1_loss(src_image1,cycle_src)
|
||||
cycle_feat = dis.get_feature(cycle_src)
|
||||
cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"])
|
||||
loss_G += cycle_loss * cycle_w + cycle_fm * cycle_fm_w
|
||||
|
||||
loss_G = loss_Gmain + loss_G_ID * id_w
|
||||
g_optimizer.zero_grad(set_to_none=True)
|
||||
loss_G.backward()
|
||||
with torch.autograd.profiler.record_function('generator_opt'):
|
||||
params = [param for param in gen.parameters() if param.grad is not None]
|
||||
flat = torch.cat([param.grad.flatten() for param in params])
|
||||
torch.distributed.all_reduce(flat)
|
||||
flat /= num_gpus
|
||||
misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
|
||||
grads = flat.split([param.numel() for param in params])
|
||||
for param, grad in zip(params, grads):
|
||||
param.grad = grad.reshape(param.shape)
|
||||
g_optimizer.step()
|
||||
# if rank ==0:
|
||||
|
||||
# elapsed = time.time() - start_time
|
||||
# elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("Generator training:",elapsed)
|
||||
|
||||
|
||||
# Print out log info
|
||||
if rank == 0 and (step + 1) % log_freq == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("ready to report losses")
|
||||
# ID_Total= loss_G_ID
|
||||
# torch.distributed.all_reduce(ID_Total)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(version, elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
reporter.writeInfo(epochinformation)
|
||||
|
||||
if config["logger"] == "tensorboard":
|
||||
logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
|
||||
logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
|
||||
logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
logger.add_scalar('D/D_real', loss_Dreal.item(), step)
|
||||
elif config["logger"] == "wandb":
|
||||
logger.log({"G_Loss": loss_G.item()}, step = step)
|
||||
logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
|
||||
logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
|
||||
logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
logger.log({"D_real": loss_Dreal.item()}, step = step)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0):
|
||||
gen.eval()
|
||||
with torch.no_grad():
|
||||
imgs = []
|
||||
zero_img = (torch.zeros_like(src_image1[0,...]))
|
||||
imgs.append(zero_img.cpu().numpy())
|
||||
save_img = ((src_image1.cpu())* img_std + img_mean).numpy()
|
||||
for r in range(batch_gpu):
|
||||
imgs.append(save_img[r,...])
|
||||
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
|
||||
id_vector_src1 = arcface(arcface_112)
|
||||
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
|
||||
|
||||
for i in range(batch_gpu):
|
||||
|
||||
imgs.append(save_img[i,...])
|
||||
image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1)
|
||||
img_fake = gen(image_infer, id_vector_src1).cpu()
|
||||
|
||||
img_fake = img_fake * img_std
|
||||
img_fake = img_fake + img_mean
|
||||
img_fake = img_fake.numpy()
|
||||
for j in range(batch_gpu):
|
||||
imgs.append(img_fake[j,...])
|
||||
print("Save test data")
|
||||
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
|
||||
plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg'))
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if rank == 0 and (step+1) % model_freq==0:
|
||||
|
||||
torch.save(gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
config["checkpoint_names"]["generator_name"])))
|
||||
torch.save(dis.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
config["checkpoint_names"]["discriminator_name"])))
|
||||
|
||||
torch.save(g_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.save(d_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
config["checkpoint_names"]["discriminator_name"])))
|
||||
print("Save step %d model checkpoint!"%(step+1))
|
||||
torch.cuda.empty_cache()
|
||||
print("Rank %d process done!"%rank)
|
||||
torch.distributed.barrier()
|
||||
@@ -1,10 +1,10 @@
|
||||
# Related scripts
|
||||
train_script_name: cycleloss
|
||||
train_script_name: multi_gpu_cycle
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: Generator
|
||||
script: Generator_LSTU_config
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 512
|
||||
@@ -22,10 +22,10 @@ model_configs:
|
||||
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
|
||||
|
||||
# Training information
|
||||
batch_size: 12
|
||||
batch_size: 24
|
||||
|
||||
# Dataset
|
||||
dataloader: VGGFace2HQ
|
||||
dataloader: VGGFace2HQ_multigpu
|
||||
dataset_name: vggface2_hq
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
@@ -40,18 +40,19 @@ eval_batch_size: 2
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: 0.0004
|
||||
lr: 0.0006
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
d_optim_config:
|
||||
lr: 0.0004
|
||||
lr: 0.0006
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
id_weight: 20.0
|
||||
reconstruct_weight: 0.1
|
||||
feature_match_weight: 0.1
|
||||
reconstruct_weight: 10.0
|
||||
rec_feature_match_weight: 10.0
|
||||
cycle_feature_match_weight: 10.0
|
||||
cycle_weight: 10.0
|
||||
|
||||
# Log
|
||||
|
||||
Reference in New Issue
Block a user