support multi-gpu
This commit is contained in:
@@ -0,0 +1,155 @@
|
||||
import torch
|
||||
from torch.fft import fftn
|
||||
|
||||
|
||||
def roll_quadrants(data, backwards=False):
|
||||
"""
|
||||
Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1]
|
||||
Args:
|
||||
data: fourier transform, (NxHxW)
|
||||
backwards: bool, if True shift high frequencies back to center
|
||||
|
||||
Returns:
|
||||
Shifted fourier transform.
|
||||
"""
|
||||
dim = data.ndim - 1
|
||||
|
||||
if dim != 2:
|
||||
raise AttributeError(f'Data must be 2d but it is {dim}d.')
|
||||
if any(s % 2 == 0 for s in data.shape[1:]):
|
||||
raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.')
|
||||
|
||||
# for each dimension swap left and right half
|
||||
dims = tuple(range(1, dim+1)) # add one for batch dimension
|
||||
shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd
|
||||
if backwards:
|
||||
shifts *= -1
|
||||
return data.roll(shifts.tolist(), dims=dims)
|
||||
|
||||
|
||||
def batch_fft(data, normalize=False):
|
||||
"""
|
||||
Compute fourier transform of batch.
|
||||
Args:
|
||||
data: input tensor, (NxHxW)
|
||||
|
||||
Returns:
|
||||
Batch fourier transform of input data.
|
||||
"""
|
||||
|
||||
dim = data.ndim - 1 # subtract one for batch dimension
|
||||
if dim != 2:
|
||||
raise AttributeError(f'Data must be 2d but it is {dim}d.')
|
||||
|
||||
dims = tuple(range(1, dim + 1)) # add one for batch dimension
|
||||
if normalize:
|
||||
norm = 'ortho'
|
||||
else:
|
||||
norm = 'backward'
|
||||
|
||||
if not torch.is_complex(data):
|
||||
data = torch.complex(data, torch.zeros_like(data))
|
||||
freq = fftn(data, dim=dims, norm=norm)
|
||||
|
||||
return freq
|
||||
|
||||
|
||||
def azimuthal_average(image, center=None):
|
||||
# modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
|
||||
"""
|
||||
Calculate the azimuthally averaged radial profile.
|
||||
Requires low frequencies to be at the center of the image.
|
||||
Args:
|
||||
image: Batch of 2D images, NxHxW
|
||||
center: The [x,y] pixel coordinates used as the center. The default is
|
||||
None, which then uses the center of the image (including
|
||||
fracitonal pixels).
|
||||
|
||||
Returns:
|
||||
Azimuthal average over the image around the center
|
||||
"""
|
||||
# Check input shapes
|
||||
assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \
|
||||
f'(but it is len(center)={len(center)}.'
|
||||
# Calculate the indices from the image
|
||||
H, W = image.shape[-2:]
|
||||
h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
||||
|
||||
if center is None:
|
||||
center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0])
|
||||
|
||||
# Compute radius for each pixel wrt center
|
||||
r = torch.stack([w-center[0], h-center[1]]).norm(2, 0)
|
||||
|
||||
# Get sorted radii
|
||||
r_sorted, ind = r.flatten().sort()
|
||||
i_sorted = image.flatten(-2, -1)[..., ind]
|
||||
|
||||
# Get the integer part of the radii (bin size = 1)
|
||||
r_int = r_sorted.long() # attribute to the smaller integer
|
||||
|
||||
# Find all pixels that fall within each radial bin.
|
||||
deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii
|
||||
rind = torch.where(deltar)[0] # location of changed radius
|
||||
|
||||
# compute number of elements in each bin
|
||||
nind = rind + 1 # number of elements = idx + 1
|
||||
nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders
|
||||
nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius
|
||||
|
||||
# Cumulative sum to figure out sums for each radius bin
|
||||
if H % 2 == 0:
|
||||
raise NotImplementedError('Not sure if implementation correct, please check')
|
||||
rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders
|
||||
else:
|
||||
rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders
|
||||
csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius
|
||||
tbin = csim[..., rind[1:]] - csim[..., rind[:-1]]
|
||||
# add mean
|
||||
tbin = torch.cat([csim[:, 0:1], tbin], 1)
|
||||
|
||||
radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins
|
||||
|
||||
return radial_prof
|
||||
|
||||
|
||||
def get_spectrum(data, normalize=False):
|
||||
dim = data.ndim - 1 # subtract one for batch dimension
|
||||
if dim != 2:
|
||||
raise AttributeError(f'Data must be 2d but it is {dim}d.')
|
||||
|
||||
freq = batch_fft(data, normalize=normalize)
|
||||
power_spec = freq.real ** 2 + freq.imag ** 2
|
||||
N = data.shape[1]
|
||||
if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum
|
||||
# and is not averaged with the mean value
|
||||
N_2 = N//2
|
||||
power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1)
|
||||
power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2)
|
||||
|
||||
power_spec = roll_quadrants(power_spec)
|
||||
power_spec = azimuthal_average(power_spec)
|
||||
return power_spec
|
||||
|
||||
|
||||
def plot_std(mean, std, x=None, ax=None, **kwargs):
|
||||
import matplotlib.pyplot as plt
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(1)
|
||||
|
||||
# plot error margins in same color as line
|
||||
err_kwargs = {
|
||||
'alpha': 0.3
|
||||
}
|
||||
|
||||
if 'c' in kwargs.keys():
|
||||
err_kwargs['color'] = kwargs['c']
|
||||
elif 'color' in kwargs.keys():
|
||||
err_kwargs['color'] = kwargs['color']
|
||||
|
||||
if x is None:
|
||||
x = torch.linspace(0, 1, len(mean)) # use normalized x axis
|
||||
ax.plot(x, mean, **kwargs)
|
||||
ax.fill_between(x, mean-std, mean+std, **err_kwargs)
|
||||
|
||||
return ax
|
||||
Reference in New Issue
Block a user