fix remote checkpoints fetch
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 17th February 2022 10:02:07 pm
|
||||
# Last Modified: Friday, 18th February 2022 3:23:18 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -34,7 +34,7 @@ def getParameters():
|
||||
help="version name for train, test, finetune")
|
||||
|
||||
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-s', '--checkpoint_step', type=int, default=530000,
|
||||
parser.add_argument('-s', '--checkpoint_step', type=int, default=640000,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
|
||||
# test
|
||||
@@ -44,13 +44,15 @@ def getParameters():
|
||||
choices=['localhost', '4card','8card','new4card'])
|
||||
|
||||
|
||||
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\gxt3.jpeg')
|
||||
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg')
|
||||
# parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg')
|
||||
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID',
|
||||
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\1',
|
||||
help="file path for attribute images or video")
|
||||
|
||||
parser.add_argument('--use_specified_data', action='store_true')
|
||||
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
|
||||
parser.add_argument('--use_specified_data_paths', type=str2bool, default='True', choices=['True', 'False'], help='use the specified save dir')
|
||||
parser.add_argument('--specified_save_path', type=str, default="G:\\swap_data\\results", help='save results to specified dir')
|
||||
|
||||
# # logs (does not to be changed in most time)
|
||||
# parser.add_argument('--dataloader_workers', type=int, default=6)
|
||||
@@ -180,14 +182,14 @@ def main():
|
||||
createDirs(sys_state)
|
||||
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
|
||||
|
||||
# Read model_config.json from remote machine
|
||||
#fetch checkpoints, model_config.json and scripts from remote machine
|
||||
if sys_state["node_name"]!="localhost":
|
||||
remote_mac = env_config["remote_machine"]
|
||||
nodeinf = remote_mac[sys_state["node_name"]]
|
||||
print("ready to fetch related files from server: %s ......"%nodeinf["ip"])
|
||||
uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"])
|
||||
|
||||
remotebase = os.path.join(nodeinf['base_path'],"train_logs",sys_state["version"]).replace('\\','/')
|
||||
remotebase = os.path.join(nodeinf['path'],"train_logs",sys_state["version"]).replace('\\','/')
|
||||
|
||||
# Get the config.json
|
||||
print("ready to get the config.json...")
|
||||
@@ -199,6 +201,28 @@ def main():
|
||||
raise Exception(print("Get file %s failed! config.json does not exist!"%remoteFile))
|
||||
print("success get the config.json from server %s"%nodeinf['ip'])
|
||||
|
||||
# Get scripts
|
||||
remoteDir = os.path.join(remotebase, "scripts").replace('\\','/')
|
||||
localDir = os.path.join(sys_state["project_scripts"])
|
||||
ssh_state = uploader.sshScpGetDir(remoteDir, localDir)
|
||||
if not ssh_state:
|
||||
raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
|
||||
print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
|
||||
|
||||
# Get checkpoints
|
||||
ckpt_name = "step%d_%s.pth"%(sys_state["checkpoint_step"],
|
||||
sys_state["checkpoint_names"]["generator_name"])
|
||||
localFile = os.path.join(sys_state["project_checkpoints"],ckpt_name)
|
||||
if not os.path.exists(localFile):
|
||||
|
||||
remoteFile = os.path.join(remotebase, "checkpoints", ckpt_name).replace('\\','/')
|
||||
ssh_state = uploader.sshScpGet(remoteFile, localFile, True)
|
||||
if not ssh_state:
|
||||
raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile))
|
||||
print("Get the checkpoint %s successfully!"%(ckpt_name))
|
||||
else:
|
||||
print("%s exists!"%(ckpt_name))
|
||||
|
||||
# Read model_config.json
|
||||
json_obj = readConfig(config_json)
|
||||
for item in json_obj.items():
|
||||
@@ -206,33 +230,6 @@ def main():
|
||||
pass
|
||||
else:
|
||||
sys_state[item[0]] = item[1]
|
||||
|
||||
# Read scripts from remote machine
|
||||
if sys_state["node_name"]!="localhost":
|
||||
# # Get scripts
|
||||
# remoteFile = os.path.join(remotebase, "scripts", sys_state["gScriptName"]+".py").replace('\\','/')
|
||||
# localFile = os.path.join(sys_state["project_scripts"], sys_state["gScriptName"]+".py")
|
||||
# ssh_state = uploader.sshScpGet(remoteFile, localFile)
|
||||
# if not ssh_state:
|
||||
# raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
|
||||
# print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
|
||||
# Get checkpoint of generator
|
||||
localFile = os.path.join(sys_state["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
sys_state["checkpoint_names"]["generator_name"]))
|
||||
if not os.path.exists(localFile):
|
||||
remoteFile = os.path.join(remotebase, "checkpoints",
|
||||
"epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
sys_state["checkpoint_names"]["generator_name"])).replace('\\','/')
|
||||
ssh_state = uploader.sshScpGet(remoteFile, localFile, True)
|
||||
if not ssh_state:
|
||||
raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile))
|
||||
print("Get the checkpoint %s successfully!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
sys_state["checkpoint_names"]["generator_name"])))
|
||||
else:
|
||||
print("%s exists!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
sys_state["checkpoint_names"]["generator_name"])))
|
||||
|
||||
# TODO get the checkpoint file path
|
||||
sys_state["ckp_name"] = {}
|
||||
# for data_key in sys_state["checkpoint_names"].keys():
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 17th February 2022 7:12:56 pm
|
||||
# Last Modified: Friday, 18th February 2022 10:47:55 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -88,11 +88,16 @@ class Tester(object):
|
||||
version = self.config["version"]
|
||||
id_imgs = self.config["id_imgs"]
|
||||
attr_files = self.config["attr_files"]
|
||||
specified_save_path = self.config["specified_save_path"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
|
||||
self.reporter.writeInfo("Version %s"%version)
|
||||
|
||||
if os.path.isdir(specified_save_path):
|
||||
print("Input a legal specified save path!")
|
||||
save_dir = specified_save_path
|
||||
|
||||
if os.path.isdir(attr_files):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
|
||||
|
||||
+24
-1
@@ -5,12 +5,13 @@
|
||||
# Created Date: Tuesday September 24th 2019
|
||||
# Author: Lcx
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th January 2021 2:02:12 pm
|
||||
# Last Modified: Friday, 18th February 2022 3:20:14 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2019 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import paramiko,os
|
||||
from pathlib import Path
|
||||
# ssh传输类:
|
||||
|
||||
class fileUploaderClass(object):
|
||||
@@ -50,6 +51,28 @@ class fileUploaderClass(object):
|
||||
wocao = sftp.listdir(remoteDir)
|
||||
return wocao
|
||||
|
||||
def sshScpGetDir(self, remoteDir, localDir, showProgress=False):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
try:
|
||||
sftp.stat(remoteDir)
|
||||
print("Remote dir exists!")
|
||||
except:
|
||||
print("Remote dir does not exist!")
|
||||
return False
|
||||
files = sftp.listdir(remoteDir)
|
||||
for i_f in files:
|
||||
i_remote_file = Path(remoteDir,i_f).as_posix()
|
||||
local_file = Path(localDir,i_f)
|
||||
if showProgress:
|
||||
sftp.get(i_remote_file, local_file,callback=self.__putCallBack__)
|
||||
else:
|
||||
sftp.get(i_remote_file, local_file)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
return True
|
||||
|
||||
def sshScpGet(self, remoteFile, localFile, showProgress=False):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
|
||||
Reference in New Issue
Block a user