update
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 17th March 2022 1:01:52 am
|
||||
# Last Modified: Sunday, 27th March 2022 12:58:54 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -397,8 +397,9 @@ def train_loop(
|
||||
# model.netD.requires_grad_(True)
|
||||
img_fake = gen(src_image1, latent_id)
|
||||
# G loss
|
||||
gen_logits,feat = dis(img_fake, None)
|
||||
real_feat = dis.get_feature(src_image1)
|
||||
# gen_logits,feat = dis(img_fake, None)
|
||||
gen_logits,_ = dis(img_fake, None)
|
||||
# real_feat = dis.get_feature(src_image1)
|
||||
loss_Gmain = (-gen_logits).mean()
|
||||
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
|
||||
latent_fake = arcface(img_fake_down)
|
||||
@@ -407,18 +408,18 @@ def train_loop(
|
||||
loss_G = loss_Gmain + loss_G_ID * id_w
|
||||
if step%2 == 0:
|
||||
#G_Rec
|
||||
rec_fm = l1_loss(feat["3"],real_feat["3"])
|
||||
# rec_fm = l1_loss(feat["3"],real_feat["3"]) + l1_loss(feat["2"],real_feat["2"])
|
||||
loss_G_Rec = l1_loss(img_fake, src_image1)
|
||||
loss_G += loss_G_Rec * rec_w + rec_fm * rec_fm_w
|
||||
loss_G += loss_G_Rec * rec_w #+ rec_fm * rec_fm_w
|
||||
else:
|
||||
source1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic')
|
||||
latent_source1 = arcface(source1_down)
|
||||
latent_source1 = F.normalize(latent_source1, p=2, dim=1)
|
||||
cycle_src = gen(img_fake, latent_source1)
|
||||
cycle_loss = l1_loss(src_image1,cycle_src)
|
||||
cycle_feat = dis.get_feature(cycle_src)
|
||||
cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"])
|
||||
loss_G += cycle_loss * cycle_w + cycle_fm * cycle_fm_w
|
||||
# cycle_feat = dis.get_feature(cycle_src)
|
||||
# cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) + l1_loss(real_feat["2"],cycle_feat["2"])
|
||||
loss_G += cycle_loss * cycle_w #+ cycle_fm * cycle_fm_w
|
||||
|
||||
|
||||
g_optimizer.zero_grad(set_to_none=True)
|
||||
@@ -448,12 +449,14 @@ def train_loop(
|
||||
# ID_Total= loss_G_ID
|
||||
# torch.distributed.all_reduce(ID_Total)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \
|
||||
rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
# epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \
|
||||
# rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
# format(version, elapsed, step, total_step, \
|
||||
# loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \
|
||||
# rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(version, elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \
|
||||
rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
reporter.writeInfo(epochinformation)
|
||||
|
||||
@@ -461,8 +464,8 @@ def train_loop(
|
||||
logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
|
||||
logger.add_scalar('G/cycle_loss', cycle_loss.item(), step)
|
||||
logger.add_scalar('G/cycle_fm', cycle_fm.item(), step)
|
||||
logger.add_scalar('G/rec_fm', rec_fm.item(), step)
|
||||
# logger.add_scalar('G/cycle_fm', cycle_fm.item(), step)
|
||||
# logger.add_scalar('G/rec_fm', rec_fm.item(), step)
|
||||
logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
@@ -471,8 +474,8 @@ def train_loop(
|
||||
logger.log({"G_Loss": loss_G.item()}, step = step)
|
||||
logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
|
||||
logger.log({"cycle_loss": cycle_loss.item()}, step = step)
|
||||
logger.log({"cycle_fm": cycle_fm.item()}, step = step)
|
||||
logger.log({"rec_fm": rec_fm.item()}, step = step)
|
||||
# logger.log({"cycle_fm": cycle_fm.item()}, step = step)
|
||||
# logger.log({"rec_fm": rec_fm.item()}, step = step)
|
||||
logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
|
||||
Reference in New Issue
Block a user