Files
disrupting-deepfakes/GANimation/networks/networks.py
T
Nataniel Ruiz Gutierrez 21970b730a All
2019-12-21 16:37:10 -05:00

58 lines
1.8 KiB
Python

import torch.nn as nn
import functools
class NetworksFactory:
def __init__(self):
pass
@staticmethod
def get_by_name(network_name, *args, **kwargs):
if network_name == 'generator_wasserstein_gan':
from .generator_wasserstein_gan import Generator
network = Generator(*args, **kwargs)
elif network_name == 'discriminator_wasserstein_gan':
from .discriminator_wasserstein_gan import Discriminator
network = Discriminator(*args, **kwargs)
else:
raise ValueError("Network %s not recognized." % network_name)
print "Network %s was created" % network_name
return network
class NetworkBase(nn.Module):
def __init__(self):
super(NetworkBase, self).__init__()
self._name = 'BaseNetwork'
@property
def name(self):
return self._name
def init_weights(self):
self.apply(self._weights_init_fn)
def _weights_init_fn(self, m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def _get_norm_layer(self, norm_type='batch'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type =='batchnorm2d':
norm_layer = nn.BatchNorm2d
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer