Don't block on push to explorer call for the routes.

This commit is contained in:
Hemang
2025-03-10 16:21:42 +01:00
committed by Hemang Sarkar
parent 7d96ae7af3
commit 2fe0f55cb3
10 changed files with 77 additions and 18 deletions
+6 -9
View File
@@ -1,7 +1,8 @@
"""Gateway service to forward requests to the Anthropic APIs"""
import asyncio
import json
from typing import Any, Optional
from typing import Any
import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
@@ -121,11 +122,9 @@ async def handle_non_streaming_response(
detail=json_response.get("error", "Unknown error from Anthropic"),
)
# Only push the trace to explorer if the last message is an end turn message
# Don't block on the response from explorer
if context.dataset_name:
await push_to_explorer(
context,
json_response,
)
asyncio.create_task(push_to_explorer(context, json_response))
return Response(
content=json.dumps(json_response),
status_code=response.status_code,
@@ -161,10 +160,8 @@ async def handle_streaming_response(
process_chunk_text(chunk_decode, merged_response)
if context.dataset_name:
await push_to_explorer(
context,
merged_response[-1],
)
# Push to Explorer - don't block on the response
asyncio.create_task(push_to_explorer(context, merged_response[-1]))
generator = event_generator()
+9 -5
View File
@@ -1,5 +1,6 @@
"""Gateway service to forward requests to the Gemini APIs"""
import asyncio
import json
from typing import Any
@@ -120,10 +121,12 @@ async def stream_response(
process_chunk_text(merged_response, chunk_text)
if context.dataset_name:
# Push to Explorer
await push_to_explorer(
context,
merged_response,
# Push to Explorer - don't block on the response
asyncio.create_task(
push_to_explorer(
context,
merged_response,
)
)
return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -210,7 +213,8 @@ async def handle_non_streaming_response(
detail=response_json.get("error", "Unknown error from Gemini API"),
)
if context.dataset_name:
await push_to_explorer(context, response_json)
# Push to Explorer - don't block on the response
asyncio.create_task(push_to_explorer(context, response_json))
return Response(
content=json.dumps(response_json),
+5 -2
View File
@@ -1,5 +1,6 @@
"""Gateway service to forward requests to the OpenAI APIs"""
import asyncio
import json
from typing import Any
@@ -141,8 +142,9 @@ async def stream_response(
)
# Send full merged response to the explorer
# Don't block on the response from explorer
if context.dataset_name:
await push_to_explorer(context, merged_response)
asyncio.create_task(push_to_explorer(context, merged_response))
return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -317,7 +319,8 @@ async def handle_non_streaming_response(
detail=json_response.get("error", "Unknown error from OpenAI API"),
)
if context.dataset_name:
await push_to_explorer(context, json_response)
# Push to Explorer - don't block on its response
asyncio.create_task(push_to_explorer(context, json_response))
return Response(
content=json.dumps(json_response),
@@ -3,6 +3,7 @@
import datetime
import os
import sys
import time
from unittest.mock import patch
# Add tests folder (parent) to sys.path
@@ -55,6 +56,10 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header(
response_text = response.content[0].text
assert "zurich" in response_text.lower()
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
@@ -5,6 +5,7 @@ import datetime
import json
import os
import sys
import time
from pathlib import Path
from typing import Dict, List
@@ -209,6 +210,9 @@ async def test_response_with_tool_call(
responses.append(response)
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
)
@@ -264,6 +268,9 @@ async def test_streaming_response_with_tool_call(
assert city in response[1][0].text.lower()
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
)
@@ -338,6 +345,9 @@ async def test_response_with_tool_call_with_image(
assert response[1].stop_reason == "end_turn"
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
)
@@ -3,6 +3,7 @@
import datetime
import os
import sys
import time
# Add tests folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -60,6 +61,9 @@ async def test_response_without_tool_call(
assert cities[queries.index(query)] in response_text.lower()
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
@@ -126,6 +130,9 @@ async def test_streaming_response_without_tool_call(
assert cities[queries.index(query)] in response_text.lower()
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
@@ -2,6 +2,7 @@
import os
import sys
import time
import uuid
# Add tests folder (parent) to sys.path
@@ -186,6 +187,9 @@ async def test_generate_content_with_tool_call(
expected_final_assistant_message = full_response
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
await _verify_trace_from_explorer(
context, explorer_api_url, dataset_name, expected_final_assistant_message
)
@@ -2,6 +2,7 @@
import os
import sys
import time
import uuid
from pathlib import Path
from unittest.mock import patch
@@ -73,6 +74,9 @@ async def test_generate_content(
expected_assistant_message = full_response
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
@@ -135,10 +139,13 @@ async def test_generate_content_with_image(
assert (
"TWO" in chat_response.candidates[0].content.parts[0].text.upper()
or 2 in chat_response.candidates[0].content.parts[0].text
or "2" in chat_response.candidates[0].content.parts[0].text
)
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
@@ -196,6 +203,10 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
assert "MADRID" in chat_response.candidates[0].content.parts[0].text.upper()
expected_assistant_message = chat_response.candidates[0].content.parts[0].text
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
@@ -3,6 +3,7 @@
import json
import os
import sys
import time
import uuid
# Add tests folder (parent) to sys.path
@@ -101,6 +102,9 @@ async def test_chat_completion_with_tool_call_without_streaming(
assert "15°C" in chat_response_final.choices[0].message.content
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
@@ -218,6 +222,9 @@ async def test_chat_completion_with_tool_call_with_streaming(
final_response["content"] += chunk.choices[0].delta.content
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
+12 -1
View File
@@ -3,6 +3,7 @@
import base64
import os
import sys
import time
import uuid
from pathlib import Path
from unittest.mock import patch
@@ -61,6 +62,9 @@ async def test_chat_completion(
expected_assistant_message = full_response
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
@@ -134,10 +138,13 @@ async def test_chat_completion_with_image(
assert (
"TWO" in chat_response.choices[0].message.content.upper()
or 2 in chat_response.choices[0].message.content
or "2" in chat_response.choices[0].message.content
)
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
@@ -202,6 +209,10 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header(
assert "PARIS" in chat_response.choices[0].message.content.upper()
expected_assistant_message = chat_response.choices[0].message.content
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"