diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 81b1fb3..5029efb 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -150,7 +150,7 @@ class AdversarialLoss(torch.nn.Module): temp_tensors = [] for discriminator_output_tensor in discriminator_output_tensors: - temp_tensor = torch.relu(1 - discriminator_output_tensor[0]).mean() + temp_tensor = torch.relu(1 - discriminator_output_tensor[0]).mean(dim = [ 1, 2, 3 ]).mean() temp_tensors.append(temp_tensor) loss = torch.stack(temp_tensors).mean() diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 0eac0e2..48e1199 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -88,7 +88,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): generator_loss = weighted_adversarial_loss + weighted_reconstruction_loss + weighted_identity_loss self.log('generator_loss_new', generator_loss, prog_bar = True) - self.log('adversarial_loss_new', adversarial_loss, prog_bar = True) + self.log('loss_adversarial_new', adversarial_loss, prog_bar = True) self.log('loss_reconstruction_new', reconstruction_loss) self.log('loss_identity_new', identity_loss) return generator_loss_set.get('loss_generator')