Update train.py

This commit is contained in:
chenxuanhong
2022-04-21 22:56:03 +08:00
parent 7c44bc4b9a
commit bc5ac1ef22

View File

@@ -5,7 +5,7 @@
# Created Date: Monday December 27th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 21st April 2022 10:36:48 pm
# Last Modified: Thursday, 21st April 2022 10:53:51 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -48,6 +48,7 @@ class TrainOptions:
# for displays
self.parser.add_argument('--tag', type=str, default='simswap')
self.parser.add_argument('--use_tensorboard', type=str2bool, default='True')
# for training
self.parser.add_argument('--dataset', type=str, default="G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan", help='path to the face swapping dataset')
@@ -142,9 +143,10 @@ if __name__ == '__main__':
model.initialize(opt)
#####################################################
tensorboard_writer = tensorboard.SummaryWriter(log_path)
logger = tensorboard_writer
if opt.use_tensorboard:
tensorboard_writer = tensorboard.SummaryWriter(log_path)
logger = tensorboard_writer
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(log_name, "a") as log_file:
@@ -247,9 +249,9 @@ if __name__ == '__main__':
"D_real":loss_Dreal.item(),
"D_loss":loss_D.item()
}
for tag, value in errors.items():
logger.add_scalar(tag, value, step)
if opt.use_tensorboard:
for tag, value in errors.items():
logger.add_scalar(tag, value, step)
message = '( step: %d, ) ' % (step)
for k, v in errors.items():
message += '%s: %.3f ' % (k, v)