From fb9ec9095e07863c7eb482a38718e9454ed4e380 Mon Sep 17 00:00:00 2001 From: Giulio Starace Date: Mon, 12 Feb 2024 00:36:42 +0100 Subject: [PATCH] implement face detection; fix image reading --- gigachadify.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/gigachadify.py b/gigachadify.py index 9e72cce..080e92c 100644 --- a/gigachadify.py +++ b/gigachadify.py @@ -3,10 +3,12 @@ from pathlib import Path from PIL import Image, ImageOps import requests from io import BytesIO +import matplotlib.pyplot as plt import numpy as np import torch import insightface +import cv2 GIGACHAD_IMAGE_PATH = "images/gigachad.png" @@ -23,29 +25,35 @@ if torch.cuda is not None: PROVIDER = ["CPUExecutionProvider"] -def get_face_analysis_model() -> torch.nn.Module: +def get_face_analysis_model() -> insightface.app.FaceAnalysis: return insightface.app.FaceAnalysis( - name="buffalo_l", root=INSIGHTFACE_PATH, providers=PROVIDER + name="buffalo_l", providers=PROVIDER, root=INSIGHTFACE_PATH ) -def detect_faces(image: torch.Tensor) -> torch.Tensor: - raise NotImplementedError +def detect_faces( + img_data: np.ndarray, + face_analyser: insightface.app.FaceAnalysis, + det_size=(640, 640), +) -> list: + face_analyser.prepare(ctx_id=0, det_size=det_size) + return face_analyser.get(img_data) def face_detection_asserts(base_faces, target_faces): raise NotImplementedError -def get_face_single(faces: torch.Tensor) -> torch.Tensor: - raise NotImplementedError +def get_face_single(faces: torch.tensor) -> torch.tensor: + # TODO + return faces[0] def swap_faces( - base_face_img: torch.Tensor, - target_face_img: torch.Tensor, - face_analysis_model: torch.nn.Module, -) -> torch.Tensor: + base_face_img: np.ndarray, + target_face_img: np.ndarray, + face_analysis_model: insightface.app.FaceAnalysis, +) -> torch.tensor: base_faces = detect_faces(base_face_img, face_analysis_model) target_faces = detect_faces(target_face_img, face_analysis_model) @@ -65,7 +73,7 @@ def is_url(image_path: str) -> bool: return image_path.startswith("http://") or image_path.startswith("https://") -def read_image(image_path: str) -> torch.tensor: +def read_image(image_path: str) -> np.ndarray: if is_url(image_path): response = requests.get(image_path) image = Image.open(BytesIO(response.content)) @@ -75,7 +83,8 @@ def read_image(image_path: str) -> torch.tensor: image = ImageOps.exif_transpose(image) image = image.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image).unsqueeze(0) + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + image = (image * 255).astype(np.uint8) return image