* Add gradient value clip

* Add gradient clip to config

* Fix HifiFace in preview

* Fix HifiFace in preview

* Fix HifiFace in preview

* Adjust save top
This commit is contained in:
Henry Ruhs
2025-05-05 10:03:34 +02:00
committed by GitHub
parent d68b77bd4d
commit 475b8b1538
5 changed files with 17 additions and 2 deletions
Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.0 MiB

After

Width:  |  Height:  |  Size: 1.0 MiB

+1 -1
View File
@@ -110,7 +110,7 @@ def create_trainer() -> Trainer:
dirpath = config_directory_path,
filename = config_file_pattern,
every_n_epochs = 1,
save_top_k = 3,
save_top_k = 5,
save_last = True
)
]
+1
View File
@@ -88,6 +88,7 @@ mask_weight = 5.0
[training.trainer]
accumulate_size = 4
learning_rate = 0.0004
gradient_clip = 20.0
max_epochs = 50
strategy = auto
precision = 16-mixed
+1
View File
@@ -46,6 +46,7 @@ mask_weight =
[training.trainer]
accumulate_size =
learning_rate =
gradient_clip =
max_epochs =
strategy =
precision =
+14 -1
View File
@@ -34,6 +34,7 @@ class HyperSwapTrainer(LightningModule):
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
@@ -114,6 +115,12 @@ class HyperSwapTrainer(LightningModule):
self.manual_backward(generator_loss)
if do_update:
if self.config_gradient_clip:
self.clip_gradients(
generator_optimizer,
gradient_clip_val = self.config_gradient_clip,
gradient_clip_algorithm = 'value'
)
generator_optimizer.step()
generator_optimizer.zero_grad()
self.untoggle_optimizer(generator_optimizer)
@@ -122,6 +129,12 @@ class HyperSwapTrainer(LightningModule):
self.manual_backward(discriminator_loss)
if do_update:
if self.config_gradient_clip:
self.clip_gradients(
discriminator_optimizer,
gradient_clip_val = self.config_gradient_clip,
gradient_clip_algorithm = 'value'
)
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
self.untoggle_optimizer(discriminator_optimizer)
@@ -223,7 +236,7 @@ def create_trainer() -> Trainer:
dirpath = config_directory_path,
filename = config_file_pattern,
every_n_train_steps = 1000,
save_top_k = 3,
save_top_k = 5,
save_last = True
)
],