Refactor the code and add type hints.

This commit is contained in:
Hemang
2025-02-04 14:24:37 +01:00
parent 9c3937183e
commit 847a01db84
2 changed files with 193 additions and 121 deletions
+185 -117
View File
@@ -1,6 +1,7 @@
"""Proxy service to forward requests to the OpenAI APIs"""
import json
from typing import Any
import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
@@ -46,7 +47,7 @@ async def openai_proxy(
request: Request,
dataset_name: str,
endpoint: str,
):
) -> Response:
"""Proxy calls to the OpenAI APIs"""
if endpoint not in ALLOWED_OPEN_AI_ENDPOINTS:
raise HTTPException(status_code=404, detail=NOT_SUPPORTED_ENDPOINT)
@@ -61,6 +62,7 @@ async def openai_proxy(
# Check if the request is for streaming
is_streaming = request_body_json.get("stream", False)
invariant_authorization = request.headers.get("invariant-authorization")
client = httpx.AsyncClient()
open_ai_request = client.build_request(
@@ -75,23 +77,34 @@ async def openai_proxy(
open_ai_request,
dataset_name,
request_body_json,
request.headers,
invariant_authorization,
)
else:
async with client:
response = await client.send(open_ai_request)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, request.headers
response, dataset_name, request_body_json, invariant_authorization
)
async def stream_response(
client, open_ai_request, dataset_name, request_body_json, request_headers
):
"""Handles streaming the OpenAI response to the client while collecting full response"""
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
dataset_name: str,
request_body_json: dict[str, Any],
invariant_authorization: str,
) -> StreamingResponse:
"""
Handles streaming the OpenAI response to the client while building a merged_response
The chunks are returned to the caller immediately
The merged_response is built from the chunks as they are received
It is sent to the Invariant Explorer at the end of the stream
"""
async def event_generator():
full_response = {
async def event_generator() -> Any:
# merged_response will be updated with the data from the chunks in the stream
# At the end of the stream, this will be sent to the explorer
merged_response = {
"id": None,
"object": "chat.completion",
"created": None,
@@ -99,11 +112,13 @@ async def stream_response(
"choices": [],
"usage": None,
}
# Tracks choice index to full_response index
index_mapping = {}
# Tracks tool calls by index
tool_call_mapping = {}
# Each chunk in the stream contains a list called "choices" each entry in the list
# has an index.
# A choice has a field called "delta" which may contain a list called "tool_calls".
# Maps the choice index in the stream to the index in the merged_response["choices"] list
choice_mapping_by_index = {}
# Combines the choice index and tool call index to uniquely identify a tool call
tool_call_mapping_by_index = {}
async with client.stream(
"POST",
@@ -112,10 +127,9 @@ async def stream_response(
content=open_ai_request.content,
) as response:
if response.status_code != 200:
error_message = json.dumps(
yield json.dumps(
{"error": f"Failed to fetch response: {response.status_code}"}
).encode()
yield error_message
return
async for chunk in response.aiter_bytes():
@@ -126,130 +140,184 @@ async def stream_response(
# Yield chunk immediately to the client (proxy behavior)
yield chunk
# There can be multiple "data: " chunks in a single response
for json_string in chunk_text.split("\ndata: "):
# Remove first "data: " prefix
json_string = json_string.replace("data: ", "").strip()
if not json_string or json_string == "[DONE]":
continue
try:
json_chunk = json.loads(json_string)
except json.JSONDecodeError:
continue
# Extract metadata safely
full_response["id"] = full_response["id"] or json_chunk.get("id")
full_response["created"] = full_response[
"created"
] or json_chunk.get("created")
full_response["model"] = full_response["model"] or json_chunk.get(
"model"
)
for choice in json_chunk.get("choices", []):
index = choice.get("index", 0)
# Ensure we have a mapping for this index
if index not in index_mapping:
index_mapping[index] = len(full_response["choices"])
full_response["choices"].append(
{
"index": index,
"message": {"role": "assistant"},
"finish_reason": None,
}
)
existing_choice = full_response["choices"][index_mapping[index]]
delta = choice.get("delta", {})
# Handle regular assistant messages
content = delta.get("content")
if content is not None:
if "content" not in existing_choice["message"]:
existing_choice["message"]["content"] = ""
existing_choice["message"]["content"] += content
# Handle tool calls
if isinstance(delta.get("tool_calls"), list):
if "tool_calls" not in existing_choice["message"]:
existing_choice["message"]["tool_calls"] = []
for tool in delta["tool_calls"]:
tool_index = tool.get("index")
tool_id = tool.get("id")
tool_name = tool.get("function", {}).get("name")
tool_arguments = tool.get("function", {}).get(
"arguments", ""
)
if tool_index is None:
continue
# Find or create tool call by index
if tool_index not in tool_call_mapping:
tool_call_mapping[tool_index] = {
"index": tool_index,
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": "",
},
}
existing_choice["message"]["tool_calls"].append(
tool_call_mapping[tool_index]
)
tool_entry = tool_call_mapping[tool_index]
if tool_id:
tool_entry["id"] = tool_id
if tool_name:
tool_entry["function"]["name"] = tool_name
# Append arguments if they exist
if tool_arguments:
tool_entry["function"]["arguments"] += (
tool_arguments
)
finish_reason = choice.get("finish_reason")
if finish_reason is not None:
existing_choice["finish_reason"] = finish_reason
# Process the chunk
# This will update merged_response with the data from the chunk
process_chunk_text(
chunk_text,
merged_response,
choice_mapping_by_index,
tool_call_mapping_by_index,
)
# Send full merged response to the explorer
await push_to_explorer(
dataset_name, full_response, request_headers, request_body_json
dataset_name,
merged_response,
request_body_json,
invariant_authorization,
)
return StreamingResponse(event_generator(), media_type="text/event-stream")
async def push_to_explorer(dataset_name, full_response, request_headers, request_body):
def initialize_merged_response() -> dict[str, Any]:
"""Initializes the full response dictionary"""
return {
"id": None,
"object": "chat.completion",
"created": None,
"model": None,
"choices": [],
"usage": None,
}
def process_chunk_text(
chunk_text: str,
merged_response: dict[str, Any],
choice_mapping_by_index: dict[int, int],
tool_call_mapping_by_index: dict[str, dict[str, Any]],
) -> None:
"""Processes the chunk text and updates the merged_response to be sent to the explorer"""
# Split the chunk text into individual JSON strings
# A single chunk can contain multiple "data: " sections
for json_string in chunk_text.split("\ndata: "):
json_string = json_string.replace("data: ", "").strip()
if not json_string or json_string == "[DONE]":
continue
try:
json_chunk = json.loads(json_string)
except json.JSONDecodeError:
continue
update_merged_response(
json_chunk,
merged_response,
choice_mapping_by_index,
tool_call_mapping_by_index,
)
def update_merged_response(
json_chunk: dict[str, Any],
merged_response: dict[str, Any],
choice_mapping_by_index: dict[int, int],
tool_call_mapping_by_index: dict[str, dict[str, Any]],
) -> None:
"""Updates the merged_response with the data (content, tool_calls, etc.) from the JSON chunk"""
merged_response["id"] = merged_response["id"] or json_chunk.get("id")
merged_response["created"] = merged_response["created"] or json_chunk.get("created")
merged_response["model"] = merged_response["model"] or json_chunk.get("model")
for choice in json_chunk.get("choices", []):
index = choice.get("index", 0)
if index not in choice_mapping_by_index:
choice_mapping_by_index[index] = len(merged_response["choices"])
merged_response["choices"].append(
{
"index": index,
"message": {"role": "assistant"},
"finish_reason": None,
}
)
existing_choice = merged_response["choices"][choice_mapping_by_index[index]]
delta = choice.get("delta", {})
update_existing_choice_with_delta(
existing_choice, delta, tool_call_mapping_by_index, choice_index=index
)
def update_existing_choice_with_delta(
existing_choice: dict[str, Any],
delta: dict[str, Any],
tool_call_mapping_by_index: dict[str, dict[str, Any]],
choice_index: int,
) -> None:
"""Updates the choice with the data from the delta"""
content = delta.get("content")
if content is not None:
if "content" not in existing_choice["message"]:
existing_choice["message"]["content"] = ""
existing_choice["message"]["content"] += content
if isinstance(delta.get("tool_calls"), list):
if "tool_calls" not in existing_choice["message"]:
existing_choice["message"]["tool_calls"] = []
for tool in delta["tool_calls"]:
tool_index = tool.get("index")
tool_id = tool.get("id")
name = tool.get("function", {}).get("name")
arguments = tool.get("function", {}).get("arguments", "")
if tool_index is None:
continue
choice_with_tool_call_index = f"{choice_index}-{tool_index}"
if choice_with_tool_call_index not in tool_call_mapping_by_index:
tool_call_mapping_by_index[choice_with_tool_call_index] = {
"index": tool_index,
"id": tool_id,
"type": "function",
"function": {
"name": name,
"arguments": "",
},
}
existing_choice["message"]["tool_calls"].append(
tool_call_mapping_by_index[choice_with_tool_call_index]
)
tool_call_entry = tool_call_mapping_by_index[choice_with_tool_call_index]
if tool_id:
tool_call_entry["id"] = tool_id
if name:
tool_call_entry["function"]["name"] = name
if arguments:
tool_call_entry["function"]["arguments"] += arguments
finish_reason = delta.get("finish_reason")
if finish_reason is not None:
existing_choice["finish_reason"] = finish_reason
async def push_to_explorer(
dataset_name: str,
merged_response: dict[str, Any],
request_body: dict[str, Any],
invariant_authorization: str,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
# Combine messages from the request and the response
# to push the full trace to the Invariant Explorer
# Combine the messages from the request body and the choices from the OpenAI response
messages = request_body.get("messages", [])
messages += [choice["message"] for choice in full_response.get("choices", [])]
messages += [choice["message"] for choice in merged_response.get("choices", [])]
_ = await push_trace(
dataset_name=dataset_name,
messages=[messages],
invariant_authorization=request_headers.get("invariant-authorization"),
invariant_authorization=invariant_authorization,
)
async def handle_non_streaming_response(
response, dataset_name, request_body_json, request_headers
response: httpx.Response,
dataset_name: str,
request_body_json: dict[str, Any],
invariant_authorization: str,
):
"""Handles non-streaming OpenAI responses"""
json_response = response.json()
await push_to_explorer(
dataset_name, json_response, request_headers, request_body_json
dataset_name, json_response, request_body_json, invariant_authorization
)
response_headers = dict(response.headers)
+8 -4
View File
@@ -11,14 +11,14 @@ PUSH_ENDPOINT = "/api/v1/push/trace"
async def push_trace(
messages: List[Dict[str, Any]],
messages: List[List[Dict[str, Any]]],
dataset_name: str,
invariant_authorization: str,
) -> Dict[str, str]:
"""Pushes traces to the dataset on the Invariant Explorer.
Args:
messages (List[Dict[str, Any]]): List of messages to push.
messages (List[List[Dict[str, Any]]]): List of messages to push.
dataset_name (str): Name of the dataset.
invariant_authorization (str): Authorization token from the
invariant-authorization header.
@@ -27,8 +27,12 @@ async def push_trace(
Dict[str, str]: Response containing the trace ID.
"""
api_url = os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/")
request = PushTracesRequest(messages=messages, dataset=dataset_name)
# Remove any None values from the messages
update_messages = [
[{k: v for k, v in msg.items() if v is not None} for msg in msg_list]
for msg_list in messages
]
request = PushTracesRequest(messages=update_messages, dataset=dataset_name)
async with httpx.AsyncClient() as client:
explorer_push_request = client.build_request(
"POST",