mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-29 04:55:57 +02:00
fix cache
This commit is contained in:
@@ -8,7 +8,6 @@ import pynvml
|
||||
import onnxruntime
|
||||
|
||||
import facefusion.choices
|
||||
from facefusion import state_manager
|
||||
from facefusion.system import detect_static_graphic_devices
|
||||
from facefusion.types import ExecutionProvider, InferenceSessionProvider
|
||||
|
||||
@@ -40,7 +39,7 @@ def create_inference_providers(execution_device_id : int, execution_providers :
|
||||
inference_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
|
||||
{
|
||||
'device_id': execution_device_id,
|
||||
'cudnn_conv_algo_search': resolve_cudnn_conv_algo_search()
|
||||
'cudnn_conv_algo_search': resolve_cudnn_conv_algo_search(execution_providers)
|
||||
}))
|
||||
|
||||
if execution_provider == 'tensorrt':
|
||||
@@ -113,8 +112,8 @@ def resolve_cache_path() -> str:
|
||||
return os.path.join('.caches', onnxruntime.get_version_string())
|
||||
|
||||
|
||||
def resolve_cudnn_conv_algo_search() -> str:
|
||||
execution_devices = detect_static_graphic_devices(state_manager.get_item('execution_providers'))
|
||||
def resolve_cudnn_conv_algo_search(execution_providers : List[ExecutionProvider]) -> str:
|
||||
execution_devices = detect_static_graphic_devices(tuple(execution_providers))
|
||||
product_names = ('GeForce GTX 1630', 'GeForce GTX 1650', 'GeForce GTX 1660')
|
||||
|
||||
for execution_device in execution_devices:
|
||||
|
||||
@@ -2,7 +2,7 @@ import importlib
|
||||
import shutil
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import psutil
|
||||
|
||||
@@ -24,11 +24,11 @@ def get_metrics_set() -> Metrics:
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def detect_static_graphic_devices(execution_providers : List[ExecutionProvider]) -> List[GraphicDevice]:
|
||||
def detect_static_graphic_devices(execution_providers : Tuple[ExecutionProvider, ...]) -> List[GraphicDevice]:
|
||||
return detect_graphic_devices(execution_providers)
|
||||
|
||||
|
||||
def detect_graphic_devices(execution_providers : List[ExecutionProvider]) -> List[GraphicDevice]:
|
||||
def detect_graphic_devices(execution_providers : Tuple[ExecutionProvider, ...]) -> List[GraphicDevice]:
|
||||
if any(execution_provider in [ 'rocm', 'migraphx' ] for execution_provider in execution_providers):
|
||||
return detect_amd_graphic_devices()
|
||||
if any(execution_provider in [ 'cuda', 'tensorrt' ] for execution_provider in execution_providers):
|
||||
|
||||
Reference in New Issue
Block a user