mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -54,7 +54,7 @@ class CycleLoss(nn.Module):
|
||||
super().__init__()
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight')
|
||||
self.l1_loss = nn.L1Loss()
|
||||
self.mae_loss = nn.L1Loss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
@@ -63,9 +63,9 @@ class CycleLoss(nn.Module):
|
||||
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
cycle_feature_loss = torch.stack(temp_tensors).mean()
|
||||
cycle_l1_loss = self.l1_loss(target_tensor, cycle_tensor)
|
||||
cycle_loss = (cycle_feature_loss + cycle_l1_loss) * 0.5
|
||||
feature_loss = torch.stack(temp_tensors).mean()
|
||||
mae_loss = self.mae_loss(target_tensor, cycle_tensor)
|
||||
cycle_loss = (feature_loss + mae_loss) * 0.5
|
||||
weighted_feature_loss = cycle_loss * self.config_cycle_weight
|
||||
return cycle_loss, weighted_feature_loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user