diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 68b702c..e7476a3 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -28,9 +28,9 @@ class DiscriminatorLoss(nn.Module): negative_tensor = torch.relu(1 - discriminator_source_tensor[0]).mean(dim = [ 1, 2, 3 ]) negative_tensors.append(negative_tensor) - discriminator_positive_loss = torch.stack(positive_tensors).mean() - discriminator_negative_loss = torch.stack(negative_tensors).mean() - discriminator_loss = (discriminator_positive_loss + discriminator_negative_loss) * 0.5 + positive_loss = torch.stack(positive_tensors).mean() + negative_loss = torch.stack(negative_tensors).mean() + discriminator_loss = (positive_loss + negative_loss) * 0.5 return discriminator_loss