From 847a01db84d2cafd5788afd85a1c29b641ea7b79 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 4 Feb 2025 14:24:37 +0100 Subject: [PATCH] Refactor the code and add type hints. --- proxy/routes/open_ai.py | 302 ++++++++++++++++++++++++---------------- proxy/utils/explorer.py | 12 +- 2 files changed, 193 insertions(+), 121 deletions(-) diff --git a/proxy/routes/open_ai.py b/proxy/routes/open_ai.py index 02a752a..34f709d 100644 --- a/proxy/routes/open_ai.py +++ b/proxy/routes/open_ai.py @@ -1,6 +1,7 @@ """Proxy service to forward requests to the OpenAI APIs""" import json +from typing import Any import httpx from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response @@ -46,7 +47,7 @@ async def openai_proxy( request: Request, dataset_name: str, endpoint: str, -): +) -> Response: """Proxy calls to the OpenAI APIs""" if endpoint not in ALLOWED_OPEN_AI_ENDPOINTS: raise HTTPException(status_code=404, detail=NOT_SUPPORTED_ENDPOINT) @@ -61,6 +62,7 @@ async def openai_proxy( # Check if the request is for streaming is_streaming = request_body_json.get("stream", False) + invariant_authorization = request.headers.get("invariant-authorization") client = httpx.AsyncClient() open_ai_request = client.build_request( @@ -75,23 +77,34 @@ async def openai_proxy( open_ai_request, dataset_name, request_body_json, - request.headers, + invariant_authorization, ) else: async with client: response = await client.send(open_ai_request) return await handle_non_streaming_response( - response, dataset_name, request_body_json, request.headers + response, dataset_name, request_body_json, invariant_authorization ) async def stream_response( - client, open_ai_request, dataset_name, request_body_json, request_headers -): - """Handles streaming the OpenAI response to the client while collecting full response""" + client: httpx.AsyncClient, + open_ai_request: httpx.Request, + dataset_name: str, + request_body_json: dict[str, Any], + invariant_authorization: str, +) -> StreamingResponse: + """ + Handles streaming the OpenAI response to the client while building a merged_response + The chunks are returned to the caller immediately + The merged_response is built from the chunks as they are received + It is sent to the Invariant Explorer at the end of the stream + """ - async def event_generator(): - full_response = { + async def event_generator() -> Any: + # merged_response will be updated with the data from the chunks in the stream + # At the end of the stream, this will be sent to the explorer + merged_response = { "id": None, "object": "chat.completion", "created": None, @@ -99,11 +112,13 @@ async def stream_response( "choices": [], "usage": None, } - - # Tracks choice index to full_response index - index_mapping = {} - # Tracks tool calls by index - tool_call_mapping = {} + # Each chunk in the stream contains a list called "choices" each entry in the list + # has an index. + # A choice has a field called "delta" which may contain a list called "tool_calls". + # Maps the choice index in the stream to the index in the merged_response["choices"] list + choice_mapping_by_index = {} + # Combines the choice index and tool call index to uniquely identify a tool call + tool_call_mapping_by_index = {} async with client.stream( "POST", @@ -112,10 +127,9 @@ async def stream_response( content=open_ai_request.content, ) as response: if response.status_code != 200: - error_message = json.dumps( + yield json.dumps( {"error": f"Failed to fetch response: {response.status_code}"} ).encode() - yield error_message return async for chunk in response.aiter_bytes(): @@ -126,130 +140,184 @@ async def stream_response( # Yield chunk immediately to the client (proxy behavior) yield chunk - # There can be multiple "data: " chunks in a single response - for json_string in chunk_text.split("\ndata: "): - # Remove first "data: " prefix - json_string = json_string.replace("data: ", "").strip() - - if not json_string or json_string == "[DONE]": - continue - - try: - json_chunk = json.loads(json_string) - except json.JSONDecodeError: - continue - - # Extract metadata safely - full_response["id"] = full_response["id"] or json_chunk.get("id") - full_response["created"] = full_response[ - "created" - ] or json_chunk.get("created") - full_response["model"] = full_response["model"] or json_chunk.get( - "model" - ) - - for choice in json_chunk.get("choices", []): - index = choice.get("index", 0) - - # Ensure we have a mapping for this index - if index not in index_mapping: - index_mapping[index] = len(full_response["choices"]) - full_response["choices"].append( - { - "index": index, - "message": {"role": "assistant"}, - "finish_reason": None, - } - ) - - existing_choice = full_response["choices"][index_mapping[index]] - delta = choice.get("delta", {}) - - # Handle regular assistant messages - content = delta.get("content") - if content is not None: - if "content" not in existing_choice["message"]: - existing_choice["message"]["content"] = "" - existing_choice["message"]["content"] += content - - # Handle tool calls - if isinstance(delta.get("tool_calls"), list): - if "tool_calls" not in existing_choice["message"]: - existing_choice["message"]["tool_calls"] = [] - - for tool in delta["tool_calls"]: - tool_index = tool.get("index") - tool_id = tool.get("id") - tool_name = tool.get("function", {}).get("name") - tool_arguments = tool.get("function", {}).get( - "arguments", "" - ) - - if tool_index is None: - continue - - # Find or create tool call by index - if tool_index not in tool_call_mapping: - tool_call_mapping[tool_index] = { - "index": tool_index, - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": "", - }, - } - existing_choice["message"]["tool_calls"].append( - tool_call_mapping[tool_index] - ) - - tool_entry = tool_call_mapping[tool_index] - - if tool_id: - tool_entry["id"] = tool_id - - if tool_name: - tool_entry["function"]["name"] = tool_name - - # Append arguments if they exist - if tool_arguments: - tool_entry["function"]["arguments"] += ( - tool_arguments - ) - - finish_reason = choice.get("finish_reason") - if finish_reason is not None: - existing_choice["finish_reason"] = finish_reason + # Process the chunk + # This will update merged_response with the data from the chunk + process_chunk_text( + chunk_text, + merged_response, + choice_mapping_by_index, + tool_call_mapping_by_index, + ) # Send full merged response to the explorer await push_to_explorer( - dataset_name, full_response, request_headers, request_body_json + dataset_name, + merged_response, + request_body_json, + invariant_authorization, ) return StreamingResponse(event_generator(), media_type="text/event-stream") -async def push_to_explorer(dataset_name, full_response, request_headers, request_body): +def initialize_merged_response() -> dict[str, Any]: + """Initializes the full response dictionary""" + return { + "id": None, + "object": "chat.completion", + "created": None, + "model": None, + "choices": [], + "usage": None, + } + + +def process_chunk_text( + chunk_text: str, + merged_response: dict[str, Any], + choice_mapping_by_index: dict[int, int], + tool_call_mapping_by_index: dict[str, dict[str, Any]], +) -> None: + """Processes the chunk text and updates the merged_response to be sent to the explorer""" + # Split the chunk text into individual JSON strings + # A single chunk can contain multiple "data: " sections + for json_string in chunk_text.split("\ndata: "): + json_string = json_string.replace("data: ", "").strip() + + if not json_string or json_string == "[DONE]": + continue + + try: + json_chunk = json.loads(json_string) + except json.JSONDecodeError: + continue + + update_merged_response( + json_chunk, + merged_response, + choice_mapping_by_index, + tool_call_mapping_by_index, + ) + + +def update_merged_response( + json_chunk: dict[str, Any], + merged_response: dict[str, Any], + choice_mapping_by_index: dict[int, int], + tool_call_mapping_by_index: dict[str, dict[str, Any]], +) -> None: + """Updates the merged_response with the data (content, tool_calls, etc.) from the JSON chunk""" + merged_response["id"] = merged_response["id"] or json_chunk.get("id") + merged_response["created"] = merged_response["created"] or json_chunk.get("created") + merged_response["model"] = merged_response["model"] or json_chunk.get("model") + + for choice in json_chunk.get("choices", []): + index = choice.get("index", 0) + + if index not in choice_mapping_by_index: + choice_mapping_by_index[index] = len(merged_response["choices"]) + merged_response["choices"].append( + { + "index": index, + "message": {"role": "assistant"}, + "finish_reason": None, + } + ) + + existing_choice = merged_response["choices"][choice_mapping_by_index[index]] + delta = choice.get("delta", {}) + + update_existing_choice_with_delta( + existing_choice, delta, tool_call_mapping_by_index, choice_index=index + ) + + +def update_existing_choice_with_delta( + existing_choice: dict[str, Any], + delta: dict[str, Any], + tool_call_mapping_by_index: dict[str, dict[str, Any]], + choice_index: int, +) -> None: + """Updates the choice with the data from the delta""" + content = delta.get("content") + if content is not None: + if "content" not in existing_choice["message"]: + existing_choice["message"]["content"] = "" + existing_choice["message"]["content"] += content + + if isinstance(delta.get("tool_calls"), list): + if "tool_calls" not in existing_choice["message"]: + existing_choice["message"]["tool_calls"] = [] + + for tool in delta["tool_calls"]: + tool_index = tool.get("index") + tool_id = tool.get("id") + name = tool.get("function", {}).get("name") + arguments = tool.get("function", {}).get("arguments", "") + + if tool_index is None: + continue + + choice_with_tool_call_index = f"{choice_index}-{tool_index}" + + if choice_with_tool_call_index not in tool_call_mapping_by_index: + tool_call_mapping_by_index[choice_with_tool_call_index] = { + "index": tool_index, + "id": tool_id, + "type": "function", + "function": { + "name": name, + "arguments": "", + }, + } + existing_choice["message"]["tool_calls"].append( + tool_call_mapping_by_index[choice_with_tool_call_index] + ) + + tool_call_entry = tool_call_mapping_by_index[choice_with_tool_call_index] + + if tool_id: + tool_call_entry["id"] = tool_id + + if name: + tool_call_entry["function"]["name"] = name + + if arguments: + tool_call_entry["function"]["arguments"] += arguments + + finish_reason = delta.get("finish_reason") + if finish_reason is not None: + existing_choice["finish_reason"] = finish_reason + + +async def push_to_explorer( + dataset_name: str, + merged_response: dict[str, Any], + request_body: dict[str, Any], + invariant_authorization: str, +) -> None: """Pushes the full trace to the Invariant Explorer""" - # Combine messages from the request and the response - # to push the full trace to the Invariant Explorer + # Combine the messages from the request body and the choices from the OpenAI response messages = request_body.get("messages", []) - messages += [choice["message"] for choice in full_response.get("choices", [])] + messages += [choice["message"] for choice in merged_response.get("choices", [])] _ = await push_trace( dataset_name=dataset_name, messages=[messages], - invariant_authorization=request_headers.get("invariant-authorization"), + invariant_authorization=invariant_authorization, ) async def handle_non_streaming_response( - response, dataset_name, request_body_json, request_headers + response: httpx.Response, + dataset_name: str, + request_body_json: dict[str, Any], + invariant_authorization: str, ): """Handles non-streaming OpenAI responses""" json_response = response.json() await push_to_explorer( - dataset_name, json_response, request_headers, request_body_json + dataset_name, json_response, request_body_json, invariant_authorization ) response_headers = dict(response.headers) diff --git a/proxy/utils/explorer.py b/proxy/utils/explorer.py index f9eea8e..d54b4e6 100644 --- a/proxy/utils/explorer.py +++ b/proxy/utils/explorer.py @@ -11,14 +11,14 @@ PUSH_ENDPOINT = "/api/v1/push/trace" async def push_trace( - messages: List[Dict[str, Any]], + messages: List[List[Dict[str, Any]]], dataset_name: str, invariant_authorization: str, ) -> Dict[str, str]: """Pushes traces to the dataset on the Invariant Explorer. Args: - messages (List[Dict[str, Any]]): List of messages to push. + messages (List[List[Dict[str, Any]]]): List of messages to push. dataset_name (str): Name of the dataset. invariant_authorization (str): Authorization token from the invariant-authorization header. @@ -27,8 +27,12 @@ async def push_trace( Dict[str, str]: Response containing the trace ID. """ api_url = os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/") - - request = PushTracesRequest(messages=messages, dataset=dataset_name) + # Remove any None values from the messages + update_messages = [ + [{k: v for k, v in msg.items() if v is not None} for msg in msg_list] + for msg_list in messages + ] + request = PushTracesRequest(messages=update_messages, dataset=dataset_name) async with httpx.AsyncClient() as client: explorer_push_request = client.build_request( "POST",