diff --git a/facefusion/core.py b/facefusion/core.py index 2027d8c1..be9398ad 100755 --- a/facefusion/core.py +++ b/facefusion/core.py @@ -5,7 +5,7 @@ import signal import sys from time import time -from facefusion import benchmarker, cli_helper, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, hash_helper, logger, state_manager, translator, voice_extractor +from facefusion import benchmarker, cli_helper, content_analyser, hash_helper, logger, state_manager, translator from facefusion.args import apply_args, collect_job_args, reduce_job_args, reduce_step_args from facefusion.download import conditional_download_hashes, conditional_download_sources from facefusion.exit_helper import hard_exit, signal_exit @@ -101,21 +101,10 @@ def pre_check() -> bool: def common_pre_check() -> bool: - common_modules =\ - [ - content_analyser, - face_classifier, - face_detector, - face_landmarker, - face_masker, - face_recognizer, - voice_extractor - ] - content_analyser_content = inspect.getsource(content_analyser).encode() content_analyser_hash = hash_helper.create_hash(content_analyser_content) - return all(module.pre_check() for module in common_modules) and content_analyser_hash == '05843613' + return content_analyser_hash == '05843613' def processors_pre_check() -> bool: @@ -126,22 +115,19 @@ def processors_pre_check() -> bool: def force_download() -> ErrorCode: - common_modules =\ - [ - content_analyser, - face_classifier, - face_detector, - face_landmarker, - face_masker, - face_recognizer, - voice_extractor - ] + download_scope = state_manager.get_item('download_scope') available_processors = [ get_file_name(file_path) for file_path in resolve_file_paths('facefusion/processors/modules') ] processor_modules = get_processors_modules(available_processors) + common_modules = [] + + for processor_module in processor_modules: + for common_module in processor_module.get_common_modules(): + if common_module not in common_modules: + common_modules.append(common_module) for module in common_modules + processor_modules: if hasattr(module, 'create_static_model_set'): - for model in module.create_static_model_set(state_manager.get_item('download_scope')).values(): + for model in module.create_static_model_set(download_scope).values(): model_hash_set = model.get('hashes') model_source_set = model.get('sources') diff --git a/facefusion/processors/core.py b/facefusion/processors/core.py index 09d45e5e..fe395445 100644 --- a/facefusion/processors/core.py +++ b/facefusion/processors/core.py @@ -12,6 +12,7 @@ PROCESSORS_METHODS =\ 'clear_inference_pool', 'register_args', 'apply_args', + 'get_common_modules', 'pre_check', 'pre_process', 'post_process', diff --git a/facefusion/processors/modules/age_modifier/core.py b/facefusion/processors/modules/age_modifier/core.py index 6ef85479..99bd921b 100755 --- a/facefusion/processors/modules/age_modifier/core.py +++ b/facefusion/processors/modules/age_modifier/core.py @@ -1,5 +1,7 @@ from argparse import ArgumentParser from functools import lru_cache +from types import ModuleType +from typing import List import cv2 import numpy @@ -133,10 +135,18 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('age_modifier_direction', args.get('age_modifier_direction')) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer ] + + def pre_check() -> bool: model_hash_set = get_model_options().get('hashes') model_source_set = get_model_options().get('sources') + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) @@ -157,15 +167,13 @@ def post_process() -> None: read_static_image.cache_clear() read_static_video_frame.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() - face_classifier.clear_inference_pool() - face_detector.clear_inference_pool() - face_landmarker.clear_inference_pool() - face_masker.clear_inference_pool() - face_recognizer.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def modify_age(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: diff --git a/facefusion/processors/modules/background_remover/core.py b/facefusion/processors/modules/background_remover/core.py index b605b48e..2ff10aa9 100644 --- a/facefusion/processors/modules/background_remover/core.py +++ b/facefusion/processors/modules/background_remover/core.py @@ -1,5 +1,6 @@ from argparse import ArgumentParser from functools import lru_cache, partial +from types import ModuleType from typing import List, Tuple import cv2 @@ -507,10 +508,18 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('background_remover_despill_color', normalize_color(args.get('background_remover_despill_color'))) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser ] + + def pre_check() -> bool: model_hash_set = get_model_options().get('hashes') model_source_set = get_model_options().get('sources') + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) @@ -531,10 +540,13 @@ def post_process() -> None: read_static_image.cache_clear() read_static_video_frame.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def remove_background(temp_vision_frame : VisionFrame) -> Tuple[VisionFrame, Mask]: diff --git a/facefusion/processors/modules/deep_swapper/core.py b/facefusion/processors/modules/deep_swapper/core.py index fcf79452..fa53477a 100755 --- a/facefusion/processors/modules/deep_swapper/core.py +++ b/facefusion/processors/modules/deep_swapper/core.py @@ -1,6 +1,7 @@ from argparse import ArgumentParser from functools import lru_cache -from typing import Tuple +from types import ModuleType +from typing import List, Tuple import cv2 import numpy @@ -286,10 +287,18 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('deep_swapper_morph', args.get('deep_swapper_morph')) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer ] + + def pre_check() -> bool: model_hash_set = get_model_options().get('hashes') model_source_set = get_model_options().get('sources') + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False + if model_hash_set and model_source_set: return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) return True @@ -312,15 +321,13 @@ def post_process() -> None: read_static_image.cache_clear() read_static_video_frame.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() - face_classifier.clear_inference_pool() - face_detector.clear_inference_pool() - face_landmarker.clear_inference_pool() - face_masker.clear_inference_pool() - face_recognizer.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def swap_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: diff --git a/facefusion/processors/modules/expression_restorer/core.py b/facefusion/processors/modules/expression_restorer/core.py index b659eec2..255bc61b 100755 --- a/facefusion/processors/modules/expression_restorer/core.py +++ b/facefusion/processors/modules/expression_restorer/core.py @@ -1,6 +1,7 @@ from argparse import ArgumentParser from functools import lru_cache -from typing import Tuple +from types import ModuleType +from typing import List, Tuple import cv2 import numpy @@ -111,10 +112,18 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('expression_restorer_areas', args.get('expression_restorer_areas')) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer ] + + def pre_check() -> bool: model_hash_set = get_model_options().get('hashes') model_source_set = get_model_options().get('sources') + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) @@ -138,15 +147,13 @@ def post_process() -> None: read_static_image.cache_clear() read_static_video_frame.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() - face_classifier.clear_inference_pool() - face_detector.clear_inference_pool() - face_landmarker.clear_inference_pool() - face_masker.clear_inference_pool() - face_recognizer.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def restore_expression(target_face : Face, target_vision_frame : VisionFrame, temp_vision_frame : VisionFrame) -> VisionFrame: diff --git a/facefusion/processors/modules/face_debugger/core.py b/facefusion/processors/modules/face_debugger/core.py index 61f8567c..20502ef7 100755 --- a/facefusion/processors/modules/face_debugger/core.py +++ b/facefusion/processors/modules/face_debugger/core.py @@ -1,4 +1,6 @@ from argparse import ArgumentParser +from types import ModuleType +from typing import List import cv2 import numpy @@ -38,7 +40,14 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('face_debugger_items', args.get('face_debugger_items')) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer ] + + def pre_check() -> bool: + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False return True @@ -59,13 +68,10 @@ def post_process() -> None: read_static_image.cache_clear() read_static_video_frame.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() - face_classifier.clear_inference_pool() - face_detector.clear_inference_pool() - face_landmarker.clear_inference_pool() - face_masker.clear_inference_pool() - face_recognizer.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def debug_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: diff --git a/facefusion/processors/modules/face_editor/core.py b/facefusion/processors/modules/face_editor/core.py index 55b3743b..43162874 100755 --- a/facefusion/processors/modules/face_editor/core.py +++ b/facefusion/processors/modules/face_editor/core.py @@ -1,6 +1,7 @@ from argparse import ArgumentParser from functools import lru_cache -from typing import Tuple +from types import ModuleType +from typing import List, Tuple import cv2 import numpy @@ -165,10 +166,18 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('face_editor_head_roll', args.get('face_editor_head_roll')) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer ] + + def pre_check() -> bool: model_hash_set = get_model_options().get('hashes') model_source_set = get_model_options().get('sources') + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) @@ -189,15 +198,13 @@ def post_process() -> None: read_static_image.cache_clear() read_static_video_frame.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() - face_classifier.clear_inference_pool() - face_detector.clear_inference_pool() - face_landmarker.clear_inference_pool() - face_masker.clear_inference_pool() - face_recognizer.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def edit_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: diff --git a/facefusion/processors/modules/face_enhancer/core.py b/facefusion/processors/modules/face_enhancer/core.py index ef5d0f0a..3b69a8fa 100755 --- a/facefusion/processors/modules/face_enhancer/core.py +++ b/facefusion/processors/modules/face_enhancer/core.py @@ -1,5 +1,7 @@ from argparse import ArgumentParser from functools import lru_cache +from types import ModuleType +from typing import List import numpy @@ -304,10 +306,18 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('face_enhancer_weight', args.get('face_enhancer_weight')) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer ] + + def pre_check() -> bool: model_hash_set = get_model_options().get('hashes') model_source_set = get_model_options().get('sources') + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) @@ -328,15 +338,13 @@ def post_process() -> None: read_static_image.cache_clear() read_static_video_frame.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() - face_classifier.clear_inference_pool() - face_detector.clear_inference_pool() - face_landmarker.clear_inference_pool() - face_masker.clear_inference_pool() - face_recognizer.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def enhance_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: diff --git a/facefusion/processors/modules/face_swapper/core.py b/facefusion/processors/modules/face_swapper/core.py index 88eb469b..216a4af2 100755 --- a/facefusion/processors/modules/face_swapper/core.py +++ b/facefusion/processors/modules/face_swapper/core.py @@ -1,5 +1,6 @@ from argparse import ArgumentParser from functools import lru_cache +from types import ModuleType from typing import List, Optional, Tuple import cv2 @@ -541,10 +542,18 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('face_swapper_weight', args.get('face_swapper_weight')) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer ] + + def pre_check() -> bool: model_hash_set = get_model_options().get('hashes') model_source_set = get_model_options().get('sources') + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) @@ -580,16 +589,14 @@ def post_process() -> None: read_static_image.cache_clear() read_static_video_frame.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: get_static_model_initializer.cache_clear() clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() - face_classifier.clear_inference_pool() - face_detector.clear_inference_pool() - face_landmarker.clear_inference_pool() - face_masker.clear_inference_pool() - face_recognizer.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def swap_face(source_face : Face, target_face : Face, source_vision_frame : VisionFrame, temp_vision_frame : VisionFrame) -> VisionFrame: diff --git a/facefusion/processors/modules/frame_colorizer/core.py b/facefusion/processors/modules/frame_colorizer/core.py index 9658bff0..87a9a357 100644 --- a/facefusion/processors/modules/frame_colorizer/core.py +++ b/facefusion/processors/modules/frame_colorizer/core.py @@ -1,5 +1,6 @@ from argparse import ArgumentParser from functools import lru_cache +from types import ModuleType from typing import List import cv2 @@ -198,10 +199,18 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('frame_colorizer_size', args.get('frame_colorizer_size')) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser ] + + def pre_check() -> bool: model_hash_set = get_model_options().get('hashes') model_source_set = get_model_options().get('sources') + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) @@ -222,10 +231,13 @@ def post_process() -> None: read_static_image.cache_clear() read_static_video_frame.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def colorize_frame(temp_vision_frame : VisionFrame) -> VisionFrame: diff --git a/facefusion/processors/modules/frame_enhancer/core.py b/facefusion/processors/modules/frame_enhancer/core.py index f329d9c3..8b1506a5 100644 --- a/facefusion/processors/modules/frame_enhancer/core.py +++ b/facefusion/processors/modules/frame_enhancer/core.py @@ -1,5 +1,6 @@ from argparse import ArgumentParser from functools import lru_cache +from types import ModuleType from typing import List import cv2 @@ -591,10 +592,18 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('frame_enhancer_blend', args.get('frame_enhancer_blend')) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser ] + + def pre_check() -> bool: model_hash_set = get_model_options().get('hashes') model_source_set = get_model_options().get('sources') + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) @@ -615,10 +624,13 @@ def post_process() -> None: read_static_image.cache_clear() read_static_video_frame.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def enhance_frame(temp_vision_frame : VisionFrame) -> VisionFrame: diff --git a/facefusion/processors/modules/lip_syncer/core.py b/facefusion/processors/modules/lip_syncer/core.py index 9851114b..03f9d163 100755 --- a/facefusion/processors/modules/lip_syncer/core.py +++ b/facefusion/processors/modules/lip_syncer/core.py @@ -1,5 +1,7 @@ from argparse import ArgumentParser from functools import lru_cache +from types import ModuleType +from typing import List import cv2 import numpy @@ -142,10 +144,18 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('lip_syncer_weight', args.get('lip_syncer_weight')) +def get_common_modules() -> List[ModuleType]: + return [ content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, voice_extractor ] + + def pre_check() -> bool: model_hash_set = get_model_options().get('hashes') model_source_set = get_model_options().get('sources') + for common_module in get_common_modules(): + if not common_module.pre_check(): + return False + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) @@ -161,16 +171,13 @@ def post_process() -> None: read_static_video_frame.cache_clear() read_static_voice.cache_clear() video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': - content_analyser.clear_inference_pool() - face_classifier.clear_inference_pool() - face_detector.clear_inference_pool() - face_landmarker.clear_inference_pool() - face_masker.clear_inference_pool() - face_recognizer.clear_inference_pool() - voice_extractor.clear_inference_pool() + for common_module in get_common_modules(): + common_module.clear_inference_pool() def sync_lip(target_face : Face, source_voice_frame : AudioFrame, temp_vision_frame : VisionFrame) -> VisionFrame: diff --git a/facefusion/uis/components/source.py b/facefusion/uis/components/source.py index 68d6ad1d..26089880 100644 --- a/facefusion/uis/components/source.py +++ b/facefusion/uis/components/source.py @@ -2,7 +2,7 @@ from typing import List, Optional, Tuple import gradio -from facefusion import state_manager, translator +from facefusion import state_manager, translator, voice_extractor from facefusion.common_helper import get_first from facefusion.filesystem import filter_audio_paths, filter_image_paths, has_audio, has_image from facefusion.uis.core import register_ui_component @@ -55,6 +55,10 @@ def update(files : List[File]) -> Tuple[gradio.Audio, gradio.Image]: source_audio_path = get_first(filter_audio_paths(file_names)) source_image_path = get_first(filter_image_paths(file_names)) state_manager.set_item('source_paths', file_names) + + if has_source_audio: + voice_extractor.pre_check() + return gradio.Audio(value = source_audio_path, visible = has_source_audio), gradio.Image(value = source_image_path, visible = has_source_image) state_manager.clear_item('source_paths') diff --git a/facefusion/uis/layouts/default.py b/facefusion/uis/layouts/default.py index ef1a0727..c415e4a5 100755 --- a/facefusion/uis/layouts/default.py +++ b/facefusion/uis/layouts/default.py @@ -1,10 +1,21 @@ import gradio +import facefusion.face_classifier +import facefusion.face_detector +import facefusion.face_landmarker +import facefusion.face_masker +import facefusion.face_recognizer from facefusion import state_manager from facefusion.uis.components import about, age_modifier_options, background_remover_options, common_options, deep_swapper_options, download, execution, execution_thread_count, expression_restorer_options, face_debugger_options, face_detector, face_editor_options, face_enhancer_options, face_landmarker, face_masker, face_selector, face_swapper_options, frame_colorizer_options, frame_enhancer_options, instant_runner, job_manager, job_runner, lip_syncer_options, memory, output, output_options, preview, preview_options, processors, source, target, temp_frame, terminal, trim_frame, ui_workflow, voice_extractor def pre_check() -> bool: + common_modules = [ facefusion.face_classifier, facefusion.face_detector, facefusion.face_landmarker, facefusion.face_masker, facefusion.face_recognizer ] + + for common_module in common_modules: + if not common_module.pre_check(): + return False + return True