Normalize validation output

This commit is contained in:
henryruhs
2025-02-19 08:40:50 +01:00
parent 0d45568bd1
commit d25f2865a9
+1 -1
View File
@@ -79,7 +79,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
output_tensor, target_attributes = self.generator(target_tensor, source_embedding)
output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0))
validation = nn.functional.cosine_similarity(source_embedding, output_embedding).mean()
validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation', validation)
return validation