This commit is contained in:
harisreedhar
2025-04-15 14:10:58 +05:30
parent 4a319ec9bd
commit 128726701b
+2 -2
View File
@@ -64,8 +64,8 @@ class CycleLoss(nn.Module):
temp_tensors.append(temp_tensor)
feature_loss = torch.stack(temp_tensors).mean()
mae_loss = self.mae_loss(target_tensor, cycle_tensor)
cycle_loss = (feature_loss + mae_loss) * 0.5
reconstruction_loss = self.mae_loss(target_tensor, cycle_tensor)
cycle_loss = (feature_loss + reconstruction_loss) * 0.5
weighted_feature_loss = cycle_loss * self.config_cycle_weight
return cycle_loss, weighted_feature_loss