fix lr scheduler

This commit is contained in:
harisreedhar
2025-06-11 15:38:28 +05:30
parent a06f5fd9e8
commit fc766b8327
+8 -4
View File
@@ -72,8 +72,7 @@ class HyperSwapTrainer(LightningModule):
'optimizer': generator_optimizer,
'lr_scheduler':
{
'scheduler': generator_scheduler,
'monitor': 'generator_loss'
'scheduler': generator_scheduler
}
}
discriminator_config =\
@@ -81,8 +80,7 @@ class HyperSwapTrainer(LightningModule):
'optimizer': discriminator_optimizer,
'lr_scheduler':
{
'scheduler': discriminator_scheduler,
'monitor': 'discriminator_loss'
'scheduler': discriminator_scheduler
}
}
return generator_config, discriminator_config
@@ -91,6 +89,7 @@ class HyperSwapTrainer(LightningModule):
source_tensor, target_tensor = batch
do_update = (batch_index + 1) % self.config_accumulate_size == 0
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
@@ -157,6 +156,11 @@ class HyperSwapTrainer(LightningModule):
self.log('identity_loss', identity_loss)
self.log('gaze_loss', gaze_loss)
self.log('mask_loss', mask_loss)
if do_update:
generator_scheduler.step(generator_loss)
discriminator_scheduler.step(discriminator_loss)
return generator_loss
def validation_step(self, batch : Batch, batch_index : int) -> Tensor: