support multi-gpu
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: arcface_decoder.py
|
||||
# Created Date: Saturday January 29th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 29th January 2022 2:55:39 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.fc = nn.Linear(512, 7*7*512)
|
||||
|
||||
self.up4 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(512), activation
|
||||
)
|
||||
|
||||
self.up3 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(256), activation
|
||||
)
|
||||
|
||||
self.up2 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(128), activation
|
||||
)
|
||||
|
||||
self.up1 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(64), activation
|
||||
)
|
||||
|
||||
self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
|
||||
def forward(self, input):
|
||||
x = input #
|
||||
x = self.fc(x)
|
||||
x = x.view(x.size(0),512,7,7)
|
||||
x = self.up4(x)
|
||||
x = self.up3(x)
|
||||
x = self.up2(x)
|
||||
x = self.up1(x)
|
||||
x = self.last_layer(x)
|
||||
|
||||
return x
|
||||
Reference in New Issue
Block a user