Files
facefusion-labs/crossface/src/dataset.py
T
2025-04-24 12:42:53 +02:00

35 lines
1.0 KiB
Python

import glob
from configparser import ConfigParser
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
from .types import Batch
class StaticDataset(Dataset[Tensor]):
def __init__(self, config_parser : ConfigParser) -> None:
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
self.file_paths = glob.glob(self.config_file_pattern)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
file_path = self.file_paths[index]
temp_tensor = io.read_image(file_path)
return self.transforms(temp_tensor)
def __len__(self) -> int:
return len(self.file_paths)
@staticmethod
def compose_transforms() -> transforms:
return transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])