diff --git a/facefusion/execution.py b/facefusion/execution.py index 7be33fc5..30cab3c0 100644 --- a/facefusion/execution.py +++ b/facefusion/execution.py @@ -1,12 +1,12 @@ import os -from typing import List - +from functools import lru_cache +from typing import List, Tuple import onnxruntime import facefusion.choices from facefusion.filesystem import create_directory, is_directory -from facefusion.system import detect_static_graphic_devices +from facefusion.system import detect_graphic_devices from facefusion.types import ExecutionProvider, InferenceOptionSet, InferenceProvider onnxruntime.set_default_logger_severity(3) @@ -37,7 +37,7 @@ def create_inference_providers(execution_device_id : int, execution_providers : inference_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), { 'device_id': execution_device_id, - 'cudnn_conv_algo_search': resolve_cudnn_conv_algo_search(execution_providers) + 'cudnn_conv_algo_search': resolve_static_cudnn_conv_algo_search(tuple(execution_providers)) })) if execution_provider == 'tensorrt': @@ -110,9 +110,14 @@ def resolve_cache_path() -> str: return os.path.join('.caches', onnxruntime.get_version_string()) +@lru_cache() +def resolve_static_cudnn_conv_algo_search(execution_providers : Tuple[ExecutionProvider, ...]) -> str: + return resolve_cudnn_conv_algo_search(list(execution_providers)) + + def resolve_cudnn_conv_algo_search(execution_providers : List[ExecutionProvider]) -> str: if has_execution_provider('cuda') or has_execution_provider('tensorrt'): - graphic_devices = detect_static_graphic_devices(tuple(execution_providers)) + graphic_devices = detect_graphic_devices(execution_providers) product_names = ('GeForce GTX 1630', 'GeForce GTX 1650', 'GeForce GTX 1660') for graphic_device in graphic_devices: diff --git a/facefusion/system.py b/facefusion/system.py index 95af94ec..3966b928 100644 --- a/facefusion/system.py +++ b/facefusion/system.py @@ -1,8 +1,7 @@ import importlib import shutil -from functools import lru_cache from pathlib import Path -from typing import List, Tuple +from typing import List import psutil @@ -23,12 +22,7 @@ def get_metrics_set() -> Metrics: } -@lru_cache() -def detect_static_graphic_devices(execution_providers : Tuple[ExecutionProvider, ...]) -> List[GraphicDevice]: - return detect_graphic_devices(execution_providers) - - -def detect_graphic_devices(execution_providers : Tuple[ExecutionProvider, ...]) -> List[GraphicDevice]: +def detect_graphic_devices(execution_providers : List[ExecutionProvider]) -> List[GraphicDevice]: if 'rocm' in execution_providers or 'migraphx' in execution_providers: return detect_amd_graphic_devices() if 'cuda' in execution_providers or 'tensorrt' in execution_providers: