-
This commit is contained in:
+13
-2
@@ -853,15 +853,20 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||
nnlib.BilinearInterpolation = BilinearInterpolation
|
||||
|
||||
class WScaleConv2DLayer(KL.Conv2D):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, gain=None, **kwargs):
|
||||
kwargs['kernel_initializer'] = keras.initializers.random_normal()
|
||||
|
||||
if gain is None:
|
||||
gain = np.sqrt(2)
|
||||
|
||||
self.gain = gain
|
||||
|
||||
super(WScaleConv2DLayer,self).__init__(*args,**kwargs)
|
||||
|
||||
def build(self, input_shape):
|
||||
super().build(input_shape)
|
||||
kernel_shape = K.int_shape(self.kernel)
|
||||
std = np.sqrt(2) / np.sqrt( np.prod(kernel_shape[:-1]) )
|
||||
std = np.sqrt(self.gain) / np.sqrt( np.prod(kernel_shape[:-1]) )
|
||||
self.wscale = K.constant(std, dtype=K.floatx() )
|
||||
|
||||
def call(self, input, **kwargs):
|
||||
@@ -870,6 +875,12 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||
x = super().call(input,**kwargs)
|
||||
self.kernel = k
|
||||
return x
|
||||
|
||||
def get_config(self):
|
||||
config = {"gain": self.gain}
|
||||
base_config = super(WScaleConv2DLayer, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
nnlib.WScaleConv2DLayer = WScaleConv2DLayer
|
||||
|
||||
class SelfAttention(KL.Layer):
|
||||
|
||||
Reference in New Issue
Block a user