diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/dataset.py similarity index 98% rename from face_swapper/src/data_loader.py rename to face_swapper/src/dataset.py index 7410c09..2d42149 100644 --- a/face_swapper/src/data_loader.py +++ b/face_swapper/src/dataset.py @@ -9,7 +9,7 @@ from torchvision import transforms from .types import Batch -class DataLoader(Dataset[Tensor]): +class DynamicDataset(Dataset[Tensor]): def __init__(self, file_pattern : str, same_person_probability : float) -> None: self.same_person_probability = same_person_probability self.file_paths = glob.glob(file_pattern) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index ecc70b6..4e2b970 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -12,7 +12,7 @@ from torch import Tensor, nn from torch.optim import Optimizer from torch.utils.data import DataLoader as TorchDataLoader, Dataset, random_split -from .data_loader import DataLoader +from .dataset import DynamicDataset from .helper import calc_id_embedding from .models.discriminator import Discriminator from .models.generator import Generator @@ -145,7 +145,7 @@ def train() -> None: same_person_probability = CONFIG.getfloat('training.dataset', 'same_person_probability') output_resume_path = CONFIG.get('training.output', 'resume_path') - dataset = DataLoader(dataset_file_pattern, same_person_probability) + dataset = DynamicDataset(dataset_file_pattern, same_person_probability) training_loader, validation_loader = create_loaders(dataset) face_swapper_trainer = FaceSwapperTrainer() trainer = create_trainer()