mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Similarity validation for embedding converter
This commit is contained in:
@@ -10,7 +10,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torch.utils.data import DataLoader as TorchDataLoader, Subset
|
||||
|
||||
from .data_loader import DataLoader
|
||||
from .helper import calc_id_embedding
|
||||
@@ -132,8 +132,8 @@ def train() -> None:
|
||||
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
|
||||
|
||||
dataset = DataLoader(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
|
||||
training_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = DataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
training_loader = TorchDataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = TorchDataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
face_swapper_trainer = FaceSwapperTrainer()
|
||||
trainer = create_trainer()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user