mirror of
https://github.com/facefusion/facefusion.git
synced 2026-05-01 13:57:50 +02:00
Uniform model handling for predictor
This commit is contained in:
+14
-6
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
|
||||
@@ -9,15 +9,21 @@ from tqdm import tqdm
|
||||
|
||||
import facefusion.globals
|
||||
from facefusion import wording
|
||||
from facefusion.typing import Frame
|
||||
from facefusion.typing import Frame, ModelValue
|
||||
from facefusion.vision import get_video_frame, count_video_frame_total, read_image, detect_fps
|
||||
from facefusion.utilities import resolve_relative_path, conditional_download
|
||||
|
||||
PREDICTOR = None
|
||||
THREAD_LOCK : threading.Lock = threading.Lock()
|
||||
NAME = 'FACEFUSION.PREDICTOR'
|
||||
MODEL_URL = 'https://github.com/facefusion/facefusion-assets/releases/download/models/open_nsfw.onnx'
|
||||
MODEL_PATH = resolve_relative_path('../.assets/models/open_nsfw.onnx')
|
||||
MODELS : Dict[str, ModelValue] =\
|
||||
{
|
||||
'open_nsfw':
|
||||
{
|
||||
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/open_nsfw.onnx',
|
||||
'path': resolve_relative_path('../.assets/models/open_nsfw.onnx')
|
||||
}
|
||||
}
|
||||
MAX_PROBABILITY = 0.80
|
||||
STREAM_COUNTER = 0
|
||||
|
||||
@@ -27,7 +33,8 @@ def get_predictor() -> Any:
|
||||
|
||||
with THREAD_LOCK:
|
||||
if PREDICTOR is None:
|
||||
PREDICTOR = onnxruntime.InferenceSession(MODEL_PATH, providers = facefusion.globals.execution_providers)
|
||||
model_path = MODELS.get('open_nsfw').get('path')
|
||||
PREDICTOR = onnxruntime.InferenceSession(model_path, providers = facefusion.globals.execution_providers)
|
||||
return PREDICTOR
|
||||
|
||||
|
||||
@@ -40,7 +47,8 @@ def clear_predictor() -> None:
|
||||
def pre_check() -> bool:
|
||||
if not facefusion.globals.skip_download:
|
||||
download_directory_path = resolve_relative_path('../.assets/models')
|
||||
conditional_download(download_directory_path, [ MODEL_URL ])
|
||||
model_url = MODELS.get('open_nsfw').get('url')
|
||||
conditional_download(download_directory_path, [ model_url ])
|
||||
return True
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user