From fce54eb7db7785df14bac5aacff8a64f02a006bb Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Mon, 9 Jun 2025 19:00:56 +0530 Subject: [PATCH] fix ssim --- hyperswap/src/models/loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hyperswap/src/models/loss.py b/hyperswap/src/models/loss.py index dbe8e10..49e2c7b 100644 --- a/hyperswap/src/models/loss.py +++ b/hyperswap/src/models/loss.py @@ -105,8 +105,9 @@ class ReconstructionLoss(nn.Module): reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3)) reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5 - data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) - visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() + visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = 2.0) + visual_loss = (visual_loss * has_similar_identity).mean() + reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5 weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight return reconstruction_loss, weighted_reconstruction_loss