From b215db68c30141590dc5a5e2d3ea51fbb31c8635 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Fri, 11 Apr 2025 23:18:32 +0530 Subject: [PATCH] limit discriminator training every 10 steps --- face_swapper/src/training.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 59c47f7..ea03267 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -121,12 +121,13 @@ class FaceSwapperTrainer(LightningModule): generator_optimizer.zero_grad() self.untoggle_optimizer(generator_optimizer) - self.toggle_optimizer(discriminator_optimizer) - self.manual_backward(discriminator_loss) - if do_update: - discriminator_optimizer.step() - discriminator_optimizer.zero_grad() - self.untoggle_optimizer(discriminator_optimizer) + if self.global_step % 10 == 0: + self.toggle_optimizer(discriminator_optimizer) + self.manual_backward(discriminator_loss) + if do_update: + discriminator_optimizer.step() + discriminator_optimizer.zero_grad() + self.untoggle_optimizer(discriminator_optimizer) if self.global_step % self.config_preview_frequency == 0: self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)