This commit is contained in:
harisreedhar
2025-04-15 14:14:09 +05:30
parent 128726701b
commit bcf5b4e5a8
+2 -2
View File
@@ -54,7 +54,7 @@ class CycleLoss(nn.Module):
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight')
self.mae_loss = nn.L1Loss()
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
@@ -64,7 +64,7 @@ class CycleLoss(nn.Module):
temp_tensors.append(temp_tensor)
feature_loss = torch.stack(temp_tensors).mean()
reconstruction_loss = self.mae_loss(target_tensor, cycle_tensor)
reconstruction_loss = self.l1_loss(target_tensor, cycle_tensor)
cycle_loss = (feature_loss + reconstruction_loss) * 0.5
weighted_feature_loss = cycle_loss * self.config_cycle_weight
return cycle_loss, weighted_feature_loss