diff --git a/train.py b/train.py index c2bd526..2bd8b25 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ # Created Date: Monday December 27th 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 21st April 2022 10:36:48 pm +# Last Modified: Thursday, 21st April 2022 10:53:51 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -48,6 +48,7 @@ class TrainOptions: # for displays self.parser.add_argument('--tag', type=str, default='simswap') + self.parser.add_argument('--use_tensorboard', type=str2bool, default='True') # for training self.parser.add_argument('--dataset', type=str, default="G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan", help='path to the face swapping dataset') @@ -142,9 +143,10 @@ if __name__ == '__main__': model.initialize(opt) ##################################################### - - tensorboard_writer = tensorboard.SummaryWriter(log_path) - logger = tensorboard_writer + if opt.use_tensorboard: + tensorboard_writer = tensorboard.SummaryWriter(log_path) + logger = tensorboard_writer + log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') with open(log_name, "a") as log_file: @@ -247,9 +249,9 @@ if __name__ == '__main__': "D_real":loss_Dreal.item(), "D_loss":loss_D.item() } - - for tag, value in errors.items(): - logger.add_scalar(tag, value, step) + if opt.use_tensorboard: + for tag, value in errors.items(): + logger.add_scalar(tag, value, step) message = '( step: %d, ) ' % (step) for k, v in errors.items(): message += '%s: %.3f ' % (k, v)