import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import random from PIL import Image class BaseNetwork(nn.Module): def __init__(self): super(BaseNetwork, self).__init__() def init_weights(self): self.apply(self._weights_init_fn) def _weights_init_fn(self, m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) if hasattr(m.bias, 'data'): m.bias.data.fill_(0) elif classname.find('BatchNorm2d') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) class ResidualBlock(BaseNetwork): """Residual Block with instance normalization.""" def __init__(self, dim_in, dim_out): super(ResidualBlock, self).__init__() self.main = nn.Sequential( nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True), nn.ReLU(inplace=True), nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True)) self.init_weights() def forward(self, x): return x + self.main(x) class Generator(BaseNetwork): """Generator network.""" def __init__(self, conv_dim=64, c_dim=5, repeat_num=6): super(Generator, self).__init__() layers = [] layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) layers.append(nn.InstanceNorm2d( conv_dim, affine=True, track_running_stats=True)) layers.append(nn.ReLU(inplace=True)) self.debug1 = nn.Sequential(*layers) # Down-sampling layers. curr_dim = conv_dim for i in range(2): layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) layers.append(nn.InstanceNorm2d( curr_dim*2, affine=True, track_running_stats=True)) layers.append(nn.ReLU(inplace=True)) curr_dim = curr_dim * 2 self.debug2 = nn.Sequential(*layers) # Bottleneck layers. for i in range(repeat_num): layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) self.debug3 = nn.Sequential(*layers) # Up-sampling layers. for i in range(2): layers.append(nn.ConvTranspose2d(curr_dim, curr_dim // 2, kernel_size=4, stride=2, padding=1, bias=False)) layers.append(nn.InstanceNorm2d( curr_dim//2, affine=True, track_running_stats=True)) layers.append(nn.ReLU(inplace=True)) curr_dim = curr_dim // 2 self.main = nn.Sequential(*layers) self.debug4 = nn.Sequential(*layers) # Same architecture for the color regression layers = [] layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) layers.append(nn.Tanh()) self.im_reg = nn.Sequential(*layers) # One Channel output and Sigmoid function for the attention layer layers = [] layers.append(nn.Conv2d(curr_dim, 1, kernel_size=7, stride=1, padding=3, bias=False)) layers.append(nn.Sigmoid()) # Values between 0 and 1 self.im_att = nn.Sequential(*layers) self.init_weights() def forward(self, x, c): # Replicate spatially and concatenate domain information. c = c.unsqueeze(2).unsqueeze(3) c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3)) x = torch.cat([x, c], dim=1) features = self.main(x) reg = self.im_reg(features) att = self.im_att(features) return att, reg class Discriminator(BaseNetwork): """Discriminator network with PatchGAN.""" def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6): super(Discriminator, self).__init__() layers = [] layers.append( nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) layers.append(nn.LeakyReLU(0.01)) curr_dim = conv_dim for i in range(1, repeat_num): layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) layers.append(nn.LeakyReLU(0.01)) curr_dim = curr_dim * 2 kernel_size = int(image_size / np.power(2, repeat_num)) self.main = nn.Sequential(*layers) self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) self.conv2 = nn.Conv2d( curr_dim, c_dim, kernel_size=kernel_size, bias=False) self.init_weights() def forward(self, x): h = self.main(x) out_src = self.conv1(h) out_cls = self.conv2(h) # out_cls.view(out_cls.size(0), out_cls.size(1)) return out_src.squeeze(), out_cls.squeeze()