diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index e7476a3..df3343a 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -79,12 +79,8 @@ class ReconstructionLoss(nn.Module): temp_tensors = [] for _source_tensor, _target_tensor in zip(source_tensor, target_tensor): - temp_tensor = self.mse_loss(_source_tensor, _target_tensor) - - if torch.equal(_source_tensor, _target_tensor): - temp_tensors.append(temp_tensor) - else: - temp_tensors.append(temp_tensor * 0) + temp_tensor = self.mse_loss(_source_tensor, _target_tensor) * torch.equal(_source_tensor, _target_tensor) + temp_tensors.append(temp_tensor) reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5 data_range = float(torch.max(output_tensor) - torch.min(output_tensor))