limit discriminator training every 10 steps

This commit is contained in:
harisreedhar
2025-04-11 23:18:32 +05:30
parent dc2b2dc982
commit b215db68c3
+7 -6
View File
@@ -121,12 +121,13 @@ class FaceSwapperTrainer(LightningModule):
generator_optimizer.zero_grad()
self.untoggle_optimizer(generator_optimizer)
self.toggle_optimizer(discriminator_optimizer)
self.manual_backward(discriminator_loss)
if do_update:
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
self.untoggle_optimizer(discriminator_optimizer)
if self.global_step % 10 == 0:
self.toggle_optimizer(discriminator_optimizer)
self.manual_backward(discriminator_loss)
if do_update:
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
self.untoggle_optimizer(discriminator_optimizer)
if self.global_step % self.config_preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)