mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-25 10:55:57 +02:00
Refactor reusable workflow tasks (#980)
* Refactor reusable workflow tasks * Refactor reusable workflow tasks * Make it borderline again
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
from facefusion import logger, process_manager, translator
|
||||
from facefusion import logger, process_manager, state_manager, translator
|
||||
from facefusion.temp_helper import clear_temp_directory, create_temp_directory
|
||||
from facefusion.types import ErrorCode
|
||||
|
||||
|
||||
def is_process_stopping() -> bool:
|
||||
@@ -6,3 +8,15 @@ def is_process_stopping() -> bool:
|
||||
process_manager.end()
|
||||
logger.info(translator.get('processing_stopped'), __name__)
|
||||
return process_manager.is_pending()
|
||||
|
||||
|
||||
def setup() -> ErrorCode:
|
||||
create_temp_directory(state_manager.get_item('target_path'))
|
||||
logger.debug(translator.get('creating_temp'), __name__)
|
||||
return 0
|
||||
|
||||
|
||||
def clear() -> ErrorCode:
|
||||
clear_temp_directory(state_manager.get_item('target_path'))
|
||||
logger.debug(translator.get('clearing_temp'), __name__)
|
||||
return 0
|
||||
|
||||
@@ -1,26 +1,28 @@
|
||||
from functools import partial
|
||||
|
||||
from facefusion import ffmpeg
|
||||
from facefusion import logger, process_manager, state_manager, translator
|
||||
from facefusion import content_analyser, ffmpeg, logger, process_manager, state_manager, translator
|
||||
from facefusion.audio import create_empty_audio_frame
|
||||
from facefusion.content_analyser import analyse_image
|
||||
from facefusion.filesystem import is_image
|
||||
from facefusion.processors.core import get_processors_modules
|
||||
from facefusion.temp_helper import clear_temp_directory, create_temp_directory, get_temp_file_path
|
||||
from facefusion.temp_helper import get_temp_file_path
|
||||
from facefusion.time_helper import calculate_end_time
|
||||
from facefusion.types import ErrorCode
|
||||
from facefusion.vision import conditional_merge_vision_mask, detect_image_resolution, extract_vision_mask, pack_resolution, read_static_image, read_static_images, restrict_image_resolution, scale_resolution, write_image
|
||||
from facefusion.workflows.core import is_process_stopping
|
||||
from facefusion.workflows.core import clear, is_process_stopping, setup
|
||||
|
||||
|
||||
def process(start_time : float) -> ErrorCode:
|
||||
tasks =\
|
||||
[
|
||||
analyse_image,
|
||||
clear,
|
||||
setup,
|
||||
prepare_image,
|
||||
process_image,
|
||||
partial(finalize_image, start_time)
|
||||
partial(finalize_image, start_time),
|
||||
clear
|
||||
]
|
||||
|
||||
process_manager.start()
|
||||
|
||||
for task in tasks:
|
||||
@@ -34,14 +36,9 @@ def process(start_time : float) -> ErrorCode:
|
||||
return 0
|
||||
|
||||
|
||||
def setup() -> ErrorCode:
|
||||
if analyse_image(state_manager.get_item('target_path')):
|
||||
def analyse_image() -> ErrorCode:
|
||||
if content_analyser.analyse_image(state_manager.get_item('target_path')):
|
||||
return 3
|
||||
|
||||
logger.debug(translator.get('clearing_temp'), __name__)
|
||||
clear_temp_directory(state_manager.get_item('target_path'))
|
||||
logger.debug(translator.get('creating_temp'), __name__)
|
||||
create_temp_directory(state_manager.get_item('target_path'))
|
||||
return 0
|
||||
|
||||
|
||||
@@ -102,9 +99,6 @@ def finalize_image(start_time : float) -> ErrorCode:
|
||||
else:
|
||||
logger.warn(translator.get('finalizing_image_skipped'), __name__)
|
||||
|
||||
logger.debug(translator.get('clearing_temp'), __name__)
|
||||
clear_temp_directory(state_manager.get_item('target_path'))
|
||||
|
||||
if is_image(state_manager.get_item('output_path')):
|
||||
logger.info(translator.get('processing_image_succeeded').format(seconds = calculate_end_time(start_time)), __name__)
|
||||
else:
|
||||
|
||||
@@ -4,30 +4,32 @@ from functools import partial
|
||||
import numpy
|
||||
from tqdm import tqdm
|
||||
|
||||
from facefusion import ffmpeg
|
||||
from facefusion import logger, process_manager, state_manager, translator, video_manager
|
||||
from facefusion import content_analyser, ffmpeg, logger, process_manager, state_manager, translator, video_manager
|
||||
from facefusion.audio import create_empty_audio_frame, get_audio_frame, get_voice_frame
|
||||
from facefusion.common_helper import get_first
|
||||
from facefusion.content_analyser import analyse_video
|
||||
from facefusion.filesystem import filter_audio_paths, is_video
|
||||
from facefusion.processors.core import get_processors_modules
|
||||
from facefusion.temp_helper import clear_temp_directory, create_temp_directory, move_temp_file, resolve_temp_frame_paths
|
||||
from facefusion.temp_helper import move_temp_file, resolve_temp_frame_paths
|
||||
from facefusion.time_helper import calculate_end_time
|
||||
from facefusion.types import ErrorCode
|
||||
from facefusion.vision import conditional_merge_vision_mask, detect_video_resolution, extract_vision_mask, pack_resolution, read_static_image, read_static_images, read_static_video_frame, restrict_trim_frame, restrict_video_fps, restrict_video_resolution, scale_resolution, write_image
|
||||
from facefusion.workflows.core import is_process_stopping
|
||||
from facefusion.workflows.core import clear, is_process_stopping, setup
|
||||
|
||||
|
||||
def process(start_time : float) -> ErrorCode:
|
||||
tasks =\
|
||||
[
|
||||
analyse_video,
|
||||
clear,
|
||||
setup,
|
||||
extract_frames,
|
||||
process_video,
|
||||
merge_frames,
|
||||
restore_audio,
|
||||
partial(finalize_video, start_time)
|
||||
partial(finalize_video, start_time),
|
||||
clear
|
||||
]
|
||||
|
||||
process_manager.start()
|
||||
|
||||
for task in tasks:
|
||||
@@ -41,16 +43,11 @@ def process(start_time : float) -> ErrorCode:
|
||||
return 0
|
||||
|
||||
|
||||
def setup() -> ErrorCode:
|
||||
def analyse_video() -> ErrorCode:
|
||||
trim_frame_start, trim_frame_end = restrict_trim_frame(state_manager.get_item('target_path'), state_manager.get_item('trim_frame_start'), state_manager.get_item('trim_frame_end'))
|
||||
|
||||
if analyse_video(state_manager.get_item('target_path'), trim_frame_start, trim_frame_end):
|
||||
if content_analyser.analyse_video(state_manager.get_item('target_path'), trim_frame_start, trim_frame_end):
|
||||
return 3
|
||||
|
||||
logger.debug(translator.get('clearing_temp'), __name__)
|
||||
clear_temp_directory(state_manager.get_item('target_path'))
|
||||
logger.debug(translator.get('creating_temp'), __name__)
|
||||
create_temp_directory(state_manager.get_item('target_path'))
|
||||
return 0
|
||||
|
||||
|
||||
@@ -105,6 +102,39 @@ def process_video() -> ErrorCode:
|
||||
return 0
|
||||
|
||||
|
||||
def process_temp_frame(temp_frame_path : str, frame_number : int) -> bool:
|
||||
reference_vision_frame = read_static_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number'))
|
||||
source_vision_frames = read_static_images(state_manager.get_item('source_paths'))
|
||||
source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths')))
|
||||
temp_video_fps = restrict_video_fps(state_manager.get_item('target_path'), state_manager.get_item('output_video_fps'))
|
||||
target_vision_frame = read_static_image(temp_frame_path, 'rgba')
|
||||
temp_vision_frame = target_vision_frame.copy()
|
||||
temp_vision_mask = extract_vision_mask(temp_vision_frame)
|
||||
|
||||
source_audio_frame = get_audio_frame(source_audio_path, temp_video_fps, frame_number)
|
||||
source_voice_frame = get_voice_frame(source_audio_path, temp_video_fps, frame_number)
|
||||
|
||||
if not numpy.any(source_audio_frame):
|
||||
source_audio_frame = create_empty_audio_frame()
|
||||
if not numpy.any(source_voice_frame):
|
||||
source_voice_frame = create_empty_audio_frame()
|
||||
|
||||
for processor_module in get_processors_modules(state_manager.get_item('processors')):
|
||||
temp_vision_frame, temp_vision_mask = processor_module.process_frame(
|
||||
{
|
||||
'reference_vision_frame': reference_vision_frame,
|
||||
'source_vision_frames': source_vision_frames,
|
||||
'source_audio_frame': source_audio_frame,
|
||||
'source_voice_frame': source_voice_frame,
|
||||
'target_vision_frame': target_vision_frame[:, :, :3],
|
||||
'temp_vision_frame': temp_vision_frame[:, :, :3],
|
||||
'temp_vision_mask': temp_vision_mask
|
||||
})
|
||||
|
||||
temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask)
|
||||
return write_image(temp_frame_path, temp_vision_frame)
|
||||
|
||||
|
||||
def merge_frames() -> ErrorCode:
|
||||
trim_frame_start, trim_frame_end = restrict_trim_frame(state_manager.get_item('target_path'), state_manager.get_item('trim_frame_start'), state_manager.get_item('trim_frame_end'))
|
||||
output_video_resolution = scale_resolution(detect_video_resolution(state_manager.get_item('target_path')), state_manager.get_item('output_video_scale'))
|
||||
@@ -186,9 +216,6 @@ def process_temp_frame(temp_frame_path : str, frame_number : int) -> bool:
|
||||
|
||||
|
||||
def finalize_video(start_time : float) -> ErrorCode:
|
||||
logger.debug(translator.get('clearing_temp'), __name__)
|
||||
clear_temp_directory(state_manager.get_item('target_path'))
|
||||
|
||||
if is_video(state_manager.get_item('output_path')):
|
||||
logger.info(translator.get('processing_video_succeeded').format(seconds = calculate_end_time(start_time)), __name__)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user