124 lines
4.8 KiB
Python
124 lines
4.8 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: Conditional_Discriminator copy.py
|
|
# Created Date: Saturday April 18th 2020
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Tuesday, 29th June 2021 4:26:33 pm
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
from torch.nn import init
|
|
from torch.nn import functional as F
|
|
|
|
from torch.nn import utils
|
|
|
|
class Discriminator(nn.Module):
|
|
def __init__(self, chn=32, k_size=3, n_class=3):
|
|
super().__init__()
|
|
# padding_size = int((k_size -1)/2)
|
|
slop = 0.2
|
|
enable_bias = True
|
|
|
|
# stage 1
|
|
self.block1 = nn.Sequential(
|
|
utils.spectral_norm(nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size= k_size, stride = 2, padding=2,bias= enable_bias)),
|
|
nn.LeakyReLU(slop),
|
|
utils.spectral_norm(nn.Conv2d(in_channels = chn, out_channels = chn * 2 , kernel_size= k_size, stride = 2,padding=2, bias= enable_bias)), # 1/4
|
|
nn.LeakyReLU(slop)
|
|
)
|
|
self.aux_classfier1 = nn.Sequential(
|
|
utils.spectral_norm(nn.Conv2d(in_channels = chn * 2 , out_channels = chn , kernel_size= 5, bias=enable_bias)),
|
|
nn.LeakyReLU(slop),
|
|
nn.AdaptiveAvgPool2d(1),
|
|
)
|
|
self.embed1 = utils.spectral_norm(nn.Embedding(n_class, chn))
|
|
self.linear1= utils.spectral_norm(nn.Linear(chn, 1))
|
|
|
|
# stage 2
|
|
self.block2 = nn.Sequential(
|
|
utils.spectral_norm(nn.Conv2d(in_channels = chn * 2 , out_channels = chn * 4 , kernel_size= k_size, stride = 2, padding=2, bias= enable_bias)),# 1/8
|
|
nn.LeakyReLU(slop),
|
|
utils.spectral_norm(nn.Conv2d(in_channels = chn * 4, out_channels = chn * 8 , kernel_size= k_size, stride = 2,padding=2, bias= enable_bias)),# 1/16
|
|
nn.LeakyReLU(slop)
|
|
)
|
|
self.aux_classfier2 = nn.Sequential(
|
|
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8 , out_channels = chn , kernel_size= 5, bias= enable_bias)),
|
|
nn.LeakyReLU(slop),
|
|
nn.AdaptiveAvgPool2d(1),
|
|
)
|
|
self.embed2 = utils.spectral_norm(nn.Embedding(n_class, chn))
|
|
self.linear2= utils.spectral_norm(nn.Linear(chn, 1))
|
|
|
|
# stage 3
|
|
self.block3 = nn.Sequential(
|
|
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8 , out_channels = chn * 8 , kernel_size= k_size, stride = 2,padding=3, bias= enable_bias)),# 1/32
|
|
nn.LeakyReLU(slop),
|
|
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8, out_channels = chn * 16 , kernel_size= k_size, stride = 2,padding=3, bias= enable_bias)),# 1/64
|
|
nn.LeakyReLU(slop)
|
|
)
|
|
self.aux_classfier3 = nn.Sequential(
|
|
utils.spectral_norm(nn.Conv2d(in_channels = chn * 16 , out_channels = chn, kernel_size= 5, bias= enable_bias)),
|
|
nn.LeakyReLU(slop),
|
|
nn.AdaptiveAvgPool2d(1),
|
|
)
|
|
self.embed3 = utils.spectral_norm(nn.Embedding(n_class, chn))
|
|
self.linear3= utils.spectral_norm(nn.Linear(chn, 1))
|
|
self.__weights_init__()
|
|
|
|
def __weights_init__(self):
|
|
print("Init weights")
|
|
for m in self.modules():
|
|
if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
try:
|
|
nn.init.zeros_(m.bias)
|
|
except:
|
|
print("No bias found!")
|
|
|
|
if isinstance(m, nn.Embedding):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
|
|
def forward(self, input, condition):
|
|
|
|
h = self.block1(input)
|
|
prep1 = self.aux_classfier1(h)
|
|
prep1 = prep1.view(prep1.size()[0], -1)
|
|
y1 = self.embed1(condition)
|
|
y1 = torch.sum(y1 * prep1, dim=1, keepdim=True)
|
|
prep1 = self.linear1(prep1) + y1
|
|
|
|
h = self.block2(h)
|
|
prep2 = self.aux_classfier2(h)
|
|
prep2 = prep2.view(prep2.size()[0], -1)
|
|
y2 = self.embed2(condition)
|
|
y2 = torch.sum(y2 * prep2, dim=1, keepdim=True)
|
|
prep2 = self.linear2(prep2) + y2
|
|
|
|
h = self.block3(h)
|
|
prep3 = self.aux_classfier3(h)
|
|
prep3 = prep3.view(prep3.size()[0], -1)
|
|
y3 = self.embed3(condition)
|
|
y3 = torch.sum(y3 * prep3, dim=1, keepdim=True)
|
|
prep3 = self.linear3(prep3) + y3
|
|
|
|
out_prep = [prep1,prep2,prep3]
|
|
return out_prep
|
|
|
|
def get_outputs_len(self):
|
|
num = 0
|
|
for m in self.modules():
|
|
if isinstance(m,nn.Linear):
|
|
num+=1
|
|
return num
|
|
|
|
if __name__ == "__main__":
|
|
wocao = Discriminator().cuda()
|
|
from torchsummary import summary
|
|
summary(wocao, input_size=(3, 512, 512)) |