Files
SimSwapPlus/components/arcface_decoder.py
T
2022-02-08 16:37:30 +08:00

64 lines
2.0 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: arcface_decoder.py
# Created Date: Saturday January 29th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Saturday, 29th January 2022 2:55:39 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
class Decoder(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
activation = nn.ReLU(True)
self.fc = nn.Linear(512, 7*7*512)
self.up4 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512), activation
)
self.up3 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256), activation
)
self.up2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128), activation
)
self.up1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64), activation
)
self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
def forward(self, input):
x = input #
x = self.fc(x)
x = x.view(x.size(0),512,7,7)
x = self.up4(x)
x = self.up3(x)
x = self.up2(x)
x = self.up1(x)
x = self.last_layer(x)
return x