From a06f5fd9e8c17340d9a7dbe6f0e8f2cbd9d205cd Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Wed, 11 Jun 2025 14:15:11 +0530 Subject: [PATCH 1/8] stabilize finetune --- hyperswap/src/training.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 0e57faf..5621635 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -62,10 +62,10 @@ class HyperSwapTrainer(LightningModule): return output_tensor, output_mask def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: - generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2) - discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2) + generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8) + discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate * 0.5, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8) + generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8) + discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8) generator_config =\ { @@ -73,7 +73,7 @@ class HyperSwapTrainer(LightningModule): 'lr_scheduler': { 'scheduler': generator_scheduler, - 'interval': 'step' + 'monitor': 'generator_loss' } } discriminator_config =\ @@ -82,7 +82,7 @@ class HyperSwapTrainer(LightningModule): 'lr_scheduler': { 'scheduler': discriminator_scheduler, - 'interval': 'step' + 'monitor': 'discriminator_loss' } } return generator_config, discriminator_config @@ -125,7 +125,7 @@ class HyperSwapTrainer(LightningModule): self.clip_gradients( generator_optimizer, gradient_clip_val = self.config_gradient_clip, - gradient_clip_algorithm = 'value' + gradient_clip_algorithm = 'norm' ) generator_optimizer.step() generator_optimizer.zero_grad() @@ -139,7 +139,7 @@ class HyperSwapTrainer(LightningModule): self.clip_gradients( discriminator_optimizer, gradient_clip_val = self.config_gradient_clip, - gradient_clip_algorithm = 'value' + gradient_clip_algorithm = 'norm' ) discriminator_optimizer.step() discriminator_optimizer.zero_grad() From fc766b832771372db9419f97388f596489499a4f Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Wed, 11 Jun 2025 15:38:28 +0530 Subject: [PATCH 2/8] fix lr scheduler --- hyperswap/src/training.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 5621635..13528e0 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -72,8 +72,7 @@ class HyperSwapTrainer(LightningModule): 'optimizer': generator_optimizer, 'lr_scheduler': { - 'scheduler': generator_scheduler, - 'monitor': 'generator_loss' + 'scheduler': generator_scheduler } } discriminator_config =\ @@ -81,8 +80,7 @@ class HyperSwapTrainer(LightningModule): 'optimizer': discriminator_optimizer, 'lr_scheduler': { - 'scheduler': discriminator_scheduler, - 'monitor': 'discriminator_loss' + 'scheduler': discriminator_scheduler } } return generator_config, discriminator_config @@ -91,6 +89,7 @@ class HyperSwapTrainer(LightningModule): source_tensor, target_tensor = batch do_update = (batch_index + 1) % self.config_accumulate_size == 0 generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] + generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0)) @@ -157,6 +156,11 @@ class HyperSwapTrainer(LightningModule): self.log('identity_loss', identity_loss) self.log('gaze_loss', gaze_loss) self.log('mask_loss', mask_loss) + + if do_update: + generator_scheduler.step(generator_loss) + discriminator_scheduler.step(discriminator_loss) + return generator_loss def validation_step(self, batch : Batch, batch_index : int) -> Tensor: From e894e4172aaa6520b5f1bcfe8ce4152f9a014583 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 17 Jun 2025 13:06:47 +0530 Subject: [PATCH 3/8] stabilize training --- hyperswap/README.md | 6 +++++- hyperswap/config.ini | 6 +++++- hyperswap/src/training.py | 14 +++++++++----- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/hyperswap/README.md b/hyperswap/README.md index 538f0b2..eb773de 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -89,9 +89,13 @@ mask_weight = 5.0 ``` [training.trainer] accumulate_size = 4 -learning_rate = 0.0004 +generator_learning_rate = 0.0004 +discriminator_learning_rate = 0.0002 +momentum = 0.5 gradient_clip = 20.0 noise_factor = 0.05 +scheduler_factor = 0.7 +scheduler_patience = 2000 max_epochs = 50 strategy = auto precision = 16-mixed diff --git a/hyperswap/config.ini b/hyperswap/config.ini index 1f979c7..f93daa9 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -47,9 +47,13 @@ mask_weight = [training.trainer] accumulate_size = -learning_rate = +generator_learning_rate = +discriminator_learning_rate = +momentum = gradient_clip = noise_factor = +scheduler_factor = +scheduler_patience = max_epochs = strategy = precision = diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 13528e0..0559111 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -35,7 +35,11 @@ class HyperSwapTrainer(LightningModule): self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') - self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') + self.config_generator_learning_rate = config_parser.getfloat('training.trainer', 'generator_learning_rate') + self.config_discriminator_learning_rate = config_parser.getfloat('training.trainer', 'discriminator_learning_rate') + self.config_momentum = config_parser.getfloat('training.trainer', 'momentum') + self.config_scheduler_factor = config_parser.getfloat('training.trainer', 'scheduler_factor') + self.config_scheduler_patience = config_parser.getint('training.trainer', 'scheduler_patience') self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() @@ -62,10 +66,10 @@ class HyperSwapTrainer(LightningModule): return output_tensor, output_mask def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: - generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8) - discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate * 0.5, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8) - generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8) - discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8) + generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) + discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) + generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8) + discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8) generator_config =\ { From e846d88145716e524128db5de44b1ddd3274c5e9 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 17 Jun 2025 13:26:06 +0530 Subject: [PATCH 4/8] split config section --- hyperswap/README.md | 27 +++++++++++++++++++++------ hyperswap/config.ini | 21 +++++++++++++++------ hyperswap/src/training.py | 25 ++++++++++++++----------- 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/hyperswap/README.md b/hyperswap/README.md index eb773de..03d6d0e 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -89,19 +89,34 @@ mask_weight = 5.0 ``` [training.trainer] accumulate_size = 4 -generator_learning_rate = 0.0004 -discriminator_learning_rate = 0.0002 -momentum = 0.5 gradient_clip = 20.0 noise_factor = 0.05 -scheduler_factor = 0.7 -scheduler_patience = 2000 max_epochs = 50 strategy = auto precision = 16-mixed +preview_frequency = 100 +``` + +``` +[training.logger] logger_path = .logs logger_name = hyperswap -preview_frequency = 100 +``` + +``` +[training.generator] +learning_rate = 0.0004 +momentum = 0.5 +scheduler_factor = 0.7 +scheduler_patience = 2000 +``` + +``` +[training.discriminator] +learning_rate = 0.0002 +momentum = 0.5 +scheduler_factor = 0.7 +scheduler_patience = 2000 ``` ``` diff --git a/hyperswap/config.ini b/hyperswap/config.ini index f93daa9..2f37ba9 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -47,19 +47,28 @@ mask_weight = [training.trainer] accumulate_size = -generator_learning_rate = -discriminator_learning_rate = -momentum = gradient_clip = noise_factor = -scheduler_factor = -scheduler_patience = max_epochs = strategy = precision = +preview_frequency = + +[training.logger] logger_path = logger_name = -preview_frequency = + +[training.generator] +learning_rate = +momentum = +scheduler_factor = +scheduler_patience = + +[training.discriminator] +learning_rate = +momentum = +scheduler_factor = +scheduler_patience = [training.output] directory_path = diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 0559111..442650f 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -35,11 +35,14 @@ class HyperSwapTrainer(LightningModule): self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') - self.config_generator_learning_rate = config_parser.getfloat('training.trainer', 'generator_learning_rate') - self.config_discriminator_learning_rate = config_parser.getfloat('training.trainer', 'discriminator_learning_rate') - self.config_momentum = config_parser.getfloat('training.trainer', 'momentum') - self.config_scheduler_factor = config_parser.getfloat('training.trainer', 'scheduler_factor') - self.config_scheduler_patience = config_parser.getint('training.trainer', 'scheduler_patience') + self.config_generator_learning_rate = config_parser.getfloat('training.generator', 'learning_rate') + self.config_discriminator_learning_rate = config_parser.getfloat('training.discriminator', 'learning_rate') + self.config_generator_momentum = config_parser.getfloat('training.generator', 'momentum') + self.config_discriminator_momentum = config_parser.getfloat('training.discriminator', 'momentum') + self.config_generator_scheduler_factor = config_parser.getfloat('training.generator', 'scheduler_factor') + self.config_discriminator_scheduler_factor = config_parser.getfloat('training.discriminator', 'scheduler_factor') + self.config_generator_scheduler_patience = config_parser.getint('training.generator', 'scheduler_patience') + self.config_discriminator_scheduler_patience = config_parser.getint('training.discriminator', 'scheduler_patience') self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() @@ -66,10 +69,10 @@ class HyperSwapTrainer(LightningModule): return output_tensor, output_mask def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: - generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) - discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) - generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8) - discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8) + generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_generator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) + discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_discriminator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) + generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_generator_scheduler_factor, patience = self.config_generator_scheduler_patience, min_lr = 1e-8) + discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_discriminator_scheduler_factor, patience = self.config_discriminator_scheduler_patience, min_lr = 1e-8) generator_config =\ { @@ -234,8 +237,8 @@ def create_trainer() -> Trainer: config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') config_precision = CONFIG_PARSER.get('training.trainer', 'precision') - config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path') - config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name') + config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path') + config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name') config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern') logger = TensorBoardLogger(config_logger_path, config_logger_name) From 580a179f440b37fc157117fbb681a2c55d156ec6 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 17 Jun 2025 13:39:24 +0530 Subject: [PATCH 5/8] rename --- hyperswap/README.md | 4 ++-- hyperswap/config.ini | 4 ++-- hyperswap/src/training.py | 16 ++++++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/hyperswap/README.md b/hyperswap/README.md index 03d6d0e..9fbc68a 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -104,7 +104,7 @@ logger_name = hyperswap ``` ``` -[training.generator] +[training.optimizer.generator] learning_rate = 0.0004 momentum = 0.5 scheduler_factor = 0.7 @@ -112,7 +112,7 @@ scheduler_patience = 2000 ``` ``` -[training.discriminator] +[training.optimizer.discriminator] learning_rate = 0.0002 momentum = 0.5 scheduler_factor = 0.7 diff --git a/hyperswap/config.ini b/hyperswap/config.ini index 2f37ba9..b9085c9 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -58,13 +58,13 @@ preview_frequency = logger_path = logger_name = -[training.generator] +[training.optimizer.generator] learning_rate = momentum = scheduler_factor = scheduler_patience = -[training.discriminator] +[training.optimizer.discriminator] learning_rate = momentum = scheduler_factor = diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 442650f..b88380b 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -35,14 +35,14 @@ class HyperSwapTrainer(LightningModule): self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') - self.config_generator_learning_rate = config_parser.getfloat('training.generator', 'learning_rate') - self.config_discriminator_learning_rate = config_parser.getfloat('training.discriminator', 'learning_rate') - self.config_generator_momentum = config_parser.getfloat('training.generator', 'momentum') - self.config_discriminator_momentum = config_parser.getfloat('training.discriminator', 'momentum') - self.config_generator_scheduler_factor = config_parser.getfloat('training.generator', 'scheduler_factor') - self.config_discriminator_scheduler_factor = config_parser.getfloat('training.discriminator', 'scheduler_factor') - self.config_generator_scheduler_patience = config_parser.getint('training.generator', 'scheduler_patience') - self.config_discriminator_scheduler_patience = config_parser.getint('training.discriminator', 'scheduler_patience') + self.config_generator_learning_rate = config_parser.getfloat('training.optimizer.generator', 'learning_rate') + self.config_discriminator_learning_rate = config_parser.getfloat('training.optimizer.discriminator', 'learning_rate') + self.config_generator_momentum = config_parser.getfloat('training.optimizer.generator', 'momentum') + self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum') + self.config_generator_scheduler_factor = config_parser.getfloat('training.optimizer.generator', 'scheduler_factor') + self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor') + self.config_generator_scheduler_patience = config_parser.getint('training.optimizer.generator', 'scheduler_patience') + self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience') self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() From 35c250b0c91a9786ff9122685b9b466f83f18f36 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 17 Jun 2025 13:44:37 +0530 Subject: [PATCH 6/8] rearrange logger --- hyperswap/README.md | 12 ++++++------ hyperswap/config.ini | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/hyperswap/README.md b/hyperswap/README.md index 9fbc68a..f8592fe 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -97,12 +97,6 @@ precision = 16-mixed preview_frequency = 100 ``` -``` -[training.logger] -logger_path = .logs -logger_name = hyperswap -``` - ``` [training.optimizer.generator] learning_rate = 0.0004 @@ -119,6 +113,12 @@ scheduler_factor = 0.7 scheduler_patience = 2000 ``` +``` +[training.logger] +logger_path = .logs +logger_name = hyperswap +``` + ``` [training.output] directory_path = .outputs diff --git a/hyperswap/config.ini b/hyperswap/config.ini index b9085c9..46a5556 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -54,10 +54,6 @@ strategy = precision = preview_frequency = -[training.logger] -logger_path = -logger_name = - [training.optimizer.generator] learning_rate = momentum = @@ -70,6 +66,10 @@ momentum = scheduler_factor = scheduler_patience = +[training.logger] +logger_path = +logger_name = + [training.output] directory_path = file_pattern = From 2f28fb664bbc9f214df9ca4b0f3e457e3d356935 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 17 Jun 2025 14:43:21 +0530 Subject: [PATCH 7/8] apply crossface --- crossface/README.md | 10 +++++++++- crossface/config.ini | 6 +++++- crossface/src/training.py | 8 ++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/crossface/README.md b/crossface/README.md index 1d80794..07e68b3 100644 --- a/crossface/README.md +++ b/crossface/README.md @@ -45,10 +45,18 @@ target_path = .models/arcface_simswap.pt ``` [training.trainer] -learning_rate = 0.001 max_epochs = 4096 strategy = auto precision = 16-mixed +``` + +``` +[training.optimizer] +learning_rate = 0.001 +``` + +``` +[training.logger] logger_path = .logs logger_name = crossface_simswap ``` diff --git a/crossface/config.ini b/crossface/config.ini index a040687..4ec26ca 100644 --- a/crossface/config.ini +++ b/crossface/config.ini @@ -11,10 +11,14 @@ source_path = target_path = [training.trainer] -learning_rate = max_epochs = strategy = precision = + +[training.optimizer] +learning_rate = + +[training.logger] logger_path = logger_name = diff --git a/crossface/src/training.py b/crossface/src/training.py index 20ab9e8..6df45ca 100644 --- a/crossface/src/training.py +++ b/crossface/src/training.py @@ -23,7 +23,7 @@ class CrossFaceTrainer(LightningModule): super().__init__() self.config_source_path = config_parser.get('training.model', 'source_path') self.config_target_path = config_parser.get('training.model', 'target_path') - self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') + self.config_learning_rate = config_parser.getfloat('training.optimizer', 'learning_rate') self.crossface = CrossFace() self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval() self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval() @@ -50,7 +50,7 @@ class CrossFaceTrainer(LightningModule): return validation_score def configure_optimizers(self) -> OptimizerSet: - optimizer = torch.optim.Adam(self.parameters(), lr = self.config_learning_rate) + optimizer = torch.optim.AdamW(self.parameters(), lr = self.config_learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) optimizer_set =\ { @@ -91,8 +91,8 @@ def create_trainer() -> Trainer: config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') config_precision = CONFIG_PARSER.get('training.trainer', 'precision') - config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path') - config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name') + config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path') + config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name') config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern') logger = TensorBoardLogger(config_logger_path, config_logger_name) From f4d4914f5c36df5ab281477ad8134e4b0c9a17a5 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 17 Jun 2025 14:46:56 +0530 Subject: [PATCH 8/8] rearrange --- hyperswap/src/training.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index b88380b..621faaa 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -35,16 +35,16 @@ class HyperSwapTrainer(LightningModule): self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') - self.config_generator_learning_rate = config_parser.getfloat('training.optimizer.generator', 'learning_rate') - self.config_discriminator_learning_rate = config_parser.getfloat('training.optimizer.discriminator', 'learning_rate') - self.config_generator_momentum = config_parser.getfloat('training.optimizer.generator', 'momentum') - self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum') - self.config_generator_scheduler_factor = config_parser.getfloat('training.optimizer.generator', 'scheduler_factor') - self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor') - self.config_generator_scheduler_patience = config_parser.getint('training.optimizer.generator', 'scheduler_patience') - self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience') self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') + self.config_generator_learning_rate = config_parser.getfloat('training.optimizer.generator', 'learning_rate') + self.config_generator_momentum = config_parser.getfloat('training.optimizer.generator', 'momentum') + self.config_generator_scheduler_factor = config_parser.getfloat('training.optimizer.generator', 'scheduler_factor') + self.config_generator_scheduler_patience = config_parser.getint('training.optimizer.generator', 'scheduler_patience') + self.config_discriminator_learning_rate = config_parser.getfloat('training.optimizer.discriminator', 'learning_rate') + self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum') + self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor') + self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience') self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval() self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval() @@ -163,11 +163,11 @@ class HyperSwapTrainer(LightningModule): self.log('identity_loss', identity_loss) self.log('gaze_loss', gaze_loss) self.log('mask_loss', mask_loss) - + if do_update: generator_scheduler.step(generator_loss) discriminator_scheduler.step(discriminator_loss) - + return generator_loss def validation_step(self, batch : Batch, batch_index : int) -> Tensor: