47 lines
1.9 KiB
Python
47 lines
1.9 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: Generator.py
|
|
# Created Date: Sunday January 16th 2022
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Sunday, 13th February 2022 2:03:21 am
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class LSTU(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channel,
|
|
out_channel,
|
|
latent_channel,
|
|
scale = 4
|
|
):
|
|
super().__init__()
|
|
sig = nn.Sigmoid()
|
|
self.relu = nn.ReLU(True)
|
|
|
|
self.up_sample = nn.Sequential(nn.ConvTranspose2d(latent_channel, out_channel, kernel_size=4, stride=scale, padding=0, bias=False),
|
|
nn.BatchNorm2d(out_channel), sig)
|
|
|
|
self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_channel), sig)
|
|
|
|
self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_channel), sig)
|
|
|
|
self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True))
|
|
|
|
def forward(self, encoder_in, bottleneck_in):
|
|
h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel`
|
|
h_bar_l = self.conv11(h_hat_l_1)
|
|
f_l = self.forget_gate(h_hat_l_1)
|
|
r_l = self.reset_gate (h_hat_l_1)
|
|
h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in
|
|
x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1
|
|
return x_hat_l |