update
This commit is contained in:
+15
-10
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 22nd January 2022 12:45:09 pm
|
||||
# Last Modified: Monday, 24th January 2022 6:56:17 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -28,6 +28,9 @@ class Trainer(TrainerBase):
|
||||
config,
|
||||
reporter):
|
||||
super(Trainer, self).__init__(config, reporter)
|
||||
|
||||
import inspect
|
||||
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
|
||||
|
||||
self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
@@ -276,25 +279,27 @@ class Trainer(TrainerBase):
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(self.config["version"], elapsed, step, total_step, \
|
||||
loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(self.config["version"], elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["logger"] == "tensorboard":
|
||||
self.logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
self.logger.add_scalar('G/Rec_loss', loss_G_Rec.item(), step)
|
||||
self.logger.add_scalar('G/Fm_loss', feat_match_loss.item(), step)
|
||||
self.logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
|
||||
self.logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
|
||||
self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
self.logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
self.logger.add_scalar('D/D_real', loss_Dreal.item(), step)
|
||||
elif self.config["logger"] == "wandb":
|
||||
self.logger.log({"G_loss": loss_G.item()}, step = step)
|
||||
self.logger.log({"Rec_loss": loss_G_Rec.item()}, step = step)
|
||||
self.logger.log({"Fm_loss": feat_match_loss.item()}, step = step)
|
||||
self.logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
|
||||
self.logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
|
||||
self.logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
self.logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
self.logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
self.logger.log({"D_real": loss_Dreal.item()}, step = step)
|
||||
|
||||
Reference in New Issue
Block a user