This commit is contained in:
harisreedhar
2025-06-17 13:39:24 +05:30
parent e846d88145
commit 580a179f44
3 changed files with 12 additions and 12 deletions
+2 -2
View File
@@ -104,7 +104,7 @@ logger_name = hyperswap
```
```
[training.generator]
[training.optimizer.generator]
learning_rate = 0.0004
momentum = 0.5
scheduler_factor = 0.7
@@ -112,7 +112,7 @@ scheduler_patience = 2000
```
```
[training.discriminator]
[training.optimizer.discriminator]
learning_rate = 0.0002
momentum = 0.5
scheduler_factor = 0.7
+2 -2
View File
@@ -58,13 +58,13 @@ preview_frequency =
logger_path =
logger_name =
[training.generator]
[training.optimizer.generator]
learning_rate =
momentum =
scheduler_factor =
scheduler_patience =
[training.discriminator]
[training.optimizer.discriminator]
learning_rate =
momentum =
scheduler_factor =
+8 -8
View File
@@ -35,14 +35,14 @@ class HyperSwapTrainer(LightningModule):
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_generator_learning_rate = config_parser.getfloat('training.generator', 'learning_rate')
self.config_discriminator_learning_rate = config_parser.getfloat('training.discriminator', 'learning_rate')
self.config_generator_momentum = config_parser.getfloat('training.generator', 'momentum')
self.config_discriminator_momentum = config_parser.getfloat('training.discriminator', 'momentum')
self.config_generator_scheduler_factor = config_parser.getfloat('training.generator', 'scheduler_factor')
self.config_discriminator_scheduler_factor = config_parser.getfloat('training.discriminator', 'scheduler_factor')
self.config_generator_scheduler_patience = config_parser.getint('training.generator', 'scheduler_patience')
self.config_discriminator_scheduler_patience = config_parser.getint('training.discriminator', 'scheduler_patience')
self.config_generator_learning_rate = config_parser.getfloat('training.optimizer.generator', 'learning_rate')
self.config_discriminator_learning_rate = config_parser.getfloat('training.optimizer.discriminator', 'learning_rate')
self.config_generator_momentum = config_parser.getfloat('training.optimizer.generator', 'momentum')
self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum')
self.config_generator_scheduler_factor = config_parser.getfloat('training.optimizer.generator', 'scheduler_factor')
self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor')
self.config_generator_scheduler_patience = config_parser.getint('training.optimizer.generator', 'scheduler_patience')
self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience')
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()