remove discriminator frequency

This commit is contained in:
harisreedhar
2025-04-15 13:36:21 +05:30
parent b215db68c3
commit b7a6f00e8b
2 changed files with 10 additions and 9 deletions
+2 -2
View File
@@ -88,13 +88,13 @@ class AugmentTransform:
albumentations.OneOf(
[
albumentations.MotionBlur(p = 0.1),
albumentations.ZoomBlur(p = 0.1)
albumentations.ZoomBlur(max_factor = (1.0, 1.1), p = 0.1)
], p = 0.2),
albumentations.RandomBrightnessContrast(p = 0.7),
albumentations.ColorJitter(p = 0.2),
albumentations.RGBShift(p = 0.7),
albumentations.Illumination(p = 0.2),
albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.7)
albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.3)
])
+8 -7
View File
@@ -116,18 +116,19 @@ class FaceSwapperTrainer(LightningModule):
self.toggle_optimizer(generator_optimizer)
self.manual_backward(generator_loss)
if do_update:
generator_optimizer.step()
generator_optimizer.zero_grad()
self.untoggle_optimizer(generator_optimizer)
if self.global_step % 10 == 0:
self.toggle_optimizer(discriminator_optimizer)
self.manual_backward(discriminator_loss)
if do_update:
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
self.untoggle_optimizer(discriminator_optimizer)
self.toggle_optimizer(discriminator_optimizer)
self.manual_backward(discriminator_loss)
if do_update:
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
self.untoggle_optimizer(discriminator_optimizer)
if self.global_step % self.config_preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)