Add node-based architecture with workflow UI

- Node decorator system (@node) with NodeContext for composable processing
- Node registry with typed input/output ports (image, json)
- API endpoints: GET /nodes (list), POST /nodes/{name} (execute)
- Nodes: face_detector, face_landmarker, face_debugger, face_enhancer, face_swapper
- Workflow UI served as static HTML from /
- Auto-executing nodes, type-enforced connections, session persistence

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
henryruhs
2026-04-07 10:32:43 +02:00
parent cdd7c25586
commit 5fe245b1fa
8 changed files with 1524 additions and 43 deletions
+10 -1
View File
@@ -1,11 +1,15 @@
import os
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Route, WebSocketRoute
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.staticfiles import StaticFiles
from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset
from facefusion.apis.endpoints.capabilities import get_capabilities
from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics
from facefusion.apis.endpoints.nodes import execute_node, list_nodes
from facefusion.apis.endpoints.ping import websocket_ping
from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session
from facefusion.apis.endpoints.state import get_state, set_state
@@ -28,6 +32,8 @@ def create_api() -> Starlette:
Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
Route('/capabilities', get_capabilities, methods = [ 'GET' ]),
Route('/nodes', list_nodes, methods = [ 'GET' ]),
Route('/nodes/{node_name}', execute_node, methods = [ 'POST' ], middleware = [ session_guard ]),
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/stream', webrtc_stream, methods = ['POST'], middleware = [session_guard]),
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
@@ -35,6 +41,9 @@ def create_api() -> Starlette:
WebSocketRoute('/stream', websocket_stream, middleware = [session_guard])
]
static_path = os.path.join(os.path.dirname(__file__), 'static')
routes.append(Mount('/', app = StaticFiles(directory = static_path, html = True)))
api = Starlette(routes = routes)
api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ])
+128
View File
@@ -0,0 +1,128 @@
import numpy
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR
from facefusion import state_manager
from facefusion.node import NODE_REGISTRY, NodeContext, decode_vision_frame, encode_vision_frame
NODES_LOADED = False
def ensure_nodes_loaded() -> None:
global NODES_LOADED
if NODES_LOADED:
return
NODES_LOADED = True
import facefusion.face_analyser
processor_names =\
[
'face_debugger',
'face_enhancer',
'face_swapper'
]
for processor_name in processor_names:
try:
from facefusion.processors.core import load_processor_module
load_processor_module(processor_name)
except SystemExit:
pass
async def list_nodes(request : Request) -> JSONResponse:
ensure_nodes_loaded()
nodes = {}
for name, registered in NODE_REGISTRY.items():
schema = registered.schema
nodes[name] =\
{
'name' : name,
'description' : schema.description,
'inputs' : [ { 'name' : p.name, 'type' : p.type, 'label' : p.label } for p in schema.inputs ],
'outputs' : [ { 'name' : p.name, 'type' : p.type, 'label' : p.label } for p in schema.outputs ],
'state_keys' : schema.state_keys
}
return JSONResponse(nodes, status_code = HTTP_200_OK)
async def execute_node(request : Request) -> JSONResponse:
ensure_nodes_loaded()
node_name = request.path_params.get('node_name')
if node_name not in NODE_REGISTRY:
return JSONResponse(
{
'message' : 'node not found'
}, status_code = HTTP_404_NOT_FOUND)
registered = NODE_REGISTRY[node_name]
schema = registered.schema
body = await request.json()
raw_inputs = body.get('inputs', {})
state_overrides = body.get('state', {})
for key in state_overrides:
if key not in schema.state_keys:
return JSONResponse(
{
'message' : 'state key "' + key + '" not declared for node "' + node_name + '"'
}, status_code = HTTP_400_BAD_REQUEST)
# Decode inputs based on port types
decoded_inputs = {}
input_port_types = { p.name: p.type for p in schema.inputs }
for field_name, value in raw_inputs.items():
port_type = input_port_types.get(field_name, '')
if port_type == 'image' and isinstance(value, str):
decoded_inputs[field_name] = decode_vision_frame(value)
elif port_type == 'image_list' and isinstance(value, list):
decoded_inputs[field_name] = [ decode_vision_frame(v) for v in value ]
else:
decoded_inputs[field_name] = value
# Apply state overrides temporarily
saved_state = {}
for key, value in state_overrides.items():
saved_state[key] = state_manager.get_item(key)
state_manager.set_item(key, value)
try:
result = registered.fn(decoded_inputs)
# Encode outputs based on port types
output_port_types = { p.name: p.type for p in schema.outputs }
response = {}
for key, value in result.items():
port_type = output_port_types.get(key, '')
if port_type == 'image' and isinstance(value, numpy.ndarray):
response[key] = encode_vision_frame(value)
elif port_type == 'image_list' and isinstance(value, list):
response[key] = [ encode_vision_frame(v) for v in value if isinstance(v, numpy.ndarray) ]
else:
response[key] = value
return JSONResponse(response, status_code = HTTP_200_OK)
except Exception as exception:
import traceback
return JSONResponse(
{
'message' : str(exception),
'traceback' : traceback.format_exc()
}, status_code = HTTP_500_INTERNAL_SERVER_ERROR)
finally:
for key, value in saved_state.items():
state_manager.set_item(key, value)
File diff suppressed because it is too large Load Diff
+118 -1
View File
@@ -1,8 +1,10 @@
from typing import List, Optional
from typing import Any, Dict, List, Optional
import cv2
import numpy
from facefusion import state_manager
from facefusion.node import NodeContext, NodePort, encode_vision_frame, node
from facefusion.common_helper import get_first
from facefusion.face_classifier import classify_face
from facefusion.face_detector import detect_faces, detect_faces_by_angle
@@ -141,3 +143,118 @@ def scale_face(target_face : Face, target_vision_frame : VisionFrame, temp_visio
bounding_box = bounding_box,
landmark_set = landmark_set
)
def crop_face(vision_frame : VisionFrame, face : Face, padding : int = 30) -> VisionFrame:
h, w = vision_frame.shape[:2]
box = face.bounding_box.astype(int)
x1 = max(0, box[0] - padding)
y1 = max(0, box[1] - padding)
x2 = min(w, box[2] + padding)
y2 = min(h, box[3] + padding)
return vision_frame[y1:y2, x1:x2].copy()
def draw_bounding_boxes(vision_frame : VisionFrame, faces : List[Face]) -> VisionFrame:
result = numpy.ascontiguousarray(vision_frame.copy())
for face in faces:
box = face.bounding_box.astype(int)
cv2.rectangle(result, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
label = ''
if face.gender:
label += face.gender
if face.age:
label += f' {face.age.start}-{face.age.stop}'
if label:
cv2.putText(result, label.strip(), (box[0], box[1] - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
return result
@node(
name = 'face_detector',
inputs =
[
NodePort(name = 'image', type = 'image', label = 'Image')
],
outputs =
[
NodePort(name = 'face_image', type = 'image', label = 'Face Image'),
NodePort(name = 'face_data', type = 'json', label = 'Face Data')
],
state_keys =
[
'face_detector_model',
'face_detector_size',
'face_detector_margin',
'face_detector_angles',
'face_detector_score'
],
description = 'Detect faces and output annotated image with cropped faces'
)
def detect_faces_node(inputs : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]:
vision_frame = inputs.get('image')
faces = get_many_faces([ vision_frame ])
face_image = crop_face(vision_frame, faces[0]) if faces else vision_frame
face_data = []
for face in faces:
face_data.append(
{
'bounding_box' : face.bounding_box.tolist(),
'score' : float(face.score_set.get('detector', 0)),
'gender' : face.gender,
'age' : { 'start' : face.age.start, 'stop' : face.age.stop } if face.age else None,
'race' : face.race
})
return\
{
'face_image' : face_image,
'face_data' : face_data
}
@node(
name = 'face_landmarker',
inputs =
[
NodePort(name = 'image', type = 'image', label = 'Image')
],
outputs =
[
NodePort(name = 'image', type = 'image', label = 'Landmark Image'),
NodePort(name = 'landmark_data', type = 'json', label = 'Landmark Data')
],
state_keys =
[
'face_landmarker_model',
'face_landmarker_score'
],
description = 'Detect face landmarks'
)
def detect_landmarks_node(inputs : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]:
vision_frame = inputs.get('image')
faces = get_many_faces([ vision_frame ])
result_frame = numpy.ascontiguousarray(vision_frame.copy())
landmark_data = []
for face in faces:
face_landmark_68 = face.landmark_set.get('68')
if numpy.any(face_landmark_68):
for point in face_landmark_68.astype(int):
cv2.circle(result_frame, tuple(point), 2, (0, 255, 0), -1)
landmark_data.append(
{
'landmark_5' : face.landmark_set.get('5').tolist() if numpy.any(face.landmark_set.get('5')) else [],
'landmark_68' : face.landmark_set.get('68').tolist() if numpy.any(face.landmark_set.get('68')) else []
})
return\
{
'image' : result_frame,
'landmark_data' : landmark_data
}
+101
View File
@@ -0,0 +1,101 @@
import base64
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type
import cv2
import numpy
from numpy.typing import NDArray
@dataclass
class NodePort:
name : str
type : str # 'image', 'json', 'faces'
label : str = ''
@dataclass
class NodeSchema:
name : str
inputs : List[NodePort]
outputs : List[NodePort]
state_keys : List[str]
description : str = ''
@dataclass
class RegisteredNode:
schema : NodeSchema
fn : Callable
NODE_REGISTRY : Dict[str, RegisteredNode] = {}
class NodeContext:
def __init__(self, state : Dict[str, Any]) -> None:
self._state = dict(state)
def get_item(self, key : str) -> Any:
value = self._state.get(key)
if value is None:
from facefusion import state_manager
value = state_manager.get_item(key)
return value
def __getitem__(self, key : str) -> Any:
return self.get_item(key)
def __contains__(self, key : str) -> bool:
return key in self._state
def to_dict(self) -> Dict[str, Any]:
return dict(self._state)
def node(name : str, inputs : List[NodePort], outputs : List[NodePort], state_keys : List[str], description : str = '') -> Callable:
def decorator(fn : Callable) -> Callable:
schema = NodeSchema(
name = name,
inputs = inputs,
outputs = outputs,
state_keys = state_keys,
description = description
)
@wraps(fn)
def wrapper(inputs_dict : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]:
if ctx is None:
from facefusion import state_manager
state_snapshot = { key: state_manager.get_item(key) for key in state_keys }
ctx = NodeContext(state_snapshot)
return fn(inputs_dict, ctx)
wrapper.__node_schema__ = schema
NODE_REGISTRY[name] = RegisteredNode(schema = schema, fn = wrapper)
return wrapper
return decorator
def get_node(name : str) -> Optional[RegisteredNode]:
return NODE_REGISTRY.get(name)
def get_all_nodes() -> Dict[str, RegisteredNode]:
return NODE_REGISTRY
def decode_vision_frame(b64_string : str) -> NDArray[Any]:
image_bytes = base64.b64decode(b64_string)
return cv2.imdecode(numpy.frombuffer(image_bytes, numpy.uint8), cv2.IMREAD_COLOR)
def encode_vision_frame(frame : NDArray[Any], fmt : str = '.jpg') -> str:
_, buffer = cv2.imencode(fmt, frame)
return base64.b64encode(buffer.tobytes()).decode('utf-8')
@@ -6,6 +6,7 @@ import numpy
import facefusion.capability_store
import facefusion.jobs.job_manager
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, logger, state_manager, translator, video_manager
from facefusion.node import NodeContext, NodePort, node
from facefusion.face_analyser import scale_face
from facefusion.face_helper import warp_face_by_face_landmark_5
from facefusion.face_masker import create_area_mask, create_box_mask, create_occlusion_mask, create_region_mask
@@ -14,6 +15,7 @@ from facefusion.filesystem import in_directory, is_image, is_video
from facefusion.processors.modules.face_debugger import choices as face_debugger_choices
from facefusion.processors.modules.face_debugger.types import FaceDebuggerInputs
from facefusion.processors.types import ApplyStateItem, ProcessorOutputs
from typing import Optional
from facefusion.program_helper import find_argument_group
from facefusion.types import Args, Face, InferencePool, ProcessMode, VisionFrame
from facefusion.vision import read_static_image, read_static_video_frame
@@ -77,14 +79,14 @@ def post_process() -> None:
face_recognizer.clear_inference_pool()
def debug_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
face_debugger_items = state_manager.get_item('face_debugger_items')
def debug_face(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
face_debugger_items = ctx.get_item('face_debugger_items') if ctx else state_manager.get_item('face_debugger_items')
if 'bounding-box' in face_debugger_items:
temp_vision_frame = draw_bounding_box(target_face, temp_vision_frame)
if 'face-mask' in face_debugger_items:
temp_vision_frame = draw_face_mask(target_face, temp_vision_frame)
temp_vision_frame = draw_face_mask(target_face, temp_vision_frame, ctx)
if 'face-landmark-5' in face_debugger_items:
temp_vision_frame = draw_face_landmark_5(target_face, temp_vision_frame)
@@ -122,7 +124,8 @@ def draw_bounding_box(target_face : Face, temp_vision_frame : VisionFrame) -> Vi
return temp_vision_frame
def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
_get = ctx.get_item if ctx else state_manager.get_item
crop_masks = []
temp_vision_frame = numpy.ascontiguousarray(temp_vision_frame)
face_landmark_5 = target_face.landmark_set.get('5')
@@ -136,21 +139,21 @@ def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame) -> Visio
if numpy.array_equal(face_landmark_5, face_landmark_5_68):
mask_color = 255, 255, 0
if 'box' in state_manager.get_item('face_mask_types'):
box_mask = create_box_mask(crop_vision_frame, 0, state_manager.get_item('face_mask_padding'))
if 'box' in _get('face_mask_types'):
box_mask = create_box_mask(crop_vision_frame, 0, _get('face_mask_padding'))
crop_masks.append(box_mask)
if 'occlusion' in state_manager.get_item('face_mask_types'):
if 'occlusion' in _get('face_mask_types'):
occlusion_mask = create_occlusion_mask(crop_vision_frame)
crop_masks.append(occlusion_mask)
if 'area' in state_manager.get_item('face_mask_types'):
if 'area' in _get('face_mask_types'):
face_landmark_68 = cv2.transform(face_landmark_68.reshape(1, -1, 2), affine_matrix).reshape(-1, 2)
area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas'))
area_mask = create_area_mask(crop_vision_frame, face_landmark_68, _get('face_mask_areas'))
crop_masks.append(area_mask)
if 'region' in state_manager.get_item('face_mask_types'):
region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions'))
if 'region' in _get('face_mask_types'):
region_mask = create_region_mask(crop_vision_frame, _get('face_mask_regions'))
crop_masks.append(region_mask)
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
@@ -227,18 +230,51 @@ def draw_face_landmark_68_5(target_face : Face, temp_vision_frame : VisionFrame)
return temp_vision_frame
def process_frame(inputs : FaceDebuggerInputs) -> ProcessorOutputs:
@node(
name = 'face_debugger',
inputs =
[
NodePort(name = 'image', type = 'image', label = 'Image')
],
outputs =
[
NodePort(name = 'image', type = 'image', label = 'Debug Image')
],
state_keys =
[
'face_debugger_items'
],
description = 'Visualize face landmarks, bounding boxes and masks'
)
def process_frame(inputs, ctx = None):
from facefusion.face_analyser import get_many_faces
from facefusion.vision import extract_vision_mask
vision_frame = inputs.get('image')
if isinstance(inputs.get('reference_vision_frame'), type(None)) and vision_frame is not None:
target_vision_frame = vision_frame
temp_vision_frame = vision_frame.copy()
temp_vision_mask = extract_vision_mask(temp_vision_frame)
target_faces = get_many_faces([ target_vision_frame ])
if target_faces:
for target_face in target_faces:
temp_vision_frame = debug_face(target_face, temp_vision_frame, ctx)
return { 'image' : temp_vision_frame }
reference_vision_frame = inputs.get('reference_vision_frame')
target_vision_frame = inputs.get('target_vision_frame')
temp_vision_frame = inputs.get('temp_vision_frame')
target_vision_frame = inputs.get('target_vision_frame', vision_frame)
temp_vision_frame = inputs.get('temp_vision_frame', vision_frame.copy() if vision_frame is not None else None)
temp_vision_mask = inputs.get('temp_vision_mask')
target_faces = select_faces(reference_vision_frame, target_vision_frame)
if target_faces:
for target_face in target_faces:
target_face = scale_face(target_face, target_vision_frame, temp_vision_frame)
temp_vision_frame = debug_face(target_face, temp_vision_frame)
temp_vision_frame = debug_face(target_face, temp_vision_frame, ctx)
return temp_vision_frame, temp_vision_mask
return { 'image' : temp_vision_frame }
@@ -14,8 +14,10 @@ from facefusion.face_masker import create_box_mask, create_occlusion_mask
from facefusion.face_selector import select_faces
from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path
from facefusion.processors.modules.face_enhancer import choices as face_enhancer_choices
from facefusion.node import NodeContext, NodePort, node
from facefusion.processors.modules.face_enhancer.types import FaceEnhancerInputs, FaceEnhancerWeight
from facefusion.processors.types import ApplyStateItem, ProcessorOutputs
from typing import Optional
from facefusion.program_helper import find_argument_group
from facefusion.thread_helper import thread_semaphore
from facefusion.types import Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame
@@ -360,27 +362,28 @@ def post_process() -> None:
face_recognizer.clear_inference_pool()
def enhance_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
def enhance_face(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
_get = ctx.get_item if ctx else state_manager.get_item
model_template = get_model_options().get('template')
model_size = get_model_options().get('size')
crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, model_size)
box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), (0, 0, 0, 0))
box_mask = create_box_mask(crop_vision_frame, _get('face_mask_blur'), (0, 0, 0, 0))
crop_masks =\
[
box_mask
]
if 'occlusion' in state_manager.get_item('face_mask_types'):
if 'occlusion' in _get('face_mask_types'):
occlusion_mask = create_occlusion_mask(crop_vision_frame)
crop_masks.append(occlusion_mask)
crop_vision_frame = prepare_crop_frame(crop_vision_frame)
face_enhancer_weight = numpy.array([ state_manager.get_item('face_enhancer_weight') ]).astype(numpy.double)
face_enhancer_weight = numpy.array([ _get('face_enhancer_weight') ]).astype(numpy.double)
crop_vision_frame = forward(crop_vision_frame, face_enhancer_weight)
crop_vision_frame = normalize_crop_frame(crop_vision_frame)
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix)
temp_vision_frame = blend_paste_frame(temp_vision_frame, paste_vision_frame)
temp_vision_frame = blend_paste_frame(temp_vision_frame, paste_vision_frame, ctx)
return temp_vision_frame
@@ -426,22 +429,57 @@ def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
return crop_vision_frame
def blend_paste_frame(temp_vision_frame : VisionFrame, paste_vision_frame : VisionFrame) -> VisionFrame:
face_enhancer_blend = 1 - (state_manager.get_item('face_enhancer_blend') / 100)
def blend_paste_frame(temp_vision_frame : VisionFrame, paste_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
_get = ctx.get_item if ctx else state_manager.get_item
face_enhancer_blend = 1 - (_get('face_enhancer_blend') / 100)
temp_vision_frame = blend_frame(temp_vision_frame, paste_vision_frame, 1 - face_enhancer_blend)
return temp_vision_frame
def process_frame(inputs : FaceEnhancerInputs) -> ProcessorOutputs:
@node(
name = 'face_enhancer',
inputs =
[
NodePort(name = 'image', type = 'image', label = 'Image')
],
outputs =
[
NodePort(name = 'image', type = 'image', label = 'Enhanced Image')
],
state_keys =
[
'face_enhancer_model',
'face_enhancer_blend',
'face_enhancer_weight'
],
description = 'Enhance and restore face quality'
)
def process_frame(inputs, ctx = None):
from facefusion.face_analyser import get_many_faces
from facefusion.vision import extract_vision_mask
vision_frame = inputs.get('image')
if isinstance(inputs.get('reference_vision_frame'), type(None)) and vision_frame is not None:
target_vision_frame = vision_frame
temp_vision_frame = vision_frame.copy()
target_faces = get_many_faces([ target_vision_frame ])
if target_faces:
for target_face in target_faces:
temp_vision_frame = enhance_face(target_face, temp_vision_frame, ctx)
return { 'image' : temp_vision_frame }
reference_vision_frame = inputs.get('reference_vision_frame')
target_vision_frame = inputs.get('target_vision_frame')
temp_vision_frame = inputs.get('temp_vision_frame')
target_vision_frame = inputs.get('target_vision_frame', vision_frame)
temp_vision_frame = inputs.get('temp_vision_frame', vision_frame.copy() if vision_frame is not None else None)
temp_vision_mask = inputs.get('temp_vision_mask')
target_faces = select_faces(reference_vision_frame, target_vision_frame)
if target_faces:
for target_face in target_faces:
target_face = scale_face(target_face, target_vision_frame, temp_vision_frame)
temp_vision_frame = enhance_face(target_face, temp_vision_frame)
temp_vision_frame = enhance_face(target_face, temp_vision_frame, ctx)
return temp_vision_frame, temp_vision_mask
return { 'image' : temp_vision_frame }
@@ -19,6 +19,7 @@ from facefusion.face_selector import select_faces, sort_faces_by_order
from facefusion.filesystem import filter_image_paths, has_image, in_directory, is_image, is_video, resolve_relative_path
from facefusion.model_helper import get_static_model_initializer
from facefusion.processors.modules.face_swapper import choices as face_swapper_choices
from facefusion.node import NodeContext, NodePort, node
from facefusion.processors.modules.face_swapper.types import FaceSwapperInputs
from facefusion.processors.pixel_boost import explode_pixel_boost, implode_pixel_boost
from facefusion.processors.types import ApplyStateItem, ProcessorOutputs
@@ -600,20 +601,21 @@ def post_process() -> None:
face_recognizer.clear_inference_pool()
def swap_face(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
def swap_face(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
_get = ctx.get_item if ctx else state_manager.get_item
model_template = get_model_options().get('template')
model_size = get_model_options().get('size')
pixel_boost_size = unpack_resolution(state_manager.get_item('face_swapper_pixel_boost'))
pixel_boost_size = unpack_resolution(_get('face_swapper_pixel_boost'))
pixel_boost_total = pixel_boost_size[0] // model_size[0]
crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, pixel_boost_size)
temp_vision_frames = []
crop_masks = []
if 'box' in state_manager.get_item('face_mask_types'):
box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), state_manager.get_item('face_mask_padding'))
if 'box' in _get('face_mask_types'):
box_mask = create_box_mask(crop_vision_frame, _get('face_mask_blur'), _get('face_mask_padding'))
crop_masks.append(box_mask)
if 'occlusion' in state_manager.get_item('face_mask_types'):
if 'occlusion' in _get('face_mask_types'):
occlusion_mask = create_occlusion_mask(crop_vision_frame)
crop_masks.append(occlusion_mask)
@@ -625,13 +627,13 @@ def swap_face(source_face : Face, target_face : Face, temp_vision_frame : Vision
temp_vision_frames.append(pixel_boost_vision_frame)
crop_vision_frame = explode_pixel_boost(temp_vision_frames, pixel_boost_total, model_size, pixel_boost_size)
if 'area' in state_manager.get_item('face_mask_types'):
if 'area' in _get('face_mask_types'):
face_landmark_68 = cv2.transform(target_face.landmark_set.get('68').reshape(1, -1, 2), affine_matrix).reshape(-1, 2)
area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas'))
area_mask = create_area_mask(crop_vision_frame, face_landmark_68, _get('face_mask_areas'))
crop_masks.append(area_mask)
if 'region' in state_manager.get_item('face_mask_types'):
region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions'))
if 'region' in _get('face_mask_types'):
region_mask = create_region_mask(crop_vision_frame, _get('face_mask_regions'))
crop_masks.append(region_mask)
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
@@ -779,9 +781,42 @@ def extract_source_face(source_vision_frames : List[VisionFrame]) -> Optional[Fa
return get_average_face(source_faces)
def process_frame(inputs : FaceSwapperInputs) -> ProcessorOutputs:
reference_vision_frame = inputs.get('reference_vision_frame')
@node(
name = 'face_swapper',
inputs =
[
NodePort(name = 'source', type = 'image', label = 'Source Face'),
NodePort(name = 'target', type = 'image', label = 'Target Image')
],
outputs =
[
NodePort(name = 'image', type = 'image', label = 'Swapped Image')
],
state_keys =
[
'face_swapper_model',
'face_swapper_pixel_boost',
'face_swapper_weight'
],
description = 'Swap faces from source to target image'
)
def process_frame(inputs, ctx = None):
source_frame = inputs.get('source')
target_frame = inputs.get('target')
if source_frame is not None and target_frame is not None:
source_face = extract_source_face([ source_frame ])
target_faces = get_many_faces([ target_frame ])
temp_vision_frame = target_frame.copy()
if source_face and target_faces:
for target_face in target_faces:
temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame, ctx)
return { 'image' : temp_vision_frame }
source_vision_frames = inputs.get('source_vision_frames')
reference_vision_frame = inputs.get('reference_vision_frame')
target_vision_frame = inputs.get('target_vision_frame')
temp_vision_frame = inputs.get('temp_vision_frame')
temp_vision_mask = inputs.get('temp_vision_mask')
@@ -791,6 +826,6 @@ def process_frame(inputs : FaceSwapperInputs) -> ProcessorOutputs:
if source_face and target_faces:
for target_face in target_faces:
target_face = scale_face(target_face, target_vision_frame, temp_vision_frame)
temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame)
temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame, ctx)
return temp_vision_frame, temp_vision_mask
return { 'image' : temp_vision_frame }