From 9f3daca179ec60f216cad49099129db806c4005f Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Thu, 21 Apr 2022 22:39:55 +0800 Subject: [PATCH] update --- README.md | 1 + models/base_model.py | 7 ------- train.py | 10 +++++----- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 620b7d4..9ed8c16 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ If you find this project useful, please star it. It is the greatest appreciation Download the dataset from [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ). The training script is slightly different from the original version, e.g., we replace the patch discriminator with the projected discriminator, which saves a lot of hardware overhead and achieves slightly better results. +In order to ensure normal training, the batch size must be greater than 1. - Train 256 models ``` diff --git a/models/base_model.py b/models/base_model.py index 3c6ca43..0a6474a 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -51,13 +51,6 @@ class BaseModel(torch.nn.Module): save_path = os.path.join(self.save_dir, save_filename) torch.save(network.state_dict(), save_path) - # helper saving function that can be used by subclasses - def save_network(self, network, network_label, epoch_label, gpu_ids=None): - save_filename = '%s_net_%s.pth' % (epoch_label, network_label) - save_path = os.path.join(self.save_dir, save_filename) - torch.save(network.cpu().state_dict(), save_path) - if len(gpu_ids) and torch.cuda.is_available(): - network.cuda() # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label, save_dir=''): diff --git a/train.py b/train.py index 87a3b58..c2bd526 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ # Created Date: Monday December 27th 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 21st April 2022 8:10:05 pm +# Last Modified: Thursday, 21st April 2022 10:36:48 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -44,7 +44,7 @@ class TrainOptions: self.parser.add_argument('--isTrain', type=str2bool, default='True') # input/output sizes - self.parser.add_argument('--batchSize', type=int, default=16, help='input batch size') + self.parser.add_argument('--batchSize', type=int, default=2, help='input batch size') # for displays self.parser.add_argument('--tag', type=str, default='simswap') @@ -69,9 +69,9 @@ class TrainOptions: self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT") self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step') - self.parser.add_argument("--log_frep", type=int, default=250, help='frequence for printing log information') - self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling') - self.parser.add_argument("--model_freq", type=int, default=10000, help='frequence for saving the model') + self.parser.add_argument("--log_frep", type=int, default=10, help='frequence for printing log information') + self.parser.add_argument("--sample_freq", type=int, default=30, help='frequence for sampling') + self.parser.add_argument("--model_freq", type=int, default=40, help='frequence for saving the model')