From a22adaf51fe130d24643d82397dd07ca3f71a8c3 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sat, 1 Mar 2025 16:06:38 +0100 Subject: [PATCH] Adjust some namings --- face_swapper/src/models/loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 2679725..c2be15e 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -79,14 +79,14 @@ class ReconstructionLoss(nn.Module): reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight') source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) - is_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 + has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3)) - reconstruction_loss = (reconstruction_loss * is_similar_identity).mean() * 0.5 + reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5 data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) - similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() - reconstruction_loss = (reconstruction_loss + similarity) * 0.5 + visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() + reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5 weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight return reconstruction_loss, weighted_reconstruction_loss