Files
faceswap_pytorch/models/swapnet.py
2019-02-01 15:30:41 +08:00

106 lines
2.9 KiB
Python

"""
Copyright StrangeAI Authors @2019
"""
import torch
import torch.utils.data
from torch import nn, optim
from .padding_same_conv import Conv2d
from alfred.dl.torch.common import device
def toTensor(img):
img = torch.from_numpy(img.transpose((0, 3, 1, 2))).to(device)
return img
def var_to_np(img_var):
return img_var.data.cpu().numpy()
class _ConvLayer(nn.Sequential):
def __init__(self, input_features, output_features):
super(_ConvLayer, self).__init__()
self.add_module('conv2', Conv2d(input_features, output_features,
kernel_size=5, stride=2))
self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True))
class _UpScale(nn.Sequential):
def __init__(self, input_features, output_features):
super(_UpScale, self).__init__()
self.add_module('conv2_', Conv2d(input_features, output_features * 4,
kernel_size=3))
self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True))
self.add_module('pixelshuffler', _PixelShuffler())
class Flatten(nn.Module):
def forward(self, input):
output = input.view(input.size(0), -1)
return output
class Reshape(nn.Module):
def forward(self, input):
output = input.view(-1, 1024, 4, 4) # channel * 4 * 4
return output
class _PixelShuffler(nn.Module):
def forward(self, input):
batch_size, c, h, w = input.size()
rh, rw = (2, 2)
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = input.view(batch_size, rh, rw, oc, h, w)
out = out.permute(0, 3, 4, 1, 5, 2).contiguous()
out = out.view(batch_size, oc, oh, ow) # channel first
return out
class SwapNet(nn.Module):
def __init__(self):
super(SwapNet, self).__init__()
self.encoder = nn.Sequential(
_ConvLayer(3, 128),
_ConvLayer(128, 256),
_ConvLayer(256, 512),
_ConvLayer(512, 1024),
Flatten(),
nn.Linear(1024 * 4 * 4, 1024),
nn.Linear(1024, 1024 * 4 * 4),
Reshape(),
_UpScale(1024, 512),
)
self.decoder_A = nn.Sequential(
_UpScale(512, 256),
_UpScale(256, 128),
_UpScale(128, 64),
Conv2d(64, 3, kernel_size=5, padding=1),
nn.Sigmoid(),
)
self.decoder_B = nn.Sequential(
_UpScale(512, 256),
_UpScale(256, 128),
_UpScale(128, 64),
Conv2d(64, 3, kernel_size=5, padding=1),
nn.Sigmoid(),
)
def forward(self, x, select='A'):
if select == 'A':
out = self.encoder(x)
out = self.decoder_A(out)
else:
out = self.encoder(x)
out = self.decoder_B(out)
return out