Improve naming

This commit is contained in:
henryruhs
2025-02-23 22:07:30 +01:00
parent 8b2b6892aa
commit d7158749c2
+4 -4
View File
@@ -51,7 +51,7 @@ class AADResBlock(nn.Module):
intermediate_channels = input_channels if index < (num_blocks - 1) else output_channels
primary_add_blocks.extend(
[
AADLayer(input_channels, attribute_channels, identity_channels),
FeatureModulation(input_channels, attribute_channels, identity_channels),
nn.ReLU(inplace = True),
nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, padding = 1, bias = False)
]
@@ -61,7 +61,7 @@ class AADResBlock(nn.Module):
def prepare_auxiliary_add_blocks(self, input_channels : int, attribute_channels : int, identity_channels : int, output_channels : int) -> None:
if input_channels > output_channels:
auxiliary_add_blocks = AADSequential(
AADLayer(input_channels, attribute_channels, identity_channels),
FeatureModulation(input_channels, attribute_channels, identity_channels),
nn.ReLU(inplace = True),
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False)
)
@@ -84,14 +84,14 @@ class AADSequential(nn.Module):
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
for layer in self.layers:
if isinstance(layer, AADLayer):
if isinstance(layer, FeatureModulation):
feature_map = layer(feature_map, attribute_embedding, identity_embedding)
else:
feature_map = layer(feature_map)
return feature_map
class AADLayer(nn.Module):
class FeatureModulation(nn.Module):
def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None:
super().__init__()
self.input_channels = input_channels