This commit is contained in:
chenxuanhong
2022-01-21 18:01:36 +08:00
parent e698d99173
commit bebaeef2ce
8 changed files with 678 additions and 101 deletions
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 8:22:37 pm
# Last Modified: Sunday, 4th July 2021 11:32:14 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -35,13 +35,13 @@ class Tester(object):
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'TestDataset')
dataloader = dataloaderClass(config["test_data_path"],
1,
config["batch_size"],
["png","jpg"])
self.test_loader= dataloader
self.test_iter = len(dataloader)
# if len(dataloader)%config["batch_size"]>0:
# self.test_iter+=1
self.test_iter = len(dataloader)//config["batch_size"]
if len(dataloader)%config["batch_size"]>0:
self.test_iter+=1
def __init_framework__(self):
@@ -52,14 +52,19 @@ class Tester(object):
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
script_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
script_name = "components."+self.config["module_script_name"]
class_name = self.config["class_name"]
package = __import__(script_name, fromlist=True)
network_class = getattr(package, class_name)
n_class = len(self.config["selectedStyleDir"])
# TODO replace below lines to define the model framework
self.network = network_class(**model_config["g_model"]["module_params"])
self.network = network_class(self.config["GConvDim"],
self.config["GKS"],
self.config["resNum"],
n_class
#**self.config["module_params"]
)
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
@@ -68,14 +73,12 @@ class Tester(object):
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_epoch"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path))
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
# print(loader1.key())
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"])
# self.network.load_state_dict(torch.load(pathwocao))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"]))
def test(self):
@@ -84,13 +87,18 @@ class Tester(object):
ckp_epoch = self.config["checkpoint_epoch"]
version = self.config["version"]
batch_size = self.config["batch_size"]
win_size = self.config["model_configs"]["g_model"]["module_params"]["window_size"]
style_names = self.config["selectedStyleDir"]
n_class = len(style_names)
# models
self.__init_framework__()
condition_labels = torch.ones((n_class, batch_size, 1)).long()
for i in range(n_class):
condition_labels[i,:,:] = condition_labels[i,:,:]*i
if self.config["cuda"] >=0:
condition_labels = condition_labels.cuda()
total = len(self.test_loader)
print("total:", total)
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
@@ -98,25 +106,18 @@ class Tester(object):
start_time = time.time()
self.network.eval()
with torch.no_grad():
for _ in tqdm(range(total)):
for _ in tqdm(range(total//batch_size)):
contents, img_names = self.test_loader()
B, C, H, W = contents.shape
crop_h = H - H%32
crop_w = W - W%32
crop_s = min(crop_h, crop_w)
contents = contents[:,:,(H//2 - crop_s//2):(crop_s//2 + H//2),
(W//2 - crop_s//2):(crop_s//2 + W//2)]
if self.config["cuda"] >=0:
contents = contents.cuda()
res = self.network(contents, (crop_s, crop_s))
print("res shape:", res.shape)
res = tensor2img(res.cpu())
temp_img = res[0,:,:,:]
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
print(save_dir)
print(img_names[0])
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}.png'.format(
img_names[0], version, ckp_epoch)),temp_img)
for i in range(n_class):
if self.config["cuda"] >=0:
contents = contents.cuda()
res, _ = self.network(contents, condition_labels[i, 0, :])
res = tensor2img(res.cpu())
for t in range(batch_size):
temp_img = res[t,:,:,:]
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format(
img_names[t], version, ckp_epoch, style_names[i])),temp_img)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
+174 -51
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 4th July 2021 11:32:14 am
# Last Modified: Friday, 21st January 2022 11:06:37 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
@@ -15,12 +15,24 @@
import os
import cv2
import time
import shutil
import torch
from utilities.utilities import tensor2img
import torch.nn.functional as F
from torchvision import transforms
# from utilities.Reporter import Reporter
from moviepy.editor import AudioFileClip, VideoFileClip
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
import numpy as np
from tqdm import tqdm
from PIL import Image
import glob
from utilities.ImagenetNorm import ImagenetNorm
from parsing_model.model import BiSeNet
from insightface_func.face_detect_crop_single import Face_detect_crop
from utilities.reverse2original import reverse2wholeimage
class Tester(object):
def __init__(self, config, reporter):
@@ -29,20 +41,126 @@ class Tester(object):
# logger
self.reporter = reporter
#============build evaluation dataloader==============#
print("Prepare the test dataloader...")
dlModulename = config["test_dataloader"]
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'TestDataset')
dataloader = dataloaderClass(config["test_data_path"],
config["batch_size"],
["png","jpg"])
self.test_loader= dataloader
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def cv2totensor(self, cv2_img):
"""
cv2_img: an image read by cv2, H*W*C
return: an 1*C*H*W tensor
"""
cv2_img = cv2.cvtColor(cv2_img,cv2.COLOR_BGR2RGB)
cv2_img = torch.from_numpy(cv2_img)
cv2_img = cv2_img.permute(2,0,1).cuda()
temp = cv2_img / 255.0
temp -= self.imagenet_mean
temp /= self.imagenet_std
return temp.unsqueeze(0)
self.test_iter = len(dataloader)//config["batch_size"]
if len(dataloader)%config["batch_size"]>0:
self.test_iter+=1
def video_swap(
self,
video_path,
id_vetor,
save_path,
temp_results_dir='./temp_results',
crop_size=512,
use_mask =False
):
video_forcheck = VideoFileClip(video_path)
if video_forcheck.audio is None:
no_audio = True
else:
no_audio = False
del video_forcheck
if not no_audio:
video_audio_clip = AudioFileClip(video_path)
video = cv2.VideoCapture(video_path)
ret = True
frame_index = 0
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
# video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = video.get(cv2.CAP_PROP_FPS)
if os.path.exists(temp_results_dir):
shutil.rmtree(temp_results_dir)
spNorm =ImagenetNorm()
if use_mask:
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = os.path.join('./parsing_model', '79999_iter.pth')
net.load_state_dict(torch.load(save_pth))
net.eval()
else:
net =None
# while ret:
for frame_index in tqdm(range(frame_count)):
ret, frame = video.read()
if ret:
detect_results = self.detect.get(frame,crop_size)
if detect_results is not None:
# print(frame_index)
if not os.path.exists(temp_results_dir):
os.mkdir(temp_results_dir)
frame_align_crop_list = detect_results[0]
frame_mat_list = detect_results[1]
swap_result_list = []
frame_align_crop_tenor_list = []
for frame_align_crop in frame_align_crop_list:
frame_align_crop_tenor = self.cv2totensor(frame_align_crop)
swap_result = self.network(frame_align_crop_tenor, id_vetor)[0]
swap_result = swap_result* self.imagenet_std + self.imagenet_mean
swap_result = torch.clip(swap_result,0.0,1.0)
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
swap_result_list.append(swap_result)
frame_align_crop_tenor_list.append(frame_align_crop_tenor)
reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame,\
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),pasring_model =net,use_mask=use_mask, norm = spNorm)
else:
if not os.path.exists(temp_results_dir):
os.mkdir(temp_results_dir)
frame = frame.astype(np.uint8)
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
else:
break
video.release()
# image_filename_list = []
path = os.path.join(temp_results_dir,'*.jpg')
image_filenames = sorted(glob.glob(path))
clips = ImageSequenceClip(image_filenames,fps = fps)
if not no_audio:
clips = clips.set_audio(video_audio_clip)
basename = os.path.basename(video_path)
basename = os.path.splitext(basename)[0]
save_filename = os.path.join(save_path, basename+".mp4")
index = 0
while(True):
if os.path.exists(save_filename):
save_filename = os.path.join(save_path, basename+"_%d.mp4"%index)
index += 1
else:
break
clips.write_videofile(save_filename,audio_codec='aac')
def __init_framework__(self):
'''
@@ -52,53 +170,68 @@ class Tester(object):
#===============build models================#
print("build models...")
# TODO [import models here]
script_name = "components."+self.config["module_script_name"]
class_name = self.config["class_name"]
package = __import__(script_name, fromlist=True)
network_class = getattr(package, class_name)
n_class = len(self.config["selectedStyleDir"])
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = network_class(self.config["GConvDim"],
self.config["GKS"],
self.config["resNum"],
n_class
#**self.config["module_params"]
)
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
# print(loader1.key())
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"])
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path))
# self.network.load_state_dict(torch.load(pathwocao))
print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"]))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
def test(self):
# save_result = self.config["saveTestResult"]
save_dir = self.config["test_samples_path"]
ckp_epoch = self.config["checkpoint_epoch"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
batch_size = self.config["batch_size"]
style_names = self.config["selectedStyleDir"]
n_class = len(style_names)
id_imgs = self.config["id_imgs"]
attr_files = self.config["attr_files"]
self.arcface_ckpt= self.config["arcface_ckpt"]
# models
self.__init_framework__()
condition_labels = torch.ones((n_class, batch_size, 1)).long()
for i in range(n_class):
condition_labels[i,:,:] = condition_labels[i,:,:]*i
if self.config["cuda"] >=0:
condition_labels = condition_labels.cuda()
total = len(self.test_loader)
mode = None
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode)
id_img = cv2.imread(id_imgs)
id_img_align_crop, _ = self.detect.get(id_img,512)
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
@@ -106,18 +239,8 @@ class Tester(object):
start_time = time.time()
self.network.eval()
with torch.no_grad():
for _ in tqdm(range(total//batch_size)):
contents, img_names = self.test_loader()
for i in range(n_class):
if self.config["cuda"] >=0:
contents = contents.cuda()
res, _ = self.network(contents, condition_labels[i, 0, :])
res = tensor2img(res.cpu())
for t in range(batch_size):
temp_img = res[t,:,:,:]
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format(
img_names[t], version, ckp_epoch, style_names[i])),temp_img)
self.video_swap(attr_files, latend_id, save_dir, temp_results_dir="./.temples",\
use_mask=False,crop_size=512)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))