Rename data loader

This commit is contained in:
henryruhs
2025-02-21 08:26:22 +01:00
parent 40dcef7fc7
commit 0a50e2d706
2 changed files with 3 additions and 3 deletions
+1 -1
View File
@@ -11,7 +11,7 @@ from torchvision import transforms
from .types import Batch, ImagePathList, ImagePathSet
class DataLoaderVGG(TensorDataset):
class DataLoader(TensorDataset):
def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_directory_pattern : str, same_person_probability : float) -> None:
self.same_person_probability = same_person_probability
self.directory_paths = glob.glob(dataset_directory_pattern.format(dataset_path))
+2 -2
View File
@@ -12,7 +12,7 @@ from torch import Tensor, nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Subset
from .data_loader import DataLoaderVGG
from .data_loader import DataLoader
from .helper import calc_id_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
@@ -131,7 +131,7 @@ def train() -> None:
num_workers = CONFIG.getint('training.loader', 'num_workers')
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
dataset = DataLoader(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
training_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = DataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
face_swapper_trainer = FaceSwapperTrainer()