diff --git a/facefusion/benchmarker.py b/facefusion/benchmarker.py index 07d88181..4ba16bd4 100644 --- a/facefusion/benchmarker.py +++ b/facefusion/benchmarker.py @@ -9,6 +9,7 @@ import facefusion.choices from facefusion import content_analyser, core, state_manager from facefusion.cli_helper import render_table from facefusion.download import conditional_download, resolve_download_url +from facefusion.face_store import clear_faces from facefusion.filesystem import get_file_extension from facefusion.types import BenchmarkCycleSet from facefusion.vision import count_video_frame_total, detect_video_fps @@ -63,6 +64,7 @@ def cycle(cycle_count : int) -> BenchmarkCycleSet: if state_manager.get_item('benchmark_mode') == 'cold': content_analyser.analyse_image.cache_clear() content_analyser.analyse_video.cache_clear() + clear_faces() start_time = perf_counter() core.conditional_process() diff --git a/facefusion/face_analyser.py b/facefusion/face_analyser.py index 135261ed..ea0acf9c 100644 --- a/facefusion/face_analyser.py +++ b/facefusion/face_analyser.py @@ -2,7 +2,7 @@ from typing import List, Optional import numpy -from facefusion import state_manager +from facefusion import state_manager, face_store from facefusion.common_helper import get_first from facefusion.face_classifier import classify_face from facefusion.face_detector import detect_faces, detect_faces_by_angle @@ -92,6 +92,21 @@ def get_average_face(faces : List[Face]) -> Optional[Face]: return None +def get_static_faces(vision_frames : List[VisionFrame]) -> List[Face]: + many_faces : List[Face] = [] + + for vision_frame in vision_frames: + faces = face_store.get_faces(vision_frame) + + if not faces: + faces = get_many_faces([ vision_frame ]) + face_store.set_faces(vision_frame, faces) + + many_faces.extend(faces) + + return many_faces + + def get_many_faces(vision_frames : List[VisionFrame]) -> List[Face]: many_faces : List[Face] = [] @@ -115,6 +130,7 @@ def get_many_faces(vision_frames : List[VisionFrame]) -> List[Face]: if faces: many_faces.extend(faces) + return many_faces diff --git a/facefusion/face_selector.py b/facefusion/face_selector.py index 6d602e1e..0edd9b2e 100644 --- a/facefusion/face_selector.py +++ b/facefusion/face_selector.py @@ -5,7 +5,7 @@ import numpy import facefusion.choices from facefusion import state_manager from facefusion.common_helper import get_first -from facefusion.face_analyser import get_many_faces, get_one_face +from facefusion.face_analyser import get_many_faces, get_one_face, get_static_faces from facefusion.types import Face, FaceSelectorOrder, Gender, Race, Score, VisionFrame @@ -14,7 +14,7 @@ def select_faces(reference_vision_frame : VisionFrame, source_vision_frames : Li target_faces = get_many_faces([ target_vision_frame ]) if state_manager.get_item('face_selector_gender') == 'auto' or state_manager.get_item('face_selector_race') == 'auto': - source_faces = get_many_faces(source_vision_frames) + source_faces = get_static_faces(source_vision_frames) if state_manager.get_item('face_selector_mode') == 'many': return sort_and_filter_faces(source_faces, target_faces) @@ -25,7 +25,7 @@ def select_faces(reference_vision_frame : VisionFrame, source_vision_frames : Li return [ target_face ] if state_manager.get_item('face_selector_mode') == 'reference': - reference_faces = get_many_faces([ reference_vision_frame ]) + reference_faces = get_static_faces([ reference_vision_frame ]) reference_faces = sort_and_filter_faces(source_faces, reference_faces) reference_face = get_one_face(reference_faces, state_manager.get_item('reference_face_position')) diff --git a/facefusion/face_store.py b/facefusion/face_store.py new file mode 100644 index 00000000..0356378e --- /dev/null +++ b/facefusion/face_store.py @@ -0,0 +1,21 @@ +from typing import List, Optional + +from facefusion.hash_helper import create_hash +from facefusion.types import Face, FaceStore, VisionFrame + +FACE_STORE : FaceStore = {} + + +def get_faces(vision_frame : VisionFrame) -> Optional[List[Face]]: + vision_hash = create_hash(vision_frame.tobytes()) + return FACE_STORE.get(vision_hash) + + +def set_faces(vision_frame : VisionFrame, faces : List[Face]) -> None: + vision_hash = create_hash(vision_frame.tobytes()) + if vision_hash: + FACE_STORE[vision_hash] = faces + + +def clear_faces() -> None: + FACE_STORE.clear() diff --git a/facefusion/processors/modules/face_swapper/core.py b/facefusion/processors/modules/face_swapper/core.py index 5d77407c..cae7b032 100755 --- a/facefusion/processors/modules/face_swapper/core.py +++ b/facefusion/processors/modules/face_swapper/core.py @@ -12,7 +12,7 @@ from facefusion import config, content_analyser, face_classifier, face_detector, from facefusion.common_helper import get_first, is_macos from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url from facefusion.execution import has_execution_provider -from facefusion.face_analyser import get_average_face, get_many_faces, get_one_face, scale_face +from facefusion.face_analyser import get_average_face, get_many_faces, get_one_face, get_static_faces, scale_face from facefusion.face_helper import paste_back, warp_face_by_face_landmark_5 from facefusion.face_masker import create_area_mask, create_box_mask, create_occlusion_mask, create_region_mask from facefusion.face_selector import select_faces, sort_faces_by_order @@ -555,7 +555,7 @@ def pre_process(mode : ProcessMode) -> bool: source_image_paths = filter_image_paths(state_manager.get_item('source_paths')) source_vision_frames = read_static_images(source_image_paths) - source_faces = get_many_faces(source_vision_frames) + source_faces = get_static_faces(source_vision_frames) if not get_one_face(source_faces): logger.error(translator.get('no_source_face_detected') + translator.get('exclamation_mark'), __name__) @@ -758,7 +758,7 @@ def extract_source_face(source_vision_frames : List[VisionFrame]) -> Optional[Fa if source_vision_frames: for source_vision_frame in source_vision_frames: - temp_faces = get_many_faces([source_vision_frame]) + temp_faces = get_static_faces([source_vision_frame]) temp_faces = sort_faces_by_order(temp_faces, 'large-small') if temp_faces: diff --git a/facefusion/types.py b/facefusion/types.py index c90b427d..aa866cf2 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -49,6 +49,8 @@ Face = namedtuple('Face', 'gender', 'race' ]) +FaceStore : TypeAlias = Dict[str, List[Face]] + Language = Literal['en'] Locales : TypeAlias = Dict[Language, Dict[str, Any]] LocalePoolSet : TypeAlias = Dict[str, Locales] @@ -58,12 +60,12 @@ VideoWriterSet : TypeAlias = Dict[str, cv2.VideoWriter] CameraCaptureSet : TypeAlias = Dict[str, cv2.VideoCapture] VideoPoolSet = TypedDict('VideoPoolSet', { - 'capture': VideoCaptureSet, - 'writer': VideoWriterSet + 'capture' : VideoCaptureSet, + 'writer' : VideoWriterSet }) CameraPoolSet = TypedDict('CameraPoolSet', { - 'capture': CameraCaptureSet + 'capture' : CameraCaptureSet }) ColorMode = Literal['rgb', 'rgba'] @@ -344,7 +346,7 @@ State = TypedDict('State', 'benchmark_cycle_count' : int, 'face_detector_model' : FaceDetectorModel, 'face_detector_size' : str, - 'face_detector_margin': Margin, + 'face_detector_margin' : Margin, 'face_detector_angles' : List[Angle], 'face_detector_score' : Score, 'face_landmarker_model' : FaceLandmarkerModel, @@ -365,7 +367,7 @@ State = TypedDict('State', 'face_mask_regions' : List[FaceMaskRegion], 'face_mask_blur' : float, 'face_mask_padding' : Padding, - 'voice_extractor_model': VoiceExtractorModel, + 'voice_extractor_model' : VoiceExtractorModel, 'trim_frame_start' : int, 'trim_frame_end' : int, 'temp_frame_format' : TempFrameFormat, diff --git a/facefusion/uis/components/face_selector.py b/facefusion/uis/components/face_selector.py index 8422fe47..5e9829e2 100644 --- a/facefusion/uis/components/face_selector.py +++ b/facefusion/uis/components/face_selector.py @@ -9,6 +9,7 @@ from facefusion import state_manager, translator from facefusion.common_helper import calculate_float_step, calculate_int_step from facefusion.face_analyser import get_many_faces from facefusion.face_selector import sort_and_filter_faces +from facefusion.face_store import clear_faces from facefusion.filesystem import filter_image_paths, is_image, is_video from facefusion.types import FaceSelectorGender, FaceSelectorMode, FaceSelectorOrder, FaceSelectorRace, VisionFrame from facefusion.uis.core import get_ui_component, get_ui_components, register_ui_component @@ -194,6 +195,7 @@ def clear_reference_frame_number() -> None: def clear_and_update_reference_position_gallery() -> gradio.Gallery: + clear_faces() return update_reference_position_gallery() diff --git a/facefusion/uis/components/preview.py b/facefusion/uis/components/preview.py index b6be49f7..59eaa03d 100755 --- a/facefusion/uis/components/preview.py +++ b/facefusion/uis/components/preview.py @@ -11,6 +11,7 @@ from facefusion.common_helper import get_first from facefusion.content_analyser import analyse_frame from facefusion.face_analyser import get_one_face from facefusion.face_selector import select_faces +from facefusion.face_store import clear_faces from facefusion.filesystem import filter_audio_paths, is_image, is_video from facefusion.processors.core import get_processors_modules from facefusion.types import AudioFrame, Face, Mask, VisionFrame @@ -217,6 +218,7 @@ def update_preview_image(preview_mode : PreviewMode, preview_resolution : str, f def clear_and_update_preview_image(preview_mode : PreviewMode, preview_resolution : str, frame_number : int = 0) -> gradio.Image: + clear_faces() return update_preview_image(preview_mode, preview_resolution, frame_number) diff --git a/facefusion/uis/components/target.py b/facefusion/uis/components/target.py index 9a996635..e69e3c56 100644 --- a/facefusion/uis/components/target.py +++ b/facefusion/uis/components/target.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple import gradio from facefusion import state_manager, translator +from facefusion.face_store import clear_faces from facefusion.filesystem import is_image, is_video from facefusion.uis.core import register_ui_component from facefusion.uis.types import ComponentOptions, File @@ -50,6 +51,8 @@ def listen() -> None: def update(file : File) -> Tuple[gradio.Image, gradio.Video]: + clear_faces() + if file and is_image(file.name): state_manager.set_item('target_path', file.name) return gradio.Image(value = file.name, visible = True), gradio.Video(value = None, visible = False) diff --git a/facefusion/uis/components/trim_frame.py b/facefusion/uis/components/trim_frame.py index c5e8341a..9ad37a6e 100644 --- a/facefusion/uis/components/trim_frame.py +++ b/facefusion/uis/components/trim_frame.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple from gradio_rangeslider import RangeSlider from facefusion import state_manager, translator +from facefusion.face_store import clear_faces from facefusion.filesystem import is_video from facefusion.uis.core import get_ui_components from facefusion.uis.types import ComponentOptions @@ -52,6 +53,7 @@ def remote_update() -> RangeSlider: def update_trim_frame(trim_frame : Tuple[float, float]) -> None: + clear_faces() trim_frame_start, trim_frame_end = trim_frame video_frame_total = count_video_frame_total(state_manager.get_item('target_path')) trim_frame_start = int(trim_frame_start) if trim_frame_start > 0 else None diff --git a/tests/test_face_analyser.py b/tests/test_face_analyser.py index 22d06ad5..ab9caac6 100644 --- a/tests/test_face_analyser.py +++ b/tests/test_face_analyser.py @@ -5,6 +5,7 @@ import pytest from facefusion import face_classifier, face_detector, face_landmarker, face_recognizer, state_manager from facefusion.download import conditional_download from facefusion.face_analyser import get_many_faces +from facefusion.face_store import clear_faces from facefusion.vision import read_static_image from .helper import get_test_example_file, get_test_examples_directory @@ -18,6 +19,7 @@ def before_all() -> None: subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.8:ih*0.8', get_test_example_file('source-80crop.jpg') ]) subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.7:ih*0.7', get_test_example_file('source-70crop.jpg') ]) subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.6:ih*0.6', get_test_example_file('source-60crop.jpg') ]) + state_manager.init_item('execution_device_ids', [ 0 ]) state_manager.init_item('execution_providers', [ 'cpu' ]) state_manager.init_item('download_providers', [ 'github' ]) @@ -26,6 +28,7 @@ def before_all() -> None: state_manager.init_item('face_detector_score', 0.5) state_manager.init_item('face_landmarker_model', 'many') state_manager.init_item('face_landmarker_score', 0.5) + face_classifier.pre_check() face_landmarker.pre_check() face_recognizer.pre_check() @@ -37,6 +40,7 @@ def before_each() -> None: face_detector.clear_inference_pool() face_landmarker.clear_inference_pool() face_recognizer.clear_inference_pool() + clear_faces() def test_get_one_face_with_retinaface() -> None: