From a77056526863fb2a1b49287af8fa216f9fd5bb3c Mon Sep 17 00:00:00 2001 From: NNNNAI <844294823@qq.com> Date: Mon, 19 Jul 2021 11:48:44 +0800 Subject: [PATCH] Added the ability for using mask --- util/reverse2original.py | 138 +++++++++++++++++++++++++++++--- util/videoswap.py | 24 ++++-- util/videoswap_multispecific.py | 21 ++++- util/videoswap_specific.py | 18 ++++- 4 files changed, 180 insertions(+), 21 deletions(-) diff --git a/util/reverse2original.py b/util/reverse2original.py index 0a2f99c..fdb729d 100644 --- a/util/reverse2original.py +++ b/util/reverse2original.py @@ -1,13 +1,91 @@ import cv2 import numpy as np # import time -from util.add_watermark import watermark_image +import torch +from torch.nn import functional as F +import torch.nn as nn -def reverse2wholeimage(swaped_imgs, mats, crop_size, oriimg, logoclass, save_path = '', no_simswaplogo = False): + +def encode_segmentation_rgb(segmentation, no_neck=True): + parse = segmentation + + face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14] + mouth_id = 11 + hair_id = 17 + face_map = np.zeros([parse.shape[0], parse.shape[1]]) + mouth_map = np.zeros([parse.shape[0], parse.shape[1]]) + hair_map = np.zeros([parse.shape[0], parse.shape[1]]) + + for valid_id in face_part_ids: + valid_index = np.where(parse==valid_id) + face_map[valid_index] = 255 + valid_index = np.where(parse==mouth_id) + mouth_map[valid_index] = 255 + valid_index = np.where(parse==hair_id) + hair_map[valid_index] = 255 + + return np.stack([face_map, mouth_map, hair_map], axis=2) + + +class SoftErosion(nn.Module): + def __init__(self, kernel_size=15, threshold=0.6, iterations=1): + super(SoftErosion, self).__init__() + r = kernel_size // 2 + self.padding = r + self.iterations = iterations + self.threshold = threshold + + # Create kernel + y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size)) + dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2) + kernel = dist.max() - dist + kernel /= kernel.sum() + kernel = kernel.view(1, 1, *kernel.shape) + self.register_buffer('weight', kernel) + + def forward(self, x): + x = x.float() + for i in range(self.iterations - 1): + x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)) + x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding) + + mask = x >= self.threshold + x[mask] = 1.0 + x[~mask] /= x[~mask].max() + + return x, mask + + +def postprocess(swapped_face, target, target_mask,smooth_mask): + # target_mask = cv2.resize(target_mask, (self.size, self.size)) + + mask_tensor = torch.from_numpy(target_mask.copy().transpose((2, 0, 1))).float().mul_(1/255.0).cuda() + face_mask_tensor = mask_tensor[0] + mask_tensor[1] + + soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0)) + soft_face_mask_tensor.squeeze_() + + soft_face_mask = soft_face_mask_tensor.cpu().numpy() + soft_face_mask = soft_face_mask[:, :, np.newaxis] + + result = swapped_face * soft_face_mask + target * (1 - soft_face_mask) + result = result[:,:,::-1]# .astype(np.uint8) + return result + +def reverse2wholeimage(b_align_crop_tenor_list,swaped_imgs, mats, crop_size, oriimg, logoclass, save_path = '', \ + no_simswaplogo = False,pasring_model =None,norm = None, use_mask = False): target_image_list = [] img_mask_list = [] - for swaped_img, mat in zip(swaped_imgs, mats): + if use_mask: + smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).cuda() + else: + pass + + # print(len(swaped_imgs)) + # print(mats) + # print(len(b_align_crop_tenor_list)) + for swaped_img, mat ,source_img in zip(swaped_imgs, mats,b_align_crop_tenor_list): swaped_img = swaped_img.cpu().detach().numpy().transpose((1, 2, 0)) img_white = np.full((crop_size,crop_size), 255, dtype=float) @@ -23,7 +101,27 @@ def reverse2wholeimage(swaped_imgs, mats, crop_size, oriimg, logoclass, save_pat mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 orisize = (oriimg.shape[1], oriimg.shape[0]) - target_image = cv2.warpAffine(swaped_img, mat_rev, orisize) + if use_mask: + source_img_norm = norm(source_img) + source_img_512 = F.interpolate(source_img_norm,size=(512,512)) + out = pasring_model(source_img_512)[0] + parsing = out.squeeze(0).detach().cpu().numpy().argmax(0) + vis_parsing_anno = parsing.copy().astype(np.uint8) + tgt_mask = encode_segmentation_rgb(vis_parsing_anno) + # face_mask_tensor = tgt_mask[...,0] + tgt_mask[...,1] + target_mask = cv2.resize(tgt_mask, (224, 224)) + + # print(source_img) + target_image_parsing = postprocess(swaped_img, source_img[0].cpu().detach().numpy().transpose((1, 2, 0)), target_mask,smooth_mask) + + + target_image_parsing = cv2.warpAffine(target_image_parsing, mat_rev, orisize) + # target_image_parsing = cv2.warpAffine(swaped_img, mat_rev, orisize) + + else: + target_image = cv2.warpAffine(swaped_img, mat_rev, orisize) + # source_image = cv2.warpAffine(source_img, mat_rev, orisize) + img_white = cv2.warpAffine(img_white, mat_rev, orisize) @@ -31,16 +129,39 @@ def reverse2wholeimage(swaped_imgs, mats, crop_size, oriimg, logoclass, save_pat img_mask = img_white - kernel = np.ones((10,10),np.uint8) - img_mask = cv2.erode(img_mask,kernel,iterations = 1) + if use_mask: + kernel = np.ones((10,10),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + else: + kernel = np.ones((40,40),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + kernel_size = (20, 20) + blur_size = tuple(2*i+1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + + # kernel = np.ones((10,10),np.uint8) + # img_mask = cv2.erode(img_mask,kernel,iterations = 1) + + img_mask /= 255 img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) - target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 + + # pasing mask + + # target_image_parsing = postprocess(target_image, source_image, tgt_mask) + + if use_mask: + target_image = np.array(target_image_parsing, dtype=np.float) * 255 + else: + target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 + img_mask_list.append(img_mask) target_image_list.append(target_image) + + # target_image /= 255 # target_image = 0 img = np.array(oriimg, dtype=np.float) @@ -52,7 +173,6 @@ def reverse2wholeimage(swaped_imgs, mats, crop_size, oriimg, logoclass, save_pat final_img = logoclass.apply_frames(final_img) cv2.imwrite(save_path, final_img) - # cv2.imwrite('E:\\lny\\SimSwap-main\\output\\img_div.jpg', img * 255) - # cv2.imwrite('E:\\lny\\SimSwap-main\\output\\ori_img.jpg', oriimg) + \ No newline at end of file diff --git a/util/videoswap.py b/util/videoswap.py index 82cd814..f81121d 100644 --- a/util/videoswap.py +++ b/util/videoswap.py @@ -11,14 +11,15 @@ from moviepy.editor import AudioFileClip, VideoFileClip from moviepy.video.io.ImageSequenceClip import ImageSequenceClip import time from util.add_watermark import watermark_image - +from util.norm import SpecificNorm +from parsing_model.model import BiSeNet def _totensor(array): tensor = torch.from_numpy(array) img = tensor.transpose(0, 1).transpose(0, 2).contiguous() return img.float().div(255) -def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False): +def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False): video_forcheck = VideoFileClip(video_path) if video_forcheck.audio is None: no_audio = True @@ -45,6 +46,17 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r if os.path.exists(temp_results_dir): shutil.rmtree(temp_results_dir) + spNorm =SpecificNorm() + if use_mask: + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth') + net.load_state_dict(torch.load(save_pth)) + net.eval() + else: + net =None + # while ret: for frame_index in tqdm(range(frame_count)): ret, frame = video.read() @@ -58,7 +70,7 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r frame_align_crop_list = detect_results[0] frame_mat_list = detect_results[1] swap_result_list = [] - + frame_align_crop_tenor_list = [] for frame_align_crop in frame_align_crop_list: # BGR TO RGB @@ -68,10 +80,12 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r swap_result = swap_model(None, frame_align_crop_tenor, id_vetor, None, True)[0] swap_result_list.append(swap_result) + frame_align_crop_tenor_list.append(frame_align_crop_tenor) - reverse2wholeimage(swap_result_list, frame_mat_list, crop_size, frame, logoclass,os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo) + reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame, logoclass,\ + os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask=use_mask, norm = spNorm) else: if not os.path.exists(temp_results_dir): @@ -95,5 +109,5 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r clips = clips.set_audio(video_audio_clip) - clips.write_videofile(save_path) + clips.write_videofile(save_path,audio_codec='aac') diff --git a/util/videoswap_multispecific.py b/util/videoswap_multispecific.py index f5364fe..20b53b6 100644 --- a/util/videoswap_multispecific.py +++ b/util/videoswap_multispecific.py @@ -13,13 +13,14 @@ import time from util.add_watermark import watermark_image from util.norm import SpecificNorm import torch.nn.functional as F +from parsing_model.model import BiSeNet def _totensor(array): tensor = torch.from_numpy(array) img = tensor.transpose(0, 1).transpose(0, 2).contiguous() return img.float().div(255) -def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False): +def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False): video_forcheck = VideoFileClip(video_path) if video_forcheck.audio is None: no_audio = True @@ -49,6 +50,16 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id spNorm =SpecificNorm() mse = torch.nn.MSELoss().cuda() + if use_mask: + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth') + net.load_state_dict(torch.load(save_pth)) + net.eval() + else: + net =None + # while ret: for frame_index in tqdm(range(frame_count)): ret, frame = video.read() @@ -85,12 +96,13 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id swap_result_list = [] swap_result_matrix_list = [] - + swap_result_ori_pic_list = [] for tmp_index, min_index in enumerate(min_indexs): if min_value[tmp_index] < id_thres: swap_result = swap_model(None, frame_align_crop_tenor_list[tmp_index], target_id_norm_list[min_index], None, True)[0] swap_result_list.append(swap_result) swap_result_matrix_list.append(frame_mat_list[tmp_index]) + swap_result_ori_pic_list.append(frame_align_crop_tenor_list[tmp_index]) else: pass @@ -98,7 +110,8 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id if len(swap_result_list) !=0: - reverse2wholeimage(swap_result_list, swap_result_matrix_list, crop_size, frame, logoclass,os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo) + reverse2wholeimage(swap_result_ori_pic_list,swap_result_list, swap_result_matrix_list, crop_size, frame, logoclass,\ + os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask=use_mask, norm = spNorm) else: if not os.path.exists(temp_results_dir): os.mkdir(temp_results_dir) @@ -129,5 +142,5 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id clips = clips.set_audio(video_audio_clip) - clips.write_videofile(save_path) + clips.write_videofile(save_path,audio_codec='aac') diff --git a/util/videoswap_specific.py b/util/videoswap_specific.py index 086eb32..b8d7ee6 100644 --- a/util/videoswap_specific.py +++ b/util/videoswap_specific.py @@ -13,13 +13,14 @@ import time from util.add_watermark import watermark_image from util.norm import SpecificNorm import torch.nn.functional as F +from parsing_model.model import BiSeNet def _totensor(array): tensor = torch.from_numpy(array) img = tensor.transpose(0, 1).transpose(0, 2).contiguous() return img.float().div(255) -def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False): +def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False): video_forcheck = VideoFileClip(video_path) if video_forcheck.audio is None: no_audio = True @@ -49,6 +50,16 @@ def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_mod spNorm =SpecificNorm() mse = torch.nn.MSELoss().cuda() + if use_mask: + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth') + net.load_state_dict(torch.load(save_pth)) + net.eval() + else: + net =None + # while ret: for frame_index in tqdm(range(frame_count)): ret, frame = video.read() @@ -83,7 +94,8 @@ def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_mod if min_value < id_thres: swap_result = swap_model(None, frame_align_crop_tenor_list[min_index], id_vetor, None, True)[0] - reverse2wholeimage([swap_result], [frame_mat_list[min_index]], crop_size, frame, logoclass,os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo) + reverse2wholeimage([frame_align_crop_tenor_list[min_index]], [swap_result], [frame_mat_list[min_index]], crop_size, frame, logoclass,\ + os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask= use_mask, norm = spNorm) else: if not os.path.exists(temp_results_dir): os.mkdir(temp_results_dir) @@ -114,5 +126,5 @@ def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_mod clips = clips.set_audio(video_audio_clip) - clips.write_videofile(save_path) + clips.write_videofile(save_path,audio_codec='aac')