diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 76a534f..0be3ce7 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -118,6 +118,9 @@ def train() -> None: dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') output_resume_path = CONFIG.get('training.output', 'resume_path') + if torch.cuda.is_available(): + torch.set_float32_matmul_precision('high') + dataset = StaticDataset(dataset_file_pattern) training_loader, validation_loader = create_loaders(dataset) embedding_converter_trainer = EmbeddingConverterTrainer() diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 046073b..29fa196 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -194,6 +194,9 @@ def train() -> None: dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio') output_resume_path = CONFIG.get('training.output', 'resume_path') + if torch.cuda.is_available(): + torch.set_float32_matmul_precision('high') + dataset = DynamicDataset(dataset_file_pattern, dataset_batch_ratio) training_loader, validation_loader = create_loaders(dataset) face_swapper_trainer = FaceSwapperTrainer()