From 1f2aa26bd1f0b0035a53d5bd22db5326f44f9c73 Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Wed, 23 Feb 2022 15:39:51 +0800 Subject: [PATCH] distillation --- train_scripts/trainer_distillation_mgpu.py | 27 ++++++++++++++++++---- train_yamls/train_distillation.yaml | 1 + 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/train_scripts/trainer_distillation_mgpu.py b/train_scripts/trainer_distillation_mgpu.py index d660200..751bcbe 100644 --- a/train_scripts/trainer_distillation_mgpu.py +++ b/train_scripts/trainer_distillation_mgpu.py @@ -5,7 +5,7 @@ # Created Date: Sunday January 9th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Wednesday, 23rd February 2022 2:36:05 am +# Last Modified: Wednesday, 23rd February 2022 3:39:20 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -232,6 +232,9 @@ def train_loop( id_w = config["id_weight"] rec_w = config["reconstruct_weight"] feat_w = config["feature_match_weight"] + distill_w = config["distillation_weight"] + feat_num = len(config["feature_list"]) + num_gpus = len(config["gpus"]) batch_gpu = config["batch_size"] // num_gpus @@ -386,7 +389,7 @@ def train_loop( if interval == 0: - img_t = tgen(src_image1, latent_id) + img_fake = gen(src_image1, latent_id) gen_logits,_ = dis(img_fake.detach(), None) loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() @@ -425,8 +428,20 @@ def train_loop( else: # model.netD.requires_grad_(True) + img_t = tgen(src_image1, latent_id) img_fake = gen(src_image1, latent_id) + + Sacts = [ + s_feat[key] for key in sorted(s_feat.keys()) + ] + Tacts = [ + t_feat[key] for key in sorted(t_feat.keys()) + ] + loss_distill = 0 + for Sact, Tact in zip(Sacts, Tacts): + loss_distill += -KA(Sact, Tact) # G loss + loss_distill /= feat_num gen_logits,feat = dis(img_fake, None) loss_Gmain = (-gen_logits).mean() @@ -437,7 +452,7 @@ def train_loop( real_feat = dis.get_feature(src_image1) feat_match_loss = l1_loss(feat["3"],real_feat["3"]) loss_G = loss_Gmain + loss_G_ID * id_w + \ - feat_match_loss * feat_w + feat_match_loss * feat_w + loss_distill * distill_w if step%2 == 0: #G_Rec loss_G_Rec = l1_loss(img_fake, src_image1) @@ -472,10 +487,10 @@ def train_loop( epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ - D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + Distillaton_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ format(version, elapsed, step, total_step, \ loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ - loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + loss_distill.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) print(epochinformation) reporter.writeInfo(epochinformation) @@ -483,6 +498,7 @@ def train_loop( logger.add_scalar('G/G_loss', loss_G.item(), step) logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step) + logger.add_scalar('G/G_distillation', loss_distill.item(), step) logger.add_scalar('G/G_ID', loss_G_ID.item(), step) logger.add_scalar('D/D_loss', loss_D.item(), step) logger.add_scalar('D/D_fake', loss_Dgen.item(), step) @@ -491,6 +507,7 @@ def train_loop( logger.log({"G_Loss": loss_G.item()}, step = step) logger.log({"G_Rec": loss_G_Rec.item()}, step = step) logger.log({"G_feat_match": feat_match_loss.item()}, step = step) + logger.log({"G_distillation": loss_distill.item()}, step = step) logger.log({"G_ID": loss_G_ID.item()}, step = step) logger.log({"D_loss": loss_D.item()}, step = step) logger.log({"D_fake": loss_Dgen.item()}, step = step) diff --git a/train_yamls/train_distillation.yaml b/train_yamls/train_distillation.yaml index bdb0960..22618c5 100644 --- a/train_yamls/train_distillation.yaml +++ b/train_yamls/train_distillation.yaml @@ -61,6 +61,7 @@ d_optim_config: id_weight: 20.0 reconstruct_weight: 10.0 feature_match_weight: 10.0 +distillation_weight: 10.0 # Log log_step: 300