This commit is contained in:
harisreedhar
2025-03-23 18:45:35 +05:30
parent 9ede8a2a7d
commit 602e890af2
+1 -1
View File
@@ -139,7 +139,7 @@ class FaceSwapperTrainer(LightningModule):
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
output_tensor, _ = self(source_embedding, target_tensor)
output_tensor, _ = self.forward(source_embedding, target_tensor)
output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)