first commit

This commit is contained in:
Your Name
2019-02-01 15:30:41 +08:00
commit df2c0c06fa
81 changed files with 1765 additions and 0 deletions

8
.gitignore vendored Normal file
View File

@@ -0,0 +1,8 @@
data/
checkpoint/
.idea/
__pycache__/
.vscode
demo_faceswap_video.py
predict_128x128.py

View File

@@ -0,0 +1,204 @@
"""
Copyright StrangeAI authors @2019
assume you have to directly which you want
convert A to B, just put all faces of A person to A,
faces of B person to B
"""
import torch
from torch.utils.data import Dataset
import glob
import os
from alfred.dl.torch.common import device
import cv2
from PIL import Image
from torchvision import transforms
import numpy as np
from utils.umeyama import umeyama
import cv2
random_transform_args = {
'rotation_range': 10,
'zoom_range': 0.05,
'shift_range': 0.05,
'random_flip': 0.4,
}
def random_transform(image, rotation_range, zoom_range, shift_range, random_flip):
h, w = image.shape[0:2]
rotation = np.random.uniform(-rotation_range, rotation_range)
scale = np.random.uniform(1 - zoom_range, 1 + zoom_range)
tx = np.random.uniform(-shift_range, shift_range) * w
ty = np.random.uniform(-shift_range, shift_range) * h
mat = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale)
mat[:, 2] += (tx, ty)
result = cv2.warpAffine(image, mat, (w, h), borderMode=cv2.BORDER_REPLICATE)
if np.random.random() < random_flip:
result = result[:, ::-1]
return result
def random_warp_128(image):
assert image.shape == (256, 256, 3), 'resize image to 256 256 first'
range_ = np.linspace(128 - 120, 128 + 120, 9)
mapx = np.broadcast_to(range_, (9, 9))
mapy = mapx.T
mapx = mapx + np.random.normal(size=(9, 9), scale=5)
mapy = mapy + np.random.normal(size=(9, 9), scale=5)
interp_mapx = cv2.resize(mapx, (144, 144))[8:136, 8:136].astype('float32')
interp_mapy = cv2.resize(mapy, (144, 144))[8:136, 8:136].astype('float32')
warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
dst_points = np.mgrid[0:129:16, 0:129:16].T.reshape(-1, 2)
mat = umeyama(src_points, dst_points, True)[0:2]
target_image = cv2.warpAffine(image, mat, (128, 128))
return warped_image, target_image
def random_warp_64(image):
assert image.shape == (256, 256, 3)
range_ = np.linspace(128 - 120, 128 + 120, 5)
mapx = np.broadcast_to(range_, (5, 5))
mapy = mapx.T
mapx = mapx + np.random.normal(size=(5, 5), scale=5)
mapy = mapy + np.random.normal(size=(5, 5), scale=5)
interp_mapx = cv2.resize(mapx, (80, 80))[8:72, 8:72].astype('float32')
interp_mapy = cv2.resize(mapy, (80, 80))[8:72, 8:72].astype('float32')
warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
dst_points = np.mgrid[0:65:16, 0:65:16].T.reshape(-1, 2)
mat = umeyama(src_points, dst_points, True)[0:2]
target_image = cv2.warpAffine(image, mat, (64, 64))
return warped_image, target_image
class FacePairDataset(Dataset):
def __init__(self, a_dir, b_dir, target_size, transform):
super(FacePairDataset, self).__init__
self.a_dir = a_dir
self.b_dir = b_dir
self.target_size = target_size
self.transform = transform
# extension can be changed here to png or others
self.a_images_list = glob.glob(os.path.join(a_dir, '*.png'))
self.b_images_list = glob.glob(os.path.join(b_dir, '*.png'))
def __getitem__(self, index):
# return 2 image pair, A and B
img_a = Image.open(self.a_images_list[index])
img_b = Image.open(self.b_images_list[index])
# align the face first
img_a = img_a.resize((self.target_size, self.target_size), Image.ANTIALIAS)
img_b = img_b.resize((self.target_size, self.target_size), Image.ANTIALIAS)
# transform
if self.transform:
img_a = self.transform(img_a)
img_b = self.transform(img_b)
# already resized, warp it
img_a = random_transform(np.array(img_a), **random_transform_args)
img_b = random_transform(np.array(img_b), **random_transform_args)
img_a_input, img_a = random_warp(np.array(img_a), 256)
img_b_input, img_b = random_warp(np.array(img_b), 256)
img_a_tensor = torch.Tensor(img_a.transpose(2, 0, 1)/255.).float()
img_a_input_tensor = torch.Tensor(img_a_input.transpose(2, 0, 1)/255.).float()
img_b_tensor = torch.Tensor(img_b.transpose(2, 0, 1)/255.).float()
img_b_input_tensor = torch.Tensor(img_b_input.transpose(2, 0, 1)/255.).float()
return img_a_tensor, img_a_input_tensor, img_b_tensor, img_b_input_tensor
def __len__(self):
return min(len(self.a_images_list), len(self.b_images_list))
class FacePairDataset64x64(Dataset):
def __init__(self, a_dir, b_dir, target_size, transform):
super(FacePairDataset64x64, self).__init__
self.a_dir = a_dir
self.b_dir = b_dir
self.target_size = target_size
self.transform = transform
# extension can be changed here to png or others
self.a_images_list = glob.glob(os.path.join(a_dir, '*.png'))
self.b_images_list = glob.glob(os.path.join(b_dir, '*.png'))
def __getitem__(self, index):
# return 2 image pair, A and B
img_a = Image.open(self.a_images_list[index])
img_b = Image.open(self.b_images_list[index])
# align the face first
img_a = img_a.resize((256, 256), Image.ANTIALIAS)
img_b = img_b.resize((256, 256), Image.ANTIALIAS)
# transform
if self.transform:
img_a = self.transform(img_a)
img_b = self.transform(img_b)
# # already resized, warp it
img_a = random_transform(np.array(img_a), **random_transform_args)
img_b = random_transform(np.array(img_b), **random_transform_args)
img_a_input, img_a = random_warp_64(np.array(img_a))
img_b_input, img_b = random_warp_64(np.array(img_b))
img_a = np.array(img_a)
img_b = np.array(img_b)
img_a_tensor = torch.Tensor(img_a.transpose(2, 0, 1)/255.).float()
img_a_input_tensor = torch.Tensor(img_a_input.transpose(2, 0, 1)/255.).float()
img_b_tensor = torch.Tensor(img_b.transpose(2, 0, 1)/255.).float()
img_b_input_tensor = torch.Tensor(img_b_input.transpose(2, 0, 1)/255.).float()
return img_a_tensor, img_a_input_tensor, img_b_tensor, img_b_input_tensor
def __len__(self):
return min(len(self.a_images_list), len(self.b_images_list))
class FacePairDataset128x128(Dataset):
def __init__(self, a_dir, b_dir, target_size, transform):
super(FacePairDataset128x128, self).__init__
self.a_dir = a_dir
self.b_dir = b_dir
self.target_size = target_size
self.transform = transform
self.a_images_list = glob.glob(os.path.join(a_dir, '*.png'))
self.b_images_list = glob.glob(os.path.join(b_dir, '*.png'))
def __getitem__(self, index):
# return 2 image pair, A and B
img_a = Image.open(self.a_images_list[index])
img_b = Image.open(self.b_images_list[index])
# align the face first
img_a = img_a.resize((256, 256), Image.ANTIALIAS)
img_b = img_b.resize((256, 256), Image.ANTIALIAS)
# transform
if self.transform:
img_a = self.transform(img_a)
img_b = self.transform(img_b)
img_a = random_transform(np.array(img_a), **random_transform_args)
img_b = random_transform(np.array(img_b), **random_transform_args)
img_a_input, img_a = random_warp_128(np.array(img_a))
img_b_input, img_b = random_warp_128(np.array(img_b))
img_a_tensor = torch.Tensor(img_a.transpose(2, 0, 1)/255.).float()
img_a_input_tensor = torch.Tensor(img_a_input.transpose(2, 0, 1)/255.).float()
img_b_tensor = torch.Tensor(img_b.transpose(2, 0, 1)/255.).float()
img_b_input_tensor = torch.Tensor(img_b_input.transpose(2, 0, 1)/255.).float()
return img_a_tensor, img_a_input_tensor, img_b_tensor, img_b_input_tensor
def __len__(self):
return min(len(self.a_images_list), len(self.b_images_list))

