diff --git a/facefusion/processors/frame/modules/frame_enhancer.py b/facefusion/processors/frame/modules/frame_enhancer.py index 4f08c01a..a529971c 100644 --- a/facefusion/processors/frame/modules/frame_enhancer.py +++ b/facefusion/processors/frame/modules/frame_enhancer.py @@ -11,7 +11,7 @@ from facefusion import wording from facefusion.face_analyser import clear_face_analyser from facefusion.predictor import clear_predictor from facefusion.typing import Frame, Face, Update_Process, ProcessMode, ModelValue, OptionsWithModel -from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, get_device, update_status +from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, map_device, update_status from facefusion.vision import read_image, read_static_image, write_image from facefusion.processors.frame import globals as frame_processors_globals from facefusion.processors.frame import choices as frame_processors_choices @@ -58,7 +58,7 @@ def get_frame_processor() -> Any: num_out_ch = 3, scale = model_scale ), - device = get_device(facefusion.globals.execution_providers), + device = map_device(facefusion.globals.execution_providers), scale = model_scale ) return FRAME_PROCESSOR diff --git a/facefusion/utilities.py b/facefusion/utilities.py index b4e52c22..ce319c5d 100644 --- a/facefusion/utilities.py +++ b/facefusion/utilities.py @@ -231,11 +231,13 @@ def decode_execution_providers(execution_providers: List[str]) -> List[str]: return [ execution_provider for execution_provider, encoded_execution_provider in zip(available_execution_providers, encoded_execution_providers) if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers) ] -def get_device(execution_providers : List[str]) -> str: - if 'CUDAExecutionProvider' in execution_providers: - return 'cuda' +def map_device(execution_providers : List[str]) -> str: if 'CoreMLExecutionProvider' in execution_providers: return 'mps' + if 'CUDAExecutionProvider' in execution_providers or 'ROCMExecutionProvider' in execution_providers : + return 'cuda' + if 'OpenVINOExecutionProvider' in execution_providers: + return 'mkl' return 'cpu'