update
This commit is contained in:
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import PIL
|
||||
|
||||
def postprocess(x):
|
||||
"""[0,1] to uint8."""
|
||||
|
||||
x = np.clip(255 * x, 0, 255)
|
||||
x = np.cast[np.uint8](x)
|
||||
return x
|
||||
|
||||
def tile(X, rows, cols):
|
||||
"""Tile images for display."""
|
||||
tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype)
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
idx = i * cols + j
|
||||
if idx < X.shape[0]:
|
||||
img = X[idx,...]
|
||||
tiling[
|
||||
i*X.shape[1]:(i+1)*X.shape[1],
|
||||
j*X.shape[2]:(j+1)*X.shape[2],
|
||||
:] = img
|
||||
return tiling
|
||||
|
||||
|
||||
def plot_batch(X, out_path):
|
||||
"""Save batch of images tiled."""
|
||||
n_channels = X.shape[3]
|
||||
if n_channels > 3:
|
||||
X = X[:,:,:,np.random.choice(n_channels, size = 3)]
|
||||
X = postprocess(X)
|
||||
rc = math.sqrt(X.shape[0])
|
||||
rows = cols = math.ceil(rc)
|
||||
canvas = tile(X, rows, cols)
|
||||
canvas = np.squeeze(canvas)
|
||||
PIL.Image.fromarray(canvas).save(out_path)
|
||||
Reference in New Issue
Block a user