diff --git a/crossface/README.md b/crossface/README.md index 1d80794..07e68b3 100644 --- a/crossface/README.md +++ b/crossface/README.md @@ -45,10 +45,18 @@ target_path = .models/arcface_simswap.pt ``` [training.trainer] -learning_rate = 0.001 max_epochs = 4096 strategy = auto precision = 16-mixed +``` + +``` +[training.optimizer] +learning_rate = 0.001 +``` + +``` +[training.logger] logger_path = .logs logger_name = crossface_simswap ``` diff --git a/crossface/config.ini b/crossface/config.ini index a040687..4ec26ca 100644 --- a/crossface/config.ini +++ b/crossface/config.ini @@ -11,10 +11,14 @@ source_path = target_path = [training.trainer] -learning_rate = max_epochs = strategy = precision = + +[training.optimizer] +learning_rate = + +[training.logger] logger_path = logger_name = diff --git a/crossface/src/training.py b/crossface/src/training.py index 20ab9e8..6df45ca 100644 --- a/crossface/src/training.py +++ b/crossface/src/training.py @@ -23,7 +23,7 @@ class CrossFaceTrainer(LightningModule): super().__init__() self.config_source_path = config_parser.get('training.model', 'source_path') self.config_target_path = config_parser.get('training.model', 'target_path') - self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') + self.config_learning_rate = config_parser.getfloat('training.optimizer', 'learning_rate') self.crossface = CrossFace() self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval() self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval() @@ -50,7 +50,7 @@ class CrossFaceTrainer(LightningModule): return validation_score def configure_optimizers(self) -> OptimizerSet: - optimizer = torch.optim.Adam(self.parameters(), lr = self.config_learning_rate) + optimizer = torch.optim.AdamW(self.parameters(), lr = self.config_learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) optimizer_set =\ { @@ -91,8 +91,8 @@ def create_trainer() -> Trainer: config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') config_precision = CONFIG_PARSER.get('training.trainer', 'precision') - config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path') - config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name') + config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path') + config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name') config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern') logger = TensorBoardLogger(config_logger_path, config_logger_name) diff --git a/hyperswap/README.md b/hyperswap/README.md index 538f0b2..f8592fe 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -89,15 +89,34 @@ mask_weight = 5.0 ``` [training.trainer] accumulate_size = 4 -learning_rate = 0.0004 gradient_clip = 20.0 noise_factor = 0.05 max_epochs = 50 strategy = auto precision = 16-mixed +preview_frequency = 100 +``` + +``` +[training.optimizer.generator] +learning_rate = 0.0004 +momentum = 0.5 +scheduler_factor = 0.7 +scheduler_patience = 2000 +``` + +``` +[training.optimizer.discriminator] +learning_rate = 0.0002 +momentum = 0.5 +scheduler_factor = 0.7 +scheduler_patience = 2000 +``` + +``` +[training.logger] logger_path = .logs logger_name = hyperswap -preview_frequency = 100 ``` ``` diff --git a/hyperswap/config.ini b/hyperswap/config.ini index 1f979c7..46a5556 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -47,15 +47,28 @@ mask_weight = [training.trainer] accumulate_size = -learning_rate = gradient_clip = noise_factor = max_epochs = strategy = precision = +preview_frequency = + +[training.optimizer.generator] +learning_rate = +momentum = +scheduler_factor = +scheduler_patience = + +[training.optimizer.discriminator] +learning_rate = +momentum = +scheduler_factor = +scheduler_patience = + +[training.logger] logger_path = logger_name = -preview_frequency = [training.output] directory_path = diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 0e57faf..621faaa 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -35,9 +35,16 @@ class HyperSwapTrainer(LightningModule): self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') - self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') + self.config_generator_learning_rate = config_parser.getfloat('training.optimizer.generator', 'learning_rate') + self.config_generator_momentum = config_parser.getfloat('training.optimizer.generator', 'momentum') + self.config_generator_scheduler_factor = config_parser.getfloat('training.optimizer.generator', 'scheduler_factor') + self.config_generator_scheduler_patience = config_parser.getint('training.optimizer.generator', 'scheduler_patience') + self.config_discriminator_learning_rate = config_parser.getfloat('training.optimizer.discriminator', 'learning_rate') + self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum') + self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor') + self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience') self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval() self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval() @@ -62,18 +69,17 @@ class HyperSwapTrainer(LightningModule): return output_tensor, output_mask def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: - generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2) - discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2) + generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_generator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) + discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_discriminator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) + generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_generator_scheduler_factor, patience = self.config_generator_scheduler_patience, min_lr = 1e-8) + discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_discriminator_scheduler_factor, patience = self.config_discriminator_scheduler_patience, min_lr = 1e-8) generator_config =\ { 'optimizer': generator_optimizer, 'lr_scheduler': { - 'scheduler': generator_scheduler, - 'interval': 'step' + 'scheduler': generator_scheduler } } discriminator_config =\ @@ -81,8 +87,7 @@ class HyperSwapTrainer(LightningModule): 'optimizer': discriminator_optimizer, 'lr_scheduler': { - 'scheduler': discriminator_scheduler, - 'interval': 'step' + 'scheduler': discriminator_scheduler } } return generator_config, discriminator_config @@ -91,6 +96,7 @@ class HyperSwapTrainer(LightningModule): source_tensor, target_tensor = batch do_update = (batch_index + 1) % self.config_accumulate_size == 0 generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] + generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0)) @@ -125,7 +131,7 @@ class HyperSwapTrainer(LightningModule): self.clip_gradients( generator_optimizer, gradient_clip_val = self.config_gradient_clip, - gradient_clip_algorithm = 'value' + gradient_clip_algorithm = 'norm' ) generator_optimizer.step() generator_optimizer.zero_grad() @@ -139,7 +145,7 @@ class HyperSwapTrainer(LightningModule): self.clip_gradients( discriminator_optimizer, gradient_clip_val = self.config_gradient_clip, - gradient_clip_algorithm = 'value' + gradient_clip_algorithm = 'norm' ) discriminator_optimizer.step() discriminator_optimizer.zero_grad() @@ -157,6 +163,11 @@ class HyperSwapTrainer(LightningModule): self.log('identity_loss', identity_loss) self.log('gaze_loss', gaze_loss) self.log('mask_loss', mask_loss) + + if do_update: + generator_scheduler.step(generator_loss) + discriminator_scheduler.step(discriminator_loss) + return generator_loss def validation_step(self, batch : Batch, batch_index : int) -> Tensor: @@ -226,8 +237,8 @@ def create_trainer() -> Trainer: config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') config_precision = CONFIG_PARSER.get('training.trainer', 'precision') - config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path') - config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name') + config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path') + config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name') config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern') logger = TensorBoardLogger(config_logger_path, config_logger_name)