update involution
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: DeConv copy.py
|
||||
# Created Date: Tuesday July 20th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 10th February 2022 1:10:04 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
from components.misc.Involution import involution
|
||||
from torch import nn
|
||||
|
||||
class DeConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size = 3, upsampl_scale = 2, padding="zero"):
|
||||
super().__init__()
|
||||
self.upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampl_scale)
|
||||
padding_size = int((kernel_size -1)/2)
|
||||
self.conv1x1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size= 1)
|
||||
# self.same_padding = nn.ReflectionPad2d(padding_size)
|
||||
if padding.lower() == "reflect":
|
||||
|
||||
self.conv = involution(out_channels,kernel_size,1)
|
||||
# self.conv = nn.Sequential(
|
||||
# nn.ReflectionPad2d(padding_size),
|
||||
# nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size= kernel_size, bias= False))
|
||||
# for layer in self.conv:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
elif padding.lower() == "zero":
|
||||
self.conv = involution(out_channels,kernel_size,1)
|
||||
# nn.init.xavier_uniform_(self.conv.weight)
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# nn.init.xavier_uniform_(self.conv.weight)
|
||||
|
||||
def forward(self, input):
|
||||
h = self.conv1x1(input)
|
||||
h = self.upsampling(h)
|
||||
h = self.conv(h)
|
||||
return h
|
||||
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator.py
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 10th February 2022 3:14:08 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from components.DeConv_Invo import DeConv
|
||||
|
||||
class Demodule(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
"""
|
||||
@notice: avoid in-place ops.
|
||||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
||||
"""
|
||||
super(Demodule, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
tmp = torch.mul(x, x) # or x ** 2
|
||||
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
||||
return x * tmp
|
||||
|
||||
class ApplyStyle(nn.Module):
|
||||
"""
|
||||
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
||||
"""
|
||||
def __init__(self, latent_size, channels):
|
||||
super(ApplyStyle, self).__init__()
|
||||
self.linear = nn.Linear(latent_size, channels * 2)
|
||||
|
||||
def forward(self, x, latent):
|
||||
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
||||
shape = [-1, 2, x.size(1), 1, 1]
|
||||
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
||||
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
||||
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
|
||||
return x
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, latent_size, channels):
|
||||
super(Modulation, self).__init__()
|
||||
self.linear = nn.Linear(latent_size, channels)
|
||||
|
||||
def forward(self, x, latent):
|
||||
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
||||
shape = [-1, x.size(1), 1, 1]
|
||||
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
||||
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
||||
x = x * style
|
||||
return x
|
||||
|
||||
class ResnetBlock_Modulation(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Modulation, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
if padding_type == 'reflect':
|
||||
conv1 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv1 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), Demodule()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = Modulation(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
conv2 = []
|
||||
if padding_type == 'reflect':
|
||||
conv2 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv2 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), Demodule()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = Modulation(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
y = self.conv1(x)
|
||||
y = self.style1(y, dlatents_in_slice)
|
||||
y = self.act1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.style2(y, dlatents_in_slice)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=False),
|
||||
nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(128), activation)
|
||||
|
||||
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(256), activation)
|
||||
|
||||
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(512), activation)
|
||||
|
||||
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(512), activation)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
ResnetBlock_Modulation(512, latent_size=chn, padding_type=padding_type, activation=activation)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
self.up4 = nn.Sequential(
|
||||
DeConv(512,512,3),
|
||||
nn.BatchNorm2d(512), activation
|
||||
)
|
||||
|
||||
self.up3 = nn.Sequential(
|
||||
DeConv(512,256,3),
|
||||
nn.BatchNorm2d(256), activation
|
||||
)
|
||||
|
||||
self.up2 = nn.Sequential(
|
||||
DeConv(256,128,3),
|
||||
nn.BatchNorm2d(128), activation
|
||||
)
|
||||
|
||||
self.up1 = nn.Sequential(
|
||||
DeConv(128,64,3),
|
||||
nn.BatchNorm2d(64), activation
|
||||
)
|
||||
|
||||
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input, id):
|
||||
x = input # 3*224*224
|
||||
skip1 = self.first_layer(x)
|
||||
skip2 = self.down1(skip1)
|
||||
skip3 = self.down2(skip2)
|
||||
skip4 = self.down3(skip3)
|
||||
res = self.down4(skip4)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
x = self.BottleNeck[i](res, id)
|
||||
|
||||
x = self.up4(x)
|
||||
x = self.up3(x)
|
||||
x = self.up2(x)
|
||||
x = self.up1(x)
|
||||
x = self.last_layer(x)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator.py
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 10th February 2022 1:01:09 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from components.DeConv_Invo import DeConv
|
||||
|
||||
class InstanceNorm(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
"""
|
||||
@notice: avoid in-place ops.
|
||||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
||||
"""
|
||||
super(InstanceNorm, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
x = x - torch.mean(x, (2, 3), True)
|
||||
tmp = torch.mul(x, x) # or x ** 2
|
||||
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
||||
return x * tmp
|
||||
|
||||
class ApplyStyle(nn.Module):
|
||||
"""
|
||||
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
||||
"""
|
||||
def __init__(self, latent_size, channels):
|
||||
super(ApplyStyle, self).__init__()
|
||||
self.linear = nn.Linear(latent_size, channels * 2)
|
||||
|
||||
def forward(self, x, latent):
|
||||
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
||||
shape = [-1, 2, x.size(1), 1, 1]
|
||||
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
||||
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
||||
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
|
||||
return x
|
||||
|
||||
class ResnetBlock_Adain(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Adain, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
if padding_type == 'reflect':
|
||||
conv1 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv1 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = ApplyStyle(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
conv2 = []
|
||||
if padding_type == 'reflect':
|
||||
conv2 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv2 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = ApplyStyle(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
y = self.conv1(x)
|
||||
y = self.style1(y, dlatents_in_slice)
|
||||
y = self.act1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.style2(y, dlatents_in_slice)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=False),
|
||||
nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128), activation)
|
||||
|
||||
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(256), activation)
|
||||
|
||||
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(512), activation)
|
||||
|
||||
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(512), activation)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
self.up4 = nn.Sequential(
|
||||
DeConv(512,512,3),
|
||||
nn.BatchNorm2d(512), activation
|
||||
)
|
||||
|
||||
self.up3 = nn.Sequential(
|
||||
DeConv(512,256,3),
|
||||
nn.BatchNorm2d(256), activation
|
||||
)
|
||||
|
||||
self.up2 = nn.Sequential(
|
||||
DeConv(256,128,3),
|
||||
nn.BatchNorm2d(128), activation
|
||||
)
|
||||
|
||||
self.up1 = nn.Sequential(
|
||||
DeConv(128,64,3),
|
||||
nn.BatchNorm2d(64), activation
|
||||
)
|
||||
|
||||
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input, id):
|
||||
x = input # 3*224*224
|
||||
skip1 = self.first_layer(x)
|
||||
skip2 = self.down1(skip1)
|
||||
skip3 = self.down2(skip2)
|
||||
skip4 = self.down3(skip3)
|
||||
res = self.down4(skip4)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
x = self.BottleNeck[i](res, id)
|
||||
|
||||
x = self.up4(x)
|
||||
x = self.up3(x)
|
||||
x = self.up2(x)
|
||||
x = self.up1(x)
|
||||
x = self.last_layer(x)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,302 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Involution.py
|
||||
# Created Date: Tuesday July 20th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 11th February 2022 12:08:41 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.utils import _pair
|
||||
from torch.autograd import Function
|
||||
|
||||
import cupy
|
||||
from string import Template
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
|
||||
Stream = namedtuple('Stream', ['ptr'])
|
||||
|
||||
|
||||
def Dtype(t):
|
||||
if isinstance(t, torch.cuda.FloatTensor):
|
||||
return 'float'
|
||||
elif isinstance(t, torch.cuda.DoubleTensor):
|
||||
return 'double'
|
||||
|
||||
|
||||
@cupy._util.memoize(for_each_device=True)
|
||||
def load_kernel(kernel_name, code, **kwargs):
|
||||
code = Template(code).substitute(**kwargs)
|
||||
kernel_code = cupy.cuda.compile_with_cache(code)
|
||||
return kernel_code.get_function(kernel_name)
|
||||
|
||||
|
||||
CUDA_NUM_THREADS = 1024
|
||||
|
||||
kernel_loop = '''
|
||||
#define CUDA_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
||||
i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
'''
|
||||
|
||||
|
||||
def GET_BLOCKS(N):
|
||||
return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS
|
||||
|
||||
|
||||
_involution_kernel = kernel_loop + '''
|
||||
extern "C"
|
||||
__global__ void involution_forward_kernel(
|
||||
const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) {
|
||||
CUDA_KERNEL_LOOP(index, ${nthreads}) {
|
||||
const int n = index / ${channels} / ${top_height} / ${top_width};
|
||||
const int c = (index / ${top_height} / ${top_width}) % ${channels};
|
||||
const int h = (index / ${top_width}) % ${top_height};
|
||||
const int w = index % ${top_width};
|
||||
const int g = c / (${channels} / ${groups});
|
||||
${Dtype} value = 0;
|
||||
#pragma unroll
|
||||
for (int kh = 0; kh < ${kernel_h}; ++kh) {
|
||||
#pragma unroll
|
||||
for (int kw = 0; kw < ${kernel_w}; ++kw) {
|
||||
const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h};
|
||||
const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w};
|
||||
if ((h_in >= 0) && (h_in < ${bottom_height})
|
||||
&& (w_in >= 0) && (w_in < ${bottom_width})) {
|
||||
const int offset = ((n * ${channels} + c) * ${bottom_height} + h_in)
|
||||
* ${bottom_width} + w_in;
|
||||
const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h)
|
||||
* ${top_width} + w;
|
||||
value += weight_data[offset_weight] * bottom_data[offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
top_data[index] = value;
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
|
||||
_involution_kernel_backward_grad_input = kernel_loop + '''
|
||||
extern "C"
|
||||
__global__ void involution_backward_grad_input_kernel(
|
||||
const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* const bottom_diff) {
|
||||
CUDA_KERNEL_LOOP(index, ${nthreads}) {
|
||||
const int n = index / ${channels} / ${bottom_height} / ${bottom_width};
|
||||
const int c = (index / ${bottom_height} / ${bottom_width}) % ${channels};
|
||||
const int h = (index / ${bottom_width}) % ${bottom_height};
|
||||
const int w = index % ${bottom_width};
|
||||
const int g = c / (${channels} / ${groups});
|
||||
${Dtype} value = 0;
|
||||
#pragma unroll
|
||||
for (int kh = 0; kh < ${kernel_h}; ++kh) {
|
||||
#pragma unroll
|
||||
for (int kw = 0; kw < ${kernel_w}; ++kw) {
|
||||
const int h_out_s = h + ${pad_h} - kh * ${dilation_h};
|
||||
const int w_out_s = w + ${pad_w} - kw * ${dilation_w};
|
||||
if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) {
|
||||
const int h_out = h_out_s / ${stride_h};
|
||||
const int w_out = w_out_s / ${stride_w};
|
||||
if ((h_out >= 0) && (h_out < ${top_height})
|
||||
&& (w_out >= 0) && (w_out < ${top_width})) {
|
||||
const int offset = ((n * ${channels} + c) * ${top_height} + h_out)
|
||||
* ${top_width} + w_out;
|
||||
const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h_out)
|
||||
* ${top_width} + w_out;
|
||||
value += weight_data[offset_weight] * top_diff[offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
bottom_diff[index] = value;
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
|
||||
_involution_kernel_backward_grad_weight = kernel_loop + '''
|
||||
extern "C"
|
||||
__global__ void involution_backward_grad_weight_kernel(
|
||||
const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* const buffer_data) {
|
||||
CUDA_KERNEL_LOOP(index, ${nthreads}) {
|
||||
const int h = (index / ${top_width}) % ${top_height};
|
||||
const int w = index % ${top_width};
|
||||
const int kh = (index / ${kernel_w} / ${top_height} / ${top_width})
|
||||
% ${kernel_h};
|
||||
const int kw = (index / ${top_height} / ${top_width}) % ${kernel_w};
|
||||
const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h};
|
||||
const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w};
|
||||
if ((h_in >= 0) && (h_in < ${bottom_height})
|
||||
&& (w_in >= 0) && (w_in < ${bottom_width})) {
|
||||
const int g = (index / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${groups};
|
||||
const int n = (index / ${groups} / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${num};
|
||||
${Dtype} value = 0;
|
||||
#pragma unroll
|
||||
for (int c = g * (${channels} / ${groups}); c < (g + 1) * (${channels} / ${groups}); ++c) {
|
||||
const int top_offset = ((n * ${channels} + c) * ${top_height} + h)
|
||||
* ${top_width} + w;
|
||||
const int bottom_offset = ((n * ${channels} + c) * ${bottom_height} + h_in)
|
||||
* ${bottom_width} + w_in;
|
||||
value += top_diff[top_offset] * bottom_data[bottom_offset];
|
||||
}
|
||||
buffer_data[index] = value;
|
||||
} else {
|
||||
buffer_data[index] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
|
||||
class _involution(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, stride, padding, dilation):
|
||||
assert input.dim() == 4 and input.is_cuda
|
||||
assert weight.dim() == 6 and weight.is_cuda
|
||||
batch_size, channels, height, width = input.size()
|
||||
kernel_h, kernel_w = weight.size()[2:4]
|
||||
output_h = int((height + 2 * padding[0] - (dilation[0] * (kernel_h - 1) + 1)) / stride[0] + 1)
|
||||
output_w = int((width + 2 * padding[1] - (dilation[1] * (kernel_w - 1) + 1)) / stride[1] + 1)
|
||||
|
||||
output = input.new(batch_size, channels, output_h, output_w)
|
||||
n = output.numel()
|
||||
|
||||
with torch.cuda.device_of(input):
|
||||
f = load_kernel('involution_forward_kernel', _involution_kernel, Dtype=Dtype(input), nthreads=n,
|
||||
num=batch_size, channels=channels, groups=weight.size()[1],
|
||||
bottom_height=height, bottom_width=width,
|
||||
top_height=output_h, top_width=output_w,
|
||||
kernel_h=kernel_h, kernel_w=kernel_w,
|
||||
stride_h=stride[0], stride_w=stride[1],
|
||||
dilation_h=dilation[0], dilation_w=dilation[1],
|
||||
pad_h=padding[0], pad_w=padding[1])
|
||||
f(block=(CUDA_NUM_THREADS,1,1),
|
||||
grid=(GET_BLOCKS(n),1,1),
|
||||
args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()],
|
||||
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
|
||||
|
||||
ctx.save_for_backward(input, weight)
|
||||
ctx.stride, ctx.padding, ctx.dilation = stride, padding, dilation
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
assert grad_output.is_cuda and grad_output.is_contiguous()
|
||||
input, weight = ctx.saved_tensors
|
||||
stride, padding, dilation = ctx.stride, ctx.padding, ctx.dilation
|
||||
|
||||
batch_size, channels, height, width = input.size()
|
||||
kernel_h, kernel_w = weight.size()[2:4]
|
||||
output_h, output_w = grad_output.size()[2:]
|
||||
|
||||
grad_input, grad_weight = None, None
|
||||
|
||||
opt = dict(Dtype=Dtype(grad_output),
|
||||
num=batch_size, channels=channels, groups=weight.size()[1],
|
||||
bottom_height=height, bottom_width=width,
|
||||
top_height=output_h, top_width=output_w,
|
||||
kernel_h=kernel_h, kernel_w=kernel_w,
|
||||
stride_h=stride[0], stride_w=stride[1],
|
||||
dilation_h=dilation[0], dilation_w=dilation[1],
|
||||
pad_h=padding[0], pad_w=padding[1])
|
||||
|
||||
with torch.cuda.device_of(input):
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = input.new(input.size())
|
||||
|
||||
n = grad_input.numel()
|
||||
opt['nthreads'] = n
|
||||
|
||||
f = load_kernel('involution_backward_grad_input_kernel',
|
||||
_involution_kernel_backward_grad_input, **opt)
|
||||
f(block=(CUDA_NUM_THREADS,1,1),
|
||||
grid=(GET_BLOCKS(n),1,1),
|
||||
args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()],
|
||||
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = weight.new(weight.size())
|
||||
|
||||
n = grad_weight.numel()
|
||||
opt['nthreads'] = n
|
||||
|
||||
f = load_kernel('involution_backward_grad_weight_kernel',
|
||||
_involution_kernel_backward_grad_weight, **opt)
|
||||
f(block=(CUDA_NUM_THREADS,1,1),
|
||||
grid=(GET_BLOCKS(n),1,1),
|
||||
args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()],
|
||||
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
|
||||
|
||||
return grad_input, grad_weight, None, None, None
|
||||
|
||||
|
||||
def _involution_cuda(input, weight, bias=None, stride=1, padding=0, dilation=1):
|
||||
""" involution kernel
|
||||
"""
|
||||
assert input.size(0) == weight.size(0)
|
||||
assert input.size(-2)//stride == weight.size(-2)
|
||||
assert input.size(-1)//stride == weight.size(-1)
|
||||
if input.is_cuda:
|
||||
out = _involution.apply(input, weight, _pair(stride), _pair(padding), _pair(dilation))
|
||||
if bias is not None:
|
||||
out += bias.view(1,-1,1,1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return out
|
||||
|
||||
|
||||
class involution(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride):
|
||||
super(involution, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.channels = channels
|
||||
reduction_ratio = 4
|
||||
self.group_channels = 8
|
||||
self.groups = self.channels // self.group_channels
|
||||
self.seblock = nn.Sequential(
|
||||
nn.Conv2d(in_channels = channels, out_channels = channels // reduction_ratio, kernel_size= 1),
|
||||
# nn.BatchNorm2d(channels // reduction_ratio),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels = channels // reduction_ratio, out_channels = kernel_size**2 * self.groups, kernel_size= 1)
|
||||
)
|
||||
|
||||
# self.conv1 = ConvModule(
|
||||
# in_channels=channels,
|
||||
# out_channels=channels // reduction_ratio,
|
||||
# kernel_size=1,
|
||||
# conv_cfg=None,
|
||||
# norm_cfg=dict(type='BN'),
|
||||
# act_cfg=dict(type='ReLU'))
|
||||
# self.conv2 = ConvModule(
|
||||
# in_channels=channels // reduction_ratio,
|
||||
# out_channels=kernel_size**2 * self.groups,
|
||||
# kernel_size=1,
|
||||
# stride=1,
|
||||
# conv_cfg=None,
|
||||
# norm_cfg=None,
|
||||
# act_cfg=None)
|
||||
if stride > 1:
|
||||
self.avgpool = nn.AvgPool2d(stride, stride)
|
||||
|
||||
def forward(self, x):
|
||||
# weight = self.conv2(self.conv1(x if self.stride == 1 else self.avgpool(x)))
|
||||
weight = self.seblock(x)
|
||||
b, c, h, w = weight.shape
|
||||
weight = weight.view(b, self.groups, self.kernel_size, self.kernel_size, h, w)
|
||||
out = _involution_cuda(x, weight, stride=self.stride, padding=(self.kernel_size-1)//2)
|
||||
return out
|
||||
Reference in New Issue
Block a user