mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-22 09:26:02 +02:00
add has_amd_execution_provider and has_nvidia_execution_provider
This commit is contained in:
+10
-2
@@ -29,13 +29,21 @@ def detect_static_graphic_devices(execution_providers : Tuple[ExecutionProvider,
|
||||
|
||||
|
||||
def detect_graphic_devices(execution_providers : Tuple[ExecutionProvider, ...]) -> List[GraphicDevice]:
|
||||
if any(execution_provider in [ 'rocm', 'migraphx' ] for execution_provider in execution_providers):
|
||||
if has_amd_execution_provider(execution_providers):
|
||||
return detect_amd_graphic_devices()
|
||||
if any(execution_provider in [ 'cuda', 'tensorrt' ] for execution_provider in execution_providers):
|
||||
if has_nvidia_execution_provider(execution_providers):
|
||||
return detect_nvidia_graphic_devices()
|
||||
return []
|
||||
|
||||
|
||||
def has_amd_execution_provider(execution_providers : Tuple[ExecutionProvider, ...]) -> bool:
|
||||
return 'rocm' in execution_providers or 'migraphx' in execution_providers
|
||||
|
||||
|
||||
def has_nvidia_execution_provider(execution_providers : Tuple[ExecutionProvider, ...]) -> bool:
|
||||
return 'cuda' in execution_providers or 'tensorrt' in execution_providers
|
||||
|
||||
|
||||
def detect_nvidia_graphic_devices() -> List[GraphicDevice]:
|
||||
pynvml = importlib.import_module('pynvml')
|
||||
graphic_devices : List[GraphicDevice] = []
|
||||
|
||||
Reference in New Issue
Block a user