Uniform model handling for predictor

This commit is contained in:
henryruhs
2023-10-16 11:24:00 +02:00
parent 8cf09f62bb
commit 401aa780b0
+14 -6
View File
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict
import threading
from functools import lru_cache
@@ -9,15 +9,21 @@ from tqdm import tqdm
import facefusion.globals
from facefusion import wording
from facefusion.typing import Frame
from facefusion.typing import Frame, ModelValue
from facefusion.vision import get_video_frame, count_video_frame_total, read_image, detect_fps
from facefusion.utilities import resolve_relative_path, conditional_download
PREDICTOR = None
THREAD_LOCK : threading.Lock = threading.Lock()
NAME = 'FACEFUSION.PREDICTOR'
MODEL_URL = 'https://github.com/facefusion/facefusion-assets/releases/download/models/open_nsfw.onnx'
MODEL_PATH = resolve_relative_path('../.assets/models/open_nsfw.onnx')
MODELS : Dict[str, ModelValue] =\
{
'open_nsfw':
{
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/open_nsfw.onnx',
'path': resolve_relative_path('../.assets/models/open_nsfw.onnx')
}
}
MAX_PROBABILITY = 0.80
STREAM_COUNTER = 0
@@ -27,7 +33,8 @@ def get_predictor() -> Any:
with THREAD_LOCK:
if PREDICTOR is None:
PREDICTOR = onnxruntime.InferenceSession(MODEL_PATH, providers = facefusion.globals.execution_providers)
model_path = MODELS.get('open_nsfw').get('path')
PREDICTOR = onnxruntime.InferenceSession(model_path, providers = facefusion.globals.execution_providers)
return PREDICTOR
@@ -40,7 +47,8 @@ def clear_predictor() -> None:
def pre_check() -> bool:
if not facefusion.globals.skip_download:
download_directory_path = resolve_relative_path('../.assets/models')
conditional_download(download_directory_path, [ MODEL_URL ])
model_url = MODELS.get('open_nsfw').get('url')
conditional_download(download_directory_path, [ model_url ])
return True