distillation

This commit is contained in:
chenxuanhong
2022-02-23 15:39:51 +08:00
parent a0428c8c73
commit 1f2aa26bd1
2 changed files with 23 additions and 5 deletions
+22 -5
View File
@@ -5,7 +5,7 @@
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 23rd February 2022 2:36:05 am
# Last Modified: Wednesday, 23rd February 2022 3:39:20 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -232,6 +232,9 @@ def train_loop(
id_w = config["id_weight"]
rec_w = config["reconstruct_weight"]
feat_w = config["feature_match_weight"]
distill_w = config["distillation_weight"]
feat_num = len(config["feature_list"])
num_gpus = len(config["gpus"])
batch_gpu = config["batch_size"] // num_gpus
@@ -386,7 +389,7 @@ def train_loop(
if interval == 0:
img_t = tgen(src_image1, latent_id)
img_fake = gen(src_image1, latent_id)
gen_logits,_ = dis(img_fake.detach(), None)
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
@@ -425,8 +428,20 @@ def train_loop(
else:
# model.netD.requires_grad_(True)
img_t = tgen(src_image1, latent_id)
img_fake = gen(src_image1, latent_id)
Sacts = [
s_feat[key] for key in sorted(s_feat.keys())
]
Tacts = [
t_feat[key] for key in sorted(t_feat.keys())
]
loss_distill = 0
for Sact, Tact in zip(Sacts, Tacts):
loss_distill += -KA(Sact, Tact)
# G loss
loss_distill /= feat_num
gen_logits,feat = dis(img_fake, None)
loss_Gmain = (-gen_logits).mean()
@@ -437,7 +452,7 @@ def train_loop(
real_feat = dis.get_feature(src_image1)
feat_match_loss = l1_loss(feat["3"],real_feat["3"])
loss_G = loss_Gmain + loss_G_ID * id_w + \
feat_match_loss * feat_w
feat_match_loss * feat_w + loss_distill * distill_w
if step%2 == 0:
#G_Rec
loss_G_Rec = l1_loss(img_fake, src_image1)
@@ -472,10 +487,10 @@ def train_loop(
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
Distillaton_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
format(version, elapsed, step, total_step, \
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
loss_distill.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
print(epochinformation)
reporter.writeInfo(epochinformation)
@@ -483,6 +498,7 @@ def train_loop(
logger.add_scalar('G/G_loss', loss_G.item(), step)
logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
logger.add_scalar('G/G_distillation', loss_distill.item(), step)
logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
logger.add_scalar('D/D_loss', loss_D.item(), step)
logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
@@ -491,6 +507,7 @@ def train_loop(
logger.log({"G_Loss": loss_G.item()}, step = step)
logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
logger.log({"G_distillation": loss_distill.item()}, step = step)
logger.log({"G_ID": loss_G_ID.item()}, step = step)
logger.log({"D_loss": loss_D.item()}, step = step)
logger.log({"D_fake": loss_Dgen.item()}, step = step)
+1
View File
@@ -61,6 +61,7 @@ d_optim_config:
id_weight: 20.0
reconstruct_weight: 10.0
feature_match_weight: 10.0
distillation_weight: 10.0
# Log
log_step: 300