fix remote checkpoints fetch

This commit is contained in:
chenxuanhong
2022-02-18 15:24:10 +08:00
parent f2d448e643
commit 7b597e7205
3 changed files with 60 additions and 35 deletions
+30 -33
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 17th February 2022 10:02:07 pm
# Last Modified: Friday, 18th February 2022 3:23:18 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -34,7 +34,7 @@ def getParameters():
help="version name for train, test, finetune")
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
parser.add_argument('-s', '--checkpoint_step', type=int, default=530000,
parser.add_argument('-s', '--checkpoint_step', type=int, default=640000,
help="checkpoint epoch for test phase or finetune phase")
# test
@@ -44,13 +44,15 @@ def getParameters():
choices=['localhost', '4card','8card','new4card'])
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\gxt3.jpeg')
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg')
# parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg')
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID',
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\1',
help="file path for attribute images or video")
parser.add_argument('--use_specified_data', action='store_true')
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
parser.add_argument('--use_specified_data_paths', type=str2bool, default='True', choices=['True', 'False'], help='use the specified save dir')
parser.add_argument('--specified_save_path', type=str, default="G:\\swap_data\\results", help='save results to specified dir')
# # logs (does not to be changed in most time)
# parser.add_argument('--dataloader_workers', type=int, default=6)
@@ -180,14 +182,14 @@ def main():
createDirs(sys_state)
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
# Read model_config.json from remote machine
#fetch checkpoints, model_config.json and scripts from remote machine
if sys_state["node_name"]!="localhost":
remote_mac = env_config["remote_machine"]
nodeinf = remote_mac[sys_state["node_name"]]
print("ready to fetch related files from server: %s ......"%nodeinf["ip"])
uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"])
remotebase = os.path.join(nodeinf['base_path'],"train_logs",sys_state["version"]).replace('\\','/')
remotebase = os.path.join(nodeinf['path'],"train_logs",sys_state["version"]).replace('\\','/')
# Get the config.json
print("ready to get the config.json...")
@@ -199,6 +201,28 @@ def main():
raise Exception(print("Get file %s failed! config.json does not exist!"%remoteFile))
print("success get the config.json from server %s"%nodeinf['ip'])
# Get scripts
remoteDir = os.path.join(remotebase, "scripts").replace('\\','/')
localDir = os.path.join(sys_state["project_scripts"])
ssh_state = uploader.sshScpGetDir(remoteDir, localDir)
if not ssh_state:
raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
# Get checkpoints
ckpt_name = "step%d_%s.pth"%(sys_state["checkpoint_step"],
sys_state["checkpoint_names"]["generator_name"])
localFile = os.path.join(sys_state["project_checkpoints"],ckpt_name)
if not os.path.exists(localFile):
remoteFile = os.path.join(remotebase, "checkpoints", ckpt_name).replace('\\','/')
ssh_state = uploader.sshScpGet(remoteFile, localFile, True)
if not ssh_state:
raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile))
print("Get the checkpoint %s successfully!"%(ckpt_name))
else:
print("%s exists!"%(ckpt_name))
# Read model_config.json
json_obj = readConfig(config_json)
for item in json_obj.items():
@@ -206,33 +230,6 @@ def main():
pass
else:
sys_state[item[0]] = item[1]
# Read scripts from remote machine
if sys_state["node_name"]!="localhost":
# # Get scripts
# remoteFile = os.path.join(remotebase, "scripts", sys_state["gScriptName"]+".py").replace('\\','/')
# localFile = os.path.join(sys_state["project_scripts"], sys_state["gScriptName"]+".py")
# ssh_state = uploader.sshScpGet(remoteFile, localFile)
# if not ssh_state:
# raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
# print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
# Get checkpoint of generator
localFile = os.path.join(sys_state["project_checkpoints"],
"epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"]))
if not os.path.exists(localFile):
remoteFile = os.path.join(remotebase, "checkpoints",
"epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"])).replace('\\','/')
ssh_state = uploader.sshScpGet(remoteFile, localFile, True)
if not ssh_state:
raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile))
print("Get the checkpoint %s successfully!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"])))
else:
print("%s exists!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"])))
# TODO get the checkpoint file path
sys_state["ckp_name"] = {}
# for data_key in sys_state["checkpoint_names"].keys():