mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Introduce Transform classes, Add albumentations
This commit is contained in:
@@ -2,6 +2,7 @@ import glob
|
||||
import os
|
||||
import random
|
||||
|
||||
import albumentations
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
@@ -37,16 +38,12 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
[
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
|
||||
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)),
|
||||
AlterTransform(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(self.warp_tensor),
|
||||
WarpTransform(self.warp_template),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
def warp_tensor(self, temp_tensor : Tensor) -> Tensor:
|
||||
return warp_tensor(temp_tensor.unsqueeze(0), self.warp_template).squeeze(0)
|
||||
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(self.file_paths)
|
||||
source_tensor = io.read_image(source_path)
|
||||
@@ -69,3 +66,35 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
|
||||
class AlterTransform:
|
||||
def __init__(self) -> None:
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __call__(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.numpy().transpose(1, 2, 0)
|
||||
return self.transforms(temp_tensor).get('image')
|
||||
|
||||
@staticmethod
|
||||
def compose_transforms() -> albumentations.Compose:
|
||||
return albumentations.Compose(
|
||||
[
|
||||
albumentations.RandomBrightnessContrast(p = 0.3),
|
||||
albumentations.OneOf(
|
||||
[
|
||||
albumentations.MotionBlur(p = 0.1),
|
||||
albumentations.MedianBlur(p = 0.1)
|
||||
], p = 0.3),
|
||||
albumentations.ColorJitter(p = 0.1),
|
||||
])
|
||||
|
||||
|
||||
class WarpTransform:
|
||||
def __init__(self, warp_template : WarpTemplate) -> None:
|
||||
self.warp_template = warp_template
|
||||
|
||||
def __call__(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.unsqueeze(0)
|
||||
return warp_tensor(temp_tensor, self.warp_template).squeeze(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user