Move the predictor to ONNX to avoid tensorflow, Use video ranges for prediction

This commit is contained in:
henryruhs
2023-10-14 23:26:44 +02:00
parent c83219eb7c
commit 27e506f5ac
9 changed files with 68 additions and 48 deletions
+4 -17
View File
@@ -9,15 +9,14 @@ import warnings
import platform
import shutil
import onnxruntime
import tensorflow
from argparse import ArgumentParser, HelpFormatter
import facefusion.choices
import facefusion.globals
from facefusion import metadata, wording
from facefusion import metadata, predictor, wording
from facefusion.predictor import predict_image, predict_video
from facefusion.processors.frame.core import get_frame_processors_modules, load_frame_processor_module
from facefusion.utilities import is_image, is_video, detect_fps, compress_image, merge_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clear_temp, list_module_names, encode_execution_providers, decode_execution_providers, normalize_output_path
from facefusion.utilities import is_image, is_video, detect_fps, compress_image, merge_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clear_temp, list_module_names, encode_execution_providers, decode_execution_providers, normalize_output_path, update_status
warnings.filterwarnings('ignore', category = FutureWarning, module = 'insightface')
warnings.filterwarnings('ignore', category = UserWarning, module = 'torchvision')
@@ -125,7 +124,7 @@ def apply_args(program : ArgumentParser) -> None:
def run(program : ArgumentParser) -> None:
apply_args(program)
limit_resources()
if not pre_check():
if not pre_check() or not predictor.pre_check():
return
for frame_processor_module in get_frame_processors_modules(facefusion.globals.frame_processors):
if not frame_processor_module.pre_check():
@@ -148,14 +147,6 @@ def destroy() -> None:
def limit_resources() -> None:
# prevent tensorflow memory leak
gpus = tensorflow.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tensorflow.config.experimental.set_virtual_device_configuration(gpu,
[
tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit = 512)
])
# limit memory usage
if facefusion.globals.max_memory:
memory = facefusion.globals.max_memory * 1024 ** 3
if platform.system().lower() == 'darwin':
@@ -210,7 +201,7 @@ def process_image() -> None:
def process_video() -> None:
if predict_video(facefusion.globals.target_path):
if predict_video(facefusion.globals.target_path, facefusion.globals.trim_frame_start, facefusion.globals.trim_frame_end):
return
fps = detect_fps(facefusion.globals.target_path) if facefusion.globals.keep_fps else 25.0
# create temp
@@ -251,7 +242,3 @@ def process_video() -> None:
update_status(wording.get('processing_video_succeed'))
else:
update_status(wording.get('processing_video_failed'))
def update_status(message : str, scope : str = 'FACEFUSION.CORE') -> None:
print('[' + scope + '] ' + message)
+43 -17
View File
@@ -1,26 +1,33 @@
from typing import Any
import threading
from functools import lru_cache
import cv2
import numpy
import opennsfw2
from PIL import Image
from keras import Model
import onnxruntime
from tqdm import tqdm
import facefusion.globals
from facefusion import wording
from facefusion.typing import Frame
from facefusion.vision import get_video_frame, count_video_frame_total, read_image, detect_fps
from facefusion.utilities import resolve_relative_path, conditional_download
PREDICTOR = None
THREAD_LOCK : threading.Lock = threading.Lock()
MAX_PROBABILITY = 0.75
FRAME_INTERVAL = 25
NAME = 'FACEFUSION.PREDICTOR'
MODEL_URL = 'https://github.com/facefusion/facefusion-assets/releases/download/models/open_nsfw.onnx'
MODEL_PATH = resolve_relative_path('../.assets/models/_open_nsfw.onnx')
MAX_PROBABILITY = 0.80
STREAM_COUNTER = 0
def get_predictor() -> Model:
def get_predictor() -> Any:
global PREDICTOR
with THREAD_LOCK:
if PREDICTOR is None:
PREDICTOR = opennsfw2.make_open_nsfw_model()
PREDICTOR = onnxruntime.InferenceSession(MODEL_PATH, providers = facefusion.globals.execution_providers)
return PREDICTOR
@@ -30,29 +37,48 @@ def clear_predictor() -> None:
PREDICTOR = None
def predict_stream(frame : Frame) -> bool:
def pre_check() -> bool:
if not facefusion.globals.skip_download:
download_directory_path = resolve_relative_path('../.assets/models')
conditional_download(download_directory_path, [ MODEL_URL ])
return True
def predict_stream(frame : Frame, fps : float) -> bool:
global STREAM_COUNTER
STREAM_COUNTER = STREAM_COUNTER + 1
if STREAM_COUNTER % FRAME_INTERVAL == 0:
if STREAM_COUNTER % fps == 0:
return predict_frame(frame)
return False
def predict_frame(frame : Frame) -> bool:
image = Image.fromarray(frame)
image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
views = numpy.expand_dims(image, axis = 0)
_, probability = get_predictor().predict(views)[0]
predictor = get_predictor()
frame = cv2.resize(frame, (224, 224)).astype(numpy.float32)
frame -= numpy.array([ 104, 117, 123 ], dtype = numpy.float32)
frame = numpy.expand_dims(frame, axis = 0)
probability = predictor.run(None,
{
'input:0': frame
})[0][0][1]
return probability > MAX_PROBABILITY
@lru_cache(maxsize = None)
def predict_image(image_path : str) -> bool:
return opennsfw2.predict_image(image_path) > MAX_PROBABILITY
frame = read_image(image_path)
return predict_frame(frame)
@lru_cache(maxsize = None)
def predict_video(video_path : str) -> bool:
_, probabilities = opennsfw2.predict_video_frames(video_path = video_path, frame_interval = FRAME_INTERVAL)
return any(probability > MAX_PROBABILITY for probability in probabilities)
def predict_video(video_path : str, start_frame : int, end_frame : int) -> bool:
video_frame_total = count_video_frame_total(video_path)
fps = detect_fps(video_path)
frame_range = range(start_frame or 0, end_frame or video_frame_total)
for frame_number in tqdm(frame_range, desc = wording.get('analysing')):
if frame_number % int(fps) == 0:
frame = get_video_frame(video_path, frame_number)
if predict_frame(frame):
return True
return False
@@ -7,16 +7,17 @@ import onnxruntime
import facefusion.globals
from facefusion import wording
from facefusion.core import update_status
from facefusion.face_analyser import get_many_faces, clear_face_analyser
from facefusion.face_helper import warp_face, paste_back
from facefusion.predictor import clear_predictor
from facefusion.typing import Face, Frame, Update_Process, ProcessMode, ModelValue, OptionsWithModel
from facefusion.utilities import conditional_download, resolve_relative_path, is_image, is_video, is_file, is_download_done
from facefusion.utilities import conditional_download, resolve_relative_path, is_image, is_video, is_file, is_download_done, update_status
from facefusion.vision import read_image, read_static_image, write_image
from facefusion.processors.frame import globals as frame_processors_globals
from facefusion.processors.frame import choices as frame_processors_choices
FRAME_PROCESSOR = None
THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore()
THREAD_LOCK : threading.Lock = threading.Lock()
NAME = 'FACEFUSION.FRAME_PROCESSOR.FACE_ENHANCER'
MODELS : Dict[str, ModelValue] =\
@@ -123,6 +124,7 @@ def pre_process(mode : ProcessMode) -> bool:
def post_process() -> None:
clear_frame_processor()
clear_face_analyser()
clear_predictor()
read_static_image.cache_clear()
@@ -136,7 +138,8 @@ def enhance_face(target_face: Face, temp_frame: Frame) -> Frame:
frame_processor_inputs[frame_processor_input.name] = crop_frame
if frame_processor_input.name == 'weight':
frame_processor_inputs[frame_processor_input.name] = numpy.array([ 1 ], dtype = numpy.double)
crop_frame = frame_processor.run(None, frame_processor_inputs)[0][0]
with THREAD_SEMAPHORE:
crop_frame = frame_processor.run(None, frame_processor_inputs)[0][0]
crop_frame = normalize_crop_frame(crop_frame)
paste_frame = paste_back(temp_frame, crop_frame, affine_matrix)
temp_frame = blend_frame(temp_frame, paste_frame)
@@ -10,12 +10,12 @@ from onnx import numpy_helper
import facefusion.globals
import facefusion.processors.frame.core as frame_processors
from facefusion import wording
from facefusion.core import update_status
from facefusion.face_analyser import get_one_face, get_many_faces, find_similar_faces, clear_face_analyser
from facefusion.face_helper import warp_face, paste_back
from facefusion.face_reference import get_face_reference, set_face_reference
from facefusion.predictor import clear_predictor
from facefusion.typing import Face, Frame, Update_Process, ProcessMode, ModelValue, OptionsWithModel
from facefusion.utilities import conditional_download, resolve_relative_path, is_image, is_video, is_file, is_download_done
from facefusion.utilities import conditional_download, resolve_relative_path, is_image, is_video, is_file, is_download_done, update_status
from facefusion.vision import read_image, read_static_image, write_image
from facefusion.processors.frame import globals as frame_processors_globals
from facefusion.processors.frame import choices as frame_processors_choices
@@ -135,6 +135,7 @@ def post_process() -> None:
clear_frame_processor()
clear_model_matrix()
clear_face_analyser()
clear_predictor()
read_static_image.cache_clear()
@@ -8,10 +8,10 @@ from realesrgan import RealESRGANer
import facefusion.globals
import facefusion.processors.frame.core as frame_processors
from facefusion import wording
from facefusion.core import update_status
from facefusion.face_analyser import clear_face_analyser
from facefusion.predictor import clear_predictor
from facefusion.typing import Frame, Face, Update_Process, ProcessMode, ModelValue, OptionsWithModel
from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, get_device
from facefusion.utilities import conditional_download, resolve_relative_path, is_file, is_download_done, get_device, update_status
from facefusion.vision import read_image, read_static_image, write_image
from facefusion.processors.frame import globals as frame_processors_globals
from facefusion.processors.frame import choices as frame_processors_choices
@@ -124,6 +124,7 @@ def pre_process(mode : ProcessMode) -> bool:
def post_process() -> None:
clear_frame_processor()
clear_face_analyser()
clear_predictor()
read_static_image.cache_clear()
+4 -4
View File
@@ -60,7 +60,7 @@ def listen() -> None:
getattr(source_image, method)(stop, cancels = start_event)
def start(mode: WebcamMode, resolution: str, fps: float) -> Generator[Frame, None, None]:
def start(mode : WebcamMode, resolution : str, fps : float) -> Generator[Frame, None, None]:
facefusion.globals.face_recognition = 'many'
source_face = get_one_face(read_static_image(facefusion.globals.source_path))
stream = None
@@ -68,20 +68,20 @@ def start(mode: WebcamMode, resolution: str, fps: float) -> Generator[Frame, Non
stream = open_stream(mode, resolution, fps) # type: ignore[arg-type]
capture = capture_webcam(resolution, fps)
if capture.isOpened():
for capture_frame in multi_process_capture(source_face, capture):
for capture_frame in multi_process_capture(source_face, capture, fps):
if stream is not None:
stream.stdin.write(capture_frame.tobytes())
yield normalize_frame_color(capture_frame)
def multi_process_capture(source_face: Face, capture : cv2.VideoCapture) -> Generator[Frame, None, None]:
def multi_process_capture(source_face : Face, capture : cv2.VideoCapture, fps : float) -> Generator[Frame, None, None]:
progress = tqdm(desc = wording.get('processing'), unit = 'frame', dynamic_ncols = True)
with ThreadPoolExecutor(max_workers = facefusion.globals.execution_thread_count) as executor:
futures = []
deque_capture_frames : Deque[Frame] = deque()
while True:
_, capture_frame = capture.read()
if predict_stream(capture_frame):
if predict_stream(capture_frame, fps):
return
future = executor.submit(process_stream_frame, source_face, capture_frame)
futures.append(future)
+4
View File
@@ -237,3 +237,7 @@ def get_device(execution_providers : List[str]) -> str:
if 'CoreMLExecutionProvider' in execution_providers:
return 'mps'
return 'cpu'
def update_status(message : str, scope : str = 'FACEFUSION.CORE') -> None:
print('[' + scope + '] ' + message)
+1
View File
@@ -35,6 +35,7 @@ WORDING =\
'headless_help': 'run the program in headless mode',
'creating_temp': 'Creating temporary resources',
'extracting_frames_fps': 'Extracting frames with {fps} FPS',
'analysing': 'Analysing',
'processing': 'Processing',
'downloading': 'Downloading',
'temp_frames_not_found': 'Temporary frames not found',
-3
View File
@@ -5,11 +5,8 @@ numpy==1.24.3
onnx==1.14.1
onnxruntime==1.16.0
opencv-python==4.8.1.78
opennsfw2==0.10.2
pillow==10.0.1
protobuf==4.24.2
psutil==5.9.5
realesrgan==0.3.0
tensorflow==2.13.0
torch==2.1.0
tqdm==4.66.1