mirror of
https://github.com/facefusion/facefusion.git
synced 2026-05-11 18:12:36 +02:00
Move the predictor to ONNX to avoid tensorflow, Use video ranges for prediction
This commit is contained in:
+4
-17
@@ -9,15 +9,14 @@ import warnings
|
||||
import platform
|
||||
import shutil
|
||||
import onnxruntime
|
||||
import tensorflow
|
||||
from argparse import ArgumentParser, HelpFormatter
|
||||
|
||||
import facefusion.choices
|
||||
import facefusion.globals
|
||||
from facefusion import metadata, wording
|
||||
from facefusion import metadata, predictor, wording
|
||||
from facefusion.predictor import predict_image, predict_video
|
||||
from facefusion.processors.frame.core import get_frame_processors_modules, load_frame_processor_module
|
||||
from facefusion.utilities import is_image, is_video, detect_fps, compress_image, merge_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clear_temp, list_module_names, encode_execution_providers, decode_execution_providers, normalize_output_path
|
||||
from facefusion.utilities import is_image, is_video, detect_fps, compress_image, merge_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clear_temp, list_module_names, encode_execution_providers, decode_execution_providers, normalize_output_path, update_status
|
||||
|
||||
warnings.filterwarnings('ignore', category = FutureWarning, module = 'insightface')
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torchvision')
|
||||
@@ -125,7 +124,7 @@ def apply_args(program : ArgumentParser) -> None:
|
||||
def run(program : ArgumentParser) -> None:
|
||||
apply_args(program)
|
||||
limit_resources()
|
||||
if not pre_check():
|
||||
if not pre_check() or not predictor.pre_check():
|
||||
return
|
||||
for frame_processor_module in get_frame_processors_modules(facefusion.globals.frame_processors):
|
||||
if not frame_processor_module.pre_check():
|
||||
@@ -148,14 +147,6 @@ def destroy() -> None:
|
||||
|
||||
|
||||
def limit_resources() -> None:
|
||||
# prevent tensorflow memory leak
|
||||
gpus = tensorflow.config.experimental.list_physical_devices('GPU')
|
||||
for gpu in gpus:
|
||||
tensorflow.config.experimental.set_virtual_device_configuration(gpu,
|
||||
[
|
||||
tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit = 512)
|
||||
])
|
||||
# limit memory usage
|
||||
if facefusion.globals.max_memory:
|
||||
memory = facefusion.globals.max_memory * 1024 ** 3
|
||||
if platform.system().lower() == 'darwin':
|
||||
@@ -210,7 +201,7 @@ def process_image() -> None:
|
||||
|
||||
|
||||
def process_video() -> None:
|
||||
if predict_video(facefusion.globals.target_path):
|
||||
if predict_video(facefusion.globals.target_path, facefusion.globals.trim_frame_start, facefusion.globals.trim_frame_end):
|
||||
return
|
||||
fps = detect_fps(facefusion.globals.target_path) if facefusion.globals.keep_fps else 25.0
|
||||
# create temp
|
||||
@@ -251,7 +242,3 @@ def process_video() -> None:
|
||||
update_status(wording.get('processing_video_succeed'))
|
||||
else:
|
||||
update_status(wording.get('processing_video_failed'))
|
||||
|
||||
|
||||
def update_status(message : str, scope : str = 'FACEFUSION.CORE') -> None:
|
||||
print('[' + scope + '] ' + message)
|
||||
|
||||
+43
-17
@@ -1,26 +1,33 @@
|
||||
from typing import Any
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
import opennsfw2
|
||||
from PIL import Image
|
||||
from keras import Model
|
||||
import onnxruntime
|
||||
from tqdm import tqdm
|
||||
|
||||
import facefusion.globals
|
||||
from facefusion import wording
|
||||
from facefusion.typing import Frame
|
||||
from facefusion.vision import get_video_frame, count_video_frame_total, read_image, detect_fps
|
||||
from facefusion.utilities import resolve_relative_path, conditional_download
|
||||
|
||||
PREDICTOR = None
|
||||
THREAD_LOCK : threading.Lock = threading.Lock()
|
||||
MAX_PROBABILITY = 0.75
|
||||
FRAME_INTERVAL = 25
|
||||
NAME = 'FACEFUSION.PREDICTOR'
|
||||
MODEL_URL = 'https://github.com/facefusion/facefusion-assets/releases/download/models/open_nsfw.onnx'
|
||||
MODEL_PATH = resolve_relative_path('../.assets/models/_open_nsfw.onnx')
|
||||
MAX_PROBABILITY = 0.80
|
||||
STREAM_COUNTER = 0
|
||||
|
||||
|
||||
def get_predictor() -> Model:
|
||||
def get_predictor() -> Any:
|
||||
global PREDICTOR
|
||||
|
||||
with THREAD_LOCK:
|
||||
if PREDICTOR is None:
|
||||
PREDICTOR = opennsfw2.make_open_nsfw_model()
|
||||
PREDICTOR = onnxruntime.InferenceSession(MODEL_PATH, providers = facefusion.globals.execution_providers)
|
||||
return PREDICTOR
|
||||
|
||||
|
||||
@@ -30,29 +37,48 @@ def clear_predictor() -> None:
|
||||
PREDICTOR = None
|
||||
|
||||
|
||||
def predict_stream(frame : Frame) -> bool:
|
||||
def pre_check() -> bool:
|
||||
if not facefusion.globals.skip_download:
|
||||
download_directory_path = resolve_relative_path('../.assets/models')
|
||||
conditional_download(download_directory_path, [ MODEL_URL ])
|
||||
return True
|
||||
|
||||
|
||||
def predict_stream(frame : Frame, fps : float) -> bool:
|
||||
global STREAM_COUNTER
|
||||
|
||||
STREAM_COUNTER = STREAM_COUNTER + 1
|
||||
if STREAM_COUNTER % FRAME_INTERVAL == 0:
|
||||
if STREAM_COUNTER % fps == 0:
|
||||
return predict_frame(frame)
|
||||
return False
|
||||
|
||||
|
||||
def predict_frame(frame : Frame) -> bool:
|
||||
image = Image.fromarray(frame)
|
||||
image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
|
||||
views = numpy.expand_dims(image, axis = 0)
|
||||
_, probability = get_predictor().predict(views)[0]
|
||||
predictor = get_predictor()
|
||||
frame = cv2.resize(frame, (224, 224)).astype(numpy.float32)
|
||||
frame -= numpy.array([ 104, 117, 123 ], dtype = numpy.float32)
|
||||
frame = numpy.expand_dims(frame, axis = 0)
|
||||
probability = predictor.run(None,
|
||||
{
|
||||
'input:0': frame
|
||||
})[0][0][1]
|
||||
return probability > MAX_PROBABILITY
|
||||
|
||||
|
||||
@lru_cache(maxsize = None)
|
||||
def predict_image(image_path : str) -> bool:
|
||||
return opennsfw2.predict_image(image_path) > MAX_PROBABILITY
|
||||
frame = read_image(image_path)
|
||||
return predict_frame(frame)
|
||||
|
||||
|
||||
@lru_cache(maxsize = None)
|
||||
def predict_video(video_path : str) -> bool:
|
||||
_, probabilities = opennsfw2.predict_video_frames(video_path = video_path, frame_interval = FRAME_INTERVAL)
|
||||
return any(probability > MAX_PROBABILITY for probability in probabilities)
|
||||
def predict_video(video_path : str, start_frame : int, end_frame : int) -> bool:
|
||||
video_frame_total = count_video_frame_total(video_path)
|
||||
fps = detect_fps(video_path)
|
||||
frame_range = range(start_frame or 0, end_frame or video_frame_total)
|
||||
for frame_number in tqdm(frame_range, desc = wording.get('analysing')):
|
||||
if frame_number % int(fps) == 0:
|
||||
frame = get_video_frame(video_path, frame_number)
|
||||
if predict_frame(frame):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -7,16 +7,17 @@ import onnxruntime
|
||||
|
||||
import facefusion.globals
|
||||
from facefusion import wording
|
||||
from facefusion.core import update_status
|
||||
from facefusion.face_analyser import get_many_faces, clear_face_analyser
|
||||
from facefusion.face_helper import warp_face, paste_back
|
||||
from facefusion.predictor import clear_predictor
|
||||
from facefusion.typing import Face, Frame, Update_Process, ProcessMode, ModelValue, OptionsWithModel
|
||||
from facefusion.utilities import conditional_download, resolve_relative_path, is_image, is_video, is_file, is_download_done
|
||||
from facefusion.utilities import conditional_download, resolve_relative_path, is_image, is_video, is_file, is_download_done, update_status
|
||||
from facefusion.vision import read_image, read_static_image, write_image
|
||||
from facefusion.processors.frame import globals as frame_processors_globals
|
||||
from facefusion.processors.frame import choices as frame_processors_choices
|
||||
|
||||
FRAME_PROCESSOR = None
|
||||
THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore()
|
||||
THREAD_LOCK : threading.Lock = threading.Lock()
|
||||
NAME = 'FACEFUSION.FRAME_PROCESSOR.FACE_ENHANCER'
|
||||
MODELS : Dict[str, ModelValue] =\
|
||||
@@ -123,6 +124,7 @@ def pre_process(mode : ProcessMode) -> bool:
|
||||
def post_process() -> None:
|
||||
clear_frame_processor()
|
||||
clear_face_analyser()
|
||||
clear_predictor()
|
||||
read_static_image.cache_clear()
|
||||
|
||||
|
||||
@@ -136,7 +138,8 @@ def enhance_face(target_face: Face, temp_frame: Frame) -> Frame:
|
||||
frame_processor_inputs[frame_processor_input.name] = crop_frame
|
||||
if frame_processor_input.name == 'weight':
|
||||
frame_processor_inputs[frame_processor_input.name] = numpy.array([ 1 ], dtype = numpy.double)
|
||||
crop_frame = frame_processor.run(None, frame_processor_inputs)[0][0]
|
||||
with THREAD_SEMAPHORE:
|
||||
crop_frame = frame_processor.run(None, frame_processor_inputs)[0][0]
|
||||
crop_frame = normalize_crop_frame(crop_frame)
|
||||
paste_frame = paste_back(temp_frame, crop_frame, affine_matrix)
|
||||
temp_frame = blend_frame(temp_frame, paste_frame)
|
||||
|
||||
@@ -10,12 +10,12 @@ from onnx import numpy_helper
|
||||
import facefusion.globals
|
||||
import facefusion.processors.frame.core as frame_processors
|
||||
from facefusion import wording
|
||||
from facefusion.core import update_status
|
||||
from facefusion.face_analyser import get_one_face, get_many_faces, find_similar_faces, clear_face_analyser
|
||||
from facefusion.face_helper import warp_face, paste_back
|
||||
from facefusion.face_reference import get_face_reference, set_face_reference
|
||||
from facefusion.predictor import clear_predictor
|
||||
from facefusion.typing import Face, Frame, Update_Process, ProcessMode, ModelValue, OptionsWithModel
|
||||
from facefusion.utilities import conditional_download, resolve_relative_path, is_image, is_video, is_file, is_download_done
|
||||
from facefusion.utilities import conditional_download, resolve_relative_path, is_image, is_video, is_file, is_download_done, update_status
|
||||
from facefusion.vision import read_image, read_static_image, write_image
|
||||
from facefusion.processors.frame import globals as frame_processors_globals
|
||||
from facefusion.processors.frame import choices as frame_processors_choices
|
||||
@@ -135,6 +135,7 @@ def post_process() -> None:
|
||||
clear_frame_processor()
|
||||
clear_model_matrix()
|
||||
clear_face_analyser()
|
||||
clear_predictor()
|
||||
read_static_image.cache_clear()
|
||||
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ from realesrgan import RealESRGANer
|
||||
import facefusion.globals
|
||||
import facefusion.processors.frame.core as frame_processors
|
||||
from facefusion import wording
|
||||
from facefusion.core import update_status
|
||||
from facefusion.face_analyser import clear_face_analyser
|
||||
from facefusion.predictor import clear_predictor
|
||||
from facefusion.typing import Frame, Face, Update_Process, ProcessMode, ModelValue, OptionsWithModel
|
||||
from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, get_device
|
||||
from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, get_device, update_status
|
||||
from facefusion.vision import read_image, read_static_image, write_image
|
||||
from facefusion.processors.frame import globals as frame_processors_globals
|
||||
from facefusion.processors.frame import choices as frame_processors_choices
|
||||
@@ -124,6 +124,7 @@ def pre_process(mode : ProcessMode) -> bool:
|
||||
def post_process() -> None:
|
||||
clear_frame_processor()
|
||||
clear_face_analyser()
|
||||
clear_predictor()
|
||||
read_static_image.cache_clear()
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def listen() -> None:
|
||||
getattr(source_image, method)(stop, cancels = start_event)
|
||||
|
||||
|
||||
def start(mode: WebcamMode, resolution: str, fps: float) -> Generator[Frame, None, None]:
|
||||
def start(mode : WebcamMode, resolution : str, fps : float) -> Generator[Frame, None, None]:
|
||||
facefusion.globals.face_recognition = 'many'
|
||||
source_face = get_one_face(read_static_image(facefusion.globals.source_path))
|
||||
stream = None
|
||||
@@ -68,20 +68,20 @@ def start(mode: WebcamMode, resolution: str, fps: float) -> Generator[Frame, Non
|
||||
stream = open_stream(mode, resolution, fps) # type: ignore[arg-type]
|
||||
capture = capture_webcam(resolution, fps)
|
||||
if capture.isOpened():
|
||||
for capture_frame in multi_process_capture(source_face, capture):
|
||||
for capture_frame in multi_process_capture(source_face, capture, fps):
|
||||
if stream is not None:
|
||||
stream.stdin.write(capture_frame.tobytes())
|
||||
yield normalize_frame_color(capture_frame)
|
||||
|
||||
|
||||
def multi_process_capture(source_face: Face, capture : cv2.VideoCapture) -> Generator[Frame, None, None]:
|
||||
def multi_process_capture(source_face : Face, capture : cv2.VideoCapture, fps : float) -> Generator[Frame, None, None]:
|
||||
progress = tqdm(desc = wording.get('processing'), unit = 'frame', dynamic_ncols = True)
|
||||
with ThreadPoolExecutor(max_workers = facefusion.globals.execution_thread_count) as executor:
|
||||
futures = []
|
||||
deque_capture_frames : Deque[Frame] = deque()
|
||||
while True:
|
||||
_, capture_frame = capture.read()
|
||||
if predict_stream(capture_frame):
|
||||
if predict_stream(capture_frame, fps):
|
||||
return
|
||||
future = executor.submit(process_stream_frame, source_face, capture_frame)
|
||||
futures.append(future)
|
||||
|
||||
@@ -237,3 +237,7 @@ def get_device(execution_providers : List[str]) -> str:
|
||||
if 'CoreMLExecutionProvider' in execution_providers:
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def update_status(message : str, scope : str = 'FACEFUSION.CORE') -> None:
|
||||
print('[' + scope + '] ' + message)
|
||||
|
||||
@@ -35,6 +35,7 @@ WORDING =\
|
||||
'headless_help': 'run the program in headless mode',
|
||||
'creating_temp': 'Creating temporary resources',
|
||||
'extracting_frames_fps': 'Extracting frames with {fps} FPS',
|
||||
'analysing': 'Analysing',
|
||||
'processing': 'Processing',
|
||||
'downloading': 'Downloading',
|
||||
'temp_frames_not_found': 'Temporary frames not found',
|
||||
|
||||
@@ -5,11 +5,8 @@ numpy==1.24.3
|
||||
onnx==1.14.1
|
||||
onnxruntime==1.16.0
|
||||
opencv-python==4.8.1.78
|
||||
opennsfw2==0.10.2
|
||||
pillow==10.0.1
|
||||
protobuf==4.24.2
|
||||
psutil==5.9.5
|
||||
realesrgan==0.3.0
|
||||
tensorflow==2.13.0
|
||||
torch==2.1.0
|
||||
tqdm==4.66.1
|
||||
|
||||
Reference in New Issue
Block a user