Remove the condition from reconstruction loss

This commit is contained in:
henryruhs
2025-02-23 19:10:38 +01:00
parent 94480e16eb
commit 5bba2a1c69
+2 -6
View File
@@ -79,12 +79,8 @@ class ReconstructionLoss(nn.Module):
temp_tensors = []
for _source_tensor, _target_tensor in zip(source_tensor, target_tensor):
temp_tensor = self.mse_loss(_source_tensor, _target_tensor)
if torch.equal(_source_tensor, _target_tensor):
temp_tensors.append(temp_tensor)
else:
temp_tensors.append(temp_tensor * 0)
temp_tensor = self.mse_loss(_source_tensor, _target_tensor) * torch.equal(_source_tensor, _target_tensor)
temp_tensors.append(temp_tensor)
reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))