Update train.py
This commit is contained in:
16
train.py
16
train.py
@@ -5,7 +5,7 @@
|
||||
# Created Date: Monday December 27th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 21st April 2022 10:36:48 pm
|
||||
# Last Modified: Thursday, 21st April 2022 10:53:51 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -48,6 +48,7 @@ class TrainOptions:
|
||||
|
||||
# for displays
|
||||
self.parser.add_argument('--tag', type=str, default='simswap')
|
||||
self.parser.add_argument('--use_tensorboard', type=str2bool, default='True')
|
||||
|
||||
# for training
|
||||
self.parser.add_argument('--dataset', type=str, default="G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan", help='path to the face swapping dataset')
|
||||
@@ -142,9 +143,10 @@ if __name__ == '__main__':
|
||||
model.initialize(opt)
|
||||
|
||||
#####################################################
|
||||
|
||||
tensorboard_writer = tensorboard.SummaryWriter(log_path)
|
||||
logger = tensorboard_writer
|
||||
if opt.use_tensorboard:
|
||||
tensorboard_writer = tensorboard.SummaryWriter(log_path)
|
||||
logger = tensorboard_writer
|
||||
|
||||
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
||||
|
||||
with open(log_name, "a") as log_file:
|
||||
@@ -247,9 +249,9 @@ if __name__ == '__main__':
|
||||
"D_real":loss_Dreal.item(),
|
||||
"D_loss":loss_D.item()
|
||||
}
|
||||
|
||||
for tag, value in errors.items():
|
||||
logger.add_scalar(tag, value, step)
|
||||
if opt.use_tensorboard:
|
||||
for tag, value in errors.items():
|
||||
logger.add_scalar(tag, value, step)
|
||||
message = '( step: %d, ) ' % (step)
|
||||
for k, v in errors.items():
|
||||
message += '%s: %.3f ' % (k, v)
|
||||
|
||||
Reference in New Issue
Block a user