Make validation step more solid, failed on empty checksums

This commit is contained in:
henryruhs
2025-02-19 09:14:57 +01:00
parent d25f2865a9
commit 857365770f
+9 -4
View File
@@ -76,10 +76,15 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor, _ = batch
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
output_tensor, target_attributes = self.generator(target_tensor, source_embedding)
output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0))
validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
if torch.isnan(source_tensor).any() and torch.isnan(target_tensor).any():
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
output_tensor, target_attributes = self.generator(target_tensor, source_embedding)
output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0))
validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
else:
validation = torch.tensor(0.0)
self.log('validation', validation)
return validation