Use high float32 matmul precision

This commit is contained in:
henryruhs
2025-02-26 10:54:38 +01:00
parent ab0a59fb74
commit 58a85a80bb
2 changed files with 6 additions and 0 deletions
+3
View File
@@ -118,6 +118,9 @@ def train() -> None:
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
output_resume_path = CONFIG.get('training.output', 'resume_path')
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = StaticDataset(dataset_file_pattern)
training_loader, validation_loader = create_loaders(dataset)
embedding_converter_trainer = EmbeddingConverterTrainer()
+3
View File
@@ -194,6 +194,9 @@ def train() -> None:
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
output_resume_path = CONFIG.get('training.output', 'resume_path')
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = DynamicDataset(dataset_file_pattern, dataset_batch_ratio)
training_loader, validation_loader = create_loaders(dataset)
face_swapper_trainer = FaceSwapperTrainer()