update
This commit is contained in:
@@ -100,9 +100,11 @@ def init_framework(config, reporter, device, rank):
|
||||
# arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
|
||||
# arcface = arcface1['model'].module
|
||||
|
||||
arcface = iresnet100(pretrained=False, fp16=False)
|
||||
arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu'))
|
||||
arcface.eval()
|
||||
# arcface = iresnet100(pretrained=False, fp16=False)
|
||||
# arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu'))
|
||||
# arcface.eval()
|
||||
arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
|
||||
arcface = arcface1['model'].module
|
||||
|
||||
# train in GPU
|
||||
|
||||
@@ -402,7 +404,7 @@ def train_loop(
|
||||
latent_fake = arcface(img_fake_down)
|
||||
latent_fake = F.normalize(latent_fake, p=2, dim=1)
|
||||
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
|
||||
|
||||
loss_G = loss_Gmain + loss_G_ID * id_w
|
||||
if step%2 == 0:
|
||||
#G_Rec
|
||||
rec_fm = l1_loss(feat["3"],real_feat["3"])
|
||||
@@ -418,7 +420,7 @@ def train_loop(
|
||||
cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"])
|
||||
loss_G += cycle_loss * cycle_w + cycle_fm * cycle_fm_w
|
||||
|
||||
loss_G = loss_Gmain + loss_G_ID * id_w
|
||||
|
||||
g_optimizer.zero_grad(set_to_none=True)
|
||||
loss_G.backward()
|
||||
with torch.autograd.profiler.record_function('generator_opt'):
|
||||
@@ -447,18 +449,20 @@ def train_loop(
|
||||
# torch.distributed.all_reduce(ID_Total)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \
|
||||
rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(version, elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \
|
||||
rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
reporter.writeInfo(epochinformation)
|
||||
|
||||
if config["logger"] == "tensorboard":
|
||||
logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
|
||||
logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
|
||||
logger.add_scalar('G/cycle_loss', cycle_loss.item(), step)
|
||||
logger.add_scalar('G/cycle_fm', cycle_fm.item(), step)
|
||||
logger.add_scalar('G/rec_fm', rec_fm.item(), step)
|
||||
logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
@@ -466,7 +470,9 @@ def train_loop(
|
||||
elif config["logger"] == "wandb":
|
||||
logger.log({"G_Loss": loss_G.item()}, step = step)
|
||||
logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
|
||||
logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
|
||||
logger.log({"cycle_loss": cycle_loss.item()}, step = step)
|
||||
logger.log({"cycle_fm": cycle_fm.item()}, step = step)
|
||||
logger.log({"rec_fm": rec_fm.item()}, step = step)
|
||||
logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
|
||||
Reference in New Issue
Block a user