mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
fix lr scheduler
This commit is contained in:
@@ -72,8 +72,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
'optimizer': generator_optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': generator_scheduler,
|
||||
'monitor': 'generator_loss'
|
||||
'scheduler': generator_scheduler
|
||||
}
|
||||
}
|
||||
discriminator_config =\
|
||||
@@ -81,8 +80,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
'optimizer': discriminator_optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': discriminator_scheduler,
|
||||
'monitor': 'discriminator_loss'
|
||||
'scheduler': discriminator_scheduler
|
||||
}
|
||||
}
|
||||
return generator_config, discriminator_config
|
||||
@@ -91,6 +89,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
source_tensor, target_tensor = batch
|
||||
do_update = (batch_index + 1) % self.config_accumulate_size == 0
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined]
|
||||
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
|
||||
|
||||
@@ -157,6 +156,11 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.log('identity_loss', identity_loss)
|
||||
self.log('gaze_loss', gaze_loss)
|
||||
self.log('mask_loss', mask_loss)
|
||||
|
||||
if do_update:
|
||||
generator_scheduler.step(generator_loss)
|
||||
discriminator_scheduler.step(discriminator_loss)
|
||||
|
||||
return generator_loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user