distillation

This commit is contained in:
chenxuanhong
2022-03-04 19:04:36 +08:00
parent 39964bf613
commit e02f756116
10 changed files with 1232 additions and 23 deletions
+10 -7
View File
@@ -1,6 +1,6 @@
{
"GUI.py": 1645109256.0056663,
"test.py": 1645344802.7112515,
"test.py": 1646330130.1009316,
"train.py": 1643397924.974299,
"components\\Generator.py": 1644689001.9005148,
"components\\projected_discriminator.py": 1642348101.4661522,
@@ -31,7 +31,7 @@
"utilities\\learningrate_scheduler.py": 1611123530.675422,
"utilities\\logo_class.py": 1633883995.3093486,
"utilities\\plot.py": 1641911100.7995758,
"utilities\\reporter.py": 1625413813.7213495,
"utilities\\reporter.py": 1646311333.3067005,
"utilities\\save_heatmap.py": 1611123530.679439,
"utilities\\sshupload.py": 1645168814.6421573,
"utilities\\transfer_checkpoint.py": 1642397157.0163105,
@@ -60,7 +60,7 @@
"face_crop.py": 1643789609.1834445,
"face_crop_video.py": 1643815024.5516832,
"similarity.py": 1643269705.1073737,
"train_multigpu.py": 1646101637.160833,
"train_multigpu.py": 1646329983.38444,
"components\\arcface_decoder.py": 1643396144.2575414,
"components\\Generator_nobias.py": 1643179001.810856,
"data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1644861019.9044807,
@@ -105,13 +105,13 @@
"components\\Generator_ori.py": 1644689174.414655,
"losses\\cos.py": 1644229583.4023254,
"data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1644860106.943826,
"speed_test.py": 1645863205.1120403,
"speed_test.py": 1646304298.3483005,
"components\\DeConv_Invo.py": 1644426607.1588645,
"components\\Generator_reduce_up.py": 1644688655.2096283,
"components\\Generator_upsample.py": 1644689723.8293872,
"components\\misc\\Involution.py": 1644509321.5267963,
"train_yamls\\train_Invoup.yaml": 1644689981.9794765,
"flops.py": 1646101039.8459642,
"flops.py": 1646330033.710075,
"detection_test.py": 1644935512.6830947,
"components\\DeConv_Depthwise.py": 1645064447.4379447,
"components\\DeConv_Depthwise1.py": 1644946969.5054545,
@@ -119,7 +119,7 @@
"components\\Generator_modulation_depthwise_config.py": 1645262162.9779513,
"components\\Generator_modulation_up.py": 1644946498.7005584,
"components\\Generator_oriae_modulation.py": 1644897798.1987727,
"components\\Generator_ori_config.py": 1644946742.3635018,
"components\\Generator_ori_config.py": 1646329319.6131227,
"train_scripts\\trainer_multi_gpu1.py": 1644859528.8428593,
"train_yamls\\train_Depthwise.yaml": 1644860961.099242,
"train_yamls\\train_depthwise_modulation.yaml": 1645035964.9551077,
@@ -142,5 +142,8 @@
"components\\misc\\Involution_ECA.py": 1645869012.4927464,
"train_yamls\\train_Invobn_config.yaml": 1646101598.499709,
"components\\Generator_Invobn_config2.py": 1645962618.7056074,
"components\\Generator_Invobn_config3.py": 1646100847.8995547
"components\\Generator_Invobn_config3.py": 1646302561.1984286,
"components\\Generator_ori_modulation_config.py": 1646329636.719998,
"test_scripts\\tester_image_allstep.py": 1646312637.9363256,
"train_yamls\\train_ori_modulation_config.yaml": 1646330406.200162
}
+5 -5
View File
@@ -5,7 +5,7 @@
# Created Date: Sunday January 16th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 3rd March 2022 6:09:43 pm
# Last Modified: Friday, 4th March 2022 1:41:59 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -106,8 +106,8 @@ class Generator(nn.Module):
activation = nn.ReLU(True)
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3),
nn.Conv2d(3, in_channel, kernel_size=7, padding=0, bias=False),
self.first_layer = nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(3, in_channel, kernel_size=3, padding=0, bias=False),
nn.BatchNorm2d(in_channel), activation)
### downsample
self.down1 = nn.Sequential(nn.Conv2d(in_channel, in_channel*2, kernel_size=3, stride=2, padding=1, bias=False),
@@ -153,8 +153,8 @@ class Generator(nn.Module):
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel), activation
)
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
nn.Conv2d(in_channel, 3, kernel_size=7, padding=0))
self.last_layer = nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(in_channel, 3, kernel_size=3, padding=0))
# self.__weights_init__()
@@ -0,0 +1,203 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator.py
# Created Date: Sunday January 16th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 4th March 2022 1:47:16 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
class Demodule(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(Demodule, self).__init__()
self.epsilon = epsilon
def forward(self, x):
tmp = torch.mul(x, x) # or x ** 2
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
return x * tmp
class ApplyStyle(nn.Module):
"""
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
"""
def __init__(self, latent_size, channels):
super(ApplyStyle, self).__init__()
self.linear = nn.Linear(latent_size, channels * 2)
def forward(self, x, latent):
style = self.linear(latent) # style => [batch_size, n_channels*2]
shape = [-1, 2, x.size(1), 1, 1]
style = style.view(shape) # [batch_size, 2, n_channels, ...]
#x = x * (style[:, 0] + 1.) + style[:, 1]
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
return x
class Modulation(nn.Module):
def __init__(self, latent_size, channels):
super(Modulation, self).__init__()
self.linear = nn.Linear(latent_size, channels)
def forward(self, x, latent):
style = self.linear(latent) # style => [batch_size, n_channels*2]
shape = [-1, x.size(1), 1, 1]
style = style.view(shape) # [batch_size, 2, n_channels, ...]
#x = x * (style[:, 0] + 1.) + style[:, 1]
x = x * style
return x
class ResnetBlock_Modulation(nn.Module):
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True),res_mode="depthwise"):
super(ResnetBlock_Modulation, self).__init__()
p = 0
conv1 = []
if padding_type == 'reflect':
conv1 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv1 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), Demodule()]
self.conv1 = nn.Sequential(*conv1)
self.style1 = Modulation(latent_size, dim)
self.act1 = activation
p = 0
conv2 = []
if padding_type == 'reflect':
conv2 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv2 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
# res_mode = "conv"
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), Demodule()]
self.conv2 = nn.Sequential(*conv2)
self.style2 = Modulation(latent_size, dim)
def forward(self, x, dlatents_in_slice):
y = self.style1(x, dlatents_in_slice)
y = self.conv1(y)
y = self.act1(y)
y = self.style2(y, dlatents_in_slice)
y = self.conv2(y)
out = x + y
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.ReLU(True)
self.first_layer = nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(3, in_channel, kernel_size=3, padding=0, bias=False),
nn.BatchNorm2d(in_channel), activation)
### downsample
self.down1 = nn.Sequential(nn.Conv2d(in_channel, in_channel*2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(in_channel*2), activation)
self.down2 = nn.Sequential(nn.Conv2d(in_channel*2, in_channel*4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(in_channel*4), activation)
self.down3 = nn.Sequential(nn.Conv2d(in_channel*4, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(in_channel*8), activation)
# self.down4 = nn.Sequential(nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False),
# nn.BatchNorm2d(in_channel*8), activation)
### resnet blocks
BN = []
for _ in range(res_num):
BN += [
ResnetBlock_Modulation(in_channel*8, latent_size=id_dim,
padding_type=padding_type, activation=activation)]
self.BottleNeck = nn.Sequential(*BN)
# self.up4 = nn.Sequential(
# nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=1, padding=1, bias=False),
# nn.BatchNorm2d(in_channel*8), activation
# )
self.up3 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channel*8, in_channel*4, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel*4), activation
)
self.up2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channel*4, in_channel*2, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel*2), activation
)
self.up1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel), activation
)
self.last_layer = nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(in_channel, 3, kernel_size=3, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
# x = input # 3*224*224
res = self.first_layer(img)
res = self.down1(res)
res = self.down2(res)
res = self.down3(res)
# res = self.down4(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
# res = self.up4(res)
res = self.up3(res)
res = self.up2(res)
res = self.up1(res)
res = self.last_layer(res)
return res
+3 -1
View File
@@ -6,7 +6,9 @@
"dataset_paths": {
"vggface2_hq": "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan",
"val_dataset_root": "",
"test_dataset_root": ""
"test_dataset_root": "",
"id_pose_source_root": "",
"id_pose_target_root": ""
},
"train_config_path":"./train_yamls",
"train_scripts_path":"./train_scripts",
+3 -3
View File
@@ -5,7 +5,7 @@
# Created Date: Sunday February 13th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 3rd March 2022 6:15:37 pm
# Last Modified: Friday, 4th March 2022 1:53:53 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -23,7 +23,7 @@ if __name__ == '__main__':
#
# script = "Generator_modulation_up"
script = "Generator_Invobn_config3"
# script = "Generator_ori_config"
# script = "Generator_ori_modulation_config"
# script = "Generator_ori_config"
class_name = "Generator"
arcface_ckpt= "arcface_ckpt/arcface_checkpoint.tar"
@@ -35,7 +35,7 @@ if __name__ == '__main__':
# "up_mode": "nearest",
"up_mode": "bilinear",
"aggregator": "eca_invo",
"res_mode": "eca_invo"
"res_mode": "conv"
}
+4 -3
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 3rd March 2022 9:04:25 pm
# Last Modified: Friday, 4th March 2022 5:40:11 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -30,11 +30,11 @@ def getParameters():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('-v', '--version', type=str, default='Invobn_resinvo1', # depthwise depthwise_config0 Invobn_resinvo1
parser.add_argument('-v', '--version', type=str, default='ori_tiny', # depthwise depthwise_config0 Invobn_resinvo1
help="version name for train, test, finetune")
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
parser.add_argument('-s', '--checkpoint_step', type=int, default=150000,
parser.add_argument('-s', '--checkpoint_step', type=int, default=80000,
help="checkpoint epoch for test phase or finetune phase")
parser.add_argument('--start_checkpoint_step', type=int, default=10000,
help="checkpoint epoch for test phase or finetune phase")
@@ -153,6 +153,7 @@ def main():
# read system environment paths
env_config = readConfig('env/env.json')
env_config = env_config["path"]
sys_state["env_config"] = env_config
# obtain all configurations in argparse
config_dic = vars(config)
+346
View File
@@ -0,0 +1,346 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_ID_Pose.py
# Created Date: Friday March 4th 2022
# Author: Liu Naiyuan
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 4th March 2022 5:33:47 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import glob
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils import data
import numpy as np
import PIL
from PIL import Image
class TotalDataset(data.Dataset):
"""Dataset class for the vggface dataset with precalulated face landmarks."""
def __init__(self,image_dir,content_transform, img_size=224):
self.image_dir= image_dir
self.content_transform= content_transform
self.img_size = img_size
self.dataset = []
self.preprocess()
self.num_images = len(self.dataset)
def preprocess(self):
"""Preprocess the Face++ original frames."""
filenames = sorted(glob.glob(os.path.join(self.image_dir, '*'), recursive=False))
# self.total_num = len(lines)
for filename in filenames:
self.dataset.append(filename)
print('Finished preprocessing the Face++ original frames dataset...')
def __getitem__(self, index):
"""Return two src domain images and two dst domain images."""
src_filename = self.dataset[index]
split_tmp = src_filename.split('/')
save_filename = split_tmp[-1]
src_image1 = self.content_transform(Image.open(src_filename))
return src_image1, save_filename
def __len__(self):
"""Return the number of images."""
return len(self.dataset)
def getLoader_sourceface(c_image_dir,
img_size=224, batch_size=16, num_workers=8):
"""Build and return a data loader."""
c_transforms = []
c_transforms.append(T.ToTensor())
c_transforms.append(T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
# c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
c_transforms = T.Compose(c_transforms)
content_dataset = TotalDataset(c_image_dir, c_transforms, 224)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True)
return content_data_loader, len(content_dataset)
def getLoader_targetface(c_image_dir,
img_size=224, batch_size=16, num_workers=8):
"""Build and return a data loader."""
c_transforms = []
c_transforms.append(transforms.ToTensor())
# c_transforms.append(T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
# c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
c_transforms = transforms.Compose(c_transforms)
content_dataset = TotalDataset(c_image_dir, c_transforms, 224)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True)
return content_data_loader, len(content_dataset)
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
def test(self):
save_dir = self.config["test_samples_path"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
id_imgs = self.config["id_imgs"]
attr_files = self.config["attr_files"]
specified_save_path = self.config["specified_save_path"]
self.arcface_ckpt= self.config["arcface_ckpt"]
imgs_list = []
self.reporter.writeInfo("Version %s"%version)
if os.path.isdir(specified_save_path):
print("Input a legal specified save path!")
save_dir = specified_save_path
if os.path.isdir(attr_files):
print("Input a dir....")
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
for item in imgs:
imgs_list.append(item)
print(imgs_list)
else:
print("Input an image....")
imgs_list.append(attr_files)
id_basename = os.path.basename(id_imgs)
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
source_loader, dataet_len = getLoader_sourceface(
self.config["env_config"]["dataset_paths"]["id_pose_source_root"], batch_size=opt.batchSize)
target_loader, dataet_len = getLoader_targetface(
self.config["env_config"]["dataset_paths"]["id_pose_source_root"], batch_size=opt.batchSize)
source_iter = iter(source_loader)
target_iter = iter(target_loader)
# models
self.__init_framework__()
id_img = cv2.imread(id_imgs)
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
with torch.no_grad():
for profile_batch, filename_batch in tqdm(source_iter):
profile_batch = profile_batch.cuda()
profile_id_downsample = F.interpolate(profile_batch, (112,112), mode='bicubic')
profile_latent_id = model.netArc(profile_id_downsample)
profile_latent_id = F.normalize(profile_latent_id, p=2, dim=1)
if init_batch ==True:
wholeid_batch = profile_latent_id.cpu()
init_batch = False
else:
wholeid_batch = torch.cat([wholeid_batch,profile_latent_id.cpu()],dim=0)
target_source_pair_dict = np.load(
self.config["env_config"]["dataset_paths"]["pairs_dict"] ,allow_pickle=True).item()
for target_batch, filename_batch in tqdm(target_iter):
target_index_list = []
init_id_batch = True
for filename_tmp in filename_batch:
source_index = int(filename_tmp.split('_')[0])
target_index = target_source_pair_dict[source_index]
target_index_list.append(target_index)
if init_id_batch:
batch_id = wholeid_batch[target_index][None].cuda()
init_id_batch = False
else:
batch_id = torch.cat([batch_id, wholeid_batch[target_index][None].cuda()],dim = 0)
img_fakes = model(None, target_batch.cuda(), batch_id, None, True)
for img_fake, target_index_tmp,filename_tmp in zip(img_fakes, target_index_list,filename_batch):
filename_tmp_split = filename_tmp.split('_')
final_filename = filename_tmp_split[0] + '_' +str(target_index_tmp) + '_' + filename_tmp_split[-1]
save_path = os.path.join(simswap_eval_save_image_path,final_filename)
save_image = postprocess(img_fake.cpu().numpy().transpose(1,2,0))
PIL.Image.fromarray(save_image).save(save_path,quality=95)
for img in imgs_list:
print(img)
attr_img_ori= cv2.imread(img)
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB))
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
attr_id = self.arcface(attr_img_arc)
attr_id = F.normalize(attr_id, p=2, dim=1)
results = self.network(attr_img, latend_id)
results = results * self.imagenet_std + self.imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
results = results.numpy()
results = np.clip(results,0.0,1.0)
final_img = img1.astype(np.uint8)
attr_basename = os.path.splitext(os.path.basename(img))[0]
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
save_filename = os.path.join(save_dir,
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
attr_basename,ckp_step,version))
cv2.imwrite(save_filename, final_img)
average_cos /= len(imgs_list)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
print("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item()))
if __name__ == '__main__':
opt = TestOptions().parse()
with torch.no_grad():
source_loader, dataet_len = getLoader_sourceface('/home/gdp/harddisk/Data2/Faceswap/FaceForensics++_image_hififacestyle_source_Nonearcstyle', batch_size=opt.batchSize)
target_loader, dataet_len = getLoader_targetface('/home/gdp/harddisk/Data2/Faceswap/FaceForensics++_image_target_even10_pro_withmat_Nonearcstyle_256', batch_size=opt.batchSize)
simswap_eval_save_image_path = opt.output_path
criterion = nn.L1Loss()
if not os.path.exists(simswap_eval_save_image_path):
os.makedirs(simswap_eval_save_image_path)
torch.nn.Module.dump_patches = True
model = create_model(opt)
model.eval()
source_iter = iter(source_loader)
target_iter = iter(target_loader)
init_batch = True
for profile_batch, filename_batch in tqdm(source_iter):
# src_batch, filename_batch = data_iter.next()
profile_batch = profile_batch.cuda()
profile_id_downsample = F.interpolate(profile_batch, (112,112))
profile_latent_id = model.netArc(profile_id_downsample)
profile_latent_id = F.normalize(profile_latent_id, p=2, dim=1)
if init_batch ==True:
wholeid_batch = profile_latent_id.cpu()
init_batch = False
else:
wholeid_batch = torch.cat([wholeid_batch,profile_latent_id.cpu()],dim=0)
print(wholeid_batch.shape)
# np.save("simswap_wholeid_batch.npy", wholeid_batch.detach().cpu().numpy())
target_source_pair_dict = np.load('/home/gdp/harddisk/Data2/Faceswap/npy_file/target_source_pair.npy' ,allow_pickle=True).item()
for target_batch, filename_batch in tqdm(target_iter):
target_index_list = []
init_id_batch = True
for filename_tmp in filename_batch:
source_index = int(filename_tmp.split('_')[0])
target_index = target_source_pair_dict[source_index]
target_index_list.append(target_index)
if init_id_batch:
batch_id = wholeid_batch[target_index][None].cuda()
init_id_batch = False
else:
batch_id = torch.cat([batch_id, wholeid_batch[target_index][None].cuda()],dim = 0)
img_fakes = model(None, target_batch.cuda(), batch_id, None, True)
for img_fake, target_index_tmp,filename_tmp in zip(img_fakes, target_index_list,filename_batch):
filename_tmp_split = filename_tmp.split('_')
final_filename = filename_tmp_split[0] + '_' +str(target_index_tmp) + '_' + filename_tmp_split[-1]
save_path = os.path.join(simswap_eval_save_image_path,final_filename)
save_image = postprocess(img_fake.cpu().numpy().transpose(1,2,0))
PIL.Image.fromarray(save_image).save(save_path,quality=95)
+4 -4
View File
@@ -5,7 +5,7 @@
# Created Date: Tuesday April 28th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 1st March 2022 10:27:16 am
# Last Modified: Friday, 4th March 2022 1:53:03 am
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
@@ -31,7 +31,7 @@ def getParameters():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('-v', '--version', type=str, default='Invobn_resinvo1',
parser.add_argument('-v', '--version', type=str, default='ori_tiny',
help="version name for train, test, finetune")
parser.add_argument('-t', '--tag', type=str, default='tiny',
help="tag for current experiment")
@@ -46,9 +46,9 @@ def getParameters():
# training
parser.add_argument('--experiment_description', type=str,
default="尝试直接训练最小规模的网络,正往由Invo构成,Resblock用Invo+conv, 对齐batchsize 64")
default="只用conv,训练最小的模型")
parser.add_argument('--train_yaml', type=str, default="train_Invobn_config.yaml")
parser.add_argument('--train_yaml', type=str, default="train_ori_modulation_config.yaml")
# system logger
parser.add_argument('--logger', type=str,
@@ -0,0 +1,590 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_naiv512.py
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 4th March 2022 7:02:04 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import time
import random
import shutil
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from torch_utils import misc
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import grid_sample_gradfix
from losses.KA import KA
from utilities.plot import plot_batch
from train_scripts.trainer_multigpu_base import TrainerBase
class Trainer(TrainerBase):
def __init__(self,
config,
reporter):
super(Trainer, self).__init__(config, reporter)
import inspect
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
def train(self):
# Launch processes.
num_gpus = len(self.config["gpus"])
print('Launching processes...')
torch.multiprocessing.set_start_method('spawn')
with tempfile.TemporaryDirectory() as temp_dir:
torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus)
def add_mapping_hook(network, features,mapping_layers):
mapping_hooks = []
def get_activation(mem, name):
def get_output_hook(module, input, output):
mem[name] = output
return get_output_hook
def add_hook(net, mem, mapping_layers):
for n, m in net.named_modules():
if n in mapping_layers:
mapping_hooks.append(
m.register_forward_hook(get_activation(mem, n)))
add_hook(network, features, mapping_layers)
# TODO modify this function to build your models
def init_framework(config, reporter, device, rank):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
torch.cuda.set_device(rank)
torch.cuda.empty_cache()
model_config = config["model_configs"]
if config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
file1 = os.path.join("components", model_config["g_model"]["script"]+".py")
tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py")
shutil.copyfile(file1,tgtfile1)
dscript_name = "components." + model_config["d_model"]["script"]
file1 = os.path.join("components", model_config["d_model"]["script"]+".py")
tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py")
shutil.copyfile(file1,tgtfile1)
elif config["phase"] == "finetune":
gscript_name = config["com_base"] + model_config["g_model"]["script"]
dscript_name = config["com_base"] + model_config["d_model"]["script"]
com_base = "train_logs."+config["teacher_model"]["version"]+".scripts"
tscript_name = com_base +"."+ config["teacher_model"]["model_configs"]["g_model"]["script"]
class_name = config["teacher_model"]["model_configs"]["g_model"]["class_name"]
package = __import__(tscript_name, fromlist=True)
gen_class = getattr(package, class_name)
tgen = gen_class(**config["teacher_model"]["model_configs"]["g_model"]["module_params"])
tgen = tgen.eval()
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
reporter.writeInfo("Generator structure:")
reporter.writeModel(gen.__str__())
reporter.writeInfo("Teacher structure:")
reporter.writeModel(tgen.__str__())
class_name = model_config["d_model"]["class_name"]
package = __import__(dscript_name, fromlist=True)
dis_class = getattr(package, class_name)
dis = dis_class(**model_config["d_model"]["module_params"])
# print and recorde model structure
reporter.writeInfo("Discriminator structure:")
reporter.writeModel(dis.__str__())
arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
arcface = arcface1['model'].module
# train in GPU
# if in finetune phase, load the pretrained checkpoint
if config["phase"] == "finetune":
model_path = os.path.join(config["project_checkpoints"],
"step%d_%s.pth"%(config["ckpt"],
config["checkpoint_names"]["generator_name"]))
gen.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model_path = os.path.join(config["project_checkpoints"],
"step%d_%s.pth"%(config["ckpt"],
config["checkpoint_names"]["discriminator_name"]))
dis.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"]))
model_path = os.path.join(config["teacher_model"]["project_checkpoints"],
"step%d_%s.pth"%(config["teacher_model"]["model_step"],
config["teacher_model"]["checkpoint_names"]["generator_name"]))
tgen.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
print('loaded trained teacher backbone model step {}...!'.format(config["teacher_model"]["model_step"]))
tgen = tgen.to(device)
tgen.requires_grad_(False)
gen = gen.to(device)
dis = dis.to(device)
arcface= arcface.to(device)
arcface.requires_grad_(False)
arcface.eval()
t_features = {}
s_features = {}
add_mapping_hook(tgen,t_features,config["feature_list"])
add_mapping_hook(gen,s_features,config["feature_list"])
return tgen, gen, dis, arcface, t_features, s_features
# TODO modify this function to configurate the optimizer of your pipeline
def setup_optimizers(config, reporter, gen, dis, rank):
torch.cuda.set_device(rank)
torch.cuda.empty_cache()
g_train_opt = config['g_optim_config']
d_train_opt = config['d_optim_config']
g_optim_params = []
d_optim_params = []
for k, v in gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
for k, v in dis.named_parameters():
if v.requires_grad:
d_optim_params.append(v)
else:
reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = config['optim_type']
if optim_type == 'Adam':
g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
if config["phase"] == "finetune":
opt_path = os.path.join(config["project_checkpoints"],
"step%d_optim_%s.pth"%(config["ckpt"],
config["optimizer_names"]["generator_name"]))
g_optimizer.load_state_dict(torch.load(opt_path))
opt_path = os.path.join(config["project_checkpoints"],
"step%d_optim_%s.pth"%(config["ckpt"],
config["optimizer_names"]["discriminator_name"]))
d_optimizer.load_state_dict(torch.load(opt_path))
print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"]))
return g_optimizer, d_optimizer
def train_loop(
rank,
config,
reporter,
temp_dir
):
version = config["version"]
ckpt_dir = config["project_checkpoints"]
sample_dir = config["project_samples"]
log_freq = config["log_step"]
model_freq = config["model_save_step"]
sample_freq = config["sample_step"]
total_step = config["total_step"]
random_seed = config["dataset_params"]["random_seed"]
id_w = config["id_weight"]
rec_w = config["reconstruct_weight"]
feat_w = config["feature_match_weight"]
distill_w = config["distillation_weight"]
distill_rec_w = config["teacher_reconstruction"]
distill_feat_w = config["teacher_featurematching"]
feat_num = len(config["feature_list"])
num_gpus = len(config["gpus"])
batch_gpu = config["batch_size"] // num_gpus
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
if os.name == 'nt':
init_method = 'file:///' + init_file.replace('\\', '/')
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus)
else:
init_method = f'file://{init_file}'
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus)
# Init torch_utils.
sync_device = torch.device('cuda', rank)
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
if rank == 0:
img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
# Initialize.
device = torch.device('cuda', rank)
np.random.seed(random_seed * num_gpus + rank)
torch.manual_seed(random_seed * num_gpus + rank)
torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy.
torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy.
conv2d_gradfix.enabled = True # Improves training speed.
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
# Create dataloader.
if rank == 0:
print('Loading training set...')
dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
dataloader_class= dataloaderClass
dataloader = dataloader_class(dataset,
rank,
num_gpus,
batch_gpu,
**config["dataset_params"])
# Construct networks.
if rank == 0:
print('Constructing networks...')
tgen, gen, dis, arcface, t_feat, s_feat = init_framework(config, reporter, device, rank)
# Check for existing checkpoint
# Print network summary tables.
# if rank == 0:
# attr = torch.empty([batch_gpu, 3, 512, 512], device=device)
# id = torch.empty([batch_gpu, 3, 112, 112], device=device)
# latent = misc.print_module_summary(arcface, [id])
# img = misc.print_module_summary(gen, [attr, latent])
# misc.print_module_summary(dis, [img, None])
# del attr
# del id
# del latent
# del img
# torch.cuda.empty_cache()
# Distribute across GPUs.
if rank == 0:
print(f'Distributing across {num_gpus} GPUs...')
for module in [gen, dis, arcface, tgen]:
if module is not None and num_gpus > 1:
for param in misc.params_and_buffers(module):
torch.distributed.broadcast(param, src=0)
# Setup training phases.
if rank == 0:
print('Setting up training phases...')
#===============build losses===================#
# TODO replace below lines to build your losses
# MSE_loss = torch.nn.MSELoss()
l1_loss = torch.nn.L1Loss()
l1_loss_import = torch.nn.L1Loss(reduce=False)
cos_loss = torch.nn.CosineSimilarity()
g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank)
# Initialize logs.
if rank == 0:
print('Initializing logs...')
#==============build tensorboard=================#
if config["logger"] == "tensorboard":
import torch.utils.tensorboard as tensorboard
tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"])
logger = tensorboard_writer
elif config["logger"] == "wandb":
import wandb
wandb.init(project="Simswap_HQ", entity="xhchen", notes="512",
tags=[config["tag"]], name=version)
wandb.config = {
"total_step": config["total_step"],
"batch_size": config["batch_size"]
}
logger = wandb
random.seed(random_seed)
randindex = [i for i in range(batch_gpu)]
# set the start point for training loop
if config["phase"] == "finetune":
start = config["ckpt"]
else:
start = 0
if rank == 0:
import datetime
start_time = time.time()
# Caculate the epoch number
print("Total step = %d"%total_step)
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
dis.feature_network.requires_grad_(False)
for step in range(start, total_step):
gen.train()
dis.train()
for interval in range(2):
random.shuffle(randindex)
src_image1, src_image2 = dataloader.next()
# if rank ==0:
# elapsed = time.time() - start_time
# elapsed = str(datetime.timedelta(seconds=elapsed))
# print("dataloader:",elapsed)
if step%2 == 0:
img_id = src_image2
else:
img_id = src_image2[randindex]
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
latent_id = arcface(img_id_112)
latent_id = F.normalize(latent_id, p=2, dim=1)
if interval == 0:
img_fake = gen(src_image1, latent_id)
gen_logits,_ = dis(img_fake.detach(), None)
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
real_logits,_ = dis(src_image2,None)
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
loss_D = loss_Dgen + loss_Dreal
d_optimizer.zero_grad(set_to_none=True)
loss_D.backward()
with torch.autograd.profiler.record_function('discriminator_opt'):
# params = [param for param in dis.parameters() if param.grad is not None]
# if len(params) > 0:
# flat = torch.cat([param.grad.flatten() for param in params])
# if num_gpus > 1:
# torch.distributed.all_reduce(flat)
# flat /= num_gpus
# misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
# grads = flat.split([param.numel() for param in params])
# for param, grad in zip(params, grads):
# param.grad = grad.reshape(param.shape)
params = [param for param in dis.parameters() if param.grad is not None]
flat = torch.cat([param.grad.flatten() for param in params])
torch.distributed.all_reduce(flat)
flat /= num_gpus
misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
grads = flat.split([param.numel() for param in params])
for param, grad in zip(params, grads):
param.grad = grad.reshape(param.shape)
d_optimizer.step()
# if rank ==0:
# elapsed = time.time() - start_time
# elapsed = str(datetime.timedelta(seconds=elapsed))
# print("Discriminator training:",elapsed)
else:
# model.netD.requires_grad_(True)
t_fake = tgen(src_image1, latent_id)
t_id = arcface(t_fake.detach())
t_feat = dis.get_feature(t_fake.detach())
realism = cos_loss(t_id, latent_id)
img_fake = gen(src_image1, latent_id)
Sacts = [
s_feat[key] for key in sorted(s_feat.keys())
]
Tacts = [
t_feat[key] for key in sorted(t_feat.keys())
]
loss_distill = 0
for Sact, Tact in zip(Sacts, Tacts):
loss_distill += -KA(Sact, Tact)
# G loss
loss_distill /= feat_num
gen_logits,feat = dis(img_fake, None)
loss_Gmain = (-gen_logits).mean()
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
latent_fake = arcface(img_fake_down)
latent_fake = F.normalize(latent_fake, p=2, dim=1)
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
real_feat = dis.get_feature(src_image1)
feat_match_loss = l1_loss(feat["3"],real_feat["3"])
feat_match_ts = (realism * l1_loss_import(feat["3"],t_feat)).mean()
t_rec_loss = (realism * l1_loss_import(t_fake.detach(), img_fake)).mean()
loss_G = loss_Gmain + loss_G_ID * id_w + \
feat_match_loss * feat_w + loss_distill * distill_w +\
distill_feat_w * feat_match_ts + distill_rec_w * t_rec_loss
if step%2 == 0:
#G_Rec
loss_G_Rec = l1_loss(img_fake, src_image1)
loss_G += loss_G_Rec * rec_w
g_optimizer.zero_grad(set_to_none=True)
loss_G.backward()
with torch.autograd.profiler.record_function('generator_opt'):
params = [param for param in gen.parameters() if param.grad is not None]
flat = torch.cat([param.grad.flatten() for param in params])
torch.distributed.all_reduce(flat)
flat /= num_gpus
misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
grads = flat.split([param.numel() for param in params])
for param, grad in zip(params, grads):
param.grad = grad.reshape(param.shape)
g_optimizer.step()
# if rank ==0:
# elapsed = time.time() - start_time
# elapsed = str(datetime.timedelta(seconds=elapsed))
# print("Generator training:",elapsed)
# Print out log info
if rank == 0 and (step + 1) % log_freq == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# print("ready to report losses")
# ID_Total= loss_G_ID
# torch.distributed.all_reduce(ID_Total)
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
Distillaton_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
format(version, elapsed, step, total_step, \
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
loss_distill.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
print(epochinformation)
reporter.writeInfo(epochinformation)
if config["logger"] == "tensorboard":
logger.add_scalar('G/G_loss', loss_G.item(), step)
logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
logger.add_scalar('G/G_distillation', loss_distill.item(), step)
logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
logger.add_scalar('D/D_loss', loss_D.item(), step)
logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
logger.add_scalar('D/D_real', loss_Dreal.item(), step)
elif config["logger"] == "wandb":
logger.log({"G_Loss": loss_G.item()}, step = step)
logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
logger.log({"G_distillation": loss_distill.item()}, step = step)
logger.log({"G_ID": loss_G_ID.item()}, step = step)
logger.log({"D_loss": loss_D.item()}, step = step)
logger.log({"D_fake": loss_Dgen.item()}, step = step)
logger.log({"D_real": loss_Dreal.item()}, step = step)
torch.cuda.empty_cache()
if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0):
gen.eval()
with torch.no_grad():
imgs = []
zero_img = (torch.zeros_like(src_image1[0,...]))
imgs.append(zero_img.cpu().numpy())
save_img = ((src_image1.cpu())* img_std + img_mean).numpy()
for r in range(batch_gpu):
imgs.append(save_img[r,...])
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
id_vector_src1 = arcface(arcface_112)
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
for i in range(batch_gpu):
imgs.append(save_img[i,...])
image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1)
img_fake = gen(image_infer, id_vector_src1).cpu()
img_fake = img_fake * img_std
img_fake = img_fake + img_mean
img_fake = img_fake.numpy()
for j in range(batch_gpu):
imgs.append(img_fake[j,...])
print("Save test data")
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg'))
torch.cuda.empty_cache()
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if rank == 0 and (step+1) % model_freq==0:
torch.save(gen.state_dict(),
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
config["checkpoint_names"]["generator_name"])))
torch.save(dis.state_dict(),
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
config["checkpoint_names"]["discriminator_name"])))
torch.save(g_optimizer.state_dict(),
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
config["checkpoint_names"]["generator_name"])))
torch.save(d_optimizer.state_dict(),
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
config["checkpoint_names"]["discriminator_name"])))
print("Save step %d model checkpoint!"%(step+1))
torch.cuda.empty_cache()
print("Rank %d process done!"%rank)
torch.distributed.barrier()
@@ -0,0 +1,64 @@
# Related scripts
train_script_name: multi_gpu
# models' scripts
model_configs:
g_model:
script: Generator_ori_modulation_config
class_name: Generator
module_params:
id_dim: 512
g_kernel_size: 3
in_channel: 8
res_num: 9
d_model:
script: projected_discriminator
class_name: ProjectedDiscriminator
module_params:
diffaug: False
interp224: False
backbone_kwargs: {}
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
# Training information
batch_size: 32
# Dataset
dataloader: VGGFace2HQ_multigpu
dataset_name: vggface2_hq
dataset_params:
random_seed: 1234
dataloader_workers: 4
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: 0.0006
betas: [ 0, 0.99]
eps: !!float 1e-8
d_optim_config:
lr: 0.0006
betas: [ 0, 0.99]
eps: !!float 1e-8
id_weight: 20.0
reconstruct_weight: 10.0
feature_match_weight: 10.0
# Log
log_step: 300
model_save_step: 10000
sample_step: 1000
total_step: 1000000
checkpoint_names:
generator_name: Generator
discriminator_name: Discriminator