58 lines
1.8 KiB
Python
58 lines
1.8 KiB
Python
import torch.nn as nn
|
|
import functools
|
|
|
|
class NetworksFactory:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@staticmethod
|
|
def get_by_name(network_name, *args, **kwargs):
|
|
|
|
if network_name == 'generator_wasserstein_gan':
|
|
from .generator_wasserstein_gan import Generator
|
|
network = Generator(*args, **kwargs)
|
|
elif network_name == 'discriminator_wasserstein_gan':
|
|
from .discriminator_wasserstein_gan import Discriminator
|
|
network = Discriminator(*args, **kwargs)
|
|
else:
|
|
raise ValueError("Network %s not recognized." % network_name)
|
|
|
|
print "Network %s was created" % network_name
|
|
|
|
return network
|
|
|
|
|
|
class NetworkBase(nn.Module):
|
|
def __init__(self):
|
|
super(NetworkBase, self).__init__()
|
|
self._name = 'BaseNetwork'
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
def init_weights(self):
|
|
self.apply(self._weights_init_fn)
|
|
|
|
def _weights_init_fn(self, m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Conv') != -1:
|
|
m.weight.data.normal_(0.0, 0.02)
|
|
if hasattr(m.bias, 'data'):
|
|
m.bias.data.fill_(0)
|
|
elif classname.find('BatchNorm2d') != -1:
|
|
m.weight.data.normal_(1.0, 0.02)
|
|
m.bias.data.fill_(0)
|
|
|
|
def _get_norm_layer(self, norm_type='batch'):
|
|
if norm_type == 'batch':
|
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
|
elif norm_type == 'instance':
|
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
|
elif norm_type =='batchnorm2d':
|
|
norm_layer = nn.BatchNorm2d
|
|
else:
|
|
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
|
|
|
return norm_layer
|