61
dataset/training_data.py Normal file
View File

@@ -0,0 +1,61 @@
import numpy
from utils.umeyama import umeyama
import cv2
random_transform_args = {
'rotation_range': 10,
'zoom_range': 0.05,
'shift_range': 0.05,
'random_flip': 0.4,
}
def random_transform(image, rotation_range, zoom_range, shift_range, random_flip):
h, w = image.shape[0:2]
rotation = numpy.random.uniform(-rotation_range, rotation_range)
scale = numpy.random.uniform(1 - zoom_range, 1 + zoom_range)
tx = numpy.random.uniform(-shift_range, shift_range) * w
ty = numpy.random.uniform(-shift_range, shift_range) * h
mat = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale)
mat[:, 2] += (tx, ty)
result = cv2.warpAffine(image, mat, (w, h), borderMode=cv2.BORDER_REPLICATE)
if numpy.random.random() < random_flip:
result = result[:, ::-1]
return result
# get pair of random warped images from aligened face image
def random_warp(image):
assert image.shape == (256, 256, 3)
range_ = numpy.linspace(128 - 80, 128 + 80, 5)
mapx = numpy.broadcast_to(range_, (5, 5))
mapy = mapx.T
mapx = mapx + numpy.random.normal(size=(5, 5), scale=5)
mapy = mapy + numpy.random.normal(size=(5, 5), scale=5)
interp_mapx = cv2.resize(mapx, (80, 80))[8:72, 8:72].astype('float32')
interp_mapy = cv2.resize(mapy, (80, 80))[8:72, 8:72].astype('float32')
# just crop the image, remove the top left bottom right 8 pixels (in order to get the pure face)
warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
src_points = numpy.stack([mapx.ravel(), mapy.ravel()], axis=-1)
dst_points = numpy.mgrid[0:65:16, 0:65:16].T.reshape(-1, 2)
mat = umeyama(src_points, dst_points, True)[0:2]
target_image = cv2.warpAffine(image, mat, (64, 64))
return warped_image, target_image
def get_training_data(images, batch_size):
indices = numpy.random.randint(len(images), size=batch_size)
for i, index in enumerate(indices):
image = images[index]
image = random_transform(image, **random_transform_args)
warped_img, target_img = random_warp(image)
if i == 0:
warped_images = numpy.empty((batch_size,) + warped_img.shape, warped_img.dtype)
target_images = numpy.empty((batch_size,) + target_img.shape, warped_img.dtype)
warped_images[i] = warped_img
target_images[i] = target_img
return warped_images, target_images

BIN
images/1.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

BIN
images/2.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

BIN
images/3.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
images/4.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

BIN
images/5.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

34
images/grid_res.py Normal file
View File

