From 5bba2a1c69acefe304fb4f9734962bd6f1e90e4b Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sun, 23 Feb 2025 19:10:38 +0100 Subject: [PATCH] Remove the condition from reconstruction loss --- face_swapper/src/models/loss.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index e7476a3..df3343a 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -79,12 +79,8 @@ class ReconstructionLoss(nn.Module): temp_tensors = [] for _source_tensor, _target_tensor in zip(source_tensor, target_tensor): - temp_tensor = self.mse_loss(_source_tensor, _target_tensor) - - if torch.equal(_source_tensor, _target_tensor): - temp_tensors.append(temp_tensor) - else: - temp_tensors.append(temp_tensor * 0) + temp_tensor = self.mse_loss(_source_tensor, _target_tensor) * torch.equal(_source_tensor, _target_tensor) + temp_tensors.append(temp_tensor) reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5 data_range = float(torch.max(output_tensor) - torch.min(output_tensor))