update teseter_image

This commit is contained in:
chenxuanhong
2022-02-18 17:01:25 +08:00
parent 7b597e7205
commit d81ee31117
3 changed files with 36 additions and 22 deletions
+2 -1
View File
@@ -11,6 +11,7 @@
"train_config_path":"./train_yamls",
"train_scripts_path":"./train_scripts",
"test_scripts_path":"./test_scripts",
"config_json_name":"model_config.json"
"config_json_name":"model_config.json",
"machine_config":"./GUI/machines.json"
}
}
+29 -19
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 18th February 2022 3:23:18 pm
# Last Modified: Friday, 18th February 2022 3:59:57 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -34,25 +34,25 @@ def getParameters():
help="version name for train, test, finetune")
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
parser.add_argument('-s', '--checkpoint_step', type=int, default=640000,
parser.add_argument('-s', '--checkpoint_step', type=int, default=680000,
help="checkpoint epoch for test phase or finetune phase")
# test
parser.add_argument('-t', '--test_script_name', type=str, default='image')
parser.add_argument('-b', '--batch_size', type=int, default=1)
parser.add_argument('-n', '--node_name', type=str, default='localhost',
choices=['localhost', '4card','8card','new4card'])
parser.add_argument('-n', '--node_ip', type=str, default='2001:da8:8000:6880:f284:d61c:3c76:f9cb')
parser.add_argument('--crop_mode', type=str, default="vggface", choices=['ffhq','vggface'], help='crop mode for face detector')
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg')
# parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg')
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\1',
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID',
help="file path for attribute images or video")
parser.add_argument('--use_specified_data', action='store_true')
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
parser.add_argument('--use_specified_data_paths', type=str2bool, default='True', choices=['True', 'False'], help='use the specified save dir')
parser.add_argument('--specified_save_path', type=str, default="G:\\swap_data\\results", help='save results to specified dir')
parser.add_argument('--specified_save_path', type=str, default="", help='save results to specified dir')
# # logs (does not to be changed in most time)
# parser.add_argument('--dataloader_workers', type=int, default=6)
@@ -183,9 +183,16 @@ def main():
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
#fetch checkpoints, model_config.json and scripts from remote machine
if sys_state["node_name"]!="localhost":
remote_mac = env_config["remote_machine"]
nodeinf = remote_mac[sys_state["node_name"]]
if sys_state["node_ip"]!="localhost":
machine_config = env_config["machine_config"]
machine_config = readConfig(machine_config)
nodeinf = None
for item in machine_config:
if item["ip"] == sys_state["node_ip"]:
nodeinf = item
break
if not nodeinf:
raise Exception(print("Configuration of node %s is unavaliable"%sys_state["node_ip"]))
print("ready to fetch related files from server: %s ......"%nodeinf["ip"])
uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"])
@@ -207,9 +214,18 @@ def main():
ssh_state = uploader.sshScpGetDir(remoteDir, localDir)
if not ssh_state:
raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
# Get checkpoints
print("Get the scripts successful!")
# Read model_config.json
json_obj = readConfig(config_json)
for item in json_obj.items():
if item[0] in ignoreKey:
pass
else:
sys_state[item[0]] = item[1]
# Get checkpoints
if sys_state["node_ip"]!="localhost":
ckpt_name = "step%d_%s.pth"%(sys_state["checkpoint_step"],
sys_state["checkpoint_names"]["generator_name"])
localFile = os.path.join(sys_state["project_checkpoints"],ckpt_name)
@@ -223,13 +239,7 @@ def main():
else:
print("%s exists!"%(ckpt_name))
# Read model_config.json
json_obj = readConfig(config_json)
for item in json_obj.items():
if item[0] in ignoreKey:
pass
else:
sys_state[item[0]] = item[1]
# TODO get the checkpoint file path
sys_state["ckp_name"] = {}
# for data_key in sys_state["checkpoint_names"].keys():
+5 -2
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 18th February 2022 10:47:55 am
# Last Modified: Friday, 18th February 2022 5:00:28 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -87,6 +87,7 @@ class Tester(object):
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
id_imgs = self.config["id_imgs"]
crop_mode = self.config["crop_mode"]
attr_files = self.config["attr_files"]
specified_save_path = self.config["specified_save_path"]
self.arcface_ckpt= self.config["arcface_ckpt"]
@@ -113,7 +114,9 @@ class Tester(object):
# models
self.__init_framework__()
mode = None
mode = crop_mode.lower()
if mode == "vggface":
mode = "none"
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)