mirror of
https://github.com/facefusion/facefusion.git
synced 2026-05-01 13:57:50 +02:00
ROCM and OpenVINO mapping for torch backends
This commit is contained in:
@@ -11,7 +11,7 @@ from facefusion import wording
|
||||
from facefusion.face_analyser import clear_face_analyser
|
||||
from facefusion.predictor import clear_predictor
|
||||
from facefusion.typing import Frame, Face, Update_Process, ProcessMode, ModelValue, OptionsWithModel
|
||||
from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, get_device, update_status
|
||||
from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, map_device, update_status
|
||||
from facefusion.vision import read_image, read_static_image, write_image
|
||||
from facefusion.processors.frame import globals as frame_processors_globals
|
||||
from facefusion.processors.frame import choices as frame_processors_choices
|
||||
@@ -58,7 +58,7 @@ def get_frame_processor() -> Any:
|
||||
num_out_ch = 3,
|
||||
scale = model_scale
|
||||
),
|
||||
device = get_device(facefusion.globals.execution_providers),
|
||||
device = map_device(facefusion.globals.execution_providers),
|
||||
scale = model_scale
|
||||
)
|
||||
return FRAME_PROCESSOR
|
||||
|
||||
@@ -231,11 +231,13 @@ def decode_execution_providers(execution_providers: List[str]) -> List[str]:
|
||||
return [ execution_provider for execution_provider, encoded_execution_provider in zip(available_execution_providers, encoded_execution_providers) if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers) ]
|
||||
|
||||
|
||||
def get_device(execution_providers : List[str]) -> str:
|
||||
if 'CUDAExecutionProvider' in execution_providers:
|
||||
return 'cuda'
|
||||
def map_device(execution_providers : List[str]) -> str:
|
||||
if 'CoreMLExecutionProvider' in execution_providers:
|
||||
return 'mps'
|
||||
if 'CUDAExecutionProvider' in execution_providers or 'ROCMExecutionProvider' in execution_providers :
|
||||
return 'cuda'
|
||||
if 'OpenVINOExecutionProvider' in execution_providers:
|
||||
return 'mkl'
|
||||
return 'cpu'
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user