Files
disrupting-deepfakes/GANimation/utils/util.py
Nataniel Ruiz Gutierrez 21970b730a All
2019-12-21 16:37:10 -05:00

53 lines
1.5 KiB
Python

from __future__ import print_function
from PIL import Image
import numpy as np
import os
import torchvision
import math
def tensor2im(img, imtype=np.uint8, unnormalize=True, idx=0, nrows=None):
# select a sample or create grid if img is a batch
if len(img.shape) == 4:
nrows = nrows if nrows is not None else int(math.sqrt(img.size(0)))
img = img[idx] if idx >= 0 else torchvision.utils.make_grid(img, nrows)
img = img.cpu().float()
if unnormalize:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
for i, m, s in zip(img, mean, std):
i.mul_(s).add_(m)
image_numpy = img.numpy()
image_numpy_t = np.transpose(image_numpy, (1, 2, 0))
image_numpy_t = image_numpy_t*254.0
return image_numpy_t.astype(imtype)
def tensor2maskim(mask, imtype=np.uint8, idx=0, nrows=1):
im = tensor2im(mask, imtype=imtype, idx=idx, unnormalize=False, nrows=nrows)
if im.shape[2] == 1:
im = np.repeat(im, 3, axis=-1)
return im
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def save_image(image_numpy, image_path):
mkdir(os.path.dirname(image_path))
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def save_str_data(data, path):
mkdir(os.path.dirname(path))
np.savetxt(path, data, delimiter=",", fmt="%s")