mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Remove Numpy and CV2 to fully use Tensors
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import glob
|
||||
|
||||
import cv2
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision import transforms, io
|
||||
|
||||
from .types import Batch
|
||||
|
||||
@@ -15,8 +14,8 @@ class StaticDataset(Dataset[Tensor]):
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = self.file_paths[index]
|
||||
vision_frame = cv2.imread(file_path)
|
||||
return self.transforms(vision_frame)
|
||||
temp_tensor = io.read_image(file_path)
|
||||
return self.transforms(temp_tensor)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.file_paths)
|
||||
@@ -29,6 +28,5 @@ class StaticDataset(Dataset[Tensor]):
|
||||
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda temp_tensor: temp_tensor[[2, 1, 0], :, :]),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
transforms.Normalize(0.5, 0.5)
|
||||
])
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from torch import Tensor
|
||||
|
||||
Batch : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
VisionFrame : TypeAlias = NDArray[Any]
|
||||
|
||||
OptimizerConfig : TypeAlias = Any
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import glob
|
||||
import random
|
||||
|
||||
import cv2
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision import transforms, io
|
||||
|
||||
from .types import Batch
|
||||
|
||||
@@ -35,19 +34,18 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
|
||||
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda temp_tensor: temp_tensor[[2, 1, 0], :, :]),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
transforms.Normalize(0.5, 0.5)
|
||||
])
|
||||
|
||||
def prepare_different_batch(self, source_image_path : str) -> Batch:
|
||||
target_image_path = random.choice(self.file_paths)
|
||||
source_vision_frame = cv2.imread(source_image_path)
|
||||
target_vision_frame = cv2.imread(target_image_path)
|
||||
source_tensor = self.transforms(source_vision_frame)
|
||||
target_tensor = self.transforms(target_vision_frame)
|
||||
source_tensor = io.read_image(source_image_path)
|
||||
target_tensor = io.read_image(target_image_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_equal_batch(self, source_image_path : str) -> Batch:
|
||||
source_vision_frame = cv2.imread(source_image_path)
|
||||
source_tensor = self.transforms(source_vision_frame)
|
||||
source_tensor = io.read_image(source_image_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
return source_tensor, source_tensor
|
||||
|
||||
@@ -1,25 +1,6 @@
|
||||
import numpy
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import EmbedderModule, Embedding, Padding, VisionFrame
|
||||
|
||||
|
||||
def convert_to_tensor(vision_frame : VisionFrame) -> Tensor:
|
||||
output_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32))
|
||||
output_tensor = output_tensor / 255.0
|
||||
output_tensor = (output_tensor - 0.5) * 2
|
||||
output_tensor = output_tensor.unsqueeze(0)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def convert_to_vision_frame(input_tensor : Tensor) -> VisionFrame:
|
||||
vision_frame = input_tensor.detach().cpu().numpy()[0]
|
||||
vision_frame = vision_frame.transpose(1, 2, 0)
|
||||
vision_frame = (vision_frame + 1) * 127.5
|
||||
vision_frame = vision_frame.clip(0, 255).astype(numpy.uint8)
|
||||
vision_frame = vision_frame[:, :, ::-1]
|
||||
return vision_frame
|
||||
from .types import EmbedderModule, Embedding, Padding
|
||||
|
||||
|
||||
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
|
||||
|
||||
@@ -1,23 +1,21 @@
|
||||
import configparser
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torchvision import io
|
||||
|
||||
from .helper import calc_embedding, convert_to_tensor, convert_to_vision_frame
|
||||
from .helper import calc_embedding
|
||||
from .models.generator import Generator
|
||||
from .types import EmbedderModule, GeneratorModule, VisionFrame
|
||||
from .types import EmbedderModule, GeneratorModule
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def run_swap(generator : GeneratorModule, embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
source_tensor = convert_to_tensor(source_vision_frame)
|
||||
target_tensor = convert_to_tensor(target_vision_frame)
|
||||
def run_swap(generator : GeneratorModule, embedder : EmbedderModule, source_tensor : Tensor, target_tensor : Tensor) -> Tensor:
|
||||
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor = generator(source_embedding, target_tensor)[0]
|
||||
output_vision_frame = convert_to_vision_frame(output_tensor)
|
||||
return output_vision_frame
|
||||
return output_tensor
|
||||
|
||||
|
||||
def infer() -> None:
|
||||
@@ -34,7 +32,7 @@ def infer() -> None:
|
||||
embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
embedder.eval()
|
||||
|
||||
source_vision_frame = cv2.imread(source_path)
|
||||
target_vision_frame = cv2.imread(target_path)
|
||||
output_vision_frame = run_swap(generator, embedder, source_vision_frame, target_vision_frame)
|
||||
cv2.imwrite(output_path, output_vision_frame)
|
||||
source_tensor = io.read_image(source_path)
|
||||
target_tensor = io.read_image(target_path)
|
||||
output_tensor = run_swap(generator, embedder, source_tensor, target_tensor)
|
||||
io.write_jpeg(output_tensor, output_path)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, Tuple, TypeAlias
|
||||
from typing import Tuple, TypeAlias
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
@@ -12,7 +11,5 @@ FaceLandmark203 : TypeAlias = Tensor
|
||||
|
||||
Padding : TypeAlias = Tuple[int, int, int, int]
|
||||
|
||||
VisionFrame : TypeAlias = NDArray[Any]
|
||||
|
||||
GeneratorModule : TypeAlias = Module
|
||||
EmbedderModule : TypeAlias = Module
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu124
|
||||
lightning==2.5.0
|
||||
numpy==2.2.3
|
||||
onnx==1.17.0
|
||||
onnxruntime==1.20.1
|
||||
opencv-python==4.11.0.86
|
||||
pytorch-msssim==1.0.0
|
||||
torch==2.6.0
|
||||
torchvision==0.21.0
|
||||
|
||||
Reference in New Issue
Block a user