From b27b8663e5908835b75219f01e4ace4eeca73a11 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Thu, 27 Feb 2025 16:18:53 +0530 Subject: [PATCH] changes --- face_swapper/src/training.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 44d635d..01cb2c8 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -48,8 +48,8 @@ class FaceSwapperTrainer(lightning.LightningModule): learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate') generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optimizer, T_max = 10, eta_min = 1e-6) - discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(discriminator_optimizer, T_max = 10, eta_min = 1e-6) + generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2) + discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2) generator_config =\ { @@ -82,6 +82,7 @@ class FaceSwapperTrainer(lightning.LightningModule): generator_output_attributes = self.generator.get_attributes(generator_output_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) + self.toggle_optimizer(generator_optimizer) adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors) attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor) @@ -93,10 +94,9 @@ class FaceSwapperTrainer(lightning.LightningModule): generator_optimizer.zero_grad() self.manual_backward(generator_loss) generator_optimizer.step() + self.untoggle_optimizer(generator_optimizer) - generator_scheduler = self.lr_schedulers()[0] - generator_scheduler.step() - + self.toggle_optimizer(discriminator_optimizer) discriminator_source_tensors = self.discriminator(source_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) @@ -104,9 +104,7 @@ class FaceSwapperTrainer(lightning.LightningModule): discriminator_optimizer.zero_grad() self.manual_backward(discriminator_loss) discriminator_optimizer.step() - - discriminator_scheduler = self.lr_schedulers()[1] - discriminator_scheduler.step() + self.untoggle_optimizer(discriminator_optimizer) if self.global_step % preview_frequency == 0: self.generate_preview(source_tensor, target_tensor, generator_output_tensor)