Adjust networks for 512

This commit is contained in:
henryruhs
2025-03-05 15:23:31 +01:00
parent 3e69d5a9a9
commit 7cc893c32e
+22 -25
View File
@@ -15,37 +15,34 @@ class AAD(nn.Module):
self.layers = self.create_layers()
def create_layers(self) -> nn.ModuleList:
layers = nn.ModuleList(
layers = nn.ModuleList()
if self.output_size == 256:
layers.extend(
[
AdaptiveFeatureModulation(1024, 1024, 1024, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, self.identity_channels, self.num_blocks)
])
if self.output_size == 512:
layers.extend(
[
AdaptiveFeatureModulation(2048, 2048, 2048, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(2048, 2048, 4096, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(2048, 2048, 2048, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(2048, 1024, 1024, self.identity_channels, self.num_blocks)
])
layers.extend(
[
AdaptiveFeatureModulation(1024, 1024, 1024, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(512, 256, 256, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(256, 128, 128, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks)
AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(64, 3, 64, self.identity_channels, self.num_blocks)
])
if self.output_size in [ 384, 512, 768, 1024 ]:
layers.append(AdaptiveFeatureModulation(64, 32, 32, self.identity_channels, self.num_blocks))
if self.output_size in [ 512, 768, 1024 ]:
layers.append(AdaptiveFeatureModulation(32, 16, 16, self.identity_channels, self.num_blocks))
if self.output_size in [ 768, 1024 ]:
layers.append(AdaptiveFeatureModulation(16, 8, 8, self.identity_channels, self.num_blocks))
if self.output_size == 1024:
layers.append(AdaptiveFeatureModulation(8, 4, 4, self.identity_channels, self.num_blocks))
if self.output_size == 256:
layers.append(AdaptiveFeatureModulation(64, 3, 64, self.identity_channels, self.num_blocks))
if self.output_size == 384:
layers.append(AdaptiveFeatureModulation(32, 3, 32, self.identity_channels, self.num_blocks))
if self.output_size == 512:
layers.append(AdaptiveFeatureModulation(16, 3, 16, self.identity_channels, self.num_blocks))
if self.output_size == 768:
layers.append(AdaptiveFeatureModulation(8, 3, 8, self.identity_channels, self.num_blocks))
if self.output_size == 1024:
layers.append(AdaptiveFeatureModulation(4, 3, 4, self.identity_channels, self.num_blocks))
return layers
def forward(self, source_embedding : Embedding, target_attributes : Attributes) -> Tensor: