diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index 5be40df6..6b077c0a 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -1,11 +1,15 @@ +import os + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware -from starlette.routing import Route, WebSocketRoute +from starlette.routing import Mount, Route, WebSocketRoute +from starlette.staticfiles import StaticFiles from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset from facefusion.apis.endpoints.capabilities import get_capabilities from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics +from facefusion.apis.endpoints.nodes import execute_node, list_nodes from facefusion.apis.endpoints.ping import websocket_ping from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session from facefusion.apis.endpoints.state import get_state, set_state @@ -28,6 +32,8 @@ def create_api() -> Starlette: Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ session_guard ]), Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]), Route('/capabilities', get_capabilities, methods = [ 'GET' ]), + Route('/nodes', list_nodes, methods = [ 'GET' ]), + Route('/nodes/{node_name}', execute_node, methods = [ 'POST' ], middleware = [ session_guard ]), Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]), Route('/stream', webrtc_stream, methods = ['POST'], middleware = [session_guard]), WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]), @@ -35,6 +41,9 @@ def create_api() -> Starlette: WebSocketRoute('/stream', websocket_stream, middleware = [session_guard]) ] + static_path = os.path.join(os.path.dirname(__file__), 'static') + routes.append(Mount('/', app = StaticFiles(directory = static_path, html = True))) + api = Starlette(routes = routes) api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ]) diff --git a/facefusion/apis/endpoints/nodes.py b/facefusion/apis/endpoints/nodes.py new file mode 100644 index 00000000..64cf89e5 --- /dev/null +++ b/facefusion/apis/endpoints/nodes.py @@ -0,0 +1,128 @@ +import numpy +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR + +from facefusion import state_manager +from facefusion.node import NODE_REGISTRY, NodeContext, decode_vision_frame, encode_vision_frame + +NODES_LOADED = False + + +def ensure_nodes_loaded() -> None: + global NODES_LOADED + + if NODES_LOADED: + return + + NODES_LOADED = True + + import facefusion.face_analyser + + processor_names =\ + [ + 'face_debugger', + 'face_enhancer', + 'face_swapper' + ] + + for processor_name in processor_names: + try: + from facefusion.processors.core import load_processor_module + + load_processor_module(processor_name) + except SystemExit: + pass + + +async def list_nodes(request : Request) -> JSONResponse: + ensure_nodes_loaded() + nodes = {} + + for name, registered in NODE_REGISTRY.items(): + schema = registered.schema + nodes[name] =\ + { + 'name' : name, + 'description' : schema.description, + 'inputs' : [ { 'name' : p.name, 'type' : p.type, 'label' : p.label } for p in schema.inputs ], + 'outputs' : [ { 'name' : p.name, 'type' : p.type, 'label' : p.label } for p in schema.outputs ], + 'state_keys' : schema.state_keys + } + + return JSONResponse(nodes, status_code = HTTP_200_OK) + + +async def execute_node(request : Request) -> JSONResponse: + ensure_nodes_loaded() + node_name = request.path_params.get('node_name') + + if node_name not in NODE_REGISTRY: + return JSONResponse( + { + 'message' : 'node not found' + }, status_code = HTTP_404_NOT_FOUND) + + registered = NODE_REGISTRY[node_name] + schema = registered.schema + body = await request.json() + raw_inputs = body.get('inputs', {}) + state_overrides = body.get('state', {}) + + for key in state_overrides: + if key not in schema.state_keys: + return JSONResponse( + { + 'message' : 'state key "' + key + '" not declared for node "' + node_name + '"' + }, status_code = HTTP_400_BAD_REQUEST) + + # Decode inputs based on port types + decoded_inputs = {} + input_port_types = { p.name: p.type for p in schema.inputs } + + for field_name, value in raw_inputs.items(): + port_type = input_port_types.get(field_name, '') + + if port_type == 'image' and isinstance(value, str): + decoded_inputs[field_name] = decode_vision_frame(value) + elif port_type == 'image_list' and isinstance(value, list): + decoded_inputs[field_name] = [ decode_vision_frame(v) for v in value ] + else: + decoded_inputs[field_name] = value + + # Apply state overrides temporarily + saved_state = {} + + for key, value in state_overrides.items(): + saved_state[key] = state_manager.get_item(key) + state_manager.set_item(key, value) + + try: + result = registered.fn(decoded_inputs) + + # Encode outputs based on port types + output_port_types = { p.name: p.type for p in schema.outputs } + response = {} + + for key, value in result.items(): + port_type = output_port_types.get(key, '') + + if port_type == 'image' and isinstance(value, numpy.ndarray): + response[key] = encode_vision_frame(value) + elif port_type == 'image_list' and isinstance(value, list): + response[key] = [ encode_vision_frame(v) for v in value if isinstance(v, numpy.ndarray) ] + else: + response[key] = value + + return JSONResponse(response, status_code = HTTP_200_OK) + except Exception as exception: + import traceback + + return JSONResponse( + { + 'message' : str(exception), + 'traceback' : traceback.format_exc() + }, status_code = HTTP_500_INTERNAL_SERVER_ERROR) + finally: + for key, value in saved_state.items(): + state_manager.set_item(key, value) diff --git a/facefusion/apis/static/index.html b/facefusion/apis/static/index.html new file mode 100644 index 00000000..41f6e73d --- /dev/null +++ b/facefusion/apis/static/index.html @@ -0,0 +1,1017 @@ + + + + + + Workflow Area + + + + + + + + + +
+ +
+
+
+ + Workflow Area +
+
+
+
+
+
No session
+ 100% + + +
+
+ +
+ + + + +
+
+ + + +
+ +
+
+
+

