add has_amd_execution_provider and has_nvidia_execution_provider

This commit is contained in:
harisreedhar
2026-02-12 14:01:34 +05:30
committed by henryruhs
parent ea0df867e1
commit ceab5c4a63
+10 -2
View File
@@ -29,13 +29,21 @@ def detect_static_graphic_devices(execution_providers : Tuple[ExecutionProvider,
def detect_graphic_devices(execution_providers : Tuple[ExecutionProvider, ...]) -> List[GraphicDevice]:
if any(execution_provider in [ 'rocm', 'migraphx' ] for execution_provider in execution_providers):
if has_amd_execution_provider(execution_providers):
return detect_amd_graphic_devices()
if any(execution_provider in [ 'cuda', 'tensorrt' ] for execution_provider in execution_providers):
if has_nvidia_execution_provider(execution_providers):
return detect_nvidia_graphic_devices()
return []
def has_amd_execution_provider(execution_providers : Tuple[ExecutionProvider, ...]) -> bool:
return 'rocm' in execution_providers or 'migraphx' in execution_providers
def has_nvidia_execution_provider(execution_providers : Tuple[ExecutionProvider, ...]) -> bool:
return 'cuda' in execution_providers or 'tensorrt' in execution_providers
def detect_nvidia_graphic_devices() -> List[GraphicDevice]:
pynvml = importlib.import_module('pynvml')
graphic_devices : List[GraphicDevice] = []