From 80e600cbb58076f747e09e018e878dee9c977224 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Wed, 26 Mar 2025 16:51:38 +0530 Subject: [PATCH 1/3] changes --- face_swapper/src/training.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index b928679..7c72136 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -9,7 +9,7 @@ from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn -from torch.utils.data import Dataset, random_split +from torch.utils.data import ConcatDataset, Dataset, random_split from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import DynamicDataset @@ -179,6 +179,23 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T return training_dataset, validate_dataset +def combine_datasets() -> Dataset[Tensor]: + datasets = [] + + for config_section in CONFIG_PARSER.sections(): + + if config_section.startswith('training.dataset'): + dataset_config = ConfigParser() + dataset_config.add_section('training.dataset') + + for key, value in CONFIG_PARSER.items(config_section): + dataset_config.set('training.dataset', key, value) + datasets.append(DynamicDataset(dataset_config)) + + combined_dataset = ConcatDataset(datasets) + return combined_dataset + + def create_trainer() -> Trainer: config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs') config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy') @@ -216,8 +233,8 @@ def train() -> None: if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - dataset = DynamicDataset(CONFIG_PARSER) - training_loader, validation_loader = create_loaders(dataset) + combine_dataset = combine_datasets() + training_loader, validation_loader = create_loaders(combine_dataset) face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER) trainer = create_trainer() From 9df29f8a229b48fe23688df629984e3f755e7934 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Wed, 26 Mar 2025 16:54:05 +0530 Subject: [PATCH 2/3] changes --- face_swapper/src/training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 7c72136..7507f91 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -192,8 +192,8 @@ def combine_datasets() -> Dataset[Tensor]: dataset_config.set('training.dataset', key, value) datasets.append(DynamicDataset(dataset_config)) - combined_dataset = ConcatDataset(datasets) - return combined_dataset + combine_dataset = ConcatDataset(datasets) + return combine_dataset def create_trainer() -> Trainer: From cc6a99f305a69eb7a61943cf549d3f1cfa9b134a Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Wed, 26 Mar 2025 17:13:44 +0530 Subject: [PATCH 3/3] changes --- face_swapper/src/training.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 7507f91..026f2a8 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -1,7 +1,7 @@ import os import warnings from configparser import ConfigParser -from typing import Tuple +from typing import List, Tuple import torch import torchvision @@ -179,21 +179,21 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T return training_dataset, validate_dataset -def combine_datasets() -> Dataset[Tensor]: +def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]: datasets = [] - for config_section in CONFIG_PARSER.sections(): + for config_section in config_parser.sections(): if config_section.startswith('training.dataset'): - dataset_config = ConfigParser() - dataset_config.add_section('training.dataset') + current_config_parser = ConfigParser() + current_config_parser.add_section('training.dataset') - for key, value in CONFIG_PARSER.items(config_section): - dataset_config.set('training.dataset', key, value) - datasets.append(DynamicDataset(dataset_config)) + for key, value in config_parser.items(config_section): + current_config_parser.set('training.dataset', key, value) - combine_dataset = ConcatDataset(datasets) - return combine_dataset + datasets.append(DynamicDataset(current_config_parser)) + + return datasets def create_trainer() -> Trainer: @@ -233,8 +233,8 @@ def train() -> None: if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - combine_dataset = combine_datasets() - training_loader, validation_loader = create_loaders(combine_dataset) + dataset = ConcatDataset(prepare_datasets(CONFIG_PARSER)) + training_loader, validation_loader = create_loaders(dataset) face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER) trainer = create_trainer()