diff --git a/gigachadify.py b/gigachadify.py index 080e92c..5ac40bd 100644 --- a/gigachadify.py +++ b/gigachadify.py @@ -57,16 +57,18 @@ def swap_faces( base_faces = detect_faces(base_face_img, face_analysis_model) target_faces = detect_faces(target_face_img, face_analysis_model) - face_detection_asserts(base_faces, target_faces) + # face_detection_asserts(base_faces, target_faces) face_swapper = get_face_swap_model() base_face = get_face_single(base_faces) target_face = get_face_single(target_faces) - result = target_face_img + result = base_face_img result = face_swapper.get(result, base_face, target_face) + return result + def is_url(image_path: str) -> bool: """Check if the given path is a URL.""" @@ -88,12 +90,16 @@ def read_image(image_path: str) -> np.ndarray: return image -def save_image(image: torch.tensor, output_image_path: str): - raise NotImplementedError +def save_image(image: np.ndarray, output_image_path: str): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = Image.fromarray(image) + image.save(output_image_path) def get_face_swap_model(): - pass + return insightface.model_zoo.get_model( + os.path.join(INSIGHTFACE_PATH, "inswapper_128.onnx"), providers=PROVIDER + ) def gigachadify(input_image_path: str, output_image_path: str):