update
This commit is contained in:
@@ -60,7 +60,7 @@
|
||||
"face_crop.py": 1643789609.1834445,
|
||||
"face_crop_video.py": 1643815024.5516832,
|
||||
"similarity.py": 1643269705.1073737,
|
||||
"train_multigpu.py": 1645935139.5672748,
|
||||
"train_multigpu.py": 1646101637.160833,
|
||||
"components\\arcface_decoder.py": 1643396144.2575414,
|
||||
"components\\Generator_nobias.py": 1643179001.810856,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1644861019.9044807,
|
||||
@@ -111,7 +111,7 @@
|
||||
"components\\Generator_upsample.py": 1644689723.8293872,
|
||||
"components\\misc\\Involution.py": 1644509321.5267963,
|
||||
"train_yamls\\train_Invoup.yaml": 1644689981.9794765,
|
||||
"flops.py": 1645883189.3803008,
|
||||
"flops.py": 1646101039.8459642,
|
||||
"detection_test.py": 1644935512.6830947,
|
||||
"components\\DeConv_Depthwise.py": 1645064447.4379447,
|
||||
"components\\DeConv_Depthwise1.py": 1644946969.5054545,
|
||||
@@ -140,5 +140,7 @@
|
||||
"components\\Generator_Invobn_config1.py": 1645862695.8743145,
|
||||
"components\\misc\\Involution_BN.py": 1645867197.3984175,
|
||||
"components\\misc\\Involution_ECA.py": 1645869012.4927464,
|
||||
"train_yamls\\train_Invobn_config.yaml": 1645934993.5420852
|
||||
"train_yamls\\train_Invobn_config.yaml": 1646101598.499709,
|
||||
"components\\Generator_Invobn_config2.py": 1645962618.7056074,
|
||||
"components\\Generator_Invobn_config3.py": 1646100847.8995547
|
||||
}
|
||||
+2
-2
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"breakpoint": [
|
||||
1054,
|
||||
0
|
||||
1877,
|
||||
29
|
||||
]
|
||||
}
|
||||
+5850
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 27th February 2022 7:50:18 pm
|
||||
# Last Modified: Thursday, 3rd March 2022 6:16:01 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -107,7 +107,7 @@ class ResnetBlock_Modulation(nn.Module):
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
res_mode = "conv"
|
||||
# res_mode = "conv"
|
||||
if res_mode.lower() == "conv":
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), Demodule()]
|
||||
elif res_mode.lower() == "depthwise":
|
||||
@@ -158,13 +158,13 @@ class Generator(nn.Module):
|
||||
up_mode = kwargs["up_mode"]
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = aggregator
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
# from components.misc.Involution_BN import involution
|
||||
if aggregator == "invo":
|
||||
from components.misc.Involution_BN import involution
|
||||
from components.DeConv_Invobn import DeConv
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 16th February 2022 1:39:02 am
|
||||
# Last Modified: Thursday, 3rd March 2022 6:09:43 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -106,8 +106,8 @@ class Generator(nn.Module):
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.first_layer = nn.Sequential(nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(3, in_channel, kernel_size=3, padding=0, bias=False),
|
||||
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(3, in_channel, kernel_size=7, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_channel), activation)
|
||||
### downsample
|
||||
self.down1 = nn.Sequential(nn.Conv2d(in_channel, in_channel*2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
@@ -119,8 +119,8 @@ class Generator(nn.Module):
|
||||
self.down3 = nn.Sequential(nn.Conv2d(in_channel*4, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel*8), activation)
|
||||
|
||||
self.down4 = nn.Sequential(nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel*8), activation)
|
||||
# self.down4 = nn.Sequential(nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(in_channel*8), activation)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
@@ -130,11 +130,11 @@ class Generator(nn.Module):
|
||||
padding_type=padding_type, activation=activation)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
self.up4 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel*8), activation
|
||||
)
|
||||
# self.up4 = nn.Sequential(
|
||||
# nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
# nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(in_channel*8), activation
|
||||
# )
|
||||
|
||||
self.up3 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
@@ -153,8 +153,8 @@ class Generator(nn.Module):
|
||||
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel), activation
|
||||
)
|
||||
self.last_layer = nn.Sequential(nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(in_channel, 3, kernel_size=3, padding=0))
|
||||
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(in_channel, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
@@ -174,12 +174,12 @@ class Generator(nn.Module):
|
||||
res = self.down1(res)
|
||||
res = self.down2(res)
|
||||
res = self.down3(res)
|
||||
res = self.down4(res)
|
||||
# res = self.down4(res)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
|
||||
res = self.up4(res)
|
||||
# res = self.up4(res)
|
||||
res = self.up3(res)
|
||||
res = self.up2(res)
|
||||
res = self.up1(res)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday February 13th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 27th February 2022 8:15:11 pm
|
||||
# Last Modified: Thursday, 3rd March 2022 6:15:37 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -34,7 +34,8 @@ if __name__ == '__main__':
|
||||
"res_num": 9,
|
||||
# "up_mode": "nearest",
|
||||
"up_mode": "bilinear",
|
||||
"aggregator": "eca_invo"
|
||||
"aggregator": "eca_invo",
|
||||
"res_mode": "eca_invo"
|
||||
}
|
||||
|
||||
|
||||
|
||||
+9
-7
@@ -5,7 +5,7 @@
|
||||
# Created Date: Thursday February 10th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 26th February 2022 4:13:24 pm
|
||||
# Last Modified: Thursday, 3rd March 2022 6:44:57 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -21,20 +21,22 @@ if __name__ == '__main__':
|
||||
# cudnn.benchmark = True
|
||||
# cudnn.enabled = True
|
||||
# script = "Generator_modulation_up"
|
||||
script = "Generator_Invobn_config1"
|
||||
# script = "Generator_modulation_up"
|
||||
# script = "Generator_Invobn_config3"
|
||||
script = "Generator_ori_config"
|
||||
# script = "Generator_ori_config"
|
||||
class_name = "Generator"
|
||||
arcface_ckpt= "arcface_ckpt/arcface_checkpoint.tar"
|
||||
model_config={
|
||||
"id_dim": 512,
|
||||
"g_kernel_size": 3,
|
||||
"in_channel":16,
|
||||
"res_num": 4,
|
||||
"in_channel":64,
|
||||
"res_num": 9,
|
||||
# "up_mode": "nearest",
|
||||
"up_mode": "bilinear",
|
||||
"res_mode": "depthwise"
|
||||
"aggregator": "eca_invo",
|
||||
"res_mode": "eca_invo"
|
||||
}
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
|
||||
print("GPU used : ", os.environ['CUDA_VISIBLE_DEVICES'])
|
||||
|
||||
@@ -55,7 +57,7 @@ if __name__ == '__main__':
|
||||
id_latent = torch.rand((4,512)).cuda()
|
||||
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0]
|
||||
|
||||
attr = torch.rand((4,3,512,512)).cuda()
|
||||
attr = torch.rand((4,3,224,224)).cuda()
|
||||
|
||||
import datetime
|
||||
start_time = time.time()
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 20th February 2022 4:13:22 pm
|
||||
# Last Modified: Thursday, 3rd March 2022 9:04:25 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -30,15 +30,17 @@ def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='depthwise_config0', # depthwise depthwise_config0
|
||||
parser.add_argument('-v', '--version', type=str, default='Invobn_resinvo1', # depthwise depthwise_config0 Invobn_resinvo1
|
||||
help="version name for train, test, finetune")
|
||||
|
||||
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-s', '--checkpoint_step', type=int, default=250000,
|
||||
parser.add_argument('-s', '--checkpoint_step', type=int, default=150000,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
parser.add_argument('--start_checkpoint_step', type=int, default=10000,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
|
||||
# test
|
||||
parser.add_argument('-t', '--test_script_name', type=str, default='image')
|
||||
parser.add_argument('-t', '--test_script_name', type=str, default='image_allstep')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=1)
|
||||
parser.add_argument('-n', '--node_ip', type=str, default='101.33.242.26') # 101.33.242.26 2001:da8:8000:6880:f284:d61c:3c76:f9cb
|
||||
parser.add_argument('--crop_mode', type=str, default="vggface", choices=['ffhq','vggface'], help='crop mode for face detector')
|
||||
@@ -193,6 +195,7 @@ def main():
|
||||
break
|
||||
if not nodeinf:
|
||||
raise Exception(print("Configuration of node %s is unavaliable"%sys_state["node_ip"]))
|
||||
sys_state["remote_machine"] = nodeinf
|
||||
print("ready to fetch related files from server: %s ......"%nodeinf["ip"])
|
||||
uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"])
|
||||
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 3rd March 2022 9:03:57 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# for name in self.network.state_dict():
|
||||
# print(name)
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
|
||||
|
||||
def test(self):
|
||||
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
id_imgs = self.config["id_imgs"]
|
||||
crop_mode = self.config["crop_mode"]
|
||||
attr_files = self.config["attr_files"]
|
||||
specified_save_path = self.config["specified_save_path"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
|
||||
self.reporter.writeInfo("Version %s"%version)
|
||||
|
||||
if os.path.isdir(specified_save_path):
|
||||
print("Input a legal specified save path!")
|
||||
save_dir = specified_save_path
|
||||
|
||||
if os.path.isdir(attr_files):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
|
||||
for item in imgs:
|
||||
imgs_list.append(item)
|
||||
print(imgs_list)
|
||||
else:
|
||||
print("Input an image....")
|
||||
imgs_list.append(attr_files)
|
||||
id_basename = os.path.basename(id_imgs)
|
||||
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
mode = crop_mode.lower()
|
||||
if mode == "vggface":
|
||||
mode = "none"
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
id_img = cv2.imread(id_imgs)
|
||||
id_img_align_crop, _ = self.detect.get(id_img,512)
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
id_img = self.transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0).cuda()
|
||||
|
||||
#create latent id
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
latend_id = self.arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
|
||||
total_dict = {}
|
||||
|
||||
from utilities.sshupload import fileUploaderClass
|
||||
nodeinf = self.config["remote_machine"]
|
||||
|
||||
uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"])
|
||||
|
||||
remotebase = os.path.join(nodeinf['path'],"train_logs",self.config["version"]).replace('\\','/')
|
||||
|
||||
|
||||
for istep in range(self.config["start_checkpoint_step"],self.config["checkpoint_step"]+1,10000):
|
||||
ckpt_name = "step%d_%s.pth"%(istep,
|
||||
self.config["checkpoint_names"]["generator_name"])
|
||||
localFile = os.path.join(self.config["project_checkpoints"],ckpt_name)
|
||||
|
||||
if self.config["node_ip"]!="localhost":
|
||||
if not os.path.exists(localFile):
|
||||
remoteFile = os.path.join(remotebase, "checkpoints", ckpt_name).replace('\\','/')
|
||||
ssh_state = uploader.sshScpGet(remoteFile, localFile, True)
|
||||
if not ssh_state:
|
||||
raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile))
|
||||
print("Get the checkpoint %s successfully!"%(ckpt_name))
|
||||
else:
|
||||
print("%s exists!"%(ckpt_name))
|
||||
self.network.load_state_dict(torch.load(localFile, map_location=torch.device("cpu")))
|
||||
print('loaded trained backbone model step {}...!'.format(istep))
|
||||
cos_dict = {}
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
|
||||
average_cos = 0
|
||||
with torch.no_grad():
|
||||
for img in imgs_list:
|
||||
print(img)
|
||||
attr_img_ori= cv2.imread(img)
|
||||
try:
|
||||
attr_img_align_crop, _ = self.detect.get(attr_img_ori,512)
|
||||
except:
|
||||
continue
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
|
||||
|
||||
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
|
||||
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
|
||||
attr_id = self.arcface(attr_img_arc)
|
||||
attr_id = F.normalize(attr_id, p=2, dim=1)
|
||||
|
||||
results = self.network(attr_img, latend_id)
|
||||
|
||||
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
|
||||
results_arc = self.arcface(results_arc)
|
||||
results_arc = F.normalize(results_arc, p=2, dim=1)
|
||||
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
|
||||
cos_dict[img] = results_cos_dis.item()
|
||||
average_cos += results_cos_dis
|
||||
|
||||
average_cos /= len(imgs_list)
|
||||
total_dict[str(istep)] = {
|
||||
"step":istep,
|
||||
"Average_cosin": average_cos.item(),
|
||||
"images": cos_dict
|
||||
}
|
||||
|
||||
print("Step: [{}], average cosin similarity between ID and results [{}]".format(istep, average_cos.item()))
|
||||
self.reporter.writeInfo("Step: [{}], average cosin similarity between ID and results [{}]".format(istep, average_cos.item()))
|
||||
self.reporter.writeJson(total_dict)
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
+3
-3
@@ -5,7 +5,7 @@
|
||||
# Created Date: Tuesday April 28th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 27th February 2022 12:12:19 pm
|
||||
# Last Modified: Tuesday, 1st March 2022 10:27:16 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -31,7 +31,7 @@ def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='Invobn_eca1',
|
||||
parser.add_argument('-v', '--version', type=str, default='Invobn_resinvo1',
|
||||
help="version name for train, test, finetune")
|
||||
parser.add_argument('-t', '--tag', type=str, default='tiny',
|
||||
help="tag for current experiment")
|
||||
@@ -46,7 +46,7 @@ def getParameters():
|
||||
|
||||
# training
|
||||
parser.add_argument('--experiment_description', type=str,
|
||||
default="尝试直接训练最小规模的网络,正往由ECA Invo构成")
|
||||
default="尝试直接训练最小规模的网络,正往由Invo构成,Resblock用Invo+conv, 对齐batchsize 64")
|
||||
|
||||
parser.add_argument('--train_yaml', type=str, default="train_Invobn_config.yaml")
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ model_configs:
|
||||
res_num: 9
|
||||
up_mode: bilinear
|
||||
aggregator: "invo"
|
||||
res_mode: "invo"
|
||||
|
||||
d_model:
|
||||
script: projected_discriminator
|
||||
@@ -25,7 +26,7 @@ model_configs:
|
||||
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
|
||||
|
||||
# Training information
|
||||
batch_size: 24
|
||||
batch_size: 64
|
||||
|
||||
# Dataset
|
||||
dataloader: VGGFace2HQ_multigpu
|
||||
|
||||
@@ -5,13 +5,14 @@
|
||||
# Created Date: Tuesday September 24th 2019
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 4th July 2021 11:50:12 pm
|
||||
# Last Modified: Thursday, 3rd March 2022 8:42:13 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2019 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import json
|
||||
|
||||
class Reporter:
|
||||
def __init__(self,reportPath):
|
||||
@@ -54,3 +55,8 @@ class Reporter:
|
||||
timeStr = datetime.datetime.strftime(datetime.datetime.now(),self.timeStrFormat)
|
||||
logf.writelines("[%d]-[%s]-[logInfo]-epoch[%d]-step[%d] %s\n"%(self.index,timeStr,epoch,step,logText))
|
||||
self.index += 1
|
||||
|
||||
def writeJson(self, info):
|
||||
with open(self.path, 'a+') as cf:
|
||||
configjson = json.dumps(info, indent=4)
|
||||
cf.writelines(configjson)
|
||||
|
||||
Reference in New Issue
Block a user