Final refactoring for AAD done

This commit is contained in:
henryruhs
2025-02-24 00:15:13 +01:00
parent d7158749c2
commit bc174186eb
2 changed files with 70 additions and 44 deletions
+69 -43
View File
@@ -14,14 +14,14 @@ class AAD(nn.Module):
def create_layers(identity_channels : int, num_blocks : int) -> nn.ModuleList:
return nn.ModuleList(
[
AADResBlock(1024, 1024, 1024, identity_channels, num_blocks),
AADResBlock(1024, 1024, 2048, identity_channels, num_blocks),
AADResBlock(1024, 1024, 1024, identity_channels, num_blocks),
AADResBlock(1024, 512, 512, identity_channels, num_blocks),
AADResBlock(512, 256, 256, identity_channels, num_blocks),
AADResBlock(256, 128, 128, identity_channels, num_blocks),
AADResBlock(128, 64, 64, identity_channels, num_blocks),
AADResBlock(64, 3, 64, identity_channels, num_blocks)
AdaptiveFeatureModulation(1024, 1024, 1024, identity_channels, num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, identity_channels, num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, identity_channels, num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, identity_channels, num_blocks),
AdaptiveFeatureModulation(512, 256, 256, identity_channels, num_blocks),
AdaptiveFeatureModulation(256, 128, 128, identity_channels, num_blocks),
AdaptiveFeatureModulation(128, 64, 64, identity_channels, num_blocks),
AdaptiveFeatureModulation(64, 3, 64, identity_channels, num_blocks)
])
def forward(self, source_embedding : Embedding, target_attributes : Attributes) -> Tensor:
@@ -36,59 +36,85 @@ class AAD(nn.Module):
return output_tensor
class AADResBlock(nn.Module):
class AdaptiveFeatureModulation(nn.Module):
def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> None:
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.prepare_primary_add_blocks(input_channels, attribute_channels, identity_channels, output_channels, num_blocks)
self.prepare_auxiliary_add_blocks(input_channels, attribute_channels, identity_channels, output_channels)
self.primary_layers = self.create_primary_layers(input_channels, output_channels, attribute_channels, identity_channels, num_blocks)
self.shortcut_layers = self.create_shortcut_layers(input_channels, output_channels, attribute_channels, identity_channels)
def prepare_primary_add_blocks(self, input_channels : int, attribute_channels : int, identity_channels : int, output_channels : int, num_blocks : int) -> None:
primary_add_blocks = []
@staticmethod
def create_primary_layers(input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> nn.ModuleList:
primary_layers = nn.ModuleList()
for index in range(num_blocks):
intermediate_channels = input_channels if index < (num_blocks - 1) else output_channels
primary_add_blocks.extend(
[
FeatureModulation(input_channels, attribute_channels, identity_channels),
nn.ReLU(inplace = True),
nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, padding = 1, bias = False)
]
)
self.primary_add_blocks = AADSequential(*primary_add_blocks)
primary_layers.extend(
[
FeatureModulation(input_channels, attribute_channels, identity_channels),
nn.ReLU(inplace = True)
])
if index < num_blocks - 1:
primary_layers.append(nn.Conv2d(input_channels, input_channels, kernel_size = 3, padding = 1, bias = False))
else:
primary_layers.append(nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False))
return primary_layers
@staticmethod
def _create_primary_layers(input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> nn.ModuleList:
primary_layers = nn.ModuleList()
for index in range(num_blocks):
primary_layers.extend(
[
FeatureModulation(input_channels, attribute_channels, identity_channels),
nn.ReLU(inplace = True)
])
if index < num_blocks - 1:
primary_layers.append(nn.Conv2d(input_channels, input_channels, kernel_size = 3, padding = 1, bias = False))
else:
primary_layers.append(nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False))
return primary_layers
@staticmethod
def create_shortcut_layers(input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int) -> nn.ModuleList:
shortcut_layers = nn.ModuleList()
def prepare_auxiliary_add_blocks(self, input_channels : int, attribute_channels : int, identity_channels : int, output_channels : int) -> None:
if input_channels > output_channels:
auxiliary_add_blocks = AADSequential(
shortcut_layers.extend(
[
FeatureModulation(input_channels, attribute_channels, identity_channels),
nn.ReLU(inplace = True),
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False)
)
self.auxiliary_add_blocks = auxiliary_add_blocks
])
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
primary_feature = self.primary_add_blocks(feature_map, attribute_embedding, identity_embedding)
return shortcut_layers
def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
primary_tensor = input_tensor
for primary_layer in self.primary_layers:
if isinstance(primary_layer, FeatureModulation):
primary_tensor = primary_layer(primary_tensor, attribute_embedding, identity_embedding)
else:
primary_tensor = primary_layer(primary_tensor)
if self.input_channels > self.output_channels:
feature_map = self.auxiliary_add_blocks(feature_map, attribute_embedding, identity_embedding)
shortcut_tensor = input_tensor
output_feature = primary_feature + feature_map
return output_feature
for shortcut_layer in self.shortcut_layers:
if isinstance(shortcut_layer, FeatureModulation):
shortcut_tensor = shortcut_layer(shortcut_tensor, attribute_embedding, identity_embedding)
else:
shortcut_tensor = shortcut_layer(shortcut_tensor)
input_tensor = shortcut_tensor
class AADSequential(nn.Module):
def __init__(self, *args : nn.Module) -> None:
super().__init__()
self.layers = nn.ModuleList(args)
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
for layer in self.layers:
if isinstance(layer, FeatureModulation):
feature_map = layer(feature_map, attribute_embedding, identity_embedding)
else:
feature_map = layer(feature_map)
return feature_map
return primary_tensor + input_tensor
class FeatureModulation(nn.Module):
+1 -1
View File
@@ -9,7 +9,7 @@ class NLD(nn.Module):
self.nld = self.create_nld(input_channels, num_filters, num_layers, kernel_size)
@staticmethod
def create_nld(input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> nn.Sequential:
def create_nld(input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> nn.Sequential:
padding = math.ceil((kernel_size - 1) / 2)
current_filters = num_filters
layers =\