Connecting...

+
+
+
+ + + diff --git a/facefusion/face_analyser.py b/facefusion/face_analyser.py index 76b9b621..ff407e16 100644 --- a/facefusion/face_analyser.py +++ b/facefusion/face_analyser.py @@ -1,8 +1,10 @@ -from typing import List, Optional +from typing import Any, Dict, List, Optional +import cv2 import numpy from facefusion import state_manager +from facefusion.node import NodeContext, NodePort, encode_vision_frame, node from facefusion.common_helper import get_first from facefusion.face_classifier import classify_face from facefusion.face_detector import detect_faces, detect_faces_by_angle @@ -141,3 +143,118 @@ def scale_face(target_face : Face, target_vision_frame : VisionFrame, temp_visio bounding_box = bounding_box, landmark_set = landmark_set ) + + +def crop_face(vision_frame : VisionFrame, face : Face, padding : int = 30) -> VisionFrame: + h, w = vision_frame.shape[:2] + box = face.bounding_box.astype(int) + x1 = max(0, box[0] - padding) + y1 = max(0, box[1] - padding) + x2 = min(w, box[2] + padding) + y2 = min(h, box[3] + padding) + return vision_frame[y1:y2, x1:x2].copy() + + +def draw_bounding_boxes(vision_frame : VisionFrame, faces : List[Face]) -> VisionFrame: + result = numpy.ascontiguousarray(vision_frame.copy()) + + for face in faces: + box = face.bounding_box.astype(int) + cv2.rectangle(result, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) + label = '' + if face.gender: + label += face.gender + if face.age: + label += f' {face.age.start}-{face.age.stop}' + if label: + cv2.putText(result, label.strip(), (box[0], box[1] - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + + return result + + +@node( + name = 'face_detector', + inputs = + [ + NodePort(name = 'image', type = 'image', label = 'Image') + ], + outputs = + [ + NodePort(name = 'face_image', type = 'image', label = 'Face Image'), + NodePort(name = 'face_data', type = 'json', label = 'Face Data') + ], + state_keys = + [ + 'face_detector_model', + 'face_detector_size', + 'face_detector_margin', + 'face_detector_angles', + 'face_detector_score' + ], + description = 'Detect faces and output annotated image with cropped faces' +) +def detect_faces_node(inputs : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]: + vision_frame = inputs.get('image') + faces = get_many_faces([ vision_frame ]) + face_image = crop_face(vision_frame, faces[0]) if faces else vision_frame + + face_data = [] + for face in faces: + face_data.append( + { + 'bounding_box' : face.bounding_box.tolist(), + 'score' : float(face.score_set.get('detector', 0)), + 'gender' : face.gender, + 'age' : { 'start' : face.age.start, 'stop' : face.age.stop } if face.age else None, + 'race' : face.race + }) + + return\ + { + 'face_image' : face_image, + 'face_data' : face_data + } + + +@node( + name = 'face_landmarker', + inputs = + [ + NodePort(name = 'image', type = 'image', label = 'Image') + ], + outputs = + [ + NodePort(name = 'image', type = 'image', label = 'Landmark Image'), + NodePort(name = 'landmark_data', type = 'json', label = 'Landmark Data') + ], + state_keys = + [ + 'face_landmarker_model', + 'face_landmarker_score' + ], + description = 'Detect face landmarks' +) +def detect_landmarks_node(inputs : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]: + vision_frame = inputs.get('image') + faces = get_many_faces([ vision_frame ]) + result_frame = numpy.ascontiguousarray(vision_frame.copy()) + + landmark_data = [] + for face in faces: + face_landmark_68 = face.landmark_set.get('68') + + if numpy.any(face_landmark_68): + for point in face_landmark_68.astype(int): + cv2.circle(result_frame, tuple(point), 2, (0, 255, 0), -1) + + landmark_data.append( + { + 'landmark_5' : face.landmark_set.get('5').tolist() if numpy.any(face.landmark_set.get('5')) else [], + 'landmark_68' : face.landmark_set.get('68').tolist() if numpy.any(face.landmark_set.get('68')) else [] + }) + + return\ + { + 'image' : result_frame, + 'landmark_data' : landmark_data + } diff --git a/facefusion/node.py b/facefusion/node.py new file mode 100644 index 00000000..3d099e7f --- /dev/null +++ b/facefusion/node.py @@ -0,0 +1,101 @@ +import base64 +from dataclasses import dataclass +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Type + +import cv2 +import numpy +from numpy.typing import NDArray + + +@dataclass +class NodePort: + name : str + type : str # 'image', 'json', 'faces' + label : str = '' + + +@dataclass +class NodeSchema: + name : str + inputs : List[NodePort] + outputs : List[NodePort] + state_keys : List[str] + description : str = '' + + +@dataclass +class RegisteredNode: + schema : NodeSchema + fn : Callable + + +NODE_REGISTRY : Dict[str, RegisteredNode] = {} + + +class NodeContext: + def __init__(self, state : Dict[str, Any]) -> None: + self._state = dict(state) + + def get_item(self, key : str) -> Any: + value = self._state.get(key) + + if value is None: + from facefusion import state_manager + + value = state_manager.get_item(key) + + return value + + def __getitem__(self, key : str) -> Any: + return self.get_item(key) + + def __contains__(self, key : str) -> bool: + return key in self._state + + def to_dict(self) -> Dict[str, Any]: + return dict(self._state) + + +def node(name : str, inputs : List[NodePort], outputs : List[NodePort], state_keys : List[str], description : str = '') -> Callable: + def decorator(fn : Callable) -> Callable: + schema = NodeSchema( + name = name, + inputs = inputs, + outputs = outputs, + state_keys = state_keys, + description = description + ) + + @wraps(fn) + def wrapper(inputs_dict : Dict[str, Any], ctx : Optional[NodeContext] = None) -> Dict[str, Any]: + if ctx is None: + from facefusion import state_manager + + state_snapshot = { key: state_manager.get_item(key) for key in state_keys } + ctx = NodeContext(state_snapshot) + return fn(inputs_dict, ctx) + + wrapper.__node_schema__ = schema + NODE_REGISTRY[name] = RegisteredNode(schema = schema, fn = wrapper) + return wrapper + + return decorator + + +def get_node(name : str) -> Optional[RegisteredNode]: + return NODE_REGISTRY.get(name) + + +def get_all_nodes() -> Dict[str, RegisteredNode]: + return NODE_REGISTRY + + +def decode_vision_frame(b64_string : str) -> NDArray[Any]: + image_bytes = base64.b64decode(b64_string) + return cv2.imdecode(numpy.frombuffer(image_bytes, numpy.uint8), cv2.IMREAD_COLOR) + + +def encode_vision_frame(frame : NDArray[Any], fmt : str = '.jpg') -> str: + _, buffer = cv2.imencode(fmt, frame) + return base64.b64encode(buffer.tobytes()).decode('utf-8') diff --git a/facefusion/processors/modules/face_debugger/core.py b/facefusion/processors/modules/face_debugger/core.py index fc05e7b8..0c5366a4 100755 --- a/facefusion/processors/modules/face_debugger/core.py +++ b/facefusion/processors/modules/face_debugger/core.py @@ -6,6 +6,7 @@ import numpy import facefusion.capability_store import facefusion.jobs.job_manager from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, logger, state_manager, translator, video_manager +from facefusion.node import NodeContext, NodePort, node from facefusion.face_analyser import scale_face from facefusion.face_helper import warp_face_by_face_landmark_5 from facefusion.face_masker import create_area_mask, create_box_mask, create_occlusion_mask, create_region_mask @@ -14,6 +15,7 @@ from facefusion.filesystem import in_directory, is_image, is_video from facefusion.processors.modules.face_debugger import choices as face_debugger_choices from facefusion.processors.modules.face_debugger.types import FaceDebuggerInputs from facefusion.processors.types import ApplyStateItem, ProcessorOutputs +from typing import Optional from facefusion.program_helper import find_argument_group from facefusion.types import Args, Face, InferencePool, ProcessMode, VisionFrame from facefusion.vision import read_static_image, read_static_video_frame @@ -77,14 +79,14 @@ def post_process() -> None: face_recognizer.clear_inference_pool() -def debug_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: - face_debugger_items = state_manager.get_item('face_debugger_items') +def debug_face(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame: + face_debugger_items = ctx.get_item('face_debugger_items') if ctx else state_manager.get_item('face_debugger_items') if 'bounding-box' in face_debugger_items: temp_vision_frame = draw_bounding_box(target_face, temp_vision_frame) if 'face-mask' in face_debugger_items: - temp_vision_frame = draw_face_mask(target_face, temp_vision_frame) + temp_vision_frame = draw_face_mask(target_face, temp_vision_frame, ctx) if 'face-landmark-5' in face_debugger_items: temp_vision_frame = draw_face_landmark_5(target_face, temp_vision_frame) @@ -122,7 +124,8 @@ def draw_bounding_box(target_face : Face, temp_vision_frame : VisionFrame) -> Vi return temp_vision_frame -def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: +def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame: + _get = ctx.get_item if ctx else state_manager.get_item crop_masks = [] temp_vision_frame = numpy.ascontiguousarray(temp_vision_frame) face_landmark_5 = target_face.landmark_set.get('5') @@ -136,21 +139,21 @@ def draw_face_mask(target_face : Face, temp_vision_frame : VisionFrame) -> Visio if numpy.array_equal(face_landmark_5, face_landmark_5_68): mask_color = 255, 255, 0 - if 'box' in state_manager.get_item('face_mask_types'): - box_mask = create_box_mask(crop_vision_frame, 0, state_manager.get_item('face_mask_padding')) + if 'box' in _get('face_mask_types'): + box_mask = create_box_mask(crop_vision_frame, 0, _get('face_mask_padding')) crop_masks.append(box_mask) - if 'occlusion' in state_manager.get_item('face_mask_types'): + if 'occlusion' in _get('face_mask_types'): occlusion_mask = create_occlusion_mask(crop_vision_frame) crop_masks.append(occlusion_mask) - if 'area' in state_manager.get_item('face_mask_types'): + if 'area' in _get('face_mask_types'): face_landmark_68 = cv2.transform(face_landmark_68.reshape(1, -1, 2), affine_matrix).reshape(-1, 2) - area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas')) + area_mask = create_area_mask(crop_vision_frame, face_landmark_68, _get('face_mask_areas')) crop_masks.append(area_mask) - if 'region' in state_manager.get_item('face_mask_types'): - region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions')) + if 'region' in _get('face_mask_types'): + region_mask = create_region_mask(crop_vision_frame, _get('face_mask_regions')) crop_masks.append(region_mask) crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) @@ -227,18 +230,51 @@ def draw_face_landmark_68_5(target_face : Face, temp_vision_frame : VisionFrame) return temp_vision_frame -def process_frame(inputs : FaceDebuggerInputs) -> ProcessorOutputs: +@node( + name = 'face_debugger', + inputs = + [ + NodePort(name = 'image', type = 'image', label = 'Image') + ], + outputs = + [ + NodePort(name = 'image', type = 'image', label = 'Debug Image') + ], + state_keys = + [ + 'face_debugger_items' + ], + description = 'Visualize face landmarks, bounding boxes and masks' +) +def process_frame(inputs, ctx = None): + from facefusion.face_analyser import get_many_faces + from facefusion.vision import extract_vision_mask + + vision_frame = inputs.get('image') + + if isinstance(inputs.get('reference_vision_frame'), type(None)) and vision_frame is not None: + target_vision_frame = vision_frame + temp_vision_frame = vision_frame.copy() + temp_vision_mask = extract_vision_mask(temp_vision_frame) + target_faces = get_many_faces([ target_vision_frame ]) + + if target_faces: + for target_face in target_faces: + temp_vision_frame = debug_face(target_face, temp_vision_frame, ctx) + + return { 'image' : temp_vision_frame } + reference_vision_frame = inputs.get('reference_vision_frame') - target_vision_frame = inputs.get('target_vision_frame') - temp_vision_frame = inputs.get('temp_vision_frame') + target_vision_frame = inputs.get('target_vision_frame', vision_frame) + temp_vision_frame = inputs.get('temp_vision_frame', vision_frame.copy() if vision_frame is not None else None) temp_vision_mask = inputs.get('temp_vision_mask') target_faces = select_faces(reference_vision_frame, target_vision_frame) if target_faces: for target_face in target_faces: target_face = scale_face(target_face, target_vision_frame, temp_vision_frame) - temp_vision_frame = debug_face(target_face, temp_vision_frame) + temp_vision_frame = debug_face(target_face, temp_vision_frame, ctx) - return temp_vision_frame, temp_vision_mask + return { 'image' : temp_vision_frame } diff --git a/facefusion/processors/modules/face_enhancer/core.py b/facefusion/processors/modules/face_enhancer/core.py index bc40538d..9aa026fa 100755 --- a/facefusion/processors/modules/face_enhancer/core.py +++ b/facefusion/processors/modules/face_enhancer/core.py @@ -14,8 +14,10 @@ from facefusion.face_masker import create_box_mask, create_occlusion_mask from facefusion.face_selector import select_faces from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path from facefusion.processors.modules.face_enhancer import choices as face_enhancer_choices +from facefusion.node import NodeContext, NodePort, node from facefusion.processors.modules.face_enhancer.types import FaceEnhancerInputs, FaceEnhancerWeight from facefusion.processors.types import ApplyStateItem, ProcessorOutputs +from typing import Optional from facefusion.program_helper import find_argument_group from facefusion.thread_helper import thread_semaphore from facefusion.types import Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, VisionFrame @@ -360,27 +362,28 @@ def post_process() -> None: face_recognizer.clear_inference_pool() -def enhance_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: +def enhance_face(target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame: + _get = ctx.get_item if ctx else state_manager.get_item model_template = get_model_options().get('template') model_size = get_model_options().get('size') crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, model_size) - box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), (0, 0, 0, 0)) + box_mask = create_box_mask(crop_vision_frame, _get('face_mask_blur'), (0, 0, 0, 0)) crop_masks =\ [ box_mask ] - if 'occlusion' in state_manager.get_item('face_mask_types'): + if 'occlusion' in _get('face_mask_types'): occlusion_mask = create_occlusion_mask(crop_vision_frame) crop_masks.append(occlusion_mask) crop_vision_frame = prepare_crop_frame(crop_vision_frame) - face_enhancer_weight = numpy.array([ state_manager.get_item('face_enhancer_weight') ]).astype(numpy.double) + face_enhancer_weight = numpy.array([ _get('face_enhancer_weight') ]).astype(numpy.double) crop_vision_frame = forward(crop_vision_frame, face_enhancer_weight) crop_vision_frame = normalize_crop_frame(crop_vision_frame) crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix) - temp_vision_frame = blend_paste_frame(temp_vision_frame, paste_vision_frame) + temp_vision_frame = blend_paste_frame(temp_vision_frame, paste_vision_frame, ctx) return temp_vision_frame @@ -426,22 +429,57 @@ def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: return crop_vision_frame -def blend_paste_frame(temp_vision_frame : VisionFrame, paste_vision_frame : VisionFrame) -> VisionFrame: - face_enhancer_blend = 1 - (state_manager.get_item('face_enhancer_blend') / 100) +def blend_paste_frame(temp_vision_frame : VisionFrame, paste_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame: + _get = ctx.get_item if ctx else state_manager.get_item + face_enhancer_blend = 1 - (_get('face_enhancer_blend') / 100) temp_vision_frame = blend_frame(temp_vision_frame, paste_vision_frame, 1 - face_enhancer_blend) return temp_vision_frame -def process_frame(inputs : FaceEnhancerInputs) -> ProcessorOutputs: +@node( + name = 'face_enhancer', + inputs = + [ + NodePort(name = 'image', type = 'image', label = 'Image') + ], + outputs = + [ + NodePort(name = 'image', type = 'image', label = 'Enhanced Image') + ], + state_keys = + [ + 'face_enhancer_model', + 'face_enhancer_blend', + 'face_enhancer_weight' + ], + description = 'Enhance and restore face quality' +) +def process_frame(inputs, ctx = None): + from facefusion.face_analyser import get_many_faces + from facefusion.vision import extract_vision_mask + + vision_frame = inputs.get('image') + + if isinstance(inputs.get('reference_vision_frame'), type(None)) and vision_frame is not None: + target_vision_frame = vision_frame + temp_vision_frame = vision_frame.copy() + target_faces = get_many_faces([ target_vision_frame ]) + + if target_faces: + for target_face in target_faces: + temp_vision_frame = enhance_face(target_face, temp_vision_frame, ctx) + + return { 'image' : temp_vision_frame } + reference_vision_frame = inputs.get('reference_vision_frame') - target_vision_frame = inputs.get('target_vision_frame') - temp_vision_frame = inputs.get('temp_vision_frame') + target_vision_frame = inputs.get('target_vision_frame', vision_frame) + temp_vision_frame = inputs.get('temp_vision_frame', vision_frame.copy() if vision_frame is not None else None) temp_vision_mask = inputs.get('temp_vision_mask') target_faces = select_faces(reference_vision_frame, target_vision_frame) if target_faces: for target_face in target_faces: target_face = scale_face(target_face, target_vision_frame, temp_vision_frame) - temp_vision_frame = enhance_face(target_face, temp_vision_frame) + temp_vision_frame = enhance_face(target_face, temp_vision_frame, ctx) - return temp_vision_frame, temp_vision_mask + return { 'image' : temp_vision_frame } diff --git a/facefusion/processors/modules/face_swapper/core.py b/facefusion/processors/modules/face_swapper/core.py index 4122a0af..ab5ac3e4 100755 --- a/facefusion/processors/modules/face_swapper/core.py +++ b/facefusion/processors/modules/face_swapper/core.py @@ -19,6 +19,7 @@ from facefusion.face_selector import select_faces, sort_faces_by_order from facefusion.filesystem import filter_image_paths, has_image, in_directory, is_image, is_video, resolve_relative_path from facefusion.model_helper import get_static_model_initializer from facefusion.processors.modules.face_swapper import choices as face_swapper_choices +from facefusion.node import NodeContext, NodePort, node from facefusion.processors.modules.face_swapper.types import FaceSwapperInputs from facefusion.processors.pixel_boost import explode_pixel_boost, implode_pixel_boost from facefusion.processors.types import ApplyStateItem, ProcessorOutputs @@ -600,20 +601,21 @@ def post_process() -> None: face_recognizer.clear_inference_pool() -def swap_face(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: +def swap_face(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame, ctx : NodeContext = None) -> VisionFrame: + _get = ctx.get_item if ctx else state_manager.get_item model_template = get_model_options().get('template') model_size = get_model_options().get('size') - pixel_boost_size = unpack_resolution(state_manager.get_item('face_swapper_pixel_boost')) + pixel_boost_size = unpack_resolution(_get('face_swapper_pixel_boost')) pixel_boost_total = pixel_boost_size[0] // model_size[0] crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, pixel_boost_size) temp_vision_frames = [] crop_masks = [] - if 'box' in state_manager.get_item('face_mask_types'): - box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), state_manager.get_item('face_mask_padding')) + if 'box' in _get('face_mask_types'): + box_mask = create_box_mask(crop_vision_frame, _get('face_mask_blur'), _get('face_mask_padding')) crop_masks.append(box_mask) - if 'occlusion' in state_manager.get_item('face_mask_types'): + if 'occlusion' in _get('face_mask_types'): occlusion_mask = create_occlusion_mask(crop_vision_frame) crop_masks.append(occlusion_mask) @@ -625,13 +627,13 @@ def swap_face(source_face : Face, target_face : Face, temp_vision_frame : Vision temp_vision_frames.append(pixel_boost_vision_frame) crop_vision_frame = explode_pixel_boost(temp_vision_frames, pixel_boost_total, model_size, pixel_boost_size) - if 'area' in state_manager.get_item('face_mask_types'): + if 'area' in _get('face_mask_types'): face_landmark_68 = cv2.transform(target_face.landmark_set.get('68').reshape(1, -1, 2), affine_matrix).reshape(-1, 2) - area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas')) + area_mask = create_area_mask(crop_vision_frame, face_landmark_68, _get('face_mask_areas')) crop_masks.append(area_mask) - if 'region' in state_manager.get_item('face_mask_types'): - region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions')) + if 'region' in _get('face_mask_types'): + region_mask = create_region_mask(crop_vision_frame, _get('face_mask_regions')) crop_masks.append(region_mask) crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) @@ -779,9 +781,42 @@ def extract_source_face(source_vision_frames : List[VisionFrame]) -> Optional[Fa return get_average_face(source_faces) -def process_frame(inputs : FaceSwapperInputs) -> ProcessorOutputs: - reference_vision_frame = inputs.get('reference_vision_frame') +@node( + name = 'face_swapper', + inputs = + [ + NodePort(name = 'source', type = 'image', label = 'Source Face'), + NodePort(name = 'target', type = 'image', label = 'Target Image') + ], + outputs = + [ + NodePort(name = 'image', type = 'image', label = 'Swapped Image') + ], + state_keys = + [ + 'face_swapper_model', + 'face_swapper_pixel_boost', + 'face_swapper_weight' + ], + description = 'Swap faces from source to target image' +) +def process_frame(inputs, ctx = None): + source_frame = inputs.get('source') + target_frame = inputs.get('target') + + if source_frame is not None and target_frame is not None: + source_face = extract_source_face([ source_frame ]) + target_faces = get_many_faces([ target_frame ]) + temp_vision_frame = target_frame.copy() + + if source_face and target_faces: + for target_face in target_faces: + temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame, ctx) + + return { 'image' : temp_vision_frame } + source_vision_frames = inputs.get('source_vision_frames') + reference_vision_frame = inputs.get('reference_vision_frame') target_vision_frame = inputs.get('target_vision_frame') temp_vision_frame = inputs.get('temp_vision_frame') temp_vision_mask = inputs.get('temp_vision_mask') @@ -791,6 +826,6 @@ def process_frame(inputs : FaceSwapperInputs) -> ProcessorOutputs: if source_face and target_faces: for target_face in target_faces: target_face = scale_face(target_face, target_vision_frame, temp_vision_frame) - temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame) + temp_vision_frame = swap_face(source_face, target_face, temp_vision_frame, ctx) - return temp_vision_frame, temp_vision_mask + return { 'image' : temp_vision_frame }