From 09e913233b0d2b9b6b557fce4784f686c99f97ff Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 21 Feb 2025 16:23:54 +0100 Subject: [PATCH] Follow convention of the other project --- face_swapper/src/training.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 4e2b970..5412068 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -1,6 +1,6 @@ import configparser import os -from typing import Any, Tuple +from typing import Tuple import lightning import torch @@ -10,7 +10,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn from torch.optim import Optimizer -from torch.utils.data import DataLoader as TorchDataLoader, Dataset, random_split +from torch.utils.data import DataLoader, Dataset, random_split from .dataset import DynamicDataset from .helper import calc_id_embedding @@ -94,21 +94,22 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined] -def create_loaders(dataset : Dataset[Any]) -> Tuple[TorchDataLoader[Any], TorchDataLoader[Any]]: +def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]: batch_size = CONFIG.getint('training.loader', 'batch_size') num_workers = CONFIG.getint('training.loader', 'num_workers') training_dataset, validate_dataset = split_dataset(dataset) - training_loader = TorchDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) - validation_loader = TorchDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True) + training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True) return training_loader, validation_loader -def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]: +def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]: loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio') - training_size = int(loader_split_ratio * len(dataset)) # type:ignore[operator, arg-type] - validation_size = len(dataset) - training_size # type:ignore[arg-type] - training_dataset, validate_dataset = random_split(dataset, [training_size, validation_size]) + dataset_size = len(dataset) # type:ignore[arg-type] + training_size = int(dataset_size * loader_split_ratio) + validation_size = int(dataset_size - training_size) + training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ]) return training_dataset, validate_dataset