@@ -0,0 +1,34 @@
"""
grid a final image from result images
"""
import cv2
import numpy as np
import os
import sys
import glob
from PIL import Image
d = sys.argv[1]
print('from ', d)
all_img_files = glob.glob(os.path.join(d, '*.png'))
assert len(all_img_files) % 6 == 0, 'images divided by 6'
all_img_files = sorted(all_img_files)
rows = len(all_img_files) // 6
print(rows)
print(len(all_img_files))
res_img = Image.new('RGB',(128*6, 128*(len(all_img_files)//6)), (255, 255, 255))
for i in range(len(all_img_files)//6):
for j in range(6):
# print('now: ', all_img_files[6*i + j])
img = Image.open(all_img_files[6*i + j])
res_img.paste(img, (j*128, i*128))
res_img.save('res_grid.png')
print(np.array(res_img).shape)
cv2.imshow('rr', np.array(res_img))
cv2.waitKey(0)

BIN
images/res/1480_0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

BIN
images/res/1480_1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
images/res/1480_2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
images/res/1480_3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

BIN
images/res/1480_4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
images/res/1480_5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
images/res/1490_0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

BIN
images/res/1490_1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

BIN
images/res/1490_2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

BIN
images/res/1490_3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

BIN
images/res/1490_4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
images/res/1490_5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
images/res/15450_0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

BIN
images/res/15450_1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

BIN
images/res/15450_2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

BIN
images/res/15450_3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

BIN
images/res/15450_4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
images/res/15450_5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

BIN
images/res/15780_0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

BIN
images/res/15780_1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
images/res/15780_2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
images/res/15780_3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

BIN
images/res/15780_4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

BIN
images/res/15780_5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

BIN
images/res_grid.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 563 KiB

BIN
images/res_right/1480_0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
images/res_right/1480_1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

BIN
images/res_right/1480_2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

BIN
images/res_right/1480_3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

BIN
images/res_right/1480_4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

BIN
images/res_right/1480_5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

BIN
images/res_right/1490_0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
images/res_right/1490_1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

BIN
images/res_right/1490_2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

BIN
images/res_right/1490_3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

BIN
images/res_right/1490_4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

BIN
images/res_right/1490_5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

BIN
images/resul2t.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

BIN
images/result.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 390 KiB

11
images/rotate.sh Executable file
View File

@@ -0,0 +1,11 @@
d=$1
echo "roate from "$d
for szFile in $d/*.png
do
dd=${d}_right
echo "save to "$dd
if [ ! -d $dd ];then
mkdir $dd
fi
convert "$szFile" -rotate 90 $dd/"$(basename "$szFile")" ;
done

4
images/run.sh Executable file
View File

@@ -0,0 +1,4 @@
rm -r res_right
./rotate.sh res
python3 grid_res.py res_right

8
init_dependencies.sh Normal file
View File

@@ -0,0 +1,8 @@
echo 'dlib should build manually.'
sudo apt-get install ffmpeg x264 libx264-dev
sudo apt-get install xvfb
sudo pip3 install pyvirtualdisplay
sudo pip3 install moviepy
sudo pip3 install face_recognition

2
models/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
swapnet_128.py
swapnet_256.py

126
models/padding_same_conv.py Normal file
View File

@@ -0,0 +1,126 @@
# modify con2d function to use same padding
# code referd to @famssa in 'https://github.com/pytorch/pytorch/issues/3867'
# and tensorflow source code
import torch.utils.data
from torch.nn import functional as F
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.functional import pad
from torch.nn.modules import Module
from torch.nn.modules.utils import _single, _pair, _triple
class _ConvNd(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, groups, bias):
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
if transposed:
self.weight = Parameter(torch.Tensor(
in_channels, out_channels // groups, *kernel_size))
else:
self.weight = Parameter(torch.Tensor(
out_channels, in_channels // groups, *kernel_size))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
class Conv2d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias)
def forward(self, input):
return conv2d_same_padding(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class Conv2dPaddingSame(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2dPaddingSame, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias)
def forward(self, input):
return conv2d_same_padding(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
# custom con2d, because pytorch don't have "padding='same'" option.
def conv2d_same_padding(input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1):
input_rows = input.size(2)
filter_rows = weight.size(2)
effective_filter_size_rows = (filter_rows - 1) * dilation[0] + 1
out_rows = (input_rows + stride[0] - 1) // stride[0]
padding_needed = max(0, (out_rows - 1) * stride[0] + effective_filter_size_rows -
input_rows)
padding_rows = max(0, (out_rows - 1) * stride[0] +
(filter_rows - 1) * dilation[0] + 1 - input_rows)
rows_odd = (padding_rows % 2 != 0)
padding_cols = max(0, (out_rows - 1) * stride[0] +
(filter_rows - 1) * dilation[0] + 1 - input_rows)
cols_odd = (padding_rows % 2 != 0)
if rows_odd or cols_odd:
input = pad(input, [0, int(cols_odd), 0, int(rows_odd)])
return F.conv2d(input, weight, bias, stride,
padding=(padding_rows // 2, padding_cols // 2),
dilation=dilation, groups=groups)

105
models/swapnet.py Normal file
View File

@@ -0,0 +1,105 @@
"""
Copyright StrangeAI Authors @2019
"""
import torch
import torch.utils.data
from torch import nn, optim
from .padding_same_conv import Conv2d
from alfred.dl.torch.common import device
def toTensor(img):
img = torch.from_numpy(img.transpose((0, 3, 1, 2))).to(device)
return img
def var_to_np(img_var):
return img_var.data.cpu().numpy()
class _ConvLayer(nn.Sequential):
def __init__(self, input_features, output_features):
super(_ConvLayer, self).__init__()
self.add_module('conv2', Conv2d(input_features, output_features,
kernel_size=5, stride=2))
self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True))
class _UpScale(nn.Sequential):
def __init__(self, input_features, output_features):
super(_UpScale, self).__init__()
self.add_module('conv2_', Conv2d(input_features, output_features * 4,
kernel_size=3))
self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True))
self.add_module('pixelshuffler', _PixelShuffler())
class Flatten(nn.Module):
def forward(self, input):
output = input.view(input.size(0), -1)
return output
class Reshape(nn.Module):
def forward(self, input):
output = input.view(-1, 1024, 4, 4) # channel * 4 * 4
return output
class _PixelShuffler(nn.Module):
def forward(self, input):
batch_size, c, h, w = input.size()
rh, rw = (2, 2)
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = input.view(batch_size, rh, rw, oc, h, w)
out = out.permute(0, 3, 4, 1, 5, 2).contiguous()
out = out.view(batch_size, oc, oh, ow) # channel first
return out
class SwapNet(nn.Module):
def __init__(self):
super(SwapNet, self).__init__()
self.encoder = nn.Sequential(
_ConvLayer(3, 128),
_ConvLayer(128, 256),
_ConvLayer(256, 512),
_ConvLayer(512, 1024),
Flatten(),
nn.Linear(1024 * 4 * 4, 1024),
nn.Linear(1024, 1024 * 4 * 4),
Reshape(),
_UpScale(1024, 512),
)
self.decoder_A = nn.Sequential(
_UpScale(512, 256),
_UpScale(256, 128),
_UpScale(128, 64),
Conv2d(64, 3, kernel_size=5, padding=1),
nn.Sigmoid(),
)
self.decoder_B = nn.Sequential(
_UpScale(512, 256),
_UpScale(256, 128),
_UpScale(128, 64),
Conv2d(64, 3, kernel_size=5, padding=1),
nn.Sigmoid(),
)
def forward(self, x, select='A'):
if select == 'A':
out = self.encoder(x)
out = self.decoder_A(out)
else:
out = self.encoder(x)
out = self.decoder_B(out)
return out

71
predict_64x64.py Normal file
View File

@@ -0,0 +1,71 @@
"""
convert a face to another person
"""
from models.swapnet import SwapNet
import torch
from alfred.dl.torch.common import device
import cv2
import numpy as np
from dataset.training_data import random_warp
from utils.umeyama import umeyama
mean_value = np.array([0.03321508, 0.05035182, 0.02038819])
def process_img(ori_img):
img = cv2.resize(ori_img, (256, 256))
range_ = np.linspace( 128-80, 128+80, 5 )
mapx = np.broadcast_to( range_, (5,5) )
mapy = mapx.T
# warp image like in the training
mapx = mapx + np.random.normal( size=(5,5), scale=5 )
mapy = mapy + np.random.normal( size=(5,5), scale=5 )
interp_mapx = cv2.resize(mapx, (80, 80))[8:72, 8:72].astype('float32')
interp_mapy = cv2.resize(mapy, (80, 80))[8:72, 8:72].astype('float32')
warped_image = cv2.remap(img, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
return warped_image
def load_img():
a = 'images/34600_test_A_target.png'
img = cv2.imread(a) / 255.
return img
def predict():
# convert trump to cage
# img_f = 'data/trump/51834796.jpg'
# img_f = 'data/trump/494045244.jpg'
# NOTE: using face extracted (not original image)
img_f = 'data/trump/464669134_face_0.png'
ori_img = cv2.imread(img_f)
img = cv2.resize(ori_img, (64, 64)) / 255.
img = np.rot90(img)
# img = load_img()
in_img = np.array(img, dtype=np.float).transpose(2, 1, 0)
# normalize img
in_img = torch.Tensor(in_img).to(device).unsqueeze(0)
model = SwapNet().to(device)
if torch.cuda.is_available():
checkpoint = torch.load('checkpoint/faceswap_trump_cage_64x64.pth')
else:
checkpoint = torch.load('checkpoint/faceswap_trump_cage_64x64.pth', map_location={'cuda:0': 'cpu'})
model.load_state_dict(checkpoint['state'])
model.eval()
print('model loaded.')
out = model.forward(in_img, select='B')
out = np.clip(out.detach().cpu().numpy()[0]*255, 0, 255).astype('uint8').transpose(2, 1, 0)
cv2.imshow('original image', ori_img)
cv2.imshow('network input image', img)
cv2.imshow('result image', np.rot90(out, axes=(1, 0)))
cv2.waitKey(0)
if __name__ == '__main__':
predict()

93
readme.md Normal file
View File

@@ -0,0 +1,93 @@
# High Resolution Face Swap
**a face swap implementation with much more higher resolution result (128x128) **, this is a promoted and optimized *swap face application* based on GAN tech. our implementation did those changes based on original *deepfakes* implementation:
- *deepfakes* only support 64x64 input, we make it **deeper** and can output 128x128 size;
- we proposed a new network called *SwapNet, SwapNet128*;
- we changed the pre-proccess step with input data (such as warp face), make it more clear;
- we make the dataset loader more efficient, load pair face data directly from 2 dir;
- we proposed a new **face outline replace** tech to get a much more combination result with
original image, their differences are like below image.
we will continuely update this repo, and make face swap much more intuitive and simple, anyone can build there own face changing model. Here are some result for 128x128 higher resolution face swap:
<p align="center">
<img src="https://s2.ax1x.com/2019/02/01/k3uReA.png">
</p>
We have train on trump-cage and fanbingbing-galgadot convert model. The result is not fully trained yet, but it shows a promising result, the face in most situation can works perfect!
final result on face swap directly from original big image:
<p align="center">
<img src="https://s2.ax1x.com/2019/02/01/k30qoV.png">
</p>
<p align="center">
<img src="https://s2.ax1x.com/2019/02/01/k3BCe1.png">
</p>
As you can see above, we can achieve **high resolution** and seamlessly combination with face transformation. final result on face swap directly from video (to be added soon):
## Dependencies
our face swap implementation need *alfred-py* which can installed with:
```
sudo pip3 install alfred-py
```
## Pretrained Model
We only provided pretrained model for 128x128 model, and it was hosted by StrangeAI (http://codes.strangeai.pro). For train from scratch, you can download the trump cage dataset from: https://anonfile.com/p7w3m0d5be/face-swap.zip .
For those already StrangeAI VIP membership users, you can download the whole codes and models from http://strangeai.pro .
## Train & Predict
the run, simply using:
```
python3 predict.py
# train fanbingbing-galgadot face swap
python3 train_trump_cage_64x64.py
python3 train_fbb_gal_128x128.py
```
this will predict on a trump face and convert it into cage face.
## More Info
if you wanna be invited to our computer vision discussion wechat group, you can add me via wechat or found us at: http://strangeai.pro which is **the biggest AI codes sharing platform in China**.
## Note About FaceSwap
We have did some failure attempt and experiments lots of combination to produce a good result, here are some notes you need to know to build a face swap tech:
- Size is everything: we have try maximum 256x256 as input size, but it fails to swap face style between 2 faces;
- Warp preprocess does not really matter, we have also trying to remove warp preprocess step and directly using target images for train, it can also success train a face swap model, but for dataset augumentation, better to warp it and make some random transform;
- loss is not really matter. Just kick of train, and train about 15000 epochs, and you can get good result;
- For data preparing, better extract faces first using dlib or [alfred](http://github.com/jinfagang/alfred)
## Faceswap Datasets
Actually, we gathered a lot of faces datasets. beside the default one, you may also access them via Baidu cloud disk.
## Copyright
*FaceSwap* is a project opensourced under MIT license, all right reserved by StrangeAI authors. website: http://strangeai.pro

BIN
result.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

52
tests.py Normal file
View File

@@ -0,0 +1,52 @@
from models.swapnet import SwapNet
from models.swapnet_128 import SwapNet128
from utils.model_summary import summary
from alfred.dl.torch.common import device
from dataset.face_pair_dataset import random_warp_128
from dataset.training_data import random_transform, random_transform_args
from PIL import Image
import cv2
import numpy as np
import torch
from utils.umeyama import umeyama
# model = SwapNet().to(device)
# summary(model, input_size=(3, 64, 64))
# def random_warp(image):
# assert image.shape == (256, 256, 3)
# range_ = np.linspace(128 - 120, 128 + 120, 5)
# mapx = np.broadcast_to(range_, (5, 5))
# mapy = mapx.T
# mapx = mapx + np.random.normal(size=(5, 5), scale=5)
# mapy = mapy + np.random.normal(size=(5, 5), scale=5)
# interp_mapx = cv2.resize(mapx, (80, 80))[8:72, 8:72].astype('float32')
# interp_mapy = cv2.resize(mapy, (80, 80))[8:72, 8:72].astype('float32')
# # just crop the image, remove the top left bottom right 8 pixels (in order to get the pure face)
# warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
# src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
# dst_points = np.mgrid[0:65:16, 0:65:16].T.reshape(-1, 2)
# mat = umeyama(src_points, dst_points, True)[0:2]
# target_image = cv2.warpAffine(image, mat, (64, 64))
# return warped_image, target_image
# model = SwapNet128().to(device)
# summary(model, input_size=(3, 128, 128))
# a = Image.open('data/trump_cage/cage/2455911_face_0.png')
# a = a.resize((256, 256), Image.ANTIALIAS)
# a = random_transform(np.array(a), **random_transform_args)
# warped_img, target_img = random_warp_128(np.array(a))
# t = torch.from_numpy(target_img.transpose(2, 0, 1) / 255.).to(device)
# b = t.detach().cpu().numpy().transpose((2, 1, 0))*255
# print(b.shape)
# cv2.imshow('rr', np.array(a))
# cv2.imshow('warped image', np.array(warped_img))
# cv2.imshow('target image', np.array(target_img))
# cv2.imshow('bbbbbbbbb', b)
# cv2.waitKey(0)

View File

@@ -0,0 +1,3 @@
cd ~
wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
bzip2 -d shape_predictor_68_face_landmarks.dat.bz2

View File

@@ -0,0 +1 @@
wget https://anonfile.com/p7w3m0d5be/face-swap.zip

7
tools/extract_faces.sh Normal file
View File

@@ -0,0 +1,7 @@
#!/usr/bin/env bash
# this script wille extract all faces from a directory of images
# it using alfred-py and dlib to do this
# the size of faces are does not matter, it will be resized according to faceswap networks
sudo pip3 install alfred-py

158
train_fbb_gal_128x128.py Normal file
View File

@@ -0,0 +1,158 @@
"""
Copyright StrangeAI Authors @2019
As the network without linear connect layer
the feature are not compressed, so the encoder are weak
it consist to many informations, and decoder can not using the abstract
information to construct a new image
"""
from __future__ import print_function
import argparse
import os
import cv2
import numpy as np
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
from utils.util import get_image_paths, load_images, stack_images
from dataset.training_data import get_training_data
from alfred.dl.torch.common import device
from shutil import copyfile
try:
from models.swapnet_128 import SwapNet128, toTensor, var_to_np
except Exception:
print('can not import swapnet128, if you need high resolution face swap, '
'you can download from http://luoli.ai (you can afford a VIP membership to get all other codes)')
from loguru import logger
from dataset.face_pair_dataset import FacePairDataset128x128
from torchvision import transforms
from torch.utils.data import DataLoader
from alfred.utils.log import init_logger
init_logger()
batch_size = 32
epochs = 100000
save_per_epoch = 300
a_dir = './data/galgadot_fbb/fanbingbing_faces'
b_dir = './data/galgadot_fbb/galgadot_faces'
# we start to train on bigger size
dataset_name = 'galgadot_fbb'
target_size = 128
log_img_dir = './checkpoint/results_{}_{}x{}'.format(dataset_name, target_size, target_size)
log_model_dir = './checkpoint/{}_{}x{}'.format(dataset_name,
target_size, target_size)
check_point_save_path = os.path.join(
log_model_dir, 'faceswap_{}_{}x{}.pth'.format(dataset_name, target_size, target_size))
def main():
os.makedirs(log_img_dir, exist_ok=True)
os.makedirs(log_model_dir, exist_ok=True)
logger.info("loading datasets")
transform = transforms.Compose([
# transforms.Resize((target_size, target_size)),
transforms.RandomHorizontalFlip(),
# transforms.RandomVerticalFlip(),
# transforms.ToTensor(),
])
ds = FacePairDataset128x128(a_dir=a_dir, b_dir=b_dir,
target_size=target_size, transform=transform)
dataloader = DataLoader(ds, batch_size, shuffle=True)
model = SwapNet128()
model.to(device)
start_epoch = 0
logger.info('try resume from checkpoint')
try:
if torch.cuda.is_available():
checkpoint = torch.load(check_point_save_path)
else:
checkpoint = torch.load(
check_point_save_path, map_location={'cuda:0': 'cpu'})
model.load_state_dict(checkpoint['state'])
start_epoch = checkpoint['epoch']
logger.info('checkpoint loaded.')
except FileNotFoundError:
print('Can\'t found {}'.format(check_point_save_path))
criterion = nn.L1Loss()
optimizer_1 = optim.Adam([{'params': model.encoder.parameters()},
{'params': model.decoder_a.parameters()}], lr=5e-5, betas=(0.5, 0.999))
optimizer_2 = optim.Adam([{'params': model.encoder.parameters()},
{'params': model.decoder_b.parameters()}], lr=5e-5, betas=(0.5, 0.999))
logger.info('Start training, from epoch {} '.format(start_epoch))
try:
for epoch in range(start_epoch, epochs):
iter = 0
for data in dataloader:
iter += 1
img_a_target, img_a_input, img_b_target, img_b_input = data
img_a_target = img_a_target.to(device)
img_a_input = img_a_input.to(device)
img_b_target = img_b_target.to(device)
img_b_input = img_b_input.to(device)
# print(img_a.size())
# print(img_b.size())
optimizer_1.zero_grad()
optimizer_2.zero_grad()
predict_a = model(img_a_input, to='a')
predict_b = model(img_b_input, to='b')
loss1 = criterion(predict_a, img_a_target)
loss2 = criterion(predict_b, img_b_target)
loss1.backward()
loss2.backward()
optimizer_1.step()
optimizer_2.step()
logger.info('Epoch: {}, iter: {}, lossA: {}, lossB: {}'.format(
epoch, iter, loss1.item(), loss2.item()))
if epoch % save_per_epoch == 0 and epoch != 0:
logger.info('Saving models...')
state = {
'state': model.state_dict(),
'epoch': epoch
}
torch.save(state, os.path.join(os.path.dirname(
check_point_save_path), 'faceswap_{}_128x128_{}.pth'.format(dataset_name, epoch)))
copyfile(os.path.join(os.path.dirname(check_point_save_path), 'faceswap_{}_128x128_{}.pth'.format(dataset_name, epoch)),
check_point_save_path)
if epoch % 10 == 0 and epoch != 0 and iter == 1:
img_a_original = np.array(img_a_target.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
img_b_original = np.array(img_b_target.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
a_predict_a = np.array(predict_a.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
b_predict_b = np.array(predict_b.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
a_predict_b = model(img_a_input, to='b')
b_predict_a = model(img_b_input, to='a')
a_predict_b = np.array(a_predict_b.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
b_predict_a = np.array(b_predict_a.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
cv2.imwrite(os.path.join(log_img_dir, '{}_0.png'.format(epoch)), cv2.cvtColor(img_a_original, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_3.png'.format(epoch)), cv2.cvtColor(img_b_original, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_1.png'.format(epoch)), cv2.cvtColor(a_predict_a, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_4.png'.format(epoch)), cv2.cvtColor(b_predict_b, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_2.png'.format(epoch)), cv2.cvtColor(a_predict_b, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_5.png'.format(epoch)), cv2.cvtColor(b_predict_a, cv2.COLOR_BGR2RGB))
logger.info('Record a result')
except KeyboardInterrupt:
logger.info('try saving models...')
state = {
'state': model.state_dict(),
'epoch': epoch
}
torch.save(state, os.path.join(os.path.dirname(check_point_save_path), 'faceswap_{}_128x128_{}.pth'.format(dataset_name, epoch)))
copyfile(os.path.join(os.path.dirname(check_point_save_path), 'faceswap_{}_128x128_{}.pth'.format(dataset_name, epoch)),
check_point_save_path)
if __name__ == "__main__":
main()

157
train_fbb_gal_64x64.py Normal file
View File

@@ -0,0 +1,157 @@
"""
Copyright StrangeAI Authors @2019
original forked from deepfakes repo
edit and promoted by StrangeAI authors
"""
from __future__ import print_function
import argparse
import os
import cv2
import numpy as np
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from models.swapnet import SwapNet, toTensor, var_to_np
from utils.util import get_image_paths, load_images, stack_images
from dataset.training_data import get_training_data
from alfred.dl.torch.common import device
from shutil import copyfile
from loguru import logger
from dataset.face_pair_dataset import FacePairDataset, FacePairDataset64x64
from torchvision import transforms
import sys
logger.remove() # Remove the pre-configured handler
logger.start(sys.stderr, format="<lvl>{level}</lvl> {time:MM-DD HH:mm:ss} {file}:{line} - {message}")
batch_size = 64
epochs = 100000
save_per_epoch = 300
a_dir = './data/galgadot_fbb/fanbingbing_faces'
b_dir = './data/galgadot_fbb/galgadot_faces'
# we start to train on bigger size
target_size = 64
dataset_name = 'galgadot_fbb'
log_img_dir = './checkpoint/results_{}_{}x{}'.format(dataset_name, target_size, target_size)
log_model_dir = './checkpoint/{}_{}x{}'.format(dataset_name,
target_size, target_size)
check_point_save_path = os.path.join(
log_model_dir, 'faceswap_{}_{}x{}.pth'.format(dataset_name, target_size, target_size))
def main():
os.makedirs(log_img_dir, exist_ok=True)
os.makedirs(log_model_dir, exist_ok=True)
transform = transforms.Compose([
# transforms.Resize((target_size, target_size)),
transforms.RandomHorizontalFlip(),
# transforms.RandomVerticalFlip(),
# transforms.ToTensor(),
])
ds = FacePairDataset64x64(a_dir=a_dir, b_dir=b_dir,
target_size=target_size, transform=transform)
dataloader = DataLoader(ds, batch_size, shuffle=True)
model = SwapNet()
model.to(device)
start_epoch = 0
logger.info('try resume from checkpoint')
if os.path.isdir('checkpoint'):
try:
if torch.cuda.is_available():
checkpoint = torch.load(check_point_save_path)
else:
checkpoint = torch.load(
check_point_save_path, map_location={'cuda:0': 'cpu'})
model.load_state_dict(checkpoint['state'])
start_epoch = checkpoint['epoch']
logger.info('checkpoint loaded.')
except FileNotFoundError:
print('Can\'t found faceswap_trump_cage.pth')
criterion = nn.L1Loss()
optimizer_1 = optim.Adam([{'params': model.encoder.parameters()},
{'params': model.decoder_A.parameters()}], lr=5e-5, betas=(0.5, 0.999))
optimizer_2 = optim.Adam([{'params': model.encoder.parameters()},
{'params': model.decoder_B.parameters()}], lr=5e-5, betas=(0.5, 0.999))
logger.info('Start training, from epoch {} '.format(start_epoch))
try:
for epoch in range(start_epoch, epochs):
iter = 0
for data in dataloader:
iter += 1
img_a_target, img_a_input, img_b_target, img_b_input = data
img_a_target = img_a_target.to(device)
img_a_input = img_a_input.to(device)
img_b_target = img_b_target.to(device)
img_b_input = img_b_input.to(device)
# print(img_a.size())
# print(img_b.size())
optimizer_1.zero_grad()
optimizer_2.zero_grad()
predict_a = model(img_a_input, select='A')
predict_b = model(img_b_input, select='B')
loss1 = criterion(predict_a, img_a_target)
loss2 = criterion(predict_b, img_b_target)
loss1.backward()
loss2.backward()
optimizer_1.step()
optimizer_2.step()
logger.info('Epoch: {}, iter: {}, lossA: {}, lossB: {}'.format(
epoch, iter, loss1.item(), loss2.item()))
if epoch % save_per_epoch == 0 and epoch != 0:
logger.info('Saving models...')
state = {
'state': model.state_dict(),
'epoch': epoch
}
torch.save(state, os.path.join(os.path.dirname(
check_point_save_path), 'faceswap_trump_cage_128x128_{}.pth'.format(epoch)))
copyfile(os.path.join(os.path.dirname(check_point_save_path), 'faceswap_trump_cage_128x128_{}.pth'.format(epoch)),
check_point_save_path)
if epoch % 10 == 0 and epoch != 0 and iter == 1:
img_a_original = np.array(img_a_target.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
img_b_original = np.array(img_b_target.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
a_predict_a = np.array(predict_a.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
b_predict_b = np.array(predict_b.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
a_predict_b = model(img_a_input, select='B')
b_predict_a = model(img_b_input, select='A')
a_predict_b = np.array(a_predict_b.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
b_predict_a = np.array(b_predict_a.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
cv2.imwrite(os.path.join(log_img_dir, '{}_0.png'.format(epoch)), cv2.cvtColor(img_a_original, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_3.png'.format(epoch)), cv2.cvtColor(img_b_original, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_1.png'.format(epoch)), cv2.cvtColor(a_predict_a, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_4.png'.format(epoch)), cv2.cvtColor(b_predict_b, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_2.png'.format(epoch)), cv2.cvtColor(a_predict_b, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_5.png'.format(epoch)), cv2.cvtColor(b_predict_a, cv2.COLOR_BGR2RGB))
logger.info('Record a result')
except KeyboardInterrupt:
logger.warning('try saving models...do not interrupt')
state = {
'state': model.state_dict(),
'epoch': epoch
}
torch.save(state, os.path.join(os.path.dirname(
check_point_save_path), 'faceswap_trump_cage_256x256_{}.pth'.format(epoch)))
copyfile(os.path.join(os.path.dirname(check_point_save_path), 'faceswap_trump_cage_256x256_{}.pth'.format(epoch)),
check_point_save_path)
if __name__ == "__main__":
main()

158
train_trump_cage_128x128.py Normal file
View File

@@ -0,0 +1,158 @@
"""
Copyright StrangeAI Authors @2019
As the network without linear connect layer
the feature are not compressed, so the encoder are weak
it consist to many informations, and decoder can not using the abstract
information to construct a new image
"""
from __future__ import print_function
import argparse
import os
import cv2
import numpy as np
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
from utils.util import get_image_paths, load_images, stack_images
from dataset.training_data import get_training_data
from alfred.dl.torch.common import device
from shutil import copyfile
try:
from models.swapnet_128 import SwapNet128, toTensor, var_to_np
except Exception:
print('can not import swapnet128, if you need high resolution face swap, '
'you can download from http://luoli.ai (you can afford a VIP membership to get all other codes)')
from loguru import logger
from dataset.face_pair_dataset import FacePairDataset128x128
from torchvision import transforms
from torch.utils.data import DataLoader
from alfred.utils.log import init_logger
init_logger()
batch_size = 32
epochs = 100000
save_per_epoch = 300
a_dir = './data/trump_cage/trump'
b_dir = './data/trump_cage/cage'
dataset_name = 'trump_cage'
# we start to train on bigger size
target_size = 128
log_img_dir = './checkpoint/results_{}_{}x{}'.format(dataset_name, target_size, target_size)
log_model_dir = './checkpoint/{}_{}x{}'.format(dataset_name,
target_size, target_size)
check_point_save_path = os.path.join(
log_model_dir, 'faceswap_{}_{}x{}.pth'.format(dataset_name, target_size, target_size))
def main():
os.makedirs(log_img_dir, exist_ok=True)
os.makedirs(log_model_dir, exist_ok=True)
logger.info("loading datasets")
transform = transforms.Compose([
# transforms.Resize((target_size, target_size)),
transforms.RandomHorizontalFlip(),
# transforms.RandomVerticalFlip(),
# transforms.ToTensor(),
])
ds = FacePairDataset128x128(a_dir=a_dir, b_dir=b_dir,
target_size=target_size, transform=transform)
dataloader = DataLoader(ds, batch_size, shuffle=True)
model = SwapNet128()
model.to(device)
start_epoch = 0
logger.info('try resume from checkpoint')
try:
if torch.cuda.is_available():
checkpoint = torch.load(check_point_save_path)
else:
checkpoint = torch.load(
check_point_save_path, map_location={'cuda:0': 'cpu'})
model.load_state_dict(checkpoint['state'])
start_epoch = checkpoint['epoch']
logger.info('checkpoint loaded.')
except FileNotFoundError:
print('Can\'t found {}'.format(check_point_save_path))
criterion = nn.L1Loss()
optimizer_1 = optim.Adam([{'params': model.encoder.parameters()},
{'params': model.decoder_a.parameters()}], lr=5e-5, betas=(0.5, 0.999))
optimizer_2 = optim.Adam([{'params': model.encoder.parameters()},
{'params': model.decoder_b.parameters()}], lr=5e-5, betas=(0.5, 0.999))
logger.info('Start training, from epoch {} '.format(start_epoch))
try:
for epoch in range(start_epoch, epochs):
iter = 0
for data in dataloader:
iter += 1
img_a_target, img_a_input, img_b_target, img_b_input = data
img_a_target = img_a_target.to(device)
img_a_input = img_a_input.to(device)
img_b_target = img_b_target.to(device)
img_b_input = img_b_input.to(device)
# print(img_a.size())
# print(img_b.size())
optimizer_1.zero_grad()
optimizer_2.zero_grad()
predict_a = model(img_a_input, to='a')
predict_b = model(img_b_input, to='b')
loss1 = criterion(predict_a, img_a_target)
loss2 = criterion(predict_b, img_b_target)
loss1.backward()
loss2.backward()
optimizer_1.step()
optimizer_2.step()
logger.info('Epoch: {}, iter: {}, lossA: {}, lossB: {}'.format(
epoch, iter, loss1.item(), loss2.item()))
if epoch % save_per_epoch == 0 and epoch != 0 and iter == 1:
logger.info('Saving models...')
state = {
'state': model.state_dict(),
'epoch': epoch
}
torch.save(state, os.path.join(os.path.dirname(
check_point_save_path), 'faceswap_{}_128x128_{}.pth'.format(dataset_name, epoch)))
copyfile(os.path.join(os.path.dirname(check_point_save_path), 'faceswap_{}_128x128_{}.pth'.format(dataset_name, epoch)),
check_point_save_path)
if epoch % 10 == 0 and epoch != 0 and iter == 1:
img_a_original = np.array(img_a_target.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
img_b_original = np.array(img_b_target.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
a_predict_a = np.array(predict_a.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
b_predict_b = np.array(predict_b.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
a_predict_b = model(img_a_input, to='b')
b_predict_a = model(img_b_input, to='a')
a_predict_b = np.array(a_predict_b.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
b_predict_a = np.array(b_predict_a.detach().cpu().numpy()[0].transpose(2, 1, 0)*255, dtype=np.uint8)
cv2.imwrite(os.path.join(log_img_dir, '{}_0.png'.format(epoch)), cv2.cvtColor(img_a_original, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_3.png'.format(epoch)), cv2.cvtColor(img_b_original, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_1.png'.format(epoch)), cv2.cvtColor(a_predict_a, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_4.png'.format(epoch)), cv2.cvtColor(b_predict_b, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_2.png'.format(epoch)), cv2.cvtColor(a_predict_b, cv2.COLOR_BGR2RGB))
cv2.imwrite(os.path.join(log_img_dir, '{}_5.png'.format(epoch)), cv2.cvtColor(b_predict_a, cv2.COLOR_BGR2RGB))
logger.info('Record a result')
except KeyboardInterrupt:
logger.info('try saving models...')
state = {
'state': model.state_dict(),
'epoch': epoch
}
torch.save(state, os.path.join(os.path.dirname(check_point_save_path), 'faceswap_{}_128x128_{}.pth'.format(dataset_name, epoch)))
copyfile(os.path.join(os.path.dirname(check_point_save_path), 'faceswap_{}_128x128_{}.pth'.format(dataset_name, epoch)),
check_point_save_path)
if __name__ == "__main__":
main()

144
train_trump_cage_64x64.py Normal file
View File

@@ -0,0 +1,144 @@
"""
Copyright StrangeAI Authors @2019
original forked from deepfakes repo
edit and promoted by StrangeAI authors
"""
from __future__ import print_function
import argparse
import os
import cv2
import numpy as np
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
from models.swapnet import SwapNet, toTensor, var_to_np
from utils.util import get_image_paths, load_images, stack_images
from dataset.training_data import get_training_data
from alfred.dl.torch.common import device
from shutil import copyfile
from loguru import logger
batch_size = 64
epochs = 100000
save_per_epoch = 300
a_dir = './data/trump_cage/trump'
b_dir = './data/trump_cage/cage'
# we start to train on bigger size
target_size = 64
dataset_name = 'trump_cage'
log_img_dir = './checkpoint/results_{}_{}x{}'.format(dataset_name, target_size, target_size)
log_model_dir = './checkpoint/{}_{}x{}'.format(dataset_name,
target_size, target_size)
check_point_save_path = os.path.join(
log_model_dir, 'faceswap_{}_{}x{}.pth'.format(dataset_name, target_size, target_size))
def main():
os.makedirs(log_img_dir, exist_ok=True)
os.makedirs(log_model_dir, exist_ok=True)
logger.info("loading datasets")
images_A = get_image_paths(a_dir)
images_B = get_image_paths(b_dir)
images_A = load_images(images_A) / 255.0
images_B = load_images(images_B) / 255.0
print('mean value to remember: ', images_B.mean(
axis=(0, 1, 2)) - images_A.mean(axis=(0, 1, 2)))
images_A += images_B.mean(axis=(0, 1, 2)) - images_A.mean(axis=(0, 1, 2))
model = SwapNet()
model.to(device)
start_epoch = 0
logger.info('try resume from checkpoint')
if os.path.isdir('checkpoint'):
try:
if torch.cuda.is_available():
checkpoint = torch.load('./checkpoint/faceswap_trump_cage_64x64.pth')
else:
checkpoint = torch.load(
'./checkpoint/faceswap_trump_cage_64x64.pth', map_location={'cuda:0': 'cpu'})
model.load_state_dict(checkpoint['state'])
start_epoch = checkpoint['epoch']
logger.info('checkpoint loaded.')
except FileNotFoundError:
print('Can\'t found faceswap_trump_cage.pth')
criterion = nn.L1Loss()
optimizer_1 = optim.Adam([{'params': model.encoder.parameters()},
{'params': model.decoder_A.parameters()}], lr=5e-5, betas=(0.5, 0.999))
optimizer_2 = optim.Adam([{'params': model.encoder.parameters()},
{'params': model.decoder_B.parameters()}], lr=5e-5, betas=(0.5, 0.999))
logger.info('Start training, from epoch {} '.format(start_epoch))
for epoch in range(start_epoch, epochs):
warped_A, target_A = get_training_data(images_A, batch_size)
# print(warped_A.shape)
# t_a = np.array(warped_A[0] * 255, dtype=np.uint8)
# print(t_a)
# print(t_a.shape)
# cv2.imshow('rr', t_a)
# cv2.waitKey(0)
# warped a and target a are not rotated, where did rotate?
warped_B, target_B = get_training_data(images_B, batch_size)
warped_A, target_A = toTensor(warped_A), toTensor(target_A)
warped_B, target_B = toTensor(warped_B), toTensor(target_B)
# warp_a = np.array(warped_A[0].detach().cpu().numpy().transpose(2, 1, 0)*255, dtype=np.uint8)
# cv2.imshow('rr', warp_a)
# cv2.waitKey(0)
warped_A, target_A, warped_B, target_B = Variable(warped_A.float()), Variable(target_A.float()), \
Variable(warped_B.float()), Variable(target_B.float())
optimizer_1.zero_grad()
optimizer_2.zero_grad()
warped_A_out = model(warped_A, 'A')
warped_B_out = model(warped_B, 'B')
loss1 = criterion(warped_A_out, target_A)
loss2 = criterion(warped_B_out, target_B)
loss1.backward()
loss2.backward()
optimizer_1.step()
optimizer_2.step()
logger.info('epoch: {}, lossA: {}, lossB: {}'.format(epoch, loss1.item(), loss2.item()))
if epoch % save_per_epoch == 0 and iter == 0:
logger.info('Saving models...')
state = {
'state': model.state_dict(),
'epoch': epoch
}
torch.save(state, os.path.join(os.path.dirname(
check_point_save_path), 'faceswap_{}_64x64_{}.pth'.format(dataset_name, epoch)))
copyfile(os.path.join(os.path.dirname(check_point_save_path), 'faceswap_{}_64x64_{}.pth'.format(dataset_name, epoch)),
check_point_save_path)
if epoch % 100 == 0:
test_A_ = warped_A[0:2]
a_predict_a = var_to_np(model(test_A_, 'A'))[0]*255
# warped a out
# print(test_A_[0].detach().cpu().numpy().shape)
a_predict_b = var_to_np(model(test_A_, 'B'))[0]*255
warp_a = test_A_[0].detach().cpu().numpy()*255
target_a = target_A[0].detach().cpu().numpy()*255
cv2.imwrite(os.path.join(log_img_dir, "{}_res_a_to_a.png".format(epoch)), np.array(a_predict_a.transpose(2, 1, 0)).astype('uint8'))
cv2.imwrite(os.path.join(log_img_dir, "{}_res_a_to_b.png".format(epoch)), np.array(a_predict_b.transpose(2, 1, 0)).astype('uint8'))
cv2.imwrite(os.path.join(log_img_dir, "{}_test_A_warped.png".format(epoch)), np.array(warp_a.transpose(2, 1, 0)).astype('uint8'))
cv2.imwrite(os.path.join(log_img_dir, "{}_test_A_target.png".format(epoch)), np.array(target_a.transpose(2, 1, 0)).astype('uint8'))
logger.info('Record a result')
if __name__ == "__main__":
main()

101
utils/face_extractor.py Normal file
View File

@@ -0,0 +1,101 @@
"""
This file using for extracting faces of all images
"""
import glob
try:
import dlib
except ImportError:
print('You have not installed dlib, install from https://github.com/davisking/dlib')
print('see you later.')
exit(0)
import os
import cv2
import numpy as np
from loguru import logger
class FaceExtractor(object):
def __init__(self):
self.detector = dlib.get_frontal_face_detector()
# self.predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
self.predictor_path = os.path.expanduser('~/shape_predictor_68_face_landmarks.dat')
def get_faces_list(self, img, landmark=False):
"""
get faces and locations
"""
assert isinstance(img, np.ndarray), 'img should be numpy array (cv2 frame)'
if landmark:
if os.path.exists(self.predictor_path):
predictor = dlib.shape_predictor(self.predictor_path)
else:
logger.error('can not call this method, you should download '
'dlib landmark model: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')
exit(0)
dets = self.detector(img, 1)
all_faces = []
locations = []
landmarks = []
for i, d in enumerate(dets):
# get the face crop
x = int(d.left())
y = int(d.top())
w = int(d.width())
h = int(d.height())
face_patch = np.array(img)[y: y + h, x: x + w, 0:3]
if landmark:
shape = predictor(img, d)
landmarks.append(shape)
locations.append([x, y, w, h])
all_faces.append(face_patch)
if landmark:
return all_faces, locations, landmarks
else:
return all_faces, locations
def get_faces(self, img_d):
"""
get all faces from img_d
:param img_d:
:return:
"""
all_images = []
for e in ['png', 'jpg', 'jpeg']:
all_images.extend(glob.glob(os.path.join(img_d, '*.{}'.format(e))))
print('Found all {} images under {}'.format(len(all_images), img_d))
s_d = os.path.dirname(img_d) + "_faces"
if not os.path.exists(s_d):
os.makedirs(s_d)
for img_f in all_images:
img = cv2.imread(img_f, cv2.COLOR_BGR2RGB)
dets = self.detector(img, 1)
print('=> get {} faces in {}'.format(len(dets), img_f))
print('=> saving faces...')
for i, d in enumerate(dets):
save_face_f = os.path.join(s_d, os.path.basename(img_f).split('.')[0]
+ '_face_{}.png'.format(i))
# get the face crop
x = int(d.left())
y = int(d.top())
w = int(d.width())
h = int(d.height())
face_patch = np.array(img)[y: y + h, x: x + w, 0:3]
# print(face_patch.shape)
img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 0, 255), 2)
# cv2.imshow('tt', img)
# cv2.waitKey(0)
cv2.imwrite(save_face_f, face_patch)
print('Done!')
# cv2.waitKey(0)

129
utils/model_summary.py Normal file
View File

@@ -0,0 +1,129 @@
# -----------------------
#
# Copyright Jin Fagang @2018
#
# 1/25/19
# torch_summary
# -----------------------
"""
codes token from
https://github.com/sksq96/pytorch-summary
I edit something here, credits belongs to author
"""
import torch
import torch.nn as nn
from torch.autograd import Variable
from collections import OrderedDict
import numpy as np
def summary(model, input_size, batch_size=-1, device="cuda"):
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
m_key = "%s-%i" % (class_name, module_idx + 1)
summary[m_key] = OrderedDict()
summary[m_key]["input_shape"] = list(input[0].size())
summary[m_key]["input_shape"][0] = batch_size
if isinstance(output, (list, tuple)):
summary[m_key]["output_shape"] = [
[-1] + list(o.size())[1:] for o in output
]
else:
summary[m_key]["output_shape"] = list(output.size())
summary[m_key]["output_shape"][0] = batch_size
params = 0
if hasattr(module, "weight") and hasattr(module.weight, "size"):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]["trainable"] = module.weight.requires_grad
if hasattr(module, "bias") and hasattr(module.bias, "size"):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]["nb_params"] = params
if (
not isinstance(module, nn.Sequential)
and not isinstance(module, nn.ModuleList)
and not (module == model)
):
hooks.append(module.register_forward_hook(hook))
device = device.lower()
assert device in [
"cuda",
"cpu",
], "Input device is not valid, please specify 'cuda' or 'cpu'"
if device == "cuda" and torch.cuda.is_available():
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
# multiple inputs to the network
if isinstance(input_size, tuple) and input_size[0] <= 3:
# batch_size of 2 for batchnorm
x = torch.rand(2, *input_size).type(dtype)
else:
print('Wrong! you should send input size specific without batch size, etc: (3, 64, 64), channel first.')
exit(0)
# create properties
summary = OrderedDict()
hooks = []
# register hook
model.apply(register_hook)
# make a forward pass
try:
print('fake data input: ', x.size())
model(x)
except Exception as e:
print('summary failed. error: {}'.format(e))
print('make sure your called model.to(device) ')
exit(0)
# remove these hooks
for h in hooks:
h.remove()
print("----------------------------------------------------------------")
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
print(line_new)
print("================================================================")
total_params = 0
total_output = 0
trainable_params = 0
for layer in summary:
# input_shape, output_shape, trainable, nb_params
line_new = "{:>20} {:>25} {:>15}".format(
layer,
str(summary[layer]["output_shape"]),
"{0:,}".format(summary[layer]["nb_params"]),
)
total_params += summary[layer]["nb_params"]
total_output += np.prod(summary[layer]["output_shape"])
if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"]
print(line_new)
# assume 4 bytes/number (float on cuda).
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
print("================================================================")
print("Total params: {0:,}".format(total_params))
print("Trainable params: {0:,}".format(trainable_params))
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
print("----------------------------------------------------------------")
print("Input size (MB): %0.2f" % total_input_size)
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
print("Params size (MB): %0.2f" % total_params_size)
print("Estimated Total Size (MB): %0.2f" % total_size)
print("----------------------------------------------------------------")

88
utils/umeyama.py Normal file
View File

@@ -0,0 +1,88 @@
# # License (Modified BSD) # Copyright (C) 2011, the scikit-image team All rights reserved. # # Redistribution and
# use in source and binary forms, with or without modification, are permitted provided that the following conditions
# are met: # # Redistributions of source code must retain the above copyright notice, this list of conditions and the
# following disclaimer. # Redistributions in binary form must reproduce the above copyright notice, this list of
# conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
# Neither the name of skimage nor the names of its contributors may be used to endorse or promote products derived
# from this software without specific prior written permission. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS''
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
# GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# umeyama function from scikit-image/skimage/transform/_geometric.py
import numpy as np
def umeyama(src, dst, estimate_scale):
"""Estimate N-D similarity transformation with or without scaling.
Parameters
----------
src : (M, N) array
Source coordinates.
dst : (M, N) array
Destination coordinates.
estimate_scale : bool
Whether to estimate scaling factor.
Returns
-------
T : (N + 1, N + 1)
The homogeneous similarity transformation matrix. The matrix contains
NaN values only if the problem is not well-conditioned.
References
----------
.. [1] "Least-squares estimation of transformation parameters between two
point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573
"""
num = src.shape[0]
dim = src.shape[1]
# Compute mean of src and dst.
src_mean = src.mean(axis=0)
dst_mean = dst.mean(axis=0)
# Subtract mean from src and dst.
src_demean = src - src_mean
dst_demean = dst - dst_mean
# Eq. (38).
A = np.dot(dst_demean.T, src_demean) / num
# Eq. (39).
d = np.ones((dim,), dtype=np.double)
if np.linalg.det(A) < 0:
d[dim - 1] = -1
T = np.eye(dim + 1, dtype=np.double)
U, S, V = np.linalg.svd(A)
# Eq. (40) and (43).
rank = np.linalg.matrix_rank(A)
if rank == 0:
return np.nan * T
elif rank == dim - 1:
if np.linalg.det(U) * np.linalg.det(V) > 0:
T[:dim, :dim] = np.dot(U, V)
else:
s = d[dim - 1]
d[dim - 1] = -1
T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V))
d[dim - 1] = s
else:
T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V.T))
if estimate_scale:
# Eq. (41) and (42).
scale = 1.0 / src_demean.var(axis=0).sum() * np.dot(S, d)
else:
scale = 1.0
T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T)
T[:dim, :dim] *= scale
return T

40
utils/util.py Normal file
View File

@@ -0,0 +1,40 @@
import cv2
import numpy
import os
def get_image_paths(directory):
# return [x.path for x in os.scandir(directory) if x.name.endswith(".jpg") or x.name.endswith(".png")]
return [x.path for x in os.scandir(directory) if x.name.endswith(".png")]
def load_images(image_paths, convert=None):
iter_all_images = (cv2.resize(cv2.imread(fn), (256, 256)) for fn in image_paths)
if convert:
iter_all_images = (convert(img) for img in iter_all_images)
for i, image in enumerate(iter_all_images):
if i == 0:
all_images = numpy.empty((len(image_paths),) + image.shape, dtype=image.dtype)
all_images[i] = image
return all_images
def get_transpose_axes(n):
if n % 2 == 0:
y_axes = list(range(1, n - 1, 2))
x_axes = list(range(0, n - 1, 2))
else:
y_axes = list(range(0, n - 1, 2))
x_axes = list(range(1, n - 1, 2))
return y_axes, x_axes, [n - 1]
def stack_images(images):
images_shape = numpy.array(images.shape)
new_axes = get_transpose_axes(len(images_shape))
new_shape = [numpy.prod(images_shape[x]) for x in new_axes]
return numpy.transpose(
images,
axes=numpy.concatenate(new_axes)
).reshape(new_shape)