diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py
index 5be40df6..6b077c0a 100644
--- a/facefusion/apis/core.py
+++ b/facefusion/apis/core.py
@@ -1,11 +1,15 @@
+import os
+
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
-from starlette.routing import Route, WebSocketRoute
+from starlette.routing import Mount, Route, WebSocketRoute
+from starlette.staticfiles import StaticFiles
from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset
from facefusion.apis.endpoints.capabilities import get_capabilities
from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics
+from facefusion.apis.endpoints.nodes import execute_node, list_nodes
from facefusion.apis.endpoints.ping import websocket_ping
from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session
from facefusion.apis.endpoints.state import get_state, set_state
@@ -28,6 +32,8 @@ def create_api() -> Starlette:
Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
Route('/capabilities', get_capabilities, methods = [ 'GET' ]),
+ Route('/nodes', list_nodes, methods = [ 'GET' ]),
+ Route('/nodes/{node_name}', execute_node, methods = [ 'POST' ], middleware = [ session_guard ]),
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/stream', webrtc_stream, methods = ['POST'], middleware = [session_guard]),
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
@@ -35,6 +41,9 @@ def create_api() -> Starlette:
WebSocketRoute('/stream', websocket_stream, middleware = [session_guard])
]
+ static_path = os.path.join(os.path.dirname(__file__), 'static')
+ routes.append(Mount('/', app = StaticFiles(directory = static_path, html = True)))
+
api = Starlette(routes = routes)
api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ])
diff --git a/facefusion/apis/endpoints/nodes.py b/facefusion/apis/endpoints/nodes.py
new file mode 100644
index 00000000..64cf89e5
--- /dev/null
+++ b/facefusion/apis/endpoints/nodes.py
@@ -0,0 +1,128 @@
+import numpy
+from starlette.requests import Request
+from starlette.responses import JSONResponse
+from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR
+
+from facefusion import state_manager
+from facefusion.node import NODE_REGISTRY, NodeContext, decode_vision_frame, encode_vision_frame
+
+NODES_LOADED = False
+
+
+def ensure_nodes_loaded() -> None:
+ global NODES_LOADED
+
+ if NODES_LOADED:
+ return
+
+ NODES_LOADED = True
+
+ import facefusion.face_analyser
+
+ processor_names =\
+ [
+ 'face_debugger',
+ 'face_enhancer',
+ 'face_swapper'
+ ]
+
+ for processor_name in processor_names:
+ try:
+ from facefusion.processors.core import load_processor_module
+
+ load_processor_module(processor_name)
+ except SystemExit:
+ pass
+
+
+async def list_nodes(request : Request) -> JSONResponse:
+ ensure_nodes_loaded()
+ nodes = {}
+
+ for name, registered in NODE_REGISTRY.items():
+ schema = registered.schema
+ nodes[name] =\
+ {
+ 'name' : name,
+ 'description' : schema.description,
+ 'inputs' : [ { 'name' : p.name, 'type' : p.type, 'label' : p.label } for p in schema.inputs ],
+ 'outputs' : [ { 'name' : p.name, 'type' : p.type, 'label' : p.label } for p in schema.outputs ],
+ 'state_keys' : schema.state_keys
+ }
+
+ return JSONResponse(nodes, status_code = HTTP_200_OK)
+
+
+async def execute_node(request : Request) -> JSONResponse:
+ ensure_nodes_loaded()
+ node_name = request.path_params.get('node_name')
+
+ if node_name not in NODE_REGISTRY:
+ return JSONResponse(
+ {
+ 'message' : 'node not found'
+ }, status_code = HTTP_404_NOT_FOUND)
+
+ registered = NODE_REGISTRY[node_name]
+ schema = registered.schema
+ body = await request.json()
+ raw_inputs = body.get('inputs', {})
+ state_overrides = body.get('state', {})
+
+ for key in state_overrides:
+ if key not in schema.state_keys:
+ return JSONResponse(
+ {
+ 'message' : 'state key "' + key + '" not declared for node "' + node_name + '"'
+ }, status_code = HTTP_400_BAD_REQUEST)
+
+ # Decode inputs based on port types
+ decoded_inputs = {}
+ input_port_types = { p.name: p.type for p in schema.inputs }
+
+ for field_name, value in raw_inputs.items():
+ port_type = input_port_types.get(field_name, '')
+
+ if port_type == 'image' and isinstance(value, str):
+ decoded_inputs[field_name] = decode_vision_frame(value)
+ elif port_type == 'image_list' and isinstance(value, list):
+ decoded_inputs[field_name] = [ decode_vision_frame(v) for v in value ]
+ else:
+ decoded_inputs[field_name] = value
+
+ # Apply state overrides temporarily
+ saved_state = {}
+
+ for key, value in state_overrides.items():
+ saved_state[key] = state_manager.get_item(key)
+ state_manager.set_item(key, value)
+
+ try:
+ result = registered.fn(decoded_inputs)
+
+ # Encode outputs based on port types
+ output_port_types = { p.name: p.type for p in schema.outputs }
+ response = {}
+
+ for key, value in result.items():
+ port_type = output_port_types.get(key, '')
+
+ if port_type == 'image' and isinstance(value, numpy.ndarray):
+ response[key] = encode_vision_frame(value)
+ elif port_type == 'image_list' and isinstance(value, list):
+ response[key] = [ encode_vision_frame(v) for v in value if isinstance(v, numpy.ndarray) ]
+ else:
+ response[key] = value
+
+ return JSONResponse(response, status_code = HTTP_200_OK)
+ except Exception as exception:
+ import traceback
+
+ return JSONResponse(
+ {
+ 'message' : str(exception),
+ 'traceback' : traceback.format_exc()
+ }, status_code = HTTP_500_INTERNAL_SERVER_ERROR)
+ finally:
+ for key, value in saved_state.items():
+ state_manager.set_item(key, value)
diff --git a/facefusion/apis/static/index.html b/facefusion/apis/static/index.html
new file mode 100644
index 00000000..41f6e73d
--- /dev/null
+++ b/facefusion/apis/static/index.html
@@ -0,0 +1,1017 @@
+
+
+
+
+
+ Workflow Area
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/facefusion/face_analyser.py b/facefusion/face_analyser.py
index 76b9b621..ff407e16 100644
--- a/facefusion/face_analyser.py
+++ b/facefusion/face_analyser.py
@@ -1,8 +1,10 @@
-from typing import List, Optional
+from typing import Any, Dict, List, Optional
+import cv2
import numpy
from facefusion import state_manager
+from facefusion.node import NodeContext, NodePort, encode_vision_frame, node
from facefusion.common_helper import get_first
from facefusion.face_classifier import classify_face
from facefusion.face_detector import detect_faces, detect_faces_by_angle
@@ -141,3 +143,118 @@ def scale_face(target_face : Face, target_vision_frame : VisionFrame, temp_visio
bounding_box = bounding_box,
landmark_set = landmark_set
)
+
+
+def crop_face(vision_frame : VisionFrame, face : Face, padding : int = 30) -> VisionFrame:
+ h, w = vision_frame.shape[:2]
+ box = face.bounding_box.astype(int)
+ x1 = max(0, box[0] - padding)
+ y1 = max(0, box[1] - padding)
+ x2 = min(w, box[2] + padding)
+ y2 = min(h, box[3] + padding)
+ return vision_frame[y1:y2, x1:x2].copy()
+
+
+def draw_bounding_boxes(vision_frame : VisionFrame, faces : List[Face]) -> VisionFrame:
+ result = numpy.ascontiguousarray(vision_frame.copy())
+
+ for face in faces:
+ box = face.bounding_box.astype(int)
+ cv2.rectangle(result, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
+ label = ''
+ if face.gender:
+ label += face.gender
+ if face.age:
+ label += f' {face.age.start}-{face.age.stop}'
+ if label:
+ cv2.putText(result, label.strip(), (box[0], box[1] - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
+
+ return result
+
+
+@node(
+ name = 'face_detector',
+ inputs =
+ [
+ NodePort(name = 'image', type = 'image', label = 'Image')
+ ],
+ outputs =
+ [
+ NodePort(name = 'face_image', type = 'image', label = 'Face Image'),
+ NodePort(name = 'face_data', type = 'json', label = 'Face Data')
+ ],
+ state_keys =
+ [
+ 'face_detector_model',
+ 'face_detector_size',
+ 'face_detector_margin',
+ 'face_detector_angles',
+ 'face_detector_score'
+ ],
+ description = 'Detect faces and output annotated image with cropped faces'
+)
+def detect_faces_node(inputs : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]:
+ vision_frame = inputs.get('image')
+ faces = get_many_faces([ vision_frame ])
+ face_image = crop_face(vision_frame, faces[0]) if faces else vision_frame
+
+ face_data = []
+ for face in faces:
+ face_data.append(
+ {
+ 'bounding_box' : face.bounding_box.tolist(),
+ 'score' : float(face.score_set.get('detector', 0)),
+ 'gender' : face.gender,
+ 'age' : { 'start' : face.age.start, 'stop' : face.age.stop } if face.age else None,
+ 'race' : face.race
+ })
+
+ return\
+ {
+ 'face_image' : face_image,
+ 'face_data' : face_data
+ }
+
+
+@node(
+ name = 'face_landmarker',
+ inputs =
+ [
+ NodePort(name = 'image', type = 'image', label = 'Image')
+ ],
+ outputs =
+ [
+ NodePort(name = 'image', type = 'image', label = 'Landmark Image'),
+ NodePort(name = 'landmark_data', type = 'json', label = 'Landmark Data')
+ ],
+ state_keys =
+ [
+ 'face_landmarker_model',
+ 'face_landmarker_score'
+ ],
+ description = 'Detect face landmarks'
+)
+def detect_landmarks_node(inputs : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]:
+ vision_frame = inputs.get('image')
+ faces = get_many_faces([ vision_frame ])
+ result_frame = numpy.ascontiguousarray(vision_frame.copy())
+
+ landmark_data = []
+ for face in faces:
+ face_landmark_68 = face.landmark_set.get('68')
+
+ if numpy.any(face_landmark_68):
+ for point in face_landmark_68.astype(int):
+ cv2.circle(result_frame, tuple(point), 2, (0, 255, 0), -1)
+
+ landmark_data.append(
+ {
+ 'landmark_5' : face.landmark_set.get('5').tolist() if numpy.any(face.landmark_set.get('5')) else [],
+ 'landmark_68' : face.landmark_set.get('68').tolist() if numpy.any(face.landmark_set.get('68')) else []
+ })
+
+ return\
+ {
+ 'image' : result_frame,
+ 'landmark_data' : landmark_data
+ }
diff --git a/facefusion/node.py b/facefusion/node.py
new file mode 100644
index 00000000..3d099e7f
--- /dev/null
+++ b/facefusion/node.py
@@ -0,0 +1,101 @@
+import base64
+from dataclasses import dataclass
+from functools import wraps
+from typing import Any, Callable, Dict, List, Optional, Type
+
+import cv2
+import numpy
+from numpy.typing import NDArray
+
+
+@dataclass
+class NodePort:
+ name : str
+ type : str # 'image', 'json', 'faces'
+ label : str = ''
+
+
+@dataclass
+class NodeSchema:
+ name : str
+ inputs : List[NodePort]
+ outputs : List[NodePort]
+ state_keys : List[str]
+ description : str = ''
+
+
+@dataclass
+class RegisteredNode:
+ schema : NodeSchema
+ fn : Callable
+
+
+NODE_REGISTRY : Dict[str, RegisteredNode] = {}
+
+
+class NodeContext:
+ def __init__(self, state : Dict[str, Any]) -> None:
+ self._state = dict(state)
+
+ def get_item(self, key : str) -> Any:
+ value = self._state.get(key)
+
+ if value is None:
+ from facefusion import state_manager
+
+ value = state_manager.get_item(key)
+
+ return value
+
+ def __getitem__(self, key : str) -> Any:
+ return self.get_item(key)
+
+ def __contains__(self, key : str) -> bool:
+ return key in self._state
+
+ def to_dict(self) -> Dict[str, Any]:
+ return dict(self._state)
+
+
+def node(name : str, inputs : List[NodePort], outputs : List[NodePort], state_keys : List[str], description : str = '') -> Callable:
+ def decorator(fn : Callable) -> Callable:
+ schema = NodeSchema(
+ name = name,
+ inputs = inputs,
+ outputs = outputs,
+ state_keys = state_keys,
+ description = description
+ )
+
+ @wraps(fn)
+ def wrapper(inputs_dict : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]:
+ if ctx is None:
+ from facefusion import state_manager
+
+ state_snapshot = { key: state_manager.get_item(key) for key in state_keys }
+ ctx = NodeContext(state_snapshot)
+ return fn(inputs_dict, ctx)
+
+ wrapper.__node_schema__ = schema
+ NODE_REGISTRY[name] = RegisteredNode(schema = schema, fn = wrapper)
+ return wrapper
+
+ return decorator
+
+
+def get_node(name : str) -> Optional[RegisteredNode]:
+ return NODE_REGISTRY.get(name)
+
+
+def get_all_nodes() -> Dict[str, RegisteredNode]:
+ return NODE_REGISTRY
+
+
+def decode_vision_frame(b64_string : str) -> NDArray[Any]:
+ image_bytes = base64.b64decode(b64_string)
+ return cv2.imdecode(numpy.frombuffer(image_bytes, numpy.uint8), cv2.IMREAD_COLOR)
+
+
+def encode_vision_frame(frame : NDArray[Any], fmt : str = '.jpg') -> str:
+ _, buffer = cv2.imencode(fmt, frame)
+ return base64.b64encode(buffer.tobytes()).decode('utf-8')
diff --git a/facefusion/processors/modules/face_debugger/core.py b/facefusion/processors/modules/face_debugger/core.py
index fc05e7b8..0c5366a4 100755
--- a/facefusion/processors/modules/face_debugger/core.py
+++ b/facefusion/processors/modules/face_debugger/core.py
@@ -6,6 +6,7 @@ import numpy
import facefusion.capability_store
import facefusion.jobs.job_manager
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, logger, state_manager, translator, video_manager
+from facefusion.node import NodeContext, NodePort, node
from facefusion.face_analyser import scale_face
from facefusion.face_helper import warp_face_by_face_landmark_5
from facefusion.face_masker import create_area_mask, create_box_mask, create_occlusion_mask, create_region_mask
@@ -14,6 +15,7 @@ from facefusion.filesystem import in_directory, is_image, is_video
from facefusion.processors.modules.face_debugger import choices as face_debugger_choices
from facefusion.processors.modules.face_debugger.types import FaceDebuggerInputs
from facefusion.processors.types import ApplyStateItem, ProcessorOutputs
+from typing import Optional
from facefusion.program_helper import find_argument_group
from facefusion.types import Args, Face, InferencePool, ProcessMode, VisionFrame
from facefusion.vision import read_static_image, read_static_video_frame
@@ -77,14 +79,14 @@ def post_process() -> None:
face_recognizer.clear_inference_pool()
-def debug_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
- face_debugger_items = state_manager.get_item('face_debugger_items')
+def debug_face(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
+ face_debugger_items = ctx.get_item('face_debugger_items') if ctx else state_manager.get_item('face_debugger_items')
if 'bounding-box' in face_debugger_items:
temp_vision_frame = draw_bounding_box(target_face, temp_vision_frame)
if 'face-mask' in face_debugger_items:
- temp_vision_frame = draw_face_mask(target_face, temp_vision_frame)
+ temp_vision_frame = draw_face_mask(target_face, temp_vision_frame, ctx)
if 'face-landmark-5' in face_debugger_items:
temp_vision_frame = draw_face_landmark_5(target_face, temp_vision_frame)
@@ -122,7 +124,8 @@ def draw_bounding_box(target_face : Face, temp_vision_frame : VisionFrame) -> Vi
return temp_vision_frame
-def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
+def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
+ _get = ctx.get_item if ctx else state_manager.get_item
crop_masks = []
temp_vision_frame = numpy.ascontiguousarray(temp_vision_frame)
face_landmark_5 = target_face.landmark_set.get('5')
@@ -136,21 +139,21 @@ def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame) -> Visio
if numpy.array_equal(face_landmark_5, face_landmark_5_68):
mask_color = 255, 255, 0
- if 'box' in state_manager.get_item('face_mask_types'):
- box_mask = create_box_mask(crop_vision_frame, 0, state_manager.get_item('face_mask_padding'))
+ if 'box' in _get('face_mask_types'):
+ box_mask = create_box_mask(crop_vision_frame, 0, _get('face_mask_padding'))
crop_masks.append(box_mask)
- if 'occlusion' in state_manager.get_item('face_mask_types'):
+ if 'occlusion' in _get('face_mask_types'):
occlusion_mask = create_occlusion_mask(crop_vision_frame)
crop_masks.append(occlusion_mask)
- if 'area' in state_manager.get_item('face_mask_types'):
+ if 'area' in _get('face_mask_types'):
face_landmark_68 = cv2.transform(face_landmark_68.reshape(1, -1, 2), affine_matrix).reshape(-1, 2)
- area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas'))
+ area_mask = create_area_mask(crop_vision_frame, face_landmark_68, _get('face_mask_areas'))
crop_masks.append(area_mask)
- if 'region' in state_manager.get_item('face_mask_types'):
- region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions'))
+ if 'region' in _get('face_mask_types'):
+ region_mask = create_region_mask(crop_vision_frame, _get('face_mask_regions'))
crop_masks.append(region_mask)
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
@@ -227,18 +230,51 @@ def draw_face_landmark_68_5(target_face : Face, temp_vision_frame : VisionFrame)
return temp_vision_frame
-def process_frame(inputs : FaceDebuggerInputs) -> ProcessorOutputs:
+@node(
+ name = 'face_debugger',
+ inputs =
+ [
+ NodePort(name = 'image', type = 'image', label = 'Image')
+ ],
+ outputs =
+ [
+ NodePort(name = 'image', type = 'image', label = 'Debug Image')
+ ],
+ state_keys =
+ [
+ 'face_debugger_items'
+ ],
+ description = 'Visualize face landmarks, bounding boxes and masks'
+)
+def process_frame(inputs, ctx = None):
+ from facefusion.face_analyser import get_many_faces
+ from facefusion.vision import extract_vision_mask
+
+ vision_frame = inputs.get('image')
+
+ if isinstance(inputs.get('reference_vision_frame'), type(None)) and vision_frame is not None:
+ target_vision_frame = vision_frame
+ temp_vision_frame = vision_frame.copy()
+ temp_vision_mask = extract_vision_mask(temp_vision_frame)
+ target_faces = get_many_faces([ target_vision_frame ])
+
+ if target_faces:
+ for target_face in target_faces:
+ temp_vision_frame = debug_face(target_face, temp_vision_frame, ctx)
+
+ return { 'image' : temp_vision_frame }
+
reference_vision_frame = inputs.get('reference_vision_frame')
- target_vision_frame = inputs.get('target_vision_frame')
- temp_vision_frame = inputs.get('temp_vision_frame')
+ target_vision_frame = inputs.get('target_vision_frame', vision_frame)
+ temp_vision_frame = inputs.get('temp_vision_frame', vision_frame.copy() if vision_frame is not None else None)
temp_vision_mask = inputs.get('temp_vision_mask')
target_faces = select_faces(reference_vision_frame, target_vision_frame)
if target_faces:
for target_face in target_faces:
target_face = scale_face(target_face, target_vision_frame, temp_vision_frame)
- temp_vision_frame = debug_face(target_face, temp_vision_frame)
+ temp_vision_frame = debug_face(target_face, temp_vision_frame, ctx)
- return temp_vision_frame, temp_vision_mask
+ return { 'image' : temp_vision_frame }
diff --git a/facefusion/processors/modules/face_enhancer/core.py b/facefusion/processors/modules/face_enhancer/core.py
index bc40538d..9aa026fa 100755
--- a/facefusion/processors/modules/face_enhancer/core.py
+++ b/facefusion/processors/modules/face_enhancer/core.py
@@ -14,8 +14,10 @@ from facefusion.face_masker import create_box_mask, create_occlusion_mask
from facefusion.face_selector import select_faces
from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path
from facefusion.processors.modules.face_enhancer import choices as face_enhancer_choices
+from facefusion.node import NodeContext, NodePort, node
from facefusion.processors.modules.face_enhancer.types import FaceEnhancerInputs, FaceEnhancerWeight
from facefusion.processors.types import ApplyStateItem, ProcessorOutputs
+from typing import Optional
from facefusion.program_helper import find_argument_group
from facefusion.thread_helper import thread_semaphore
from facefusion.types import Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame
@@ -360,27 +362,28 @@ def post_process() -> None:
face_recognizer.clear_inference_pool()
-def enhance_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
+def enhance_face(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
+ _get = ctx.get_item if ctx else state_manager.get_item
model_template = get_model_options().get('template')
model_size = get_model_options().get('size')
crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, model_size)
- box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), (0, 0, 0, 0))
+ box_mask = create_box_mask(crop_vision_frame, _get('face_mask_blur'), (0, 0, 0, 0))
crop_masks =\
[
box_mask
]
- if 'occlusion' in state_manager.get_item('face_mask_types'):
+ if 'occlusion' in _get('face_mask_types'):
occlusion_mask = create_occlusion_mask(crop_vision_frame)
crop_masks.append(occlusion_mask)
crop_vision_frame = prepare_crop_frame(crop_vision_frame)
- face_enhancer_weight = numpy.array([ state_manager.get_item('face_enhancer_weight') ]).astype(numpy.double)
+ face_enhancer_weight = numpy.array([ _get('face_enhancer_weight') ]).astype(numpy.double)
crop_vision_frame = forward(crop_vision_frame, face_enhancer_weight)
crop_vision_frame = normalize_crop_frame(crop_vision_frame)
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix)
- temp_vision_frame = blend_paste_frame(temp_vision_frame, paste_vision_frame)
+ temp_vision_frame = blend_paste_frame(temp_vision_frame, paste_vision_frame, ctx)
return temp_vision_frame
@@ -426,22 +429,57 @@ def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
return crop_vision_frame
-def blend_paste_frame(temp_vision_frame : VisionFrame, paste_vision_frame : VisionFrame) -> VisionFrame:
- face_enhancer_blend = 1 - (state_manager.get_item('face_enhancer_blend') / 100)
+def blend_paste_frame(temp_vision_frame : VisionFrame, paste_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
+ _get = ctx.get_item if ctx else state_manager.get_item
+ face_enhancer_blend = 1 - (_get('face_enhancer_blend') / 100)
temp_vision_frame = blend_frame(temp_vision_frame, paste_vision_frame, 1 - face_enhancer_blend)
return temp_vision_frame
-def process_frame(inputs : FaceEnhancerInputs) -> ProcessorOutputs:
+@node(
+ name = 'face_enhancer',
+ inputs =
+ [
+ NodePort(name = 'image', type = 'image', label = 'Image')
+ ],
+ outputs =
+ [
+ NodePort(name = 'image', type = 'image', label = 'Enhanced Image')
+ ],
+ state_keys =
+ [
+ 'face_enhancer_model',
+ 'face_enhancer_blend',
+ 'face_enhancer_weight'
+ ],
+ description = 'Enhance and restore face quality'
+)
+def process_frame(inputs, ctx = None):
+ from facefusion.face_analyser import get_many_faces
+ from facefusion.vision import extract_vision_mask
+
+ vision_frame = inputs.get('image')
+
+ if isinstance(inputs.get('reference_vision_frame'), type(None)) and vision_frame is not None:
+ target_vision_frame = vision_frame
+ temp_vision_frame = vision_frame.copy()
+ target_faces = get_many_faces([ target_vision_frame ])
+
+ if target_faces:
+ for target_face in target_faces:
+ temp_vision_frame = enhance_face(target_face, temp_vision_frame, ctx)
+
+ return { 'image' : temp_vision_frame }
+
reference_vision_frame = inputs.get('reference_vision_frame')
- target_vision_frame = inputs.get('target_vision_frame')
- temp_vision_frame = inputs.get('temp_vision_frame')
+ target_vision_frame = inputs.get('target_vision_frame', vision_frame)
+ temp_vision_frame = inputs.get('temp_vision_frame', vision_frame.copy() if vision_frame is not None else None)
temp_vision_mask = inputs.get('temp_vision_mask')
target_faces = select_faces(reference_vision_frame, target_vision_frame)
if target_faces:
for target_face in target_faces:
target_face = scale_face(target_face, target_vision_frame, temp_vision_frame)
- temp_vision_frame = enhance_face(target_face, temp_vision_frame)
+ temp_vision_frame = enhance_face(target_face, temp_vision_frame, ctx)
- return temp_vision_frame, temp_vision_mask
+ return { 'image' : temp_vision_frame }
diff --git a/facefusion/processors/modules/face_swapper/core.py b/facefusion/processors/modules/face_swapper/core.py
index 4122a0af..ab5ac3e4 100755
--- a/facefusion/processors/modules/face_swapper/core.py
+++ b/facefusion/processors/modules/face_swapper/core.py
@@ -19,6 +19,7 @@ from facefusion.face_selector import select_faces, sort_faces_by_order
from facefusion.filesystem import filter_image_paths, has_image, in_directory, is_image, is_video, resolve_relative_path
from facefusion.model_helper import get_static_model_initializer
from facefusion.processors.modules.face_swapper import choices as face_swapper_choices
+from facefusion.node import NodeContext, NodePort, node
from facefusion.processors.modules.face_swapper.types import FaceSwapperInputs
from facefusion.processors.pixel_boost import explode_pixel_boost, implode_pixel_boost
from facefusion.processors.types import ApplyStateItem, ProcessorOutputs
@@ -600,20 +601,21 @@ def post_process() -> None:
face_recognizer.clear_inference_pool()
-def swap_face(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
+def swap_face(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
+ _get = ctx.get_item if ctx else state_manager.get_item
model_template = get_model_options().get('template')
model_size = get_model_options().get('size')
- pixel_boost_size = unpack_resolution(state_manager.get_item('face_swapper_pixel_boost'))
+ pixel_boost_size = unpack_resolution(_get('face_swapper_pixel_boost'))
pixel_boost_total = pixel_boost_size[0] // model_size[0]
crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, pixel_boost_size)
temp_vision_frames = []
crop_masks = []
- if 'box' in state_manager.get_item('face_mask_types'):
- box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), state_manager.get_item('face_mask_padding'))
+ if 'box' in _get('face_mask_types'):
+ box_mask = create_box_mask(crop_vision_frame, _get('face_mask_blur'), _get('face_mask_padding'))
crop_masks.append(box_mask)
- if 'occlusion' in state_manager.get_item('face_mask_types'):
+ if 'occlusion' in _get('face_mask_types'):
occlusion_mask = create_occlusion_mask(crop_vision_frame)
crop_masks.append(occlusion_mask)
@@ -625,13 +627,13 @@ def swap_face(source_face : Face, target_face : Face, temp_vision_frame : Vision
temp_vision_frames.append(pixel_boost_vision_frame)
crop_vision_frame = explode_pixel_boost(temp_vision_frames, pixel_boost_total, model_size, pixel_boost_size)
- if 'area' in state_manager.get_item('face_mask_types'):
+ if 'area' in _get('face_mask_types'):
face_landmark_68 = cv2.transform(target_face.landmark_set.get('68').reshape(1, -1, 2), affine_matrix).reshape(-1, 2)
- area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas'))
+ area_mask = create_area_mask(crop_vision_frame, face_landmark_68, _get('face_mask_areas'))
crop_masks.append(area_mask)
- if 'region' in state_manager.get_item('face_mask_types'):
- region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions'))
+ if 'region' in _get('face_mask_types'):
+ region_mask = create_region_mask(crop_vision_frame, _get('face_mask_regions'))
crop_masks.append(region_mask)
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
@@ -779,9 +781,42 @@ def extract_source_face(source_vision_frames : List[VisionFrame]) -> Optional[Fa
return get_average_face(source_faces)
-def process_frame(inputs : FaceSwapperInputs) -> ProcessorOutputs:
- reference_vision_frame = inputs.get('reference_vision_frame')
+@node(
+ name = 'face_swapper',
+ inputs =
+ [
+ NodePort(name = 'source', type = 'image', label = 'Source Face'),
+ NodePort(name = 'target', type = 'image', label = 'Target Image')
+ ],
+ outputs =
+ [
+ NodePort(name = 'image', type = 'image', label = 'Swapped Image')
+ ],
+ state_keys =
+ [
+ 'face_swapper_model',
+ 'face_swapper_pixel_boost',
+ 'face_swapper_weight'
+ ],
+ description = 'Swap faces from source to target image'
+)
+def process_frame(inputs, ctx = None):
+ source_frame = inputs.get('source')
+ target_frame = inputs.get('target')
+
+ if source_frame is not None and target_frame is not None:
+ source_face = extract_source_face([ source_frame ])
+ target_faces = get_many_faces([ target_frame ])
+ temp_vision_frame = target_frame.copy()
+
+ if source_face and target_faces:
+ for target_face in target_faces:
+ temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame, ctx)
+
+ return { 'image' : temp_vision_frame }
+
source_vision_frames = inputs.get('source_vision_frames')
+ reference_vision_frame = inputs.get('reference_vision_frame')
target_vision_frame = inputs.get('target_vision_frame')
temp_vision_frame = inputs.get('temp_vision_frame')
temp_vision_mask = inputs.get('temp_vision_mask')
@@ -791,6 +826,6 @@ def process_frame(inputs : FaceSwapperInputs) -> ProcessorOutputs:
if source_face and target_faces:
for target_face in target_faces:
target_face = scale_face(target_face, target_vision_frame, temp_vision_frame)
- temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame)
+ temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame, ctx)
- return temp_vision_frame, temp_vision_mask
+ return { 'image' : temp_vision_frame }