From 6fed877d338fb31aefdb6b970b7d785adcc2db27 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Sun, 23 Feb 2025 13:30:08 +0530 Subject: [PATCH] reconstruction-loss fix --- face_swapper/src/models/loss.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 6df2c7b..e4592f6 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -72,21 +72,22 @@ class AttributeLoss(nn.Module): class ReconstructionLoss(nn.Module): def __init__(self) -> None: super().__init__() + self.mse_loss = nn.MSELoss() def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: - batch_size = CONFIG.getint('training.loader', 'batch_size') reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight') - reconstruction_loss = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1) - reconstruction_loss = torch.mean(reconstruction_loss, dim = 1) * 0.5 + temp_tensors = [] - if torch.equal(source_tensor, target_tensor): - reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4) - else: - reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4) + for _source_tensor, _target_tensor in zip(source_tensor, target_tensor): + temp_tensor = self.mse_loss(_source_tensor, _target_tensor) + if torch.equal(_source_tensor, _target_tensor): + temp_tensors.append(temp_tensor) + else: + temp_tensors.append(temp_tensor * 0) + reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5 data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() - reconstruction_loss = (reconstruction_loss + similarity) * 0.5 weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight return reconstruction_loss, weighted_reconstruction_loss