From 2fe0f55cb3ec3ee86452e9db25279269d23c158e Mon Sep 17 00:00:00 2001 From: Hemang Date: Mon, 10 Mar 2025 16:21:42 +0100 Subject: [PATCH] Don't block on push to explorer call for the routes. --- gateway/routes/anthropic.py | 15 ++++++--------- gateway/routes/gemini.py | 14 +++++++++----- gateway/routes/open_ai.py | 7 +++++-- .../test_anthropic_header_with_invariant_key.py | 5 +++++ tests/anthropic/test_anthropic_with_tool_call.py | 10 ++++++++++ .../anthropic/test_anthropic_without_tool_call.py | 7 +++++++ .../test_generate_content_with_tool_calls.py | 4 ++++ .../test_generate_content_without_tool_calls.py | 13 ++++++++++++- tests/open_ai/test_chat_with_tool_call.py | 7 +++++++ tests/open_ai/test_chat_without_tool_calls.py | 13 ++++++++++++- 10 files changed, 77 insertions(+), 18 deletions(-) diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index e0b8ced..982aa3d 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -1,7 +1,8 @@ """Gateway service to forward requests to the Anthropic APIs""" +import asyncio import json -from typing import Any, Optional +from typing import Any import httpx from common.config_manager import GatewayConfig, GatewayConfigManager @@ -121,11 +122,9 @@ async def handle_non_streaming_response( detail=json_response.get("error", "Unknown error from Anthropic"), ) # Only push the trace to explorer if the last message is an end turn message + # Don't block on the response from explorer if context.dataset_name: - await push_to_explorer( - context, - json_response, - ) + asyncio.create_task(push_to_explorer(context, json_response)) return Response( content=json.dumps(json_response), status_code=response.status_code, @@ -161,10 +160,8 @@ async def handle_streaming_response( process_chunk_text(chunk_decode, merged_response) if context.dataset_name: - await push_to_explorer( - context, - merged_response[-1], - ) + # Push to Explorer - don't block on the response + asyncio.create_task(push_to_explorer(context, merged_response[-1])) generator = event_generator() diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index cd88584..8383857 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -1,5 +1,6 @@ """Gateway service to forward requests to the Gemini APIs""" +import asyncio import json from typing import Any @@ -120,10 +121,12 @@ async def stream_response( process_chunk_text(merged_response, chunk_text) if context.dataset_name: - # Push to Explorer - await push_to_explorer( - context, - merged_response, + # Push to Explorer - don't block on the response + asyncio.create_task( + push_to_explorer( + context, + merged_response, + ) ) return StreamingResponse(event_generator(), media_type="text/event-stream") @@ -210,7 +213,8 @@ async def handle_non_streaming_response( detail=response_json.get("error", "Unknown error from Gemini API"), ) if context.dataset_name: - await push_to_explorer(context, response_json) + # Push to Explorer - don't block on the response + asyncio.create_task(push_to_explorer(context, response_json)) return Response( content=json.dumps(response_json), diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index cbdb52c..0528cae 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -1,5 +1,6 @@ """Gateway service to forward requests to the OpenAI APIs""" +import asyncio import json from typing import Any @@ -141,8 +142,9 @@ async def stream_response( ) # Send full merged response to the explorer + # Don't block on the response from explorer if context.dataset_name: - await push_to_explorer(context, merged_response) + asyncio.create_task(push_to_explorer(context, merged_response)) return StreamingResponse(event_generator(), media_type="text/event-stream") @@ -317,7 +319,8 @@ async def handle_non_streaming_response( detail=json_response.get("error", "Unknown error from OpenAI API"), ) if context.dataset_name: - await push_to_explorer(context, json_response) + # Push to Explorer - don't block on its response + asyncio.create_task(push_to_explorer(context, json_response)) return Response( content=json.dumps(json_response), diff --git a/tests/anthropic/test_anthropic_header_with_invariant_key.py b/tests/anthropic/test_anthropic_header_with_invariant_key.py index 891a5b1..1b95900 100644 --- a/tests/anthropic/test_anthropic_header_with_invariant_key.py +++ b/tests/anthropic/test_anthropic_header_with_invariant_key.py @@ -3,6 +3,7 @@ import datetime import os import sys +import time from unittest.mock import patch # Add tests folder (parent) to sys.path @@ -55,6 +56,10 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header( response_text = response.content[0].text assert "zurich" in response_text.lower() + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" ) diff --git a/tests/anthropic/test_anthropic_with_tool_call.py b/tests/anthropic/test_anthropic_with_tool_call.py index 53948ed..8e063d8 100644 --- a/tests/anthropic/test_anthropic_with_tool_call.py +++ b/tests/anthropic/test_anthropic_with_tool_call.py @@ -5,6 +5,7 @@ import datetime import json import os import sys +import time from pathlib import Path from typing import Dict, List @@ -209,6 +210,9 @@ async def test_response_with_tool_call( responses.append(response) if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces" ) @@ -264,6 +268,9 @@ async def test_streaming_response_with_tool_call( assert city in response[1][0].text.lower() if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces" ) @@ -338,6 +345,9 @@ async def test_response_with_tool_call_with_image( assert response[1].stop_reason == "end_turn" if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces" ) diff --git a/tests/anthropic/test_anthropic_without_tool_call.py b/tests/anthropic/test_anthropic_without_tool_call.py index cdbd708..8baf955 100644 --- a/tests/anthropic/test_anthropic_without_tool_call.py +++ b/tests/anthropic/test_anthropic_without_tool_call.py @@ -3,6 +3,7 @@ import datetime import os import sys +import time # Add tests folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -60,6 +61,9 @@ async def test_response_without_tool_call( assert cities[queries.index(query)] in response_text.lower() if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" ) @@ -126,6 +130,9 @@ async def test_streaming_response_without_tool_call( assert cities[queries.index(query)] in response_text.lower() if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" ) diff --git a/tests/gemini/test_generate_content_with_tool_calls.py b/tests/gemini/test_generate_content_with_tool_calls.py index 8ae5a52..9700376 100644 --- a/tests/gemini/test_generate_content_with_tool_calls.py +++ b/tests/gemini/test_generate_content_with_tool_calls.py @@ -2,6 +2,7 @@ import os import sys +import time import uuid # Add tests folder (parent) to sys.path @@ -186,6 +187,9 @@ async def test_generate_content_with_tool_call( expected_final_assistant_message = full_response if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) await _verify_trace_from_explorer( context, explorer_api_url, dataset_name, expected_final_assistant_message ) diff --git a/tests/gemini/test_generate_content_without_tool_calls.py b/tests/gemini/test_generate_content_without_tool_calls.py index c1424eb..cd94e7c 100644 --- a/tests/gemini/test_generate_content_without_tool_calls.py +++ b/tests/gemini/test_generate_content_without_tool_calls.py @@ -2,6 +2,7 @@ import os import sys +import time import uuid from pathlib import Path from unittest.mock import patch @@ -73,6 +74,9 @@ async def test_generate_content( expected_assistant_message = full_response if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) # Fetch the trace ids for the dataset traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" @@ -135,10 +139,13 @@ async def test_generate_content_with_image( assert ( "TWO" in chat_response.candidates[0].content.parts[0].text.upper() - or 2 in chat_response.candidates[0].content.parts[0].text + or "2" in chat_response.candidates[0].content.parts[0].text ) if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) # Fetch the trace ids for the dataset traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" @@ -196,6 +203,10 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( assert "MADRID" in chat_response.candidates[0].content.parts[0].text.upper() expected_assistant_message = chat_response.candidates[0].content.parts[0].text + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + # Fetch the trace ids for the dataset traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" diff --git a/tests/open_ai/test_chat_with_tool_call.py b/tests/open_ai/test_chat_with_tool_call.py index d49b199..1481842 100644 --- a/tests/open_ai/test_chat_with_tool_call.py +++ b/tests/open_ai/test_chat_with_tool_call.py @@ -3,6 +3,7 @@ import json import os import sys +import time import uuid # Add tests folder (parent) to sys.path @@ -101,6 +102,9 @@ async def test_chat_completion_with_tool_call_without_streaming( assert "15°C" in chat_response_final.choices[0].message.content if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) # Fetch the trace ids for the dataset traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" @@ -218,6 +222,9 @@ async def test_chat_completion_with_tool_call_with_streaming( final_response["content"] += chunk.choices[0].delta.content if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) # Fetch the trace ids for the dataset traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" diff --git a/tests/open_ai/test_chat_without_tool_calls.py b/tests/open_ai/test_chat_without_tool_calls.py index 0772b2a..d9ae738 100644 --- a/tests/open_ai/test_chat_without_tool_calls.py +++ b/tests/open_ai/test_chat_without_tool_calls.py @@ -3,6 +3,7 @@ import base64 import os import sys +import time import uuid from pathlib import Path from unittest.mock import patch @@ -61,6 +62,9 @@ async def test_chat_completion( expected_assistant_message = full_response if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) # Fetch the trace ids for the dataset traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" @@ -134,10 +138,13 @@ async def test_chat_completion_with_image( assert ( "TWO" in chat_response.choices[0].message.content.upper() - or 2 in chat_response.choices[0].message.content + or "2" in chat_response.choices[0].message.content ) if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) # Fetch the trace ids for the dataset traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" @@ -202,6 +209,10 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header( assert "PARIS" in chat_response.choices[0].message.content.upper() expected_assistant_message = chat_response.choices[0].message.content + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + # Fetch the trace ids for the dataset traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"