diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index d0cc2bdb..3e1b6137 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -8,24 +8,25 @@ from facefusion.apis.session_helper import extract_access_token from facefusion.apis.stream_helper import add_rtc_viewer, handle_image_stream, handle_video_stream -# TODO: reject websocket with invalid or missing stream mode async def websocket_stream(websocket : WebSocket) -> None: stream_mode = websocket.query_params.get('mode') if stream_mode == 'image': - await handle_image_stream(websocket) + return await handle_image_stream(websocket) if stream_mode == 'video': - await handle_video_stream(websocket) + return await handle_video_stream(websocket) + + return await websocket.close(1008) -# TODO: validate content type is application/sdp, sanitize sdp input async def post_stream(request : Request) -> Response: + content_type = request.headers.get('content-type') access_token = extract_access_token(request.scope) session_id = session_manager.find_session_id(access_token) session_context.set_session_id(session_id) - if session_id: + if content_type == 'application/sdp' and session_id: sdp_offer = await request.body() sdp_answer = add_rtc_viewer(session_id, sdp_offer.decode())