Files
disrupting-deepfakes/ganimation/model.py
Nataniel Ruiz Gutierrez 21970b730a All
2019-12-21 16:37:10 -05:00

159 lines
5.2 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from PIL import Image
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def init_weights(self):
self.apply(self._weights_init_fn)
def _weights_init_fn(self, m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class ResidualBlock(BaseNetwork):
"""Residual Block with instance normalization."""
def __init__(self, dim_in, dim_out):
super(ResidualBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=3,
stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, kernel_size=3,
stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))
self.init_weights()
def forward(self, x):
return x + self.main(x)
class Generator(BaseNetwork):
"""Generator network."""
def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
super(Generator, self).__init__()
layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7,
stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(
conv_dim, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
self.debug1 = nn.Sequential(*layers)
# Down-sampling layers.
curr_dim = conv_dim
for i in range(2):
layers.append(nn.Conv2d(curr_dim, curr_dim*2,
kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(
curr_dim*2, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim * 2
self.debug2 = nn.Sequential(*layers)
# Bottleneck layers.
for i in range(repeat_num):
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
self.debug3 = nn.Sequential(*layers)
# Up-sampling layers.
for i in range(2):
layers.append(nn.ConvTranspose2d(curr_dim, curr_dim //
2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(
curr_dim//2, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim // 2
self.main = nn.Sequential(*layers)
self.debug4 = nn.Sequential(*layers)
# Same architecture for the color regression
layers = []
layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7,
stride=1, padding=3, bias=False))
layers.append(nn.Tanh())
self.im_reg = nn.Sequential(*layers)
# One Channel output and Sigmoid function for the attention layer
layers = []
layers.append(nn.Conv2d(curr_dim, 1, kernel_size=7,
stride=1, padding=3, bias=False))
layers.append(nn.Sigmoid()) # Values between 0 and 1
self.im_att = nn.Sequential(*layers)
self.init_weights()
def forward(self, x, c):
# Replicate spatially and concatenate domain information.
c = c.unsqueeze(2).unsqueeze(3)
c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3))
x = torch.cat([x, c], dim=1)
features = self.main(x)
reg = self.im_reg(features)
att = self.im_att(features)
return att, reg
class Discriminator(BaseNetwork):
"""Discriminator network with PatchGAN."""
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
super(Discriminator, self).__init__()
layers = []
layers.append(
nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = conv_dim
for i in range(1, repeat_num):
layers.append(nn.Conv2d(curr_dim, curr_dim*2,
kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = curr_dim * 2
kernel_size = int(image_size / np.power(2, repeat_num))
self.main = nn.Sequential(*layers)
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3,
stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(
curr_dim, c_dim, kernel_size=kernel_size, bias=False)
self.init_weights()
def forward(self, x):
h = self.main(x)
out_src = self.conv1(h)
out_cls = self.conv2(h)
# out_cls.view(out_cls.size(0), out_cls.size(1))
return out_src.squeeze(), out_cls.squeeze()