diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index 655ed5d8..7a3cd4e1 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -1,16 +1,14 @@ from functools import lru_cache -from typing import List, Tuple +from typing import Tuple import numpy from tqdm import tqdm from facefusion import inference_manager, state_manager, translator -from facefusion.common_helper import is_macos from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url -from facefusion.execution import has_execution_provider from facefusion.filesystem import resolve_relative_path from facefusion.thread_helper import conditional_thread_semaphore -from facefusion.types import Detection, DownloadScope, DownloadSet, ExecutionProvider, Fps, InferencePool, ModelSet, VisionFrame +from facefusion.types import Detection, DownloadScope, DownloadSet, Fps, InferencePool, ModelSet, VisionFrame from facefusion.vision import detect_video_fps, fit_contain_frame, read_image, read_video_frame STREAM_COUNTER = 0 @@ -119,12 +117,6 @@ def clear_inference_pool() -> None: inference_manager.clear_inference_pool(__name__, model_names) -def resolve_execution_providers() -> List[ExecutionProvider]: - if is_macos() and has_execution_provider('coreml'): - return [ 'cpu' ] - return state_manager.get_item('execution_providers') - - def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: model_set = create_static_model_set('full') model_hash_set = {} diff --git a/facefusion/core.py b/facefusion/core.py index 9d290880..2027d8c1 100755 --- a/facefusion/core.py +++ b/facefusion/core.py @@ -115,7 +115,7 @@ def common_pre_check() -> bool: content_analyser_content = inspect.getsource(content_analyser).encode() content_analyser_hash = hash_helper.create_hash(content_analyser_content) - return all(module.pre_check() for module in common_modules) and content_analyser_hash == 'b14e7b92' + return all(module.pre_check() for module in common_modules) and content_analyser_hash == '05843613' def processors_pre_check() -> bool: diff --git a/facefusion/inference_manager.py b/facefusion/inference_manager.py index cde6865b..3b10eed3 100644 --- a/facefusion/inference_manager.py +++ b/facefusion/inference_manager.py @@ -13,7 +13,7 @@ from facefusion.execution import create_inference_providers, has_execution_provi from facefusion.exit_helper import fatal_exit from facefusion.filesystem import get_file_name, is_file from facefusion.time_helper import calculate_end_time -from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet +from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet, InferenceProvider INFERENCE_POOL_SET : InferencePoolSet =\ { @@ -26,7 +26,7 @@ def get_inference_pool(module_name : str, model_names : List[str], model_source_ while process_manager.is_checking(): sleep(0.5) execution_device_ids = state_manager.get_item('execution_device_ids') - execution_providers = resolve_static_execution_providers(module_name) + execution_providers = state_manager.get_item('execution_providers') app_context = detect_app_context() for execution_device_id in execution_device_ids: @@ -37,26 +37,27 @@ def get_inference_pool(module_name : str, model_names : List[str], model_source_ if app_context == 'ui' and INFERENCE_POOL_SET.get('cli').get(inference_context): INFERENCE_POOL_SET['ui'][inference_context] = INFERENCE_POOL_SET.get('cli').get(inference_context) if not INFERENCE_POOL_SET.get(app_context).get(inference_context): - INFERENCE_POOL_SET[app_context][inference_context] = create_inference_pool(model_source_set, execution_device_id, execution_providers) + inference_providers = resolve_static_inference_providers(module_name, execution_device_id) + INFERENCE_POOL_SET[app_context][inference_context] = create_inference_pool(model_source_set, inference_providers) current_inference_context = get_inference_context(module_name, model_names, random.choice(execution_device_ids), execution_providers) return INFERENCE_POOL_SET.get(app_context).get(current_inference_context) -def create_inference_pool(model_source_set : DownloadSet, execution_device_id : int, execution_providers : List[ExecutionProvider]) -> InferencePool: +def create_inference_pool(model_source_set : DownloadSet, inference_providers : List[InferenceProvider]) -> InferencePool: inference_pool : InferencePool = {} for model_name in model_source_set.keys(): model_path = model_source_set.get(model_name).get('path') if is_file(model_path): - inference_pool[model_name] = create_inference_session(model_path, execution_device_id, execution_providers) + inference_pool[model_name] = create_inference_session(model_path, inference_providers) return inference_pool def clear_inference_pool(module_name : str, model_names : List[str]) -> None: execution_device_ids = state_manager.get_item('execution_device_ids') - execution_providers = resolve_static_execution_providers(module_name) + execution_providers = state_manager.get_item('execution_providers') app_context = detect_app_context() if is_windows() and has_execution_provider('directml'): @@ -68,12 +69,11 @@ def clear_inference_pool(module_name : str, model_names : List[str]) -> None: del INFERENCE_POOL_SET[app_context][inference_context] -def create_inference_session(model_path : str, execution_device_id : int, execution_providers : List[ExecutionProvider]) -> InferenceSession: +def create_inference_session(model_path : str, inference_providers : List[InferenceProvider]) -> InferenceSession: model_file_name = get_file_name(model_path) start_time = time() try: - inference_providers = create_inference_providers(execution_device_id, execution_providers) inference_session = InferenceSession(model_path, providers = inference_providers) logger.debug(translator.get('loading_model_succeeded').format(model_name = model_file_name, seconds = calculate_end_time(start_time)), __name__) return inference_session @@ -89,9 +89,14 @@ def get_inference_context(module_name : str, model_names : List[str], execution_ @lru_cache() -def resolve_static_execution_providers(module_name : str) -> List[ExecutionProvider]: +def resolve_static_inference_providers(module_name : str, execution_device_id : int) -> List[InferenceProvider]: module = importlib.import_module(module_name) + execution_providers = state_manager.get_item('execution_providers') - if hasattr(module, 'resolve_execution_providers'): - return getattr(module, 'resolve_execution_providers')() - return state_manager.get_item('execution_providers') + if hasattr(module, 'resolve_inference_providers'): + inference_providers = getattr(module, 'resolve_inference_providers')() + + if inference_providers: + return inference_providers + + return create_inference_providers(execution_device_id, execution_providers) diff --git a/facefusion/processors/modules/age_modifier/core.py b/facefusion/processors/modules/age_modifier/core.py index 385c2a56..413c4333 100755 --- a/facefusion/processors/modules/age_modifier/core.py +++ b/facefusion/processors/modules/age_modifier/core.py @@ -8,9 +8,8 @@ import facefusion.choices import facefusion.jobs.job_manager import facefusion.jobs.job_store from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, state_manager, translator, video_manager -from facefusion.common_helper import create_int_metavar, is_macos +from facefusion.common_helper import create_int_metavar from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url -from facefusion.execution import has_execution_provider from facefusion.face_analyser import scale_face from facefusion.face_helper import merge_matrix, paste_back, scale_face_landmark_5, warp_face_by_face_landmark_5 from facefusion.face_masker import create_box_mask, create_occlusion_mask @@ -231,9 +230,6 @@ def forward(crop_vision_frame : VisionFrame, extend_vision_frame : VisionFrame, age_modifier = get_inference_pool().get('age_modifier') age_modifier_inputs = {} - if is_macos() and has_execution_provider('coreml'): - age_modifier.set_providers([ facefusion.choices.execution_provider_set.get('cpu') ]) - for age_modifier_input in age_modifier.get_inputs(): if age_modifier_input.name == 'target': age_modifier_inputs[age_modifier_input.name] = crop_vision_frame diff --git a/facefusion/processors/modules/background_remover/core.py b/facefusion/processors/modules/background_remover/core.py index e1dbe82c..b605b48e 100644 --- a/facefusion/processors/modules/background_remover/core.py +++ b/facefusion/processors/modules/background_remover/core.py @@ -5,6 +5,7 @@ from typing import List, Tuple import cv2 import numpy +import facefusion.choices import facefusion.jobs.job_manager import facefusion.jobs.job_store from facefusion import config, content_analyser, inference_manager, logger, state_manager, translator, video_manager @@ -19,7 +20,7 @@ from facefusion.processors.types import ProcessorOutputs from facefusion.program_helper import find_argument_group from facefusion.sanitizer import sanitize_int_range from facefusion.thread_helper import thread_semaphore -from facefusion.types import ApplyStateItem, Args, DownloadScope, ExecutionProvider, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, VisionFrame +from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, InferenceProvider, Mask, ModelOptions, ModelSet, ProcessMode, VisionFrame from facefusion.vision import read_static_image, read_static_video_frame @@ -477,12 +478,13 @@ def clear_inference_pool() -> None: inference_manager.clear_inference_pool(__name__, model_names) -def resolve_execution_providers() -> List[ExecutionProvider]: +def resolve_inference_providers() -> List[InferenceProvider]: model_type = get_model_options().get('type') if is_macos() and has_execution_provider('coreml') or is_windows() and has_execution_provider('directml') and model_type == 'corridor_key': - return [ 'cpu' ] - return state_manager.get_item('execution_providers') + return [ facefusion.choices.execution_provider_set.get('cpu') ] + + return [] def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/face_swapper/core.py b/facefusion/processors/modules/face_swapper/core.py index 2ab17194..dee7d8e8 100755 --- a/facefusion/processors/modules/face_swapper/core.py +++ b/facefusion/processors/modules/face_swapper/core.py @@ -24,7 +24,7 @@ from facefusion.processors.pixel_boost import explode_pixel_boost, implode_pixel from facefusion.processors.types import ProcessorOutputs from facefusion.program_helper import find_argument_group from facefusion.thread_helper import conditional_thread_semaphore -from facefusion.types import ApplyStateItem, Args, DownloadScope, Embedding, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame +from facefusion.types import ApplyStateItem, Args, DownloadScope, Embedding, Face, InferencePool, InferenceProvider, ModelOptions, ModelSet, ProcessMode, VisionFrame from facefusion.vision import read_static_image, read_static_images, read_static_video_frame, unpack_resolution @@ -246,6 +246,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: 'path': resolve_relative_path('../.assets/models/hyperswap_1a_256.onnx') } }, + 'precision': 'fp16', 'type': 'hyperswap', 'template': 'arcface_128', 'size': (256, 256), @@ -276,6 +277,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: 'path': resolve_relative_path('../.assets/models/hyperswap_1b_256.onnx') } }, + 'precision': 'fp16', 'type': 'hyperswap', 'template': 'arcface_128', 'size': (256, 256), @@ -306,6 +308,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: 'path': resolve_relative_path('../.assets/models/hyperswap_1c_256.onnx') } }, + 'precision': 'fp16', 'type': 'hyperswap', 'template': 'arcface_128', 'size': (256, 256), @@ -366,6 +369,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: 'path': resolve_relative_path('../.assets/models/inswapper_128_fp16.onnx') } }, + 'precision': 'fp16', 'type': 'inswapper', 'template': 'arcface_128', 'size': (128, 128), @@ -486,28 +490,38 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: - model_names = [ get_model_name() ] + model_names = [ state_manager.get_item('face_swapper_model') ] model_source_set = get_model_options().get('sources') return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: - model_names = [ get_model_name() ] + model_names = [ state_manager.get_item('face_swapper_model') ] inference_manager.clear_inference_pool(__name__, model_names) +def resolve_inference_providers() -> List[InferenceProvider]: + model_precision = get_model_options().get('precision') + model_type = get_model_options().get('type') + + if is_macos() and has_execution_provider('coreml'): + if model_type in [ 'ghost', 'uniface' ] or model_precision == 'fp16': + return\ + [ + (facefusion.choices.execution_provider_set.get('coreml'), + { + 'ModelFormat': 'MLProgram', + 'SpecializationStrategy': 'FastPrediction' + }) + ] + + return [] + + def get_model_options() -> ModelOptions: - model_name = get_model_name() - return create_static_model_set('full').get(model_name) - - -def get_model_name() -> str: model_name = state_manager.get_item('face_swapper_model') - - if is_macos() and has_execution_provider('coreml') and model_name == 'inswapper_128_fp16': - return 'inswapper_128' - return model_name + return create_static_model_set('full').get(model_name) def register_args(program : ArgumentParser) -> None: @@ -622,9 +636,6 @@ def forward_swap_face(source_face : Face, target_face : Face, crop_vision_frame model_type = get_model_options().get('type') face_swapper_inputs = {} - if is_macos() and has_execution_provider('coreml') and model_type in [ 'ghost', 'uniface' ]: - face_swapper.set_providers([ facefusion.choices.execution_provider_set.get('cpu') ]) - for face_swapper_input in face_swapper.get_inputs(): if face_swapper_input.name == 'source': if model_type in [ 'blendswap', 'uniface' ]: diff --git a/facefusion/processors/modules/frame_colorizer/core.py b/facefusion/processors/modules/frame_colorizer/core.py index d10ed18b..9658bff0 100644 --- a/facefusion/processors/modules/frame_colorizer/core.py +++ b/facefusion/processors/modules/frame_colorizer/core.py @@ -5,6 +5,7 @@ from typing import List import cv2 import numpy +import facefusion.choices import facefusion.jobs.job_manager import facefusion.jobs.job_store from facefusion import config, content_analyser, inference_manager, logger, state_manager, translator, video_manager @@ -17,7 +18,7 @@ from facefusion.processors.modules.frame_colorizer.types import FrameColorizerIn from facefusion.processors.types import ProcessorOutputs from facefusion.program_helper import find_argument_group from facefusion.thread_helper import thread_semaphore -from facefusion.types import ApplyStateItem, Args, DownloadScope, ExecutionProvider, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame +from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, InferenceProvider, ModelOptions, ModelSet, ProcessMode, VisionFrame from facefusion.vision import blend_frame, read_static_image, read_static_video_frame, unpack_resolution @@ -170,10 +171,11 @@ def clear_inference_pool() -> None: inference_manager.clear_inference_pool(__name__, model_names) -def resolve_execution_providers() -> List[ExecutionProvider]: +def resolve_inference_providers() -> List[InferenceProvider]: if is_macos() and has_execution_provider('coreml'): - return [ 'cpu' ] - return state_manager.get_item('execution_providers') + return [ facefusion.choices.execution_provider_set.get('cpu') ] + + return [] def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/frame_enhancer/core.py b/facefusion/processors/modules/frame_enhancer/core.py index fdd980e3..f329d9c3 100644 --- a/facefusion/processors/modules/frame_enhancer/core.py +++ b/facefusion/processors/modules/frame_enhancer/core.py @@ -1,9 +1,11 @@ from argparse import ArgumentParser from functools import lru_cache +from typing import List import cv2 import numpy +import facefusion.choices import facefusion.jobs.job_manager import facefusion.jobs.job_store from facefusion import config, content_analyser, inference_manager, logger, state_manager, translator, video_manager @@ -16,7 +18,7 @@ from facefusion.processors.modules.frame_enhancer.types import FrameEnhancerInpu from facefusion.processors.types import ProcessorOutputs from facefusion.program_helper import find_argument_group from facefusion.thread_helper import conditional_thread_semaphore -from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame +from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, InferenceProvider, ModelOptions, ModelSet, ProcessMode, VisionFrame from facefusion.vision import blend_frame, create_tile_frames, merge_tile_frames, read_static_image, read_static_video_frame @@ -156,6 +158,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: 'path': resolve_relative_path('../.assets/models/real_esrgan_x2_fp16.onnx') } }, + 'precision': 'fp16', 'size': (256, 16, 8), 'scale': 2 }, @@ -210,6 +213,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: 'path': resolve_relative_path('../.assets/models/real_esrgan_x4_fp16.onnx') } }, + 'precision': 'fp16', 'size': (256, 16, 8), 'scale': 4 }, @@ -264,6 +268,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: 'path': resolve_relative_path('../.assets/models/real_esrgan_x8_fp16.onnx') } }, + 'precision': 'fp16', 'size': (256, 16, 8), 'scale': 8 }, @@ -541,35 +546,38 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: - model_names = [ get_frame_enhancer_model() ] + model_names = [ state_manager.get_item('frame_enhancer_model') ] model_source_set = get_model_options().get('sources') return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: - model_names = [ get_frame_enhancer_model() ] + model_names = [ state_manager.get_item('frame_enhancer_model') ] inference_manager.clear_inference_pool(__name__, model_names) +def resolve_inference_providers() -> List[InferenceProvider]: + model_precision = get_model_options().get('precision') + + if is_macos() and has_execution_provider('coreml') and model_precision == 'fp16': + return\ + [ + (facefusion.choices.execution_provider_set.get('coreml'), + { + 'ModelFormat': 'MLProgram', + 'SpecializationStrategy': 'FastPrediction' + }) + ] + + return [] + + def get_model_options() -> ModelOptions: - model_name = get_frame_enhancer_model() + model_name = state_manager.get_item('frame_enhancer_model') return create_static_model_set('full').get(model_name) -def get_frame_enhancer_model() -> str: - frame_enhancer_model = state_manager.get_item('frame_enhancer_model') - - if is_macos() and has_execution_provider('coreml'): - if frame_enhancer_model == 'real_esrgan_x2_fp16': - return 'real_esrgan_x2' - if frame_enhancer_model == 'real_esrgan_x4_fp16': - return 'real_esrgan_x4' - if frame_enhancer_model == 'real_esrgan_x8_fp16': - return 'real_esrgan_x8' - return frame_enhancer_model - - def register_args(program : ArgumentParser) -> None: group_processors = find_argument_group(program, 'processors') if group_processors: