diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index e950860..35a24de 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -10,7 +10,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn from torch.optim import Optimizer -from torch.utils.data import DataLoader, Subset +from torch.utils.data import DataLoader as TorchDataLoader, Subset from .data_loader import DataLoader from .helper import calc_id_embedding @@ -132,8 +132,8 @@ def train() -> None: resume_file_path = CONFIG.get('training.output', 'resume_file_path') dataset = DataLoader(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability) - training_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) - validation_loader = DataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + training_loader = TorchDataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + validation_loader = TorchDataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) face_swapper_trainer = FaceSwapperTrainer() trainer = create_trainer()