From 2789ec2b98496cfd8fba729875026b6c961a3f3a Mon Sep 17 00:00:00 2001 From: Hemang Date: Thu, 13 Feb 2025 14:11:47 +0100 Subject: [PATCH] Use the AsyncClient from the sdk to push traces to Explorer. --- proxy/requirements.txt | 6 ++---- proxy/utils/explorer.py | 43 +++++++++++++++-------------------------- 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/proxy/requirements.txt b/proxy/requirements.txt index 193dbfb..a564b13 100644 --- a/proxy/requirements.txt +++ b/proxy/requirements.txt @@ -1,7 +1,5 @@ fastapi==0.115.7 httpx==0.28.1 -uvicorn==0.34.0 -invariant-sdk +invariant-sdk>=0.0.10 starlette-compress==1.4.0 -tavily-python -anthropic \ No newline at end of file +uvicorn==0.34.0 \ No newline at end of file diff --git a/proxy/utils/explorer.py b/proxy/utils/explorer.py index d54b4e6..71108d7 100644 --- a/proxy/utils/explorer.py +++ b/proxy/utils/explorer.py @@ -3,53 +3,42 @@ import os from typing import Any, Dict, List -import httpx -from invariant_sdk.types.push_traces import PushTracesRequest +from invariant_sdk.async_client import AsyncClient +from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse DEFAULT_API_URL = "https://explorer.invariantlabs.ai" -PUSH_ENDPOINT = "/api/v1/push/trace" async def push_trace( messages: List[List[Dict[str, Any]]], dataset_name: str, invariant_authorization: str, -) -> Dict[str, str]: +) -> PushTracesResponse: """Pushes traces to the dataset on the Invariant Explorer. + If a dataset with the given name does not exist, it will be created. + Args: messages (List[List[Dict[str, Any]]]): List of messages to push. dataset_name (str): Name of the dataset. - invariant_authorization (str): Authorization token from the + invariant_authorization (str): Value of the invariant-authorization header. Returns: - Dict[str, str]: Response containing the trace ID. + PushTracesResponse: Response containing the trace ID details. """ - api_url = os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/") # Remove any None values from the messages update_messages = [ [{k: v for k, v in msg.items() if v is not None} for msg in msg_list] for msg_list in messages ] request = PushTracesRequest(messages=update_messages, dataset=dataset_name) - async with httpx.AsyncClient() as client: - explorer_push_request = client.build_request( - "POST", - f"{api_url}{PUSH_ENDPOINT}", - json=request.to_json(), - headers={ - "Authorization": f"{invariant_authorization}", - "Accept": "application/json", - }, - ) - try: - response = await client.send(explorer_push_request) - response.raise_for_status() - return response.json() - except httpx.HTTPStatusError as e: - print(f"Failed to push trace: {e.response.text}") - return {"error": str(e)} - except Exception as e: - print(f"Unexpected error pushing trace: {str(e)}") - return {"error": str(e)} + client = AsyncClient( + api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"), + api_key=invariant_authorization.split("Bearer ")[1], + ) + try: + return await client.push_trace(request) + except Exception as e: + print(f"Failed to push trace: {e}") + return {"error": str(e)}