diff --git a/util/reverse2original.py b/util/reverse2original.py new file mode 100644 index 0000000..9a27ecb --- /dev/null +++ b/util/reverse2original.py @@ -0,0 +1,55 @@ +import cv2 +import numpy as np +# import time +from util.add_watermark import watermark_image + +def reverse2wholeimage(swaped_imgs, mats, crop_size, oriimg, logoclass, save_path = '',): + + target_image_list = [] + img_mask_list = [] + for swaped_img, mat in zip(swaped_imgs, mats): + swaped_img = swaped_img.cpu().detach().numpy().transpose((1, 2, 0)) + img_white = np.full((crop_size,crop_size), 255, dtype=float) + + # inverse the Affine transformation matrix + mat_rev = np.zeros([2,3]) + div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0] + mat_rev[0][0] = mat[1][1]/div1 + mat_rev[0][1] = -mat[0][1]/div1 + mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1 + div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1] + mat_rev[1][0] = mat[1][0]/div2 + mat_rev[1][1] = -mat[0][0]/div2 + mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 + + orisize = (oriimg.shape[1], oriimg.shape[0]) + target_image = cv2.warpAffine(swaped_img, mat_rev, orisize) + img_white = cv2.warpAffine(img_white, mat_rev, orisize) + + + img_white[img_white>20] =255 + + img_mask = img_white + + kernel = np.ones((10,10),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + + img_mask /= 255 + + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 + + img_mask_list.append(img_mask) + target_image_list.append(target_image) + # target_image /= 255 + # target_image = 0 + img = np.array(oriimg, dtype=np.float) + for img_mask, target_image in zip(img_mask_list, target_image_list): + img = img_mask * target_image + (1-img_mask) * img + + final_img = logoclass.apply_frames(img.astype(np.uint8)) + cv2.imwrite(save_path, final_img) + + # cv2.imwrite('E:\\lny\\SimSwap-main\\output\\img_div.jpg', img * 255) + # cv2.imwrite('E:\\lny\\SimSwap-main\\output\\ori_img.jpg', oriimg) + \ No newline at end of file diff --git a/util/videoswap.py b/util/videoswap.py new file mode 100644 index 0000000..a19cecc --- /dev/null +++ b/util/videoswap.py @@ -0,0 +1,101 @@ +import os +import cv2 +import glob +import torch +import shutil +import numpy as np +from tqdm import tqdm +from util.reverse2original import reverse2wholeimage +import moviepy.editor as mp +from moviepy.editor import AudioFileClip, VideoFileClip +from moviepy.video.io.ImageSequenceClip import ImageSequenceClip +import time +from util.add_watermark import watermark_image + + +def _totensor(array): + tensor = torch.from_numpy(array) + img = tensor.transpose(0, 1).transpose(0, 2).contiguous() + return img.float().div(255) + +def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224): + video_audio_clip = AudioFileClip(video_path) + video = cv2.VideoCapture(video_path) + logoclass = watermark_image('./simswaplogo/simswaplogo.png') + ret = True + frame_index = 0 + + frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + + # video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + + # video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + fps = video.get(cv2.CAP_PROP_FPS) + if os.path.exists(temp_results_dir): + shutil.rmtree(temp_results_dir) + + # while ret: + for frame_index in tqdm(range(frame_count)): + ret, frame = video.read() + if ret: + detect_results = detect_model.get(frame,crop_size) + + if detect_results is not None: + # print(frame_index) + if not os.path.exists(temp_results_dir): + os.mkdir(temp_results_dir) + frame_align_crop_list = detect_results[0] + frame_mat_list = detect_results[1] + swap_result_list = [] + + for frame_align_crop in frame_align_crop_list: + + # BGR TO RGB + # frame_align_crop_RGB = frame_align_crop[...,::-1] + + frame_align_crop_tenor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda() + + swap_result = swap_model(None, frame_align_crop_tenor, id_vetor, None, True)[0] + swap_result_list.append(swap_result) + + + + reverse2wholeimage(swap_result_list, frame_mat_list, crop_size, frame, logoclass,os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index))) + + else: + if not os.path.exists(temp_results_dir): + os.mkdir(temp_results_dir) + cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) + else: + break + + # TODO,是否应该判断这个break是否是异常抛出 + video.release() + + # image_filename_list = [] + path = os.path.join(temp_results_dir,'*.jpg') + image_filenames = sorted(glob.glob(path)) + + clips = ImageSequenceClip(image_filenames,fps = fps) + + final_clips = clips.set_audio(video_audio_clip) + + # logo = (mp.ImageClip("./simswaplogo/simswap.png") + # .set_duration(clips.duration) # 水印持续时间 + # .resize(height=100) # 水印的高度,会等比缩放 + # .margin(right=8, top=8, opacity=1) # 水印边距和透明度 + # .set_pos(("left"))) # 水印的位置 + + # final_clips = mp.CompositeVideoClip([clips, logo]) + + # final_clips.write_videofile("./output/test_beatuy_480p_full.mp4") + final_clips.write_videofile(save_path) + + # video = VideoFileClip(save_path) + + + + + + # video_audio_clip \ No newline at end of file