split config section

This commit is contained in:
harisreedhar
2025-06-17 13:26:06 +05:30
parent e894e4172a
commit e846d88145
3 changed files with 50 additions and 23 deletions
+21 -6
View File
@@ -89,19 +89,34 @@ mask_weight = 5.0
```
[training.trainer]
accumulate_size = 4
generator_learning_rate = 0.0004
discriminator_learning_rate = 0.0002
momentum = 0.5
gradient_clip = 20.0
noise_factor = 0.05
scheduler_factor = 0.7
scheduler_patience = 2000
max_epochs = 50
strategy = auto
precision = 16-mixed
preview_frequency = 100
```
```
[training.logger]
logger_path = .logs
logger_name = hyperswap
preview_frequency = 100
```
```
[training.generator]
learning_rate = 0.0004
momentum = 0.5
scheduler_factor = 0.7
scheduler_patience = 2000
```
```
[training.discriminator]
learning_rate = 0.0002
momentum = 0.5
scheduler_factor = 0.7
scheduler_patience = 2000
```
```
+15 -6
View File
@@ -47,19 +47,28 @@ mask_weight =
[training.trainer]
accumulate_size =
generator_learning_rate =
discriminator_learning_rate =
momentum =
gradient_clip =
noise_factor =
scheduler_factor =
scheduler_patience =
max_epochs =
strategy =
precision =
preview_frequency =
[training.logger]
logger_path =
logger_name =
preview_frequency =
[training.generator]
learning_rate =
momentum =
scheduler_factor =
scheduler_patience =
[training.discriminator]
learning_rate =
momentum =
scheduler_factor =
scheduler_patience =
[training.output]
directory_path =
+14 -11
View File
@@ -35,11 +35,14 @@ class HyperSwapTrainer(LightningModule):
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_generator_learning_rate = config_parser.getfloat('training.trainer', 'generator_learning_rate')
self.config_discriminator_learning_rate = config_parser.getfloat('training.trainer', 'discriminator_learning_rate')
self.config_momentum = config_parser.getfloat('training.trainer', 'momentum')
self.config_scheduler_factor = config_parser.getfloat('training.trainer', 'scheduler_factor')
self.config_scheduler_patience = config_parser.getint('training.trainer', 'scheduler_patience')
self.config_generator_learning_rate = config_parser.getfloat('training.generator', 'learning_rate')
self.config_discriminator_learning_rate = config_parser.getfloat('training.discriminator', 'learning_rate')
self.config_generator_momentum = config_parser.getfloat('training.generator', 'momentum')
self.config_discriminator_momentum = config_parser.getfloat('training.discriminator', 'momentum')
self.config_generator_scheduler_factor = config_parser.getfloat('training.generator', 'scheduler_factor')
self.config_discriminator_scheduler_factor = config_parser.getfloat('training.discriminator', 'scheduler_factor')
self.config_generator_scheduler_patience = config_parser.getint('training.generator', 'scheduler_patience')
self.config_discriminator_scheduler_patience = config_parser.getint('training.discriminator', 'scheduler_patience')
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
@@ -66,10 +69,10 @@ class HyperSwapTrainer(LightningModule):
return output_tensor, output_mask
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8)
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_scheduler_factor, patience = self.config_scheduler_patience, min_lr = 1e-8)
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_generator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_discriminator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_generator_scheduler_factor, patience = self.config_generator_scheduler_patience, min_lr = 1e-8)
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_discriminator_scheduler_factor, patience = self.config_discriminator_scheduler_patience, min_lr = 1e-8)
generator_config =\
{
@@ -234,8 +237,8 @@ def create_trainer() -> Trainer:
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path')
config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name')
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
logger = TensorBoardLogger(config_logger_path, config_logger_name)