diff --git a/embedding_converter/src/models/embedding_converter.py b/embedding_converter/src/models/embedding_converter.py index 95a3581..9485779 100644 --- a/embedding_converter/src/models/embedding_converter.py +++ b/embedding_converter/src/models/embedding_converter.py @@ -12,14 +12,13 @@ class EmbeddingConverter(nn.Module): @staticmethod def create_layers() -> nn.ModuleList: - layers = nn.ModuleList( + return nn.ModuleList( [ nn.Linear(512, 1024), nn.Linear(1024, 2048), nn.Linear(2048, 1024), nn.Linear(1024, 512) ]) - return layers def forward(self, input_tensor : VisionTensor) -> VisionTensor: output_tensor = input_tensor / torch.norm(input_tensor)