106 lines
2.9 KiB
Python
106 lines
2.9 KiB
Python
"""
|
|
Copyright StrangeAI Authors @2019
|
|
|
|
"""
|
|
import torch
|
|
import torch.utils.data
|
|
from torch import nn, optim
|
|
from .padding_same_conv import Conv2d
|
|
from alfred.dl.torch.common import device
|
|
|
|
|
|
def toTensor(img):
|
|
img = torch.from_numpy(img.transpose((0, 3, 1, 2))).to(device)
|
|
return img
|
|
|
|
|
|
def var_to_np(img_var):
|
|
return img_var.data.cpu().numpy()
|
|
|
|
|
|
class _ConvLayer(nn.Sequential):
|
|
def __init__(self, input_features, output_features):
|
|
super(_ConvLayer, self).__init__()
|
|
self.add_module('conv2', Conv2d(input_features, output_features,
|
|
kernel_size=5, stride=2))
|
|
self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True))
|
|
|
|
|
|
class _UpScale(nn.Sequential):
|
|
def __init__(self, input_features, output_features):
|
|
super(_UpScale, self).__init__()
|
|
self.add_module('conv2_', Conv2d(input_features, output_features * 4,
|
|
kernel_size=3))
|
|
self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True))
|
|
self.add_module('pixelshuffler', _PixelShuffler())
|
|
|
|
|
|
class Flatten(nn.Module):
|
|
|
|
def forward(self, input):
|
|
output = input.view(input.size(0), -1)
|
|
return output
|
|
|
|
|
|
class Reshape(nn.Module):
|
|
|
|
def forward(self, input):
|
|
output = input.view(-1, 1024, 4, 4) # channel * 4 * 4
|
|
|
|
return output
|
|
|
|
|
|
class _PixelShuffler(nn.Module):
|
|
def forward(self, input):
|
|
batch_size, c, h, w = input.size()
|
|
rh, rw = (2, 2)
|
|
oh, ow = h * rh, w * rw
|
|
oc = c // (rh * rw)
|
|
out = input.view(batch_size, rh, rw, oc, h, w)
|
|
out = out.permute(0, 3, 4, 1, 5, 2).contiguous()
|
|
out = out.view(batch_size, oc, oh, ow) # channel first
|
|
|
|
return out
|
|
|
|
|
|
class SwapNet(nn.Module):
|
|
def __init__(self):
|
|
super(SwapNet, self).__init__()
|
|
|
|
self.encoder = nn.Sequential(
|
|
_ConvLayer(3, 128),
|
|
_ConvLayer(128, 256),
|
|
_ConvLayer(256, 512),
|
|
_ConvLayer(512, 1024),
|
|
Flatten(),
|
|
nn.Linear(1024 * 4 * 4, 1024),
|
|
nn.Linear(1024, 1024 * 4 * 4),
|
|
Reshape(),
|
|
_UpScale(1024, 512),
|
|
)
|
|
|
|
self.decoder_A = nn.Sequential(
|
|
_UpScale(512, 256),
|
|
_UpScale(256, 128),
|
|
_UpScale(128, 64),
|
|
Conv2d(64, 3, kernel_size=5, padding=1),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
self.decoder_B = nn.Sequential(
|
|
_UpScale(512, 256),
|
|
_UpScale(256, 128),
|
|
_UpScale(128, 64),
|
|
Conv2d(64, 3, kernel_size=5, padding=1),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, x, select='A'):
|
|
if select == 'A':
|
|
out = self.encoder(x)
|
|
out = self.decoder_A(out)
|
|
else:
|
|
out = self.encoder(x)
|
|
out = self.decoder_B(out)
|
|
return out
|