Merge pull request #83 from facefusion/ssim-fix

fix ssim
This commit is contained in:
Harisreedhar
2025-06-09 19:07:30 +05:30
committed by GitHub
+3 -2
View File
@@ -105,8 +105,9 @@ class ReconstructionLoss(nn.Module):
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = 2.0)
visual_loss = (visual_loss * has_similar_identity).mean()
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight
return reconstruction_loss, weighted_reconstruction_loss