mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Follow convention of the other project
This commit is contained in:
@@ -9,7 +9,7 @@ from torchvision import transforms
|
||||
from .types import Batch
|
||||
|
||||
|
||||
class DataLoader(Dataset[Tensor]):
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
def __init__(self, file_pattern : str, same_person_probability : float) -> None:
|
||||
self.same_person_probability = same_person_probability
|
||||
self.file_paths = glob.glob(file_pattern)
|
||||
@@ -12,7 +12,7 @@ from torch import Tensor, nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader as TorchDataLoader, Dataset, random_split
|
||||
|
||||
from .data_loader import DataLoader
|
||||
from .dataset import DynamicDataset
|
||||
from .helper import calc_id_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
@@ -145,7 +145,7 @@ def train() -> None:
|
||||
same_person_probability = CONFIG.getfloat('training.dataset', 'same_person_probability')
|
||||
output_resume_path = CONFIG.get('training.output', 'resume_path')
|
||||
|
||||
dataset = DataLoader(dataset_file_pattern, same_person_probability)
|
||||
dataset = DynamicDataset(dataset_file_pattern, same_person_probability)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
face_swapper_trainer = FaceSwapperTrainer()
|
||||
trainer = create_trainer()
|
||||
|
||||
Reference in New Issue
Block a user