support multi-gpu
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
# empty
|
||||
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "bias_act.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
||||
{
|
||||
if (x.dim() != y.dim())
|
||||
return false;
|
||||
for (int64_t i = 0; i < x.dim(); i++)
|
||||
{
|
||||
if (x.size(i) != y.size(i))
|
||||
return false;
|
||||
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
||||
{
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
||||
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
||||
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
||||
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
||||
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
||||
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
||||
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
||||
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
||||
|
||||
// Validate layout.
|
||||
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
||||
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
||||
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
||||
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
||||
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
||||
|
||||
// Create output tensor.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
torch::Tensor y = torch::empty_like(x);
|
||||
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
bias_act_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
||||
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
||||
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
||||
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
||||
p.y = y.data_ptr();
|
||||
p.grad = grad;
|
||||
p.act = act;
|
||||
p.alpha = alpha;
|
||||
p.gain = gain;
|
||||
p.clamp = clamp;
|
||||
p.sizeX = (int)x.numel();
|
||||
p.sizeB = (int)b.numel();
|
||||
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
void* kernel;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
||||
{
|
||||
kernel = choose_bias_act_kernel<scalar_t>(p);
|
||||
});
|
||||
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
||||
|
||||
// Launch CUDA kernel.
|
||||
p.loopX = 4;
|
||||
int blockSize = 4 * 32;
|
||||
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
||||
void* args[] = {&p};
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return y;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("bias_act", &bias_act);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
#include "bias_act.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
|
||||
template <class T> struct InternalType;
|
||||
template <> struct InternalType<double> { typedef double scalar_t; };
|
||||
template <> struct InternalType<float> { typedef float scalar_t; };
|
||||
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel.
|
||||
|
||||
template <class T, int A>
|
||||
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
int G = p.grad;
|
||||
scalar_t alpha = (scalar_t)p.alpha;
|
||||
scalar_t gain = (scalar_t)p.gain;
|
||||
scalar_t clamp = (scalar_t)p.clamp;
|
||||
scalar_t one = (scalar_t)1;
|
||||
scalar_t two = (scalar_t)2;
|
||||
scalar_t expRange = (scalar_t)80;
|
||||
scalar_t halfExpRange = (scalar_t)40;
|
||||
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
||||
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
||||
|
||||
// Loop over elements.
|
||||
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
||||
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
||||
{
|
||||
// Load.
|
||||
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
||||
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
||||
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
||||
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
||||
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
||||
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
||||
scalar_t y = 0;
|
||||
|
||||
// Apply bias.
|
||||
((G == 0) ? x : xref) += b;
|
||||
|
||||
// linear
|
||||
if (A == 1)
|
||||
{
|
||||
if (G == 0) y = x;
|
||||
if (G == 1) y = x;
|
||||
}
|
||||
|
||||
// relu
|
||||
if (A == 2)
|
||||
{
|
||||
if (G == 0) y = (x > 0) ? x : 0;
|
||||
if (G == 1) y = (yy > 0) ? x : 0;
|
||||
}
|
||||
|
||||
// lrelu
|
||||
if (A == 3)
|
||||
{
|
||||
if (G == 0) y = (x > 0) ? x : x * alpha;
|
||||
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
||||
}
|
||||
|
||||
// tanh
|
||||
if (A == 4)
|
||||
{
|
||||
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
||||
if (G == 1) y = x * (one - yy * yy);
|
||||
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
||||
}
|
||||
|
||||
// sigmoid
|
||||
if (A == 5)
|
||||
{
|
||||
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
||||
if (G == 1) y = x * yy * (one - yy);
|
||||
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
||||
}
|
||||
|
||||
// elu
|
||||
if (A == 6)
|
||||
{
|
||||
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
||||
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
||||
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
||||
}
|
||||
|
||||
// selu
|
||||
if (A == 7)
|
||||
{
|
||||
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
||||
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
||||
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
||||
}
|
||||
|
||||
// softplus
|
||||
if (A == 8)
|
||||
{
|
||||
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
||||
if (G == 1) y = x * (one - exp(-yy));
|
||||
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
||||
}
|
||||
|
||||
// swish
|
||||
if (A == 9)
|
||||
{
|
||||
if (G == 0)
|
||||
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
||||
else
|
||||
{
|
||||
scalar_t c = exp(xref);
|
||||
scalar_t d = c + one;
|
||||
if (G == 1)
|
||||
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
||||
else
|
||||
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
||||
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply gain.
|
||||
y *= gain * dy;
|
||||
|
||||
// Clamp.
|
||||
if (clamp >= 0)
|
||||
{
|
||||
if (G == 0)
|
||||
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
||||
else
|
||||
y = (yref > -clamp & yref < clamp) ? y : 0;
|
||||
}
|
||||
|
||||
// Store.
|
||||
((T*)p.y)[xi] = (T)y;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
||||
{
|
||||
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
||||
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
||||
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
||||
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
||||
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
||||
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
||||
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
||||
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
||||
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Template specializations.
|
||||
|
||||
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
||||
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
||||
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct bias_act_kernel_params
|
||||
{
|
||||
const void* x; // [sizeX]
|
||||
const void* b; // [sizeB] or NULL
|
||||
const void* xref; // [sizeX] or NULL
|
||||
const void* yref; // [sizeX] or NULL
|
||||
const void* dy; // [sizeX] or NULL
|
||||
void* y; // [sizeX]
|
||||
|
||||
int grad;
|
||||
int act;
|
||||
float alpha;
|
||||
float gain;
|
||||
float clamp;
|
||||
|
||||
int sizeX;
|
||||
int sizeB;
|
||||
int stepB;
|
||||
int loopX;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,209 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom PyTorch ops for efficient bias and activation."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import dnnlib
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
activation_funcs = {
|
||||
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
||||
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
||||
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
||||
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
||||
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
||||
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
||||
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
||||
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
||||
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
||||
}
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_plugin = None
|
||||
_null_tensor = torch.empty([0])
|
||||
|
||||
def _init():
|
||||
global _plugin
|
||||
if _plugin is None:
|
||||
_plugin = custom_ops.get_plugin(
|
||||
module_name='bias_act_plugin',
|
||||
sources=['bias_act.cpp', 'bias_act.cu'],
|
||||
headers=['bias_act.h'],
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math'],
|
||||
)
|
||||
return True
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
||||
r"""Fused bias and activation function.
|
||||
|
||||
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
||||
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
||||
the fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports first and second order gradients,
|
||||
but not third order gradients.
|
||||
|
||||
Args:
|
||||
x: Input activation tensor. Can be of any shape.
|
||||
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
||||
as `x`. The shape must be known, and it must match the dimension of `x`
|
||||
corresponding to `dim`.
|
||||
dim: The dimension in `x` corresponding to the elements of `b`.
|
||||
The value of `dim` is ignored if `b` is not specified.
|
||||
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
||||
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
||||
See `activation_funcs` for a full list. `None` is not allowed.
|
||||
alpha: Shape parameter for the activation function, or `None` to use the default.
|
||||
gain: Scaling factor for the output tensor, or `None` to use default.
|
||||
See `activation_funcs` for the default scaling of each activation function.
|
||||
If unsure, consider specifying 1.
|
||||
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
||||
the clamping (default).
|
||||
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
||||
|
||||
Returns:
|
||||
Tensor of the same shape and datatype as `x`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
||||
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
||||
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert clamp is None or clamp >= 0
|
||||
spec = activation_funcs[act]
|
||||
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
||||
gain = float(gain if gain is not None else spec.def_gain)
|
||||
clamp = float(clamp if clamp is not None else -1)
|
||||
|
||||
# Add bias.
|
||||
if b is not None:
|
||||
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
||||
assert 0 <= dim < x.ndim
|
||||
assert b.shape[0] == x.shape[dim]
|
||||
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
||||
|
||||
# Evaluate activation function.
|
||||
alpha = float(alpha)
|
||||
x = spec.func(x, alpha=alpha)
|
||||
|
||||
# Scale by gain.
|
||||
gain = float(gain)
|
||||
if gain != 1:
|
||||
x = x * gain
|
||||
|
||||
# Clamp.
|
||||
if clamp >= 0:
|
||||
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_bias_act_cuda_cache = dict()
|
||||
|
||||
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
||||
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
||||
"""
|
||||
# Parse arguments.
|
||||
assert clamp is None or clamp >= 0
|
||||
spec = activation_funcs[act]
|
||||
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
||||
gain = float(gain if gain is not None else spec.def_gain)
|
||||
clamp = float(clamp if clamp is not None else -1)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (dim, act, alpha, gain, clamp)
|
||||
if key in _bias_act_cuda_cache:
|
||||
return _bias_act_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class BiasActCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
||||
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
|
||||
x = x.contiguous(memory_format=ctx.memory_format)
|
||||
b = b.contiguous() if b is not None else _null_tensor
|
||||
y = x
|
||||
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
||||
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
ctx.save_for_backward(
|
||||
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
||||
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
||||
y if 'y' in spec.ref else _null_tensor)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
dy = dy.contiguous(memory_format=ctx.memory_format)
|
||||
x, b, y = ctx.saved_tensors
|
||||
dx = None
|
||||
db = None
|
||||
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
||||
dx = dy
|
||||
if act != 'linear' or gain != 1 or clamp >= 0:
|
||||
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
||||
|
||||
return dx, db
|
||||
|
||||
# Backward op.
|
||||
class BiasActCudaGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
||||
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
|
||||
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
ctx.save_for_backward(
|
||||
dy if spec.has_2nd_grad else _null_tensor,
|
||||
x, b, y)
|
||||
return dx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
||||
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
||||
dy, x, b, y = ctx.saved_tensors
|
||||
d_dy = None
|
||||
d_x = None
|
||||
d_b = None
|
||||
d_y = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
||||
|
||||
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
||||
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
|
||||
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
||||
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
||||
|
||||
return d_dy, d_x, d_b, d_y
|
||||
|
||||
# Add to cache.
|
||||
_bias_act_cuda_cache[key] = BiasActCuda
|
||||
return BiasActCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,198 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
||||
arbitrarily high order gradients with zero performance penalty."""
|
||||
|
||||
import contextlib
|
||||
import torch
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
# pylint: disable=arguments-differ
|
||||
# pylint: disable=protected-access
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
enabled = False # Enable the custom op by setting this to true.
|
||||
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_weight_gradients(disable=True):
|
||||
global weight_gradients_disabled
|
||||
old = weight_gradients_disabled
|
||||
if disable:
|
||||
weight_gradients_disabled = True
|
||||
yield
|
||||
weight_gradients_disabled = old
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if _should_use_custom_op(input):
|
||||
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
||||
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
if _should_use_custom_op(input):
|
||||
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
||||
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _should_use_custom_op(input):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
if (not enabled) or (not torch.backends.cudnn.enabled):
|
||||
return False
|
||||
if input.device.type != 'cuda':
|
||||
return False
|
||||
return True
|
||||
|
||||
def _tuple_of_ints(xs, ndim):
|
||||
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
||||
assert len(xs) == ndim
|
||||
assert all(isinstance(x, int) for x in xs)
|
||||
return xs
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_conv2d_gradfix_cache = dict()
|
||||
_null_tensor = torch.empty([0])
|
||||
|
||||
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
||||
# Parse arguments.
|
||||
ndim = 2
|
||||
weight_shape = tuple(weight_shape)
|
||||
stride = _tuple_of_ints(stride, ndim)
|
||||
padding = _tuple_of_ints(padding, ndim)
|
||||
output_padding = _tuple_of_ints(output_padding, ndim)
|
||||
dilation = _tuple_of_ints(dilation, ndim)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
||||
if key in _conv2d_gradfix_cache:
|
||||
return _conv2d_gradfix_cache[key]
|
||||
|
||||
# Validate arguments.
|
||||
assert groups >= 1
|
||||
assert len(weight_shape) == ndim + 2
|
||||
assert all(stride[i] >= 1 for i in range(ndim))
|
||||
assert all(padding[i] >= 0 for i in range(ndim))
|
||||
assert all(dilation[i] >= 0 for i in range(ndim))
|
||||
if not transpose:
|
||||
assert all(output_padding[i] == 0 for i in range(ndim))
|
||||
else: # transpose
|
||||
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
||||
|
||||
# Helpers.
|
||||
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
def calc_output_padding(input_shape, output_shape):
|
||||
if transpose:
|
||||
return [0, 0]
|
||||
return [
|
||||
input_shape[i + 2]
|
||||
- (output_shape[i + 2] - 1) * stride[i]
|
||||
- (1 - 2 * padding[i])
|
||||
- dilation[i] * (weight_shape[i + 2] - 1)
|
||||
for i in range(ndim)
|
||||
]
|
||||
|
||||
# Forward & backward.
|
||||
class Conv2d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias):
|
||||
assert weight.shape == weight_shape
|
||||
ctx.save_for_backward(
|
||||
input if weight.requires_grad else _null_tensor,
|
||||
weight if input.requires_grad else _null_tensor,
|
||||
)
|
||||
ctx.input_shape = input.shape
|
||||
|
||||
# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
|
||||
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
|
||||
a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
|
||||
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
|
||||
c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
|
||||
c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
|
||||
c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
||||
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
|
||||
|
||||
# General case => cuDNN.
|
||||
if transpose:
|
||||
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
||||
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
input_shape = ctx.input_shape
|
||||
grad_input = None
|
||||
grad_weight = None
|
||||
grad_bias = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
|
||||
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
|
||||
grad_input = op.apply(grad_output, weight, None)
|
||||
assert grad_input.shape == input_shape
|
||||
|
||||
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
||||
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
||||
assert grad_weight.shape == weight_shape
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = grad_output.sum([0, 2, 3])
|
||||
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
# Gradient with respect to the weights.
|
||||
class Conv2dGradWeight(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input):
|
||||
ctx.save_for_backward(
|
||||
grad_output if input.requires_grad else _null_tensor,
|
||||
input if grad_output.requires_grad else _null_tensor,
|
||||
)
|
||||
ctx.grad_output_shape = grad_output.shape
|
||||
ctx.input_shape = input.shape
|
||||
|
||||
# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
|
||||
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
|
||||
a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
|
||||
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
|
||||
c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
|
||||
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
|
||||
|
||||
# General case => cuDNN.
|
||||
name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
|
||||
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
||||
return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_weight):
|
||||
grad_output, input = ctx.saved_tensors
|
||||
grad_output_shape = ctx.grad_output_shape
|
||||
input_shape = ctx.input_shape
|
||||
grad2_grad_output = None
|
||||
grad2_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
||||
assert grad2_grad_output.shape == grad_output_shape
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
|
||||
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
|
||||
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
|
||||
assert grad2_input.shape == input_shape
|
||||
|
||||
return grad2_grad_output, grad2_input
|
||||
|
||||
_conv2d_gradfix_cache[key] = Conv2d
|
||||
return Conv2d
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""2D convolution with optional up/downsampling."""
|
||||
|
||||
import torch
|
||||
|
||||
from .. import misc
|
||||
from . import conv2d_gradfix
|
||||
from . import upfirdn2d
|
||||
from .upfirdn2d import _parse_padding
|
||||
from .upfirdn2d import _get_filter_size
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _get_weight_shape(w):
|
||||
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
||||
shape = [int(sz) for sz in w.shape]
|
||||
misc.assert_shape(w, shape)
|
||||
return shape
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
||||
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
||||
"""
|
||||
_out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
|
||||
|
||||
# Flip weight if requested.
|
||||
# Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
||||
if not flip_weight and (kw > 1 or kh > 1):
|
||||
w = w.flip([2, 3])
|
||||
|
||||
# Execute using conv2d_gradfix.
|
||||
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
||||
return op(x, w, stride=stride, padding=padding, groups=groups)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
||||
r"""2D convolution with optional up/downsampling.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape
|
||||
`[batch_size, in_channels, in_height, in_width]`.
|
||||
w: Weight tensor of shape
|
||||
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
||||
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
||||
calling upfirdn2d.setup_filter(). None = identity (default).
|
||||
up: Integer upsampling factor (default: 1).
|
||||
down: Integer downsampling factor (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
groups: Split input channels into N groups (default: 1).
|
||||
flip_weight: False = convolution, True = correlation (default: True).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
# Validate arguments.
|
||||
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
||||
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
||||
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
||||
assert isinstance(up, int) and (up >= 1)
|
||||
assert isinstance(down, int) and (down >= 1)
|
||||
assert isinstance(groups, int) and (groups >= 1)
|
||||
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
||||
fw, fh = _get_filter_size(f)
|
||||
px0, px1, py0, py1 = _parse_padding(padding)
|
||||
|
||||
# Adjust padding to account for up/downsampling.
|
||||
if up > 1:
|
||||
px0 += (fw + up - 1) // 2
|
||||
px1 += (fw - up) // 2
|
||||
py0 += (fh + up - 1) // 2
|
||||
py1 += (fh - up) // 2
|
||||
if down > 1:
|
||||
px0 += (fw - down + 1) // 2
|
||||
px1 += (fw - down) // 2
|
||||
py0 += (fh - down + 1) // 2
|
||||
py1 += (fh - down) // 2
|
||||
|
||||
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
||||
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
return x
|
||||
|
||||
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
||||
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
# Fast path: downsampling only => use strided convolution.
|
||||
if down > 1 and up == 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
||||
return x
|
||||
|
||||
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
||||
if up > 1:
|
||||
if groups == 1:
|
||||
w = w.transpose(0, 1)
|
||||
else:
|
||||
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
||||
w = w.transpose(1, 2)
|
||||
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
||||
px0 -= kw - 1
|
||||
px1 -= kw - up
|
||||
py0 -= kh - 1
|
||||
py1 -= kh - up
|
||||
pxt = max(min(-px0, -px1), 0)
|
||||
pyt = max(min(-py0, -py1), 0)
|
||||
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
||||
if down > 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
||||
if up == 1 and down == 1:
|
||||
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
||||
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
||||
|
||||
# Fallback: Generic reference implementation.
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
if down > 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,300 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "filtered_lrelu.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
|
||||
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
|
||||
int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
|
||||
{
|
||||
// Set CUDA device.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
|
||||
TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
|
||||
TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
|
||||
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(x.numel() > 0, "x is empty");
|
||||
TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
|
||||
TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
|
||||
TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
|
||||
TORCH_CHECK(fu.numel() > 0, "fu is empty");
|
||||
TORCH_CHECK(fd.numel() > 0, "fd is empty");
|
||||
TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
|
||||
TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
|
||||
|
||||
// Figure out how much shared memory is available on the device.
|
||||
int maxSharedBytes = 0;
|
||||
AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
|
||||
int sharedKB = maxSharedBytes >> 10;
|
||||
|
||||
// Populate enough launch parameters to check if a CUDA kernel exists.
|
||||
filtered_lrelu_kernel_params p;
|
||||
p.up = up;
|
||||
p.down = down;
|
||||
p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
|
||||
p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
|
||||
filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
|
||||
if (!test_spec.exec)
|
||||
{
|
||||
// No kernel found - return empty tensors and indicate missing kernel with return code of -1.
|
||||
return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
|
||||
}
|
||||
|
||||
// Input/output element size.
|
||||
int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
|
||||
|
||||
// Input sizes.
|
||||
int64_t xw = (int)x.size(3);
|
||||
int64_t xh = (int)x.size(2);
|
||||
int64_t fut_w = (int)fu.size(-1) - 1;
|
||||
int64_t fut_h = (int)fu.size(0) - 1;
|
||||
int64_t fdt_w = (int)fd.size(-1) - 1;
|
||||
int64_t fdt_h = (int)fd.size(0) - 1;
|
||||
|
||||
// Logical size of upsampled buffer.
|
||||
int64_t cw = xw * up + (px0 + px1) - fut_w;
|
||||
int64_t ch = xh * up + (py0 + py1) - fut_h;
|
||||
TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
|
||||
TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
|
||||
|
||||
// Compute output size and allocate.
|
||||
int64_t yw = (cw - fdt_w + (down - 1)) / down;
|
||||
int64_t yh = (ch - fdt_h + (down - 1)) / down;
|
||||
TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
|
||||
TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
|
||||
torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
|
||||
|
||||
// Allocate sign tensor.
|
||||
torch::Tensor so;
|
||||
torch::Tensor s = si;
|
||||
bool readSigns = !!s.numel();
|
||||
int64_t sw_active = 0; // Active width of sign tensor.
|
||||
if (writeSigns)
|
||||
{
|
||||
sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
|
||||
int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
|
||||
int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
|
||||
TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
|
||||
s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
|
||||
}
|
||||
else if (readSigns)
|
||||
sw_active = s.size(3) << 2;
|
||||
|
||||
// Validate sign tensor if in use.
|
||||
if (readSigns || writeSigns)
|
||||
{
|
||||
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
|
||||
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
|
||||
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
|
||||
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
|
||||
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
|
||||
TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
|
||||
}
|
||||
|
||||
// Populate rest of CUDA kernel parameters.
|
||||
p.x = x.data_ptr();
|
||||
p.y = y.data_ptr();
|
||||
p.b = b.data_ptr();
|
||||
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
|
||||
p.fu = fu.data_ptr<float>();
|
||||
p.fd = fd.data_ptr<float>();
|
||||
p.pad0 = make_int2(px0, py0);
|
||||
p.gain = gain;
|
||||
p.slope = slope;
|
||||
p.clamp = clamp;
|
||||
p.flip = (flip_filters) ? 1 : 0;
|
||||
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
||||
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
|
||||
p.sOfs = make_int2(sx, sy);
|
||||
p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
|
||||
|
||||
// x, y, b strides are in bytes.
|
||||
p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
|
||||
p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
|
||||
p.bStride = sz * b.stride(0);
|
||||
|
||||
// fu, fd strides are in elements.
|
||||
p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
|
||||
p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
|
||||
|
||||
// Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
|
||||
bool index64b = false;
|
||||
if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
|
||||
if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
|
||||
if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
|
||||
if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
|
||||
if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
|
||||
if (s.numel() > INT_MAX) index64b = true;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
filtered_lrelu_kernel_spec spec = { 0 };
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
|
||||
{
|
||||
if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
|
||||
{
|
||||
// Choose kernel based on index type, datatype and sign read/write modes.
|
||||
if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
|
||||
else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
|
||||
else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
|
||||
else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
|
||||
else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
|
||||
else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
|
||||
}
|
||||
});
|
||||
TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
|
||||
|
||||
// Launch CUDA kernel.
|
||||
void* args[] = {&p};
|
||||
int bx = spec.numWarps * 32;
|
||||
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
|
||||
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
|
||||
int gz = p.yShape.z * p.yShape.w;
|
||||
|
||||
// Repeat multiple horizontal tiles in a CTA?
|
||||
if (spec.xrep)
|
||||
{
|
||||
p.tilesXrep = spec.xrep;
|
||||
p.tilesXdim = gx;
|
||||
|
||||
gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
|
||||
std::swap(gx, gy);
|
||||
}
|
||||
else
|
||||
{
|
||||
p.tilesXrep = 0;
|
||||
p.tilesXdim = 0;
|
||||
}
|
||||
|
||||
// Launch filter setup kernel.
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
|
||||
// Copy kernels to constant memory.
|
||||
if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
|
||||
else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
|
||||
else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
|
||||
|
||||
// Set cache and shared memory configurations for main kernel.
|
||||
AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
|
||||
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
|
||||
AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
|
||||
AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
|
||||
|
||||
// Launch main kernel.
|
||||
const int maxSubGz = 65535; // CUDA maximum for block z dimension.
|
||||
for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
|
||||
{
|
||||
p.blockZofs = zofs;
|
||||
int subGz = std::min(maxSubGz, gz - zofs);
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
|
||||
}
|
||||
|
||||
// Done.
|
||||
return std::make_tuple(y, so, 0);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
|
||||
{
|
||||
// Set CUDA device.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(x.numel() > 0, "x is empty");
|
||||
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
|
||||
|
||||
// Output signs if we don't have sign input.
|
||||
torch::Tensor so;
|
||||
torch::Tensor s = si;
|
||||
bool readSigns = !!s.numel();
|
||||
if (writeSigns)
|
||||
{
|
||||
int64_t sw = x.size(3);
|
||||
sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
|
||||
s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
// Validate sign tensor if in use.
|
||||
if (readSigns || writeSigns)
|
||||
{
|
||||
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
|
||||
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
|
||||
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
|
||||
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
|
||||
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
|
||||
TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
|
||||
}
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
filtered_lrelu_act_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
|
||||
p.gain = gain;
|
||||
p.slope = slope;
|
||||
p.clamp = clamp;
|
||||
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
|
||||
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
|
||||
p.sOfs = make_int2(sx, sy);
|
||||
|
||||
// Choose CUDA kernel.
|
||||
void* func = 0;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
|
||||
{
|
||||
if (writeSigns)
|
||||
func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
|
||||
else if (readSigns)
|
||||
func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
|
||||
else
|
||||
func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
|
||||
});
|
||||
TORCH_CHECK(func, "internal error - CUDA kernel not found");
|
||||
|
||||
// Launch CUDA kernel.
|
||||
void* args[] = {&p};
|
||||
int bx = 128; // 4 warps per block.
|
||||
|
||||
// Logical size of launch = writeSigns ? p.s : p.x
|
||||
uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
|
||||
uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
|
||||
uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
|
||||
gx = (gx - 1) / bx + 1;
|
||||
|
||||
// Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
|
||||
const uint32_t gmax = 65535;
|
||||
gy = std::min(gy, gmax);
|
||||
gz = std::min(gz, gmax);
|
||||
|
||||
// Launch.
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return so;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
|
||||
m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,90 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct filtered_lrelu_kernel_params
|
||||
{
|
||||
// These parameters decide which kernel to use.
|
||||
int up; // upsampling ratio (1, 2, 4)
|
||||
int down; // downsampling ratio (1, 2, 4)
|
||||
int2 fuShape; // [size, 1] | [size, size]
|
||||
int2 fdShape; // [size, 1] | [size, size]
|
||||
|
||||
int _dummy; // Alignment.
|
||||
|
||||
// Rest of the parameters.
|
||||
const void* x; // Input tensor.
|
||||
void* y; // Output tensor.
|
||||
const void* b; // Bias tensor.
|
||||
unsigned char* s; // Sign tensor in/out. NULL if unused.
|
||||
const float* fu; // Upsampling filter.
|
||||
const float* fd; // Downsampling filter.
|
||||
|
||||
int2 pad0; // Left/top padding.
|
||||
float gain; // Additional gain factor.
|
||||
float slope; // Leaky ReLU slope on negative side.
|
||||
float clamp; // Clamp after nonlinearity.
|
||||
int flip; // Filter kernel flip for gradient computation.
|
||||
|
||||
int tilesXdim; // Original number of horizontal output tiles.
|
||||
int tilesXrep; // Number of horizontal tiles per CTA.
|
||||
int blockZofs; // Block z offset to support large minibatch, channel dimensions.
|
||||
|
||||
int4 xShape; // [width, height, channel, batch]
|
||||
int4 yShape; // [width, height, channel, batch]
|
||||
int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
|
||||
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
|
||||
int swLimit; // Active width of sign tensor in bytes.
|
||||
|
||||
longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
|
||||
longlong4 yStride; //
|
||||
int64_t bStride; //
|
||||
longlong3 fuStride; //
|
||||
longlong3 fdStride; //
|
||||
};
|
||||
|
||||
struct filtered_lrelu_act_kernel_params
|
||||
{
|
||||
void* x; // Input/output, modified in-place.
|
||||
unsigned char* s; // Sign tensor in/out. NULL if unused.
|
||||
|
||||
float gain; // Additional gain factor.
|
||||
float slope; // Leaky ReLU slope on negative side.
|
||||
float clamp; // Clamp after nonlinearity.
|
||||
|
||||
int4 xShape; // [width, height, channel, batch]
|
||||
longlong4 xStride; // Input/output tensor strides, same order as in shape.
|
||||
int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
|
||||
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel specialization.
|
||||
|
||||
struct filtered_lrelu_kernel_spec
|
||||
{
|
||||
void* setup; // Function for filter kernel setup.
|
||||
void* exec; // Function for main operation.
|
||||
int2 tileOut; // Width/height of launch tile.
|
||||
int numWarps; // Number of warps per thread block, determines launch block size.
|
||||
int xrep; // For processing multiple horizontal tiles per thread block.
|
||||
int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
|
||||
template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,275 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
from . import upfirdn2d
|
||||
from . import bias_act
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_plugin = None
|
||||
|
||||
def _init():
|
||||
global _plugin
|
||||
if _plugin is None:
|
||||
_plugin = custom_ops.get_plugin(
|
||||
module_name='filtered_lrelu_plugin',
|
||||
sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
|
||||
headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math'],
|
||||
)
|
||||
return True
|
||||
|
||||
def _get_filter_size(f):
|
||||
if f is None:
|
||||
return 1, 1
|
||||
assert isinstance(f, torch.Tensor)
|
||||
assert 1 <= f.ndim <= 2
|
||||
return f.shape[-1], f.shape[0] # width, height
|
||||
|
||||
def _parse_padding(padding):
|
||||
if isinstance(padding, int):
|
||||
padding = [padding, padding]
|
||||
assert isinstance(padding, (list, tuple))
|
||||
assert all(isinstance(x, (int, np.integer)) for x in padding)
|
||||
padding = [int(x) for x in padding]
|
||||
if len(padding) == 2:
|
||||
px, py = padding
|
||||
padding = [px, px, py, py]
|
||||
px0, px1, py0, py1 = padding
|
||||
return px0, px1, py0, py1
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
|
||||
r"""Filtered leaky ReLU for a batch of 2D images.
|
||||
|
||||
Performs the following sequence of operations for each channel:
|
||||
|
||||
1. Add channel-specific bias if provided (`b`).
|
||||
|
||||
2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
||||
|
||||
3. Pad the image with the specified number of zeros on each side (`padding`).
|
||||
Negative padding corresponds to cropping the image.
|
||||
|
||||
4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
|
||||
so that the footprint of all output pixels lies within the input image.
|
||||
|
||||
5. Multiply each value by the provided gain factor (`gain`).
|
||||
|
||||
6. Apply leaky ReLU activation function to each value.
|
||||
|
||||
7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
|
||||
|
||||
8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
|
||||
it so that the footprint of all output pixels lies within the input image.
|
||||
|
||||
9. Downsample the image by keeping every Nth pixel (`down`).
|
||||
|
||||
The fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports gradients of arbitrary order.
|
||||
|
||||
Args:
|
||||
x: Float32/float16/float64 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
fu: Float32 upsampling FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
fd: Float32 downsampling FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
||||
as `x`. The length of vector must must match the channel dimension of `x`.
|
||||
up: Integer upsampling factor (default: 1).
|
||||
down: Integer downsampling factor. (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
|
||||
slope: Slope on the negative side of leaky ReLU (default: 0.2).
|
||||
clamp: Maximum magnitude for leaky ReLU output (default: None).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
|
||||
return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
|
||||
"""Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
|
||||
existing `upfirdn2n()` and `bias_act()` ops.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
fu_w, fu_h = _get_filter_size(fu)
|
||||
fd_w, fd_h = _get_filter_size(fd)
|
||||
if b is not None:
|
||||
assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
|
||||
misc.assert_shape(b, [x.shape[1]])
|
||||
assert isinstance(up, int) and up >= 1
|
||||
assert isinstance(down, int) and down >= 1
|
||||
px0, px1, py0, py1 = _parse_padding(padding)
|
||||
assert gain == float(gain) and gain > 0
|
||||
assert slope == float(slope) and slope >= 0
|
||||
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
|
||||
|
||||
# Calculate output size.
|
||||
batch_size, channels, in_h, in_w = x.shape
|
||||
in_dtype = x.dtype
|
||||
out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
|
||||
out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
|
||||
|
||||
# Compute using existing ops.
|
||||
x = bias_act.bias_act(x=x, b=b) # Apply bias.
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
|
||||
x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
|
||||
|
||||
# Check output shape & dtype.
|
||||
misc.assert_shape(x, [batch_size, channels, out_h, out_w])
|
||||
assert x.dtype == in_dtype
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_filtered_lrelu_cuda_cache = dict()
|
||||
|
||||
def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
|
||||
"""Fast CUDA implementation of `filtered_lrelu()` using custom ops.
|
||||
"""
|
||||
assert isinstance(up, int) and up >= 1
|
||||
assert isinstance(down, int) and down >= 1
|
||||
px0, px1, py0, py1 = _parse_padding(padding)
|
||||
assert gain == float(gain) and gain > 0
|
||||
gain = float(gain)
|
||||
assert slope == float(slope) and slope >= 0
|
||||
slope = float(slope)
|
||||
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
|
||||
clamp = float(clamp if clamp is not None else 'inf')
|
||||
|
||||
# Lookup from cache.
|
||||
key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
|
||||
if key in _filtered_lrelu_cuda_cache:
|
||||
return _filtered_lrelu_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class FilteredLReluCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
|
||||
# Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
|
||||
if fu is None:
|
||||
fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
if fd is None:
|
||||
fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
assert 1 <= fu.ndim <= 2
|
||||
assert 1 <= fd.ndim <= 2
|
||||
|
||||
# Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
|
||||
if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
|
||||
fu = fu.square()[None]
|
||||
if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
|
||||
fd = fd.square()[None]
|
||||
|
||||
# Missing sign input tensor.
|
||||
if si is None:
|
||||
si = torch.empty([0])
|
||||
|
||||
# Missing bias tensor.
|
||||
if b is None:
|
||||
b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
|
||||
|
||||
# Construct internal sign tensor only if gradients are needed.
|
||||
write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
|
||||
|
||||
# Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
|
||||
x = x.contiguous()
|
||||
strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
|
||||
if any(a < b for a, b in zip(strides[:-1], strides[1:])):
|
||||
warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
|
||||
|
||||
# Call C++/Cuda plugin if datatype is supported.
|
||||
if x.dtype in [torch.float16, torch.float32]:
|
||||
if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
|
||||
warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
|
||||
y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
|
||||
else:
|
||||
return_code = -1
|
||||
|
||||
# No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
|
||||
# only the bit-packed sign tensor is retained for gradient computation.
|
||||
if return_code < 0:
|
||||
warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
|
||||
|
||||
y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
|
||||
y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
|
||||
so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
|
||||
y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
|
||||
|
||||
# Prepare for gradient computation.
|
||||
ctx.save_for_backward(fu, fd, (si if si.numel() else so))
|
||||
ctx.x_shape = x.shape
|
||||
ctx.y_shape = y.shape
|
||||
ctx.s_ofs = sx, sy
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
fu, fd, si = ctx.saved_tensors
|
||||
_, _, xh, xw = ctx.x_shape
|
||||
_, _, yh, yw = ctx.y_shape
|
||||
sx, sy = ctx.s_ofs
|
||||
dx = None # 0
|
||||
dfu = None; assert not ctx.needs_input_grad[1]
|
||||
dfd = None; assert not ctx.needs_input_grad[2]
|
||||
db = None # 3
|
||||
dsi = None; assert not ctx.needs_input_grad[4]
|
||||
dsx = None; assert not ctx.needs_input_grad[5]
|
||||
dsy = None; assert not ctx.needs_input_grad[6]
|
||||
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
|
||||
pp = [
|
||||
(fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
|
||||
xw * up - yw * down + px0 - (up - 1),
|
||||
(fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
|
||||
xh * up - yh * down + py0 - (up - 1),
|
||||
]
|
||||
gg = gain * (up ** 2) / (down ** 2)
|
||||
ff = (not flip_filter)
|
||||
sx = sx - (fu.shape[-1] - 1) + px0
|
||||
sy = sy - (fu.shape[0] - 1) + py0
|
||||
dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
|
||||
|
||||
if ctx.needs_input_grad[3]:
|
||||
db = dx.sum([0, 2, 3])
|
||||
|
||||
return dx, dfu, dfd, db, dsi, dsx, dsy
|
||||
|
||||
# Add to cache.
|
||||
_filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
|
||||
return FilteredLReluCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include "filtered_lrelu.cu"
|
||||
|
||||
// Template/kernel specializations for no signs mode (no gradients required).
|
||||
|
||||
// Full op, 32-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Full op, 64-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Activation/signs only for generic variant. 64-bit indexing.
|
||||
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
|
||||
|
||||
// Copy filters to constant memory.
|
||||
template cudaError_t copy_filters<false, false>(cudaStream_t stream);
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include "filtered_lrelu.cu"
|
||||
|
||||
// Template/kernel specializations for sign read mode.
|
||||
|
||||
// Full op, 32-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Full op, 64-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Activation/signs only for generic variant. 64-bit indexing.
|
||||
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
|
||||
|
||||
// Copy filters to constant memory.
|
||||
template cudaError_t copy_filters<false, true>(cudaStream_t stream);
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include "filtered_lrelu.cu"
|
||||
|
||||
// Template/kernel specializations for sign write mode.
|
||||
|
||||
// Full op, 32-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Full op, 64-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Activation/signs only for generic variant. 64-bit indexing.
|
||||
template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
|
||||
|
||||
// Copy filters to constant memory.
|
||||
template cudaError_t copy_filters<true, false>(cudaStream_t stream);
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
||||
|
||||
import torch
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def fma(a, b, c): # => a * b + c
|
||||
return _FusedMultiplyAdd.apply(a, b, c)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
||||
out = torch.addcmul(c, a, b)
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.c_shape = c.shape
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout): # pylint: disable=arguments-differ
|
||||
a, b = ctx.saved_tensors
|
||||
c_shape = ctx.c_shape
|
||||
da = None
|
||||
db = None
|
||||
dc = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
da = _unbroadcast(dout * b, a.shape)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
db = _unbroadcast(dout * a, b.shape)
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
dc = _unbroadcast(dout, c_shape)
|
||||
|
||||
return da, db, dc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _unbroadcast(x, shape):
|
||||
extra_dims = x.ndim - len(shape)
|
||||
assert extra_dims >= 0
|
||||
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
||||
if len(dim):
|
||||
x = x.sum(dim=dim, keepdim=True)
|
||||
if extra_dims:
|
||||
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
||||
assert x.shape == shape
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
||||
supports arbitrarily high order gradients between the input and output.
|
||||
Only works on 2D images and assumes
|
||||
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
||||
|
||||
import torch
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
# pylint: disable=arguments-differ
|
||||
# pylint: disable=protected-access
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
enabled = False # Enable the custom op by setting this to true.
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def grid_sample(input, grid):
|
||||
if _should_use_custom_op():
|
||||
return _GridSample2dForward.apply(input, grid)
|
||||
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _should_use_custom_op():
|
||||
return enabled
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _GridSample2dForward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, grid):
|
||||
assert input.ndim == 4
|
||||
assert grid.ndim == 4
|
||||
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
ctx.save_for_backward(input, grid)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, grid = ctx.saved_tensors
|
||||
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
||||
return grad_input, grad_grid
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _GridSample2dBackward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input, grid):
|
||||
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
||||
ctx.save_for_backward(grid)
|
||||
return grad_input, grad_grid
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
||||
_ = grad2_grad_grid # unused
|
||||
grid, = ctx.saved_tensors
|
||||
grad2_grad_output = None
|
||||
grad2_input = None
|
||||
grad2_grid = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
||||
|
||||
assert not ctx.needs_input_grad[2]
|
||||
return grad2_grad_output, grad2_input, grad2_grid
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,107 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "upfirdn2d.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
||||
{
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
||||
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
||||
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
||||
TORCH_CHECK(x.numel() > 0, "x has zero size");
|
||||
TORCH_CHECK(f.numel() > 0, "f has zero size");
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
||||
TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
|
||||
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
||||
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
||||
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
||||
|
||||
// Create output tensor.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
||||
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
||||
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
||||
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
||||
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
||||
TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
upfirdn2d_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.f = f.data_ptr<float>();
|
||||
p.y = y.data_ptr();
|
||||
p.up = make_int2(upx, upy);
|
||||
p.down = make_int2(downx, downy);
|
||||
p.pad0 = make_int2(padx0, pady0);
|
||||
p.flip = (flip) ? 1 : 0;
|
||||
p.gain = gain;
|
||||
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
||||
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
||||
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
||||
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
||||
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
||||
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
||||
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
upfirdn2d_kernel_spec spec;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
||||
{
|
||||
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
||||
});
|
||||
|
||||
// Set looping options.
|
||||
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
||||
p.loopMinor = spec.loopMinor;
|
||||
p.loopX = spec.loopX;
|
||||
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
||||
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
||||
|
||||
// Compute grid size.
|
||||
dim3 blockSize, gridSize;
|
||||
if (spec.tileOutW < 0) // large
|
||||
{
|
||||
blockSize = dim3(4, 32, 1);
|
||||
gridSize = dim3(
|
||||
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
||||
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
||||
p.launchMajor);
|
||||
}
|
||||
else // small
|
||||
{
|
||||
blockSize = dim3(256, 1, 1);
|
||||
gridSize = dim3(
|
||||
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
||||
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
||||
p.launchMajor);
|
||||
}
|
||||
|
||||
// Launch CUDA kernel.
|
||||
void* args[] = {&p};
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return y;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("upfirdn2d", &upfirdn2d);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,384 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
#include "upfirdn2d.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
|
||||
template <class T> struct InternalType;
|
||||
template <> struct InternalType<double> { typedef double scalar_t; };
|
||||
template <> struct InternalType<float> { typedef float scalar_t; };
|
||||
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
||||
|
||||
static __device__ __forceinline__ int floor_div(int a, int b)
|
||||
{
|
||||
int t = 1 - a / b;
|
||||
return (a + t * b) / b - t;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Generic CUDA implementation for large filters.
|
||||
|
||||
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
|
||||
// Calculate thread index.
|
||||
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int outY = minorBase / p.launchMinor;
|
||||
minorBase -= outY * p.launchMinor;
|
||||
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
||||
int majorBase = blockIdx.z * p.loopMajor;
|
||||
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
||||
return;
|
||||
|
||||
// Setup Y receptive field.
|
||||
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
||||
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
||||
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
||||
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
||||
if (p.flip)
|
||||
filterY = p.filterSize.y - 1 - filterY;
|
||||
|
||||
// Loop over major, minor, and X.
|
||||
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
||||
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
||||
{
|
||||
int nc = major * p.sizeMinor + minor;
|
||||
int n = nc / p.inSize.z;
|
||||
int c = nc - n * p.inSize.z;
|
||||
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
||||
{
|
||||
// Setup X receptive field.
|
||||
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
||||
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
||||
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
||||
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
||||
if (p.flip)
|
||||
filterX = p.filterSize.x - 1 - filterX;
|
||||
|
||||
// Initialize pointers.
|
||||
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
||||
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
||||
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
||||
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
||||
|
||||
// Inner loop.
|
||||
scalar_t v = 0;
|
||||
for (int y = 0; y < h; y++)
|
||||
{
|
||||
for (int x = 0; x < w; x++)
|
||||
{
|
||||
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
||||
xp += p.inStride.x;
|
||||
fp += filterStepX;
|
||||
}
|
||||
xp += p.inStride.y - w * p.inStride.x;
|
||||
fp += filterStepY - w * filterStepX;
|
||||
}
|
||||
|
||||
// Store result.
|
||||
v *= p.gain;
|
||||
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Specialized CUDA implementation for small filters.
|
||||
|
||||
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
||||
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
||||
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
||||
__shared__ volatile scalar_t sf[filterH][filterW];
|
||||
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
||||
|
||||
// Calculate tile index.
|
||||
int minorBase = blockIdx.x;
|
||||
int tileOutY = minorBase / p.launchMinor;
|
||||
minorBase -= tileOutY * p.launchMinor;
|
||||
minorBase *= loopMinor;
|
||||
tileOutY *= tileOutH;
|
||||
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
||||
int majorBase = blockIdx.z * p.loopMajor;
|
||||
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
||||
return;
|
||||
|
||||
// Load filter (flipped).
|
||||
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
||||
{
|
||||
int fy = tapIdx / filterW;
|
||||
int fx = tapIdx - fy * filterW;
|
||||
scalar_t v = 0;
|
||||
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
||||
{
|
||||
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
||||
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
||||
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
||||
}
|
||||
sf[fy][fx] = v;
|
||||
}
|
||||
|
||||
// Loop over major and X.
|
||||
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
||||
{
|
||||
int baseNC = major * p.sizeMinor + minorBase;
|
||||
int n = baseNC / p.inSize.z;
|
||||
int baseC = baseNC - n * p.inSize.z;
|
||||
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
||||
{
|
||||
// Load input pixels.
|
||||
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
||||
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
||||
int tileInX = floor_div(tileMidX, upx);
|
||||
int tileInY = floor_div(tileMidY, upy);
|
||||
__syncthreads();
|
||||
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
||||
{
|
||||
int relC = inIdx;
|
||||
int relInX = relC / loopMinor;
|
||||
int relInY = relInX / tileInW;
|
||||
relC -= relInX * loopMinor;
|
||||
relInX -= relInY * tileInW;
|
||||
int c = baseC + relC;
|
||||
int inX = tileInX + relInX;
|
||||
int inY = tileInY + relInY;
|
||||
scalar_t v = 0;
|
||||
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
||||
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
||||
sx[relInY][relInX][relC] = v;
|
||||
}
|
||||
|
||||
// Loop over output pixels.
|
||||
__syncthreads();
|
||||
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
||||
{
|
||||
int relC = outIdx;
|
||||
int relOutX = relC / loopMinor;
|
||||
int relOutY = relOutX / tileOutW;
|
||||
relC -= relOutX * loopMinor;
|
||||
relOutX -= relOutY * tileOutW;
|
||||
int c = baseC + relC;
|
||||
int outX = tileOutX + relOutX;
|
||||
int outY = tileOutY + relOutY;
|
||||
|
||||
// Setup receptive field.
|
||||
int midX = tileMidX + relOutX * downx;
|
||||
int midY = tileMidY + relOutY * downy;
|
||||
int inX = floor_div(midX, upx);
|
||||
int inY = floor_div(midY, upy);
|
||||
int relInX = inX - tileInX;
|
||||
int relInY = inY - tileInY;
|
||||
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
||||
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
||||
|
||||
// Inner loop.
|
||||
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
||||
{
|
||||
scalar_t v = 0;
|
||||
#pragma unroll
|
||||
for (int y = 0; y < filterH / upy; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < filterW / upx; x++)
|
||||
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
||||
v *= p.gain;
|
||||
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
||||
{
|
||||
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
||||
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
||||
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
||||
|
||||
// No up/downsampling.
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
|
||||
// 2x upsampling.
|
||||
if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
||||
}
|
||||
if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
|
||||
// 2x downsampling.
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
||||
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
||||
}
|
||||
|
||||
// 4x upsampling.
|
||||
if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
|
||||
}
|
||||
if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
|
||||
// 4x downsampling (inefficient).
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
|
||||
}
|
||||
return spec;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Template specializations.
|
||||
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct upfirdn2d_kernel_params
|
||||
{
|
||||
const void* x;
|
||||
const float* f;
|
||||
void* y;
|
||||
|
||||
int2 up;
|
||||
int2 down;
|
||||
int2 pad0;
|
||||
int flip;
|
||||
float gain;
|
||||
|
||||
int4 inSize; // [width, height, channel, batch]
|
||||
int4 inStride;
|
||||
int2 filterSize; // [width, height]
|
||||
int2 filterStride;
|
||||
int4 outSize; // [width, height, channel, batch]
|
||||
int4 outStride;
|
||||
int sizeMinor;
|
||||
int sizeMajor;
|
||||
|
||||
int loopMinor;
|
||||
int loopMajor;
|
||||
int loopX;
|
||||
int launchMinor;
|
||||
int launchMajor;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel specialization.
|
||||
|
||||
struct upfirdn2d_kernel_spec
|
||||
{
|
||||
void* kernel;
|
||||
int tileOutW;
|
||||
int tileOutH;
|
||||
int loopMinor;
|
||||
int loopX;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,389 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
from . import conv2d_gradfix
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_plugin = None
|
||||
|
||||
def _init():
|
||||
global _plugin
|
||||
if _plugin is None:
|
||||
_plugin = custom_ops.get_plugin(
|
||||
module_name='upfirdn2d_plugin',
|
||||
sources=['upfirdn2d.cpp', 'upfirdn2d.cu'],
|
||||
headers=['upfirdn2d.h'],
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math'],
|
||||
)
|
||||
return True
|
||||
|
||||
def _parse_scaling(scaling):
|
||||
if isinstance(scaling, int):
|
||||
scaling = [scaling, scaling]
|
||||
assert isinstance(scaling, (list, tuple))
|
||||
assert all(isinstance(x, int) for x in scaling)
|
||||
sx, sy = scaling
|
||||
assert sx >= 1 and sy >= 1
|
||||
return sx, sy
|
||||
|
||||
def _parse_padding(padding):
|
||||
if isinstance(padding, int):
|
||||
padding = [padding, padding]
|
||||
assert isinstance(padding, (list, tuple))
|
||||
assert all(isinstance(x, int) for x in padding)
|
||||
if len(padding) == 2:
|
||||
padx, pady = padding
|
||||
padding = [padx, padx, pady, pady]
|
||||
padx0, padx1, pady0, pady1 = padding
|
||||
return padx0, padx1, pady0, pady1
|
||||
|
||||
def _get_filter_size(f):
|
||||
if f is None:
|
||||
return 1, 1
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
fw = f.shape[-1]
|
||||
fh = f.shape[0]
|
||||
with misc.suppress_tracer_warnings():
|
||||
fw = int(fw)
|
||||
fh = int(fh)
|
||||
misc.assert_shape(f, [fh, fw][:f.ndim])
|
||||
assert fw >= 1 and fh >= 1
|
||||
return fw, fh
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
||||
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
||||
|
||||
Args:
|
||||
f: Torch tensor, numpy array, or python list of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable),
|
||||
`[]` (impulse), or
|
||||
`None` (identity).
|
||||
device: Result device (default: cpu).
|
||||
normalize: Normalize the filter so that it retains the magnitude
|
||||
for constant input signal (DC)? (default: True).
|
||||
flip_filter: Flip the filter? (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
separable: Return a separable filter? (default: select automatically).
|
||||
|
||||
Returns:
|
||||
Float32 tensor of the shape
|
||||
`[filter_height, filter_width]` (non-separable) or
|
||||
`[filter_taps]` (separable).
|
||||
"""
|
||||
# Validate.
|
||||
if f is None:
|
||||
f = 1
|
||||
f = torch.as_tensor(f, dtype=torch.float32)
|
||||
assert f.ndim in [0, 1, 2]
|
||||
assert f.numel() > 0
|
||||
if f.ndim == 0:
|
||||
f = f[np.newaxis]
|
||||
|
||||
# Separable?
|
||||
if separable is None:
|
||||
separable = (f.ndim == 1 and f.numel() >= 8)
|
||||
if f.ndim == 1 and not separable:
|
||||
f = f.ger(f)
|
||||
assert f.ndim == (1 if separable else 2)
|
||||
|
||||
# Apply normalize, flip, gain, and device.
|
||||
if normalize:
|
||||
f /= f.sum()
|
||||
if flip_filter:
|
||||
f = f.flip(list(range(f.ndim)))
|
||||
f = f * (gain ** (f.ndim / 2))
|
||||
f = f.to(device=device)
|
||||
return f
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
||||
|
||||
Performs the following sequence of operations for each channel:
|
||||
|
||||
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
||||
|
||||
2. Pad the image with the specified number of zeros on each side (`padding`).
|
||||
Negative padding corresponds to cropping the image.
|
||||
|
||||
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
||||
so that the footprint of all output pixels lies within the input image.
|
||||
|
||||
4. Downsample the image by keeping every Nth pixel (`down`).
|
||||
|
||||
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
||||
The fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports gradients of arbitrary order.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
up: Integer upsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
down: Integer downsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
||||
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
||||
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
||||
"""
|
||||
# Validate arguments.
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
if f is None:
|
||||
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
assert f.dtype == torch.float32 and not f.requires_grad
|
||||
batch_size, num_channels, in_height, in_width = x.shape
|
||||
upx, upy = _parse_scaling(up)
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
|
||||
# Check that upsampled buffer is not smaller than the filter.
|
||||
upW = in_width * upx + padx0 + padx1
|
||||
upH = in_height * upy + pady0 + pady1
|
||||
assert upW >= f.shape[-1] and upH >= f.shape[0]
|
||||
|
||||
# Upsample by inserting zeros.
|
||||
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
||||
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
||||
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
||||
|
||||
# Pad or crop.
|
||||
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
||||
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
||||
|
||||
# Setup filter.
|
||||
f = f * (gain ** (f.ndim / 2))
|
||||
f = f.to(x.dtype)
|
||||
if not flip_filter:
|
||||
f = f.flip(list(range(f.ndim)))
|
||||
|
||||
# Convolve with the filter.
|
||||
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
||||
if f.ndim == 4:
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
||||
else:
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
||||
|
||||
# Downsample by throwing away pixels.
|
||||
x = x[:, :, ::downy, ::downx]
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_upfirdn2d_cuda_cache = dict()
|
||||
|
||||
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
||||
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
||||
"""
|
||||
# Parse arguments.
|
||||
upx, upy = _parse_scaling(up)
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
||||
if key in _upfirdn2d_cuda_cache:
|
||||
return _upfirdn2d_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class Upfirdn2dCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
if f is None:
|
||||
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
if f.ndim == 1 and f.shape[0] == 1:
|
||||
f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1.
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
y = x
|
||||
if f.ndim == 2:
|
||||
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
||||
else:
|
||||
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0)
|
||||
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain)
|
||||
ctx.save_for_backward(f)
|
||||
ctx.x_shape = x.shape
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
f, = ctx.saved_tensors
|
||||
_, _, ih, iw = ctx.x_shape
|
||||
_, _, oh, ow = dy.shape
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
fw - padx0 - 1,
|
||||
iw * upx - ow * downx + padx0 - upx + 1,
|
||||
fh - pady0 - 1,
|
||||
ih * upy - oh * downy + pady0 - upy + 1,
|
||||
]
|
||||
dx = None
|
||||
df = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
||||
|
||||
assert not ctx.needs_input_grad[1]
|
||||
return dx, df
|
||||
|
||||
# Add to cache.
|
||||
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
||||
return Upfirdn2dCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape matches the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
padding: Padding with respect to the output. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + fw // 2,
|
||||
padx1 + (fw - 1) // 2,
|
||||
pady0 + fh // 2,
|
||||
pady1 + (fh - 1) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape is a multiple of the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
up: Integer upsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the output. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
upx, upy = _parse_scaling(up)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + (fw + upx - 1) // 2,
|
||||
padx1 + (fw - upx) // 2,
|
||||
pady0 + (fh + upy - 1) // 2,
|
||||
pady1 + (fh - upy) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape is a fraction of the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
down: Integer downsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the input. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + (fw - downx + 1) // 2,
|
||||
padx1 + (fw - downx) // 2,
|
||||
pady0 + (fh - downy + 1) // 2,
|
||||
pady1 + (fh - downy) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
Reference in New Issue
Block a user