add masknet layer

This commit is contained in:
harisreedhar
2025-03-19 19:27:09 +05:30
parent 10b6f801d1
commit b6a2734622
+4 -2
View File
@@ -14,7 +14,7 @@ class MaskNet(nn.Module):
self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters')
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters)
self.up_samples = self.create_up_samples(self.config_num_filters)
self.bottleneck = BottleNeck(self.config_num_filters * 2)
self.bottleneck = BottleNeck(self.config_num_filters * 4)
self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1)
self.sigmoid = nn.Sigmoid()
@@ -23,13 +23,15 @@ class MaskNet(nn.Module):
return nn.ModuleList(
[
DownSample(input_channels, num_filters),
DownSample(num_filters, num_filters * 2)
DownSample(num_filters, num_filters * 2),
DownSample(num_filters * 2, num_filters * 4)
])
@staticmethod
def create_up_samples(num_filters : int) -> nn.ModuleList:
return nn.ModuleList(
[
UpSample(num_filters * 4, num_filters * 2),
UpSample(num_filters * 2, num_filters),
UpSample(num_filters, num_filters)
])