From f9d105ea2b05789b09153c674391233ec865ec70 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Mon, 10 Mar 2025 11:09:50 +0100 Subject: [PATCH] Some code review --- face_swapper/src/networks/masknet.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index 7d26633..5b8305d 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -16,23 +16,23 @@ class MaskNet(nn.Module): self.conv = nn.Conv2d(self.config_base_channels, self.config_output_channels, kernel_size = 1) self.sigmoid = nn.Sigmoid() - def create_down_samples(self, input_channels : int, base_channels: int) -> nn.ModuleList: - down_samples = nn.ModuleList( + @staticmethod + def create_down_samples(input_channels : int, base_channels : int) -> nn.ModuleList: + return nn.ModuleList( [ DownSample(input_channels, base_channels), DownSample(base_channels, base_channels * 2), DownSample(base_channels * 2, base_channels * 4) ]) - return down_samples - def create_up_samples(self, base_channels : int) -> nn.ModuleList: - down_samples = nn.ModuleList( + @staticmethod + def create_up_samples(base_channels : int) -> nn.ModuleList: + return nn.ModuleList( [ UpSample(base_channels * 4, base_channels * 2), UpSample(base_channels * 2, base_channels), UpSample(base_channels, base_channels) ]) - return down_samples def forward(self, target_tensor : Tensor, target_attribute : Tensor) -> Tensor: output_tensor = torch.cat([ target_tensor, target_attribute ], dim = 1) @@ -51,14 +51,14 @@ class MaskNet(nn.Module): class BottleNeck(nn.Module): - def __init__(self, channels : int): + def __init__(self, base_channels : int): super().__init__() self.conv = nn.Sequential( - nn.Conv2d(channels, channels, kernel_size = 3, padding = 1, bias = False), - nn.BatchNorm2d(channels), + nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False), + nn.BatchNorm2d(base_channels), nn.ReLU(inplace = True), - nn.Conv2d(channels, channels, kernel_size = 3, padding = 1, bias = False), - nn.BatchNorm2d(channels), + nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False), + nn.BatchNorm2d(base_channels), nn.ReLU(inplace = True) ) self.relu = nn.ReLU(inplace = True)