From 0055c0c97f07e4f4e93715b1d9453fc326413e5d Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sat, 1 Mar 2025 16:02:35 +0100 Subject: [PATCH] Adjust some namings --- face_swapper/src/models/loss.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index c0699f6..2679725 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -79,9 +79,11 @@ class ReconstructionLoss(nn.Module): reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight') source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) - same_person = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 + is_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 + reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3)) - reconstruction_loss = (reconstruction_loss * same_person).mean() * 0.5 + reconstruction_loss = (reconstruction_loss * is_similar_identity).mean() * 0.5 + data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() reconstruction_loss = (reconstruction_loss + similarity) * 0.5