diff --git a/face_swapper/config.ini b/face_swapper/config.ini index fbc18eb..c8a9f50 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -1,6 +1,6 @@ [training.dataset] file_pattern = -same_person_probability = +equal_probability = [training.loader] batch_size = diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index 2d42149..e947de3 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -10,18 +10,18 @@ from .types import Batch class DynamicDataset(Dataset[Tensor]): - def __init__(self, file_pattern : str, same_person_probability : float) -> None: - self.same_person_probability = same_person_probability + def __init__(self, file_pattern : str, equal_probability : float) -> None: self.file_paths = glob.glob(file_pattern) self.transforms = self.compose_transforms() + self.equal_probability = equal_probability - def __getitem__(self, index : int) -> Batch: # type:ignore[override] + def __getitem__(self, index : int) -> Batch: source_image_path = self.file_paths[index] - if random.random() < self.same_person_probability: - return self.prepare_same_person(source_image_path) + if random.random() < self.equal_probability: + return self.prepare_equal_batch(source_image_path) - return self.prepare_different_person(source_image_path) + return self.prepare_different_batch(source_image_path) def __len__(self) -> int: return len(self.file_paths) @@ -39,7 +39,7 @@ class DynamicDataset(Dataset[Tensor]): transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) - def prepare_different_person(self, source_image_path : str) -> Batch: + def prepare_different_batch(self, source_image_path : str) -> Batch: target_image_path = random.choice(self.file_paths) source_vision_frame = cv2.imread(source_image_path) target_vision_frame = cv2.imread(target_image_path) @@ -47,8 +47,7 @@ class DynamicDataset(Dataset[Tensor]): target_tensor = self.transforms(target_vision_frame) return source_tensor, target_tensor - def prepare_same_person(self, source_image_path : str) -> Batch: + def prepare_equal_batch(self, source_image_path : str) -> Batch: source_vision_frame = cv2.imread(source_image_path) source_tensor = self.transforms(source_vision_frame) - target_tensor = source_tensor.clone() - return source_tensor, target_tensor + return source_tensor, source_tensor diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 5412068..20cfa77 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -143,10 +143,10 @@ def create_trainer() -> Trainer: def train() -> None: dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') - same_person_probability = CONFIG.getfloat('training.dataset', 'same_person_probability') + dataset_equal_probability = CONFIG.getfloat('training.dataset', 'equal_probability') output_resume_path = CONFIG.get('training.output', 'resume_path') - dataset = DynamicDataset(dataset_file_pattern, same_person_probability) + dataset = DynamicDataset(dataset_file_pattern, dataset_equal_probability) training_loader, validation_loader = create_loaders(dataset) face_swapper_trainer = FaceSwapperTrainer() trainer = create_trainer()