mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Introduce new AdversarialLoss class fix
This commit is contained in:
@@ -150,7 +150,7 @@ class AdversarialLoss(torch.nn.Module):
|
||||
temp_tensors = []
|
||||
|
||||
for discriminator_output_tensor in discriminator_output_tensors:
|
||||
temp_tensor = torch.relu(1 - discriminator_output_tensor[0]).mean()
|
||||
temp_tensor = torch.relu(1 - discriminator_output_tensor[0]).mean(dim = [ 1, 2, 3 ]).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
loss = torch.stack(temp_tensors).mean()
|
||||
|
||||
@@ -88,7 +88,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
generator_loss = weighted_adversarial_loss + weighted_reconstruction_loss + weighted_identity_loss
|
||||
|
||||
self.log('generator_loss_new', generator_loss, prog_bar = True)
|
||||
self.log('adversarial_loss_new', adversarial_loss, prog_bar = True)
|
||||
self.log('loss_adversarial_new', adversarial_loss, prog_bar = True)
|
||||
self.log('loss_reconstruction_new', reconstruction_loss)
|
||||
self.log('loss_identity_new', identity_loss)
|
||||
return generator_loss_set.get('loss_generator')
|
||||
|
||||
Reference in New Issue
Block a user