restructure conditional methods to a fall-through pattern

common process_temp_frame for all workflow
This commit is contained in:
harisreedhar
2025-12-05 12:53:27 +05:30
committed by henryruhs
parent a829742cba
commit e5460a06d2
4 changed files with 77 additions and 105 deletions
+3 -37
View File
@@ -1,19 +1,18 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
import numpy
from tqdm import tqdm
from facefusion import content_analyser, ffmpeg, logger, process_manager, state_manager, translator
from facefusion.audio import create_empty_audio_frame, get_audio_frame, get_voice_frame, restrict_trim_audio_frame
from facefusion.audio import restrict_trim_audio_frame
from facefusion.common_helper import get_first
from facefusion.filesystem import filter_audio_paths, is_video
from facefusion.processors.core import get_processors_modules
from facefusion.temp_helper import move_temp_file, resolve_temp_frame_paths
from facefusion.time_helper import calculate_end_time
from facefusion.types import ErrorCode
from facefusion.vision import conditional_merge_vision_mask, detect_image_resolution, extract_vision_mask, pack_resolution, read_static_image, read_static_images, restrict_image_resolution, restrict_trim_video_frame, scale_resolution, write_image
from facefusion.workflows.core import clear, is_process_stopping, setup
from facefusion.vision import detect_image_resolution, pack_resolution, restrict_image_resolution, restrict_trim_video_frame, scale_resolution
from facefusion.workflows.core import clear, is_process_stopping, process_temp_frame, setup
def process(start_time : float) -> ErrorCode:
@@ -100,39 +99,6 @@ def process_image() -> ErrorCode:
return 0
def process_temp_frame(temp_frame_path : str, frame_number : int) -> bool: # TODO refinement like to_video.py file.
output_video_fps = state_manager.get_item('output_video_fps')
reference_vision_frame = read_static_image(state_manager.get_item('target_path'))
source_vision_frames = read_static_images(state_manager.get_item('source_paths'))
source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths')))
target_vision_frame = read_static_image(state_manager.get_item('target_path'), 'rgba')
temp_vision_frame = target_vision_frame.copy()
temp_vision_mask = extract_vision_mask(temp_vision_frame)
source_audio_frame = get_audio_frame(source_audio_path, output_video_fps, frame_number)
source_voice_frame = get_voice_frame(source_audio_path, output_video_fps, frame_number)
if not numpy.any(source_audio_frame):
source_audio_frame = create_empty_audio_frame()
if not numpy.any(source_voice_frame):
source_voice_frame = create_empty_audio_frame()
for processor_module in get_processors_modules(state_manager.get_item('processors')):
temp_vision_frame, temp_vision_mask = processor_module.process_frame(
{
'reference_vision_frame': reference_vision_frame,
'source_vision_frames': source_vision_frames,
'source_audio_frame': source_audio_frame,
'source_voice_frame': source_voice_frame,
'target_vision_frame': target_vision_frame[:, :, :3],
'temp_vision_frame': temp_vision_frame[:, :, :3],
'temp_vision_mask': temp_vision_mask
})
temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask)
return write_image(temp_frame_path, temp_vision_frame)
def merge_frames() -> ErrorCode:
trim_frame_start, trim_frame_end = restrict_trim_video_frame(state_manager.get_item('target_path'), state_manager.get_item('trim_frame_start'), state_manager.get_item('trim_frame_end'))
output_video_resolution = scale_resolution(detect_image_resolution(state_manager.get_item('target_path')), state_manager.get_item('output_image_scale'))
+69 -1
View File
@@ -1,6 +1,13 @@
import numpy
from facefusion import logger, process_manager, state_manager, translator
from facefusion.audio import create_empty_audio_frame, get_audio_frame, get_voice_frame
from facefusion.common_helper import get_first
from facefusion.filesystem import filter_audio_paths
from facefusion.processors.core import get_processors_modules
from facefusion.temp_helper import clear_temp_directory, create_temp_directory
from facefusion.types import ErrorCode
from facefusion.types import AudioFrame, ErrorCode, VisionFrame
from facefusion.vision import conditional_merge_vision_mask, extract_vision_mask, read_static_image, read_static_images, read_static_video_frame, restrict_video_fps, write_image
def is_process_stopping() -> bool:
@@ -20,3 +27,64 @@ def clear() -> ErrorCode:
clear_temp_directory(state_manager.get_temp_path(), state_manager.get_item('output_path'))
logger.debug(translator.get('clearing_temp'), __name__)
return 0
def conditional_get_source_audio_frame(frame_number : int) -> AudioFrame:
if state_manager.get_item('workflow') in [ 'audio-to-image', 'image-to-video' ]:
source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths')))
output_video_fps = state_manager.get_item('output_video_fps')
if state_manager.get_item('workflow') == 'image-to-video':
output_video_fps = restrict_video_fps(state_manager.get_item('target_path'), output_video_fps)
source_audio_frame = get_audio_frame(source_audio_path, output_video_fps, frame_number)
if numpy.any(source_audio_frame):
return source_audio_frame
return create_empty_audio_frame()
def conditional_get_source_voice_frame(frame_number: int) -> AudioFrame:
if state_manager.get_item('workflow') in [ 'audio-to-image', 'image-to-video' ]:
source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths')))
output_video_fps = state_manager.get_item('output_video_fps')
if state_manager.get_item('workflow') == 'image-to-video':
output_video_fps = restrict_video_fps(state_manager.get_item('target_path'), output_video_fps)
source_voice_frame = get_voice_frame(source_audio_path, output_video_fps, frame_number)
if numpy.any(source_voice_frame):
return source_voice_frame
return create_empty_audio_frame()
def conditional_get_reference_vision_frame() -> VisionFrame:
if state_manager.get_item('workflow') == 'image-to-video':
return read_static_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number'))
return read_static_image(state_manager.get_item('target_path'))
def process_temp_frame(temp_frame_path : str, frame_number : int) -> bool:
reference_vision_frame = conditional_get_reference_vision_frame()
source_vision_frames = read_static_images(state_manager.get_item('source_paths'))
target_vision_frame = read_static_image(temp_frame_path, 'rgba')
source_audio_frame = conditional_get_source_audio_frame(frame_number)
source_voice_frame = conditional_get_source_voice_frame(frame_number)
temp_vision_frame = target_vision_frame.copy()
temp_vision_mask = extract_vision_mask(temp_vision_frame)
for processor_module in get_processors_modules(state_manager.get_item('processors')):
temp_vision_frame, temp_vision_mask = processor_module.process_frame(
{
'reference_vision_frame': reference_vision_frame,
'source_vision_frames': source_vision_frames,
'source_audio_frame': source_audio_frame,
'source_voice_frame': source_voice_frame,
'target_vision_frame': target_vision_frame[:, :, :3],
'temp_vision_frame': temp_vision_frame[:, :, :3],
'temp_vision_mask': temp_vision_mask
})
temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask)
return write_image(temp_frame_path, temp_vision_frame)
+3 -30
View File
@@ -1,14 +1,12 @@
from functools import partial
from facefusion import content_analyser, ffmpeg, logger, process_manager, state_manager, translator
from facefusion.audio import create_empty_audio_frame
from facefusion.filesystem import is_image
from facefusion.processors.core import get_processors_modules
from facefusion.temp_helper import get_temp_file_path
from facefusion.time_helper import calculate_end_time
from facefusion.types import ErrorCode
from facefusion.vision import conditional_merge_vision_mask, detect_image_resolution, extract_vision_mask, pack_resolution, read_static_image, read_static_images, restrict_image_resolution, scale_resolution, write_image
from facefusion.workflows.core import clear, is_process_stopping, setup
from facefusion.vision import detect_image_resolution, pack_resolution, restrict_image_resolution, scale_resolution
from facefusion.workflows.core import clear, is_process_stopping, process_temp_frame, setup
def process(start_time : float) -> ErrorCode:
@@ -58,32 +56,7 @@ def prepare_image() -> ErrorCode:
def process_image() -> ErrorCode:
temp_image_path = get_temp_file_path(state_manager.get_temp_path(), state_manager.get_item('output_path'))
reference_vision_frame = read_static_image(state_manager.get_item('target_path'))
source_vision_frames = read_static_images(state_manager.get_item('source_paths'))
source_audio_frame = create_empty_audio_frame()
source_voice_frame = create_empty_audio_frame()
target_vision_frame = read_static_image(state_manager.get_item('target_path'), 'rgba')
temp_vision_frame = target_vision_frame.copy()
temp_vision_mask = extract_vision_mask(temp_vision_frame)
for processor_module in get_processors_modules(state_manager.get_item('processors')):
logger.info(translator.get('processing'), processor_module.__name__)
temp_vision_frame, temp_vision_mask = processor_module.process_frame(
{
'reference_vision_frame': reference_vision_frame,
'source_vision_frames': source_vision_frames,
'source_audio_frame': source_audio_frame,
'source_voice_frame': source_voice_frame,
'target_vision_frame': target_vision_frame[:, :, :3],
'temp_vision_frame': temp_vision_frame[:, :, :3],
'temp_vision_mask': temp_vision_mask
})
processor_module.post_process()
temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask)
write_image(temp_image_path, temp_vision_frame)
process_temp_frame(temp_image_path, 0)
if is_process_stopping():
return 4
+2 -37
View File
@@ -1,19 +1,17 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
import numpy
from tqdm import tqdm
from facefusion import content_analyser, ffmpeg, logger, process_manager, state_manager, translator, video_manager
from facefusion.audio import create_empty_audio_frame, get_audio_frame, get_voice_frame
from facefusion.common_helper import get_first
from facefusion.filesystem import filter_audio_paths, is_video
from facefusion.processors.core import get_processors_modules
from facefusion.temp_helper import move_temp_file, resolve_temp_frame_paths
from facefusion.time_helper import calculate_end_time
from facefusion.types import ErrorCode
from facefusion.vision import conditional_merge_vision_mask, detect_video_resolution, extract_vision_mask, pack_resolution, read_static_image, read_static_images, read_static_video_frame, restrict_trim_video_frame, restrict_video_fps, restrict_video_resolution, scale_resolution, write_image
from facefusion.workflows.core import clear, is_process_stopping, setup
from facefusion.vision import detect_video_resolution, pack_resolution, restrict_trim_video_frame, restrict_video_fps, restrict_video_resolution, scale_resolution
from facefusion.workflows.core import clear, is_process_stopping, process_temp_frame, setup
def process(start_time : float) -> ErrorCode:
@@ -149,39 +147,6 @@ def restore_audio() -> ErrorCode:
return 0
def process_temp_frame(temp_frame_path : str, frame_number : int) -> bool:
reference_vision_frame = read_static_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number'))
source_vision_frames = read_static_images(state_manager.get_item('source_paths'))
source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths')))
temp_video_fps = restrict_video_fps(state_manager.get_item('target_path'), state_manager.get_item('output_video_fps'))
target_vision_frame = read_static_image(temp_frame_path, 'rgba')
temp_vision_frame = target_vision_frame.copy()
temp_vision_mask = extract_vision_mask(temp_vision_frame)
source_audio_frame = get_audio_frame(source_audio_path, temp_video_fps, frame_number)
source_voice_frame = get_voice_frame(source_audio_path, temp_video_fps, frame_number)
if not numpy.any(source_audio_frame):
source_audio_frame = create_empty_audio_frame()
if not numpy.any(source_voice_frame):
source_voice_frame = create_empty_audio_frame()
for processor_module in get_processors_modules(state_manager.get_item('processors')):
temp_vision_frame, temp_vision_mask = processor_module.process_frame(
{
'reference_vision_frame': reference_vision_frame,
'source_vision_frames': source_vision_frames,
'source_audio_frame': source_audio_frame,
'source_voice_frame': source_voice_frame,
'target_vision_frame': target_vision_frame[:, :, :3],
'temp_vision_frame': temp_vision_frame[:, :, :3],
'temp_vision_mask': temp_vision_mask
})
temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask)
return write_image(temp_frame_path, temp_vision_frame)
def finalize_video(start_time : float) -> ErrorCode:
if is_video(state_manager.get_item('output_path')):
logger.info(translator.get('processing_video_succeeded').format(seconds = calculate_end_time(start_time)), __name__)