anthropic integration of pipelined and pre-guardrailing

This commit is contained in:
Luca Beurer-Kellner
2025-03-28 20:53:23 +01:00
committed by Hemang
parent 7f820bd79f
commit c2177faaa8
5 changed files with 403 additions and 110 deletions
+12
View File
@@ -118,6 +118,18 @@ class ExtraItem:
return f"<ExtraItem value={self.value} end_of_stream={self.end_of_stream}>"
class Replacement(ExtraItem):
"""
Like ExtraItem, but used to replace the full request result in case of 'InstrumentedResponse'.
"""
def __init__(self, value):
super().__init__(value, end_of_stream=True)
def __str__(self):
return f"<Replacement value={self.value}>"
class InstrumentedStreamingResponse:
def __init__(self):
# request statistics
+290 -104
View File
@@ -5,6 +5,7 @@ import json
from typing import Any, Optional
import httpx
from regex import R
from common.config_manager import GatewayConfig, GatewayConfigManager
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse
@@ -18,7 +19,14 @@ from converters.anthropic_to_invariant import (
)
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
from integrations.guardrails import check_guardrails, preload_guardrails
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
InstrumentedStreamingResponse,
Replacement,
check_guardrails,
preload_guardrails,
)
gateway = APIRouter()
@@ -85,8 +93,7 @@ async def anthropic_v1_messages_gateway(
if request_json.get("stream"):
return await handle_streaming_response(context, client, anthropic_request)
response = await client.send(anthropic_request)
return await handle_non_streaming_response(context, response)
return await handle_non_streaming_response(context, client, anthropic_request)
def create_metadata(
@@ -110,7 +117,8 @@ def combine_request_and_response_messages(
{"role": "system", "content": context.request_json.get("system")}
)
messages.extend(context.request_json.get("messages", []))
messages.append(json_response)
if len(json_response) > 0:
messages.append(json_response)
return messages
@@ -154,56 +162,282 @@ async def push_to_explorer(
)
class InstrumentedAnthropicResponse(InstrumentedResponse):
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
):
super().__init__()
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.anthropic_request: httpx.Request = anthropic_request
# response data
self.response: Optional[httpx.Response] = None
self.response_string: Optional[str] = None
self.json_response: Optional[dict[str, Any]] = None
# guardrailing response (if any)
self.guardrails_execution_result = {}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The request did not pass the guardrails",
"details": self.guardrails_execution_result,
}
}
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
{},
self.guardrails_execution_result,
)
)
# if we find something, we prevent the request from going through
# and return an error instead
return Replacement(
Response(
content=error_chunk,
status_code=400,
media_type="application/json",
headers={"content-type": "application/json"},
)
)
async def request(self):
self.response = await self.client.send(self.anthropic_request)
try:
json_response = self.response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
detail=f"Invalid JSON response received from Anthropic: {self.response.text}, got error{e}",
) from e
if self.response.status_code != 200:
raise HTTPException(
status_code=self.response.status_code,
detail=json_response.get("error", "Unknown error from Anthropic"),
)
self.json_response = json_response
self.response_string = json.dumps(json_response)
return self._make_response(
content=self.response_string,
status_code=self.response.status_code,
)
def _make_response(self, content: str, status_code: int):
"""Creates a new Response object with the correct headers and content"""
assert self.response is not None, "response is None"
updated_headers = self.response.headers.copy()
updated_headers.pop("Content-Length", None)
return Response(
content=content,
status_code=status_code,
media_type="application/json",
headers=dict(updated_headers),
)
async def on_end(self):
"""Checks guardrails after the response is received, and asynchronously pushes to Explorer."""
# ensure the response data is available
assert self.response is not None, "response is None"
assert self.json_response is not None, "json_response is None"
assert self.response_string is not None, "response_string is None"
if self.context.config and self.context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
self.context, self.json_response
)
if guardrails_execution_result.get("errors", []):
guardrail_response_string = json.dumps(
{
"error": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
}
)
# push to explorer (if configured)
if self.context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
guardrails_execution_result,
)
)
return Replacement(
self._make_response(
content=guardrail_response_string,
status_code=400,
)
)
# push to explorer (if configured)
if self.context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(
self.context, self.json_response, guardrails_execution_result
)
)
async def handle_non_streaming_response(
context: RequestContextData,
response: httpx.Response,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
) -> Response:
"""Handles non-streaming Anthropic responses"""
try:
json_response = response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=response.status_code,
detail=f"Invalid JSON response received from Anthropic: {response.text}, got error{e}",
) from e
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from Anthropic"),
)
guardrails_execution_result = {}
response_string = json.dumps(json_response)
response_code = response.status_code
if context.config and context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, json_response
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
}
)
response_code = 400
if context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(context, json_response, guardrails_execution_result)
)
updated_headers = response.headers.copy()
updated_headers.pop("Content-Length", None)
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(updated_headers),
response = InstrumentedAnthropicResponse(
context=context,
client=client,
anthropic_request=anthropic_request,
)
return await response.instrumented_request()
class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
):
super().__init__()
# request parameters
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.anthropic_request: httpx.Request = anthropic_request
# response data
self.merged_response = {}
# guardrailing response (if any)
self.guardrails_execution_result = {}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The request did not pass the guardrails",
"details": self.guardrails_execution_result,
}
}
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.merged_response,
self.guardrails_execution_result,
)
)
# if we find something, we end the stream prematurely (end_of_stream=True)
# and yield an error chunk instead of actually beginning the stream
return ExtraItem(
f"event: error\ndata: {error_chunk}\n\n".encode(),
end_of_stream=True,
)
async def event_generator(self):
"""Actual streaming response generator"""
response = await self.client.send(self.anthropic_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content)
error_detail = error_json.get("error", "Unknown error from Anthropic")
except json.JSONDecodeError:
error_detail = {
"error": "Failed to decode error response from Anthropic"
}
raise HTTPException(status_code=response.status_code, detail=error_detail)
# iterate over the response stream
async for chunk in response.aiter_bytes():
yield chunk
async def on_chunk(self, chunk):
decoded_chunk = chunk.decode().strip()
if not decoded_chunk:
return
# process chunk and extend the merged_response
process_chunk(decoded_chunk, self.merged_response)
# on last stream chunk, run output guardrails
if (
"event: message_stop" in decoded_chunk
and self.context.config
and self.context.config.guardrails
):
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"type": "error",
"error": {
"message": "[Invariant] The response did not pass the guardrails",
"details": self.guardrails_execution_result,
},
}
)
# yield an extra error chunk (without preventing the original chunk to go through after,
# so client gets the proper message_stop event still)
return ExtraItem(
value=f"event: error\ndata: {error_chunk}\n\n".encode()
)
async def on_end(self):
"""on_end: send full merged response to the exploree (if configured)"""
# don't block on the response from explorer (.create_task)
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.merged_response,
self.guardrails_execution_result,
)
)
async def handle_streaming_response(
context: RequestContextData,
@@ -211,63 +445,15 @@ async def handle_streaming_response(
anthropic_request: httpx.Request,
) -> StreamingResponse:
"""Handles streaming Anthropic responses"""
merged_response = {}
response = InstrumentedAnthropicStreamingResposne(
context=context,
client=client,
anthropic_request=anthropic_request,
)
response = await client.send(anthropic_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content)
error_detail = error_json.get("error", "Unknown error from Anthropic")
except json.JSONDecodeError:
error_detail = {"error": "Failed to decode error response from Anthropic"}
raise HTTPException(status_code=response.status_code, detail=error_detail)
async def event_generator() -> Any:
async for chunk in response.aiter_bytes():
decoded_chunk = chunk.decode().strip()
if not decoded_chunk:
continue
process_chunk(decoded_chunk, merged_response)
if (
"event: message_stop" in decoded_chunk
and context.config
and context.config.guardrails
):
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, merged_response
)
if guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"type": "error",
"error": {
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
},
}
)
# Push annotated trace to the explorer - don't block on its response
if context.dataset_name:
asyncio.create_task(
push_to_explorer(
context,
merged_response,
guardrails_execution_result,
)
)
yield f"event: error\ndata: {error_chunk}\n\n".encode()
return
yield chunk
if context.dataset_name:
# Push to Explorer - don't block on the response
asyncio.create_task(push_to_explorer(context, merged_response))
generator = event_generator()
return StreamingResponse(generator, media_type="text/event-stream")
return StreamingResponse(
response.instrumented_event_generator(), media_type="text/event-stream"
)
def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None:
+6 -5
View File
@@ -80,13 +80,13 @@ async def openai_chat_completions_gateway(
asyncio.create_task(preload_guardrails(context))
if request_json.get("stream", False):
return await stream_response(
return await handle_stream_response(
context,
client,
open_ai_request,
)
return await non_stream_response(context, client, open_ai_request)
return await handle_non_stream_response(context, client, open_ai_request)
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
@@ -158,7 +158,8 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
# if we find something, we end the stream prematurely (end_of_stream=True)
# and yield an error chunk instead of actually beginning the stream
return ExtraItem(
f"data: {error_chunk}\n\n".encode(), end_of_stream=True
f"data: {error_chunk}\n\n".encode(),
end_of_stream=True,
)
async def on_chunk(self, chunk):
@@ -231,7 +232,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
yield chunk
async def stream_response(
async def handle_stream_response(
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
@@ -598,7 +599,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
)
async def non_stream_response(
async def handle_non_stream_response(
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
@@ -238,3 +238,97 @@ async def test_tool_call_guardrail_from_file(
== "get_capital is called with Germany as argument"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize(
"do_stream, push_to_explorer",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_input_from_guardrail_from_file(
explorer_api_url, gateway_url, do_stream, push_to_explorer
):
"""Test input guardrail enforcement with Anthropic."""
if not os.getenv("INVARIANT_API_KEY"):
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
client = Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/anthropic",
)
request = {
"model": "claude-3-5-sonnet-20241022",
"max_tokens": 100,
"messages": [{"role": "user", "content": "Tell me more about Fight Club."}],
}
if not do_stream:
with pytest.raises(BadRequestError) as exc_info:
_ = client.messages.create(**request, stream=False)
assert exc_info.value.status_code == 400
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value
)
else:
with pytest.raises(APIStatusError) as exc_info:
chat_response = client.messages.create(**request, stream=True)
for _ in chat_response:
pass
assert (
"[Invariant] The request did not pass the guardrails"
in exc_info.value.message
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value.body
)
if push_to_explorer:
time.sleep(2)
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
# in case of input guardrailing, the pushed trace will not contain a response
trace = trace_response.json()
assert len(trace["messages"]) == 1, "Only the user message should be present"
assert trace["messages"][0] == {
"role": "user",
"content": "Tell me more about Fight Club.",
}
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"]
== "Users must not mention the magic phrase 'Fight Club'"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@@ -330,7 +330,7 @@ async def test_input_from_guardrail_from_file(
trace = trace_response.json()
# in case of input guardrailing, the pushed trace will not contain a response
assert len(trace["messages"]) == 1
assert len(trace["messages"]) == 1, "Trace should only contain the user message"
assert trace["messages"][0] == {
"role": "user",
"content": "Tell me more about Fight Club.",