face analysis model setup

This commit is contained in:
Giulio Starace
2024-02-11 22:20:14 +01:00
parent d4b7aefe8f
commit 60924207f8
3 changed files with 32 additions and 1131 deletions
+30 -7
View File
@@ -1,9 +1,28 @@
import os
from pathlib import Path
import torch
import insightface
GIGACHAD_IMAGE_PATH = "TODO"
MODELS_DIR = os.path.join(Path(__file__).parent, "models")
INSIGHTFACE_PATH = os.path.join(MODELS_DIR, "insightface")
if torch.cuda is not None:
if torch.cuda.is_available():
PROVIDER = ["CUDAExecutionProvider"]
elif torch.backends.mps.is_available():
PROVIDER = ["CoreMLExecutionProvider"]
else:
PROVIDER = ["CPUExecutionProvider"]
def get_face_analysis_model() -> torch.nn.Module:
return insightface.app.FaceAnalysis(
name="buffalo_l", root=INSIGHTFACE_PATH, providers=PROVIDER
)
def detect_faces(image: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@@ -18,10 +37,12 @@ def get_face_single(faces: torch.Tensor) -> torch.Tensor:
def swap_faces(
base_face_img: torch.Tensor, target_face_img: torch.Tensor
base_face_img: torch.Tensor,
target_face_img: torch.Tensor,
face_analysis_model: torch.nn.Module,
) -> torch.Tensor:
base_faces = detect_faces(base_face_img)
target_faces = detect_faces(target_face_img)
base_faces = detect_faces(base_face_img, face_analysis_model)
target_faces = detect_faces(target_face_img, face_analysis_model)
face_detection_asserts(base_faces, target_faces)
@@ -47,12 +68,14 @@ def get_face_swap_model():
def gigachadify(input_image_path: str, output_image_path: str):
chad_base = read_image(GIGACHAD_IMAGE_PATH)
input_image = read_image(input_image_path)
# chad_base = read_image(GIGACHAD_IMAGE_PATH)
# input_image = read_image(input_image_path)
result = swap_faces(chad_base, input_image)
face_analysis_model = get_face_analysis_model()
save_image(result, output_image_path)
# result = swap_faces(chad_base, input_image, face_analysis_model)
# save_image(result, output_image_path)
if __name__ == "__main__":
Generated
+1 -1123
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -8,8 +8,8 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10"
onnx = "^1.15.0"
torch = "^2.2.0"
insightface = "^0.7.3"
[build-system]