""" Copyright StrangeAI Authors @2019 """ import torch import torch.utils.data from torch import nn, optim from .padding_same_conv import Conv2d from alfred.dl.torch.common import device def toTensor(img): img = torch.from_numpy(img.transpose((0, 3, 1, 2))).to(device) return img def var_to_np(img_var): return img_var.data.cpu().numpy() class _ConvLayer(nn.Sequential): def __init__(self, input_features, output_features): super(_ConvLayer, self).__init__() self.add_module('conv2', Conv2d(input_features, output_features, kernel_size=5, stride=2)) self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True)) class _UpScale(nn.Sequential): def __init__(self, input_features, output_features): super(_UpScale, self).__init__() self.add_module('conv2_', Conv2d(input_features, output_features * 4, kernel_size=3)) self.add_module('leakyrelu', nn.LeakyReLU(0.1, inplace=True)) self.add_module('pixelshuffler', _PixelShuffler()) class Flatten(nn.Module): def forward(self, input): output = input.view(input.size(0), -1) return output class Reshape(nn.Module): def forward(self, input): output = input.view(-1, 1024, 4, 4) # channel * 4 * 4 return output class _PixelShuffler(nn.Module): def forward(self, input): batch_size, c, h, w = input.size() rh, rw = (2, 2) oh, ow = h * rh, w * rw oc = c // (rh * rw) out = input.view(batch_size, rh, rw, oc, h, w) out = out.permute(0, 3, 4, 1, 5, 2).contiguous() out = out.view(batch_size, oc, oh, ow) # channel first return out class SwapNet(nn.Module): def __init__(self): super(SwapNet, self).__init__() self.encoder = nn.Sequential( _ConvLayer(3, 128), _ConvLayer(128, 256), _ConvLayer(256, 512), _ConvLayer(512, 1024), Flatten(), nn.Linear(1024 * 4 * 4, 1024), nn.Linear(1024, 1024 * 4 * 4), Reshape(), _UpScale(1024, 512), ) self.decoder_A = nn.Sequential( _UpScale(512, 256), _UpScale(256, 128), _UpScale(128, 64), Conv2d(64, 3, kernel_size=5, padding=1), nn.Sigmoid(), ) self.decoder_B = nn.Sequential( _UpScale(512, 256), _UpScale(256, 128), _UpScale(128, 64), Conv2d(64, 3, kernel_size=5, padding=1), nn.Sigmoid(), ) def forward(self, x, select='A'): if select == 'A': out = self.encoder(x) out = self.decoder_A(out) else: out = self.encoder(x) out = self.decoder_B(out) return out