mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-16 04:49:04 +02:00
Refactor Anthropic code to use RequestContextData
This commit is contained in:
+24
-28
@@ -12,8 +12,11 @@ from common.constants import (
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from integrations.explorer import push_trace
|
||||
from converters.anthropic_to_invariant import convert_anthropic_to_invariant_message_format
|
||||
from converters.anthropic_to_invariant import (
|
||||
convert_anthropic_to_invariant_message_format,
|
||||
)
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
@@ -62,7 +65,7 @@ async def anthropic_v1_messages_gateway(
|
||||
headers[ANTHROPIC_AUTHORIZATION_HEADER] = anthopic_api_key
|
||||
|
||||
request_body = await request.body()
|
||||
request_body_json = json.loads(request_body)
|
||||
request_json = json.loads(request_body)
|
||||
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
|
||||
anthropic_request = client.build_request(
|
||||
"POST",
|
||||
@@ -71,40 +74,38 @@ async def anthropic_v1_messages_gateway(
|
||||
data=request_body,
|
||||
)
|
||||
|
||||
if request_body_json.get("stream"):
|
||||
return await handle_streaming_response(
|
||||
client, anthropic_request, dataset_name, invariant_authorization
|
||||
)
|
||||
response = await client.send(anthropic_request)
|
||||
return await handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, invariant_authorization
|
||||
context = RequestContextData(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
)
|
||||
|
||||
if request_json.get("stream"):
|
||||
return await handle_streaming_response(context, client, anthropic_request)
|
||||
response = await client.send(anthropic_request)
|
||||
return await handle_non_streaming_response(context, response)
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
dataset_name: str,
|
||||
context: RequestContextData,
|
||||
merged_response: dict[str, Any],
|
||||
request_body: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
# Combine the messages from the request body and Anthropic response
|
||||
messages = request_body.get("messages", [])
|
||||
messages = context.request_json.get("messages", [])
|
||||
messages += [merged_response]
|
||||
|
||||
converted_messages = convert_anthropic_to_invariant_message_format(messages)
|
||||
_ = await push_trace(
|
||||
dataset_name=dataset_name,
|
||||
dataset_name=context.dataset_name,
|
||||
messages=[converted_messages],
|
||||
invariant_authorization=invariant_authorization,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
response: httpx.Response,
|
||||
dataset_name: Optional[str],
|
||||
request_body_json: dict[str, Any],
|
||||
invariant_authorization: Optional[str],
|
||||
) -> Response:
|
||||
"""Handles non-streaming Anthropic responses"""
|
||||
try:
|
||||
@@ -120,12 +121,10 @@ async def handle_non_streaming_response(
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
# Only push the trace to explorer if the last message is an end turn message
|
||||
if dataset_name:
|
||||
if context.dataset_name:
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
context,
|
||||
json_response,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
)
|
||||
return Response(
|
||||
content=json.dumps(json_response),
|
||||
@@ -136,10 +135,9 @@ async def handle_non_streaming_response(
|
||||
|
||||
|
||||
async def handle_streaming_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
dataset_name: Optional[str],
|
||||
invariant_authorization: Optional[str],
|
||||
) -> StreamingResponse:
|
||||
"""Handles streaming Anthropic responses"""
|
||||
merged_response = []
|
||||
@@ -162,12 +160,10 @@ async def handle_streaming_response(
|
||||
yield chunk
|
||||
|
||||
process_chunk_text(chunk_decode, merged_response)
|
||||
if dataset_name:
|
||||
if context.dataset_name:
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
context,
|
||||
merged_response[-1],
|
||||
json.loads(anthropic_request.content),
|
||||
invariant_authorization,
|
||||
)
|
||||
|
||||
generator = event_generator()
|
||||
|
||||
Reference in New Issue
Block a user