diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 6c1aba6..4b1f6b1 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -49,9 +49,8 @@ class EmbeddingConverterTrainer(lightning.LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: with torch.no_grad(): source_embedding = self.source_embedder(batch) - target_embedding = self.target_embedder(batch) output_embedding = self(source_embedding) - validation = self.mse_loss(output_embedding, target_embedding) + validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation', validation, prog_bar = True) return validation @@ -113,7 +112,7 @@ def create_trainer() -> Trainer: save_last = True ) ], - val_check_interval = 10 + val_check_interval = 100 )