Formatting changes.

This commit is contained in:
Hemang
2025-03-10 14:21:22 +01:00
committed by Hemang Sarkar
parent abbc80890d
commit 20f8a12032
10 changed files with 111 additions and 103 deletions
+11
View File
@@ -0,0 +1,11 @@
"""Common Request context data class."""
from dataclasses import dataclass
from typing import Any, Dict, Optional
@dataclass(frozen=True)
class RequestContextData:
"""Request context data class."""
request_json: Dict[str, Any]
dataset_name: Optional[str] = None
invariant_authorization: Optional[str] = None
+28 -31
View File
@@ -1,7 +1,7 @@
"""Gateway service to forward requests to the Gemini APIs"""
import json
from typing import Any, Optional
from typing import Any
import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
@@ -12,6 +12,7 @@ from common.constants import (
IGNORED_HEADERS,
)
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
from converters.gemini_to_invariant import convert_request, convert_response
from integrations.explorer import push_trace
@@ -52,7 +53,7 @@ async def gemini_generate_content_gateway(
headers[GEMINI_AUTHORIZATION_HEADER] = gemini_api_key
request_body_bytes = await request.body()
request_body_json = json.loads(request_body_bytes)
request_json = json.loads(request_body_bytes)
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
gemini_api_url = f"https://generativelanguage.googleapis.com/{api_version}/models/{model}:{endpoint}"
@@ -65,26 +66,29 @@ async def gemini_generate_content_gateway(
headers=headers,
)
context = RequestContextData(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
)
if alt == "sse" or endpoint == "streamGenerateContent":
return await stream_response(
context,
client,
gemini_request,
dataset_name,
request_body_json,
invariant_authorization,
)
response = await client.send(gemini_request)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
context,
response,
)
async def stream_response(
context: RequestContextData,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
dataset_name: Optional[str],
request_body_json: dict[str, Any],
invariant_authorization: Optional[str],
) -> Response:
"""Handles streaming the Gemini response to the client"""
@@ -115,13 +119,11 @@ async def stream_response(
# Parse and update merged_response incrementally
process_chunk_text(merged_response, chunk_text)
if dataset_name:
if context.dataset_name:
# Push to Explorer
await push_to_explorer(
dataset_name,
context,
merged_response,
request_body_json,
invariant_authorization,
)
return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -175,31 +177,28 @@ def update_merged_response(merged_response: dict[str, Any], chunk_json: dict) ->
if "finishReason" in candidate:
merged_response["candidates"][0]["finishReason"] = candidate["finishReason"]
async def push_to_explorer(
dataset_name: str,
merged_response: dict[str, Any],
request_body: dict[str, Any],
invariant_authorization: str,
context: RequestContextData,
response_json: dict[str, Any],
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
converted_requests = convert_request(request_body)
converted_responses = convert_response(merged_response)
converted_requests = convert_request(context.request_json)
converted_responses = convert_response(response_json)
_ = await push_trace(
dataset_name=dataset_name,
dataset_name=context.dataset_name,
messages=[converted_requests + converted_responses],
invariant_authorization=invariant_authorization,
invariant_authorization=context.invariant_authorization,
)
async def handle_non_streaming_response(
context: RequestContextData,
response: httpx.Response,
dataset_name: Optional[str],
request_body_json: dict[str, Any],
invariant_authorization: Optional[str],
) -> Response:
"""Handles non-streaming Gemini responses"""
try:
json_response = response.json()
response_json = response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=response.status_code,
@@ -208,15 +207,13 @@ async def handle_non_streaming_response(
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from Gemini API"),
)
if dataset_name:
await push_to_explorer(
dataset_name, json_response, request_body_json, invariant_authorization
detail=response_json.get("error", "Unknown error from Gemini API"),
)
if context.dataset_name:
await push_to_explorer(context, response_json)
return Response(
content=json.dumps(json_response),
content=json.dumps(response_json),
status_code=response.status_code,
media_type="application/json",
headers=dict(response.headers),
+24 -34
View File
@@ -1,7 +1,7 @@
"""Gateway service to forward requests to the OpenAI APIs"""
import json
from typing import Any, Optional
from typing import Any
import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
@@ -13,6 +13,7 @@ from common.constants import (
)
from integrations.explorer import push_trace
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
gateway = APIRouter()
@@ -52,9 +53,7 @@ async def openai_chat_completions_gateway(
headers[OPENAI_AUTHORIZATION_HEADER] = openai_api_key
request_body_bytes = await request.body()
request_body_json = json.loads(request_body_bytes)
# Check if the request is for streaming
is_streaming = request_body_json.get("stream", False)
request_json = json.loads(request_body_bytes)
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
open_ai_request = client.build_request(
@@ -63,26 +62,30 @@ async def openai_chat_completions_gateway(
content=request_body_bytes,
headers=headers,
)
if is_streaming:
context = RequestContextData(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
)
if request_json.get("stream", False):
return await stream_response(
context,
client,
open_ai_request,
dataset_name,
request_body_json,
invariant_authorization,
)
response = await client.send(open_ai_request)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
context,
response,
)
async def stream_response(
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
dataset_name: Optional[str],
request_body_json: dict[str, Any],
invariant_authorization: Optional[str],
) -> Response:
"""
Handles streaming the OpenAI response to the client while building a merged_response
@@ -138,13 +141,8 @@ async def stream_response(
)
# Send full merged response to the explorer
if dataset_name:
await push_to_explorer(
dataset_name,
merged_response,
request_body_json,
invariant_authorization,
)
if context.dataset_name:
await push_to_explorer(context, merged_response)
return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -282,10 +280,7 @@ def update_existing_choice_with_delta(
async def push_to_explorer(
dataset_name: str,
merged_response: dict[str, Any],
request_body: dict[str, Any],
invariant_authorization: str,
context: RequestContextData, merged_response: dict[str, Any]
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
# Only push the trace to explorer if the message is an end turn message
@@ -296,20 +291,17 @@ async def push_to_explorer(
):
return
# Combine the messages from the request body and the choices from the OpenAI response
messages = request_body.get("messages", [])
messages = context.request_json.get("messages", [])
messages += [choice["message"] for choice in merged_response.get("choices", [])]
_ = await push_trace(
dataset_name=dataset_name,
dataset_name=context.dataset_name,
messages=[messages],
invariant_authorization=invariant_authorization,
invariant_authorization=context.invariant_authorization,
)
async def handle_non_streaming_response(
response: httpx.Response,
dataset_name: Optional[str],
request_body_json: dict[str, Any],
invariant_authorization: Optional[str],
context: RequestContextData, response: httpx.Response
) -> Response:
"""Handles non-streaming OpenAI responses"""
try:
@@ -324,10 +316,8 @@ async def handle_non_streaming_response(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from OpenAI API"),
)
if dataset_name:
await push_to_explorer(
dataset_name, json_response, request_body_json, invariant_authorization
)
if context.dataset_name:
await push_to_explorer(context, json_response)
return Response(
content=json.dumps(json_response),
@@ -5,14 +5,16 @@ import os
import sys
from unittest.mock import patch
# Add tests folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import anthropic
import pytest
from httpx import Client
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from util import * # needed for pytest fixtures
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -8,14 +8,16 @@ import sys
from pathlib import Path
from typing import Dict, List
# Add tests folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import anthropic
import pytest
from httpx import Client
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from util import * # needed for pytest fixtures
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -4,14 +4,16 @@ import datetime
import os
import sys
# Add tests folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import anthropic
import pytest
from httpx import Client
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from util import * # needed for pytest fixtures
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -3,19 +3,19 @@
import os
import sys
import uuid
import pytest
# Add tests folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
from google import genai
from google.genai import types
# add tests folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from util import * # needed for pytest fixtures
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
def set_light_values(brightness: int, color_temp: str) -> dict[str, int | str]:
"""Set the brightness and color temperature of a room light. (mock API).
@@ -6,17 +6,16 @@ import uuid
from pathlib import Path
from unittest.mock import patch
import pytest
from google import genai
import PIL.Image
# add tests folder (parent) to sys.path
# Add tests folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
import PIL.Image
from google import genai
from util import * # needed for pytest fixtures
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -134,7 +133,10 @@ async def test_generate_content_with_image(
config={"maxOutputTokens": 100},
)
assert "TWO" in chat_response.candidates[0].content.parts[0].text.upper()
assert (
"TWO" in chat_response.candidates[0].content.parts[0].text.upper()
or 2 in chat_response.candidates[0].content.parts[0].text
)
if push_to_explorer:
# Fetch the trace ids for the dataset
+7 -7
View File
@@ -5,16 +5,16 @@ import os
import sys
import uuid
import pytest
from httpx import Client
# add tests folder (parent) to sys.path
from openai import OpenAI
# Add tests folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from util import * # needed for pytest fixtures
import pytest
from httpx import Client
from openai import OpenAI
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
+10 -8
View File
@@ -7,17 +7,16 @@ import uuid
from pathlib import Path
from unittest.mock import patch
import pytest
from httpx import Client
# add tests folder (parent) to sys.path
from openai import NotFoundError, OpenAI
# Add tests folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
from httpx import Client
from openai import NotFoundError, OpenAI
from util import * # needed for pytest fixtures
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -133,7 +132,10 @@ async def test_chat_completion_with_image(
max_tokens=100,
)
assert "TWO" in chat_response.choices[0].message.content.upper()
assert (
"TWO" in chat_response.choices[0].message.content.upper()
or 2 in chat_response.choices[0].message.content
)
if push_to_explorer:
# Fetch the trace ids for the dataset