mirror of
https://github.com/facefusion/facefusion.git
synced 2026-06-06 20:53:54 +02:00
resolve static inference providers to fix macos (#1127)
* resolve static inference providers to fix macos * fix lint * restore old behaviour * restore old behaviour * handle ghost and uniface as well * adjust condition for ghost and uniface
This commit is contained in:
@@ -1,16 +1,14 @@
|
||||
from functools import lru_cache
|
||||
from typing import List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import numpy
|
||||
from tqdm import tqdm
|
||||
|
||||
from facefusion import inference_manager, state_manager, translator
|
||||
from facefusion.common_helper import is_macos
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.execution import has_execution_provider
|
||||
from facefusion.filesystem import resolve_relative_path
|
||||
from facefusion.thread_helper import conditional_thread_semaphore
|
||||
from facefusion.types import Detection, DownloadScope, DownloadSet, ExecutionProvider, Fps, InferencePool, ModelSet, VisionFrame
|
||||
from facefusion.types import Detection, DownloadScope, DownloadSet, Fps, InferencePool, ModelSet, VisionFrame
|
||||
from facefusion.vision import detect_video_fps, fit_contain_frame, read_image, read_video_frame
|
||||
|
||||
STREAM_COUNTER = 0
|
||||
@@ -119,12 +117,6 @@ def clear_inference_pool() -> None:
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def resolve_execution_providers() -> List[ExecutionProvider]:
|
||||
if is_macos() and has_execution_provider('coreml'):
|
||||
return [ 'cpu' ]
|
||||
return state_manager.get_item('execution_providers')
|
||||
|
||||
|
||||
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||
model_set = create_static_model_set('full')
|
||||
model_hash_set = {}
|
||||
|
||||
+1
-1
@@ -115,7 +115,7 @@ def common_pre_check() -> bool:
|
||||
content_analyser_content = inspect.getsource(content_analyser).encode()
|
||||
content_analyser_hash = hash_helper.create_hash(content_analyser_content)
|
||||
|
||||
return all(module.pre_check() for module in common_modules) and content_analyser_hash == 'b14e7b92'
|
||||
return all(module.pre_check() for module in common_modules) and content_analyser_hash == '05843613'
|
||||
|
||||
|
||||
def processors_pre_check() -> bool:
|
||||
|
||||
@@ -13,7 +13,7 @@ from facefusion.execution import create_inference_providers, has_execution_provi
|
||||
from facefusion.exit_helper import fatal_exit
|
||||
from facefusion.filesystem import get_file_name, is_file
|
||||
from facefusion.time_helper import calculate_end_time
|
||||
from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet
|
||||
from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet, InferenceProvider
|
||||
|
||||
INFERENCE_POOL_SET : InferencePoolSet =\
|
||||
{
|
||||
@@ -26,7 +26,7 @@ def get_inference_pool(module_name : str, model_names : List[str], model_source_
|
||||
while process_manager.is_checking():
|
||||
sleep(0.5)
|
||||
execution_device_ids = state_manager.get_item('execution_device_ids')
|
||||
execution_providers = resolve_static_execution_providers(module_name)
|
||||
execution_providers = state_manager.get_item('execution_providers')
|
||||
app_context = detect_app_context()
|
||||
|
||||
for execution_device_id in execution_device_ids:
|
||||
@@ -37,26 +37,27 @@ def get_inference_pool(module_name : str, model_names : List[str], model_source_
|
||||
if app_context == 'ui' and INFERENCE_POOL_SET.get('cli').get(inference_context):
|
||||
INFERENCE_POOL_SET['ui'][inference_context] = INFERENCE_POOL_SET.get('cli').get(inference_context)
|
||||
if not INFERENCE_POOL_SET.get(app_context).get(inference_context):
|
||||
INFERENCE_POOL_SET[app_context][inference_context] = create_inference_pool(model_source_set, execution_device_id, execution_providers)
|
||||
inference_providers = resolve_static_inference_providers(module_name, execution_device_id)
|
||||
INFERENCE_POOL_SET[app_context][inference_context] = create_inference_pool(model_source_set, inference_providers)
|
||||
|
||||
current_inference_context = get_inference_context(module_name, model_names, random.choice(execution_device_ids), execution_providers)
|
||||
return INFERENCE_POOL_SET.get(app_context).get(current_inference_context)
|
||||
|
||||
|
||||
def create_inference_pool(model_source_set : DownloadSet, execution_device_id : int, execution_providers : List[ExecutionProvider]) -> InferencePool:
|
||||
def create_inference_pool(model_source_set : DownloadSet, inference_providers : List[InferenceProvider]) -> InferencePool:
|
||||
inference_pool : InferencePool = {}
|
||||
|
||||
for model_name in model_source_set.keys():
|
||||
model_path = model_source_set.get(model_name).get('path')
|
||||
if is_file(model_path):
|
||||
inference_pool[model_name] = create_inference_session(model_path, execution_device_id, execution_providers)
|
||||
inference_pool[model_name] = create_inference_session(model_path, inference_providers)
|
||||
|
||||
return inference_pool
|
||||
|
||||
|
||||
def clear_inference_pool(module_name : str, model_names : List[str]) -> None:
|
||||
execution_device_ids = state_manager.get_item('execution_device_ids')
|
||||
execution_providers = resolve_static_execution_providers(module_name)
|
||||
execution_providers = state_manager.get_item('execution_providers')
|
||||
app_context = detect_app_context()
|
||||
|
||||
if is_windows() and has_execution_provider('directml'):
|
||||
@@ -68,12 +69,11 @@ def clear_inference_pool(module_name : str, model_names : List[str]) -> None:
|
||||
del INFERENCE_POOL_SET[app_context][inference_context]
|
||||
|
||||
|
||||
def create_inference_session(model_path : str, execution_device_id : int, execution_providers : List[ExecutionProvider]) -> InferenceSession:
|
||||
def create_inference_session(model_path : str, inference_providers : List[InferenceProvider]) -> InferenceSession:
|
||||
model_file_name = get_file_name(model_path)
|
||||
start_time = time()
|
||||
|
||||
try:
|
||||
inference_providers = create_inference_providers(execution_device_id, execution_providers)
|
||||
inference_session = InferenceSession(model_path, providers = inference_providers)
|
||||
logger.debug(translator.get('loading_model_succeeded').format(model_name = model_file_name, seconds = calculate_end_time(start_time)), __name__)
|
||||
return inference_session
|
||||
@@ -89,9 +89,14 @@ def get_inference_context(module_name : str, model_names : List[str], execution_
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def resolve_static_execution_providers(module_name : str) -> List[ExecutionProvider]:
|
||||
def resolve_static_inference_providers(module_name : str, execution_device_id : int) -> List[InferenceProvider]:
|
||||
module = importlib.import_module(module_name)
|
||||
execution_providers = state_manager.get_item('execution_providers')
|
||||
|
||||
if hasattr(module, 'resolve_execution_providers'):
|
||||
return getattr(module, 'resolve_execution_providers')()
|
||||
return state_manager.get_item('execution_providers')
|
||||
if hasattr(module, 'resolve_inference_providers'):
|
||||
inference_providers = getattr(module, 'resolve_inference_providers')()
|
||||
|
||||
if inference_providers:
|
||||
return inference_providers
|
||||
|
||||
return create_inference_providers(execution_device_id, execution_providers)
|
||||
|
||||
@@ -8,9 +8,8 @@ import facefusion.choices
|
||||
import facefusion.jobs.job_manager
|
||||
import facefusion.jobs.job_store
|
||||
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, state_manager, translator, video_manager
|
||||
from facefusion.common_helper import create_int_metavar, is_macos
|
||||
from facefusion.common_helper import create_int_metavar
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.execution import has_execution_provider
|
||||
from facefusion.face_analyser import scale_face
|
||||
from facefusion.face_helper import merge_matrix, paste_back, scale_face_landmark_5, warp_face_by_face_landmark_5
|
||||
from facefusion.face_masker import create_box_mask, create_occlusion_mask
|
||||
@@ -231,9 +230,6 @@ def forward(crop_vision_frame : VisionFrame, extend_vision_frame : VisionFrame,
|
||||
age_modifier = get_inference_pool().get('age_modifier')
|
||||
age_modifier_inputs = {}
|
||||
|
||||
if is_macos() and has_execution_provider('coreml'):
|
||||
age_modifier.set_providers([ facefusion.choices.execution_provider_set.get('cpu') ])
|
||||
|
||||
for age_modifier_input in age_modifier.get_inputs():
|
||||
if age_modifier_input.name == 'target':
|
||||
age_modifier_inputs[age_modifier_input.name] = crop_vision_frame
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import List, Tuple
|
||||
import cv2
|
||||
import numpy
|
||||
|
||||
import facefusion.choices
|
||||
import facefusion.jobs.job_manager
|
||||
import facefusion.jobs.job_store
|
||||
from facefusion import config, content_analyser, inference_manager, logger, state_manager, translator, video_manager
|
||||
@@ -19,7 +20,7 @@ from facefusion.processors.types import ProcessorOutputs
|
||||
from facefusion.program_helper import find_argument_group
|
||||
from facefusion.sanitizer import sanitize_int_range
|
||||
from facefusion.thread_helper import thread_semaphore
|
||||
from facefusion.types import ApplyStateItem, Args, DownloadScope, ExecutionProvider, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, VisionFrame
|
||||
from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, InferenceProvider, Mask, ModelOptions, ModelSet, ProcessMode, VisionFrame
|
||||
from facefusion.vision import read_static_image, read_static_video_frame
|
||||
|
||||
|
||||
@@ -477,12 +478,13 @@ def clear_inference_pool() -> None:
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def resolve_execution_providers() -> List[ExecutionProvider]:
|
||||
def resolve_inference_providers() -> List[InferenceProvider]:
|
||||
model_type = get_model_options().get('type')
|
||||
|
||||
if is_macos() and has_execution_provider('coreml') or is_windows() and has_execution_provider('directml') and model_type == 'corridor_key':
|
||||
return [ 'cpu' ]
|
||||
return state_manager.get_item('execution_providers')
|
||||
return [ facefusion.choices.execution_provider_set.get('cpu') ]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
|
||||
@@ -24,7 +24,7 @@ from facefusion.processors.pixel_boost import explode_pixel_boost, implode_pixel
|
||||
from facefusion.processors.types import ProcessorOutputs
|
||||
from facefusion.program_helper import find_argument_group
|
||||
from facefusion.thread_helper import conditional_thread_semaphore
|
||||
from facefusion.types import ApplyStateItem, Args, DownloadScope, Embedding, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame
|
||||
from facefusion.types import ApplyStateItem, Args, DownloadScope, Embedding, Face, InferencePool, InferenceProvider, ModelOptions, ModelSet, ProcessMode, VisionFrame
|
||||
from facefusion.vision import read_static_image, read_static_images, read_static_video_frame, unpack_resolution
|
||||
|
||||
|
||||
@@ -246,6 +246,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'path': resolve_relative_path('../.assets/models/hyperswap_1a_256.onnx')
|
||||
}
|
||||
},
|
||||
'precision': 'fp16',
|
||||
'type': 'hyperswap',
|
||||
'template': 'arcface_128',
|
||||
'size': (256, 256),
|
||||
@@ -276,6 +277,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'path': resolve_relative_path('../.assets/models/hyperswap_1b_256.onnx')
|
||||
}
|
||||
},
|
||||
'precision': 'fp16',
|
||||
'type': 'hyperswap',
|
||||
'template': 'arcface_128',
|
||||
'size': (256, 256),
|
||||
@@ -306,6 +308,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'path': resolve_relative_path('../.assets/models/hyperswap_1c_256.onnx')
|
||||
}
|
||||
},
|
||||
'precision': 'fp16',
|
||||
'type': 'hyperswap',
|
||||
'template': 'arcface_128',
|
||||
'size': (256, 256),
|
||||
@@ -366,6 +369,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'path': resolve_relative_path('../.assets/models/inswapper_128_fp16.onnx')
|
||||
}
|
||||
},
|
||||
'precision': 'fp16',
|
||||
'type': 'inswapper',
|
||||
'template': 'arcface_128',
|
||||
'size': (128, 128),
|
||||
@@ -486,28 +490,38 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ get_model_name() ]
|
||||
model_names = [ state_manager.get_item('face_swapper_model') ]
|
||||
model_source_set = get_model_options().get('sources')
|
||||
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_names = [ get_model_name() ]
|
||||
model_names = [ state_manager.get_item('face_swapper_model') ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def resolve_inference_providers() -> List[InferenceProvider]:
|
||||
model_precision = get_model_options().get('precision')
|
||||
model_type = get_model_options().get('type')
|
||||
|
||||
if is_macos() and has_execution_provider('coreml'):
|
||||
if model_type in [ 'ghost', 'uniface' ] or model_precision == 'fp16':
|
||||
return\
|
||||
[
|
||||
(facefusion.choices.execution_provider_set.get('coreml'),
|
||||
{
|
||||
'ModelFormat': 'MLProgram',
|
||||
'SpecializationStrategy': 'FastPrediction'
|
||||
})
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
model_name = get_model_name()
|
||||
return create_static_model_set('full').get(model_name)
|
||||
|
||||
|
||||
def get_model_name() -> str:
|
||||
model_name = state_manager.get_item('face_swapper_model')
|
||||
|
||||
if is_macos() and has_execution_provider('coreml') and model_name == 'inswapper_128_fp16':
|
||||
return 'inswapper_128'
|
||||
return model_name
|
||||
return create_static_model_set('full').get(model_name)
|
||||
|
||||
|
||||
def register_args(program : ArgumentParser) -> None:
|
||||
@@ -622,9 +636,6 @@ def forward_swap_face(source_face : Face, target_face : Face, crop_vision_frame
|
||||
model_type = get_model_options().get('type')
|
||||
face_swapper_inputs = {}
|
||||
|
||||
if is_macos() and has_execution_provider('coreml') and model_type in [ 'ghost', 'uniface' ]:
|
||||
face_swapper.set_providers([ facefusion.choices.execution_provider_set.get('cpu') ])
|
||||
|
||||
for face_swapper_input in face_swapper.get_inputs():
|
||||
if face_swapper_input.name == 'source':
|
||||
if model_type in [ 'blendswap', 'uniface' ]:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import List
|
||||
import cv2
|
||||
import numpy
|
||||
|
||||
import facefusion.choices
|
||||
import facefusion.jobs.job_manager
|
||||
import facefusion.jobs.job_store
|
||||
from facefusion import config, content_analyser, inference_manager, logger, state_manager, translator, video_manager
|
||||
@@ -17,7 +18,7 @@ from facefusion.processors.modules.frame_colorizer.types import FrameColorizerIn
|
||||
from facefusion.processors.types import ProcessorOutputs
|
||||
from facefusion.program_helper import find_argument_group
|
||||
from facefusion.thread_helper import thread_semaphore
|
||||
from facefusion.types import ApplyStateItem, Args, DownloadScope, ExecutionProvider, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame
|
||||
from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, InferenceProvider, ModelOptions, ModelSet, ProcessMode, VisionFrame
|
||||
from facefusion.vision import blend_frame, read_static_image, read_static_video_frame, unpack_resolution
|
||||
|
||||
|
||||
@@ -170,10 +171,11 @@ def clear_inference_pool() -> None:
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def resolve_execution_providers() -> List[ExecutionProvider]:
|
||||
def resolve_inference_providers() -> List[InferenceProvider]:
|
||||
if is_macos() and has_execution_provider('coreml'):
|
||||
return [ 'cpu' ]
|
||||
return state_manager.get_item('execution_providers')
|
||||
return [ facefusion.choices.execution_provider_set.get('cpu') ]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from argparse import ArgumentParser
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
|
||||
import facefusion.choices
|
||||
import facefusion.jobs.job_manager
|
||||
import facefusion.jobs.job_store
|
||||
from facefusion import config, content_analyser, inference_manager, logger, state_manager, translator, video_manager
|
||||
@@ -16,7 +18,7 @@ from facefusion.processors.modules.frame_enhancer.types import FrameEnhancerInpu
|
||||
from facefusion.processors.types import ProcessorOutputs
|
||||
from facefusion.program_helper import find_argument_group
|
||||
from facefusion.thread_helper import conditional_thread_semaphore
|
||||
from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame
|
||||
from facefusion.types import ApplyStateItem, Args, DownloadScope, InferencePool, InferenceProvider, ModelOptions, ModelSet, ProcessMode, VisionFrame
|
||||
from facefusion.vision import blend_frame, create_tile_frames, merge_tile_frames, read_static_image, read_static_video_frame
|
||||
|
||||
|
||||
@@ -156,6 +158,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'path': resolve_relative_path('../.assets/models/real_esrgan_x2_fp16.onnx')
|
||||
}
|
||||
},
|
||||
'precision': 'fp16',
|
||||
'size': (256, 16, 8),
|
||||
'scale': 2
|
||||
},
|
||||
@@ -210,6 +213,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'path': resolve_relative_path('../.assets/models/real_esrgan_x4_fp16.onnx')
|
||||
}
|
||||
},
|
||||
'precision': 'fp16',
|
||||
'size': (256, 16, 8),
|
||||
'scale': 4
|
||||
},
|
||||
@@ -264,6 +268,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'path': resolve_relative_path('../.assets/models/real_esrgan_x8_fp16.onnx')
|
||||
}
|
||||
},
|
||||
'precision': 'fp16',
|
||||
'size': (256, 16, 8),
|
||||
'scale': 8
|
||||
},
|
||||
@@ -541,35 +546,38 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ get_frame_enhancer_model() ]
|
||||
model_names = [ state_manager.get_item('frame_enhancer_model') ]
|
||||
model_source_set = get_model_options().get('sources')
|
||||
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_names = [ get_frame_enhancer_model() ]
|
||||
model_names = [ state_manager.get_item('frame_enhancer_model') ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def resolve_inference_providers() -> List[InferenceProvider]:
|
||||
model_precision = get_model_options().get('precision')
|
||||
|
||||
if is_macos() and has_execution_provider('coreml') and model_precision == 'fp16':
|
||||
return\
|
||||
[
|
||||
(facefusion.choices.execution_provider_set.get('coreml'),
|
||||
{
|
||||
'ModelFormat': 'MLProgram',
|
||||
'SpecializationStrategy': 'FastPrediction'
|
||||
})
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
model_name = get_frame_enhancer_model()
|
||||
model_name = state_manager.get_item('frame_enhancer_model')
|
||||
return create_static_model_set('full').get(model_name)
|
||||
|
||||
|
||||
def get_frame_enhancer_model() -> str:
|
||||
frame_enhancer_model = state_manager.get_item('frame_enhancer_model')
|
||||
|
||||
if is_macos() and has_execution_provider('coreml'):
|
||||
if frame_enhancer_model == 'real_esrgan_x2_fp16':
|
||||
return 'real_esrgan_x2'
|
||||
if frame_enhancer_model == 'real_esrgan_x4_fp16':
|
||||
return 'real_esrgan_x4'
|
||||
if frame_enhancer_model == 'real_esrgan_x8_fp16':
|
||||
return 'real_esrgan_x8'
|
||||
return frame_enhancer_model
|
||||
|
||||
|
||||
def register_args(program : ArgumentParser) -> None:
|
||||
group_processors = find_argument_group(program, 'processors')
|
||||
if group_processors:
|
||||
|
||||
Reference in New Issue
Block a user