From b7a6f00e8bdfa0c433282ce29f5dd3889e3a1356 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 15 Apr 2025 13:36:21 +0530 Subject: [PATCH] remove discriminator frequency --- face_swapper/src/dataset.py | 4 ++-- face_swapper/src/training.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index 0a14eb4..db04932 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -88,13 +88,13 @@ class AugmentTransform: albumentations.OneOf( [ albumentations.MotionBlur(p = 0.1), - albumentations.ZoomBlur(p = 0.1) + albumentations.ZoomBlur(max_factor = (1.0, 1.1), p = 0.1) ], p = 0.2), albumentations.RandomBrightnessContrast(p = 0.7), albumentations.ColorJitter(p = 0.2), albumentations.RGBShift(p = 0.7), albumentations.Illumination(p = 0.2), - albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.7) + albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.3) ]) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index ea03267..1d994ae 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -116,18 +116,19 @@ class FaceSwapperTrainer(LightningModule): self.toggle_optimizer(generator_optimizer) self.manual_backward(generator_loss) + if do_update: generator_optimizer.step() generator_optimizer.zero_grad() self.untoggle_optimizer(generator_optimizer) - if self.global_step % 10 == 0: - self.toggle_optimizer(discriminator_optimizer) - self.manual_backward(discriminator_loss) - if do_update: - discriminator_optimizer.step() - discriminator_optimizer.zero_grad() - self.untoggle_optimizer(discriminator_optimizer) + self.toggle_optimizer(discriminator_optimizer) + self.manual_backward(discriminator_loss) + + if do_update: + discriminator_optimizer.step() + discriminator_optimizer.zero_grad() + self.untoggle_optimizer(discriminator_optimizer) if self.global_step % self.config_preview_frequency == 0: self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)