diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 01cb2c8..704be44 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -46,8 +46,8 @@ class FaceSwapperTrainer(lightning.LightningModule): def configure_optimizers(self) -> Tuple[OptimizerConfig, OptimizerConfig]: learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate') - generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) + generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) + discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2) discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)