mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-29 13:05:59 +02:00
Fix bad caching of graphic devices (#1064)
* fix bad caching of graphic devices * restore without cache * restore without cache * restore without cache * remove detect_static_graphic_devices and add resolve_static_cudnn_conv_algo_search
This commit is contained in:
+10
-5
@@ -1,12 +1,12 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Tuple
|
||||
|
||||
import onnxruntime
|
||||
|
||||
import facefusion.choices
|
||||
from facefusion.filesystem import create_directory, is_directory
|
||||
from facefusion.system import detect_static_graphic_devices
|
||||
from facefusion.system import detect_graphic_devices
|
||||
from facefusion.types import ExecutionProvider, InferenceOptionSet, InferenceProvider
|
||||
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
@@ -37,7 +37,7 @@ def create_inference_providers(execution_device_id : int, execution_providers :
|
||||
inference_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
|
||||
{
|
||||
'device_id': execution_device_id,
|
||||
'cudnn_conv_algo_search': resolve_cudnn_conv_algo_search(execution_providers)
|
||||
'cudnn_conv_algo_search': resolve_static_cudnn_conv_algo_search(tuple(execution_providers))
|
||||
}))
|
||||
|
||||
if execution_provider == 'tensorrt':
|
||||
@@ -110,9 +110,14 @@ def resolve_cache_path() -> str:
|
||||
return os.path.join('.caches', onnxruntime.get_version_string())
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def resolve_static_cudnn_conv_algo_search(execution_providers : Tuple[ExecutionProvider, ...]) -> str:
|
||||
return resolve_cudnn_conv_algo_search(list(execution_providers))
|
||||
|
||||
|
||||
def resolve_cudnn_conv_algo_search(execution_providers : List[ExecutionProvider]) -> str:
|
||||
if has_execution_provider('cuda') or has_execution_provider('tensorrt'):
|
||||
graphic_devices = detect_static_graphic_devices(tuple(execution_providers))
|
||||
graphic_devices = detect_graphic_devices(execution_providers)
|
||||
product_names = ('GeForce GTX 1630', 'GeForce GTX 1650', 'GeForce GTX 1660')
|
||||
|
||||
for graphic_device in graphic_devices:
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import importlib
|
||||
import shutil
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
import psutil
|
||||
|
||||
@@ -23,12 +22,7 @@ def get_metrics_set() -> Metrics:
|
||||
}
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def detect_static_graphic_devices(execution_providers : Tuple[ExecutionProvider, ...]) -> List[GraphicDevice]:
|
||||
return detect_graphic_devices(execution_providers)
|
||||
|
||||
|
||||
def detect_graphic_devices(execution_providers : Tuple[ExecutionProvider, ...]) -> List[GraphicDevice]:
|
||||
def detect_graphic_devices(execution_providers : List[ExecutionProvider]) -> List[GraphicDevice]:
|
||||
if 'rocm' in execution_providers or 'migraphx' in execution_providers:
|
||||
return detect_amd_graphic_devices()
|
||||
if 'cuda' in execution_providers or 'tensorrt' in execution_providers:
|
||||
|
||||
Reference in New Issue
Block a user