From 7f16d0a10eab8a04c6efbbf6a6dee2d9e0b2b6c9 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Mon, 10 Mar 2025 18:50:54 +0100 Subject: [PATCH] Follow sequences pattern --- face_swapper/src/networks/masknet.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index 03b9940..bbe4e55 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -51,21 +51,19 @@ class MaskNet(nn.Module): class BottleNeck(nn.Module): def __init__(self, num_filters : int): super().__init__() - self.layers = self.create_layers(num_filters) - self.sequences = nn.Sequential(*self.layers) + self.sequences = self.create_sequences(num_filters) self.relu = nn.ReLU(inplace = True) @staticmethod - def create_layers(num_filters : int) -> nn.ModuleList: - return nn.ModuleList( - [ + def create_sequences(num_filters : int) -> nn.Sequential: + return nn.Sequential( nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False), nn.BatchNorm2d(num_filters), nn.ReLU(inplace = True), nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False), nn.BatchNorm2d(num_filters), nn.ReLU(inplace = True) - ]) + ) def forward(self, input_tensor : Tensor) -> Tensor: output_tensor = self.sequences(input_tensor) + input_tensor