update teseter_image

This commit is contained in:
chenxuanhong
2022-02-18 17:01:25 +08:00
parent 7b597e7205
commit d81ee31117
3 changed files with 36 additions and 22 deletions
+29 -19
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 18th February 2022 3:23:18 pm
# Last Modified: Friday, 18th February 2022 3:59:57 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -34,25 +34,25 @@ def getParameters():
help="version name for train, test, finetune")
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
parser.add_argument('-s', '--checkpoint_step', type=int, default=640000,
parser.add_argument('-s', '--checkpoint_step', type=int, default=680000,
help="checkpoint epoch for test phase or finetune phase")
# test
parser.add_argument('-t', '--test_script_name', type=str, default='image')
parser.add_argument('-b', '--batch_size', type=int, default=1)
parser.add_argument('-n', '--node_name', type=str, default='localhost',
choices=['localhost', '4card','8card','new4card'])
parser.add_argument('-n', '--node_ip', type=str, default='2001:da8:8000:6880:f284:d61c:3c76:f9cb')
parser.add_argument('--crop_mode', type=str, default="vggface", choices=['ffhq','vggface'], help='crop mode for face detector')
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg')
# parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg')
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\1',
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID',
help="file path for attribute images or video")
parser.add_argument('--use_specified_data', action='store_true')
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
parser.add_argument('--use_specified_data_paths', type=str2bool, default='True', choices=['True', 'False'], help='use the specified save dir')
parser.add_argument('--specified_save_path', type=str, default="G:\\swap_data\\results", help='save results to specified dir')
parser.add_argument('--specified_save_path', type=str, default="", help='save results to specified dir')
# # logs (does not to be changed in most time)
# parser.add_argument('--dataloader_workers', type=int, default=6)
@@ -183,9 +183,16 @@ def main():
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
#fetch checkpoints, model_config.json and scripts from remote machine
if sys_state["node_name"]!="localhost":
remote_mac = env_config["remote_machine"]
nodeinf = remote_mac[sys_state["node_name"]]
if sys_state["node_ip"]!="localhost":
machine_config = env_config["machine_config"]
machine_config = readConfig(machine_config)
nodeinf = None
for item in machine_config:
if item["ip"] == sys_state["node_ip"]:
nodeinf = item
break
if not nodeinf:
raise Exception(print("Configuration of node %s is unavaliable"%sys_state["node_ip"]))
print("ready to fetch related files from server: %s ......"%nodeinf["ip"])
uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"])
@@ -207,9 +214,18 @@ def main():
ssh_state = uploader.sshScpGetDir(remoteDir, localDir)
if not ssh_state:
raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
# Get checkpoints
print("Get the scripts successful!")
# Read model_config.json
json_obj = readConfig(config_json)
for item in json_obj.items():
if item[0] in ignoreKey:
pass
else:
sys_state[item[0]] = item[1]
# Get checkpoints
if sys_state["node_ip"]!="localhost":
ckpt_name = "step%d_%s.pth"%(sys_state["checkpoint_step"],
sys_state["checkpoint_names"]["generator_name"])
localFile = os.path.join(sys_state["project_checkpoints"],ckpt_name)
@@ -223,13 +239,7 @@ def main():
else:
print("%s exists!"%(ckpt_name))
# Read model_config.json
json_obj = readConfig(config_json)
for item in json_obj.items():
if item[0] in ignoreKey:
pass
else:
sys_state[item[0]] = item[1]
# TODO get the checkpoint file path
sys_state["ckp_name"] = {}
# for data_key in sys_state["checkpoint_names"].keys():