update teseter_image
This commit is contained in:
Vendored
+2
-1
@@ -11,6 +11,7 @@
|
||||
"train_config_path":"./train_yamls",
|
||||
"train_scripts_path":"./train_scripts",
|
||||
"test_scripts_path":"./test_scripts",
|
||||
"config_json_name":"model_config.json"
|
||||
"config_json_name":"model_config.json",
|
||||
"machine_config":"./GUI/machines.json"
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 18th February 2022 3:23:18 pm
|
||||
# Last Modified: Friday, 18th February 2022 3:59:57 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -34,25 +34,25 @@ def getParameters():
|
||||
help="version name for train, test, finetune")
|
||||
|
||||
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-s', '--checkpoint_step', type=int, default=640000,
|
||||
parser.add_argument('-s', '--checkpoint_step', type=int, default=680000,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
|
||||
# test
|
||||
parser.add_argument('-t', '--test_script_name', type=str, default='image')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=1)
|
||||
parser.add_argument('-n', '--node_name', type=str, default='localhost',
|
||||
choices=['localhost', '4card','8card','new4card'])
|
||||
parser.add_argument('-n', '--node_ip', type=str, default='2001:da8:8000:6880:f284:d61c:3c76:f9cb')
|
||||
parser.add_argument('--crop_mode', type=str, default="vggface", choices=['ffhq','vggface'], help='crop mode for face detector')
|
||||
|
||||
|
||||
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg')
|
||||
# parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg')
|
||||
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\1',
|
||||
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID',
|
||||
help="file path for attribute images or video")
|
||||
|
||||
parser.add_argument('--use_specified_data', action='store_true')
|
||||
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
|
||||
parser.add_argument('--use_specified_data_paths', type=str2bool, default='True', choices=['True', 'False'], help='use the specified save dir')
|
||||
parser.add_argument('--specified_save_path', type=str, default="G:\\swap_data\\results", help='save results to specified dir')
|
||||
parser.add_argument('--specified_save_path', type=str, default="", help='save results to specified dir')
|
||||
|
||||
# # logs (does not to be changed in most time)
|
||||
# parser.add_argument('--dataloader_workers', type=int, default=6)
|
||||
@@ -183,9 +183,16 @@ def main():
|
||||
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
|
||||
|
||||
#fetch checkpoints, model_config.json and scripts from remote machine
|
||||
if sys_state["node_name"]!="localhost":
|
||||
remote_mac = env_config["remote_machine"]
|
||||
nodeinf = remote_mac[sys_state["node_name"]]
|
||||
if sys_state["node_ip"]!="localhost":
|
||||
machine_config = env_config["machine_config"]
|
||||
machine_config = readConfig(machine_config)
|
||||
nodeinf = None
|
||||
for item in machine_config:
|
||||
if item["ip"] == sys_state["node_ip"]:
|
||||
nodeinf = item
|
||||
break
|
||||
if not nodeinf:
|
||||
raise Exception(print("Configuration of node %s is unavaliable"%sys_state["node_ip"]))
|
||||
print("ready to fetch related files from server: %s ......"%nodeinf["ip"])
|
||||
uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"])
|
||||
|
||||
@@ -207,9 +214,18 @@ def main():
|
||||
ssh_state = uploader.sshScpGetDir(remoteDir, localDir)
|
||||
if not ssh_state:
|
||||
raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
|
||||
print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
|
||||
|
||||
# Get checkpoints
|
||||
print("Get the scripts successful!")
|
||||
# Read model_config.json
|
||||
json_obj = readConfig(config_json)
|
||||
for item in json_obj.items():
|
||||
if item[0] in ignoreKey:
|
||||
pass
|
||||
else:
|
||||
sys_state[item[0]] = item[1]
|
||||
|
||||
# Get checkpoints
|
||||
if sys_state["node_ip"]!="localhost":
|
||||
|
||||
ckpt_name = "step%d_%s.pth"%(sys_state["checkpoint_step"],
|
||||
sys_state["checkpoint_names"]["generator_name"])
|
||||
localFile = os.path.join(sys_state["project_checkpoints"],ckpt_name)
|
||||
@@ -223,13 +239,7 @@ def main():
|
||||
else:
|
||||
print("%s exists!"%(ckpt_name))
|
||||
|
||||
# Read model_config.json
|
||||
json_obj = readConfig(config_json)
|
||||
for item in json_obj.items():
|
||||
if item[0] in ignoreKey:
|
||||
pass
|
||||
else:
|
||||
sys_state[item[0]] = item[1]
|
||||
|
||||
# TODO get the checkpoint file path
|
||||
sys_state["ckp_name"] = {}
|
||||
# for data_key in sys_state["checkpoint_names"].keys():
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 18th February 2022 10:47:55 am
|
||||
# Last Modified: Friday, 18th February 2022 5:00:28 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -87,6 +87,7 @@ class Tester(object):
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
id_imgs = self.config["id_imgs"]
|
||||
crop_mode = self.config["crop_mode"]
|
||||
attr_files = self.config["attr_files"]
|
||||
specified_save_path = self.config["specified_save_path"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
@@ -113,7 +114,9 @@ class Tester(object):
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
mode = None
|
||||
mode = crop_mode.lower()
|
||||
if mode == "vggface":
|
||||
mode = "none"
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user