From 20f8a12032265fee75c3dad0efbdd3102501d647 Mon Sep 17 00:00:00 2001 From: Hemang Date: Mon, 10 Mar 2025 14:21:22 +0100 Subject: [PATCH] Formatting changes. --- gateway/common/request_context_data.py | 11 ++++ gateway/routes/gemini.py | 59 +++++++++---------- gateway/routes/open_ai.py | 58 ++++++++---------- ...est_anthropic_header_with_invariant_key.py | 8 ++- .../test_anthropic_with_tool_call.py | 8 ++- .../test_anthropic_without_tool_call.py | 8 ++- .../test_generate_content_with_tool_calls.py | 12 ++-- ...est_generate_content_without_tool_calls.py | 18 +++--- tests/open_ai/test_chat_with_tool_call.py | 14 ++--- tests/open_ai/test_chat_without_tool_calls.py | 18 +++--- 10 files changed, 111 insertions(+), 103 deletions(-) create mode 100644 gateway/common/request_context_data.py diff --git a/gateway/common/request_context_data.py b/gateway/common/request_context_data.py new file mode 100644 index 0000000..d59fb03 --- /dev/null +++ b/gateway/common/request_context_data.py @@ -0,0 +1,11 @@ +"""Common Request context data class.""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional + +@dataclass(frozen=True) +class RequestContextData: + """Request context data class.""" + request_json: Dict[str, Any] + dataset_name: Optional[str] = None + invariant_authorization: Optional[str] = None diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index a7bc1dd..cd88584 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -1,7 +1,7 @@ """Gateway service to forward requests to the Gemini APIs""" import json -from typing import Any, Optional +from typing import Any import httpx from common.config_manager import GatewayConfig, GatewayConfigManager @@ -12,6 +12,7 @@ from common.constants import ( IGNORED_HEADERS, ) from common.authorization import extract_authorization_from_headers +from common.request_context_data import RequestContextData from converters.gemini_to_invariant import convert_request, convert_response from integrations.explorer import push_trace @@ -52,7 +53,7 @@ async def gemini_generate_content_gateway( headers[GEMINI_AUTHORIZATION_HEADER] = gemini_api_key request_body_bytes = await request.body() - request_body_json = json.loads(request_body_bytes) + request_json = json.loads(request_body_bytes) client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) gemini_api_url = f"https://generativelanguage.googleapis.com/{api_version}/models/{model}:{endpoint}" @@ -65,26 +66,29 @@ async def gemini_generate_content_gateway( headers=headers, ) + context = RequestContextData( + request_json=request_json, + dataset_name=dataset_name, + invariant_authorization=invariant_authorization, + ) + if alt == "sse" or endpoint == "streamGenerateContent": return await stream_response( + context, client, gemini_request, - dataset_name, - request_body_json, - invariant_authorization, ) response = await client.send(gemini_request) return await handle_non_streaming_response( - response, dataset_name, request_body_json, invariant_authorization + context, + response, ) async def stream_response( + context: RequestContextData, client: httpx.AsyncClient, gemini_request: httpx.Request, - dataset_name: Optional[str], - request_body_json: dict[str, Any], - invariant_authorization: Optional[str], ) -> Response: """Handles streaming the Gemini response to the client""" @@ -115,13 +119,11 @@ async def stream_response( # Parse and update merged_response incrementally process_chunk_text(merged_response, chunk_text) - if dataset_name: + if context.dataset_name: # Push to Explorer await push_to_explorer( - dataset_name, + context, merged_response, - request_body_json, - invariant_authorization, ) return StreamingResponse(event_generator(), media_type="text/event-stream") @@ -175,31 +177,28 @@ def update_merged_response(merged_response: dict[str, Any], chunk_json: dict) -> if "finishReason" in candidate: merged_response["candidates"][0]["finishReason"] = candidate["finishReason"] + async def push_to_explorer( - dataset_name: str, - merged_response: dict[str, Any], - request_body: dict[str, Any], - invariant_authorization: str, + context: RequestContextData, + response_json: dict[str, Any], ) -> None: """Pushes the full trace to the Invariant Explorer""" - converted_requests = convert_request(request_body) - converted_responses = convert_response(merged_response) + converted_requests = convert_request(context.request_json) + converted_responses = convert_response(response_json) _ = await push_trace( - dataset_name=dataset_name, + dataset_name=context.dataset_name, messages=[converted_requests + converted_responses], - invariant_authorization=invariant_authorization, + invariant_authorization=context.invariant_authorization, ) async def handle_non_streaming_response( + context: RequestContextData, response: httpx.Response, - dataset_name: Optional[str], - request_body_json: dict[str, Any], - invariant_authorization: Optional[str], ) -> Response: """Handles non-streaming Gemini responses""" try: - json_response = response.json() + response_json = response.json() except json.JSONDecodeError as e: raise HTTPException( status_code=response.status_code, @@ -208,15 +207,13 @@ async def handle_non_streaming_response( if response.status_code != 200: raise HTTPException( status_code=response.status_code, - detail=json_response.get("error", "Unknown error from Gemini API"), - ) - if dataset_name: - await push_to_explorer( - dataset_name, json_response, request_body_json, invariant_authorization + detail=response_json.get("error", "Unknown error from Gemini API"), ) + if context.dataset_name: + await push_to_explorer(context, response_json) return Response( - content=json.dumps(json_response), + content=json.dumps(response_json), status_code=response.status_code, media_type="application/json", headers=dict(response.headers), diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 49b01ac..cbdb52c 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -1,7 +1,7 @@ """Gateway service to forward requests to the OpenAI APIs""" import json -from typing import Any, Optional +from typing import Any import httpx from common.config_manager import GatewayConfig, GatewayConfigManager @@ -13,6 +13,7 @@ from common.constants import ( ) from integrations.explorer import push_trace from common.authorization import extract_authorization_from_headers +from common.request_context_data import RequestContextData gateway = APIRouter() @@ -52,9 +53,7 @@ async def openai_chat_completions_gateway( headers[OPENAI_AUTHORIZATION_HEADER] = openai_api_key request_body_bytes = await request.body() - request_body_json = json.loads(request_body_bytes) - # Check if the request is for streaming - is_streaming = request_body_json.get("stream", False) + request_json = json.loads(request_body_bytes) client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) open_ai_request = client.build_request( @@ -63,26 +62,30 @@ async def openai_chat_completions_gateway( content=request_body_bytes, headers=headers, ) - if is_streaming: + + context = RequestContextData( + request_json=request_json, + dataset_name=dataset_name, + invariant_authorization=invariant_authorization, + ) + + if request_json.get("stream", False): return await stream_response( + context, client, open_ai_request, - dataset_name, - request_body_json, - invariant_authorization, ) response = await client.send(open_ai_request) return await handle_non_streaming_response( - response, dataset_name, request_body_json, invariant_authorization + context, + response, ) async def stream_response( + context: RequestContextData, client: httpx.AsyncClient, open_ai_request: httpx.Request, - dataset_name: Optional[str], - request_body_json: dict[str, Any], - invariant_authorization: Optional[str], ) -> Response: """ Handles streaming the OpenAI response to the client while building a merged_response @@ -138,13 +141,8 @@ async def stream_response( ) # Send full merged response to the explorer - if dataset_name: - await push_to_explorer( - dataset_name, - merged_response, - request_body_json, - invariant_authorization, - ) + if context.dataset_name: + await push_to_explorer(context, merged_response) return StreamingResponse(event_generator(), media_type="text/event-stream") @@ -282,10 +280,7 @@ def update_existing_choice_with_delta( async def push_to_explorer( - dataset_name: str, - merged_response: dict[str, Any], - request_body: dict[str, Any], - invariant_authorization: str, + context: RequestContextData, merged_response: dict[str, Any] ) -> None: """Pushes the full trace to the Invariant Explorer""" # Only push the trace to explorer if the message is an end turn message @@ -296,20 +291,17 @@ async def push_to_explorer( ): return # Combine the messages from the request body and the choices from the OpenAI response - messages = request_body.get("messages", []) + messages = context.request_json.get("messages", []) messages += [choice["message"] for choice in merged_response.get("choices", [])] _ = await push_trace( - dataset_name=dataset_name, + dataset_name=context.dataset_name, messages=[messages], - invariant_authorization=invariant_authorization, + invariant_authorization=context.invariant_authorization, ) async def handle_non_streaming_response( - response: httpx.Response, - dataset_name: Optional[str], - request_body_json: dict[str, Any], - invariant_authorization: Optional[str], + context: RequestContextData, response: httpx.Response ) -> Response: """Handles non-streaming OpenAI responses""" try: @@ -324,10 +316,8 @@ async def handle_non_streaming_response( status_code=response.status_code, detail=json_response.get("error", "Unknown error from OpenAI API"), ) - if dataset_name: - await push_to_explorer( - dataset_name, json_response, request_body_json, invariant_authorization - ) + if context.dataset_name: + await push_to_explorer(context, json_response) return Response( content=json.dumps(json_response), diff --git a/tests/anthropic/test_anthropic_header_with_invariant_key.py b/tests/anthropic/test_anthropic_header_with_invariant_key.py index cbc512e..891a5b1 100644 --- a/tests/anthropic/test_anthropic_header_with_invariant_key.py +++ b/tests/anthropic/test_anthropic_header_with_invariant_key.py @@ -5,14 +5,16 @@ import os import sys from unittest.mock import patch +# Add tests folder (parent) to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + import anthropic import pytest from httpx import Client -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from util import * # needed for pytest fixtures +from util import * # Needed for pytest fixtures +# Pytest plugins pytest_plugins = ("pytest_asyncio",) diff --git a/tests/anthropic/test_anthropic_with_tool_call.py b/tests/anthropic/test_anthropic_with_tool_call.py index 12a25d8..53948ed 100644 --- a/tests/anthropic/test_anthropic_with_tool_call.py +++ b/tests/anthropic/test_anthropic_with_tool_call.py @@ -8,14 +8,16 @@ import sys from pathlib import Path from typing import Dict, List +# Add tests folder (parent) to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + import anthropic import pytest from httpx import Client -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from util import * # needed for pytest fixtures +from util import * # Needed for pytest fixtures +# Pytest plugins pytest_plugins = ("pytest_asyncio",) diff --git a/tests/anthropic/test_anthropic_without_tool_call.py b/tests/anthropic/test_anthropic_without_tool_call.py index 53484c8..cdbd708 100644 --- a/tests/anthropic/test_anthropic_without_tool_call.py +++ b/tests/anthropic/test_anthropic_without_tool_call.py @@ -4,14 +4,16 @@ import datetime import os import sys +# Add tests folder (parent) to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + import anthropic import pytest from httpx import Client -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from util import * # needed for pytest fixtures +from util import * # Needed for pytest fixtures +# Pytest plugins pytest_plugins = ("pytest_asyncio",) diff --git a/tests/gemini/test_generate_content_with_tool_calls.py b/tests/gemini/test_generate_content_with_tool_calls.py index 4e61ed6..8ae5a52 100644 --- a/tests/gemini/test_generate_content_with_tool_calls.py +++ b/tests/gemini/test_generate_content_with_tool_calls.py @@ -3,19 +3,19 @@ import os import sys import uuid -import pytest +# Add tests folder (parent) to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import pytest from google import genai from google.genai import types -# add tests folder (parent) to sys.path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from util import * # needed for pytest fixtures +from util import * # Needed for pytest fixtures +# Pytest plugins pytest_plugins = ("pytest_asyncio",) - def set_light_values(brightness: int, color_temp: str) -> dict[str, int | str]: """Set the brightness and color temperature of a room light. (mock API). diff --git a/tests/gemini/test_generate_content_without_tool_calls.py b/tests/gemini/test_generate_content_without_tool_calls.py index 28a0e68..c1424eb 100644 --- a/tests/gemini/test_generate_content_without_tool_calls.py +++ b/tests/gemini/test_generate_content_without_tool_calls.py @@ -6,17 +6,16 @@ import uuid from pathlib import Path from unittest.mock import patch -import pytest - -from google import genai -import PIL.Image - -# add tests folder (parent) to sys.path +# Add tests folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import pytest +import PIL.Image +from google import genai -from util import * # needed for pytest fixtures +from util import * # Needed for pytest fixtures +# Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -134,7 +133,10 @@ async def test_generate_content_with_image( config={"maxOutputTokens": 100}, ) - assert "TWO" in chat_response.candidates[0].content.parts[0].text.upper() + assert ( + "TWO" in chat_response.candidates[0].content.parts[0].text.upper() + or 2 in chat_response.candidates[0].content.parts[0].text + ) if push_to_explorer: # Fetch the trace ids for the dataset diff --git a/tests/open_ai/test_chat_with_tool_call.py b/tests/open_ai/test_chat_with_tool_call.py index 5afaa28..d49b199 100644 --- a/tests/open_ai/test_chat_with_tool_call.py +++ b/tests/open_ai/test_chat_with_tool_call.py @@ -5,16 +5,16 @@ import os import sys import uuid -import pytest -from httpx import Client - -# add tests folder (parent) to sys.path -from openai import OpenAI - +# Add tests folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from util import * # needed for pytest fixtures +import pytest +from httpx import Client +from openai import OpenAI +from util import * # Needed for pytest fixtures + +# Pytest plugins pytest_plugins = ("pytest_asyncio",) diff --git a/tests/open_ai/test_chat_without_tool_calls.py b/tests/open_ai/test_chat_without_tool_calls.py index b2b9200..0772b2a 100644 --- a/tests/open_ai/test_chat_without_tool_calls.py +++ b/tests/open_ai/test_chat_without_tool_calls.py @@ -7,17 +7,16 @@ import uuid from pathlib import Path from unittest.mock import patch -import pytest -from httpx import Client - -# add tests folder (parent) to sys.path -from openai import NotFoundError, OpenAI - +# Add tests folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import pytest +from httpx import Client +from openai import NotFoundError, OpenAI -from util import * # needed for pytest fixtures +from util import * # Needed for pytest fixtures +# Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -133,7 +132,10 @@ async def test_chat_completion_with_image( max_tokens=100, ) - assert "TWO" in chat_response.choices[0].message.content.upper() + assert ( + "TWO" in chat_response.choices[0].message.content.upper() + or 2 in chat_response.choices[0].message.content + ) if push_to_explorer: # Fetch the trace ids for the dataset