156 lines
5.5 KiB
Python
156 lines
5.5 KiB
Python
import torch
|
|
from torch.fft import fftn
|
|
|
|
|
|
def roll_quadrants(data, backwards=False):
|
|
"""
|
|
Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1]
|
|
Args:
|
|
data: fourier transform, (NxHxW)
|
|
backwards: bool, if True shift high frequencies back to center
|
|
|
|
Returns:
|
|
Shifted fourier transform.
|
|
"""
|
|
dim = data.ndim - 1
|
|
|
|
if dim != 2:
|
|
raise AttributeError(f'Data must be 2d but it is {dim}d.')
|
|
if any(s % 2 == 0 for s in data.shape[1:]):
|
|
raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.')
|
|
|
|
# for each dimension swap left and right half
|
|
dims = tuple(range(1, dim+1)) # add one for batch dimension
|
|
shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd
|
|
if backwards:
|
|
shifts *= -1
|
|
return data.roll(shifts.tolist(), dims=dims)
|
|
|
|
|
|
def batch_fft(data, normalize=False):
|
|
"""
|
|
Compute fourier transform of batch.
|
|
Args:
|
|
data: input tensor, (NxHxW)
|
|
|
|
Returns:
|
|
Batch fourier transform of input data.
|
|
"""
|
|
|
|
dim = data.ndim - 1 # subtract one for batch dimension
|
|
if dim != 2:
|
|
raise AttributeError(f'Data must be 2d but it is {dim}d.')
|
|
|
|
dims = tuple(range(1, dim + 1)) # add one for batch dimension
|
|
if normalize:
|
|
norm = 'ortho'
|
|
else:
|
|
norm = 'backward'
|
|
|
|
if not torch.is_complex(data):
|
|
data = torch.complex(data, torch.zeros_like(data))
|
|
freq = fftn(data, dim=dims, norm=norm)
|
|
|
|
return freq
|
|
|
|
|
|
def azimuthal_average(image, center=None):
|
|
# modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
|
|
"""
|
|
Calculate the azimuthally averaged radial profile.
|
|
Requires low frequencies to be at the center of the image.
|
|
Args:
|
|
image: Batch of 2D images, NxHxW
|
|
center: The [x,y] pixel coordinates used as the center. The default is
|
|
None, which then uses the center of the image (including
|
|
fracitonal pixels).
|
|
|
|
Returns:
|
|
Azimuthal average over the image around the center
|
|
"""
|
|
# Check input shapes
|
|
assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \
|
|
f'(but it is len(center)={len(center)}.'
|
|
# Calculate the indices from the image
|
|
H, W = image.shape[-2:]
|
|
h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
|
|
|
if center is None:
|
|
center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0])
|
|
|
|
# Compute radius for each pixel wrt center
|
|
r = torch.stack([w-center[0], h-center[1]]).norm(2, 0)
|
|
|
|
# Get sorted radii
|
|
r_sorted, ind = r.flatten().sort()
|
|
i_sorted = image.flatten(-2, -1)[..., ind]
|
|
|
|
# Get the integer part of the radii (bin size = 1)
|
|
r_int = r_sorted.long() # attribute to the smaller integer
|
|
|
|
# Find all pixels that fall within each radial bin.
|
|
deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii
|
|
rind = torch.where(deltar)[0] # location of changed radius
|
|
|
|
# compute number of elements in each bin
|
|
nind = rind + 1 # number of elements = idx + 1
|
|
nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders
|
|
nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius
|
|
|
|
# Cumulative sum to figure out sums for each radius bin
|
|
if H % 2 == 0:
|
|
raise NotImplementedError('Not sure if implementation correct, please check')
|
|
rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders
|
|
else:
|
|
rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders
|
|
csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius
|
|
tbin = csim[..., rind[1:]] - csim[..., rind[:-1]]
|
|
# add mean
|
|
tbin = torch.cat([csim[:, 0:1], tbin], 1)
|
|
|
|
radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins
|
|
|
|
return radial_prof
|
|
|
|
|
|
def get_spectrum(data, normalize=False):
|
|
dim = data.ndim - 1 # subtract one for batch dimension
|
|
if dim != 2:
|
|
raise AttributeError(f'Data must be 2d but it is {dim}d.')
|
|
|
|
freq = batch_fft(data, normalize=normalize)
|
|
power_spec = freq.real ** 2 + freq.imag ** 2
|
|
N = data.shape[1]
|
|
if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum
|
|
# and is not averaged with the mean value
|
|
N_2 = N//2
|
|
power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1)
|
|
power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2)
|
|
|
|
power_spec = roll_quadrants(power_spec)
|
|
power_spec = azimuthal_average(power_spec)
|
|
return power_spec
|
|
|
|
|
|
def plot_std(mean, std, x=None, ax=None, **kwargs):
|
|
import matplotlib.pyplot as plt
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1)
|
|
|
|
# plot error margins in same color as line
|
|
err_kwargs = {
|
|
'alpha': 0.3
|
|
}
|
|
|
|
if 'c' in kwargs.keys():
|
|
err_kwargs['color'] = kwargs['c']
|
|
elif 'color' in kwargs.keys():
|
|
err_kwargs['color'] = kwargs['color']
|
|
|
|
if x is None:
|
|
x = torch.linspace(0, 1, len(mean)) # use normalized x axis
|
|
ax.plot(x, mean, **kwargs)
|
|
ax.fill_between(x, mean-std, mean+std, **err_kwargs)
|
|
|
|
return ax
|