Some code review

This commit is contained in:
henryruhs
2025-03-10 11:09:50 +01:00
parent d2efb2fd08
commit f9d105ea2b
+11 -11
View File
@@ -16,23 +16,23 @@ class MaskNet(nn.Module):
self.conv = nn.Conv2d(self.config_base_channels, self.config_output_channels, kernel_size = 1)
self.sigmoid = nn.Sigmoid()
def create_down_samples(self, input_channels : int, base_channels: int) -> nn.ModuleList:
down_samples = nn.ModuleList(
@staticmethod
def create_down_samples(input_channels : int, base_channels : int) -> nn.ModuleList:
return nn.ModuleList(
[
DownSample(input_channels, base_channels),
DownSample(base_channels, base_channels * 2),
DownSample(base_channels * 2, base_channels * 4)
])
return down_samples
def create_up_samples(self, base_channels : int) -> nn.ModuleList:
down_samples = nn.ModuleList(
@staticmethod
def create_up_samples(base_channels : int) -> nn.ModuleList:
return nn.ModuleList(
[
UpSample(base_channels * 4, base_channels * 2),
UpSample(base_channels * 2, base_channels),
UpSample(base_channels, base_channels)
])
return down_samples
def forward(self, target_tensor : Tensor, target_attribute : Tensor) -> Tensor:
output_tensor = torch.cat([ target_tensor, target_attribute ], dim = 1)
@@ -51,14 +51,14 @@ class MaskNet(nn.Module):
class BottleNeck(nn.Module):
def __init__(self, channels : int):
def __init__(self, base_channels : int):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(channels),
nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(base_channels),
nn.ReLU(inplace = True),
nn.Conv2d(channels, channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(channels),
nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(base_channels),
nn.ReLU(inplace = True)
)
self.relu = nn.ReLU(inplace = True)