#!/usr/bin/env python3 # -*- coding:utf-8 -*- ############################################################# # File: ModulatedDWConv.py # Created Date: Monday April 18th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com # Last Modified: Monday, 18th April 2022 10:33:48 am # Modified By: Chen Xuanhong # Modified from: https://github.com/bes-dev/MobileStyleGAN.pytorch # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# import math import torch import torch.nn as nn import torch.nn.functional as F class ModulatedDWConv2d(nn.Module): def __init__( self, channels_in, channels_out, style_dim, kernel_size, demodulate=True ): super().__init__() # create conv self.weight_dw = nn.Parameter( torch.randn(channels_in, 1, kernel_size, kernel_size) ) self.weight_permute = nn.Parameter( torch.randn(channels_out, channels_in, 1, 1) ) # create modulation network self.modulation = nn.Linear(style_dim, channels_in, bias=True) self.modulation.bias.data.fill_(1.0) # create demodulation parameters self.demodulate = demodulate if self.demodulate: self.register_buffer("style_inv", torch.randn(1, 1, channels_in, 1, 1)) # some service staff self.scale = 1.0 / math.sqrt(channels_in * kernel_size ** 2) self.padding = kernel_size // 2 def forward(self, x, style): modulation = self.get_modulation(style) x = modulation * x x = F.conv2d(x, self.weight_dw, padding=self.padding, groups=x.size(1)) x = F.conv2d(x, self.weight_permute) if self.demodulate: demodulation = self.get_demodulation(style) x = demodulation * x return x def get_modulation(self, style): style = self.modulation(style).view(style.size(0), -1, 1, 1) modulation = self.scale * style return modulation def get_demodulation(self, style): w = (self.weight_dw.transpose(0, 1) * self.weight_permute).unsqueeze(0) norm = torch.rsqrt((self.scale * self.style_inv * w).pow(2).sum([2, 3, 4]) + 1e-8) demodulation = norm return demodulation.view(*demodulation.size(), 1, 1)