diff --git a/hyperswap/README.md b/hyperswap/README.md index 538f0b2..eb773de 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -89,9 +89,13 @@ mask_weight = 5.0 ``` [training.trainer] accumulate_size = 4 -learning_rate = 0.0004 +generator_learning_rate = 0.0004 +discriminator_learning_rate = 0.0002 +momentum = 0.5 gradient_clip = 20.0 noise_factor = 0.05 +scheduler_factor = 0.7 +scheduler_patience = 2000 max_epochs = 50 strategy = auto precision = 16-mixed diff --git a/hyperswap/config.ini b/hyperswap/config.ini index 1f979c7..f93daa9 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -47,9 +47,13 @@ mask_weight = [training.trainer] accumulate_size = -learning_rate = +generator_learning_rate = +discriminator_learning_rate = +momentum = gradient_clip = noise_factor = +scheduler_factor = +scheduler_patience = max_epochs = strategy = precision = diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 13528e0..0559111 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -35,7 +35,11 @@ class HyperSwapTrainer(LightningModule): self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') - self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') + self.config_generator_learning_rate = config_parser.getfloat('training.trainer', 'generator_learning_rate') + self.config_discriminator_learning_rate = config_parser.getfloat('training.trainer', 'discriminator_learning_rate') + self.config_momentum = config_parser.getfloat('training.trainer', 'momentum') + self.config_scheduler_factor = config_parser.getfloat('training.trainer', 'scheduler_factor') + self.config_scheduler_patience = config_parser.getint('training.trainer', 'scheduler_patience') self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() @@ -62,10 +66,10 @@ class HyperSwapTrainer(LightningModule): return output_tensor, output_mask def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: - generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8) - discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate * 0.5, betas = (0.5, 0.999), weight_decay = 1e-4, eps = 1e-8) - generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8) - discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = 0.7, patience = 2000, min_lr = 1e-8) + generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) + discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8) + generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8) + discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8) generator_config =\ {