mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Follow sequences pattern
This commit is contained in:
@@ -51,21 +51,19 @@ class MaskNet(nn.Module):
|
||||
class BottleNeck(nn.Module):
|
||||
def __init__(self, num_filters : int):
|
||||
super().__init__()
|
||||
self.layers = self.create_layers(num_filters)
|
||||
self.sequences = nn.Sequential(*self.layers)
|
||||
self.sequences = self.create_sequences(num_filters)
|
||||
self.relu = nn.ReLU(inplace = True)
|
||||
|
||||
@staticmethod
|
||||
def create_layers(num_filters : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
def create_sequences(num_filters : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(inplace = True)
|
||||
])
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor) + input_tensor
|
||||
|
||||
Reference in New Issue
Block a user