mirror of
https://github.com/thesofakillers/gigachadify.git
synced 2026-03-17 06:22:48 +01:00
implement face detection; fix image reading
This commit is contained in:
+21
-12
@@ -3,10 +3,12 @@ from pathlib import Path
|
||||
from PIL import Image, ImageOps
|
||||
import requests
|
||||
from io import BytesIO
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import insightface
|
||||
import cv2
|
||||
|
||||
|
||||
GIGACHAD_IMAGE_PATH = "images/gigachad.png"
|
||||
@@ -23,29 +25,35 @@ if torch.cuda is not None:
|
||||
PROVIDER = ["CPUExecutionProvider"]
|
||||
|
||||
|
||||
def get_face_analysis_model() -> torch.nn.Module:
|
||||
def get_face_analysis_model() -> insightface.app.FaceAnalysis:
|
||||
return insightface.app.FaceAnalysis(
|
||||
name="buffalo_l", root=INSIGHTFACE_PATH, providers=PROVIDER
|
||||
name="buffalo_l", providers=PROVIDER, root=INSIGHTFACE_PATH
|
||||
)
|
||||
|
||||
|
||||
def detect_faces(image: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
def detect_faces(
|
||||
img_data: np.ndarray,
|
||||
face_analyser: insightface.app.FaceAnalysis,
|
||||
det_size=(640, 640),
|
||||
) -> list:
|
||||
face_analyser.prepare(ctx_id=0, det_size=det_size)
|
||||
return face_analyser.get(img_data)
|
||||
|
||||
|
||||
def face_detection_asserts(base_faces, target_faces):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_face_single(faces: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
def get_face_single(faces: torch.tensor) -> torch.tensor:
|
||||
# TODO
|
||||
return faces[0]
|
||||
|
||||
|
||||
def swap_faces(
|
||||
base_face_img: torch.Tensor,
|
||||
target_face_img: torch.Tensor,
|
||||
face_analysis_model: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
base_face_img: np.ndarray,
|
||||
target_face_img: np.ndarray,
|
||||
face_analysis_model: insightface.app.FaceAnalysis,
|
||||
) -> torch.tensor:
|
||||
base_faces = detect_faces(base_face_img, face_analysis_model)
|
||||
target_faces = detect_faces(target_face_img, face_analysis_model)
|
||||
|
||||
@@ -65,7 +73,7 @@ def is_url(image_path: str) -> bool:
|
||||
return image_path.startswith("http://") or image_path.startswith("https://")
|
||||
|
||||
|
||||
def read_image(image_path: str) -> torch.tensor:
|
||||
def read_image(image_path: str) -> np.ndarray:
|
||||
if is_url(image_path):
|
||||
response = requests.get(image_path)
|
||||
image = Image.open(BytesIO(response.content))
|
||||
@@ -75,7 +83,8 @@ def read_image(image_path: str) -> torch.tensor:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
image = image.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image).unsqueeze(0)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
image = (image * 255).astype(np.uint8)
|
||||
return image
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user