mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
follow the not invented here syndrome
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
[training.dataset]
|
||||
file_pattern =
|
||||
same_person_probability =
|
||||
equal_probability =
|
||||
|
||||
[training.loader]
|
||||
batch_size =
|
||||
|
||||
@@ -10,18 +10,18 @@ from .types import Batch
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
def __init__(self, file_pattern : str, same_person_probability : float) -> None:
|
||||
self.same_person_probability = same_person_probability
|
||||
def __init__(self, file_pattern : str, equal_probability : float) -> None:
|
||||
self.file_paths = glob.glob(file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
self.equal_probability = equal_probability
|
||||
|
||||
def __getitem__(self, index : int) -> Batch: # type:ignore[override]
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
source_image_path = self.file_paths[index]
|
||||
|
||||
if random.random() < self.same_person_probability:
|
||||
return self.prepare_same_person(source_image_path)
|
||||
if random.random() < self.equal_probability:
|
||||
return self.prepare_equal_batch(source_image_path)
|
||||
|
||||
return self.prepare_different_person(source_image_path)
|
||||
return self.prepare_different_batch(source_image_path)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.file_paths)
|
||||
@@ -39,7 +39,7 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
def prepare_different_person(self, source_image_path : str) -> Batch:
|
||||
def prepare_different_batch(self, source_image_path : str) -> Batch:
|
||||
target_image_path = random.choice(self.file_paths)
|
||||
source_vision_frame = cv2.imread(source_image_path)
|
||||
target_vision_frame = cv2.imread(target_image_path)
|
||||
@@ -47,8 +47,7 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
target_tensor = self.transforms(target_vision_frame)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_same_person(self, source_image_path : str) -> Batch:
|
||||
def prepare_equal_batch(self, source_image_path : str) -> Batch:
|
||||
source_vision_frame = cv2.imread(source_image_path)
|
||||
source_tensor = self.transforms(source_vision_frame)
|
||||
target_tensor = source_tensor.clone()
|
||||
return source_tensor, target_tensor
|
||||
return source_tensor, source_tensor
|
||||
|
||||
@@ -143,10 +143,10 @@ def create_trainer() -> Trainer:
|
||||
|
||||
def train() -> None:
|
||||
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
|
||||
same_person_probability = CONFIG.getfloat('training.dataset', 'same_person_probability')
|
||||
dataset_equal_probability = CONFIG.getfloat('training.dataset', 'equal_probability')
|
||||
output_resume_path = CONFIG.get('training.output', 'resume_path')
|
||||
|
||||
dataset = DynamicDataset(dataset_file_pattern, same_person_probability)
|
||||
dataset = DynamicDataset(dataset_file_pattern, dataset_equal_probability)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
face_swapper_trainer = FaceSwapperTrainer()
|
||||
trainer = create_trainer()
|
||||
|
||||
Reference in New Issue
Block a user