This commit is contained in:
chenxuanhong
2022-01-24 19:01:00 +08:00
parent 0f8c2f929e
commit 94534e2e30
10 changed files with 765 additions and 79 deletions
+15 -10
View File
@@ -5,7 +5,7 @@
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Saturday, 22nd January 2022 12:45:09 pm
# Last Modified: Monday, 24th January 2022 6:56:17 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -28,6 +28,9 @@ class Trainer(TrainerBase):
config,
reporter):
super(Trainer, self).__init__(config, reporter)
import inspect
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
@@ -276,25 +279,27 @@ class Trainer(TrainerBase):
elapsed = str(datetime.timedelta(seconds=elapsed))
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
format(self.config["version"], elapsed, step, total_step, \
loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
format(self.config["version"], elapsed, step, total_step, \
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["logger"] == "tensorboard":
self.logger.add_scalar('G/G_loss', loss_G.item(), step)
self.logger.add_scalar('G/Rec_loss', loss_G_Rec.item(), step)
self.logger.add_scalar('G/Fm_loss', feat_match_loss.item(), step)
self.logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
self.logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
self.logger.add_scalar('D/D_loss', loss_D.item(), step)
self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
self.logger.add_scalar('D/D_real', loss_Dreal.item(), step)
elif self.config["logger"] == "wandb":
self.logger.log({"G_loss": loss_G.item()}, step = step)
self.logger.log({"Rec_loss": loss_G_Rec.item()}, step = step)
self.logger.log({"Fm_loss": feat_match_loss.item()}, step = step)
self.logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
self.logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
self.logger.log({"G_ID": loss_G_ID.item()}, step = step)
self.logger.log({"D_loss": loss_D.item()}, step = step)
self.logger.log({"D_fake": loss_Dgen.item()}, step = step)
self.logger.log({"D_real": loss_Dreal.item()}, step = step)