From 39ce14b5903ce117874d23d0739989b5c4c60e20 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 15 Apr 2025 13:36:21 +0530 Subject: [PATCH] remove discriminator frequency --- face_swapper/src/training.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 554a457..8d1d749 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -116,18 +116,19 @@ class FaceSwapperTrainer(LightningModule): self.toggle_optimizer(generator_optimizer) self.manual_backward(generator_loss) + if do_update: generator_optimizer.step() generator_optimizer.zero_grad() self.untoggle_optimizer(generator_optimizer) - if self.global_step % 10 == 0: - self.toggle_optimizer(discriminator_optimizer) - self.manual_backward(discriminator_loss) - if do_update: - discriminator_optimizer.step() - discriminator_optimizer.zero_grad() - self.untoggle_optimizer(discriminator_optimizer) + self.toggle_optimizer(discriminator_optimizer) + self.manual_backward(discriminator_loss) + + if do_update: + discriminator_optimizer.step() + discriminator_optimizer.zero_grad() + self.untoggle_optimizer(discriminator_optimizer) if self.global_step % self.config_preview_frequency == 0: self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)