mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Follow convention of the other project
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import configparser
|
||||
import os
|
||||
from typing import Any, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import lightning
|
||||
import torch
|
||||
@@ -10,7 +10,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader as TorchDataLoader, Dataset, random_split
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
|
||||
from .dataset import DynamicDataset
|
||||
from .helper import calc_id_embedding
|
||||
@@ -94,21 +94,22 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Any]) -> Tuple[TorchDataLoader[Any], TorchDataLoader[Any]]:
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = TorchDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = TorchDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True)
|
||||
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]:
|
||||
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
|
||||
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
|
||||
training_size = int(loader_split_ratio * len(dataset)) # type:ignore[operator, arg-type]
|
||||
validation_size = len(dataset) - training_size # type:ignore[arg-type]
|
||||
training_dataset, validate_dataset = random_split(dataset, [training_size, validation_size])
|
||||
dataset_size = len(dataset) # type:ignore[arg-type]
|
||||
training_size = int(dataset_size * loader_split_ratio)
|
||||
validation_size = int(dataset_size - training_size)
|
||||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user