mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-29 13:05:59 +02:00
Add node-based architecture with workflow UI
- Node decorator system (@node) with NodeContext for composable processing
- Node registry with typed input/output ports (image, json)
- API endpoints: GET /nodes (list), POST /nodes/{name} (execute)
- Nodes: face_detector, face_landmarker, face_debugger, face_enhancer, face_swapper
- Workflow UI served as static HTML from /
- Auto-executing nodes, type-enforced connections, session persistence
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+10
-1
@@ -1,11 +1,15 @@
|
||||
import os
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.routing import Route, WebSocketRoute
|
||||
from starlette.routing import Mount, Route, WebSocketRoute
|
||||
from starlette.staticfiles import StaticFiles
|
||||
|
||||
from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset
|
||||
from facefusion.apis.endpoints.capabilities import get_capabilities
|
||||
from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics
|
||||
from facefusion.apis.endpoints.nodes import execute_node, list_nodes
|
||||
from facefusion.apis.endpoints.ping import websocket_ping
|
||||
from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.endpoints.state import get_state, set_state
|
||||
@@ -28,6 +32,8 @@ def create_api() -> Starlette:
|
||||
Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
Route('/capabilities', get_capabilities, methods = [ 'GET' ]),
|
||||
Route('/nodes', list_nodes, methods = [ 'GET' ]),
|
||||
Route('/nodes/{node_name}', execute_node, methods = [ 'POST' ], middleware = [ session_guard ]),
|
||||
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/stream', webrtc_stream, methods = ['POST'], middleware = [session_guard]),
|
||||
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
|
||||
@@ -35,6 +41,9 @@ def create_api() -> Starlette:
|
||||
WebSocketRoute('/stream', websocket_stream, middleware = [session_guard])
|
||||
]
|
||||
|
||||
static_path = os.path.join(os.path.dirname(__file__), 'static')
|
||||
routes.append(Mount('/', app = StaticFiles(directory = static_path, html = True)))
|
||||
|
||||
api = Starlette(routes = routes)
|
||||
api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ])
|
||||
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
import numpy
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
from facefusion import state_manager
|
||||
from facefusion.node import NODE_REGISTRY, NodeContext, decode_vision_frame, encode_vision_frame
|
||||
|
||||
NODES_LOADED = False
|
||||
|
||||
|
||||
def ensure_nodes_loaded() -> None:
|
||||
global NODES_LOADED
|
||||
|
||||
if NODES_LOADED:
|
||||
return
|
||||
|
||||
NODES_LOADED = True
|
||||
|
||||
import facefusion.face_analyser
|
||||
|
||||
processor_names =\
|
||||
[
|
||||
'face_debugger',
|
||||
'face_enhancer',
|
||||
'face_swapper'
|
||||
]
|
||||
|
||||
for processor_name in processor_names:
|
||||
try:
|
||||
from facefusion.processors.core import load_processor_module
|
||||
|
||||
load_processor_module(processor_name)
|
||||
except SystemExit:
|
||||
pass
|
||||
|
||||
|
||||
async def list_nodes(request : Request) -> JSONResponse:
|
||||
ensure_nodes_loaded()
|
||||
nodes = {}
|
||||
|
||||
for name, registered in NODE_REGISTRY.items():
|
||||
schema = registered.schema
|
||||
nodes[name] =\
|
||||
{
|
||||
'name' : name,
|
||||
'description' : schema.description,
|
||||
'inputs' : [ { 'name' : p.name, 'type' : p.type, 'label' : p.label } for p in schema.inputs ],
|
||||
'outputs' : [ { 'name' : p.name, 'type' : p.type, 'label' : p.label } for p in schema.outputs ],
|
||||
'state_keys' : schema.state_keys
|
||||
}
|
||||
|
||||
return JSONResponse(nodes, status_code = HTTP_200_OK)
|
||||
|
||||
|
||||
async def execute_node(request : Request) -> JSONResponse:
|
||||
ensure_nodes_loaded()
|
||||
node_name = request.path_params.get('node_name')
|
||||
|
||||
if node_name not in NODE_REGISTRY:
|
||||
return JSONResponse(
|
||||
{
|
||||
'message' : 'node not found'
|
||||
}, status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
registered = NODE_REGISTRY[node_name]
|
||||
schema = registered.schema
|
||||
body = await request.json()
|
||||
raw_inputs = body.get('inputs', {})
|
||||
state_overrides = body.get('state', {})
|
||||
|
||||
for key in state_overrides:
|
||||
if key not in schema.state_keys:
|
||||
return JSONResponse(
|
||||
{
|
||||
'message' : 'state key "' + key + '" not declared for node "' + node_name + '"'
|
||||
}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Decode inputs based on port types
|
||||
decoded_inputs = {}
|
||||
input_port_types = { p.name: p.type for p in schema.inputs }
|
||||
|
||||
for field_name, value in raw_inputs.items():
|
||||
port_type = input_port_types.get(field_name, '')
|
||||
|
||||
if port_type == 'image' and isinstance(value, str):
|
||||
decoded_inputs[field_name] = decode_vision_frame(value)
|
||||
elif port_type == 'image_list' and isinstance(value, list):
|
||||
decoded_inputs[field_name] = [ decode_vision_frame(v) for v in value ]
|
||||
else:
|
||||
decoded_inputs[field_name] = value
|
||||
|
||||
# Apply state overrides temporarily
|
||||
saved_state = {}
|
||||
|
||||
for key, value in state_overrides.items():
|
||||
saved_state[key] = state_manager.get_item(key)
|
||||
state_manager.set_item(key, value)
|
||||
|
||||
try:
|
||||
result = registered.fn(decoded_inputs)
|
||||
|
||||
# Encode outputs based on port types
|
||||
output_port_types = { p.name: p.type for p in schema.outputs }
|
||||
response = {}
|
||||
|
||||
for key, value in result.items():
|
||||
port_type = output_port_types.get(key, '')
|
||||
|
||||
if port_type == 'image' and isinstance(value, numpy.ndarray):
|
||||
response[key] = encode_vision_frame(value)
|
||||
elif port_type == 'image_list' and isinstance(value, list):
|
||||
response[key] = [ encode_vision_frame(v) for v in value if isinstance(v, numpy.ndarray) ]
|
||||
else:
|
||||
response[key] = value
|
||||
|
||||
return JSONResponse(response, status_code = HTTP_200_OK)
|
||||
except Exception as exception:
|
||||
import traceback
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message' : str(exception),
|
||||
'traceback' : traceback.format_exc()
|
||||
}, status_code = HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
finally:
|
||||
for key, value in saved_state.items():
|
||||
state_manager.set_item(key, value)
|
||||
File diff suppressed because it is too large
Load Diff
+118
-1
@@ -1,8 +1,10 @@
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
|
||||
from facefusion import state_manager
|
||||
from facefusion.node import NodeContext, NodePort, encode_vision_frame, node
|
||||
from facefusion.common_helper import get_first
|
||||
from facefusion.face_classifier import classify_face
|
||||
from facefusion.face_detector import detect_faces, detect_faces_by_angle
|
||||
@@ -141,3 +143,118 @@ def scale_face(target_face : Face, target_vision_frame : VisionFrame, temp_visio
|
||||
bounding_box = bounding_box,
|
||||
landmark_set = landmark_set
|
||||
)
|
||||
|
||||
|
||||
def crop_face(vision_frame : VisionFrame, face : Face, padding : int = 30) -> VisionFrame:
|
||||
h, w = vision_frame.shape[:2]
|
||||
box = face.bounding_box.astype(int)
|
||||
x1 = max(0, box[0] - padding)
|
||||
y1 = max(0, box[1] - padding)
|
||||
x2 = min(w, box[2] + padding)
|
||||
y2 = min(h, box[3] + padding)
|
||||
return vision_frame[y1:y2, x1:x2].copy()
|
||||
|
||||
|
||||
def draw_bounding_boxes(vision_frame : VisionFrame, faces : List[Face]) -> VisionFrame:
|
||||
result = numpy.ascontiguousarray(vision_frame.copy())
|
||||
|
||||
for face in faces:
|
||||
box = face.bounding_box.astype(int)
|
||||
cv2.rectangle(result, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
|
||||
label = ''
|
||||
if face.gender:
|
||||
label += face.gender
|
||||
if face.age:
|
||||
label += f' {face.age.start}-{face.age.stop}'
|
||||
if label:
|
||||
cv2.putText(result, label.strip(), (box[0], box[1] - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@node(
|
||||
name = 'face_detector',
|
||||
inputs =
|
||||
[
|
||||
NodePort(name = 'image', type = 'image', label = 'Image')
|
||||
],
|
||||
outputs =
|
||||
[
|
||||
NodePort(name = 'face_image', type = 'image', label = 'Face Image'),
|
||||
NodePort(name = 'face_data', type = 'json', label = 'Face Data')
|
||||
],
|
||||
state_keys =
|
||||
[
|
||||
'face_detector_model',
|
||||
'face_detector_size',
|
||||
'face_detector_margin',
|
||||
'face_detector_angles',
|
||||
'face_detector_score'
|
||||
],
|
||||
description = 'Detect faces and output annotated image with cropped faces'
|
||||
)
|
||||
def detect_faces_node(inputs : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]:
|
||||
vision_frame = inputs.get('image')
|
||||
faces = get_many_faces([ vision_frame ])
|
||||
face_image = crop_face(vision_frame, faces[0]) if faces else vision_frame
|
||||
|
||||
face_data = []
|
||||
for face in faces:
|
||||
face_data.append(
|
||||
{
|
||||
'bounding_box' : face.bounding_box.tolist(),
|
||||
'score' : float(face.score_set.get('detector', 0)),
|
||||
'gender' : face.gender,
|
||||
'age' : { 'start' : face.age.start, 'stop' : face.age.stop } if face.age else None,
|
||||
'race' : face.race
|
||||
})
|
||||
|
||||
return\
|
||||
{
|
||||
'face_image' : face_image,
|
||||
'face_data' : face_data
|
||||
}
|
||||
|
||||
|
||||
@node(
|
||||
name = 'face_landmarker',
|
||||
inputs =
|
||||
[
|
||||
NodePort(name = 'image', type = 'image', label = 'Image')
|
||||
],
|
||||
outputs =
|
||||
[
|
||||
NodePort(name = 'image', type = 'image', label = 'Landmark Image'),
|
||||
NodePort(name = 'landmark_data', type = 'json', label = 'Landmark Data')
|
||||
],
|
||||
state_keys =
|
||||
[
|
||||
'face_landmarker_model',
|
||||
'face_landmarker_score'
|
||||
],
|
||||
description = 'Detect face landmarks'
|
||||
)
|
||||
def detect_landmarks_node(inputs : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]:
|
||||
vision_frame = inputs.get('image')
|
||||
faces = get_many_faces([ vision_frame ])
|
||||
result_frame = numpy.ascontiguousarray(vision_frame.copy())
|
||||
|
||||
landmark_data = []
|
||||
for face in faces:
|
||||
face_landmark_68 = face.landmark_set.get('68')
|
||||
|
||||
if numpy.any(face_landmark_68):
|
||||
for point in face_landmark_68.astype(int):
|
||||
cv2.circle(result_frame, tuple(point), 2, (0, 255, 0), -1)
|
||||
|
||||
landmark_data.append(
|
||||
{
|
||||
'landmark_5' : face.landmark_set.get('5').tolist() if numpy.any(face.landmark_set.get('5')) else [],
|
||||
'landmark_68' : face.landmark_set.get('68').tolist() if numpy.any(face.landmark_set.get('68')) else []
|
||||
})
|
||||
|
||||
return\
|
||||
{
|
||||
'image' : result_frame,
|
||||
'landmark_data' : landmark_data
|
||||
}
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodePort:
|
||||
name : str
|
||||
type : str # 'image', 'json', 'faces'
|
||||
label : str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeSchema:
|
||||
name : str
|
||||
inputs : List[NodePort]
|
||||
outputs : List[NodePort]
|
||||
state_keys : List[str]
|
||||
description : str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisteredNode:
|
||||
schema : NodeSchema
|
||||
fn : Callable
|
||||
|
||||
|
||||
NODE_REGISTRY : Dict[str, RegisteredNode] = {}
|
||||
|
||||
|
||||
class NodeContext:
|
||||
def __init__(self, state : Dict[str, Any]) -> None:
|
||||
self._state = dict(state)
|
||||
|
||||
def get_item(self, key : str) -> Any:
|
||||
value = self._state.get(key)
|
||||
|
||||
if value is None:
|
||||
from facefusion import state_manager
|
||||
|
||||
value = state_manager.get_item(key)
|
||||
|
||||
return value
|
||||
|
||||
def __getitem__(self, key : str) -> Any:
|
||||
return self.get_item(key)
|
||||
|
||||
def __contains__(self, key : str) -> bool:
|
||||
return key in self._state
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return dict(self._state)
|
||||
|
||||
|
||||
def node(name : str, inputs : List[NodePort], outputs : List[NodePort], state_keys : List[str], description : str = '') -> Callable:
|
||||
def decorator(fn : Callable) -> Callable:
|
||||
schema = NodeSchema(
|
||||
name = name,
|
||||
inputs = inputs,
|
||||
outputs = outputs,
|
||||
state_keys = state_keys,
|
||||
description = description
|
||||
)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(inputs_dict : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]:
|
||||
if ctx is None:
|
||||
from facefusion import state_manager
|
||||
|
||||
state_snapshot = { key: state_manager.get_item(key) for key in state_keys }
|
||||
ctx = NodeContext(state_snapshot)
|
||||
return fn(inputs_dict, ctx)
|
||||
|
||||
wrapper.__node_schema__ = schema
|
||||
NODE_REGISTRY[name] = RegisteredNode(schema = schema, fn = wrapper)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_node(name : str) -> Optional[RegisteredNode]:
|
||||
return NODE_REGISTRY.get(name)
|
||||
|
||||
|
||||
def get_all_nodes() -> Dict[str, RegisteredNode]:
|
||||
return NODE_REGISTRY
|
||||
|
||||
|
||||
def decode_vision_frame(b64_string : str) -> NDArray[Any]:
|
||||
image_bytes = base64.b64decode(b64_string)
|
||||
return cv2.imdecode(numpy.frombuffer(image_bytes, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
|
||||
def encode_vision_frame(frame : NDArray[Any], fmt : str = '.jpg') -> str:
|
||||
_, buffer = cv2.imencode(fmt, frame)
|
||||
return base64.b64encode(buffer.tobytes()).decode('utf-8')
|
||||
@@ -6,6 +6,7 @@ import numpy
|
||||
import facefusion.capability_store
|
||||
import facefusion.jobs.job_manager
|
||||
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, logger, state_manager, translator, video_manager
|
||||
from facefusion.node import NodeContext, NodePort, node
|
||||
from facefusion.face_analyser import scale_face
|
||||
from facefusion.face_helper import warp_face_by_face_landmark_5
|
||||
from facefusion.face_masker import create_area_mask, create_box_mask, create_occlusion_mask, create_region_mask
|
||||
@@ -14,6 +15,7 @@ from facefusion.filesystem import in_directory, is_image, is_video
|
||||
from facefusion.processors.modules.face_debugger import choices as face_debugger_choices
|
||||
from facefusion.processors.modules.face_debugger.types import FaceDebuggerInputs
|
||||
from facefusion.processors.types import ApplyStateItem, ProcessorOutputs
|
||||
from typing import Optional
|
||||
from facefusion.program_helper import find_argument_group
|
||||
from facefusion.types import Args, Face, InferencePool, ProcessMode, VisionFrame
|
||||
from facefusion.vision import read_static_image, read_static_video_frame
|
||||
@@ -77,14 +79,14 @@ def post_process() -> None:
|
||||
face_recognizer.clear_inference_pool()
|
||||
|
||||
|
||||
def debug_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
|
||||
face_debugger_items = state_manager.get_item('face_debugger_items')
|
||||
def debug_face(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
|
||||
face_debugger_items = ctx.get_item('face_debugger_items') if ctx else state_manager.get_item('face_debugger_items')
|
||||
|
||||
if 'bounding-box' in face_debugger_items:
|
||||
temp_vision_frame = draw_bounding_box(target_face, temp_vision_frame)
|
||||
|
||||
if 'face-mask' in face_debugger_items:
|
||||
temp_vision_frame = draw_face_mask(target_face, temp_vision_frame)
|
||||
temp_vision_frame = draw_face_mask(target_face, temp_vision_frame, ctx)
|
||||
|
||||
if 'face-landmark-5' in face_debugger_items:
|
||||
temp_vision_frame = draw_face_landmark_5(target_face, temp_vision_frame)
|
||||
@@ -122,7 +124,8 @@ def draw_bounding_box(target_face : Face, temp_vision_frame : VisionFrame) -> Vi
|
||||
return temp_vision_frame
|
||||
|
||||
|
||||
def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
|
||||
def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
|
||||
_get = ctx.get_item if ctx else state_manager.get_item
|
||||
crop_masks = []
|
||||
temp_vision_frame = numpy.ascontiguousarray(temp_vision_frame)
|
||||
face_landmark_5 = target_face.landmark_set.get('5')
|
||||
@@ -136,21 +139,21 @@ def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame) -> Visio
|
||||
if numpy.array_equal(face_landmark_5, face_landmark_5_68):
|
||||
mask_color = 255, 255, 0
|
||||
|
||||
if 'box' in state_manager.get_item('face_mask_types'):
|
||||
box_mask = create_box_mask(crop_vision_frame, 0, state_manager.get_item('face_mask_padding'))
|
||||
if 'box' in _get('face_mask_types'):
|
||||
box_mask = create_box_mask(crop_vision_frame, 0, _get('face_mask_padding'))
|
||||
crop_masks.append(box_mask)
|
||||
|
||||
if 'occlusion' in state_manager.get_item('face_mask_types'):
|
||||
if 'occlusion' in _get('face_mask_types'):
|
||||
occlusion_mask = create_occlusion_mask(crop_vision_frame)
|
||||
crop_masks.append(occlusion_mask)
|
||||
|
||||
if 'area' in state_manager.get_item('face_mask_types'):
|
||||
if 'area' in _get('face_mask_types'):
|
||||
face_landmark_68 = cv2.transform(face_landmark_68.reshape(1, -1, 2), affine_matrix).reshape(-1, 2)
|
||||
area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas'))
|
||||
area_mask = create_area_mask(crop_vision_frame, face_landmark_68, _get('face_mask_areas'))
|
||||
crop_masks.append(area_mask)
|
||||
|
||||
if 'region' in state_manager.get_item('face_mask_types'):
|
||||
region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions'))
|
||||
if 'region' in _get('face_mask_types'):
|
||||
region_mask = create_region_mask(crop_vision_frame, _get('face_mask_regions'))
|
||||
crop_masks.append(region_mask)
|
||||
|
||||
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
|
||||
@@ -227,18 +230,51 @@ def draw_face_landmark_68_5(target_face : Face, temp_vision_frame : VisionFrame)
|
||||
return temp_vision_frame
|
||||
|
||||
|
||||
def process_frame(inputs : FaceDebuggerInputs) -> ProcessorOutputs:
|
||||
@node(
|
||||
name = 'face_debugger',
|
||||
inputs =
|
||||
[
|
||||
NodePort(name = 'image', type = 'image', label = 'Image')
|
||||
],
|
||||
outputs =
|
||||
[
|
||||
NodePort(name = 'image', type = 'image', label = 'Debug Image')
|
||||
],
|
||||
state_keys =
|
||||
[
|
||||
'face_debugger_items'
|
||||
],
|
||||
description = 'Visualize face landmarks, bounding boxes and masks'
|
||||
)
|
||||
def process_frame(inputs, ctx = None):
|
||||
from facefusion.face_analyser import get_many_faces
|
||||
from facefusion.vision import extract_vision_mask
|
||||
|
||||
vision_frame = inputs.get('image')
|
||||
|
||||
if isinstance(inputs.get('reference_vision_frame'), type(None)) and vision_frame is not None:
|
||||
target_vision_frame = vision_frame
|
||||
temp_vision_frame = vision_frame.copy()
|
||||
temp_vision_mask = extract_vision_mask(temp_vision_frame)
|
||||
target_faces = get_many_faces([ target_vision_frame ])
|
||||
|
||||
if target_faces:
|
||||
for target_face in target_faces:
|
||||
temp_vision_frame = debug_face(target_face, temp_vision_frame, ctx)
|
||||
|
||||
return { 'image' : temp_vision_frame }
|
||||
|
||||
reference_vision_frame = inputs.get('reference_vision_frame')
|
||||
target_vision_frame = inputs.get('target_vision_frame')
|
||||
temp_vision_frame = inputs.get('temp_vision_frame')
|
||||
target_vision_frame = inputs.get('target_vision_frame', vision_frame)
|
||||
temp_vision_frame = inputs.get('temp_vision_frame', vision_frame.copy() if vision_frame is not None else None)
|
||||
temp_vision_mask = inputs.get('temp_vision_mask')
|
||||
target_faces = select_faces(reference_vision_frame, target_vision_frame)
|
||||
|
||||
if target_faces:
|
||||
for target_face in target_faces:
|
||||
target_face = scale_face(target_face, target_vision_frame, temp_vision_frame)
|
||||
temp_vision_frame = debug_face(target_face, temp_vision_frame)
|
||||
temp_vision_frame = debug_face(target_face, temp_vision_frame, ctx)
|
||||
|
||||
return temp_vision_frame, temp_vision_mask
|
||||
return { 'image' : temp_vision_frame }
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,10 @@ from facefusion.face_masker import create_box_mask, create_occlusion_mask
|
||||
from facefusion.face_selector import select_faces
|
||||
from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path
|
||||
from facefusion.processors.modules.face_enhancer import choices as face_enhancer_choices
|
||||
from facefusion.node import NodeContext, NodePort, node
|
||||
from facefusion.processors.modules.face_enhancer.types import FaceEnhancerInputs, FaceEnhancerWeight
|
||||
from facefusion.processors.types import ApplyStateItem, ProcessorOutputs
|
||||
from typing import Optional
|
||||
from facefusion.program_helper import find_argument_group
|
||||
from facefusion.thread_helper import thread_semaphore
|
||||
from facefusion.types import Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame
|
||||
@@ -360,27 +362,28 @@ def post_process() -> None:
|
||||
face_recognizer.clear_inference_pool()
|
||||
|
||||
|
||||
def enhance_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
|
||||
def enhance_face(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
|
||||
_get = ctx.get_item if ctx else state_manager.get_item
|
||||
model_template = get_model_options().get('template')
|
||||
model_size = get_model_options().get('size')
|
||||
crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, model_size)
|
||||
box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), (0, 0, 0, 0))
|
||||
box_mask = create_box_mask(crop_vision_frame, _get('face_mask_blur'), (0, 0, 0, 0))
|
||||
crop_masks =\
|
||||
[
|
||||
box_mask
|
||||
]
|
||||
|
||||
if 'occlusion' in state_manager.get_item('face_mask_types'):
|
||||
if 'occlusion' in _get('face_mask_types'):
|
||||
occlusion_mask = create_occlusion_mask(crop_vision_frame)
|
||||
crop_masks.append(occlusion_mask)
|
||||
|
||||
crop_vision_frame = prepare_crop_frame(crop_vision_frame)
|
||||
face_enhancer_weight = numpy.array([ state_manager.get_item('face_enhancer_weight') ]).astype(numpy.double)
|
||||
face_enhancer_weight = numpy.array([ _get('face_enhancer_weight') ]).astype(numpy.double)
|
||||
crop_vision_frame = forward(crop_vision_frame, face_enhancer_weight)
|
||||
crop_vision_frame = normalize_crop_frame(crop_vision_frame)
|
||||
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
|
||||
paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix)
|
||||
temp_vision_frame = blend_paste_frame(temp_vision_frame, paste_vision_frame)
|
||||
temp_vision_frame = blend_paste_frame(temp_vision_frame, paste_vision_frame, ctx)
|
||||
return temp_vision_frame
|
||||
|
||||
|
||||
@@ -426,22 +429,57 @@ def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
|
||||
return crop_vision_frame
|
||||
|
||||
|
||||
def blend_paste_frame(temp_vision_frame : VisionFrame, paste_vision_frame : VisionFrame) -> VisionFrame:
|
||||
face_enhancer_blend = 1 - (state_manager.get_item('face_enhancer_blend') / 100)
|
||||
def blend_paste_frame(temp_vision_frame : VisionFrame, paste_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
|
||||
_get = ctx.get_item if ctx else state_manager.get_item
|
||||
face_enhancer_blend = 1 - (_get('face_enhancer_blend') / 100)
|
||||
temp_vision_frame = blend_frame(temp_vision_frame, paste_vision_frame, 1 - face_enhancer_blend)
|
||||
return temp_vision_frame
|
||||
|
||||
|
||||
def process_frame(inputs : FaceEnhancerInputs) -> ProcessorOutputs:
|
||||
@node(
|
||||
name = 'face_enhancer',
|
||||
inputs =
|
||||
[
|
||||
NodePort(name = 'image', type = 'image', label = 'Image')
|
||||
],
|
||||
outputs =
|
||||
[
|
||||
NodePort(name = 'image', type = 'image', label = 'Enhanced Image')
|
||||
],
|
||||
state_keys =
|
||||
[
|
||||
'face_enhancer_model',
|
||||
'face_enhancer_blend',
|
||||
'face_enhancer_weight'
|
||||
],
|
||||
description = 'Enhance and restore face quality'
|
||||
)
|
||||
def process_frame(inputs, ctx = None):
|
||||
from facefusion.face_analyser import get_many_faces
|
||||
from facefusion.vision import extract_vision_mask
|
||||
|
||||
vision_frame = inputs.get('image')
|
||||
|
||||
if isinstance(inputs.get('reference_vision_frame'), type(None)) and vision_frame is not None:
|
||||
target_vision_frame = vision_frame
|
||||
temp_vision_frame = vision_frame.copy()
|
||||
target_faces = get_many_faces([ target_vision_frame ])
|
||||
|
||||
if target_faces:
|
||||
for target_face in target_faces:
|
||||
temp_vision_frame = enhance_face(target_face, temp_vision_frame, ctx)
|
||||
|
||||
return { 'image' : temp_vision_frame }
|
||||
|
||||
reference_vision_frame = inputs.get('reference_vision_frame')
|
||||
target_vision_frame = inputs.get('target_vision_frame')
|
||||
temp_vision_frame = inputs.get('temp_vision_frame')
|
||||
target_vision_frame = inputs.get('target_vision_frame', vision_frame)
|
||||
temp_vision_frame = inputs.get('temp_vision_frame', vision_frame.copy() if vision_frame is not None else None)
|
||||
temp_vision_mask = inputs.get('temp_vision_mask')
|
||||
target_faces = select_faces(reference_vision_frame, target_vision_frame)
|
||||
|
||||
if target_faces:
|
||||
for target_face in target_faces:
|
||||
target_face = scale_face(target_face, target_vision_frame, temp_vision_frame)
|
||||
temp_vision_frame = enhance_face(target_face, temp_vision_frame)
|
||||
temp_vision_frame = enhance_face(target_face, temp_vision_frame, ctx)
|
||||
|
||||
return temp_vision_frame, temp_vision_mask
|
||||
return { 'image' : temp_vision_frame }
|
||||
|
||||
@@ -19,6 +19,7 @@ from facefusion.face_selector import select_faces, sort_faces_by_order
|
||||
from facefusion.filesystem import filter_image_paths, has_image, in_directory, is_image, is_video, resolve_relative_path
|
||||
from facefusion.model_helper import get_static_model_initializer
|
||||
from facefusion.processors.modules.face_swapper import choices as face_swapper_choices
|
||||
from facefusion.node import NodeContext, NodePort, node
|
||||
from facefusion.processors.modules.face_swapper.types import FaceSwapperInputs
|
||||
from facefusion.processors.pixel_boost import explode_pixel_boost, implode_pixel_boost
|
||||
from facefusion.processors.types import ApplyStateItem, ProcessorOutputs
|
||||
@@ -600,20 +601,21 @@ def post_process() -> None:
|
||||
face_recognizer.clear_inference_pool()
|
||||
|
||||
|
||||
def swap_face(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
|
||||
def swap_face(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame:
|
||||
_get = ctx.get_item if ctx else state_manager.get_item
|
||||
model_template = get_model_options().get('template')
|
||||
model_size = get_model_options().get('size')
|
||||
pixel_boost_size = unpack_resolution(state_manager.get_item('face_swapper_pixel_boost'))
|
||||
pixel_boost_size = unpack_resolution(_get('face_swapper_pixel_boost'))
|
||||
pixel_boost_total = pixel_boost_size[0] // model_size[0]
|
||||
crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, pixel_boost_size)
|
||||
temp_vision_frames = []
|
||||
crop_masks = []
|
||||
|
||||
if 'box' in state_manager.get_item('face_mask_types'):
|
||||
box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), state_manager.get_item('face_mask_padding'))
|
||||
if 'box' in _get('face_mask_types'):
|
||||
box_mask = create_box_mask(crop_vision_frame, _get('face_mask_blur'), _get('face_mask_padding'))
|
||||
crop_masks.append(box_mask)
|
||||
|
||||
if 'occlusion' in state_manager.get_item('face_mask_types'):
|
||||
if 'occlusion' in _get('face_mask_types'):
|
||||
occlusion_mask = create_occlusion_mask(crop_vision_frame)
|
||||
crop_masks.append(occlusion_mask)
|
||||
|
||||
@@ -625,13 +627,13 @@ def swap_face(source_face : Face, target_face : Face, temp_vision_frame : Vision
|
||||
temp_vision_frames.append(pixel_boost_vision_frame)
|
||||
crop_vision_frame = explode_pixel_boost(temp_vision_frames, pixel_boost_total, model_size, pixel_boost_size)
|
||||
|
||||
if 'area' in state_manager.get_item('face_mask_types'):
|
||||
if 'area' in _get('face_mask_types'):
|
||||
face_landmark_68 = cv2.transform(target_face.landmark_set.get('68').reshape(1, -1, 2), affine_matrix).reshape(-1, 2)
|
||||
area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas'))
|
||||
area_mask = create_area_mask(crop_vision_frame, face_landmark_68, _get('face_mask_areas'))
|
||||
crop_masks.append(area_mask)
|
||||
|
||||
if 'region' in state_manager.get_item('face_mask_types'):
|
||||
region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions'))
|
||||
if 'region' in _get('face_mask_types'):
|
||||
region_mask = create_region_mask(crop_vision_frame, _get('face_mask_regions'))
|
||||
crop_masks.append(region_mask)
|
||||
|
||||
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
|
||||
@@ -779,9 +781,42 @@ def extract_source_face(source_vision_frames : List[VisionFrame]) -> Optional[Fa
|
||||
return get_average_face(source_faces)
|
||||
|
||||
|
||||
def process_frame(inputs : FaceSwapperInputs) -> ProcessorOutputs:
|
||||
reference_vision_frame = inputs.get('reference_vision_frame')
|
||||
@node(
|
||||
name = 'face_swapper',
|
||||
inputs =
|
||||
[
|
||||
NodePort(name = 'source', type = 'image', label = 'Source Face'),
|
||||
NodePort(name = 'target', type = 'image', label = 'Target Image')
|
||||
],
|
||||
outputs =
|
||||
[
|
||||
NodePort(name = 'image', type = 'image', label = 'Swapped Image')
|
||||
],
|
||||
state_keys =
|
||||
[
|
||||
'face_swapper_model',
|
||||
'face_swapper_pixel_boost',
|
||||
'face_swapper_weight'
|
||||
],
|
||||
description = 'Swap faces from source to target image'
|
||||
)
|
||||
def process_frame(inputs, ctx = None):
|
||||
source_frame = inputs.get('source')
|
||||
target_frame = inputs.get('target')
|
||||
|
||||
if source_frame is not None and target_frame is not None:
|
||||
source_face = extract_source_face([ source_frame ])
|
||||
target_faces = get_many_faces([ target_frame ])
|
||||
temp_vision_frame = target_frame.copy()
|
||||
|
||||
if source_face and target_faces:
|
||||
for target_face in target_faces:
|
||||
temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame, ctx)
|
||||
|
||||
return { 'image' : temp_vision_frame }
|
||||
|
||||
source_vision_frames = inputs.get('source_vision_frames')
|
||||
reference_vision_frame = inputs.get('reference_vision_frame')
|
||||
target_vision_frame = inputs.get('target_vision_frame')
|
||||
temp_vision_frame = inputs.get('temp_vision_frame')
|
||||
temp_vision_mask = inputs.get('temp_vision_mask')
|
||||
@@ -791,6 +826,6 @@ def process_frame(inputs : FaceSwapperInputs) -> ProcessorOutputs:
|
||||
if source_face and target_faces:
|
||||
for target_face in target_faces:
|
||||
target_face = scale_face(target_face, target_vision_frame, temp_vision_frame)
|
||||
temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame)
|
||||
temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame, ctx)
|
||||
|
||||
return temp_vision_frame, temp_vision_mask
|
||||
return { 'image' : temp_vision_frame }
|
||||
|
||||
Reference in New Issue
Block a user