Follow convention of the other project

This commit is contained in:
henryruhs
2025-02-21 16:23:54 +01:00
parent b4bbd862e2
commit 09e913233b
+10 -9
View File
@@ -1,6 +1,6 @@
import configparser
import os
from typing import Any, Tuple
from typing import Tuple
import lightning
import torch
@@ -10,7 +10,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader as TorchDataLoader, Dataset, random_split
from torch.utils.data import DataLoader, Dataset, random_split
from .dataset import DynamicDataset
from .helper import calc_id_embedding
@@ -94,21 +94,22 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
def create_loaders(dataset : Dataset[Any]) -> Tuple[TorchDataLoader[Any], TorchDataLoader[Any]]:
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = TorchDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = TorchDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True)
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]:
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
training_size = int(loader_split_ratio * len(dataset)) # type:ignore[operator, arg-type]
validation_size = len(dataset) - training_size # type:ignore[arg-type]
training_dataset, validate_dataset = random_split(dataset, [training_size, validation_size])
dataset_size = len(dataset) # type:ignore[arg-type]
training_size = int(dataset_size * loader_split_ratio)
validation_size = int(dataset_size - training_size)
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
return training_dataset, validate_dataset