From 2f28fb664bbc9f214df9ca4b0f3e457e3d356935 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 17 Jun 2025 14:43:21 +0530 Subject: [PATCH] apply crossface --- crossface/README.md | 10 +++++++++- crossface/config.ini | 6 +++++- crossface/src/training.py | 8 ++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/crossface/README.md b/crossface/README.md index 1d80794..07e68b3 100644 --- a/crossface/README.md +++ b/crossface/README.md @@ -45,10 +45,18 @@ target_path = .models/arcface_simswap.pt ``` [training.trainer] -learning_rate = 0.001 max_epochs = 4096 strategy = auto precision = 16-mixed +``` + +``` +[training.optimizer] +learning_rate = 0.001 +``` + +``` +[training.logger] logger_path = .logs logger_name = crossface_simswap ``` diff --git a/crossface/config.ini b/crossface/config.ini index a040687..4ec26ca 100644 --- a/crossface/config.ini +++ b/crossface/config.ini @@ -11,10 +11,14 @@ source_path = target_path = [training.trainer] -learning_rate = max_epochs = strategy = precision = + +[training.optimizer] +learning_rate = + +[training.logger] logger_path = logger_name = diff --git a/crossface/src/training.py b/crossface/src/training.py index 20ab9e8..6df45ca 100644 --- a/crossface/src/training.py +++ b/crossface/src/training.py @@ -23,7 +23,7 @@ class CrossFaceTrainer(LightningModule): super().__init__() self.config_source_path = config_parser.get('training.model', 'source_path') self.config_target_path = config_parser.get('training.model', 'target_path') - self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') + self.config_learning_rate = config_parser.getfloat('training.optimizer', 'learning_rate') self.crossface = CrossFace() self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval() self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval() @@ -50,7 +50,7 @@ class CrossFaceTrainer(LightningModule): return validation_score def configure_optimizers(self) -> OptimizerSet: - optimizer = torch.optim.Adam(self.parameters(), lr = self.config_learning_rate) + optimizer = torch.optim.AdamW(self.parameters(), lr = self.config_learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) optimizer_set =\ { @@ -91,8 +91,8 @@ def create_trainer() -> Trainer: config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') config_precision = CONFIG_PARSER.get('training.trainer', 'precision') - config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path') - config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name') + config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path') + config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name') config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern') logger = TensorBoardLogger(config_logger_path, config_logger_name)