From d81ee31117efe1ec7f4335c016fc24274041804f Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Fri, 18 Feb 2022 17:01:25 +0800 Subject: [PATCH] update teseter_image --- env/env.json | 3 ++- test.py | 48 ++++++++++++++++++++++-------------- test_scripts/tester_image.py | 7 ++++-- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/env/env.json b/env/env.json index 942909d..abe897f 100644 --- a/env/env.json +++ b/env/env.json @@ -11,6 +11,7 @@ "train_config_path":"./train_yamls", "train_scripts_path":"./train_scripts", "test_scripts_path":"./test_scripts", - "config_json_name":"model_config.json" + "config_json_name":"model_config.json", + "machine_config":"./GUI/machines.json" } } \ No newline at end of file diff --git a/test.py b/test.py index 8cba4df..abe5c37 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Friday, 18th February 2022 3:23:18 pm +# Last Modified: Friday, 18th February 2022 3:59:57 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -34,25 +34,25 @@ def getParameters(): help="version name for train, test, finetune") parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU - parser.add_argument('-s', '--checkpoint_step', type=int, default=640000, + parser.add_argument('-s', '--checkpoint_step', type=int, default=680000, help="checkpoint epoch for test phase or finetune phase") # test parser.add_argument('-t', '--test_script_name', type=str, default='image') parser.add_argument('-b', '--batch_size', type=int, default=1) - parser.add_argument('-n', '--node_name', type=str, default='localhost', - choices=['localhost', '4card','8card','new4card']) + parser.add_argument('-n', '--node_ip', type=str, default='2001:da8:8000:6880:f284:d61c:3c76:f9cb') + parser.add_argument('--crop_mode', type=str, default="vggface", choices=['ffhq','vggface'], help='crop mode for face detector') parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg') # parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg') - parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\1', + parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID', help="file path for attribute images or video") parser.add_argument('--use_specified_data', action='store_true') parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files') parser.add_argument('--use_specified_data_paths', type=str2bool, default='True', choices=['True', 'False'], help='use the specified save dir') - parser.add_argument('--specified_save_path', type=str, default="G:\\swap_data\\results", help='save results to specified dir') + parser.add_argument('--specified_save_path', type=str, default="", help='save results to specified dir') # # logs (does not to be changed in most time) # parser.add_argument('--dataloader_workers', type=int, default=6) @@ -183,9 +183,16 @@ def main(): config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"]) #fetch checkpoints, model_config.json and scripts from remote machine - if sys_state["node_name"]!="localhost": - remote_mac = env_config["remote_machine"] - nodeinf = remote_mac[sys_state["node_name"]] + if sys_state["node_ip"]!="localhost": + machine_config = env_config["machine_config"] + machine_config = readConfig(machine_config) + nodeinf = None + for item in machine_config: + if item["ip"] == sys_state["node_ip"]: + nodeinf = item + break + if not nodeinf: + raise Exception(print("Configuration of node %s is unavaliable"%sys_state["node_ip"])) print("ready to fetch related files from server: %s ......"%nodeinf["ip"]) uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"]) @@ -207,9 +214,18 @@ def main(): ssh_state = uploader.sshScpGetDir(remoteDir, localDir) if not ssh_state: raise Exception(print("Get file %s failed! Program exists!"%remoteFile)) - print("Get the scripts:%s.py successfully"%sys_state["gScriptName"]) - - # Get checkpoints + print("Get the scripts successful!") + # Read model_config.json + json_obj = readConfig(config_json) + for item in json_obj.items(): + if item[0] in ignoreKey: + pass + else: + sys_state[item[0]] = item[1] + + # Get checkpoints + if sys_state["node_ip"]!="localhost": + ckpt_name = "step%d_%s.pth"%(sys_state["checkpoint_step"], sys_state["checkpoint_names"]["generator_name"]) localFile = os.path.join(sys_state["project_checkpoints"],ckpt_name) @@ -223,13 +239,7 @@ def main(): else: print("%s exists!"%(ckpt_name)) - # Read model_config.json - json_obj = readConfig(config_json) - for item in json_obj.items(): - if item[0] in ignoreKey: - pass - else: - sys_state[item[0]] = item[1] + # TODO get the checkpoint file path sys_state["ckp_name"] = {} # for data_key in sys_state["checkpoint_names"].keys(): diff --git a/test_scripts/tester_image.py b/test_scripts/tester_image.py index f6ae8e0..6604843 100644 --- a/test_scripts/tester_image.py +++ b/test_scripts/tester_image.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Friday, 18th February 2022 10:47:55 am +# Last Modified: Friday, 18th February 2022 5:00:28 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -87,6 +87,7 @@ class Tester(object): ckp_step = self.config["checkpoint_step"] version = self.config["version"] id_imgs = self.config["id_imgs"] + crop_mode = self.config["crop_mode"] attr_files = self.config["attr_files"] specified_save_path = self.config["specified_save_path"] self.arcface_ckpt= self.config["arcface_ckpt"] @@ -113,7 +114,9 @@ class Tester(object): # models self.__init_framework__() - mode = None + mode = crop_mode.lower() + if mode == "vggface": + mode = "none" self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)