This commit is contained in:
harisreedhar
2025-06-09 19:00:56 +05:30
parent 3e9c8a37e7
commit fce54eb7db
+3 -2
View File
@@ -105,8 +105,9 @@ class ReconstructionLoss(nn.Module):
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = 2.0)
visual_loss = (visual_loss * has_similar_identity).mean()
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight
return reconstruction_loss, weighted_reconstruction_loss