mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -86,6 +86,7 @@ expression_weight = 0.0
|
||||
|
||||
```
|
||||
[training.trainer]
|
||||
accumulate_size = 4
|
||||
learning_rate = 0.0004
|
||||
max_epochs = 50
|
||||
precision = 16-mixed
|
||||
|
||||
@@ -44,6 +44,7 @@ pose_weight =
|
||||
expression_weight =
|
||||
|
||||
[training.trainer]
|
||||
accumulate_size =
|
||||
learning_rate =
|
||||
max_epochs =
|
||||
precision =
|
||||
|
||||
@@ -33,6 +33,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
|
||||
self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path')
|
||||
self.config_parser_path = config_parser.get('training.model', 'parser_path')
|
||||
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
|
||||
self.embedder = torch.jit.load(self.config_embedder_path, map_location = 'cpu').eval()
|
||||
@@ -95,6 +96,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
is_accumulation_step = (batch_index + 1) % self.config_accumulate_size == 0
|
||||
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_attributes = self.generator.get_attributes(target_tensor)
|
||||
@@ -111,18 +113,24 @@ class FaceSwapperTrainer(LightningModule):
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
|
||||
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_loss)
|
||||
generator_optimizer.step()
|
||||
|
||||
if is_accumulation_step:
|
||||
generator_optimizer.step()
|
||||
generator_optimizer.zero_grad()
|
||||
|
||||
self.untoggle_optimizer(generator_optimizer)
|
||||
|
||||
self.toggle_optimizer(masker_optimizer)
|
||||
mask_tensor = self.masker(target_tensor, target_attributes[-1].detach())
|
||||
mask_loss = self.mask_loss(target_tensor, mask_tensor)
|
||||
|
||||
masker_optimizer.zero_grad()
|
||||
self.manual_backward(mask_loss)
|
||||
masker_optimizer.step()
|
||||
|
||||
if is_accumulation_step:
|
||||
masker_optimizer.step()
|
||||
masker_optimizer.zero_grad()
|
||||
|
||||
self.untoggle_optimizer(masker_optimizer)
|
||||
|
||||
self.toggle_optimizer(discriminator_optimizer)
|
||||
@@ -130,9 +138,12 @@ class FaceSwapperTrainer(LightningModule):
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_loss)
|
||||
discriminator_optimizer.step()
|
||||
|
||||
if is_accumulation_step:
|
||||
discriminator_optimizer.step()
|
||||
discriminator_optimizer.zero_grad()
|
||||
|
||||
self.untoggle_optimizer(discriminator_optimizer)
|
||||
|
||||
if self.global_step % self.config_preview_frequency == 0:
|
||||
|
||||
Reference in New Issue
Block a user