ROCM and OpenVINO mapping for torch backends

This commit is contained in:
henryruhs
2023-10-19 08:16:23 +02:00
parent 782299073c
commit d81371faea
2 changed files with 7 additions and 5 deletions
@@ -11,7 +11,7 @@ from facefusion import wording
from facefusion.face_analyser import clear_face_analyser
from facefusion.predictor import clear_predictor
from facefusion.typing import Frame, Face, Update_Process, ProcessMode, ModelValue, OptionsWithModel
from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, get_device, update_status
from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, map_device, update_status
from facefusion.vision import read_image, read_static_image, write_image
from facefusion.processors.frame import globals as frame_processors_globals
from facefusion.processors.frame import choices as frame_processors_choices
@@ -58,7 +58,7 @@ def get_frame_processor() -> Any:
num_out_ch = 3,
scale = model_scale
),
device = get_device(facefusion.globals.execution_providers),
device = map_device(facefusion.globals.execution_providers),
scale = model_scale
)
return FRAME_PROCESSOR
+5 -3
View File
@@ -231,11 +231,13 @@ def decode_execution_providers(execution_providers: List[str]) -> List[str]:
return [ execution_provider for execution_provider, encoded_execution_provider in zip(available_execution_providers, encoded_execution_providers) if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers) ]
def get_device(execution_providers : List[str]) -> str:
if 'CUDAExecutionProvider' in execution_providers:
return 'cuda'
def map_device(execution_providers : List[str]) -> str:
if 'CoreMLExecutionProvider' in execution_providers:
return 'mps'
if 'CUDAExecutionProvider' in execution_providers or 'ROCMExecutionProvider' in execution_providers :
return 'cuda'
if 'OpenVINOExecutionProvider' in execution_providers:
return 'mkl'
return 'cpu'