diff --git a/.github/previews/crossface.png b/.github/previews/crossface.png index 640ff83..21a8dd4 100644 Binary files a/.github/previews/crossface.png and b/.github/previews/crossface.png differ diff --git a/crossface/src/training.py b/crossface/src/training.py index f3b140c..20ab9e8 100644 --- a/crossface/src/training.py +++ b/crossface/src/training.py @@ -110,7 +110,7 @@ def create_trainer() -> Trainer: dirpath = config_directory_path, filename = config_file_pattern, every_n_epochs = 1, - save_top_k = 3, + save_top_k = 5, save_last = True ) ] diff --git a/hyperswap/README.md b/hyperswap/README.md index a3f9654..9ce5bfd 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -88,6 +88,7 @@ mask_weight = 5.0 [training.trainer] accumulate_size = 4 learning_rate = 0.0004 +gradient_clip = 20.0 max_epochs = 50 strategy = auto precision = 16-mixed diff --git a/hyperswap/config.ini b/hyperswap/config.ini index fd94c01..b7d50c0 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -46,6 +46,7 @@ mask_weight = [training.trainer] accumulate_size = learning_rate = +gradient_clip = max_epochs = strategy = precision = diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index acba51f..58275c1 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -34,6 +34,7 @@ class HyperSwapTrainer(LightningModule): self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') + self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval() @@ -114,6 +115,12 @@ class HyperSwapTrainer(LightningModule): self.manual_backward(generator_loss) if do_update: + if self.config_gradient_clip: + self.clip_gradients( + generator_optimizer, + gradient_clip_val = self.config_gradient_clip, + gradient_clip_algorithm = 'value' + ) generator_optimizer.step() generator_optimizer.zero_grad() self.untoggle_optimizer(generator_optimizer) @@ -122,6 +129,12 @@ class HyperSwapTrainer(LightningModule): self.manual_backward(discriminator_loss) if do_update: + if self.config_gradient_clip: + self.clip_gradients( + discriminator_optimizer, + gradient_clip_val = self.config_gradient_clip, + gradient_clip_algorithm = 'value' + ) discriminator_optimizer.step() discriminator_optimizer.zero_grad() self.untoggle_optimizer(discriminator_optimizer) @@ -223,7 +236,7 @@ def create_trainer() -> Trainer: dirpath = config_directory_path, filename = config_file_pattern, every_n_train_steps = 1000, - save_top_k = 3, + save_top_k = 5, save_last = True ) ],