Files
SimSwapPlus/components/ModulatedDWConv.py
T
chenxuanhong 29d8914c0a update
2022-04-24 15:44:47 +08:00

66 lines
2.3 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: ModulatedDWConv.py
# Created Date: Monday April 18th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 18th April 2022 10:33:48 am
# Modified By: Chen Xuanhong
# Modified from: https://github.com/bes-dev/MobileStyleGAN.pytorch
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class ModulatedDWConv2d(nn.Module):
def __init__(
self,
channels_in,
channels_out,
style_dim,
kernel_size,
demodulate=True
):
super().__init__()
# create conv
self.weight_dw = nn.Parameter(
torch.randn(channels_in, 1, kernel_size, kernel_size)
)
self.weight_permute = nn.Parameter(
torch.randn(channels_out, channels_in, 1, 1)
)
# create modulation network
self.modulation = nn.Linear(style_dim, channels_in, bias=True)
self.modulation.bias.data.fill_(1.0)
# create demodulation parameters
self.demodulate = demodulate
if self.demodulate:
self.register_buffer("style_inv", torch.randn(1, 1, channels_in, 1, 1))
# some service staff
self.scale = 1.0 / math.sqrt(channels_in * kernel_size ** 2)
self.padding = kernel_size // 2
def forward(self, x, style):
modulation = self.get_modulation(style)
x = modulation * x
x = F.conv2d(x, self.weight_dw, padding=self.padding, groups=x.size(1))
x = F.conv2d(x, self.weight_permute)
if self.demodulate:
demodulation = self.get_demodulation(style)
x = demodulation * x
return x
def get_modulation(self, style):
style = self.modulation(style).view(style.size(0), -1, 1, 1)
modulation = self.scale * style
return modulation
def get_demodulation(self, style):
w = (self.weight_dw.transpose(0, 1) * self.weight_permute).unsqueeze(0)
norm = torch.rsqrt((self.scale * self.style_inv * w).pow(2).sum([2, 3, 4]) + 1e-8)
demodulation = norm
return demodulation.view(*demodulation.size(), 1, 1)