From 33d00ac9410ed94758d342c3b28f411d722ca016 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 14 Mar 2025 08:13:21 +0100 Subject: [PATCH] Mask typing and naming related updates --- face_swapper/src/networks/masknet.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index 0449fca..0a465e4 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -14,7 +14,7 @@ class MaskNet(nn.Module): self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters') self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters) self.up_samples = self.create_up_samples(self.config_num_filters) - self.bottleneck = BottleNeck(self.config_num_filters * 4) + self.bottleneck = BottleNeck(self.config_num_filters * 2) self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1) self.sigmoid = nn.Sigmoid() @@ -23,15 +23,13 @@ class MaskNet(nn.Module): return nn.ModuleList( [ DownSample(input_channels, num_filters), - DownSample(num_filters, num_filters * 2), - DownSample(num_filters, num_filters * 4) + DownSample(num_filters, num_filters * 2) ]) @staticmethod def create_up_samples(num_filters : int) -> nn.ModuleList: return nn.ModuleList( [ - UpSample(num_filters * 4, num_filters), UpSample(num_filters * 2, num_filters), UpSample(num_filters, num_filters) ])