From d809f6621698fdf672698e458ac725477763d147 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 11 Mar 2025 12:48:44 +0530 Subject: [PATCH] changes --- face_swapper/README.md | 1 + face_swapper/config.ini | 1 + face_swapper/src/training.py | 23 +++++++++++++++++------ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/face_swapper/README.md b/face_swapper/README.md index 098755e..e2b53f9 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -86,6 +86,7 @@ expression_weight = 0.0 ``` [training.trainer] +accumulate_size = 4 learning_rate = 0.0004 max_epochs = 50 precision = 16-mixed diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 07d282d..271cab6 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -44,6 +44,7 @@ pose_weight = expression_weight = [training.trainer] +accumulate_size = learning_rate = max_epochs = precision = diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 6ba9604..7348d8e 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -33,6 +33,7 @@ class FaceSwapperTrainer(LightningModule): self.config_gazer_path = config_parser.get('training.model', 'gazer_path') self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path') self.config_parser_path = config_parser.get('training.model', 'parser_path') + self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') self.embedder = torch.jit.load(self.config_embedder_path, map_location = 'cpu').eval() @@ -95,6 +96,7 @@ class FaceSwapperTrainer(LightningModule): def training_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch + is_accumulation_step = (batch_index + 1) % self.config_accumulate_size == 0 generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) target_attributes = self.generator.get_attributes(target_tensor) @@ -111,18 +113,24 @@ class FaceSwapperTrainer(LightningModule): gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor) generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss - generator_optimizer.zero_grad() self.manual_backward(generator_loss) - generator_optimizer.step() + + if is_accumulation_step: + generator_optimizer.step() + generator_optimizer.zero_grad() + self.untoggle_optimizer(generator_optimizer) self.toggle_optimizer(masker_optimizer) mask_tensor = self.masker(target_tensor, target_attributes[-1].detach()) mask_loss = self.mask_loss(target_tensor, mask_tensor) - masker_optimizer.zero_grad() self.manual_backward(mask_loss) - masker_optimizer.step() + + if is_accumulation_step: + masker_optimizer.step() + masker_optimizer.zero_grad() + self.untoggle_optimizer(masker_optimizer) self.toggle_optimizer(discriminator_optimizer) @@ -130,9 +138,12 @@ class FaceSwapperTrainer(LightningModule): discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) - discriminator_optimizer.zero_grad() self.manual_backward(discriminator_loss) - discriminator_optimizer.step() + + if is_accumulation_step: + discriminator_optimizer.step() + discriminator_optimizer.zero_grad() + self.untoggle_optimizer(discriminator_optimizer) if self.global_step % self.config_preview_frequency == 0: