diff --git a/hyperswap/src/models/loss.py b/hyperswap/src/models/loss.py index dbe8e10..49e2c7b 100644 --- a/hyperswap/src/models/loss.py +++ b/hyperswap/src/models/loss.py @@ -105,8 +105,9 @@ class ReconstructionLoss(nn.Module): reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3)) reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5 - data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) - visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() + visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = 2.0) + visual_loss = (visual_loss * has_similar_identity).mean() + reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5 weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight return reconstruction_loss, weighted_reconstruction_loss