Files
SimSwapPlus/test_scripts/tester_video.py
T
chenxuanhong bebaeef2ce update
2022-01-21 18:01:36 +08:00

247 lines
9.7 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_commonn.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 21st January 2022 11:06:37 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import shutil
import torch
import torch.nn.functional as F
from torchvision import transforms
from moviepy.editor import AudioFileClip, VideoFileClip
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
import numpy as np
from tqdm import tqdm
from PIL import Image
import glob
from utilities.ImagenetNorm import ImagenetNorm
from parsing_model.model import BiSeNet
from insightface_func.face_detect_crop_single import Face_detect_crop
from utilities.reverse2original import reverse2wholeimage
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
self.transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
def cv2totensor(self, cv2_img):
"""
cv2_img: an image read by cv2, H*W*C
return: an 1*C*H*W tensor
"""
cv2_img = cv2.cvtColor(cv2_img,cv2.COLOR_BGR2RGB)
cv2_img = torch.from_numpy(cv2_img)
cv2_img = cv2_img.permute(2,0,1).cuda()
temp = cv2_img / 255.0
temp -= self.imagenet_mean
temp /= self.imagenet_std
return temp.unsqueeze(0)
def video_swap(
self,
video_path,
id_vetor,
save_path,
temp_results_dir='./temp_results',
crop_size=512,
use_mask =False
):
video_forcheck = VideoFileClip(video_path)
if video_forcheck.audio is None:
no_audio = True
else:
no_audio = False
del video_forcheck
if not no_audio:
video_audio_clip = AudioFileClip(video_path)
video = cv2.VideoCapture(video_path)
ret = True
frame_index = 0
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
# video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = video.get(cv2.CAP_PROP_FPS)
if os.path.exists(temp_results_dir):
shutil.rmtree(temp_results_dir)
spNorm =ImagenetNorm()
if use_mask:
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = os.path.join('./parsing_model', '79999_iter.pth')
net.load_state_dict(torch.load(save_pth))
net.eval()
else:
net =None
# while ret:
for frame_index in tqdm(range(frame_count)):
ret, frame = video.read()
if ret:
detect_results = self.detect.get(frame,crop_size)
if detect_results is not None:
# print(frame_index)
if not os.path.exists(temp_results_dir):
os.mkdir(temp_results_dir)
frame_align_crop_list = detect_results[0]
frame_mat_list = detect_results[1]
swap_result_list = []
frame_align_crop_tenor_list = []
for frame_align_crop in frame_align_crop_list:
frame_align_crop_tenor = self.cv2totensor(frame_align_crop)
swap_result = self.network(frame_align_crop_tenor, id_vetor)[0]
swap_result = swap_result* self.imagenet_std + self.imagenet_mean
swap_result = torch.clip(swap_result,0.0,1.0)
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
swap_result_list.append(swap_result)
frame_align_crop_tenor_list.append(frame_align_crop_tenor)
reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame,\
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),pasring_model =net,use_mask=use_mask, norm = spNorm)
else:
if not os.path.exists(temp_results_dir):
os.mkdir(temp_results_dir)
frame = frame.astype(np.uint8)
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
else:
break
video.release()
# image_filename_list = []
path = os.path.join(temp_results_dir,'*.jpg')
image_filenames = sorted(glob.glob(path))
clips = ImageSequenceClip(image_filenames,fps = fps)
if not no_audio:
clips = clips.set_audio(video_audio_clip)
basename = os.path.basename(video_path)
basename = os.path.splitext(basename)[0]
save_filename = os.path.join(save_path, basename+".mp4")
index = 0
while(True):
if os.path.exists(save_filename):
save_filename = os.path.join(save_path, basename+"_%d.mp4"%index)
index += 1
else:
break
clips.write_videofile(save_filename,audio_codec='aac')
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.network = gen_class(**model_config["g_model"]["module_params"])
# TODO replace below lines to define the model framework
self.network = gen_class(**model_config["g_model"]["module_params"])
self.network = self.network.eval()
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface1['model'].module
self.arcface.eval()
self.arcface.requires_grad_(False)
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
self.arcface = self.arcface.cuda()
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
# print(loader1.key())
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
model_path = os.path.join(self.config["project_checkpoints"],
"step%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path))
# self.network.load_state_dict(torch.load(pathwocao))
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
def test(self):
# save_result = self.config["saveTestResult"]
save_dir = self.config["test_samples_path"]
ckp_step = self.config["checkpoint_step"]
version = self.config["version"]
id_imgs = self.config["id_imgs"]
attr_files = self.config["attr_files"]
self.arcface_ckpt= self.config["arcface_ckpt"]
# models
self.__init_framework__()
mode = None
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
self.detect.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode)
id_img = cv2.imread(id_imgs)
id_img_align_crop, _ = self.detect.get(id_img,512)
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
id_img = self.transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0).cuda()
#create latent id
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
latend_id = self.arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
with torch.no_grad():
self.video_swap(attr_files, latend_id, save_dir, temp_results_dir="./.temples",\
use_mask=False,crop_size=512)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))