This commit is contained in:
henryruhs
2025-02-14 16:40:39 +01:00
parent 88c4e53192
commit e1e0c11bb5
@@ -1,7 +1,5 @@
import torch
from torch import nn
from ..types import VisionTensor
from torch import Tensor, nn
class EmbeddingConverter(nn.Module):
@@ -20,7 +18,7 @@ class EmbeddingConverter(nn.Module):
nn.Linear(1024, 512)
])
def forward(self, input_tensor : VisionTensor) -> VisionTensor:
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = input_tensor / torch.norm(input_tensor)
for layer in self.layers[:-1]: