reconstruction-loss fix

This commit is contained in:
harisreedhar
2025-02-23 13:30:08 +05:30
committed by henryruhs
parent de6cfbc35b
commit 6fed877d33
+9 -8
View File
@@ -72,21 +72,22 @@ class AttributeLoss(nn.Module):
class ReconstructionLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mse_loss = nn.MSELoss()
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
reconstruction_loss = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1)
reconstruction_loss = torch.mean(reconstruction_loss, dim = 1) * 0.5
temp_tensors = []
if torch.equal(source_tensor, target_tensor):
reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4)
else:
reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4)
for _source_tensor, _target_tensor in zip(source_tensor, target_tensor):
temp_tensor = self.mse_loss(_source_tensor, _target_tensor)
if torch.equal(_source_tensor, _target_tensor):
temp_tensors.append(temp_tensor)
else:
temp_tensors.append(temp_tensor * 0)
reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
reconstruction_loss = (reconstruction_loss + similarity) * 0.5
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
return reconstruction_loss, weighted_reconstruction_loss