diff --git a/hyperswap/README.md b/hyperswap/README.md index eb773de..03d6d0e 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -89,19 +89,34 @@ mask_weight = 5.0 ``` [training.trainer] accumulate_size = 4 -generator_learning_rate = 0.0004 -discriminator_learning_rate = 0.0002 -momentum = 0.5 gradient_clip = 20.0 noise_factor = 0.05 -scheduler_factor = 0.7 -scheduler_patience = 2000 max_epochs = 50 strategy = auto precision = 16-mixed +preview_frequency = 100 +``` + +``` +[training.logger] logger_path = .logs logger_name = hyperswap -preview_frequency = 100 +``` + +``` +[training.generator] +learning_rate = 0.0004 +momentum = 0.5 +scheduler_factor = 0.7 +scheduler_patience = 2000 +``` + +``` +[training.discriminator] +learning_rate = 0.0002 +momentum = 0.5 +scheduler_factor = 0.7 +scheduler_patience = 2000 ``` ``` diff --git a/hyperswap/config.ini b/hyperswap/config.ini index f93daa9..2f37ba9 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -47,19 +47,28 @@ mask_weight = [training.trainer] accumulate_size = -generator_learning_rate = -discriminator_learning_rate = -momentum = gradient_clip = noise_factor = -scheduler_factor = -scheduler_patience = max_epochs = strategy = precision = +preview_frequency = + +[training.logger] logger_path = logger_name = -preview_frequency = + +[training.generator] +learning_rate = +momentum = +scheduler_factor = +scheduler_patience = + +[training.discriminator] +learning_rate = +momentum = +scheduler_factor = +scheduler_patience = [training.output] directory_path = diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 0559111..442650f 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -35,11 +35,14 @@ class HyperSwapTrainer(LightningModule): self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') - self.config_generator_learning_rate = config_parser.getfloat('training.trainer', 'generator_learning_rate') - self.config_discriminator_learning_rate = config_parser.getfloat('training.trainer', 'discriminator_learning_rate') - self.config_momentum = config_parser.getfloat('training.trainer', 'momentum') - self.config_scheduler_factor = config_parser.getfloat('training.trainer', 'scheduler_factor') - self.config_scheduler_patience = config_parser.getint('training.trainer', 'scheduler_patience') + self.config_generator_learning_rate = config_parser.getfloat('training.generator', 'learning_rate') + self.config_discriminator_learning_rate = config_parser.getfloat('training.discriminator', 'learning_rate') + self.config_generator_momentum = config_parser.getfloat('training.generator', 'momentum') + self.config_discriminator_momentum = config_parser.getfloat('training.discriminator', 'momentum') + self.config_generator_scheduler_factor = config_parser.getfloat('training.generator', 'scheduler_factor') + self.config_discriminator_scheduler_factor = config_parser.getfloat('training.discriminator', 'scheduler_factor') + self.config_generator_scheduler_patience = config_parser.getint('training.generator', 'scheduler_patience') + self.config_discriminator_scheduler_patience = config_parser.getint('training.discriminator', 'scheduler_patience') self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() @@ -66,10 +69,10 @@ class HyperSwapTrainer(LightningModule): return output_tensor, output_mask def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: - generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) - discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) - generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8) - discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8) + generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_generator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) + discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_discriminator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) + generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_generator_scheduler_factor, patience = self.config_generator_scheduler_patience, min_lr = 1e-8) + discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_discriminator_scheduler_factor, patience = self.config_discriminator_scheduler_patience, min_lr = 1e-8) generator_config =\ { @@ -234,8 +237,8 @@ def create_trainer() -> Trainer: config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') config_precision = CONFIG_PARSER.get('training.trainer', 'precision') - config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path') - config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name') + config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path') + config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name') config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern') logger = TensorBoardLogger(config_logger_path, config_logger_name)