Remove Numpy and CV2 to fully use Tensors

This commit is contained in:
henryruhs
2025-02-24 12:23:57 +01:00
parent 93cbbf52d0
commit 607c55ff1f
7 changed files with 24 additions and 56 deletions
+4 -6
View File
@@ -1,9 +1,8 @@
import glob
import cv2
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision import transforms, io
from .types import Batch
@@ -15,8 +14,8 @@ class StaticDataset(Dataset[Tensor]):
def __getitem__(self, index : int) -> Batch:
file_path = self.file_paths[index]
vision_frame = cv2.imread(file_path)
return self.transforms(vision_frame)
temp_tensor = io.read_image(file_path)
return self.transforms(temp_tensor)
def __len__(self) -> int:
return len(self.file_paths)
@@ -29,6 +28,5 @@ class StaticDataset(Dataset[Tensor]):
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.ToTensor(),
transforms.Lambda(lambda temp_tensor: temp_tensor[[2, 1, 0], :, :]),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.Normalize(0.5, 0.5)
])
-2
View File
@@ -1,10 +1,8 @@
from typing import Any, TypeAlias
from numpy.typing import NDArray
from torch import Tensor
Batch : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
VisionFrame : TypeAlias = NDArray[Any]
OptimizerConfig : TypeAlias = Any
+8 -10
View File
@@ -1,10 +1,9 @@
import glob
import random
import cv2
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision import transforms, io
from .types import Batch
@@ -35,19 +34,18 @@ class DynamicDataset(Dataset[Tensor]):
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)),
transforms.ToTensor(),
transforms.Lambda(lambda temp_tensor: temp_tensor[[2, 1, 0], :, :]),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.Normalize(0.5, 0.5)
])
def prepare_different_batch(self, source_image_path : str) -> Batch:
target_image_path = random.choice(self.file_paths)
source_vision_frame = cv2.imread(source_image_path)
target_vision_frame = cv2.imread(target_image_path)
source_tensor = self.transforms(source_vision_frame)
target_tensor = self.transforms(target_vision_frame)
source_tensor = io.read_image(source_image_path)
target_tensor = io.read_image(target_image_path)
source_tensor = self.transforms(source_tensor)
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor
def prepare_equal_batch(self, source_image_path : str) -> Batch:
source_vision_frame = cv2.imread(source_image_path)
source_tensor = self.transforms(source_vision_frame)
source_tensor = io.read_image(source_image_path)
source_tensor = self.transforms(source_tensor)
return source_tensor, source_tensor
+1 -20
View File
@@ -1,25 +1,6 @@
import numpy
import torch
from torch import Tensor, nn
from .types import EmbedderModule, Embedding, Padding, VisionFrame
def convert_to_tensor(vision_frame : VisionFrame) -> Tensor:
output_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32))
output_tensor = output_tensor / 255.0
output_tensor = (output_tensor - 0.5) * 2
output_tensor = output_tensor.unsqueeze(0)
return output_tensor
def convert_to_vision_frame(input_tensor : Tensor) -> VisionFrame:
vision_frame = input_tensor.detach().cpu().numpy()[0]
vision_frame = vision_frame.transpose(1, 2, 0)
vision_frame = (vision_frame + 1) * 127.5
vision_frame = vision_frame.clip(0, 255).astype(numpy.uint8)
vision_frame = vision_frame[:, :, ::-1]
return vision_frame
from .types import EmbedderModule, Embedding, Padding
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
+10 -12
View File
@@ -1,23 +1,21 @@
import configparser
import cv2
import torch
from torch import Tensor
from torchvision import io
from .helper import calc_embedding, convert_to_tensor, convert_to_vision_frame
from .helper import calc_embedding
from .models.generator import Generator
from .types import EmbedderModule, GeneratorModule, VisionFrame
from .types import EmbedderModule, GeneratorModule
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def run_swap(generator : GeneratorModule, embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
source_tensor = convert_to_tensor(source_vision_frame)
target_tensor = convert_to_tensor(target_vision_frame)
def run_swap(generator : GeneratorModule, embedder : EmbedderModule, source_tensor : Tensor, target_tensor : Tensor) -> Tensor:
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
output_tensor = generator(source_embedding, target_tensor)[0]
output_vision_frame = convert_to_vision_frame(output_tensor)
return output_vision_frame
return output_tensor
def infer() -> None:
@@ -34,7 +32,7 @@ def infer() -> None:
embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
embedder.eval()
source_vision_frame = cv2.imread(source_path)
target_vision_frame = cv2.imread(target_path)
output_vision_frame = run_swap(generator, embedder, source_vision_frame, target_vision_frame)
cv2.imwrite(output_path, output_vision_frame)
source_tensor = io.read_image(source_path)
target_tensor = io.read_image(target_path)
output_tensor = run_swap(generator, embedder, source_tensor, target_tensor)
io.write_jpeg(output_tensor, output_path)
+1 -4
View File
@@ -1,6 +1,5 @@
from typing import Any, Tuple, TypeAlias
from typing import Tuple, TypeAlias
from numpy.typing import NDArray
from torch import Tensor
from torch.nn import Module
@@ -12,7 +11,5 @@ FaceLandmark203 : TypeAlias = Tensor
Padding : TypeAlias = Tuple[int, int, int, int]
VisionFrame : TypeAlias = NDArray[Any]
GeneratorModule : TypeAlias = Module
EmbedderModule : TypeAlias = Module
-2
View File
@@ -1,9 +1,7 @@
--extra-index-url https://download.pytorch.org/whl/cu124
lightning==2.5.0
numpy==2.2.3
onnx==1.17.0
onnxruntime==1.20.1
opencv-python==4.11.0.86
pytorch-msssim==1.0.0
torch==2.6.0
torchvision==0.21.0