mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
stabilize finetune
This commit is contained in:
@@ -62,10 +62,10 @@ class HyperSwapTrainer(LightningModule):
|
||||
return output_tensor, output_mask
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate * 0.5, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8)
|
||||
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8)
|
||||
|
||||
generator_config =\
|
||||
{
|
||||
@@ -73,7 +73,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': generator_scheduler,
|
||||
'interval': 'step'
|
||||
'monitor': 'generator_loss'
|
||||
}
|
||||
}
|
||||
discriminator_config =\
|
||||
@@ -82,7 +82,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': discriminator_scheduler,
|
||||
'interval': 'step'
|
||||
'monitor': 'discriminator_loss'
|
||||
}
|
||||
}
|
||||
return generator_config, discriminator_config
|
||||
@@ -125,7 +125,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.clip_gradients(
|
||||
generator_optimizer,
|
||||
gradient_clip_val = self.config_gradient_clip,
|
||||
gradient_clip_algorithm = 'value'
|
||||
gradient_clip_algorithm = 'norm'
|
||||
)
|
||||
generator_optimizer.step()
|
||||
generator_optimizer.zero_grad()
|
||||
@@ -139,7 +139,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.clip_gradients(
|
||||
discriminator_optimizer,
|
||||
gradient_clip_val = self.config_gradient_clip,
|
||||
gradient_clip_algorithm = 'value'
|
||||
gradient_clip_algorithm = 'norm'
|
||||
)
|
||||
discriminator_optimizer.step()
|
||||
discriminator_optimizer.zero_grad()
|
||||
|
||||
Reference in New Issue
Block a user