mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
82 lines
3.0 KiB
Python
82 lines
3.0 KiB
Python
import configparser
|
|
import glob
|
|
import random
|
|
|
|
import cv2
|
|
import torchvision.transforms as transforms
|
|
import tqdm
|
|
from PIL import Image
|
|
from torch.utils.data import TensorDataset
|
|
|
|
from .typing import Batch
|
|
|
|
CONFIG = configparser.ConfigParser()
|
|
CONFIG.read('config.ini')
|
|
|
|
|
|
def read_image(image_path: str) -> Image.Image:
|
|
image = cv2.imread(image_path)[:, :, ::-1]
|
|
pil_image = Image.fromarray(image)
|
|
return pil_image
|
|
|
|
|
|
class DataLoaderVGG(TensorDataset):
|
|
def __init__(self, dataset_path : str) -> None:
|
|
self.same_person_probability = float(CONFIG.get('preparing.dataloader', 'same_person_probability'))
|
|
self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path))
|
|
self.folder_paths = glob.glob('{}/*'.format(dataset_path))
|
|
self.image_path_dict = {}
|
|
self._current_index = 0
|
|
|
|
for folder_path in tqdm.tqdm(self.folder_paths):
|
|
image_paths = glob.glob('{}/*'.format(folder_path))
|
|
self.image_path_dict[folder_path] = image_paths
|
|
self.dataset_total = len(self.image_paths)
|
|
self.transforms_basic = transforms.Compose(
|
|
[
|
|
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|
|
self.transforms_moderate = transforms.Compose(
|
|
[
|
|
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
|
|
transforms.ToTensor(),
|
|
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
|
|
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|
|
self.transforms_complex = transforms.Compose(
|
|
[
|
|
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
|
|
transforms.ToTensor(),
|
|
transforms.RandomHorizontalFlip(p = 0.5),
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1),
|
|
transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
|
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
|
|
])
|
|
|
|
def __getitem__(self, item : int) -> Batch:
|
|
source_image_path = self.image_paths[item]
|
|
source = read_image(source_image_path)
|
|
|
|
if random.random() > self.same_person_probability:
|
|
is_same_person = 0
|
|
target_image_path = random.choice(self.image_paths)
|
|
target = read_image(target_image_path)
|
|
source_transform = self.transforms_moderate(source)
|
|
target_transform = self.transforms_complex(target)
|
|
else:
|
|
is_same_person = 1
|
|
source_folder_path = '/'.join(source_image_path.split('/')[:-1])
|
|
target_image_path = random.choice(self.image_path_dict[source_folder_path])
|
|
target = read_image(target_image_path)
|
|
source_transform = self.transforms_basic(source)
|
|
target_transform = self.transforms_basic(target)
|
|
|
|
return source_transform, target_transform, is_same_person
|
|
|
|
def __len__(self) -> int:
|
|
return self.dataset_total
|