Modernize data loader, remove read image helper

This commit is contained in:
henryruhs
2025-02-21 01:09:57 +01:00
parent d1bf54276d
commit 3cf9711df0
5 changed files with 22 additions and 46 deletions
+8 -14
View File
@@ -1,34 +1,29 @@
import glob
import random
import cv2
import torch
import torchvision.transforms as transforms
import transforms from torchvision
from torch.utils.data import Dataset
from .helper import read_image
from .types import Batch, Paths
class DataLoaderRecognition(Dataset[torch.Tensor]):
def __init__(self, dataset_file_pattern : str) -> None:
self.image_paths = self.prepare_image_paths(dataset_file_pattern)
self.dataset_total = len(self.image_paths)
self.image_paths = glob.glob(dataset_file_pattern)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
target_image_path = random.choice(self.image_paths)
target_vision_frame = read_image(target_image_path)
target_tensor = self.transforms(target_vision_frame)
return target_tensor
image_path = random.choice(self.image_paths)
vision_frame = cv2.imread(image_path)
return self.transforms(vision_frame)
def __len__(self) -> int:
return self.dataset_total
def prepare_image_paths(self, dataset_file_pattern : str) -> Paths:
return glob.glob(dataset_file_pattern)
return len(self.image_paths)
def compose_transforms(self) -> transforms:
transform = transforms.Compose(
return transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
@@ -37,4 +32,3 @@ class DataLoaderRecognition(Dataset[torch.Tensor]):
transforms.Lambda(lambda temp_tensor : temp_tensor[[2, 1, 0], :, :]),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return transform
-7
View File
@@ -1,7 +0,0 @@
import cv2
from .types import VisionFrame
def read_image(image_path : str) -> VisionFrame:
return cv2.imread(image_path)
+11 -11
View File
@@ -3,11 +3,11 @@ import os.path
import random
from typing import Tuple
import cv2
import torch
import torchvision.transforms as transforms
import transforms from torchvision
from torch.utils.data import TensorDataset
from .helper import read_image
from .types import Batch, ImagePathList, ImagePathSet
@@ -16,11 +16,10 @@ class DataLoaderVGG(TensorDataset):
self.same_person_probability = same_person_probability
self.directory_paths = glob.glob(dataset_directory_pattern.format(dataset_path))
self.image_paths, self.image_path_set = self.prepare_image_paths(dataset_image_pattern)
self.dataset_total = len(self.image_paths)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
source_image_path = self.image_paths[index]
source_image_path = self.image_paths.get(index)
if random.random() > self.same_person_probability:
return self.prepare_same_person(source_image_path)
@@ -28,8 +27,9 @@ class DataLoaderVGG(TensorDataset):
return self.prepare_different_person(source_image_path)
def __len__(self) -> int:
return self.dataset_total
return len(self.image_paths)
#todo: remove this method - only use glob.glob in init()
def prepare_image_paths(self, dataset_image_pattern : str) -> Tuple[ImagePathList, ImagePathSet]:
image_paths = []
image_path_set = {}
@@ -40,7 +40,7 @@ class DataLoaderVGG(TensorDataset):
return image_paths, image_path_set
def compose_transforms(self) -> transforms:
transform = transforms.Compose(
return transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
@@ -50,22 +50,22 @@ class DataLoaderVGG(TensorDataset):
transforms.Lambda(lambda temp_tensor : temp_tensor[[2, 1, 0], :, :]),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return transform
def prepare_different_person(self, source_image_path : str) -> Batch:
is_same_person = torch.tensor(0)
target_image_path = random.choice(self.image_paths)
source_vision_frame = read_image(source_image_path)
target_vision_frame = read_image(target_image_path)
source_vision_frame = cv2.imread(source_image_path)
target_vision_frame = cv2.imread(target_image_path)
source_tensor = self.transforms(source_vision_frame)
target_tensor = self.transforms(target_vision_frame)
return source_tensor, target_tensor, is_same_person
def prepare_same_person(self, source_image_path : str) -> Batch:
is_same_person = torch.tensor(1)
#todo: why not like in prepare_different_person
target_image_path = random.choice(self.image_path_set.get(os.path.dirname(source_image_path)))
source_vision_frame = read_image(source_image_path)
target_vision_frame = read_image(target_image_path)
source_vision_frame = cv2.imread(source_image_path)
target_vision_frame = cv2.imread(target_image_path)
source_tensor = self.transforms(source_vision_frame)
target_tensor = self.transforms(target_vision_frame)
return source_tensor, target_tensor, is_same_person
-11
View File
@@ -8,17 +8,6 @@ from torch import Tensor, nn
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
def is_windows() -> bool:
return platform.system().lower() == 'windows'
def read_image(image_path : str) -> VisionFrame:
if is_windows():
image_buffer = numpy.fromfile(image_path, dtype = numpy.uint8)
return cv2.imdecode(image_buffer, cv2.IMREAD_COLOR)
return cv2.imread(image_path)
def convert_to_vision_tensor(vision_frame : VisionFrame) -> VisionTensor:
vision_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32))
vision_tensor = vision_tensor / 255.0
+3 -3
View File
@@ -3,7 +3,7 @@ import configparser
import cv2
import torch
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor
from .models.generator import Generator
from .types import EmbedderModule, GeneratorModule, VisionFrame
@@ -34,7 +34,7 @@ def infer() -> None:
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
id_embedder.eval()
source_vision_frame = read_image(source_path)
target_vision_frame = read_image(target_path)
source_vision_frame = cv2.imread(source_path)
target_vision_frame = cv2.imread(target_path)
output_vision_frame = run_swap(generator, id_embedder, source_vision_frame, target_vision_frame)
cv2.imwrite(output_path, output_vision_frame)