AMP fix
This commit is contained in:
@@ -333,7 +333,9 @@ def depth_to_space(x, size):
|
||||
x = tf.reshape(x, (-1, oh, ow, oc, ))
|
||||
return x
|
||||
else:
|
||||
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
||||
cfg = nn.getCurrentDeviceConfig()
|
||||
if not cfg.cpu_only:
|
||||
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
||||
b,c,h,w = x.shape.as_list()
|
||||
oh, ow = h * size, w * size
|
||||
oc = c // (size * size)
|
||||
@@ -344,11 +346,6 @@ def depth_to_space(x, size):
|
||||
return x
|
||||
nn.depth_to_space = depth_to_space
|
||||
|
||||
def pixel_norm(x, power = 1.0):
|
||||
return x * power * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=nn.conv2d_spatial_axes, keepdims=True) + 1e-06)
|
||||
nn.pixel_norm = pixel_norm
|
||||
|
||||
|
||||
def rgb_to_lab(srgb):
|
||||
srgb_pixels = tf.reshape(srgb, [-1, 3])
|
||||
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
|
||||
|
||||
Reference in New Issue
Block a user