From 4a319ec9bd0552451a326f4ff617d66e46471fa3 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 15 Apr 2025 14:05:27 +0530 Subject: [PATCH] changes --- face_swapper/src/models/loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 2087a38..ccbd96a 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -54,7 +54,7 @@ class CycleLoss(nn.Module): super().__init__() self.config_batch_size = config_parser.getint('training.loader', 'batch_size') self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight') - self.l1_loss = nn.L1Loss() + self.mae_loss = nn.L1Loss() def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]: temp_tensors = [] @@ -63,9 +63,9 @@ class CycleLoss(nn.Module): temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean() temp_tensors.append(temp_tensor) - cycle_feature_loss = torch.stack(temp_tensors).mean() - cycle_l1_loss = self.l1_loss(target_tensor, cycle_tensor) - cycle_loss = (cycle_feature_loss + cycle_l1_loss) * 0.5 + feature_loss = torch.stack(temp_tensors).mean() + mae_loss = self.mae_loss(target_tensor, cycle_tensor) + cycle_loss = (feature_loss + mae_loss) * 0.5 weighted_feature_loss = cycle_loss * self.config_cycle_weight return cycle_loss, weighted_feature_loss