335 lines
10 KiB
Python
335 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: utilities.py
|
|
# Created Date: Monday April 6th 2020
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Tuesday, 12th October 2021 2:18:05 pm
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
import cv2
|
|
import torch
|
|
from PIL import Image
|
|
import numpy as np
|
|
from torchvision import transforms
|
|
|
|
# Gram Matrix
|
|
def Gram(tensor: torch.Tensor):
|
|
B, C, H, W = tensor.shape
|
|
x = tensor.view(B, C, H*W)
|
|
x_t = x.transpose(1, 2)
|
|
return torch.bmm(x, x_t) / (C*H*W)
|
|
|
|
def build_tensorboard(summary_path):
|
|
from tensorboardX import SummaryWriter
|
|
# from logger import Logger
|
|
# self.logger = Logger(self.log_path)
|
|
writer = SummaryWriter(log_dir=summary_path)
|
|
return writer
|
|
|
|
|
|
|
|
def denorm(x):
|
|
out = (x + 1) / 2
|
|
return out.clamp_(0, 1)
|
|
|
|
def tensor2img(img_tensor):
|
|
"""
|
|
Input image tensor shape must be [B C H W]
|
|
the return image numpy array shape is [B H W C]
|
|
"""
|
|
res = img_tensor.numpy()
|
|
res = (res + 1) / 2
|
|
res = np.clip(res, 0.0, 1.0)
|
|
res = res * 255
|
|
res = res.transpose((0,2,3,1))
|
|
return res
|
|
|
|
def img2tensor255(path, max_size=None):
|
|
|
|
image = Image.open(path)
|
|
# Rescale the image
|
|
if (max_size==None):
|
|
itot_t = transforms.Compose([
|
|
#transforms.ToPILImage(),
|
|
transforms.ToTensor(),
|
|
transforms.Lambda(lambda x: x.mul(255))
|
|
])
|
|
else:
|
|
H, W, C = image.shape
|
|
image_size = tuple([int((float(max_size) / max([H,W]))*x) for x in [H, W]])
|
|
itot_t = transforms.Compose([
|
|
transforms.ToPILImage(),
|
|
transforms.Resize(image_size),
|
|
transforms.ToTensor(),
|
|
transforms.Lambda(lambda x: x.mul(255))
|
|
])
|
|
|
|
# Convert image to tensor
|
|
tensor = itot_t(image)
|
|
|
|
# Add the batch_size dimension
|
|
tensor = tensor.unsqueeze(dim=0)
|
|
return tensor
|
|
|
|
def img2tensor255crop(path, crop_size=256):
|
|
|
|
image = Image.open(path)
|
|
# Rescale the image
|
|
itot_t = transforms.Compose([
|
|
transforms.CenterCrop(crop_size),
|
|
transforms.ToTensor(),
|
|
transforms.Lambda(lambda x: x.mul(255))
|
|
])
|
|
|
|
# Convert image to tensor
|
|
tensor = itot_t(image)
|
|
|
|
# Add the batch_size dimension
|
|
tensor = tensor.unsqueeze(dim=0)
|
|
return tensor
|
|
|
|
# def img2tensor255(path, crop_size=None):
|
|
# """
|
|
# Input image tensor shape must be [B C H W]
|
|
# the return image numpy array shape is [B H W C]
|
|
# """
|
|
# img = cv2.imread(path)
|
|
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float)
|
|
# img = torch.from_numpy(img).transpose((2,0,1)).unsqueeze(0)
|
|
# return img
|
|
|
|
def img2tensor1(img_tensor):
|
|
"""
|
|
Input image tensor shape must be [B C H W]
|
|
the return image numpy array shape is [B H W C]
|
|
"""
|
|
res = img_tensor.numpy()
|
|
res = (res + 1) / 2
|
|
res = np.clip(res, 0.0, 1.0)
|
|
res = res * 255
|
|
res = res.transpose((0,2,3,1))
|
|
return res
|
|
|
|
def _convert_input_type_range(img):
|
|
"""Convert the type and range of the input image.
|
|
|
|
It converts the input image to np.float32 type and range of [0, 1].
|
|
It is mainly used for pre-processing the input image in colorspace
|
|
convertion functions such as rgb2ycbcr and ycbcr2rgb.
|
|
|
|
Args:
|
|
img (ndarray): The input image. It accepts:
|
|
1. np.uint8 type with range [0, 255];
|
|
2. np.float32 type with range [0, 1].
|
|
|
|
Returns:
|
|
(ndarray): The converted image with type of np.float32 and range of
|
|
[0, 1].
|
|
"""
|
|
img_type = img.dtype
|
|
img = img.astype(np.float32)
|
|
if img_type == np.float32:
|
|
pass
|
|
elif img_type == np.uint8:
|
|
img /= 255.
|
|
else:
|
|
raise TypeError('The img type should be np.float32 or np.uint8, '
|
|
f'but got {img_type}')
|
|
return img
|
|
|
|
def _convert_output_type_range(img, dst_type):
|
|
"""Convert the type and range of the image according to dst_type.
|
|
|
|
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
|
images will be converted to np.uint8 type with range [0, 255]. If
|
|
`dst_type` is np.float32, it converts the image to np.float32 type with
|
|
range [0, 1].
|
|
It is mainly used for post-processing images in colorspace convertion
|
|
functions such as rgb2ycbcr and ycbcr2rgb.
|
|
|
|
Args:
|
|
img (ndarray): The image to be converted with np.float32 type and
|
|
range [0, 255].
|
|
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
|
converts the image to np.uint8 type with range [0, 255]. If
|
|
dst_type is np.float32, it converts the image to np.float32 type
|
|
with range [0, 1].
|
|
|
|
Returns:
|
|
(ndarray): The converted image with desired type and range.
|
|
"""
|
|
if dst_type not in (np.uint8, np.float32):
|
|
raise TypeError('The dst_type should be np.float32 or np.uint8, '
|
|
f'but got {dst_type}')
|
|
if dst_type == np.uint8:
|
|
img = img.round()
|
|
else:
|
|
img /= 255.
|
|
return img.astype(dst_type)
|
|
|
|
|
|
def bgr2ycbcr(img, y_only=False):
|
|
"""Convert a BGR image to YCbCr image.
|
|
|
|
The bgr version of rgb2ycbcr.
|
|
It implements the ITU-R BT.601 conversion for standard-definition
|
|
television. See more details in
|
|
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
|
|
|
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
|
In OpenCV, it implements a JPEG conversion. See more details in
|
|
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
|
|
|
Args:
|
|
img (ndarray): The input image. It accepts:
|
|
1. np.uint8 type with range [0, 255];
|
|
2. np.float32 type with range [0, 1].
|
|
y_only (bool): Whether to only return Y channel. Default: False.
|
|
|
|
Returns:
|
|
ndarray: The converted YCbCr image. The output image has the same type
|
|
and range as input image.
|
|
"""
|
|
img_type = img.dtype
|
|
img = _convert_input_type_range(img)
|
|
if y_only:
|
|
# out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
|
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 #RGB
|
|
else:
|
|
out_img = np.matmul(
|
|
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
|
[65.481, -37.797, 112.0]]) + [16, 128, 128]
|
|
out_img = _convert_output_type_range(out_img, img_type)
|
|
return out_img
|
|
|
|
def to_y_channel(img):
|
|
"""Change to Y channel of YCbCr.
|
|
|
|
Args:
|
|
img (ndarray): Images with range [0, 255].
|
|
|
|
Returns:
|
|
(ndarray): Images with range [0, 255] (float type) without round.
|
|
"""
|
|
img = img.astype(np.float32) / 255.
|
|
if img.ndim == 3 and img.shape[2] == 3:
|
|
img = bgr2ycbcr(img, y_only=True)
|
|
img = img[..., None]
|
|
return img * 255.
|
|
|
|
def calculate_psnr(img1,
|
|
img2,
|
|
# crop_border=0,
|
|
test_y_channel=True):
|
|
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
|
|
|
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
|
|
|
Args:
|
|
img1 (ndarray): Images with range [0, 255].
|
|
img2 (ndarray): Images with range [0, 255].
|
|
crop_border (int): Cropped pixels in each edge of an image. These
|
|
pixels are not involved in the PSNR calculation.
|
|
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
|
Default: 'HWC'.
|
|
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
|
|
|
Returns:
|
|
float: psnr result.
|
|
"""
|
|
|
|
# if crop_border != 0:
|
|
# img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
|
# img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
|
|
|
img1 = img1.astype(np.float64)
|
|
img2 = img2.astype(np.float64)
|
|
|
|
if test_y_channel:
|
|
img1 = to_y_channel(img1)
|
|
img2 = to_y_channel(img2)
|
|
|
|
mse = np.mean((img1 - img2)**2)
|
|
if mse == 0:
|
|
return float('inf')
|
|
return 20. * np.log10(255. / np.sqrt(mse))
|
|
|
|
|
|
def _ssim(img1, img2):
|
|
"""Calculate SSIM (structural similarity) for one channel images.
|
|
|
|
It is called by func:`calculate_ssim`.
|
|
|
|
Args:
|
|
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
|
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
|
|
|
Returns:
|
|
float: ssim result.
|
|
"""
|
|
|
|
C1 = (0.01 * 255)**2
|
|
C2 = (0.03 * 255)**2
|
|
|
|
img1 = img1.astype(np.float64)
|
|
img2 = img2.astype(np.float64)
|
|
kernel = cv2.getGaussianKernel(11, 1.5)
|
|
window = np.outer(kernel, kernel.transpose())
|
|
|
|
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
|
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
|
mu1_sq = mu1**2
|
|
mu2_sq = mu2**2
|
|
mu1_mu2 = mu1 * mu2
|
|
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
|
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
|
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
|
|
|
ssim_map = ((2 * mu1_mu2 + C1) *
|
|
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
|
(sigma1_sq + sigma2_sq + C2))
|
|
return ssim_map.mean()
|
|
|
|
|
|
def calculate_ssim(img1,
|
|
img2,
|
|
test_y_channel=True):
|
|
"""Calculate SSIM (structural similarity).
|
|
|
|
Ref:
|
|
Image quality assessment: From error visibility to structural similarity
|
|
|
|
The results are the same as that of the official released MATLAB code in
|
|
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
|
|
|
For three-channel images, SSIM is calculated for each channel and then
|
|
averaged.
|
|
|
|
Args:
|
|
img1 (ndarray): Images with range [0, 255].
|
|
img2 (ndarray): Images with range [0, 255].
|
|
crop_border (int): Cropped pixels in each edge of an image. These
|
|
pixels are not involved in the SSIM calculation.
|
|
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
|
Default: 'HWC'.
|
|
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
|
|
|
Returns:
|
|
float: ssim result.
|
|
"""
|
|
|
|
img1 = img1.astype(np.float64)
|
|
img2 = img2.astype(np.float64)
|
|
|
|
if test_y_channel:
|
|
img1 = to_y_channel(img1)
|
|
img2 = to_y_channel(img2)
|
|
|
|
ssims = []
|
|
for i in range(img1.shape[2]):
|
|
ssims.append(_ssim(img1[..., i], img2[..., i]))
|
|
return np.array(ssims).mean() |