Files
SimSwapPlus/components/Conditional_Discriminator_Projection_big.py
T
chenxuanhong 3783ef0e75 init
2022-01-10 15:03:58 +08:00

124 lines
4.8 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Conditional_Discriminator copy.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 29th June 2021 4:26:33 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from torch.nn import utils
class Discriminator(nn.Module):
def __init__(self, chn=32, k_size=3, n_class=3):
super().__init__()
# padding_size = int((k_size -1)/2)
slop = 0.2
enable_bias = True
# stage 1
self.block1 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size= k_size, stride = 2, padding=2,bias= enable_bias)),
nn.LeakyReLU(slop),
utils.spectral_norm(nn.Conv2d(in_channels = chn, out_channels = chn * 2 , kernel_size= k_size, stride = 2,padding=2, bias= enable_bias)), # 1/4
nn.LeakyReLU(slop)
)
self.aux_classfier1 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = chn * 2 , out_channels = chn , kernel_size= 5, bias=enable_bias)),
nn.LeakyReLU(slop),
nn.AdaptiveAvgPool2d(1),
)
self.embed1 = utils.spectral_norm(nn.Embedding(n_class, chn))
self.linear1= utils.spectral_norm(nn.Linear(chn, 1))
# stage 2
self.block2 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = chn * 2 , out_channels = chn * 4 , kernel_size= k_size, stride = 2, padding=2, bias= enable_bias)),# 1/8
nn.LeakyReLU(slop),
utils.spectral_norm(nn.Conv2d(in_channels = chn * 4, out_channels = chn * 8 , kernel_size= k_size, stride = 2,padding=2, bias= enable_bias)),# 1/16
nn.LeakyReLU(slop)
)
self.aux_classfier2 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8 , out_channels = chn , kernel_size= 5, bias= enable_bias)),
nn.LeakyReLU(slop),
nn.AdaptiveAvgPool2d(1),
)
self.embed2 = utils.spectral_norm(nn.Embedding(n_class, chn))
self.linear2= utils.spectral_norm(nn.Linear(chn, 1))
# stage 3
self.block3 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8 , out_channels = chn * 8 , kernel_size= k_size, stride = 2,padding=3, bias= enable_bias)),# 1/32
nn.LeakyReLU(slop),
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8, out_channels = chn * 16 , kernel_size= k_size, stride = 2,padding=3, bias= enable_bias)),# 1/64
nn.LeakyReLU(slop)
)
self.aux_classfier3 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = chn * 16 , out_channels = chn, kernel_size= 5, bias= enable_bias)),
nn.LeakyReLU(slop),
nn.AdaptiveAvgPool2d(1),
)
self.embed3 = utils.spectral_norm(nn.Embedding(n_class, chn))
self.linear3= utils.spectral_norm(nn.Linear(chn, 1))
self.__weights_init__()
def __weights_init__(self):
print("Init weights")
for m in self.modules():
if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear):
nn.init.xavier_uniform_(m.weight)
try:
nn.init.zeros_(m.bias)
except:
print("No bias found!")
if isinstance(m, nn.Embedding):
nn.init.xavier_uniform_(m.weight)
def forward(self, input, condition):
h = self.block1(input)
prep1 = self.aux_classfier1(h)
prep1 = prep1.view(prep1.size()[0], -1)
y1 = self.embed1(condition)
y1 = torch.sum(y1 * prep1, dim=1, keepdim=True)
prep1 = self.linear1(prep1) + y1
h = self.block2(h)
prep2 = self.aux_classfier2(h)
prep2 = prep2.view(prep2.size()[0], -1)
y2 = self.embed2(condition)
y2 = torch.sum(y2 * prep2, dim=1, keepdim=True)
prep2 = self.linear2(prep2) + y2
h = self.block3(h)
prep3 = self.aux_classfier3(h)
prep3 = prep3.view(prep3.size()[0], -1)
y3 = self.embed3(condition)
y3 = torch.sum(y3 * prep3, dim=1, keepdim=True)
prep3 = self.linear3(prep3) + y3
out_prep = [prep1,prep2,prep3]
return out_prep
def get_outputs_len(self):
num = 0
for m in self.modules():
if isinstance(m,nn.Linear):
num+=1
return num
if __name__ == "__main__":
wocao = Discriminator().cuda()
from torchsummary import summary
summary(wocao, input_size=(3, 512, 512))