Added the ability for using mask

This commit is contained in:
NNNNAI
2021-07-19 11:48:44 +08:00
parent ca800ef00b
commit a770565268
4 changed files with 180 additions and 21 deletions
+129 -9
View File
@@ -1,13 +1,91 @@
import cv2
import numpy as np
# import time
from util.add_watermark import watermark_image
import torch
from torch.nn import functional as F
import torch.nn as nn
def reverse2wholeimage(swaped_imgs, mats, crop_size, oriimg, logoclass, save_path = '', no_simswaplogo = False):
def encode_segmentation_rgb(segmentation, no_neck=True):
parse = segmentation
face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
mouth_id = 11
hair_id = 17
face_map = np.zeros([parse.shape[0], parse.shape[1]])
mouth_map = np.zeros([parse.shape[0], parse.shape[1]])
hair_map = np.zeros([parse.shape[0], parse.shape[1]])
for valid_id in face_part_ids:
valid_index = np.where(parse==valid_id)
face_map[valid_index] = 255
valid_index = np.where(parse==mouth_id)
mouth_map[valid_index] = 255
valid_index = np.where(parse==hair_id)
hair_map[valid_index] = 255
return np.stack([face_map, mouth_map, hair_map], axis=2)
class SoftErosion(nn.Module):
def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
super(SoftErosion, self).__init__()
r = kernel_size // 2
self.padding = r
self.iterations = iterations
self.threshold = threshold
# Create kernel
y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
kernel = dist.max() - dist
kernel /= kernel.sum()
kernel = kernel.view(1, 1, *kernel.shape)
self.register_buffer('weight', kernel)
def forward(self, x):
x = x.float()
for i in range(self.iterations - 1):
x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
mask = x >= self.threshold
x[mask] = 1.0
x[~mask] /= x[~mask].max()
return x, mask
def postprocess(swapped_face, target, target_mask,smooth_mask):
# target_mask = cv2.resize(target_mask, (self.size, self.size))
mask_tensor = torch.from_numpy(target_mask.copy().transpose((2, 0, 1))).float().mul_(1/255.0).cuda()
face_mask_tensor = mask_tensor[0] + mask_tensor[1]
soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
soft_face_mask_tensor.squeeze_()
soft_face_mask = soft_face_mask_tensor.cpu().numpy()
soft_face_mask = soft_face_mask[:, :, np.newaxis]
result = swapped_face * soft_face_mask + target * (1 - soft_face_mask)
result = result[:,:,::-1]# .astype(np.uint8)
return result
def reverse2wholeimage(b_align_crop_tenor_list,swaped_imgs, mats, crop_size, oriimg, logoclass, save_path = '', \
no_simswaplogo = False,pasring_model =None,norm = None, use_mask = False):
target_image_list = []
img_mask_list = []
for swaped_img, mat in zip(swaped_imgs, mats):
if use_mask:
smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).cuda()
else:
pass
# print(len(swaped_imgs))
# print(mats)
# print(len(b_align_crop_tenor_list))
for swaped_img, mat ,source_img in zip(swaped_imgs, mats,b_align_crop_tenor_list):
swaped_img = swaped_img.cpu().detach().numpy().transpose((1, 2, 0))
img_white = np.full((crop_size,crop_size), 255, dtype=float)
@@ -23,7 +101,27 @@ def reverse2wholeimage(swaped_imgs, mats, crop_size, oriimg, logoclass, save_pat
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
orisize = (oriimg.shape[1], oriimg.shape[0])
target_image = cv2.warpAffine(swaped_img, mat_rev, orisize)
if use_mask:
source_img_norm = norm(source_img)
source_img_512 = F.interpolate(source_img_norm,size=(512,512))
out = pasring_model(source_img_512)[0]
parsing = out.squeeze(0).detach().cpu().numpy().argmax(0)
vis_parsing_anno = parsing.copy().astype(np.uint8)
tgt_mask = encode_segmentation_rgb(vis_parsing_anno)
# face_mask_tensor = tgt_mask[...,0] + tgt_mask[...,1]
target_mask = cv2.resize(tgt_mask, (224, 224))
# print(source_img)
target_image_parsing = postprocess(swaped_img, source_img[0].cpu().detach().numpy().transpose((1, 2, 0)), target_mask,smooth_mask)
target_image_parsing = cv2.warpAffine(target_image_parsing, mat_rev, orisize)
# target_image_parsing = cv2.warpAffine(swaped_img, mat_rev, orisize)
else:
target_image = cv2.warpAffine(swaped_img, mat_rev, orisize)
# source_image = cv2.warpAffine(source_img, mat_rev, orisize)
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
@@ -31,16 +129,39 @@ def reverse2wholeimage(swaped_imgs, mats, crop_size, oriimg, logoclass, save_pat
img_mask = img_white
kernel = np.ones((10,10),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
if use_mask:
kernel = np.ones((10,10),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
else:
kernel = np.ones((40,40),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel_size = (20, 20)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
# kernel = np.ones((10,10),np.uint8)
# img_mask = cv2.erode(img_mask,kernel,iterations = 1)
img_mask /= 255
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
# pasing mask
# target_image_parsing = postprocess(target_image, source_image, tgt_mask)
if use_mask:
target_image = np.array(target_image_parsing, dtype=np.float) * 255
else:
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
img_mask_list.append(img_mask)
target_image_list.append(target_image)
# target_image /= 255
# target_image = 0
img = np.array(oriimg, dtype=np.float)
@@ -52,7 +173,6 @@ def reverse2wholeimage(swaped_imgs, mats, crop_size, oriimg, logoclass, save_pat
final_img = logoclass.apply_frames(final_img)
cv2.imwrite(save_path, final_img)
# cv2.imwrite('E:\\lny\\SimSwap-main\\output\\img_div.jpg', img * 255)
# cv2.imwrite('E:\\lny\\SimSwap-main\\output\\ori_img.jpg', oriimg)
+19 -5
View File
@@ -11,14 +11,15 @@ from moviepy.editor import AudioFileClip, VideoFileClip
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
import time
from util.add_watermark import watermark_image
from util.norm import SpecificNorm
from parsing_model.model import BiSeNet
def _totensor(array):
tensor = torch.from_numpy(array)
img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
return img.float().div(255)
def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False):
def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False):
video_forcheck = VideoFileClip(video_path)
if video_forcheck.audio is None:
no_audio = True
@@ -45,6 +46,17 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r
if os.path.exists(temp_results_dir):
shutil.rmtree(temp_results_dir)
spNorm =SpecificNorm()
if use_mask:
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth')
net.load_state_dict(torch.load(save_pth))
net.eval()
else:
net =None
# while ret:
for frame_index in tqdm(range(frame_count)):
ret, frame = video.read()
@@ -58,7 +70,7 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r
frame_align_crop_list = detect_results[0]
frame_mat_list = detect_results[1]
swap_result_list = []
frame_align_crop_tenor_list = []
for frame_align_crop in frame_align_crop_list:
# BGR TO RGB
@@ -68,10 +80,12 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r
swap_result = swap_model(None, frame_align_crop_tenor, id_vetor, None, True)[0]
swap_result_list.append(swap_result)
frame_align_crop_tenor_list.append(frame_align_crop_tenor)
reverse2wholeimage(swap_result_list, frame_mat_list, crop_size, frame, logoclass,os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo)
reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame, logoclass,\
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask=use_mask, norm = spNorm)
else:
if not os.path.exists(temp_results_dir):
@@ -95,5 +109,5 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r
clips = clips.set_audio(video_audio_clip)
clips.write_videofile(save_path)
clips.write_videofile(save_path,audio_codec='aac')
+17 -4
View File
@@ -13,13 +13,14 @@ import time
from util.add_watermark import watermark_image
from util.norm import SpecificNorm
import torch.nn.functional as F
from parsing_model.model import BiSeNet
def _totensor(array):
tensor = torch.from_numpy(array)
img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
return img.float().div(255)
def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False):
def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False):
video_forcheck = VideoFileClip(video_path)
if video_forcheck.audio is None:
no_audio = True
@@ -49,6 +50,16 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id
spNorm =SpecificNorm()
mse = torch.nn.MSELoss().cuda()
if use_mask:
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth')
net.load_state_dict(torch.load(save_pth))
net.eval()
else:
net =None
# while ret:
for frame_index in tqdm(range(frame_count)):
ret, frame = video.read()
@@ -85,12 +96,13 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id
swap_result_list = []
swap_result_matrix_list = []
swap_result_ori_pic_list = []
for tmp_index, min_index in enumerate(min_indexs):
if min_value[tmp_index] < id_thres:
swap_result = swap_model(None, frame_align_crop_tenor_list[tmp_index], target_id_norm_list[min_index], None, True)[0]
swap_result_list.append(swap_result)
swap_result_matrix_list.append(frame_mat_list[tmp_index])
swap_result_ori_pic_list.append(frame_align_crop_tenor_list[tmp_index])
else:
pass
@@ -98,7 +110,8 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id
if len(swap_result_list) !=0:
reverse2wholeimage(swap_result_list, swap_result_matrix_list, crop_size, frame, logoclass,os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo)
reverse2wholeimage(swap_result_ori_pic_list,swap_result_list, swap_result_matrix_list, crop_size, frame, logoclass,\
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask=use_mask, norm = spNorm)
else:
if not os.path.exists(temp_results_dir):
os.mkdir(temp_results_dir)
@@ -129,5 +142,5 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id
clips = clips.set_audio(video_audio_clip)
clips.write_videofile(save_path)
clips.write_videofile(save_path,audio_codec='aac')
+15 -3
View File
@@ -13,13 +13,14 @@ import time
from util.add_watermark import watermark_image
from util.norm import SpecificNorm
import torch.nn.functional as F
from parsing_model.model import BiSeNet
def _totensor(array):
tensor = torch.from_numpy(array)
img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
return img.float().div(255)
def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False):
def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False):
video_forcheck = VideoFileClip(video_path)
if video_forcheck.audio is None:
no_audio = True
@@ -49,6 +50,16 @@ def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_mod
spNorm =SpecificNorm()
mse = torch.nn.MSELoss().cuda()
if use_mask:
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth')
net.load_state_dict(torch.load(save_pth))
net.eval()
else:
net =None
# while ret:
for frame_index in tqdm(range(frame_count)):
ret, frame = video.read()
@@ -83,7 +94,8 @@ def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_mod
if min_value < id_thres:
swap_result = swap_model(None, frame_align_crop_tenor_list[min_index], id_vetor, None, True)[0]
reverse2wholeimage([swap_result], [frame_mat_list[min_index]], crop_size, frame, logoclass,os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo)
reverse2wholeimage([frame_align_crop_tenor_list[min_index]], [swap_result], [frame_mat_list[min_index]], crop_size, frame, logoclass,\
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask= use_mask, norm = spNorm)
else:
if not os.path.exists(temp_results_dir):
os.mkdir(temp_results_dir)
@@ -114,5 +126,5 @@ def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_mod
clips = clips.set_audio(video_audio_clip)
clips.write_videofile(save_path)
clips.write_videofile(save_path,audio_codec='aac')