mirror of
https://github.com/thesofakillers/gigachadify.git
synced 2026-03-17 06:22:48 +01:00
face analysis model setup
This commit is contained in:
+30
-7
@@ -1,9 +1,28 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import insightface
|
||||
|
||||
|
||||
GIGACHAD_IMAGE_PATH = "TODO"
|
||||
|
||||
MODELS_DIR = os.path.join(Path(__file__).parent, "models")
|
||||
INSIGHTFACE_PATH = os.path.join(MODELS_DIR, "insightface")
|
||||
|
||||
if torch.cuda is not None:
|
||||
if torch.cuda.is_available():
|
||||
PROVIDER = ["CUDAExecutionProvider"]
|
||||
elif torch.backends.mps.is_available():
|
||||
PROVIDER = ["CoreMLExecutionProvider"]
|
||||
else:
|
||||
PROVIDER = ["CPUExecutionProvider"]
|
||||
|
||||
|
||||
def get_face_analysis_model() -> torch.nn.Module:
|
||||
return insightface.app.FaceAnalysis(
|
||||
name="buffalo_l", root=INSIGHTFACE_PATH, providers=PROVIDER
|
||||
)
|
||||
|
||||
|
||||
def detect_faces(image: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
@@ -18,10 +37,12 @@ def get_face_single(faces: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def swap_faces(
|
||||
base_face_img: torch.Tensor, target_face_img: torch.Tensor
|
||||
base_face_img: torch.Tensor,
|
||||
target_face_img: torch.Tensor,
|
||||
face_analysis_model: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
base_faces = detect_faces(base_face_img)
|
||||
target_faces = detect_faces(target_face_img)
|
||||
base_faces = detect_faces(base_face_img, face_analysis_model)
|
||||
target_faces = detect_faces(target_face_img, face_analysis_model)
|
||||
|
||||
face_detection_asserts(base_faces, target_faces)
|
||||
|
||||
@@ -47,12 +68,14 @@ def get_face_swap_model():
|
||||
|
||||
|
||||
def gigachadify(input_image_path: str, output_image_path: str):
|
||||
chad_base = read_image(GIGACHAD_IMAGE_PATH)
|
||||
input_image = read_image(input_image_path)
|
||||
# chad_base = read_image(GIGACHAD_IMAGE_PATH)
|
||||
# input_image = read_image(input_image_path)
|
||||
|
||||
result = swap_faces(chad_base, input_image)
|
||||
face_analysis_model = get_face_analysis_model()
|
||||
|
||||
save_image(result, output_image_path)
|
||||
# result = swap_faces(chad_base, input_image, face_analysis_model)
|
||||
|
||||
# save_image(result, output_image_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Generated
+1
-1123
File diff suppressed because it is too large
Load Diff
+1
-1
@@ -8,8 +8,8 @@ readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
onnx = "^1.15.0"
|
||||
torch = "^2.2.0"
|
||||
insightface = "^0.7.3"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
||||
Reference in New Issue
Block a user