This commit is contained in:
harisreedhar
2025-03-11 12:48:44 +05:30
committed by henryruhs
parent 90cb6afe10
commit d809f66216
3 changed files with 19 additions and 6 deletions
+1
View File
@@ -86,6 +86,7 @@ expression_weight = 0.0
```
[training.trainer]
accumulate_size = 4
learning_rate = 0.0004
max_epochs = 50
precision = 16-mixed
+1
View File
@@ -44,6 +44,7 @@ pose_weight =
expression_weight =
[training.trainer]
accumulate_size =
learning_rate =
max_epochs =
precision =
+17 -6
View File
@@ -33,6 +33,7 @@ class FaceSwapperTrainer(LightningModule):
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path')
self.config_parser_path = config_parser.get('training.model', 'parser_path')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
self.embedder = torch.jit.load(self.config_embedder_path, map_location = 'cpu').eval()
@@ -95,6 +96,7 @@ class FaceSwapperTrainer(LightningModule):
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
is_accumulation_step = (batch_index + 1) % self.config_accumulate_size == 0
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_attributes = self.generator.get_attributes(target_tensor)
@@ -111,18 +113,24 @@ class FaceSwapperTrainer(LightningModule):
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
generator_optimizer.zero_grad()
self.manual_backward(generator_loss)
generator_optimizer.step()
if is_accumulation_step:
generator_optimizer.step()
generator_optimizer.zero_grad()
self.untoggle_optimizer(generator_optimizer)
self.toggle_optimizer(masker_optimizer)
mask_tensor = self.masker(target_tensor, target_attributes[-1].detach())
mask_loss = self.mask_loss(target_tensor, mask_tensor)
masker_optimizer.zero_grad()
self.manual_backward(mask_loss)
masker_optimizer.step()
if is_accumulation_step:
masker_optimizer.step()
masker_optimizer.zero_grad()
self.untoggle_optimizer(masker_optimizer)
self.toggle_optimizer(discriminator_optimizer)
@@ -130,9 +138,12 @@ class FaceSwapperTrainer(LightningModule):
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_loss)
discriminator_optimizer.step()
if is_accumulation_step:
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
self.untoggle_optimizer(discriminator_optimizer)
if self.global_step % self.config_preview_frequency == 0: