mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
import glob
|
|
from configparser import ConfigParser
|
|
|
|
from torch import Tensor
|
|
from torch.utils.data import Dataset
|
|
from torchvision import io, transforms
|
|
|
|
from .types import Batch
|
|
|
|
|
|
class StaticDataset(Dataset[Tensor]):
|
|
def __init__(self, config_parser : ConfigParser) -> None:
|
|
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
|
|
self.file_paths = glob.glob(self.config_file_pattern)
|
|
self.transforms = self.compose_transforms()
|
|
|
|
def __getitem__(self, index : int) -> Batch:
|
|
file_path = self.file_paths[index]
|
|
temp_tensor = io.read_image(file_path)
|
|
return self.transforms(temp_tensor)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.file_paths)
|
|
|
|
@staticmethod
|
|
def compose_transforms() -> transforms:
|
|
return transforms.Compose(
|
|
[
|
|
transforms.ToPILImage(),
|
|
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
|
|
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|