eca depth wise

This commit is contained in:
chenxuanhong
2022-02-19 18:26:22 +08:00
parent e86270032b
commit db049166a0
7 changed files with 264 additions and 52 deletions
@@ -5,7 +5,7 @@
# Created Date: Sunday January 16th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 17th February 2022 2:06:09 am
# Last Modified: Saturday, 19th February 2022 5:16:02 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -79,6 +79,10 @@ class ResnetBlock_Modulation(nn.Module):
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,groups=dim, bias=False),
nn.Conv2d(dim, dim, kernel_size=1),
Demodule()]
elif res_mode.lower() == "depthwise_eca":
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,groups=dim, bias=False),
nn.Conv2d(dim, dim, kernel_size=1),
Demodule()]
self.conv1 = nn.Sequential(*conv1)
self.style1 = Modulation(latent_size, dim)
self.act1 = activation
@@ -99,6 +103,10 @@ class ResnetBlock_Modulation(nn.Module):
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,groups=dim, bias=False),
nn.Conv2d(dim, dim, kernel_size=1),
Demodule()]
elif res_mode.lower() == "depthwise_eca":
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,groups=dim, bias=False),
nn.Conv2d(dim, dim, kernel_size=1),
Demodule()]
self.conv2 = nn.Sequential(*conv2)
self.style2 = Modulation(latent_size, dim)