From 58a85a80bb5cc5c274d37691a65bd1330d06f554 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 26 Feb 2025 10:54:38 +0100 Subject: [PATCH] Use high float32 matmul precision --- embedding_converter/src/training.py | 3 +++ face_swapper/src/training.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 76a534f..0be3ce7 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -118,6 +118,9 @@ def train() -> None: dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') output_resume_path = CONFIG.get('training.output', 'resume_path') + if torch.cuda.is_available(): + torch.set_float32_matmul_precision('high') + dataset = StaticDataset(dataset_file_pattern) training_loader, validation_loader = create_loaders(dataset) embedding_converter_trainer = EmbeddingConverterTrainer() diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 046073b..29fa196 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -194,6 +194,9 @@ def train() -> None: dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio') output_resume_path = CONFIG.get('training.output', 'resume_path') + if torch.cuda.is_available(): + torch.set_float32_matmul_precision('high') + dataset = DynamicDataset(dataset_file_pattern, dataset_batch_ratio) training_loader, validation_loader = create_loaders(dataset) face_swapper_trainer = FaceSwapperTrainer()