diff --git a/facefusion/libraries/amd_smi.py b/facefusion/libraries/amd_smi.py index 07191fad..84614c2e 100644 --- a/facefusion/libraries/amd_smi.py +++ b/facefusion/libraries/amd_smi.py @@ -35,20 +35,6 @@ def define_driver_info() -> ctypes.Structure: })() -def define_rocm_version() -> ctypes.Structure: - return type('AMDSMI_VERSION', (ctypes.Structure,), - { - '_pack_': 1, - '_fields_': - [ - ('major', ctypes.c_uint32), - ('minor', ctypes.c_uint32), - ('patch', ctypes.c_uint32), - ('build', ctypes.c_char_p) - ] - })() - - def define_product_info() -> ctypes.Structure: return type('AMDSMI_ASIC_INFO', (ctypes.Structure,), { @@ -100,23 +86,18 @@ def define_device_utilization() -> ctypes.Structure: def init_ctypes(amd_smi : ctypes.CDLL) -> ctypes.CDLL: - void_pointer = ctypes.POINTER(None) - amd_smi.amdsmi_init.argtypes = [ ctypes.c_uint64 ] amd_smi.amdsmi_init.restype = ctypes.c_uint32 amd_smi.amdsmi_shut_down.argtypes = [] amd_smi.amdsmi_shut_down.restype = ctypes.c_uint32 - amd_smi.amdsmi_get_socket_handles.argtypes = [ ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(void_pointer) ] + amd_smi.amdsmi_get_socket_handles.argtypes = [ ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_void_p) ] amd_smi.amdsmi_get_socket_handles.restype = ctypes.c_uint32 - amd_smi.amdsmi_get_processor_handles.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(void_pointer) ] + amd_smi.amdsmi_get_processor_handles.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_void_p) ] amd_smi.amdsmi_get_processor_handles.restype = ctypes.c_uint32 - amd_smi.amdsmi_get_lib_version.argtypes = [ ctypes.c_void_p ] - amd_smi.amdsmi_get_lib_version.restype = ctypes.c_uint32 - amd_smi.amdsmi_get_gpu_driver_info.argtypes = [ ctypes.c_void_p, ctypes.c_void_p ] amd_smi.amdsmi_get_gpu_driver_info.restype = ctypes.c_uint32 diff --git a/facefusion/libraries/rocm_core.py b/facefusion/libraries/rocm_core.py new file mode 100644 index 00000000..04bfb66b --- /dev/null +++ b/facefusion/libraries/rocm_core.py @@ -0,0 +1,29 @@ +import ctypes +from functools import lru_cache +from typing import Optional + +from facefusion.common_helper import is_linux + + +def resolve_library_file() -> Optional[str]: + if is_linux(): + return 'librocm-core.so' + return None + + +@lru_cache +def create_static_library() -> Optional[ctypes.CDLL]: + library_file = resolve_library_file() + + if library_file: + rocm_core_library = ctypes.CDLL(library_file) + return init_ctypes(rocm_core_library) + + return None + + +def init_ctypes(rocm_core : ctypes.CDLL) -> ctypes.CDLL: + rocm_core.getROCmVersion.argtypes = [ ctypes.POINTER(ctypes.c_uint), ctypes.POINTER(ctypes.c_uint), ctypes.POINTER(ctypes.c_uint) ] + rocm_core.getROCmVersion.restype = ctypes.c_int + + return rocm_core diff --git a/facefusion/system.py b/facefusion/system.py index 1812b2d8..a4ff074a 100644 --- a/facefusion/system.py +++ b/facefusion/system.py @@ -6,7 +6,7 @@ from typing import List import psutil from facefusion import state_manager -from facefusion.libraries import amd_smi as amd_smi_module, nvidia_ml as nvidia_ml_module +from facefusion.libraries import amd_smi as amd_smi_module, nvidia_ml as nvidia_ml_module, rocm_core as rocm_core_module from facefusion.types import DiskMetrics, ExecutionProvider, GraphicDevice, MemoryMetrics, Metrics, NetworkMetrics, ProcessorMetrics @@ -123,8 +123,13 @@ def detect_amd_graphic_devices() -> List[GraphicDevice]: if amd_smi_library: amd_smi_library.amdsmi_init(ctypes.c_uint64(2)) - rocm_version = amd_smi_module.define_rocm_version() - amd_smi_library.amdsmi_get_lib_version(ctypes.byref(rocm_version)) + rocm_core_library = rocm_core_module.create_static_library() + rocm_major_version = ctypes.c_uint() + rocm_minor_version = ctypes.c_uint() + rocm_patch_version = ctypes.c_uint() + + if rocm_core_library: + rocm_core_library.getROCmVersion(ctypes.byref(rocm_major_version), ctypes.byref(rocm_minor_version), ctypes.byref(rocm_patch_version)) for device_handle in amd_smi_module.find_device_handles(amd_smi_library): driver_info = amd_smi_module.define_driver_info() @@ -148,7 +153,7 @@ def detect_amd_graphic_devices() -> List[GraphicDevice]: 'framework': { 'name': 'ROCm', - 'version': str(rocm_version.major) + '.' + str(rocm_version.minor) + '.' + str(rocm_version.patch) + 'version': str(rocm_major_version.value) + '.' + str(rocm_minor_version.value) + '.' + str(rocm_patch_version.value) }, 'product': {