resolve static inference providers to fix macos (#1127)

* resolve static inference providers to fix macos

* fix lint

* restore old behaviour

* restore old behaviour

* handle ghost and uniface as well

* adjust condition for ghost and uniface
This commit is contained in:
Henry Ruhs
2026-05-27 12:45:23 +02:00
committed by GitHub
parent 815baabf26
commit 73c3899e9d
8 changed files with 84 additions and 68 deletions
+2 -10
View File
@@ -1,16 +1,14 @@
from functools import lru_cache
from typing import List, Tuple
from typing import Tuple
import numpy
from tqdm import tqdm
from facefusion import inference_manager, state_manager, translator
from facefusion.common_helper import is_macos
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
from facefusion.execution import has_execution_provider
from facefusion.filesystem import resolve_relative_path
from facefusion.thread_helper import conditional_thread_semaphore
from facefusion.types import Detection, DownloadScope, DownloadSet, ExecutionProvider, Fps, InferencePool, ModelSet, VisionFrame
from facefusion.types import Detection, DownloadScope, DownloadSet, Fps, InferencePool, ModelSet, VisionFrame
from facefusion.vision import detect_video_fps, fit_contain_frame, read_image, read_video_frame
STREAM_COUNTER = 0
@@ -119,12 +117,6 @@ def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__, model_names)
def resolve_execution_providers() -> List[ExecutionProvider]:
if is_macos() and has_execution_provider('coreml'):
return [ 'cpu' ]
return state_manager.get_item('execution_providers')
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
model_set = create_static_model_set('full')
model_hash_set = {}
+1 -1
View File
@@ -115,7 +115,7 @@ def common_pre_check() -> bool:
content_analyser_content = inspect.getsource(content_analyser).encode()
content_analyser_hash = hash_helper.create_hash(content_analyser_content)
return all(module.pre_check() for module in common_modules) and content_analyser_hash == 'b14e7b92'
return all(module.pre_check() for module in common_modules) and content_analyser_hash == '05843613'
def processors_pre_check() -> bool:
+17 -12
View File
@@ -13,7 +13,7 @@ from facefusion.execution import create_inference_providers, has_execution_provi
from facefusion.exit_helper import fatal_exit
from facefusion.filesystem import get_file_name, is_file
from facefusion.time_helper import calculate_end_time
from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet
from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet, InferenceProvider
INFERENCE_POOL_SET : InferencePoolSet =\
{
@@ -26,7 +26,7 @@ def get_inference_pool(module_name : str, model_names : List[str], model_source_
while process_manager.is_checking():
sleep(0.5)
execution_device_ids = state_manager.get_item('execution_device_ids')
execution_providers = resolve_static_execution_providers(module_name)
execution_providers = state_manager.get_item('execution_providers')
app_context = detect_app_context()
for execution_device_id in execution_device_ids:
@@ -37,26 +37,27 @@ def get_inference_pool(module_name : str, model_names : List[str], model_source_
if app_context == 'ui' and INFERENCE_POOL_SET.get('cli').get(inference_context):
INFERENCE_POOL_SET['ui'][inference_context] = INFERENCE_POOL_SET.get('cli').get(inference_context)
if not INFERENCE_POOL_SET.get(app_context).get(inference_context):
INFERENCE_POOL_SET[app_context][inference_context] = create_inference_pool(model_source_set, execution_device_id, execution_providers)
inference_providers = resolve_static_inference_providers(module_name, execution_device_id)
INFERENCE_POOL_SET[app_context][inference_context] = create_inference_pool(model_source_set, inference_providers)
current_inference_context = get_inference_context(module_name, model_names, random.choice(execution_device_ids), execution_providers)
return INFERENCE_POOL_SET.get(app_context).get(current_inference_context)
def create_inference_pool(model_source_set : DownloadSet, execution_device_id : int, execution_providers : List[ExecutionProvider]) -> InferencePool:
def create_inference_pool(model_source_set : DownloadSet, inference_providers : List[InferenceProvider]) -> InferencePool:
inference_pool : InferencePool = {}
for model_name in model_source_set.keys():
model_path = model_source_set.get(model_name).get('path')
if is_file(model_path):
inference_pool[model_name] = create_inference_session(model_path, execution_device_id, execution_providers)
inference_pool[model_name] = create_inference_session(model_path, inference_providers)
return inference_pool
def clear_inference_pool(module_name : str, model_names : List[str]) -> None:
execution_device_ids = state_manager.get_item('execution_device_ids')
execution_providers = resolve_static_execution_providers(module_name)
execution_providers = state_manager.get_item('execution_providers')
app_context = detect_app_context()
if is_windows() and has_execution_provider('directml'):
@@ -68,12 +69,11 @@ def clear_inference_pool(module_name : str, model_names : List[str]) -> None:
del INFERENCE_POOL_SET[app_context][inference_context]
def create_inference_session(model_path : str, execution_device_id : int, execution_providers : List[ExecutionProvider]) -> InferenceSession:
def create_inference_session(model_path : str, inference_providers : List[InferenceProvider]) -> InferenceSession:
model_file_name = get_file_name(model_path)
start_time = time()
try:
inference_providers = create_inference_providers(execution_device_id, execution_providers)
inference_session = InferenceSession(model_path, providers = inference_providers)
logger.debug(translator.get('loading_model_succeeded').format(model_name = model_file_name, seconds = calculate_end_time(start_time)), __name__)
return inference_session
@@ -89,9 +89,14 @@ def get_inference_context(module_name : str, model_names : List[str], execution_
@lru_cache()
def resolve_static_execution_providers(module_name : str) -> List[ExecutionProvider]:
def resolve_static_inference_providers(module_name : str, execution_device_id : int) -> List[InferenceProvider]:
module = importlib.import_module(module_name)
execution_providers = state_manager.get_item('execution_providers')
if hasattr(module, 'resolve_execution_providers'):
return getattr(module, 'resolve_execution_providers')()
return state_manager.get_item('execution_providers')
if hasattr(module, 'resolve_inference_providers'):
inference_providers = getattr(module, 'resolve_inference_providers')()
if inference_providers:
return inference_providers
return create_inference_providers(execution_device_id, execution_providers)
@@ -8,9 +8,8 @@ import facefusion.choices
import facefusion.jobs.job_manager
import facefusion.jobs.job_store
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, state_manager, translator, video_manager
from facefusion.common_helper import create_int_metavar, is_macos
from facefusion.common_helper import create_int_metavar
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
from facefusion.execution import has_execution_provider
from facefusion.face_analyser import scale_face
from facefusion.face_helper import merge_matrix, paste_back, scale_face_landmark_5, warp_face_by_face_landmark_5
from facefusion.face_masker import create_box_mask, create_occlusion_mask
@@ -231,9 +230,6 @@ def forward(crop_vision_frame : VisionFrame, extend_vision_frame : VisionFrame,
age_modifier = get_inference_pool().get('age_modifier')
age_modifier_inputs = {}
if is_macos() and has_execution_provider('coreml'):
age_modifier.set_providers([ facefusion.choices.execution_provider_set.get('cpu') ])
for age_modifier_input in age_modifier.get_inputs():
if age_modifier_input.name == 'target':
age_modifier_inputs[age_modifier_input.name] = crop_vision_frame
@@ -5,6 +5,7 @@ from typing import List, Tuple
import cv2
import numpy
import facefusion.choices
import facefusion.jobs.job_manager
import facefusion.jobs.job_store
from facefusion import config, content_analyser, inference_manager, logger, state_manager, translator, video_manager
@@ -19,7 +20,7 @@ from facefusion.processors.types import ProcessorOutputs
from facefusion.program_helper import find_argument_group
from facefusion.sanitizer import sanitize_int_range
from facefusion.thread_helper import thread_semaphore
from facefusion.types import ApplyStateItem, Args, DownloadScope, ExecutionProvider, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, VisionFrame
from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, InferenceProvider, Mask, ModelOptions, ModelSet, ProcessMode, VisionFrame
from facefusion.vision import read_static_image, read_static_video_frame
@@ -477,12 +478,13 @@ def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__, model_names)
def resolve_execution_providers() -> List[ExecutionProvider]:
def resolve_inference_providers() -> List[InferenceProvider]:
model_type = get_model_options().get('type')
if is_macos() and has_execution_provider('coreml') or is_windows() and has_execution_provider('directml') and model_type == 'corridor_key':
return [ 'cpu' ]
return state_manager.get_item('execution_providers')
return [ facefusion.choices.execution_provider_set.get('cpu') ]
return []
def get_model_options() -> ModelOptions:
@@ -24,7 +24,7 @@ from facefusion.processors.pixel_boost import explode_pixel_boost, implode_pixel
from facefusion.processors.types import ProcessorOutputs
from facefusion.program_helper import find_argument_group
from facefusion.thread_helper import conditional_thread_semaphore
from facefusion.types import ApplyStateItem, Args, DownloadScope, Embedding, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame
from facefusion.types import ApplyStateItem, Args, DownloadScope, Embedding, Face, InferencePool, InferenceProvider, ModelOptions, ModelSet, ProcessMode, VisionFrame
from facefusion.vision import read_static_image, read_static_images, read_static_video_frame, unpack_resolution
@@ -246,6 +246,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'path': resolve_relative_path('../.assets/models/hyperswap_1a_256.onnx')
}
},
'precision': 'fp16',
'type': 'hyperswap',
'template': 'arcface_128',
'size': (256, 256),
@@ -276,6 +277,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'path': resolve_relative_path('../.assets/models/hyperswap_1b_256.onnx')
}
},
'precision': 'fp16',
'type': 'hyperswap',
'template': 'arcface_128',
'size': (256, 256),
@@ -306,6 +308,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'path': resolve_relative_path('../.assets/models/hyperswap_1c_256.onnx')
}
},
'precision': 'fp16',
'type': 'hyperswap',
'template': 'arcface_128',
'size': (256, 256),
@@ -366,6 +369,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'path': resolve_relative_path('../.assets/models/inswapper_128_fp16.onnx')
}
},
'precision': 'fp16',
'type': 'inswapper',
'template': 'arcface_128',
'size': (128, 128),
@@ -486,28 +490,38 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ get_model_name() ]
model_names = [ state_manager.get_item('face_swapper_model') ]
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
model_names = [ get_model_name() ]
model_names = [ state_manager.get_item('face_swapper_model') ]
inference_manager.clear_inference_pool(__name__, model_names)
def resolve_inference_providers() -> List[InferenceProvider]:
model_precision = get_model_options().get('precision')
model_type = get_model_options().get('type')
if is_macos() and has_execution_provider('coreml'):
if model_type in [ 'ghost', 'uniface' ] or model_precision == 'fp16':
return\
[
(facefusion.choices.execution_provider_set.get('coreml'),
{
'ModelFormat': 'MLProgram',
'SpecializationStrategy': 'FastPrediction'
})
]
return []
def get_model_options() -> ModelOptions:
model_name = get_model_name()
return create_static_model_set('full').get(model_name)
def get_model_name() -> str:
model_name = state_manager.get_item('face_swapper_model')
if is_macos() and has_execution_provider('coreml') and model_name == 'inswapper_128_fp16':
return 'inswapper_128'
return model_name
return create_static_model_set('full').get(model_name)
def register_args(program : ArgumentParser) -> None:
@@ -622,9 +636,6 @@ def forward_swap_face(source_face : Face, target_face : Face, crop_vision_frame
model_type = get_model_options().get('type')
face_swapper_inputs = {}
if is_macos() and has_execution_provider('coreml') and model_type in [ 'ghost', 'uniface' ]:
face_swapper.set_providers([ facefusion.choices.execution_provider_set.get('cpu') ])
for face_swapper_input in face_swapper.get_inputs():
if face_swapper_input.name == 'source':
if model_type in [ 'blendswap', 'uniface' ]:
@@ -5,6 +5,7 @@ from typing import List
import cv2
import numpy
import facefusion.choices
import facefusion.jobs.job_manager
import facefusion.jobs.job_store
from facefusion import config, content_analyser, inference_manager, logger, state_manager, translator, video_manager
@@ -17,7 +18,7 @@ from facefusion.processors.modules.frame_colorizer.types import FrameColorizerIn
from facefusion.processors.types import ProcessorOutputs
from facefusion.program_helper import find_argument_group
from facefusion.thread_helper import thread_semaphore
from facefusion.types import ApplyStateItem, Args, DownloadScope, ExecutionProvider, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame
from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, InferenceProvider, ModelOptions, ModelSet, ProcessMode, VisionFrame
from facefusion.vision import blend_frame, read_static_image, read_static_video_frame, unpack_resolution
@@ -170,10 +171,11 @@ def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__, model_names)
def resolve_execution_providers() -> List[ExecutionProvider]:
def resolve_inference_providers() -> List[InferenceProvider]:
if is_macos() and has_execution_provider('coreml'):
return [ 'cpu' ]
return state_manager.get_item('execution_providers')
return [ facefusion.choices.execution_provider_set.get('cpu') ]
return []
def get_model_options() -> ModelOptions:
@@ -1,9 +1,11 @@
from argparse import ArgumentParser
from functools import lru_cache
from typing import List
import cv2
import numpy
import facefusion.choices
import facefusion.jobs.job_manager
import facefusion.jobs.job_store
from facefusion import config, content_analyser, inference_manager, logger, state_manager, translator, video_manager
@@ -16,7 +18,7 @@ from facefusion.processors.modules.frame_enhancer.types import FrameEnhancerInpu
from facefusion.processors.types import ProcessorOutputs
from facefusion.program_helper import find_argument_group
from facefusion.thread_helper import conditional_thread_semaphore
from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame
from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, InferenceProvider, ModelOptions, ModelSet, ProcessMode, VisionFrame
from facefusion.vision import blend_frame, create_tile_frames, merge_tile_frames, read_static_image, read_static_video_frame
@@ -156,6 +158,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'path': resolve_relative_path('../.assets/models/real_esrgan_x2_fp16.onnx')
}
},
'precision': 'fp16',
'size': (256, 16, 8),
'scale': 2
},
@@ -210,6 +213,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'path': resolve_relative_path('../.assets/models/real_esrgan_x4_fp16.onnx')
}
},
'precision': 'fp16',
'size': (256, 16, 8),
'scale': 4
},
@@ -264,6 +268,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'path': resolve_relative_path('../.assets/models/real_esrgan_x8_fp16.onnx')
}
},
'precision': 'fp16',
'size': (256, 16, 8),
'scale': 8
},
@@ -541,35 +546,38 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ get_frame_enhancer_model() ]
model_names = [ state_manager.get_item('frame_enhancer_model') ]
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
model_names = [ get_frame_enhancer_model() ]
model_names = [ state_manager.get_item('frame_enhancer_model') ]
inference_manager.clear_inference_pool(__name__, model_names)
def resolve_inference_providers() -> List[InferenceProvider]:
model_precision = get_model_options().get('precision')
if is_macos() and has_execution_provider('coreml') and model_precision == 'fp16':
return\
[
(facefusion.choices.execution_provider_set.get('coreml'),
{
'ModelFormat': 'MLProgram',
'SpecializationStrategy': 'FastPrediction'
})
]
return []
def get_model_options() -> ModelOptions:
model_name = get_frame_enhancer_model()
model_name = state_manager.get_item('frame_enhancer_model')
return create_static_model_set('full').get(model_name)
def get_frame_enhancer_model() -> str:
frame_enhancer_model = state_manager.get_item('frame_enhancer_model')
if is_macos() and has_execution_provider('coreml'):
if frame_enhancer_model == 'real_esrgan_x2_fp16':
return 'real_esrgan_x2'
if frame_enhancer_model == 'real_esrgan_x4_fp16':
return 'real_esrgan_x4'
if frame_enhancer_model == 'real_esrgan_x8_fp16':
return 'real_esrgan_x8'
return frame_enhancer_model
def register_args(program : ArgumentParser) -> None:
group_processors = find_argument_group(program, 'processors')
if group_processors: