diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 59c47f7..ea03267 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -121,12 +121,13 @@ class FaceSwapperTrainer(LightningModule): generator_optimizer.zero_grad() self.untoggle_optimizer(generator_optimizer) - self.toggle_optimizer(discriminator_optimizer) - self.manual_backward(discriminator_loss) - if do_update: - discriminator_optimizer.step() - discriminator_optimizer.zero_grad() - self.untoggle_optimizer(discriminator_optimizer) + if self.global_step % 10 == 0: + self.toggle_optimizer(discriminator_optimizer) + self.manual_backward(discriminator_loss) + if do_update: + discriminator_optimizer.step() + discriminator_optimizer.zero_grad() + self.untoggle_optimizer(discriminator_optimizer) if self.global_step % self.config_preview_frequency == 0: self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)