diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 7507f91..026f2a8 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -1,7 +1,7 @@ import os import warnings from configparser import ConfigParser -from typing import Tuple +from typing import List, Tuple import torch import torchvision @@ -179,21 +179,21 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T return training_dataset, validate_dataset -def combine_datasets() -> Dataset[Tensor]: +def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]: datasets = [] - for config_section in CONFIG_PARSER.sections(): + for config_section in config_parser.sections(): if config_section.startswith('training.dataset'): - dataset_config = ConfigParser() - dataset_config.add_section('training.dataset') + current_config_parser = ConfigParser() + current_config_parser.add_section('training.dataset') - for key, value in CONFIG_PARSER.items(config_section): - dataset_config.set('training.dataset', key, value) - datasets.append(DynamicDataset(dataset_config)) + for key, value in config_parser.items(config_section): + current_config_parser.set('training.dataset', key, value) - combine_dataset = ConcatDataset(datasets) - return combine_dataset + datasets.append(DynamicDataset(current_config_parser)) + + return datasets def create_trainer() -> Trainer: @@ -233,8 +233,8 @@ def train() -> None: if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - combine_dataset = combine_datasets() - training_loader, validation_loader = create_loaders(combine_dataset) + dataset = ConcatDataset(prepare_datasets(CONFIG_PARSER)) + training_loader, validation_loader = create_loaders(dataset) face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER) trainer = create_trainer()