Add cosine_similarity

This commit is contained in:
henryruhs
2025-02-21 10:15:56 +01:00
parent 04eaa831ea
commit 5934b47961
+2 -3
View File
@@ -49,9 +49,8 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
with torch.no_grad():
source_embedding = self.source_embedder(batch)
target_embedding = self.target_embedder(batch)
output_embedding = self(source_embedding)
validation = self.mse_loss(output_embedding, target_embedding)
validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation', validation, prog_bar = True)
return validation
@@ -113,7 +112,7 @@ def create_trainer() -> Trainer:
save_last = True
)
],
val_check_interval = 10
val_check_interval = 100
)