distillation
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 23rd February 2022 2:36:05 am
|
||||
# Last Modified: Wednesday, 23rd February 2022 3:39:20 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -232,6 +232,9 @@ def train_loop(
|
||||
id_w = config["id_weight"]
|
||||
rec_w = config["reconstruct_weight"]
|
||||
feat_w = config["feature_match_weight"]
|
||||
distill_w = config["distillation_weight"]
|
||||
feat_num = len(config["feature_list"])
|
||||
|
||||
num_gpus = len(config["gpus"])
|
||||
batch_gpu = config["batch_size"] // num_gpus
|
||||
|
||||
@@ -386,7 +389,7 @@ def train_loop(
|
||||
|
||||
if interval == 0:
|
||||
|
||||
img_t = tgen(src_image1, latent_id)
|
||||
|
||||
img_fake = gen(src_image1, latent_id)
|
||||
gen_logits,_ = dis(img_fake.detach(), None)
|
||||
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
|
||||
@@ -425,8 +428,20 @@ def train_loop(
|
||||
else:
|
||||
|
||||
# model.netD.requires_grad_(True)
|
||||
img_t = tgen(src_image1, latent_id)
|
||||
img_fake = gen(src_image1, latent_id)
|
||||
|
||||
Sacts = [
|
||||
s_feat[key] for key in sorted(s_feat.keys())
|
||||
]
|
||||
Tacts = [
|
||||
t_feat[key] for key in sorted(t_feat.keys())
|
||||
]
|
||||
loss_distill = 0
|
||||
for Sact, Tact in zip(Sacts, Tacts):
|
||||
loss_distill += -KA(Sact, Tact)
|
||||
# G loss
|
||||
loss_distill /= feat_num
|
||||
gen_logits,feat = dis(img_fake, None)
|
||||
|
||||
loss_Gmain = (-gen_logits).mean()
|
||||
@@ -437,7 +452,7 @@ def train_loop(
|
||||
real_feat = dis.get_feature(src_image1)
|
||||
feat_match_loss = l1_loss(feat["3"],real_feat["3"])
|
||||
loss_G = loss_Gmain + loss_G_ID * id_w + \
|
||||
feat_match_loss * feat_w
|
||||
feat_match_loss * feat_w + loss_distill * distill_w
|
||||
if step%2 == 0:
|
||||
#G_Rec
|
||||
loss_G_Rec = l1_loss(img_fake, src_image1)
|
||||
@@ -472,10 +487,10 @@ def train_loop(
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
Distillaton_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(version, elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
loss_distill.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
reporter.writeInfo(epochinformation)
|
||||
|
||||
@@ -483,6 +498,7 @@ def train_loop(
|
||||
logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
|
||||
logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
|
||||
logger.add_scalar('G/G_distillation', loss_distill.item(), step)
|
||||
logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
@@ -491,6 +507,7 @@ def train_loop(
|
||||
logger.log({"G_Loss": loss_G.item()}, step = step)
|
||||
logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
|
||||
logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
|
||||
logger.log({"G_distillation": loss_distill.item()}, step = step)
|
||||
logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
|
||||
@@ -61,6 +61,7 @@ d_optim_config:
|
||||
id_weight: 20.0
|
||||
reconstruct_weight: 10.0
|
||||
feature_match_weight: 10.0
|
||||
distillation_weight: 10.0
|
||||
|
||||
# Log
|
||||
log_step: 300
|
||||
|
||||
Reference in New Issue
Block a user