From 5934b47961ad983a18902bea2b7f684fa780e149 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 21 Feb 2025 10:15:56 +0100 Subject: [PATCH] Add cosine_similarity --- embedding_converter/src/training.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 6c1aba6..4b1f6b1 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -49,9 +49,8 @@ class EmbeddingConverterTrainer(lightning.LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: with torch.no_grad(): source_embedding = self.source_embedder(batch) - target_embedding = self.target_embedder(batch) output_embedding = self(source_embedding) - validation = self.mse_loss(output_embedding, target_embedding) + validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation', validation, prog_bar = True) return validation @@ -113,7 +112,7 @@ def create_trainer() -> Trainer: save_last = True ) ], - val_check_interval = 10 + val_check_interval = 100 )