Add streaming support and change the explorer push call to be async.

This commit is contained in:
Hemang
2025-02-04 12:09:14 +01:00
parent be934bc07c
commit 9c3937183e
2 changed files with 231 additions and 47 deletions
+203 -33
View File
@@ -4,6 +4,7 @@ import json
import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse
from utils.explorer import push_trace
ALLOWED_OPEN_AI_ENDPOINTS = {"chat/completions"}
@@ -18,6 +19,7 @@ IGNORED_HEADERS = [
"x-forwarded-server",
"x-real-ip",
]
proxy = APIRouter()
MISSING_INVARIANT_AUTH_HEADER = "Missing invariant-authorization header"
@@ -54,41 +56,209 @@ async def openai_proxy(
}
headers["accept-encoding"] = "identity"
request_body = await request.body()
request_body_bytes = await request.body()
request_body_json = json.loads(request_body_bytes)
async with httpx.AsyncClient() as client:
open_ai_request = client.build_request(
"POST",
f"https://api.openai.com/v1/{endpoint}",
content=request_body,
headers=headers,
# Check if the request is for streaming
is_streaming = request_body_json.get("stream", False)
client = httpx.AsyncClient()
open_ai_request = client.build_request(
"POST",
f"https://api.openai.com/v1/{endpoint}",
content=request_body_bytes,
headers=headers,
)
if is_streaming:
return await stream_response(
client,
open_ai_request,
dataset_name,
request_body_json,
request.headers,
)
response = await client.send(open_ai_request)
try:
json_response = response.json()
# push messages to the Invariant Explorer
# use both the request and response messages
messages = json.loads(request_body).get("messages", [])
messages += [
choice["message"] for choice in json_response.get("choices", [])
]
_ = push_trace(
dataset_name=dataset_name,
messages=[messages],
invariant_authorization=request.headers.get("invariant-authorization"),
else:
async with client:
response = await client.send(open_ai_request)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, request.headers
)
except Exception as e:
raise HTTPException(
status_code=500, detail=FAILED_TO_PUSH_TRACE + str(e)
) from e
response_headers = dict(response.headers)
response_headers.pop("Content-Encoding", None)
response_headers.pop("Content-Length", None)
return Response(
content=json.dumps(response.json()),
status_code=response.status_code,
media_type="application/json",
headers=response_headers,
)
async def stream_response(
client, open_ai_request, dataset_name, request_body_json, request_headers
):
"""Handles streaming the OpenAI response to the client while collecting full response"""
async def event_generator():
full_response = {
"id": None,
"object": "chat.completion",
"created": None,
"model": None,
"choices": [],
"usage": None,
}
# Tracks choice index to full_response index
index_mapping = {}
# Tracks tool calls by index
tool_call_mapping = {}
async with client.stream(
"POST",
open_ai_request.url,
headers=open_ai_request.headers,
content=open_ai_request.content,
) as response:
if response.status_code != 200:
error_message = json.dumps(
{"error": f"Failed to fetch response: {response.status_code}"}
).encode()
yield error_message
return
async for chunk in response.aiter_bytes():
chunk_text = chunk.decode().strip()
if not chunk_text:
continue
# Yield chunk immediately to the client (proxy behavior)
yield chunk
# There can be multiple "data: " chunks in a single response
for json_string in chunk_text.split("\ndata: "):
# Remove first "data: " prefix
json_string = json_string.replace("data: ", "").strip()
if not json_string or json_string == "[DONE]":
continue
try:
json_chunk = json.loads(json_string)
except json.JSONDecodeError:
continue
# Extract metadata safely
full_response["id"] = full_response["id"] or json_chunk.get("id")
full_response["created"] = full_response[
"created"
] or json_chunk.get("created")
full_response["model"] = full_response["model"] or json_chunk.get(
"model"
)
for choice in json_chunk.get("choices", []):
index = choice.get("index", 0)
# Ensure we have a mapping for this index
if index not in index_mapping:
index_mapping[index] = len(full_response["choices"])
full_response["choices"].append(
{
"index": index,
"message": {"role": "assistant"},
"finish_reason": None,
}
)
existing_choice = full_response["choices"][index_mapping[index]]
delta = choice.get("delta", {})
# Handle regular assistant messages
content = delta.get("content")
if content is not None:
if "content" not in existing_choice["message"]:
existing_choice["message"]["content"] = ""
existing_choice["message"]["content"] += content
# Handle tool calls
if isinstance(delta.get("tool_calls"), list):
if "tool_calls" not in existing_choice["message"]:
existing_choice["message"]["tool_calls"] = []
for tool in delta["tool_calls"]:
tool_index = tool.get("index")
tool_id = tool.get("id")
tool_name = tool.get("function", {}).get("name")
tool_arguments = tool.get("function", {}).get(
"arguments", ""
)
if tool_index is None:
continue
# Find or create tool call by index
if tool_index not in tool_call_mapping:
tool_call_mapping[tool_index] = {
"index": tool_index,
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": "",
},
}
existing_choice["message"]["tool_calls"].append(
tool_call_mapping[tool_index]
)
tool_entry = tool_call_mapping[tool_index]
if tool_id:
tool_entry["id"] = tool_id
if tool_name:
tool_entry["function"]["name"] = tool_name
# Append arguments if they exist
if tool_arguments:
tool_entry["function"]["arguments"] += (
tool_arguments
)
finish_reason = choice.get("finish_reason")
if finish_reason is not None:
existing_choice["finish_reason"] = finish_reason
# Send full merged response to the explorer
await push_to_explorer(
dataset_name, full_response, request_headers, request_body_json
)
return StreamingResponse(event_generator(), media_type="text/event-stream")
async def push_to_explorer(dataset_name, full_response, request_headers, request_body):
"""Pushes the full trace to the Invariant Explorer"""
# Combine messages from the request and the response
# to push the full trace to the Invariant Explorer
messages = request_body.get("messages", [])
messages += [choice["message"] for choice in full_response.get("choices", [])]
_ = await push_trace(
dataset_name=dataset_name,
messages=[messages],
invariant_authorization=request_headers.get("invariant-authorization"),
)
async def handle_non_streaming_response(
response, dataset_name, request_body_json, request_headers
):
"""Handles non-streaming OpenAI responses"""
json_response = response.json()
await push_to_explorer(
dataset_name, json_response, request_headers, request_body_json
)
response_headers = dict(response.headers)
response_headers.pop("Content-Encoding", None)
response_headers.pop("Content-Length", None)
return Response(
content=json.dumps(json_response),
status_code=response.status_code,
media_type="application/json",
headers=response_headers,
)
+28 -14
View File
@@ -3,13 +3,14 @@
import os
from typing import Any, Dict, List
from fastapi import HTTPException
from invariant_sdk.client import Client
import httpx
from invariant_sdk.types.push_traces import PushTracesRequest
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
PUSH_ENDPOINT = "/api/v1/push/trace"
def push_trace(
async def push_trace(
messages: List[Dict[str, Any]],
dataset_name: str,
invariant_authorization: str,
@@ -19,19 +20,32 @@ def push_trace(
Args:
messages (List[Dict[str, Any]]): List of messages to push.
dataset_name (str): Name of the dataset.
invariant_authorization (str): Authorization token.
invariant_authorization (str): Authorization token from the
invariant-authorization header.
Returns:
Dict[str, str]: Response containing the trace ID.
"""
api_url = os.getenv("INVARIANT_API_URL", DEFAULT_API_URL)
api_key = invariant_authorization.split("Bearer ")[1]
client = Client(api_url=api_url, api_key=api_key)
try:
# TODO: Change this to the async version once that is available
push_trace_response = client.create_request_and_push_trace(
messages=messages, dataset=dataset_name
api_url = os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/")
request = PushTracesRequest(messages=messages, dataset=dataset_name)
async with httpx.AsyncClient() as client:
explorer_push_request = client.build_request(
"POST",
f"{api_url}{PUSH_ENDPOINT}",
json=request.to_json(),
headers={
"Authorization": f"{invariant_authorization}",
"Accept": "application/json",
},
)
return {"trace_id": push_trace_response.id[0]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
try:
response = await client.send(explorer_push_request)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
print(f"Failed to push trace: {e.response.text}")
return {"error": str(e)}
except Exception as e:
print(f"Unexpected error pushing trace: {str(e)}")
return {"error": str(e)}