mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Next (#75)
* Add gradient value clip * Add gradient clip to config * Fix HifiFace in preview * Fix HifiFace in preview * Fix HifiFace in preview * Adjust save top
This commit is contained in:
Binary file not shown.
|
Before Width: | Height: | Size: 1.0 MiB After Width: | Height: | Size: 1.0 MiB |
@@ -110,7 +110,7 @@ def create_trainer() -> Trainer:
|
||||
dirpath = config_directory_path,
|
||||
filename = config_file_pattern,
|
||||
every_n_epochs = 1,
|
||||
save_top_k = 3,
|
||||
save_top_k = 5,
|
||||
save_last = True
|
||||
)
|
||||
]
|
||||
|
||||
@@ -88,6 +88,7 @@ mask_weight = 5.0
|
||||
[training.trainer]
|
||||
accumulate_size = 4
|
||||
learning_rate = 0.0004
|
||||
gradient_clip = 20.0
|
||||
max_epochs = 50
|
||||
strategy = auto
|
||||
precision = 16-mixed
|
||||
|
||||
@@ -46,6 +46,7 @@ mask_weight =
|
||||
[training.trainer]
|
||||
accumulate_size =
|
||||
learning_rate =
|
||||
gradient_clip =
|
||||
max_epochs =
|
||||
strategy =
|
||||
precision =
|
||||
|
||||
@@ -34,6 +34,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
|
||||
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
|
||||
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
|
||||
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
|
||||
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
|
||||
@@ -114,6 +115,12 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.manual_backward(generator_loss)
|
||||
|
||||
if do_update:
|
||||
if self.config_gradient_clip:
|
||||
self.clip_gradients(
|
||||
generator_optimizer,
|
||||
gradient_clip_val = self.config_gradient_clip,
|
||||
gradient_clip_algorithm = 'value'
|
||||
)
|
||||
generator_optimizer.step()
|
||||
generator_optimizer.zero_grad()
|
||||
self.untoggle_optimizer(generator_optimizer)
|
||||
@@ -122,6 +129,12 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.manual_backward(discriminator_loss)
|
||||
|
||||
if do_update:
|
||||
if self.config_gradient_clip:
|
||||
self.clip_gradients(
|
||||
discriminator_optimizer,
|
||||
gradient_clip_val = self.config_gradient_clip,
|
||||
gradient_clip_algorithm = 'value'
|
||||
)
|
||||
discriminator_optimizer.step()
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.untoggle_optimizer(discriminator_optimizer)
|
||||
@@ -223,7 +236,7 @@ def create_trainer() -> Trainer:
|
||||
dirpath = config_directory_path,
|
||||
filename = config_file_pattern,
|
||||
every_n_train_steps = 1000,
|
||||
save_top_k = 3,
|
||||
save_top_k = 5,
|
||||
save_last = True
|
||||
)
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user