mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-23 15:29:43 +02:00
Add streaming support and change the explorer push call to be async.
This commit is contained in:
+203
-33
@@ -4,6 +4,7 @@ import json
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from starlette.responses import StreamingResponse
|
||||
from utils.explorer import push_trace
|
||||
|
||||
ALLOWED_OPEN_AI_ENDPOINTS = {"chat/completions"}
|
||||
@@ -18,6 +19,7 @@ IGNORED_HEADERS = [
|
||||
"x-forwarded-server",
|
||||
"x-real-ip",
|
||||
]
|
||||
|
||||
proxy = APIRouter()
|
||||
|
||||
MISSING_INVARIANT_AUTH_HEADER = "Missing invariant-authorization header"
|
||||
@@ -54,41 +56,209 @@ async def openai_proxy(
|
||||
}
|
||||
headers["accept-encoding"] = "identity"
|
||||
|
||||
request_body = await request.body()
|
||||
request_body_bytes = await request.body()
|
||||
request_body_json = json.loads(request_body_bytes)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
open_ai_request = client.build_request(
|
||||
"POST",
|
||||
f"https://api.openai.com/v1/{endpoint}",
|
||||
content=request_body,
|
||||
headers=headers,
|
||||
# Check if the request is for streaming
|
||||
is_streaming = request_body_json.get("stream", False)
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
open_ai_request = client.build_request(
|
||||
"POST",
|
||||
f"https://api.openai.com/v1/{endpoint}",
|
||||
content=request_body_bytes,
|
||||
headers=headers,
|
||||
)
|
||||
if is_streaming:
|
||||
return await stream_response(
|
||||
client,
|
||||
open_ai_request,
|
||||
dataset_name,
|
||||
request_body_json,
|
||||
request.headers,
|
||||
)
|
||||
response = await client.send(open_ai_request)
|
||||
try:
|
||||
json_response = response.json()
|
||||
# push messages to the Invariant Explorer
|
||||
# use both the request and response messages
|
||||
messages = json.loads(request_body).get("messages", [])
|
||||
messages += [
|
||||
choice["message"] for choice in json_response.get("choices", [])
|
||||
]
|
||||
_ = push_trace(
|
||||
dataset_name=dataset_name,
|
||||
messages=[messages],
|
||||
invariant_authorization=request.headers.get("invariant-authorization"),
|
||||
else:
|
||||
async with client:
|
||||
response = await client.send(open_ai_request)
|
||||
return await handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, request.headers
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=FAILED_TO_PUSH_TRACE + str(e)
|
||||
) from e
|
||||
|
||||
response_headers = dict(response.headers)
|
||||
response_headers.pop("Content-Encoding", None)
|
||||
response_headers.pop("Content-Length", None)
|
||||
|
||||
return Response(
|
||||
content=json.dumps(response.json()),
|
||||
status_code=response.status_code,
|
||||
media_type="application/json",
|
||||
headers=response_headers,
|
||||
)
|
||||
async def stream_response(
|
||||
client, open_ai_request, dataset_name, request_body_json, request_headers
|
||||
):
|
||||
"""Handles streaming the OpenAI response to the client while collecting full response"""
|
||||
|
||||
async def event_generator():
|
||||
full_response = {
|
||||
"id": None,
|
||||
"object": "chat.completion",
|
||||
"created": None,
|
||||
"model": None,
|
||||
"choices": [],
|
||||
"usage": None,
|
||||
}
|
||||
|
||||
# Tracks choice index to full_response index
|
||||
index_mapping = {}
|
||||
# Tracks tool calls by index
|
||||
tool_call_mapping = {}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
open_ai_request.url,
|
||||
headers=open_ai_request.headers,
|
||||
content=open_ai_request.content,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_message = json.dumps(
|
||||
{"error": f"Failed to fetch response: {response.status_code}"}
|
||||
).encode()
|
||||
yield error_message
|
||||
return
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
continue
|
||||
|
||||
# Yield chunk immediately to the client (proxy behavior)
|
||||
yield chunk
|
||||
|
||||
# There can be multiple "data: " chunks in a single response
|
||||
for json_string in chunk_text.split("\ndata: "):
|
||||
# Remove first "data: " prefix
|
||||
json_string = json_string.replace("data: ", "").strip()
|
||||
|
||||
if not json_string or json_string == "[DONE]":
|
||||
continue
|
||||
|
||||
try:
|
||||
json_chunk = json.loads(json_string)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Extract metadata safely
|
||||
full_response["id"] = full_response["id"] or json_chunk.get("id")
|
||||
full_response["created"] = full_response[
|
||||
"created"
|
||||
] or json_chunk.get("created")
|
||||
full_response["model"] = full_response["model"] or json_chunk.get(
|
||||
"model"
|
||||
)
|
||||
|
||||
for choice in json_chunk.get("choices", []):
|
||||
index = choice.get("index", 0)
|
||||
|
||||
# Ensure we have a mapping for this index
|
||||
if index not in index_mapping:
|
||||
index_mapping[index] = len(full_response["choices"])
|
||||
full_response["choices"].append(
|
||||
{
|
||||
"index": index,
|
||||
"message": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
)
|
||||
|
||||
existing_choice = full_response["choices"][index_mapping[index]]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Handle regular assistant messages
|
||||
content = delta.get("content")
|
||||
if content is not None:
|
||||
if "content" not in existing_choice["message"]:
|
||||
existing_choice["message"]["content"] = ""
|
||||
existing_choice["message"]["content"] += content
|
||||
|
||||
# Handle tool calls
|
||||
if isinstance(delta.get("tool_calls"), list):
|
||||
if "tool_calls" not in existing_choice["message"]:
|
||||
existing_choice["message"]["tool_calls"] = []
|
||||
|
||||
for tool in delta["tool_calls"]:
|
||||
tool_index = tool.get("index")
|
||||
tool_id = tool.get("id")
|
||||
tool_name = tool.get("function", {}).get("name")
|
||||
tool_arguments = tool.get("function", {}).get(
|
||||
"arguments", ""
|
||||
)
|
||||
|
||||
if tool_index is None:
|
||||
continue
|
||||
|
||||
# Find or create tool call by index
|
||||
if tool_index not in tool_call_mapping:
|
||||
tool_call_mapping[tool_index] = {
|
||||
"index": tool_index,
|
||||
"id": tool_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
existing_choice["message"]["tool_calls"].append(
|
||||
tool_call_mapping[tool_index]
|
||||
)
|
||||
|
||||
tool_entry = tool_call_mapping[tool_index]
|
||||
|
||||
if tool_id:
|
||||
tool_entry["id"] = tool_id
|
||||
|
||||
if tool_name:
|
||||
tool_entry["function"]["name"] = tool_name
|
||||
|
||||
# Append arguments if they exist
|
||||
if tool_arguments:
|
||||
tool_entry["function"]["arguments"] += (
|
||||
tool_arguments
|
||||
)
|
||||
|
||||
finish_reason = choice.get("finish_reason")
|
||||
if finish_reason is not None:
|
||||
existing_choice["finish_reason"] = finish_reason
|
||||
|
||||
# Send full merged response to the explorer
|
||||
await push_to_explorer(
|
||||
dataset_name, full_response, request_headers, request_body_json
|
||||
)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def push_to_explorer(dataset_name, full_response, request_headers, request_body):
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
# Combine messages from the request and the response
|
||||
# to push the full trace to the Invariant Explorer
|
||||
messages = request_body.get("messages", [])
|
||||
messages += [choice["message"] for choice in full_response.get("choices", [])]
|
||||
|
||||
_ = await push_trace(
|
||||
dataset_name=dataset_name,
|
||||
messages=[messages],
|
||||
invariant_authorization=request_headers.get("invariant-authorization"),
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, request_headers
|
||||
):
|
||||
"""Handles non-streaming OpenAI responses"""
|
||||
json_response = response.json()
|
||||
await push_to_explorer(
|
||||
dataset_name, json_response, request_headers, request_body_json
|
||||
)
|
||||
|
||||
response_headers = dict(response.headers)
|
||||
response_headers.pop("Content-Encoding", None)
|
||||
response_headers.pop("Content-Length", None)
|
||||
|
||||
return Response(
|
||||
content=json.dumps(json_response),
|
||||
status_code=response.status_code,
|
||||
media_type="application/json",
|
||||
headers=response_headers,
|
||||
)
|
||||
|
||||
+28
-14
@@ -3,13 +3,14 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import HTTPException
|
||||
from invariant_sdk.client import Client
|
||||
import httpx
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
PUSH_ENDPOINT = "/api/v1/push/trace"
|
||||
|
||||
|
||||
def push_trace(
|
||||
async def push_trace(
|
||||
messages: List[Dict[str, Any]],
|
||||
dataset_name: str,
|
||||
invariant_authorization: str,
|
||||
@@ -19,19 +20,32 @@ def push_trace(
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): List of messages to push.
|
||||
dataset_name (str): Name of the dataset.
|
||||
invariant_authorization (str): Authorization token.
|
||||
invariant_authorization (str): Authorization token from the
|
||||
invariant-authorization header.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Response containing the trace ID.
|
||||
"""
|
||||
api_url = os.getenv("INVARIANT_API_URL", DEFAULT_API_URL)
|
||||
api_key = invariant_authorization.split("Bearer ")[1]
|
||||
client = Client(api_url=api_url, api_key=api_key)
|
||||
try:
|
||||
# TODO: Change this to the async version once that is available
|
||||
push_trace_response = client.create_request_and_push_trace(
|
||||
messages=messages, dataset=dataset_name
|
||||
api_url = os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
|
||||
request = PushTracesRequest(messages=messages, dataset=dataset_name)
|
||||
async with httpx.AsyncClient() as client:
|
||||
explorer_push_request = client.build_request(
|
||||
"POST",
|
||||
f"{api_url}{PUSH_ENDPOINT}",
|
||||
json=request.to_json(),
|
||||
headers={
|
||||
"Authorization": f"{invariant_authorization}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
return {"trace_id": push_trace_response.id[0]}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
try:
|
||||
response = await client.send(explorer_push_request)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
print(f"Failed to push trace: {e.response.text}")
|
||||
return {"error": str(e)}
|
||||
except Exception as e:
|
||||
print(f"Unexpected error pushing trace: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
|
||||
Reference in New Issue
Block a user