From d25f2865a9c793c096469e090db86eb8fcda0dc9 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 19 Feb 2025 08:40:50 +0100 Subject: [PATCH] Normalize validation output --- face_swapper/src/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 1ca38ba..4559bf6 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -79,7 +79,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) output_tensor, target_attributes = self.generator(target_tensor, source_embedding) output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0)) - validation = nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation', validation) return validation