From ab0a59fb74f30bebf94fd0912dc69fe5a8054cbf Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Wed, 26 Feb 2025 14:01:41 +0530 Subject: [PATCH] changes --- face_swapper/src/training.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index d732d37..046073b 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -48,8 +48,8 @@ class FaceSwapperTrainer(lightning.LightningModule): learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate') generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer) - discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer) + generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optimizer, T_max = 10, eta_min = 1e-6) + discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(discriminator_optimizer, T_max = 10, eta_min = 1e-6) generator_config =\ { @@ -58,7 +58,6 @@ class FaceSwapperTrainer(lightning.LightningModule): { 'scheduler': generator_scheduler, 'interval': 'step', - 'frequency': 1000 } } discriminator_config =\ @@ -68,7 +67,6 @@ class FaceSwapperTrainer(lightning.LightningModule): { 'scheduler': discriminator_scheduler, 'interval': 'step', - 'frequency': 1000 } } return generator_config, discriminator_config @@ -97,7 +95,7 @@ class FaceSwapperTrainer(lightning.LightningModule): generator_optimizer.step() generator_scheduler = self.lr_schedulers()[0] - generator_scheduler.step(generator_loss) + generator_scheduler.step() discriminator_source_tensors = self.discriminator(source_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) @@ -108,7 +106,7 @@ class FaceSwapperTrainer(lightning.LightningModule): discriminator_optimizer.step() discriminator_scheduler = self.lr_schedulers()[1] - discriminator_scheduler.step(discriminator_loss) + discriminator_scheduler.step() if self.global_step % preview_frequency == 0: self.generate_preview(source_tensor, target_tensor, generator_output_tensor)