diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 5621635..13528e0 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -72,8 +72,7 @@ class HyperSwapTrainer(LightningModule): 'optimizer': generator_optimizer, 'lr_scheduler': { - 'scheduler': generator_scheduler, - 'monitor': 'generator_loss' + 'scheduler': generator_scheduler } } discriminator_config =\ @@ -81,8 +80,7 @@ class HyperSwapTrainer(LightningModule): 'optimizer': discriminator_optimizer, 'lr_scheduler': { - 'scheduler': discriminator_scheduler, - 'monitor': 'discriminator_loss' + 'scheduler': discriminator_scheduler } } return generator_config, discriminator_config @@ -91,6 +89,7 @@ class HyperSwapTrainer(LightningModule): source_tensor, target_tensor = batch do_update = (batch_index + 1) % self.config_accumulate_size == 0 generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] + generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0)) @@ -157,6 +156,11 @@ class HyperSwapTrainer(LightningModule): self.log('identity_loss', identity_loss) self.log('gaze_loss', gaze_loss) self.log('mask_loss', mask_loss) + + if do_update: + generator_scheduler.step(generator_loss) + discriminator_scheduler.step(discriminator_loss) + return generator_loss def validation_step(self, batch : Batch, batch_index : int) -> Tensor: