Files
SimSwapPlus/utilities/utilities.py
T
chenxuanhong 3783ef0e75 init
2022-01-10 15:03:58 +08:00

335 lines
10 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: utilities.py
# Created Date: Monday April 6th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 2:18:05 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import cv2
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
# Gram Matrix
def Gram(tensor: torch.Tensor):
B, C, H, W = tensor.shape
x = tensor.view(B, C, H*W)
x_t = x.transpose(1, 2)
return torch.bmm(x, x_t) / (C*H*W)
def build_tensorboard(summary_path):
from tensorboardX import SummaryWriter
# from logger import Logger
# self.logger = Logger(self.log_path)
writer = SummaryWriter(log_dir=summary_path)
return writer
def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
def tensor2img(img_tensor):
"""
Input image tensor shape must be [B C H W]
the return image numpy array shape is [B H W C]
"""
res = img_tensor.numpy()
res = (res + 1) / 2
res = np.clip(res, 0.0, 1.0)
res = res * 255
res = res.transpose((0,2,3,1))
return res
def img2tensor255(path, max_size=None):
image = Image.open(path)
# Rescale the image
if (max_size==None):
itot_t = transforms.Compose([
#transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
else:
H, W, C = image.shape
image_size = tuple([int((float(max_size) / max([H,W]))*x) for x in [H, W]])
itot_t = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
# Convert image to tensor
tensor = itot_t(image)
# Add the batch_size dimension
tensor = tensor.unsqueeze(dim=0)
return tensor
def img2tensor255crop(path, crop_size=256):
image = Image.open(path)
# Rescale the image
itot_t = transforms.Compose([
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
# Convert image to tensor
tensor = itot_t(image)
# Add the batch_size dimension
tensor = tensor.unsqueeze(dim=0)
return tensor
# def img2tensor255(path, crop_size=None):
# """
# Input image tensor shape must be [B C H W]
# the return image numpy array shape is [B H W C]
# """
# img = cv2.imread(path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float)
# img = torch.from_numpy(img).transpose((2,0,1)).unsqueeze(0)
# return img
def img2tensor1(img_tensor):
"""
Input image tensor shape must be [B C H W]
the return image numpy array shape is [B H W C]
"""
res = img_tensor.numpy()
res = (res + 1) / 2
res = np.clip(res, 0.0, 1.0)
res = res * 255
res = res.transpose((0,2,3,1))
return res
def _convert_input_type_range(img):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
convertion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type = img.dtype
img = img.astype(np.float32)
if img_type == np.float32:
pass
elif img_type == np.uint8:
img /= 255.
else:
raise TypeError('The img type should be np.float32 or np.uint8, '
f'but got {img_type}')
return img
def _convert_output_type_range(img, dst_type):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace convertion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if dst_type not in (np.uint8, np.float32):
raise TypeError('The dst_type should be np.float32 or np.uint8, '
f'but got {dst_type}')
if dst_type == np.uint8:
img = img.round()
else:
img /= 255.
return img.astype(dst_type)
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
# out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 #RGB
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.
def calculate_psnr(img1,
img2,
# crop_border=0,
test_y_channel=True):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the PSNR calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: psnr result.
"""
# if crop_border != 0:
# img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
# img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
if test_y_channel:
img1 = to_y_channel(img1)
img2 = to_y_channel(img2)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20. * np.log10(255. / np.sqrt(mse))
def _ssim(img1, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: ssim result.
"""
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) *
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim(img1,
img2,
test_y_channel=True):
"""Calculate SSIM (structural similarity).
Ref:
Image quality assessment: From error visibility to structural similarity
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the SSIM calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: ssim result.
"""
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
if test_y_channel:
img1 = to_y_channel(img1)
img2 = to_y_channel(img2)
ssims = []
for i in range(img1.shape[2]):
ssims.append(_ssim(img1[..., i], img2[..., i]))
return np.array(ssims).mean()