diff --git a/embedding_converter/src/models/embedding_converter.py b/embedding_converter/src/models/embedding_converter.py index d9c61c4..de483f3 100644 --- a/embedding_converter/src/models/embedding_converter.py +++ b/embedding_converter/src/models/embedding_converter.py @@ -1,7 +1,5 @@ import torch -from torch import nn - -from ..types import VisionTensor +from torch import Tensor, nn class EmbeddingConverter(nn.Module): @@ -20,7 +18,7 @@ class EmbeddingConverter(nn.Module): nn.Linear(1024, 512) ]) - def forward(self, input_tensor : VisionTensor) -> VisionTensor: + def forward(self, input_tensor : Tensor) -> Tensor: output_tensor = input_tensor / torch.norm(input_tensor) for layer in self.layers[:-1]: