diff --git a/hyperswap/README.md b/hyperswap/README.md index 03d6d0e..9fbc68a 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -104,7 +104,7 @@ logger_name = hyperswap ``` ``` -[training.generator] +[training.optimizer.generator] learning_rate = 0.0004 momentum = 0.5 scheduler_factor = 0.7 @@ -112,7 +112,7 @@ scheduler_patience = 2000 ``` ``` -[training.discriminator] +[training.optimizer.discriminator] learning_rate = 0.0002 momentum = 0.5 scheduler_factor = 0.7 diff --git a/hyperswap/config.ini b/hyperswap/config.ini index 2f37ba9..b9085c9 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -58,13 +58,13 @@ preview_frequency = logger_path = logger_name = -[training.generator] +[training.optimizer.generator] learning_rate = momentum = scheduler_factor = scheduler_patience = -[training.discriminator] +[training.optimizer.discriminator] learning_rate = momentum = scheduler_factor = diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 442650f..b88380b 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -35,14 +35,14 @@ class HyperSwapTrainer(LightningModule): self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') - self.config_generator_learning_rate = config_parser.getfloat('training.generator', 'learning_rate') - self.config_discriminator_learning_rate = config_parser.getfloat('training.discriminator', 'learning_rate') - self.config_generator_momentum = config_parser.getfloat('training.generator', 'momentum') - self.config_discriminator_momentum = config_parser.getfloat('training.discriminator', 'momentum') - self.config_generator_scheduler_factor = config_parser.getfloat('training.generator', 'scheduler_factor') - self.config_discriminator_scheduler_factor = config_parser.getfloat('training.discriminator', 'scheduler_factor') - self.config_generator_scheduler_patience = config_parser.getint('training.generator', 'scheduler_patience') - self.config_discriminator_scheduler_patience = config_parser.getint('training.discriminator', 'scheduler_patience') + self.config_generator_learning_rate = config_parser.getfloat('training.optimizer.generator', 'learning_rate') + self.config_discriminator_learning_rate = config_parser.getfloat('training.optimizer.discriminator', 'learning_rate') + self.config_generator_momentum = config_parser.getfloat('training.optimizer.generator', 'momentum') + self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum') + self.config_generator_scheduler_factor = config_parser.getfloat('training.optimizer.generator', 'scheduler_factor') + self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor') + self.config_generator_scheduler_patience = config_parser.getint('training.optimizer.generator', 'scheduler_patience') + self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience') self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()