From 15ee6fa763e1fb8f1bb9bead0d6934bf9cce3afc Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 21 Feb 2025 09:11:43 +0100 Subject: [PATCH] Simplify sizes --- embedding_converter/src/training.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 2f90890..ab31c2b 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -83,9 +83,10 @@ def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataL def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]: loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio') - training_size = int(loader_split_ratio * len(dataset)) # type:ignore[operator, arg-type] - validation_size = len(dataset) - training_size # type:ignore[arg-type] - training_dataset, validate_dataset = random_split(dataset, [training_size, validation_size]) + dataset_size = len(dataset) + training_size = dataset_size * loader_split_ratio + validation_size = dataset_size - training_size + training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ]) return training_dataset, validate_dataset