mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Merge pull request #84 from facefusion/stabilize-finetuning
Change optimizer and expose more parameters to config
This commit is contained in:
+9
-1
@@ -45,10 +45,18 @@ target_path = .models/arcface_simswap.pt
|
||||
|
||||
```
|
||||
[training.trainer]
|
||||
learning_rate = 0.001
|
||||
max_epochs = 4096
|
||||
strategy = auto
|
||||
precision = 16-mixed
|
||||
```
|
||||
|
||||
```
|
||||
[training.optimizer]
|
||||
learning_rate = 0.001
|
||||
```
|
||||
|
||||
```
|
||||
[training.logger]
|
||||
logger_path = .logs
|
||||
logger_name = crossface_simswap
|
||||
```
|
||||
|
||||
@@ -11,10 +11,14 @@ source_path =
|
||||
target_path =
|
||||
|
||||
[training.trainer]
|
||||
learning_rate =
|
||||
max_epochs =
|
||||
strategy =
|
||||
precision =
|
||||
|
||||
[training.optimizer]
|
||||
learning_rate =
|
||||
|
||||
[training.logger]
|
||||
logger_path =
|
||||
logger_name =
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class CrossFaceTrainer(LightningModule):
|
||||
super().__init__()
|
||||
self.config_source_path = config_parser.get('training.model', 'source_path')
|
||||
self.config_target_path = config_parser.get('training.model', 'target_path')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.config_learning_rate = config_parser.getfloat('training.optimizer', 'learning_rate')
|
||||
self.crossface = CrossFace()
|
||||
self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval()
|
||||
self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval()
|
||||
@@ -50,7 +50,7 @@ class CrossFaceTrainer(LightningModule):
|
||||
return validation_score
|
||||
|
||||
def configure_optimizers(self) -> OptimizerSet:
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr = self.config_learning_rate)
|
||||
optimizer = torch.optim.AdamW(self.parameters(), lr = self.config_learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
optimizer_set =\
|
||||
{
|
||||
@@ -91,8 +91,8 @@ def create_trainer() -> Trainer:
|
||||
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
|
||||
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
|
||||
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
|
||||
config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name')
|
||||
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
|
||||
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
|
||||
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
logger = TensorBoardLogger(config_logger_path, config_logger_name)
|
||||
|
||||
+21
-2
@@ -89,15 +89,34 @@ mask_weight = 5.0
|
||||
```
|
||||
[training.trainer]
|
||||
accumulate_size = 4
|
||||
learning_rate = 0.0004
|
||||
gradient_clip = 20.0
|
||||
noise_factor = 0.05
|
||||
max_epochs = 50
|
||||
strategy = auto
|
||||
precision = 16-mixed
|
||||
preview_frequency = 100
|
||||
```
|
||||
|
||||
```
|
||||
[training.optimizer.generator]
|
||||
learning_rate = 0.0004
|
||||
momentum = 0.5
|
||||
scheduler_factor = 0.7
|
||||
scheduler_patience = 2000
|
||||
```
|
||||
|
||||
```
|
||||
[training.optimizer.discriminator]
|
||||
learning_rate = 0.0002
|
||||
momentum = 0.5
|
||||
scheduler_factor = 0.7
|
||||
scheduler_patience = 2000
|
||||
```
|
||||
|
||||
```
|
||||
[training.logger]
|
||||
logger_path = .logs
|
||||
logger_name = hyperswap
|
||||
preview_frequency = 100
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
+15
-2
@@ -47,15 +47,28 @@ mask_weight =
|
||||
|
||||
[training.trainer]
|
||||
accumulate_size =
|
||||
learning_rate =
|
||||
gradient_clip =
|
||||
noise_factor =
|
||||
max_epochs =
|
||||
strategy =
|
||||
precision =
|
||||
preview_frequency =
|
||||
|
||||
[training.optimizer.generator]
|
||||
learning_rate =
|
||||
momentum =
|
||||
scheduler_factor =
|
||||
scheduler_patience =
|
||||
|
||||
[training.optimizer.discriminator]
|
||||
learning_rate =
|
||||
momentum =
|
||||
scheduler_factor =
|
||||
scheduler_patience =
|
||||
|
||||
[training.logger]
|
||||
logger_path =
|
||||
logger_name =
|
||||
preview_frequency =
|
||||
|
||||
[training.output]
|
||||
directory_path =
|
||||
|
||||
+24
-13
@@ -35,9 +35,16 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
|
||||
self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor')
|
||||
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
|
||||
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
|
||||
self.config_generator_learning_rate = config_parser.getfloat('training.optimizer.generator', 'learning_rate')
|
||||
self.config_generator_momentum = config_parser.getfloat('training.optimizer.generator', 'momentum')
|
||||
self.config_generator_scheduler_factor = config_parser.getfloat('training.optimizer.generator', 'scheduler_factor')
|
||||
self.config_generator_scheduler_patience = config_parser.getint('training.optimizer.generator', 'scheduler_patience')
|
||||
self.config_discriminator_learning_rate = config_parser.getfloat('training.optimizer.discriminator', 'learning_rate')
|
||||
self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum')
|
||||
self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor')
|
||||
self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience')
|
||||
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
|
||||
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
|
||||
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
|
||||
@@ -62,18 +69,17 @@ class HyperSwapTrainer(LightningModule):
|
||||
return output_tensor, output_mask
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_generator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_discriminator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
|
||||
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_generator_scheduler_factor, patience = self.config_generator_scheduler_patience, min_lr = 1e-8)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_discriminator_scheduler_factor, patience = self.config_discriminator_scheduler_patience, min_lr = 1e-8)
|
||||
|
||||
generator_config =\
|
||||
{
|
||||
'optimizer': generator_optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': generator_scheduler,
|
||||
'interval': 'step'
|
||||
'scheduler': generator_scheduler
|
||||
}
|
||||
}
|
||||
discriminator_config =\
|
||||
@@ -81,8 +87,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
'optimizer': discriminator_optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': discriminator_scheduler,
|
||||
'interval': 'step'
|
||||
'scheduler': discriminator_scheduler
|
||||
}
|
||||
}
|
||||
return generator_config, discriminator_config
|
||||
@@ -91,6 +96,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
source_tensor, target_tensor = batch
|
||||
do_update = (batch_index + 1) % self.config_accumulate_size == 0
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined]
|
||||
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
|
||||
|
||||
@@ -125,7 +131,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.clip_gradients(
|
||||
generator_optimizer,
|
||||
gradient_clip_val = self.config_gradient_clip,
|
||||
gradient_clip_algorithm = 'value'
|
||||
gradient_clip_algorithm = 'norm'
|
||||
)
|
||||
generator_optimizer.step()
|
||||
generator_optimizer.zero_grad()
|
||||
@@ -139,7 +145,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.clip_gradients(
|
||||
discriminator_optimizer,
|
||||
gradient_clip_val = self.config_gradient_clip,
|
||||
gradient_clip_algorithm = 'value'
|
||||
gradient_clip_algorithm = 'norm'
|
||||
)
|
||||
discriminator_optimizer.step()
|
||||
discriminator_optimizer.zero_grad()
|
||||
@@ -157,6 +163,11 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.log('identity_loss', identity_loss)
|
||||
self.log('gaze_loss', gaze_loss)
|
||||
self.log('mask_loss', mask_loss)
|
||||
|
||||
if do_update:
|
||||
generator_scheduler.step(generator_loss)
|
||||
discriminator_scheduler.step(discriminator_loss)
|
||||
|
||||
return generator_loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
@@ -226,8 +237,8 @@ def create_trainer() -> Trainer:
|
||||
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
|
||||
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
|
||||
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
|
||||
config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name')
|
||||
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
|
||||
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
|
||||
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
logger = TensorBoardLogger(config_logger_path, config_logger_name)
|
||||
|
||||
Reference in New Issue
Block a user