mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Modernize data loader, remove read image helper
This commit is contained in:
@@ -1,34 +1,29 @@
|
||||
import glob
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import transforms from torchvision
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .helper import read_image
|
||||
from .types import Batch, Paths
|
||||
|
||||
|
||||
class DataLoaderRecognition(Dataset[torch.Tensor]):
|
||||
def __init__(self, dataset_file_pattern : str) -> None:
|
||||
self.image_paths = self.prepare_image_paths(dataset_file_pattern)
|
||||
self.dataset_total = len(self.image_paths)
|
||||
self.image_paths = glob.glob(dataset_file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
target_image_path = random.choice(self.image_paths)
|
||||
target_vision_frame = read_image(target_image_path)
|
||||
target_tensor = self.transforms(target_vision_frame)
|
||||
return target_tensor
|
||||
image_path = random.choice(self.image_paths)
|
||||
vision_frame = cv2.imread(image_path)
|
||||
return self.transforms(vision_frame)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.dataset_total
|
||||
|
||||
def prepare_image_paths(self, dataset_file_pattern : str) -> Paths:
|
||||
return glob.glob(dataset_file_pattern)
|
||||
return len(self.image_paths)
|
||||
|
||||
def compose_transforms(self) -> transforms:
|
||||
transform = transforms.Compose(
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
@@ -37,4 +32,3 @@ class DataLoaderRecognition(Dataset[torch.Tensor]):
|
||||
transforms.Lambda(lambda temp_tensor : temp_tensor[[2, 1, 0], :, :]),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
return transform
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
import cv2
|
||||
|
||||
from .types import VisionFrame
|
||||
|
||||
|
||||
def read_image(image_path : str) -> VisionFrame:
|
||||
return cv2.imread(image_path)
|
||||
@@ -3,11 +3,11 @@ import os.path
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import transforms from torchvision
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
from .helper import read_image
|
||||
from .types import Batch, ImagePathList, ImagePathSet
|
||||
|
||||
|
||||
@@ -16,11 +16,10 @@ class DataLoaderVGG(TensorDataset):
|
||||
self.same_person_probability = same_person_probability
|
||||
self.directory_paths = glob.glob(dataset_directory_pattern.format(dataset_path))
|
||||
self.image_paths, self.image_path_set = self.prepare_image_paths(dataset_image_pattern)
|
||||
self.dataset_total = len(self.image_paths)
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
source_image_path = self.image_paths[index]
|
||||
source_image_path = self.image_paths.get(index)
|
||||
|
||||
if random.random() > self.same_person_probability:
|
||||
return self.prepare_same_person(source_image_path)
|
||||
@@ -28,8 +27,9 @@ class DataLoaderVGG(TensorDataset):
|
||||
return self.prepare_different_person(source_image_path)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.dataset_total
|
||||
return len(self.image_paths)
|
||||
|
||||
#todo: remove this method - only use glob.glob in init()
|
||||
def prepare_image_paths(self, dataset_image_pattern : str) -> Tuple[ImagePathList, ImagePathSet]:
|
||||
image_paths = []
|
||||
image_path_set = {}
|
||||
@@ -40,7 +40,7 @@ class DataLoaderVGG(TensorDataset):
|
||||
return image_paths, image_path_set
|
||||
|
||||
def compose_transforms(self) -> transforms:
|
||||
transform = transforms.Compose(
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
@@ -50,22 +50,22 @@ class DataLoaderVGG(TensorDataset):
|
||||
transforms.Lambda(lambda temp_tensor : temp_tensor[[2, 1, 0], :, :]),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
return transform
|
||||
|
||||
def prepare_different_person(self, source_image_path : str) -> Batch:
|
||||
is_same_person = torch.tensor(0)
|
||||
target_image_path = random.choice(self.image_paths)
|
||||
source_vision_frame = read_image(source_image_path)
|
||||
target_vision_frame = read_image(target_image_path)
|
||||
source_vision_frame = cv2.imread(source_image_path)
|
||||
target_vision_frame = cv2.imread(target_image_path)
|
||||
source_tensor = self.transforms(source_vision_frame)
|
||||
target_tensor = self.transforms(target_vision_frame)
|
||||
return source_tensor, target_tensor, is_same_person
|
||||
|
||||
def prepare_same_person(self, source_image_path : str) -> Batch:
|
||||
is_same_person = torch.tensor(1)
|
||||
#todo: why not like in prepare_different_person
|
||||
target_image_path = random.choice(self.image_path_set.get(os.path.dirname(source_image_path)))
|
||||
source_vision_frame = read_image(source_image_path)
|
||||
target_vision_frame = read_image(target_image_path)
|
||||
source_vision_frame = cv2.imread(source_image_path)
|
||||
target_vision_frame = cv2.imread(target_image_path)
|
||||
source_tensor = self.transforms(source_vision_frame)
|
||||
target_tensor = self.transforms(target_vision_frame)
|
||||
return source_tensor, target_tensor, is_same_person
|
||||
|
||||
@@ -8,17 +8,6 @@ from torch import Tensor, nn
|
||||
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
return platform.system().lower() == 'windows'
|
||||
|
||||
|
||||
def read_image(image_path : str) -> VisionFrame:
|
||||
if is_windows():
|
||||
image_buffer = numpy.fromfile(image_path, dtype = numpy.uint8)
|
||||
return cv2.imdecode(image_buffer, cv2.IMREAD_COLOR)
|
||||
return cv2.imread(image_path)
|
||||
|
||||
|
||||
def convert_to_vision_tensor(vision_frame : VisionFrame) -> VisionTensor:
|
||||
vision_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32))
|
||||
vision_tensor = vision_tensor / 255.0
|
||||
|
||||
@@ -3,7 +3,7 @@ import configparser
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image
|
||||
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor
|
||||
from .models.generator import Generator
|
||||
from .types import EmbedderModule, GeneratorModule, VisionFrame
|
||||
|
||||
@@ -34,7 +34,7 @@ def infer() -> None:
|
||||
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
id_embedder.eval()
|
||||
|
||||
source_vision_frame = read_image(source_path)
|
||||
target_vision_frame = read_image(target_path)
|
||||
source_vision_frame = cv2.imread(source_path)
|
||||
target_vision_frame = cv2.imread(target_path)
|
||||
output_vision_frame = run_swap(generator, id_embedder, source_vision_frame, target_vision_frame)
|
||||
cv2.imwrite(output_path, output_vision_frame)
|
||||
|
||||
Reference in New Issue
Block a user