diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 0e57faf..5621635 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -62,10 +62,10 @@ class HyperSwapTrainer(LightningModule): return output_tensor, output_mask def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: - generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2) - discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2) + generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8) + discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate * 0.5, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8) + generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8) + discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8) generator_config =\ { @@ -73,7 +73,7 @@ class HyperSwapTrainer(LightningModule): 'lr_scheduler': { 'scheduler': generator_scheduler, - 'interval': 'step' + 'monitor': 'generator_loss' } } discriminator_config =\ @@ -82,7 +82,7 @@ class HyperSwapTrainer(LightningModule): 'lr_scheduler': { 'scheduler': discriminator_scheduler, - 'interval': 'step' + 'monitor': 'discriminator_loss' } } return generator_config, discriminator_config @@ -125,7 +125,7 @@ class HyperSwapTrainer(LightningModule): self.clip_gradients( generator_optimizer, gradient_clip_val = self.config_gradient_clip, - gradient_clip_algorithm = 'value' + gradient_clip_algorithm = 'norm' ) generator_optimizer.step() generator_optimizer.zero_grad() @@ -139,7 +139,7 @@ class HyperSwapTrainer(LightningModule): self.clip_gradients( discriminator_optimizer, gradient_clip_val = self.config_gradient_clip, - gradient_clip_algorithm = 'value' + gradient_clip_algorithm = 'norm' ) discriminator_optimizer.step() discriminator_optimizer.zero_grad()