From 857365770f052d4a6aaae08a0c5081cc8f94a331 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 19 Feb 2025 09:14:57 +0100 Subject: [PATCH] Make validation step more solid, failed on empty checksums --- face_swapper/src/training.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 4559bf6..63e439c 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -76,10 +76,15 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor, _ = batch - source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) - output_tensor, target_attributes = self.generator(target_tensor, source_embedding) - output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0)) - validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 + + if torch.isnan(source_tensor).any() and torch.isnan(target_tensor).any(): + source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) + output_tensor, target_attributes = self.generator(target_tensor, source_embedding) + output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0)) + validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 + else: + validation = torch.tensor(0.0) + self.log('validation', validation) return validation