From 9c3937183ea96ad1e69ccbfd4cf7dc2ee1421cf4 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 4 Feb 2025 12:09:14 +0100 Subject: [PATCH] Add streaming support and change the explorer push call to be async. --- proxy/routes/open_ai.py | 236 ++++++++++++++++++++++++++++++++++------ proxy/utils/explorer.py | 42 ++++--- 2 files changed, 231 insertions(+), 47 deletions(-) diff --git a/proxy/routes/open_ai.py b/proxy/routes/open_ai.py index 883bdfc..02a752a 100644 --- a/proxy/routes/open_ai.py +++ b/proxy/routes/open_ai.py @@ -4,6 +4,7 @@ import json import httpx from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response +from starlette.responses import StreamingResponse from utils.explorer import push_trace ALLOWED_OPEN_AI_ENDPOINTS = {"chat/completions"} @@ -18,6 +19,7 @@ IGNORED_HEADERS = [ "x-forwarded-server", "x-real-ip", ] + proxy = APIRouter() MISSING_INVARIANT_AUTH_HEADER = "Missing invariant-authorization header" @@ -54,41 +56,209 @@ async def openai_proxy( } headers["accept-encoding"] = "identity" - request_body = await request.body() + request_body_bytes = await request.body() + request_body_json = json.loads(request_body_bytes) - async with httpx.AsyncClient() as client: - open_ai_request = client.build_request( - "POST", - f"https://api.openai.com/v1/{endpoint}", - content=request_body, - headers=headers, + # Check if the request is for streaming + is_streaming = request_body_json.get("stream", False) + + client = httpx.AsyncClient() + open_ai_request = client.build_request( + "POST", + f"https://api.openai.com/v1/{endpoint}", + content=request_body_bytes, + headers=headers, + ) + if is_streaming: + return await stream_response( + client, + open_ai_request, + dataset_name, + request_body_json, + request.headers, ) - response = await client.send(open_ai_request) - try: - json_response = response.json() - # push messages to the Invariant Explorer - # use both the request and response messages - messages = json.loads(request_body).get("messages", []) - messages += [ - choice["message"] for choice in json_response.get("choices", []) - ] - _ = push_trace( - dataset_name=dataset_name, - messages=[messages], - invariant_authorization=request.headers.get("invariant-authorization"), + else: + async with client: + response = await client.send(open_ai_request) + return await handle_non_streaming_response( + response, dataset_name, request_body_json, request.headers ) - except Exception as e: - raise HTTPException( - status_code=500, detail=FAILED_TO_PUSH_TRACE + str(e) - ) from e - response_headers = dict(response.headers) - response_headers.pop("Content-Encoding", None) - response_headers.pop("Content-Length", None) - return Response( - content=json.dumps(response.json()), - status_code=response.status_code, - media_type="application/json", - headers=response_headers, - ) +async def stream_response( + client, open_ai_request, dataset_name, request_body_json, request_headers +): + """Handles streaming the OpenAI response to the client while collecting full response""" + + async def event_generator(): + full_response = { + "id": None, + "object": "chat.completion", + "created": None, + "model": None, + "choices": [], + "usage": None, + } + + # Tracks choice index to full_response index + index_mapping = {} + # Tracks tool calls by index + tool_call_mapping = {} + + async with client.stream( + "POST", + open_ai_request.url, + headers=open_ai_request.headers, + content=open_ai_request.content, + ) as response: + if response.status_code != 200: + error_message = json.dumps( + {"error": f"Failed to fetch response: {response.status_code}"} + ).encode() + yield error_message + return + + async for chunk in response.aiter_bytes(): + chunk_text = chunk.decode().strip() + if not chunk_text: + continue + + # Yield chunk immediately to the client (proxy behavior) + yield chunk + + # There can be multiple "data: " chunks in a single response + for json_string in chunk_text.split("\ndata: "): + # Remove first "data: " prefix + json_string = json_string.replace("data: ", "").strip() + + if not json_string or json_string == "[DONE]": + continue + + try: + json_chunk = json.loads(json_string) + except json.JSONDecodeError: + continue + + # Extract metadata safely + full_response["id"] = full_response["id"] or json_chunk.get("id") + full_response["created"] = full_response[ + "created" + ] or json_chunk.get("created") + full_response["model"] = full_response["model"] or json_chunk.get( + "model" + ) + + for choice in json_chunk.get("choices", []): + index = choice.get("index", 0) + + # Ensure we have a mapping for this index + if index not in index_mapping: + index_mapping[index] = len(full_response["choices"]) + full_response["choices"].append( + { + "index": index, + "message": {"role": "assistant"}, + "finish_reason": None, + } + ) + + existing_choice = full_response["choices"][index_mapping[index]] + delta = choice.get("delta", {}) + + # Handle regular assistant messages + content = delta.get("content") + if content is not None: + if "content" not in existing_choice["message"]: + existing_choice["message"]["content"] = "" + existing_choice["message"]["content"] += content + + # Handle tool calls + if isinstance(delta.get("tool_calls"), list): + if "tool_calls" not in existing_choice["message"]: + existing_choice["message"]["tool_calls"] = [] + + for tool in delta["tool_calls"]: + tool_index = tool.get("index") + tool_id = tool.get("id") + tool_name = tool.get("function", {}).get("name") + tool_arguments = tool.get("function", {}).get( + "arguments", "" + ) + + if tool_index is None: + continue + + # Find or create tool call by index + if tool_index not in tool_call_mapping: + tool_call_mapping[tool_index] = { + "index": tool_index, + "id": tool_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": "", + }, + } + existing_choice["message"]["tool_calls"].append( + tool_call_mapping[tool_index] + ) + + tool_entry = tool_call_mapping[tool_index] + + if tool_id: + tool_entry["id"] = tool_id + + if tool_name: + tool_entry["function"]["name"] = tool_name + + # Append arguments if they exist + if tool_arguments: + tool_entry["function"]["arguments"] += ( + tool_arguments + ) + + finish_reason = choice.get("finish_reason") + if finish_reason is not None: + existing_choice["finish_reason"] = finish_reason + + # Send full merged response to the explorer + await push_to_explorer( + dataset_name, full_response, request_headers, request_body_json + ) + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + +async def push_to_explorer(dataset_name, full_response, request_headers, request_body): + """Pushes the full trace to the Invariant Explorer""" + # Combine messages from the request and the response + # to push the full trace to the Invariant Explorer + messages = request_body.get("messages", []) + messages += [choice["message"] for choice in full_response.get("choices", [])] + + _ = await push_trace( + dataset_name=dataset_name, + messages=[messages], + invariant_authorization=request_headers.get("invariant-authorization"), + ) + + +async def handle_non_streaming_response( + response, dataset_name, request_body_json, request_headers +): + """Handles non-streaming OpenAI responses""" + json_response = response.json() + await push_to_explorer( + dataset_name, json_response, request_headers, request_body_json + ) + + response_headers = dict(response.headers) + response_headers.pop("Content-Encoding", None) + response_headers.pop("Content-Length", None) + + return Response( + content=json.dumps(json_response), + status_code=response.status_code, + media_type="application/json", + headers=response_headers, + ) diff --git a/proxy/utils/explorer.py b/proxy/utils/explorer.py index 4b561f8..f9eea8e 100644 --- a/proxy/utils/explorer.py +++ b/proxy/utils/explorer.py @@ -3,13 +3,14 @@ import os from typing import Any, Dict, List -from fastapi import HTTPException -from invariant_sdk.client import Client +import httpx +from invariant_sdk.types.push_traces import PushTracesRequest DEFAULT_API_URL = "https://explorer.invariantlabs.ai" +PUSH_ENDPOINT = "/api/v1/push/trace" -def push_trace( +async def push_trace( messages: List[Dict[str, Any]], dataset_name: str, invariant_authorization: str, @@ -19,19 +20,32 @@ def push_trace( Args: messages (List[Dict[str, Any]]): List of messages to push. dataset_name (str): Name of the dataset. - invariant_authorization (str): Authorization token. + invariant_authorization (str): Authorization token from the + invariant-authorization header. Returns: Dict[str, str]: Response containing the trace ID. """ - api_url = os.getenv("INVARIANT_API_URL", DEFAULT_API_URL) - api_key = invariant_authorization.split("Bearer ")[1] - client = Client(api_url=api_url, api_key=api_key) - try: - # TODO: Change this to the async version once that is available - push_trace_response = client.create_request_and_push_trace( - messages=messages, dataset=dataset_name + api_url = os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/") + + request = PushTracesRequest(messages=messages, dataset=dataset_name) + async with httpx.AsyncClient() as client: + explorer_push_request = client.build_request( + "POST", + f"{api_url}{PUSH_ENDPOINT}", + json=request.to_json(), + headers={ + "Authorization": f"{invariant_authorization}", + "Accept": "application/json", + }, ) - return {"trace_id": push_trace_response.id[0]} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) from e + try: + response = await client.send(explorer_push_request) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + print(f"Failed to push trace: {e.response.text}") + return {"error": str(e)} + except Exception as e: + print(f"Unexpected error pushing trace: {str(e)}") + return {"error": str(e)}