From e5460a06d25ab34ac0666f1c382a2ceeb4466aa5 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Fri, 5 Dec 2025 12:53:27 +0530 Subject: [PATCH] changes restructure conditional methods to a fall-through pattern common process_temp_frame for all workflow --- facefusion/workflows/audio_to_image.py | 40 ++------------- facefusion/workflows/core.py | 70 +++++++++++++++++++++++++- facefusion/workflows/image_to_image.py | 33 ++---------- facefusion/workflows/image_to_video.py | 39 +------------- 4 files changed, 77 insertions(+), 105 deletions(-) diff --git a/facefusion/workflows/audio_to_image.py b/facefusion/workflows/audio_to_image.py index d7a608a5..9bbe6d98 100644 --- a/facefusion/workflows/audio_to_image.py +++ b/facefusion/workflows/audio_to_image.py @@ -1,19 +1,18 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from functools import partial -import numpy from tqdm import tqdm from facefusion import content_analyser, ffmpeg, logger, process_manager, state_manager, translator -from facefusion.audio import create_empty_audio_frame, get_audio_frame, get_voice_frame, restrict_trim_audio_frame +from facefusion.audio import restrict_trim_audio_frame from facefusion.common_helper import get_first from facefusion.filesystem import filter_audio_paths, is_video from facefusion.processors.core import get_processors_modules from facefusion.temp_helper import move_temp_file, resolve_temp_frame_paths from facefusion.time_helper import calculate_end_time from facefusion.types import ErrorCode -from facefusion.vision import conditional_merge_vision_mask, detect_image_resolution, extract_vision_mask, pack_resolution, read_static_image, read_static_images, restrict_image_resolution, restrict_trim_video_frame, scale_resolution, write_image -from facefusion.workflows.core import clear, is_process_stopping, setup +from facefusion.vision import detect_image_resolution, pack_resolution, restrict_image_resolution, restrict_trim_video_frame, scale_resolution +from facefusion.workflows.core import clear, is_process_stopping, process_temp_frame, setup def process(start_time : float) -> ErrorCode: @@ -100,39 +99,6 @@ def process_image() -> ErrorCode: return 0 -def process_temp_frame(temp_frame_path : str, frame_number : int) -> bool: # TODO refinement like to_video.py file. - output_video_fps = state_manager.get_item('output_video_fps') - reference_vision_frame = read_static_image(state_manager.get_item('target_path')) - source_vision_frames = read_static_images(state_manager.get_item('source_paths')) - source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths'))) - target_vision_frame = read_static_image(state_manager.get_item('target_path'), 'rgba') - temp_vision_frame = target_vision_frame.copy() - temp_vision_mask = extract_vision_mask(temp_vision_frame) - - source_audio_frame = get_audio_frame(source_audio_path, output_video_fps, frame_number) - source_voice_frame = get_voice_frame(source_audio_path, output_video_fps, frame_number) - - if not numpy.any(source_audio_frame): - source_audio_frame = create_empty_audio_frame() - if not numpy.any(source_voice_frame): - source_voice_frame = create_empty_audio_frame() - - for processor_module in get_processors_modules(state_manager.get_item('processors')): - temp_vision_frame, temp_vision_mask = processor_module.process_frame( - { - 'reference_vision_frame': reference_vision_frame, - 'source_vision_frames': source_vision_frames, - 'source_audio_frame': source_audio_frame, - 'source_voice_frame': source_voice_frame, - 'target_vision_frame': target_vision_frame[:, :, :3], - 'temp_vision_frame': temp_vision_frame[:, :, :3], - 'temp_vision_mask': temp_vision_mask - }) - - temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask) - return write_image(temp_frame_path, temp_vision_frame) - - def merge_frames() -> ErrorCode: trim_frame_start, trim_frame_end = restrict_trim_video_frame(state_manager.get_item('target_path'), state_manager.get_item('trim_frame_start'), state_manager.get_item('trim_frame_end')) output_video_resolution = scale_resolution(detect_image_resolution(state_manager.get_item('target_path')), state_manager.get_item('output_image_scale')) diff --git a/facefusion/workflows/core.py b/facefusion/workflows/core.py index d1717570..e8b4edd5 100644 --- a/facefusion/workflows/core.py +++ b/facefusion/workflows/core.py @@ -1,6 +1,13 @@ +import numpy + from facefusion import logger, process_manager, state_manager, translator +from facefusion.audio import create_empty_audio_frame, get_audio_frame, get_voice_frame +from facefusion.common_helper import get_first +from facefusion.filesystem import filter_audio_paths +from facefusion.processors.core import get_processors_modules from facefusion.temp_helper import clear_temp_directory, create_temp_directory -from facefusion.types import ErrorCode +from facefusion.types import AudioFrame, ErrorCode, VisionFrame +from facefusion.vision import conditional_merge_vision_mask, extract_vision_mask, read_static_image, read_static_images, read_static_video_frame, restrict_video_fps, write_image def is_process_stopping() -> bool: @@ -20,3 +27,64 @@ def clear() -> ErrorCode: clear_temp_directory(state_manager.get_temp_path(), state_manager.get_item('output_path')) logger.debug(translator.get('clearing_temp'), __name__) return 0 + + +def conditional_get_source_audio_frame(frame_number : int) -> AudioFrame: + if state_manager.get_item('workflow') in [ 'audio-to-image', 'image-to-video' ]: + source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths'))) + output_video_fps = state_manager.get_item('output_video_fps') + + if state_manager.get_item('workflow') == 'image-to-video': + output_video_fps = restrict_video_fps(state_manager.get_item('target_path'), output_video_fps) + source_audio_frame = get_audio_frame(source_audio_path, output_video_fps, frame_number) + + if numpy.any(source_audio_frame): + return source_audio_frame + + return create_empty_audio_frame() + + +def conditional_get_source_voice_frame(frame_number: int) -> AudioFrame: + if state_manager.get_item('workflow') in [ 'audio-to-image', 'image-to-video' ]: + source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths'))) + output_video_fps = state_manager.get_item('output_video_fps') + + if state_manager.get_item('workflow') == 'image-to-video': + output_video_fps = restrict_video_fps(state_manager.get_item('target_path'), output_video_fps) + source_voice_frame = get_voice_frame(source_audio_path, output_video_fps, frame_number) + + if numpy.any(source_voice_frame): + return source_voice_frame + + return create_empty_audio_frame() + + +def conditional_get_reference_vision_frame() -> VisionFrame: + if state_manager.get_item('workflow') == 'image-to-video': + return read_static_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number')) + return read_static_image(state_manager.get_item('target_path')) + + +def process_temp_frame(temp_frame_path : str, frame_number : int) -> bool: + reference_vision_frame = conditional_get_reference_vision_frame() + source_vision_frames = read_static_images(state_manager.get_item('source_paths')) + target_vision_frame = read_static_image(temp_frame_path, 'rgba') + source_audio_frame = conditional_get_source_audio_frame(frame_number) + source_voice_frame = conditional_get_source_voice_frame(frame_number) + temp_vision_frame = target_vision_frame.copy() + temp_vision_mask = extract_vision_mask(temp_vision_frame) + + for processor_module in get_processors_modules(state_manager.get_item('processors')): + temp_vision_frame, temp_vision_mask = processor_module.process_frame( + { + 'reference_vision_frame': reference_vision_frame, + 'source_vision_frames': source_vision_frames, + 'source_audio_frame': source_audio_frame, + 'source_voice_frame': source_voice_frame, + 'target_vision_frame': target_vision_frame[:, :, :3], + 'temp_vision_frame': temp_vision_frame[:, :, :3], + 'temp_vision_mask': temp_vision_mask + }) + + temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask) + return write_image(temp_frame_path, temp_vision_frame) diff --git a/facefusion/workflows/image_to_image.py b/facefusion/workflows/image_to_image.py index 33661ceb..01fed565 100644 --- a/facefusion/workflows/image_to_image.py +++ b/facefusion/workflows/image_to_image.py @@ -1,14 +1,12 @@ from functools import partial from facefusion import content_analyser, ffmpeg, logger, process_manager, state_manager, translator -from facefusion.audio import create_empty_audio_frame from facefusion.filesystem import is_image -from facefusion.processors.core import get_processors_modules from facefusion.temp_helper import get_temp_file_path from facefusion.time_helper import calculate_end_time from facefusion.types import ErrorCode -from facefusion.vision import conditional_merge_vision_mask, detect_image_resolution, extract_vision_mask, pack_resolution, read_static_image, read_static_images, restrict_image_resolution, scale_resolution, write_image -from facefusion.workflows.core import clear, is_process_stopping, setup +from facefusion.vision import detect_image_resolution, pack_resolution, restrict_image_resolution, scale_resolution +from facefusion.workflows.core import clear, is_process_stopping, process_temp_frame, setup def process(start_time : float) -> ErrorCode: @@ -58,32 +56,7 @@ def prepare_image() -> ErrorCode: def process_image() -> ErrorCode: temp_image_path = get_temp_file_path(state_manager.get_temp_path(), state_manager.get_item('output_path')) - reference_vision_frame = read_static_image(state_manager.get_item('target_path')) - source_vision_frames = read_static_images(state_manager.get_item('source_paths')) - source_audio_frame = create_empty_audio_frame() - source_voice_frame = create_empty_audio_frame() - target_vision_frame = read_static_image(state_manager.get_item('target_path'), 'rgba') - temp_vision_frame = target_vision_frame.copy() - temp_vision_mask = extract_vision_mask(temp_vision_frame) - - for processor_module in get_processors_modules(state_manager.get_item('processors')): - logger.info(translator.get('processing'), processor_module.__name__) - - temp_vision_frame, temp_vision_mask = processor_module.process_frame( - { - 'reference_vision_frame': reference_vision_frame, - 'source_vision_frames': source_vision_frames, - 'source_audio_frame': source_audio_frame, - 'source_voice_frame': source_voice_frame, - 'target_vision_frame': target_vision_frame[:, :, :3], - 'temp_vision_frame': temp_vision_frame[:, :, :3], - 'temp_vision_mask': temp_vision_mask - }) - - processor_module.post_process() - - temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask) - write_image(temp_image_path, temp_vision_frame) + process_temp_frame(temp_image_path, 0) if is_process_stopping(): return 4 diff --git a/facefusion/workflows/image_to_video.py b/facefusion/workflows/image_to_video.py index 4d075af1..18a5a863 100644 --- a/facefusion/workflows/image_to_video.py +++ b/facefusion/workflows/image_to_video.py @@ -1,19 +1,17 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from functools import partial -import numpy from tqdm import tqdm from facefusion import content_analyser, ffmpeg, logger, process_manager, state_manager, translator, video_manager -from facefusion.audio import create_empty_audio_frame, get_audio_frame, get_voice_frame from facefusion.common_helper import get_first from facefusion.filesystem import filter_audio_paths, is_video from facefusion.processors.core import get_processors_modules from facefusion.temp_helper import move_temp_file, resolve_temp_frame_paths from facefusion.time_helper import calculate_end_time from facefusion.types import ErrorCode -from facefusion.vision import conditional_merge_vision_mask, detect_video_resolution, extract_vision_mask, pack_resolution, read_static_image, read_static_images, read_static_video_frame, restrict_trim_video_frame, restrict_video_fps, restrict_video_resolution, scale_resolution, write_image -from facefusion.workflows.core import clear, is_process_stopping, setup +from facefusion.vision import detect_video_resolution, pack_resolution, restrict_trim_video_frame, restrict_video_fps, restrict_video_resolution, scale_resolution +from facefusion.workflows.core import clear, is_process_stopping, process_temp_frame, setup def process(start_time : float) -> ErrorCode: @@ -149,39 +147,6 @@ def restore_audio() -> ErrorCode: return 0 -def process_temp_frame(temp_frame_path : str, frame_number : int) -> bool: - reference_vision_frame = read_static_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number')) - source_vision_frames = read_static_images(state_manager.get_item('source_paths')) - source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths'))) - temp_video_fps = restrict_video_fps(state_manager.get_item('target_path'), state_manager.get_item('output_video_fps')) - target_vision_frame = read_static_image(temp_frame_path, 'rgba') - temp_vision_frame = target_vision_frame.copy() - temp_vision_mask = extract_vision_mask(temp_vision_frame) - - source_audio_frame = get_audio_frame(source_audio_path, temp_video_fps, frame_number) - source_voice_frame = get_voice_frame(source_audio_path, temp_video_fps, frame_number) - - if not numpy.any(source_audio_frame): - source_audio_frame = create_empty_audio_frame() - if not numpy.any(source_voice_frame): - source_voice_frame = create_empty_audio_frame() - - for processor_module in get_processors_modules(state_manager.get_item('processors')): - temp_vision_frame, temp_vision_mask = processor_module.process_frame( - { - 'reference_vision_frame': reference_vision_frame, - 'source_vision_frames': source_vision_frames, - 'source_audio_frame': source_audio_frame, - 'source_voice_frame': source_voice_frame, - 'target_vision_frame': target_vision_frame[:, :, :3], - 'temp_vision_frame': temp_vision_frame[:, :, :3], - 'temp_vision_mask': temp_vision_mask - }) - - temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask) - return write_image(temp_frame_path, temp_vision_frame) - - def finalize_video(start_time : float) -> ErrorCode: if is_video(state_manager.get_item('output_path')): logger.info(translator.get('processing_video_succeeded').format(seconds = calculate_end_time(start_time)), __name__)