Refactor Anthropic code to use RequestContextData

This commit is contained in:
Hemang
2025-03-10 15:45:30 +01:00
parent 0554970ce2
commit 7d96ae7af3
+24 -28
View File
@@ -12,8 +12,11 @@ from common.constants import (
IGNORED_HEADERS,
)
from integrations.explorer import push_trace
from converters.anthropic_to_invariant import convert_anthropic_to_invariant_message_format
from converters.anthropic_to_invariant import (
convert_anthropic_to_invariant_message_format,
)
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
gateway = APIRouter()
@@ -62,7 +65,7 @@ async def anthropic_v1_messages_gateway(
headers[ANTHROPIC_AUTHORIZATION_HEADER] = anthopic_api_key
request_body = await request.body()
request_body_json = json.loads(request_body)
request_json = json.loads(request_body)
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
anthropic_request = client.build_request(
"POST",
@@ -71,40 +74,38 @@ async def anthropic_v1_messages_gateway(
data=request_body,
)
if request_body_json.get("stream"):
return await handle_streaming_response(
client, anthropic_request, dataset_name, invariant_authorization
)
response = await client.send(anthropic_request)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
context = RequestContextData(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
)
if request_json.get("stream"):
return await handle_streaming_response(context, client, anthropic_request)
response = await client.send(anthropic_request)
return await handle_non_streaming_response(context, response)
async def push_to_explorer(
dataset_name: str,
context: RequestContextData,
merged_response: dict[str, Any],
request_body: dict[str, Any],
invariant_authorization: str,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
# Combine the messages from the request body and Anthropic response
messages = request_body.get("messages", [])
messages = context.request_json.get("messages", [])
messages += [merged_response]
converted_messages = convert_anthropic_to_invariant_message_format(messages)
_ = await push_trace(
dataset_name=dataset_name,
dataset_name=context.dataset_name,
messages=[converted_messages],
invariant_authorization=invariant_authorization,
invariant_authorization=context.invariant_authorization,
)
async def handle_non_streaming_response(
context: RequestContextData,
response: httpx.Response,
dataset_name: Optional[str],
request_body_json: dict[str, Any],
invariant_authorization: Optional[str],
) -> Response:
"""Handles non-streaming Anthropic responses"""
try:
@@ -120,12 +121,10 @@ async def handle_non_streaming_response(
detail=json_response.get("error", "Unknown error from Anthropic"),
)
# Only push the trace to explorer if the last message is an end turn message
if dataset_name:
if context.dataset_name:
await push_to_explorer(
dataset_name,
context,
json_response,
request_body_json,
invariant_authorization,
)
return Response(
content=json.dumps(json_response),
@@ -136,10 +135,9 @@ async def handle_non_streaming_response(
async def handle_streaming_response(
context: RequestContextData,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
dataset_name: Optional[str],
invariant_authorization: Optional[str],
) -> StreamingResponse:
"""Handles streaming Anthropic responses"""
merged_response = []
@@ -162,12 +160,10 @@ async def handle_streaming_response(
yield chunk
process_chunk_text(chunk_decode, merged_response)
if dataset_name:
if context.dataset_name:
await push_to_explorer(
dataset_name,
context,
merged_response[-1],
json.loads(anthropic_request.content),
invariant_authorization,
)
generator = event_generator()