diff --git a/test.py b/test.py index e4eb2a6..8cba4df 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 17th February 2022 10:02:07 pm +# Last Modified: Friday, 18th February 2022 3:23:18 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -34,7 +34,7 @@ def getParameters(): help="version name for train, test, finetune") parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU - parser.add_argument('-s', '--checkpoint_step', type=int, default=530000, + parser.add_argument('-s', '--checkpoint_step', type=int, default=640000, help="checkpoint epoch for test phase or finetune phase") # test @@ -44,13 +44,15 @@ def getParameters(): choices=['localhost', '4card','8card','new4card']) - parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\gxt3.jpeg') + parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg') # parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg') - parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID', + parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\1', help="file path for attribute images or video") parser.add_argument('--use_specified_data', action='store_true') parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files') + parser.add_argument('--use_specified_data_paths', type=str2bool, default='True', choices=['True', 'False'], help='use the specified save dir') + parser.add_argument('--specified_save_path', type=str, default="G:\\swap_data\\results", help='save results to specified dir') # # logs (does not to be changed in most time) # parser.add_argument('--dataloader_workers', type=int, default=6) @@ -180,14 +182,14 @@ def main(): createDirs(sys_state) config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"]) - # Read model_config.json from remote machine + #fetch checkpoints, model_config.json and scripts from remote machine if sys_state["node_name"]!="localhost": remote_mac = env_config["remote_machine"] nodeinf = remote_mac[sys_state["node_name"]] print("ready to fetch related files from server: %s ......"%nodeinf["ip"]) uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"]) - remotebase = os.path.join(nodeinf['base_path'],"train_logs",sys_state["version"]).replace('\\','/') + remotebase = os.path.join(nodeinf['path'],"train_logs",sys_state["version"]).replace('\\','/') # Get the config.json print("ready to get the config.json...") @@ -199,6 +201,28 @@ def main(): raise Exception(print("Get file %s failed! config.json does not exist!"%remoteFile)) print("success get the config.json from server %s"%nodeinf['ip']) + # Get scripts + remoteDir = os.path.join(remotebase, "scripts").replace('\\','/') + localDir = os.path.join(sys_state["project_scripts"]) + ssh_state = uploader.sshScpGetDir(remoteDir, localDir) + if not ssh_state: + raise Exception(print("Get file %s failed! Program exists!"%remoteFile)) + print("Get the scripts:%s.py successfully"%sys_state["gScriptName"]) + + # Get checkpoints + ckpt_name = "step%d_%s.pth"%(sys_state["checkpoint_step"], + sys_state["checkpoint_names"]["generator_name"]) + localFile = os.path.join(sys_state["project_checkpoints"],ckpt_name) + if not os.path.exists(localFile): + + remoteFile = os.path.join(remotebase, "checkpoints", ckpt_name).replace('\\','/') + ssh_state = uploader.sshScpGet(remoteFile, localFile, True) + if not ssh_state: + raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile)) + print("Get the checkpoint %s successfully!"%(ckpt_name)) + else: + print("%s exists!"%(ckpt_name)) + # Read model_config.json json_obj = readConfig(config_json) for item in json_obj.items(): @@ -206,33 +230,6 @@ def main(): pass else: sys_state[item[0]] = item[1] - - # Read scripts from remote machine - if sys_state["node_name"]!="localhost": - # # Get scripts - # remoteFile = os.path.join(remotebase, "scripts", sys_state["gScriptName"]+".py").replace('\\','/') - # localFile = os.path.join(sys_state["project_scripts"], sys_state["gScriptName"]+".py") - # ssh_state = uploader.sshScpGet(remoteFile, localFile) - # if not ssh_state: - # raise Exception(print("Get file %s failed! Program exists!"%remoteFile)) - # print("Get the scripts:%s.py successfully"%sys_state["gScriptName"]) - # Get checkpoint of generator - localFile = os.path.join(sys_state["project_checkpoints"], - "epoch%d_%s.pth"%(sys_state["checkpoint_epoch"], - sys_state["checkpoint_names"]["generator_name"])) - if not os.path.exists(localFile): - remoteFile = os.path.join(remotebase, "checkpoints", - "epoch%d_%s.pth"%(sys_state["checkpoint_epoch"], - sys_state["checkpoint_names"]["generator_name"])).replace('\\','/') - ssh_state = uploader.sshScpGet(remoteFile, localFile, True) - if not ssh_state: - raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile)) - print("Get the checkpoint %s successfully!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"], - sys_state["checkpoint_names"]["generator_name"]))) - else: - print("%s exists!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"], - sys_state["checkpoint_names"]["generator_name"]))) - # TODO get the checkpoint file path sys_state["ckp_name"] = {} # for data_key in sys_state["checkpoint_names"].keys(): diff --git a/test_scripts/tester_image.py b/test_scripts/tester_image.py index 7be905a..f6ae8e0 100644 --- a/test_scripts/tester_image.py +++ b/test_scripts/tester_image.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 17th February 2022 7:12:56 pm +# Last Modified: Friday, 18th February 2022 10:47:55 am # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -88,11 +88,16 @@ class Tester(object): version = self.config["version"] id_imgs = self.config["id_imgs"] attr_files = self.config["attr_files"] + specified_save_path = self.config["specified_save_path"] self.arcface_ckpt= self.config["arcface_ckpt"] imgs_list = [] self.reporter.writeInfo("Version %s"%version) + if os.path.isdir(specified_save_path): + print("Input a legal specified save path!") + save_dir = specified_save_path + if os.path.isdir(attr_files): print("Input a dir....") imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True) diff --git a/utilities/sshupload.py b/utilities/sshupload.py index 6c24213..acdc996 100644 --- a/utilities/sshupload.py +++ b/utilities/sshupload.py @@ -5,12 +5,13 @@ # Created Date: Tuesday September 24th 2019 # Author: Lcx # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 12th January 2021 2:02:12 pm +# Last Modified: Friday, 18th February 2022 3:20:14 pm # Modified By: Chen Xuanhong # Copyright (c) 2019 Shanghai Jiao Tong University ############################################################# import paramiko,os +from pathlib import Path # ssh传输类: class fileUploaderClass(object): @@ -50,6 +51,28 @@ class fileUploaderClass(object): wocao = sftp.listdir(remoteDir) return wocao + def sshScpGetDir(self, remoteDir, localDir, showProgress=False): + self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__) + sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport()) + sftp = self.__ssh__.open_sftp() + try: + sftp.stat(remoteDir) + print("Remote dir exists!") + except: + print("Remote dir does not exist!") + return False + files = sftp.listdir(remoteDir) + for i_f in files: + i_remote_file = Path(remoteDir,i_f).as_posix() + local_file = Path(localDir,i_f) + if showProgress: + sftp.get(i_remote_file, local_file,callback=self.__putCallBack__) + else: + sftp.get(i_remote_file, local_file) + sftp.close() + self.__ssh__.close() + return True + def sshScpGet(self, remoteFile, localFile, showProgress=False): self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__) sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())