diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index d84a5b3..e0b8ced 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -12,8 +12,11 @@ from common.constants import ( IGNORED_HEADERS, ) from integrations.explorer import push_trace -from converters.anthropic_to_invariant import convert_anthropic_to_invariant_message_format +from converters.anthropic_to_invariant import ( + convert_anthropic_to_invariant_message_format, +) from common.authorization import extract_authorization_from_headers +from common.request_context_data import RequestContextData gateway = APIRouter() @@ -62,7 +65,7 @@ async def anthropic_v1_messages_gateway( headers[ANTHROPIC_AUTHORIZATION_HEADER] = anthopic_api_key request_body = await request.body() - request_body_json = json.loads(request_body) + request_json = json.loads(request_body) client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) anthropic_request = client.build_request( "POST", @@ -71,40 +74,38 @@ async def anthropic_v1_messages_gateway( data=request_body, ) - if request_body_json.get("stream"): - return await handle_streaming_response( - client, anthropic_request, dataset_name, invariant_authorization - ) - response = await client.send(anthropic_request) - return await handle_non_streaming_response( - response, dataset_name, request_body_json, invariant_authorization + context = RequestContextData( + request_json=request_json, + dataset_name=dataset_name, + invariant_authorization=invariant_authorization, ) + if request_json.get("stream"): + return await handle_streaming_response(context, client, anthropic_request) + response = await client.send(anthropic_request) + return await handle_non_streaming_response(context, response) + async def push_to_explorer( - dataset_name: str, + context: RequestContextData, merged_response: dict[str, Any], - request_body: dict[str, Any], - invariant_authorization: str, ) -> None: """Pushes the full trace to the Invariant Explorer""" # Combine the messages from the request body and Anthropic response - messages = request_body.get("messages", []) + messages = context.request_json.get("messages", []) messages += [merged_response] converted_messages = convert_anthropic_to_invariant_message_format(messages) _ = await push_trace( - dataset_name=dataset_name, + dataset_name=context.dataset_name, messages=[converted_messages], - invariant_authorization=invariant_authorization, + invariant_authorization=context.invariant_authorization, ) async def handle_non_streaming_response( + context: RequestContextData, response: httpx.Response, - dataset_name: Optional[str], - request_body_json: dict[str, Any], - invariant_authorization: Optional[str], ) -> Response: """Handles non-streaming Anthropic responses""" try: @@ -120,12 +121,10 @@ async def handle_non_streaming_response( detail=json_response.get("error", "Unknown error from Anthropic"), ) # Only push the trace to explorer if the last message is an end turn message - if dataset_name: + if context.dataset_name: await push_to_explorer( - dataset_name, + context, json_response, - request_body_json, - invariant_authorization, ) return Response( content=json.dumps(json_response), @@ -136,10 +135,9 @@ async def handle_non_streaming_response( async def handle_streaming_response( + context: RequestContextData, client: httpx.AsyncClient, anthropic_request: httpx.Request, - dataset_name: Optional[str], - invariant_authorization: Optional[str], ) -> StreamingResponse: """Handles streaming Anthropic responses""" merged_response = [] @@ -162,12 +160,10 @@ async def handle_streaming_response( yield chunk process_chunk_text(chunk_decode, merged_response) - if dataset_name: + if context.dataset_name: await push_to_explorer( - dataset_name, + context, merged_response[-1], - json.loads(anthropic_request.content), - invariant_authorization, ) generator = event_generator()