mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-25 08:14:02 +02:00
Refactor the code and add type hints.
This commit is contained in:
+185
-117
@@ -1,6 +1,7 @@
|
||||
"""Proxy service to forward requests to the OpenAI APIs"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
@@ -46,7 +47,7 @@ async def openai_proxy(
|
||||
request: Request,
|
||||
dataset_name: str,
|
||||
endpoint: str,
|
||||
):
|
||||
) -> Response:
|
||||
"""Proxy calls to the OpenAI APIs"""
|
||||
if endpoint not in ALLOWED_OPEN_AI_ENDPOINTS:
|
||||
raise HTTPException(status_code=404, detail=NOT_SUPPORTED_ENDPOINT)
|
||||
@@ -61,6 +62,7 @@ async def openai_proxy(
|
||||
|
||||
# Check if the request is for streaming
|
||||
is_streaming = request_body_json.get("stream", False)
|
||||
invariant_authorization = request.headers.get("invariant-authorization")
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
open_ai_request = client.build_request(
|
||||
@@ -75,23 +77,34 @@ async def openai_proxy(
|
||||
open_ai_request,
|
||||
dataset_name,
|
||||
request_body_json,
|
||||
request.headers,
|
||||
invariant_authorization,
|
||||
)
|
||||
else:
|
||||
async with client:
|
||||
response = await client.send(open_ai_request)
|
||||
return await handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, request.headers
|
||||
response, dataset_name, request_body_json, invariant_authorization
|
||||
)
|
||||
|
||||
|
||||
async def stream_response(
|
||||
client, open_ai_request, dataset_name, request_body_json, request_headers
|
||||
):
|
||||
"""Handles streaming the OpenAI response to the client while collecting full response"""
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
dataset_name: str,
|
||||
request_body_json: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Handles streaming the OpenAI response to the client while building a merged_response
|
||||
The chunks are returned to the caller immediately
|
||||
The merged_response is built from the chunks as they are received
|
||||
It is sent to the Invariant Explorer at the end of the stream
|
||||
"""
|
||||
|
||||
async def event_generator():
|
||||
full_response = {
|
||||
async def event_generator() -> Any:
|
||||
# merged_response will be updated with the data from the chunks in the stream
|
||||
# At the end of the stream, this will be sent to the explorer
|
||||
merged_response = {
|
||||
"id": None,
|
||||
"object": "chat.completion",
|
||||
"created": None,
|
||||
@@ -99,11 +112,13 @@ async def stream_response(
|
||||
"choices": [],
|
||||
"usage": None,
|
||||
}
|
||||
|
||||
# Tracks choice index to full_response index
|
||||
index_mapping = {}
|
||||
# Tracks tool calls by index
|
||||
tool_call_mapping = {}
|
||||
# Each chunk in the stream contains a list called "choices" each entry in the list
|
||||
# has an index.
|
||||
# A choice has a field called "delta" which may contain a list called "tool_calls".
|
||||
# Maps the choice index in the stream to the index in the merged_response["choices"] list
|
||||
choice_mapping_by_index = {}
|
||||
# Combines the choice index and tool call index to uniquely identify a tool call
|
||||
tool_call_mapping_by_index = {}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
@@ -112,10 +127,9 @@ async def stream_response(
|
||||
content=open_ai_request.content,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_message = json.dumps(
|
||||
yield json.dumps(
|
||||
{"error": f"Failed to fetch response: {response.status_code}"}
|
||||
).encode()
|
||||
yield error_message
|
||||
return
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
@@ -126,130 +140,184 @@ async def stream_response(
|
||||
# Yield chunk immediately to the client (proxy behavior)
|
||||
yield chunk
|
||||
|
||||
# There can be multiple "data: " chunks in a single response
|
||||
for json_string in chunk_text.split("\ndata: "):
|
||||
# Remove first "data: " prefix
|
||||
json_string = json_string.replace("data: ", "").strip()
|
||||
|
||||
if not json_string or json_string == "[DONE]":
|
||||
continue
|
||||
|
||||
try:
|
||||
json_chunk = json.loads(json_string)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Extract metadata safely
|
||||
full_response["id"] = full_response["id"] or json_chunk.get("id")
|
||||
full_response["created"] = full_response[
|
||||
"created"
|
||||
] or json_chunk.get("created")
|
||||
full_response["model"] = full_response["model"] or json_chunk.get(
|
||||
"model"
|
||||
)
|
||||
|
||||
for choice in json_chunk.get("choices", []):
|
||||
index = choice.get("index", 0)
|
||||
|
||||
# Ensure we have a mapping for this index
|
||||
if index not in index_mapping:
|
||||
index_mapping[index] = len(full_response["choices"])
|
||||
full_response["choices"].append(
|
||||
{
|
||||
"index": index,
|
||||
"message": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
)
|
||||
|
||||
existing_choice = full_response["choices"][index_mapping[index]]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Handle regular assistant messages
|
||||
content = delta.get("content")
|
||||
if content is not None:
|
||||
if "content" not in existing_choice["message"]:
|
||||
existing_choice["message"]["content"] = ""
|
||||
existing_choice["message"]["content"] += content
|
||||
|
||||
# Handle tool calls
|
||||
if isinstance(delta.get("tool_calls"), list):
|
||||
if "tool_calls" not in existing_choice["message"]:
|
||||
existing_choice["message"]["tool_calls"] = []
|
||||
|
||||
for tool in delta["tool_calls"]:
|
||||
tool_index = tool.get("index")
|
||||
tool_id = tool.get("id")
|
||||
tool_name = tool.get("function", {}).get("name")
|
||||
tool_arguments = tool.get("function", {}).get(
|
||||
"arguments", ""
|
||||
)
|
||||
|
||||
if tool_index is None:
|
||||
continue
|
||||
|
||||
# Find or create tool call by index
|
||||
if tool_index not in tool_call_mapping:
|
||||
tool_call_mapping[tool_index] = {
|
||||
"index": tool_index,
|
||||
"id": tool_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
existing_choice["message"]["tool_calls"].append(
|
||||
tool_call_mapping[tool_index]
|
||||
)
|
||||
|
||||
tool_entry = tool_call_mapping[tool_index]
|
||||
|
||||
if tool_id:
|
||||
tool_entry["id"] = tool_id
|
||||
|
||||
if tool_name:
|
||||
tool_entry["function"]["name"] = tool_name
|
||||
|
||||
# Append arguments if they exist
|
||||
if tool_arguments:
|
||||
tool_entry["function"]["arguments"] += (
|
||||
tool_arguments
|
||||
)
|
||||
|
||||
finish_reason = choice.get("finish_reason")
|
||||
if finish_reason is not None:
|
||||
existing_choice["finish_reason"] = finish_reason
|
||||
# Process the chunk
|
||||
# This will update merged_response with the data from the chunk
|
||||
process_chunk_text(
|
||||
chunk_text,
|
||||
merged_response,
|
||||
choice_mapping_by_index,
|
||||
tool_call_mapping_by_index,
|
||||
)
|
||||
|
||||
# Send full merged response to the explorer
|
||||
await push_to_explorer(
|
||||
dataset_name, full_response, request_headers, request_body_json
|
||||
dataset_name,
|
||||
merged_response,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def push_to_explorer(dataset_name, full_response, request_headers, request_body):
|
||||
def initialize_merged_response() -> dict[str, Any]:
|
||||
"""Initializes the full response dictionary"""
|
||||
return {
|
||||
"id": None,
|
||||
"object": "chat.completion",
|
||||
"created": None,
|
||||
"model": None,
|
||||
"choices": [],
|
||||
"usage": None,
|
||||
}
|
||||
|
||||
|
||||
def process_chunk_text(
|
||||
chunk_text: str,
|
||||
merged_response: dict[str, Any],
|
||||
choice_mapping_by_index: dict[int, int],
|
||||
tool_call_mapping_by_index: dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
"""Processes the chunk text and updates the merged_response to be sent to the explorer"""
|
||||
# Split the chunk text into individual JSON strings
|
||||
# A single chunk can contain multiple "data: " sections
|
||||
for json_string in chunk_text.split("\ndata: "):
|
||||
json_string = json_string.replace("data: ", "").strip()
|
||||
|
||||
if not json_string or json_string == "[DONE]":
|
||||
continue
|
||||
|
||||
try:
|
||||
json_chunk = json.loads(json_string)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
update_merged_response(
|
||||
json_chunk,
|
||||
merged_response,
|
||||
choice_mapping_by_index,
|
||||
tool_call_mapping_by_index,
|
||||
)
|
||||
|
||||
|
||||
def update_merged_response(
|
||||
json_chunk: dict[str, Any],
|
||||
merged_response: dict[str, Any],
|
||||
choice_mapping_by_index: dict[int, int],
|
||||
tool_call_mapping_by_index: dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
"""Updates the merged_response with the data (content, tool_calls, etc.) from the JSON chunk"""
|
||||
merged_response["id"] = merged_response["id"] or json_chunk.get("id")
|
||||
merged_response["created"] = merged_response["created"] or json_chunk.get("created")
|
||||
merged_response["model"] = merged_response["model"] or json_chunk.get("model")
|
||||
|
||||
for choice in json_chunk.get("choices", []):
|
||||
index = choice.get("index", 0)
|
||||
|
||||
if index not in choice_mapping_by_index:
|
||||
choice_mapping_by_index[index] = len(merged_response["choices"])
|
||||
merged_response["choices"].append(
|
||||
{
|
||||
"index": index,
|
||||
"message": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
)
|
||||
|
||||
existing_choice = merged_response["choices"][choice_mapping_by_index[index]]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
update_existing_choice_with_delta(
|
||||
existing_choice, delta, tool_call_mapping_by_index, choice_index=index
|
||||
)
|
||||
|
||||
|
||||
def update_existing_choice_with_delta(
|
||||
existing_choice: dict[str, Any],
|
||||
delta: dict[str, Any],
|
||||
tool_call_mapping_by_index: dict[str, dict[str, Any]],
|
||||
choice_index: int,
|
||||
) -> None:
|
||||
"""Updates the choice with the data from the delta"""
|
||||
content = delta.get("content")
|
||||
if content is not None:
|
||||
if "content" not in existing_choice["message"]:
|
||||
existing_choice["message"]["content"] = ""
|
||||
existing_choice["message"]["content"] += content
|
||||
|
||||
if isinstance(delta.get("tool_calls"), list):
|
||||
if "tool_calls" not in existing_choice["message"]:
|
||||
existing_choice["message"]["tool_calls"] = []
|
||||
|
||||
for tool in delta["tool_calls"]:
|
||||
tool_index = tool.get("index")
|
||||
tool_id = tool.get("id")
|
||||
name = tool.get("function", {}).get("name")
|
||||
arguments = tool.get("function", {}).get("arguments", "")
|
||||
|
||||
if tool_index is None:
|
||||
continue
|
||||
|
||||
choice_with_tool_call_index = f"{choice_index}-{tool_index}"
|
||||
|
||||
if choice_with_tool_call_index not in tool_call_mapping_by_index:
|
||||
tool_call_mapping_by_index[choice_with_tool_call_index] = {
|
||||
"index": tool_index,
|
||||
"id": tool_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
existing_choice["message"]["tool_calls"].append(
|
||||
tool_call_mapping_by_index[choice_with_tool_call_index]
|
||||
)
|
||||
|
||||
tool_call_entry = tool_call_mapping_by_index[choice_with_tool_call_index]
|
||||
|
||||
if tool_id:
|
||||
tool_call_entry["id"] = tool_id
|
||||
|
||||
if name:
|
||||
tool_call_entry["function"]["name"] = name
|
||||
|
||||
if arguments:
|
||||
tool_call_entry["function"]["arguments"] += arguments
|
||||
|
||||
finish_reason = delta.get("finish_reason")
|
||||
if finish_reason is not None:
|
||||
existing_choice["finish_reason"] = finish_reason
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
dataset_name: str,
|
||||
merged_response: dict[str, Any],
|
||||
request_body: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
# Combine messages from the request and the response
|
||||
# to push the full trace to the Invariant Explorer
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
messages = request_body.get("messages", [])
|
||||
messages += [choice["message"] for choice in full_response.get("choices", [])]
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
|
||||
_ = await push_trace(
|
||||
dataset_name=dataset_name,
|
||||
messages=[messages],
|
||||
invariant_authorization=request_headers.get("invariant-authorization"),
|
||||
invariant_authorization=invariant_authorization,
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, request_headers
|
||||
response: httpx.Response,
|
||||
dataset_name: str,
|
||||
request_body_json: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
):
|
||||
"""Handles non-streaming OpenAI responses"""
|
||||
json_response = response.json()
|
||||
await push_to_explorer(
|
||||
dataset_name, json_response, request_headers, request_body_json
|
||||
dataset_name, json_response, request_body_json, invariant_authorization
|
||||
)
|
||||
|
||||
response_headers = dict(response.headers)
|
||||
|
||||
@@ -11,14 +11,14 @@ PUSH_ENDPOINT = "/api/v1/push/trace"
|
||||
|
||||
|
||||
async def push_trace(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: List[List[Dict[str, Any]]],
|
||||
dataset_name: str,
|
||||
invariant_authorization: str,
|
||||
) -> Dict[str, str]:
|
||||
"""Pushes traces to the dataset on the Invariant Explorer.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): List of messages to push.
|
||||
messages (List[List[Dict[str, Any]]]): List of messages to push.
|
||||
dataset_name (str): Name of the dataset.
|
||||
invariant_authorization (str): Authorization token from the
|
||||
invariant-authorization header.
|
||||
@@ -27,8 +27,12 @@ async def push_trace(
|
||||
Dict[str, str]: Response containing the trace ID.
|
||||
"""
|
||||
api_url = os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
|
||||
request = PushTracesRequest(messages=messages, dataset=dataset_name)
|
||||
# Remove any None values from the messages
|
||||
update_messages = [
|
||||
[{k: v for k, v in msg.items() if v is not None} for msg in msg_list]
|
||||
for msg_list in messages
|
||||
]
|
||||
request = PushTracesRequest(messages=update_messages, dataset=dataset_name)
|
||||
async with httpx.AsyncClient() as client:
|
||||
explorer_push_request = client.build_request(
|
||||
"POST",
|
||||
|
||||
Reference in New Issue
Block a user