fix cache

This commit is contained in:
harisreedhar
2026-02-12 13:28:10 +05:30
committed by henryruhs
parent b4138a8f12
commit ea0df867e1
2 changed files with 6 additions and 7 deletions
+3 -4
View File
@@ -8,7 +8,6 @@ import pynvml
import onnxruntime
import facefusion.choices
from facefusion import state_manager
from facefusion.system import detect_static_graphic_devices
from facefusion.types import ExecutionProvider, InferenceSessionProvider
@@ -40,7 +39,7 @@ def create_inference_providers(execution_device_id : int, execution_providers :
inference_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
{
'device_id': execution_device_id,
'cudnn_conv_algo_search': resolve_cudnn_conv_algo_search()
'cudnn_conv_algo_search': resolve_cudnn_conv_algo_search(execution_providers)
}))
if execution_provider == 'tensorrt':
@@ -113,8 +112,8 @@ def resolve_cache_path() -> str:
return os.path.join('.caches', onnxruntime.get_version_string())
def resolve_cudnn_conv_algo_search() -> str:
execution_devices = detect_static_graphic_devices(state_manager.get_item('execution_providers'))
def resolve_cudnn_conv_algo_search(execution_providers : List[ExecutionProvider]) -> str:
execution_devices = detect_static_graphic_devices(tuple(execution_providers))
product_names = ('GeForce GTX 1630', 'GeForce GTX 1650', 'GeForce GTX 1660')
for execution_device in execution_devices:
+3 -3
View File
@@ -2,7 +2,7 @@ import importlib
import shutil
from functools import lru_cache
from pathlib import Path
from typing import List
from typing import List, Tuple
import psutil
@@ -24,11 +24,11 @@ def get_metrics_set() -> Metrics:
@lru_cache()
def detect_static_graphic_devices(execution_providers : List[ExecutionProvider]) -> List[GraphicDevice]:
def detect_static_graphic_devices(execution_providers : Tuple[ExecutionProvider, ...]) -> List[GraphicDevice]:
return detect_graphic_devices(execution_providers)
def detect_graphic_devices(execution_providers : List[ExecutionProvider]) -> List[GraphicDevice]:
def detect_graphic_devices(execution_providers : Tuple[ExecutionProvider, ...]) -> List[GraphicDevice]:
if any(execution_provider in [ 'rocm', 'migraphx' ] for execution_provider in execution_providers):
return detect_amd_graphic_devices()
if any(execution_provider in [ 'cuda', 'tensorrt' ] for execution_provider in execution_providers):