Merge pull request #67 from facefusion/multi-dataset

Multi dataset
This commit is contained in:
Harisreedhar
2025-03-26 17:43:35 +05:30
committed by GitHub
+20 -3
View File
@@ -1,7 +1,7 @@
import os
import warnings
from configparser import ConfigParser
from typing import Tuple
from typing import List, Tuple
import torch
import torchvision
@@ -9,7 +9,7 @@ from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.utils.data import Dataset, random_split
from torch.utils.data import ConcatDataset, Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import DynamicDataset
@@ -179,6 +179,23 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T
return training_dataset, validate_dataset
def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]:
datasets = []
for config_section in config_parser.sections():
if config_section.startswith('training.dataset'):
current_config_parser = ConfigParser()
current_config_parser.add_section('training.dataset')
for key, value in config_parser.items(config_section):
current_config_parser.set('training.dataset', key, value)
datasets.append(DynamicDataset(current_config_parser))
return datasets
def create_trainer() -> Trainer:
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
@@ -216,7 +233,7 @@ def train() -> None:
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = DynamicDataset(CONFIG_PARSER)
dataset = ConcatDataset(prepare_datasets(CONFIG_PARSER))
training_loader, validation_loader = create_loaders(dataset)
face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER)
trainer = create_trainer()