Files
invariant-gateway/proxy/routes/open_ai.py
T
2025-02-04 14:24:37 +01:00

333 lines
11 KiB
Python

"""Proxy service to forward requests to the OpenAI APIs"""
import json
from typing import Any
import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse
from utils.explorer import push_trace
ALLOWED_OPEN_AI_ENDPOINTS = {"chat/completions"}
IGNORED_HEADERS = [
"accept-encoding",
"host",
"invariant-authorization",
"x-forwarded-for",
"x-forwarded-host",
"x-forwarded-port",
"x-forwarded-proto",
"x-forwarded-server",
"x-real-ip",
]
proxy = APIRouter()
MISSING_INVARIANT_AUTH_HEADER = "Missing invariant-authorization header"
MISSING_AUTH_HEADER = "Missing authorization header"
NOT_SUPPORTED_ENDPOINT = "Not supported OpenAI endpoint"
FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: "
def validate_headers(
invariant_authorization: str = Header(None), authorization: str = Header(None)
):
"""Require the invariant-authorization and authorization headers to be present"""
if invariant_authorization is None:
raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_HEADER)
if authorization is None:
raise HTTPException(status_code=400, detail=MISSING_AUTH_HEADER)
@proxy.post(
"/{dataset_name}/openai/{endpoint:path}",
dependencies=[Depends(validate_headers)],
)
async def openai_proxy(
request: Request,
dataset_name: str,
endpoint: str,
) -> Response:
"""Proxy calls to the OpenAI APIs"""
if endpoint not in ALLOWED_OPEN_AI_ENDPOINTS:
raise HTTPException(status_code=404, detail=NOT_SUPPORTED_ENDPOINT)
headers = {
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
}
headers["accept-encoding"] = "identity"
request_body_bytes = await request.body()
request_body_json = json.loads(request_body_bytes)
# Check if the request is for streaming
is_streaming = request_body_json.get("stream", False)
invariant_authorization = request.headers.get("invariant-authorization")
client = httpx.AsyncClient()
open_ai_request = client.build_request(
"POST",
f"https://api.openai.com/v1/{endpoint}",
content=request_body_bytes,
headers=headers,
)
if is_streaming:
return await stream_response(
client,
open_ai_request,
dataset_name,
request_body_json,
invariant_authorization,
)
else:
async with client:
response = await client.send(open_ai_request)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
)
async def stream_response(
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
dataset_name: str,
request_body_json: dict[str, Any],
invariant_authorization: str,
) -> StreamingResponse:
"""
Handles streaming the OpenAI response to the client while building a merged_response
The chunks are returned to the caller immediately
The merged_response is built from the chunks as they are received
It is sent to the Invariant Explorer at the end of the stream
"""
async def event_generator() -> Any:
# merged_response will be updated with the data from the chunks in the stream
# At the end of the stream, this will be sent to the explorer
merged_response = {
"id": None,
"object": "chat.completion",
"created": None,
"model": None,
"choices": [],
"usage": None,
}
# Each chunk in the stream contains a list called "choices" each entry in the list
# has an index.
# A choice has a field called "delta" which may contain a list called "tool_calls".
# Maps the choice index in the stream to the index in the merged_response["choices"] list
choice_mapping_by_index = {}
# Combines the choice index and tool call index to uniquely identify a tool call
tool_call_mapping_by_index = {}
async with client.stream(
"POST",
open_ai_request.url,
headers=open_ai_request.headers,
content=open_ai_request.content,
) as response:
if response.status_code != 200:
yield json.dumps(
{"error": f"Failed to fetch response: {response.status_code}"}
).encode()
return
async for chunk in response.aiter_bytes():
chunk_text = chunk.decode().strip()
if not chunk_text:
continue
# Yield chunk immediately to the client (proxy behavior)
yield chunk
# Process the chunk
# This will update merged_response with the data from the chunk
process_chunk_text(
chunk_text,
merged_response,
choice_mapping_by_index,
tool_call_mapping_by_index,
)
# Send full merged response to the explorer
await push_to_explorer(
dataset_name,
merged_response,
request_body_json,
invariant_authorization,
)
return StreamingResponse(event_generator(), media_type="text/event-stream")
def initialize_merged_response() -> dict[str, Any]:
"""Initializes the full response dictionary"""
return {
"id": None,
"object": "chat.completion",
"created": None,
"model": None,
"choices": [],
"usage": None,
}
def process_chunk_text(
chunk_text: str,
merged_response: dict[str, Any],
choice_mapping_by_index: dict[int, int],
tool_call_mapping_by_index: dict[str, dict[str, Any]],
) -> None:
"""Processes the chunk text and updates the merged_response to be sent to the explorer"""
# Split the chunk text into individual JSON strings
# A single chunk can contain multiple "data: " sections
for json_string in chunk_text.split("\ndata: "):
json_string = json_string.replace("data: ", "").strip()
if not json_string or json_string == "[DONE]":
continue
try:
json_chunk = json.loads(json_string)
except json.JSONDecodeError:
continue
update_merged_response(
json_chunk,
merged_response,
choice_mapping_by_index,
tool_call_mapping_by_index,
)
def update_merged_response(
json_chunk: dict[str, Any],
merged_response: dict[str, Any],
choice_mapping_by_index: dict[int, int],
tool_call_mapping_by_index: dict[str, dict[str, Any]],
) -> None:
"""Updates the merged_response with the data (content, tool_calls, etc.) from the JSON chunk"""
merged_response["id"] = merged_response["id"] or json_chunk.get("id")
merged_response["created"] = merged_response["created"] or json_chunk.get("created")
merged_response["model"] = merged_response["model"] or json_chunk.get("model")
for choice in json_chunk.get("choices", []):
index = choice.get("index", 0)
if index not in choice_mapping_by_index:
choice_mapping_by_index[index] = len(merged_response["choices"])
merged_response["choices"].append(
{
"index": index,
"message": {"role": "assistant"},
"finish_reason": None,
}
)
existing_choice = merged_response["choices"][choice_mapping_by_index[index]]
delta = choice.get("delta", {})
update_existing_choice_with_delta(
existing_choice, delta, tool_call_mapping_by_index, choice_index=index
)
def update_existing_choice_with_delta(
existing_choice: dict[str, Any],
delta: dict[str, Any],
tool_call_mapping_by_index: dict[str, dict[str, Any]],
choice_index: int,
) -> None:
"""Updates the choice with the data from the delta"""
content = delta.get("content")
if content is not None:
if "content" not in existing_choice["message"]:
existing_choice["message"]["content"] = ""
existing_choice["message"]["content"] += content
if isinstance(delta.get("tool_calls"), list):
if "tool_calls" not in existing_choice["message"]:
existing_choice["message"]["tool_calls"] = []
for tool in delta["tool_calls"]:
tool_index = tool.get("index")
tool_id = tool.get("id")
name = tool.get("function", {}).get("name")
arguments = tool.get("function", {}).get("arguments", "")
if tool_index is None:
continue
choice_with_tool_call_index = f"{choice_index}-{tool_index}"
if choice_with_tool_call_index not in tool_call_mapping_by_index:
tool_call_mapping_by_index[choice_with_tool_call_index] = {
"index": tool_index,
"id": tool_id,
"type": "function",
"function": {
"name": name,
"arguments": "",
},
}
existing_choice["message"]["tool_calls"].append(
tool_call_mapping_by_index[choice_with_tool_call_index]
)
tool_call_entry = tool_call_mapping_by_index[choice_with_tool_call_index]
if tool_id:
tool_call_entry["id"] = tool_id
if name:
tool_call_entry["function"]["name"] = name
if arguments:
tool_call_entry["function"]["arguments"] += arguments
finish_reason = delta.get("finish_reason")
if finish_reason is not None:
existing_choice["finish_reason"] = finish_reason
async def push_to_explorer(
dataset_name: str,
merged_response: dict[str, Any],
request_body: dict[str, Any],
invariant_authorization: str,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
# Combine the messages from the request body and the choices from the OpenAI response
messages = request_body.get("messages", [])
messages += [choice["message"] for choice in merged_response.get("choices", [])]
_ = await push_trace(
dataset_name=dataset_name,
messages=[messages],
invariant_authorization=invariant_authorization,
)
async def handle_non_streaming_response(
response: httpx.Response,
dataset_name: str,
request_body_json: dict[str, Any],
invariant_authorization: str,
):
"""Handles non-streaming OpenAI responses"""
json_response = response.json()
await push_to_explorer(
dataset_name, json_response, request_body_json, invariant_authorization
)
response_headers = dict(response.headers)
response_headers.pop("Content-Encoding", None)
response_headers.pop("Content-Length", None)
return Response(
content=json.dumps(json_response),
status_code=response.status_code,
media_type="application/json",
headers=response_headers,
)