mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Restore dataset behaviour for same person
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
|
||||
from torch import Tensor
|
||||
@@ -37,15 +38,20 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
def prepare_different_batch(self, source_image_path : str) -> Batch:
|
||||
target_image_path = random.choice(self.file_paths)
|
||||
source_tensor = io.read_image(source_image_path)
|
||||
target_tensor = io.read_image(target_image_path)
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(self.file_paths)
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_equal_batch(self, source_image_path : str) -> Batch:
|
||||
source_tensor = io.read_image(source_image_path)
|
||||
def prepare_equal_batch(self, source_path : str) -> Batch:
|
||||
target_directory_path = os.path.dirname(source_path)
|
||||
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
|
||||
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
return source_tensor, source_tensor
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
Reference in New Issue
Block a user