Similarity validation for embedding converter

This commit is contained in:
henryruhs
2025-02-21 08:29:03 +01:00
parent 0a50e2d706
commit 9bd68c3d14
+3 -3
View File
@@ -10,7 +10,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Subset
from torch.utils.data import DataLoader as TorchDataLoader, Subset
from .data_loader import DataLoader
from .helper import calc_id_embedding
@@ -132,8 +132,8 @@ def train() -> None:
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
dataset = DataLoader(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
training_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = DataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
training_loader = TorchDataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = TorchDataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
face_swapper_trainer = FaceSwapperTrainer()
trainer = create_trainer()