implement face detection; fix image reading

This commit is contained in:
Giulio Starace
2024-02-12 00:36:42 +01:00
parent a2f5a0ebe7
commit fb9ec9095e
+21 -12
View File
@@ -3,10 +3,12 @@ from pathlib import Path
from PIL import Image, ImageOps
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import torch
import insightface
import cv2
GIGACHAD_IMAGE_PATH = "images/gigachad.png"
@@ -23,29 +25,35 @@ if torch.cuda is not None:
PROVIDER = ["CPUExecutionProvider"]
def get_face_analysis_model() -> torch.nn.Module:
def get_face_analysis_model() -> insightface.app.FaceAnalysis:
return insightface.app.FaceAnalysis(
name="buffalo_l", root=INSIGHTFACE_PATH, providers=PROVIDER
name="buffalo_l", providers=PROVIDER, root=INSIGHTFACE_PATH
)
def detect_faces(image: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def detect_faces(
img_data: np.ndarray,
face_analyser: insightface.app.FaceAnalysis,
det_size=(640, 640),
) -> list:
face_analyser.prepare(ctx_id=0, det_size=det_size)
return face_analyser.get(img_data)
def face_detection_asserts(base_faces, target_faces):
raise NotImplementedError
def get_face_single(faces: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def get_face_single(faces: torch.tensor) -> torch.tensor:
# TODO
return faces[0]
def swap_faces(
base_face_img: torch.Tensor,
target_face_img: torch.Tensor,
face_analysis_model: torch.nn.Module,
) -> torch.Tensor:
base_face_img: np.ndarray,
target_face_img: np.ndarray,
face_analysis_model: insightface.app.FaceAnalysis,
) -> torch.tensor:
base_faces = detect_faces(base_face_img, face_analysis_model)
target_faces = detect_faces(target_face_img, face_analysis_model)
@@ -65,7 +73,7 @@ def is_url(image_path: str) -> bool:
return image_path.startswith("http://") or image_path.startswith("https://")
def read_image(image_path: str) -> torch.tensor:
def read_image(image_path: str) -> np.ndarray:
if is_url(image_path):
response = requests.get(image_path)
image = Image.open(BytesIO(response.content))
@@ -75,7 +83,8 @@ def read_image(image_path: str) -> torch.tensor:
image = ImageOps.exif_transpose(image)
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image).unsqueeze(0)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = (image * 255).astype(np.uint8)
return image