stabilize finetune

This commit is contained in:
harisreedhar
2025-06-11 14:15:11 +05:30
parent ce7aaa57dc
commit a06f5fd9e8
+8 -8
View File
@@ -62,10 +62,10 @@ class HyperSwapTrainer(LightningModule):
return output_tensor, output_mask
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate * 0.5, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8)
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8)
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8)
generator_config =\
{
@@ -73,7 +73,7 @@ class HyperSwapTrainer(LightningModule):
'lr_scheduler':
{
'scheduler': generator_scheduler,
'interval': 'step'
'monitor': 'generator_loss'
}
}
discriminator_config =\
@@ -82,7 +82,7 @@ class HyperSwapTrainer(LightningModule):
'lr_scheduler':
{
'scheduler': discriminator_scheduler,
'interval': 'step'
'monitor': 'discriminator_loss'
}
}
return generator_config, discriminator_config
@@ -125,7 +125,7 @@ class HyperSwapTrainer(LightningModule):
self.clip_gradients(
generator_optimizer,
gradient_clip_val = self.config_gradient_clip,
gradient_clip_algorithm = 'value'
gradient_clip_algorithm = 'norm'
)
generator_optimizer.step()
generator_optimizer.zero_grad()
@@ -139,7 +139,7 @@ class HyperSwapTrainer(LightningModule):
self.clip_gradients(
discriminator_optimizer,
gradient_clip_val = self.config_gradient_clip,
gradient_clip_algorithm = 'value'
gradient_clip_algorithm = 'norm'
)
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()