diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 19502cd..b928679 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -139,7 +139,7 @@ class FaceSwapperTrainer(LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - output_tensor, _ = self(source_embedding, target_tensor) + output_tensor, _ = self.forward(source_embedding, target_tensor) output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0)) validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)