diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index 87bb5f0..5dc73b0 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -1,4 +1,5 @@ import glob +import os import random from torch import Tensor @@ -37,15 +38,20 @@ class DynamicDataset(Dataset[Tensor]): transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) - def prepare_different_batch(self, source_image_path : str) -> Batch: - target_image_path = random.choice(self.file_paths) - source_tensor = io.read_image(source_image_path) - target_tensor = io.read_image(target_image_path) + def prepare_different_batch(self, source_path : str) -> Batch: + target_path = random.choice(self.file_paths) + source_tensor = io.read_image(source_path) source_tensor = self.transforms(source_tensor) + target_tensor = io.read_image(target_path) target_tensor = self.transforms(target_tensor) return source_tensor, target_tensor - def prepare_equal_batch(self, source_image_path : str) -> Batch: - source_tensor = io.read_image(source_image_path) + def prepare_equal_batch(self, source_path : str) -> Batch: + target_directory_path = os.path.dirname(source_path) + target_file_name_and_extension = random.choice(os.listdir(target_directory_path)) + target_path = os.path.join(target_directory_path, target_file_name_and_extension) + source_tensor = io.read_image(source_path) source_tensor = self.transforms(source_tensor) - return source_tensor, source_tensor + target_tensor = io.read_image(target_path) + target_tensor = self.transforms(target_tensor) + return source_tensor, target_tensor