From 0a50e2d7062d205af6733604ac451839b63ccfb4 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 21 Feb 2025 08:26:22 +0100 Subject: [PATCH] Rename data loader --- face_swapper/src/data_loader.py | 2 +- face_swapper/src/training.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/data_loader.py index 50f3f27..bcc3b15 100644 --- a/face_swapper/src/data_loader.py +++ b/face_swapper/src/data_loader.py @@ -11,7 +11,7 @@ from torchvision import transforms from .types import Batch, ImagePathList, ImagePathSet -class DataLoaderVGG(TensorDataset): +class DataLoader(TensorDataset): def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_directory_pattern : str, same_person_probability : float) -> None: self.same_person_probability = same_person_probability self.directory_paths = glob.glob(dataset_directory_pattern.format(dataset_path)) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 16c9e86..e950860 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -12,7 +12,7 @@ from torch import Tensor, nn from torch.optim import Optimizer from torch.utils.data import DataLoader, Subset -from .data_loader import DataLoaderVGG +from .data_loader import DataLoader from .helper import calc_id_embedding from .models.discriminator import Discriminator from .models.generator import Generator @@ -131,7 +131,7 @@ def train() -> None: num_workers = CONFIG.getint('training.loader', 'num_workers') resume_file_path = CONFIG.get('training.output', 'resume_file_path') - dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability) + dataset = DataLoader(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability) training_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) validation_loader = DataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) face_swapper_trainer = FaceSwapperTrainer()