fix remote checkpoints fetch

This commit is contained in:
chenxuanhong
2022-02-18 15:24:10 +08:00
parent f2d448e643
commit 7b597e7205
3 changed files with 60 additions and 35 deletions
+30 -33
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 17th February 2022 10:02:07 pm
# Last Modified: Friday, 18th February 2022 3:23:18 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -34,7 +34,7 @@ def getParameters():
help="version name for train, test, finetune")
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
parser.add_argument('-s', '--checkpoint_step', type=int, default=530000,
parser.add_argument('-s', '--checkpoint_step', type=int, default=640000,
help="checkpoint epoch for test phase or finetune phase")
# test
@@ -44,13 +44,15 @@ def getParameters():
choices=['localhost', '4card','8card','new4card'])
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\gxt3.jpeg')
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg')
# parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg')
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID',
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\1',
help="file path for attribute images or video")
parser.add_argument('--use_specified_data', action='store_true')
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
parser.add_argument('--use_specified_data_paths', type=str2bool, default='True', choices=['True', 'False'], help='use the specified save dir')
parser.add_argument('--specified_save_path', type=str, default="G:\\swap_data\\results", help='save results to specified dir')
# # logs (does not to be changed in most time)
# parser.add_argument('--dataloader_workers', type=int, default=6)
@@ -180,14 +182,14 @@ def main():
createDirs(sys_state)
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
# Read model_config.json from remote machine
#fetch checkpoints, model_config.json and scripts from remote machine
if sys_state["node_name"]!="localhost":
remote_mac = env_config["remote_machine"]
nodeinf = remote_mac[sys_state["node_name"]]
print("ready to fetch related files from server: %s ......"%nodeinf["ip"])
uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"])
remotebase = os.path.join(nodeinf['base_path'],"train_logs",sys_state["version"]).replace('\\','/')
remotebase = os.path.join(nodeinf['path'],"train_logs",sys_state["version"]).replace('\\','/')
# Get the config.json
print("ready to get the config.json...")
@@ -199,6 +201,28 @@ def main():
raise Exception(print("Get file %s failed! config.json does not exist!"%remoteFile))
print("success get the config.json from server %s"%nodeinf['ip'])
# Get scripts
remoteDir = os.path.join(remotebase, "scripts").replace('\\','/')
localDir = os.path.join(sys_state["project_scripts"])
ssh_state = uploader.sshScpGetDir(remoteDir, localDir)
if not ssh_state:
raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
# Get checkpoints
ckpt_name = "step%d_%s.pth"%(sys_state["checkpoint_step"],
sys_state["checkpoint_names"]["generator_name"])
localFile = os.path.join(sys_state["project_checkpoints"],ckpt_name)
if not os.path.exists(localFile):
remoteFile = os.path.join(remotebase, "checkpoints", ckpt_name).replace('\\','/')
ssh_state = uploader.sshScpGet(remoteFile, localFile, True)
if not ssh_state:
raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile))
print("Get the checkpoint %s successfully!"%(ckpt_name))
else:
print("%s exists!"%(ckpt_name))
# Read model_config.json
json_obj = readConfig(config_json)
for item in json_obj.items():
@@ -206,33 +230,6 @@ def main():
pass
else:
sys_state[item[0]] = item[1]
# Read scripts from remote machine
if sys_state["node_name"]!="localhost":
# # Get scripts
# remoteFile = os.path.join(remotebase, "scripts", sys_state["gScriptName"]+".py").replace('\\','/')
# localFile = os.path.join(sys_state["project_scripts"], sys_state["gScriptName"]+".py")
# ssh_state = uploader.sshScpGet(remoteFile, localFile)
# if not ssh_state:
# raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
# print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
# Get checkpoint of generator
localFile = os.path.join(sys_state["project_checkpoints"],
"epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"]))
if not os.path.exists(localFile):
remoteFile = os.path.join(remotebase, "checkpoints",
"epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"])).replace('\\','/')
ssh_state = uploader.sshScpGet(remoteFile, localFile, True)
if not ssh_state:
raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile))
print("Get the checkpoint %s successfully!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"])))
else:
print("%s exists!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"])))
# TODO get the checkpoint file path
sys_state["ckp_name"] = {}
# for data_key in sys_state["checkpoint_names"].keys():
+6 -1
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 17th February 2022 7:12:56 pm
# Last Modified: Friday, 18th February 2022 10:47:55 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -88,11 +88,16 @@ class Tester(object):
version = self.config["version"]
id_imgs = self.config["id_imgs"]
attr_files = self.config["attr_files"]
specified_save_path = self.config["specified_save_path"]
self.arcface_ckpt= self.config["arcface_ckpt"]
imgs_list = []
self.reporter.writeInfo("Version %s"%version)
if os.path.isdir(specified_save_path):
print("Input a legal specified save path!")
save_dir = specified_save_path
if os.path.isdir(attr_files):
print("Input a dir....")
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
+24 -1
View File
@@ -5,12 +5,13 @@
# Created Date: Tuesday September 24th 2019
# Author: Lcx
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th January 2021 2:02:12 pm
# Last Modified: Friday, 18th February 2022 3:20:14 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2019 Shanghai Jiao Tong University
#############################################################
import paramiko,os
from pathlib import Path
# ssh传输类:
class fileUploaderClass(object):
@@ -50,6 +51,28 @@ class fileUploaderClass(object):
wocao = sftp.listdir(remoteDir)
return wocao
def sshScpGetDir(self, remoteDir, localDir, showProgress=False):
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
try:
sftp.stat(remoteDir)
print("Remote dir exists!")
except:
print("Remote dir does not exist!")
return False
files = sftp.listdir(remoteDir)
for i_f in files:
i_remote_file = Path(remoteDir,i_f).as_posix()
local_file = Path(localDir,i_f)
if showProgress:
sftp.get(i_remote_file, local_file,callback=self.__putCallBack__)
else:
sftp.get(i_remote_file, local_file)
sftp.close()
self.__ssh__.close()
return True
def sshScpGet(self, remoteFile, localFile, showProgress=False):
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())