mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
stabilize training
This commit is contained in:
+5
-1
@@ -89,9 +89,13 @@ mask_weight = 5.0
|
||||
```
|
||||
[training.trainer]
|
||||
accumulate_size = 4
|
||||
learning_rate = 0.0004
|
||||
generator_learning_rate = 0.0004
|
||||
discriminator_learning_rate = 0.0002
|
||||
momentum = 0.5
|
||||
gradient_clip = 20.0
|
||||
noise_factor = 0.05
|
||||
scheduler_factor = 0.7
|
||||
scheduler_patience = 2000
|
||||
max_epochs = 50
|
||||
strategy = auto
|
||||
precision = 16-mixed
|
||||
|
||||
@@ -47,9 +47,13 @@ mask_weight =
|
||||
|
||||
[training.trainer]
|
||||
accumulate_size =
|
||||
learning_rate =
|
||||
generator_learning_rate =
|
||||
discriminator_learning_rate =
|
||||
momentum =
|
||||
gradient_clip =
|
||||
noise_factor =
|
||||
scheduler_factor =
|
||||
scheduler_patience =
|
||||
max_epochs =
|
||||
strategy =
|
||||
precision =
|
||||
|
||||
@@ -35,7 +35,11 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
|
||||
self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor')
|
||||
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.config_generator_learning_rate = config_parser.getfloat('training.trainer', 'generator_learning_rate')
|
||||
self.config_discriminator_learning_rate = config_parser.getfloat('training.trainer', 'discriminator_learning_rate')
|
||||
self.config_momentum = config_parser.getfloat('training.trainer', 'momentum')
|
||||
self.config_scheduler_factor = config_parser.getfloat('training.trainer', 'scheduler_factor')
|
||||
self.config_scheduler_patience = config_parser.getint('training.trainer', 'scheduler_patience')
|
||||
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
|
||||
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
|
||||
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
|
||||
@@ -62,10 +66,10 @@ class HyperSwapTrainer(LightningModule):
|
||||
return output_tensor, output_mask
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate * 0.5, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8)
|
||||
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8)
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
|
||||
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8)
|
||||
|
||||
generator_config =\
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user