diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index ee9a224..aa1c46d 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -55,12 +55,12 @@ class DynamicDataset(Dataset[Tensor]): target_tensor = self.transforms(target_tensor) return source_tensor, target_tensor - def prepare_same_batch(self, source_path : str) -> Batch: + def prepare_equal_batch(self, source_path : str) -> Batch: source_tensor = io.read_image(source_path) source_tensor = self.transforms(source_tensor) return source_tensor, source_tensor - def prepare_equal_batch(self, source_path : str) -> Batch: + def prepare_same_batch(self, source_path : str) -> Batch: target_directory_path = os.path.dirname(source_path) target_file_name_and_extension = random.choice(os.listdir(target_directory_path)) target_path = os.path.join(target_directory_path, target_file_name_and_extension)