Restore dataset behaviour for same person

This commit is contained in:
henryruhs
2025-02-25 12:54:54 +01:00
parent e8cc2bfff1
commit 484a49c27d
+13 -7
View File
@@ -1,4 +1,5 @@
import glob
import os
import random
from torch import Tensor
@@ -37,15 +38,20 @@ class DynamicDataset(Dataset[Tensor]):
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def prepare_different_batch(self, source_image_path : str) -> Batch:
target_image_path = random.choice(self.file_paths)
source_tensor = io.read_image(source_image_path)
target_tensor = io.read_image(target_image_path)
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(self.file_paths)
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor
def prepare_equal_batch(self, source_image_path : str) -> Batch:
source_tensor = io.read_image(source_image_path)
def prepare_equal_batch(self, source_path : str) -> Batch:
target_directory_path = os.path.dirname(source_path)
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
return source_tensor, source_tensor
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor