This commit is contained in:
chenxuanhong
2022-01-21 18:01:36 +08:00
parent e698d99173
commit bebaeef2ce
8 changed files with 678 additions and 101 deletions
+12 -13
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 7:44:02 pm
# Last Modified: Friday, 21st January 2022 10:55:59 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -30,24 +30,23 @@ def getParameters():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('-v', '--version', type=str, default='fastnst_3',
parser.add_argument('-v', '--version', type=str, default='2layerFM',
help="version name for train, test, finetune")
parser.add_argument('-c', '--cuda', type=int, default=-1) # >0 if it is set as -1, program will use CPU
parser.add_argument('-e', '--checkpoint_epoch', type=int, default=19,
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
parser.add_argument('-s', '--checkpoint_step', type=int, default=310000,
help="checkpoint epoch for test phase or finetune phase")
# test
parser.add_argument('-t', '--test_script_name', type=str, default='FastNST')
parser.add_argument('-t', '--test_script_name', type=str, default='video')
parser.add_argument('-b', '--batch_size', type=int, default=1)
parser.add_argument('-n', '--node_name', type=str, default='localhost',
choices=['localhost', '4card','8card','new4card'])
parser.add_argument('--save_test_result', action='store_false')
parser.add_argument('--test_dataloader', type=str, default='dir')
parser.add_argument('-p', '--test_data_path', type=str, default='G:\\UltraHighStyleTransfer\\benchmark')
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\dlrb2.jpeg')
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\G2010.mp4',
help="file path for attribute images or video")
parser.add_argument('--use_specified_data', action='store_true')
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
@@ -235,10 +234,10 @@ def main():
# TODO get the checkpoint file path
sys_state["ckp_name"] = {}
for data_key in sys_state["checkpoint_names"].keys():
sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"],
"%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"][data_key]))
# for data_key in sys_state["checkpoint_names"].keys():
# sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"],
# "%d_%s.pth"%(sys_state["checkpoint_epoch"],
# sys_state["checkpoint_names"][data_key]))
# Get the test configurations
sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"]