From 401aa780b01502ab1c0081625e07a3b6c3ca1cb8 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Mon, 16 Oct 2023 11:24:00 +0200 Subject: [PATCH] Uniform model handling for predictor --- facefusion/predictor.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/facefusion/predictor.py b/facefusion/predictor.py index 0e2745f2..24040135 100644 --- a/facefusion/predictor.py +++ b/facefusion/predictor.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict import threading from functools import lru_cache @@ -9,15 +9,21 @@ from tqdm import tqdm import facefusion.globals from facefusion import wording -from facefusion.typing import Frame +from facefusion.typing import Frame, ModelValue from facefusion.vision import get_video_frame, count_video_frame_total, read_image, detect_fps from facefusion.utilities import resolve_relative_path, conditional_download PREDICTOR = None THREAD_LOCK : threading.Lock = threading.Lock() NAME = 'FACEFUSION.PREDICTOR' -MODEL_URL = 'https://github.com/facefusion/facefusion-assets/releases/download/models/open_nsfw.onnx' -MODEL_PATH = resolve_relative_path('../.assets/models/open_nsfw.onnx') +MODELS : Dict[str, ModelValue] =\ +{ + 'open_nsfw': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/open_nsfw.onnx', + 'path': resolve_relative_path('../.assets/models/open_nsfw.onnx') + } +} MAX_PROBABILITY = 0.80 STREAM_COUNTER = 0 @@ -27,7 +33,8 @@ def get_predictor() -> Any: with THREAD_LOCK: if PREDICTOR is None: - PREDICTOR = onnxruntime.InferenceSession(MODEL_PATH, providers = facefusion.globals.execution_providers) + model_path = MODELS.get('open_nsfw').get('path') + PREDICTOR = onnxruntime.InferenceSession(model_path, providers = facefusion.globals.execution_providers) return PREDICTOR @@ -40,7 +47,8 @@ def clear_predictor() -> None: def pre_check() -> bool: if not facefusion.globals.skip_download: download_directory_path = resolve_relative_path('../.assets/models') - conditional_download(download_directory_path, [ MODEL_URL ]) + model_url = MODELS.get('open_nsfw').get('url') + conditional_download(download_directory_path, [ model_url ]) return True