From f90fd73b54c4555eab55f30dd8e8b24f54fdbc63 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 11 Mar 2025 18:16:36 +0530 Subject: [PATCH] add strategy to config --- face_swapper/README.md | 1 + face_swapper/config.ini | 1 + face_swapper/src/training.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/face_swapper/README.md b/face_swapper/README.md index e2b53f9..33f3e6d 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -91,6 +91,7 @@ learning_rate = 0.0004 max_epochs = 50 precision = 16-mixed preview_frequency = 250 +strategy = auto ``` ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 271cab6..1566739 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -49,6 +49,7 @@ learning_rate = max_epochs = precision = preview_frequency = +strategy = [training.output] directory_path = diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 7bda57b..a79d142 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -206,6 +206,7 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T def create_trainer() -> Trainer: config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') config_precision = CONFIG_PARSER.get('training.trainer', 'precision') + config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path') config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern') logger = TensorBoardLogger('.logs', name = 'face_swapper') @@ -215,6 +216,7 @@ def create_trainer() -> Trainer: log_every_n_steps = 10, max_epochs = config_max_epochs, precision = config_precision, + strategy = config_strategy, callbacks = [ ModelCheckpoint(