From c2177faaa8e242a0e6b577f39b193c533b7e84a9 Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Fri, 28 Mar 2025 20:53:23 +0100 Subject: [PATCH] anthropic integration of pipelined and pre-guardrailing --- gateway/integrations/guardrails.py | 12 + gateway/routes/anthropic.py | 394 +++++++++++++----- gateway/routes/open_ai.py | 11 +- .../guardrails/test_guardrails_anthropic.py | 94 +++++ .../guardrails/test_guardrails_open_ai.py | 2 +- 5 files changed, 403 insertions(+), 110 deletions(-) diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index 6b3719a..481884f 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -118,6 +118,18 @@ class ExtraItem: return f"" +class Replacement(ExtraItem): + """ + Like ExtraItem, but used to replace the full request result in case of 'InstrumentedResponse'. + """ + + def __init__(self, value): + super().__init__(value, end_of_stream=True) + + def __str__(self): + return f"" + + class InstrumentedStreamingResponse: def __init__(self): # request statistics diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 807be91..2f3e243 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -5,6 +5,7 @@ import json from typing import Any, Optional import httpx +from regex import R from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from starlette.responses import StreamingResponse @@ -18,7 +19,14 @@ from converters.anthropic_to_invariant import ( ) from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData -from integrations.guardrails import check_guardrails, preload_guardrails +from integrations.guardrails import ( + ExtraItem, + InstrumentedResponse, + InstrumentedStreamingResponse, + Replacement, + check_guardrails, + preload_guardrails, +) gateway = APIRouter() @@ -85,8 +93,7 @@ async def anthropic_v1_messages_gateway( if request_json.get("stream"): return await handle_streaming_response(context, client, anthropic_request) - response = await client.send(anthropic_request) - return await handle_non_streaming_response(context, response) + return await handle_non_streaming_response(context, client, anthropic_request) def create_metadata( @@ -110,7 +117,8 @@ def combine_request_and_response_messages( {"role": "system", "content": context.request_json.get("system")} ) messages.extend(context.request_json.get("messages", [])) - messages.append(json_response) + if len(json_response) > 0: + messages.append(json_response) return messages @@ -154,56 +162,282 @@ async def push_to_explorer( ) +class InstrumentedAnthropicResponse(InstrumentedResponse): + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + anthropic_request: httpx.Request, + ): + super().__init__() + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.anthropic_request: httpx.Request = anthropic_request + + # response data + self.response: Optional[httpx.Response] = None + self.response_string: Optional[str] = None + self.json_response: Optional[dict[str, Any]] = None + + # guardrailing response (if any) + self.guardrails_execution_result = {} + + async def on_start(self): + """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + if self.context.config and self.context.config.guardrails: + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, {} + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "error": { + "message": "[Invariant] The request did not pass the guardrails", + "details": self.guardrails_execution_result, + } + } + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + {}, + self.guardrails_execution_result, + ) + ) + + # if we find something, we prevent the request from going through + # and return an error instead + return Replacement( + Response( + content=error_chunk, + status_code=400, + media_type="application/json", + headers={"content-type": "application/json"}, + ) + ) + + async def request(self): + self.response = await self.client.send(self.anthropic_request) + + try: + json_response = self.response.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=self.response.status_code, + detail=f"Invalid JSON response received from Anthropic: {self.response.text}, got error{e}", + ) from e + if self.response.status_code != 200: + raise HTTPException( + status_code=self.response.status_code, + detail=json_response.get("error", "Unknown error from Anthropic"), + ) + + self.json_response = json_response + self.response_string = json.dumps(json_response) + + return self._make_response( + content=self.response_string, + status_code=self.response.status_code, + ) + + def _make_response(self, content: str, status_code: int): + """Creates a new Response object with the correct headers and content""" + assert self.response is not None, "response is None" + + updated_headers = self.response.headers.copy() + updated_headers.pop("Content-Length", None) + + return Response( + content=content, + status_code=status_code, + media_type="application/json", + headers=dict(updated_headers), + ) + + async def on_end(self): + """Checks guardrails after the response is received, and asynchronously pushes to Explorer.""" + # ensure the response data is available + assert self.response is not None, "response is None" + assert self.json_response is not None, "json_response is None" + assert self.response_string is not None, "response_string is None" + + if self.context.config and self.context.config.guardrails: + # Block on the guardrails check + guardrails_execution_result = await get_guardrails_check_result( + self.context, self.json_response + ) + if guardrails_execution_result.get("errors", []): + guardrail_response_string = json.dumps( + { + "error": "[Invariant] The response did not pass the guardrails", + "details": guardrails_execution_result, + } + ) + + # push to explorer (if configured) + if self.context.dataset_name: + # Push to Explorer - don't block on its response + asyncio.create_task( + push_to_explorer( + self.context, + self.json_response, + guardrails_execution_result, + ) + ) + + return Replacement( + self._make_response( + content=guardrail_response_string, + status_code=400, + ) + ) + + # push to explorer (if configured) + if self.context.dataset_name: + # Push to Explorer - don't block on its response + asyncio.create_task( + push_to_explorer( + self.context, self.json_response, guardrails_execution_result + ) + ) + + async def handle_non_streaming_response( context: RequestContextData, - response: httpx.Response, + client: httpx.AsyncClient, + anthropic_request: httpx.Request, ) -> Response: """Handles non-streaming Anthropic responses""" - try: - json_response = response.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=response.status_code, - detail=f"Invalid JSON response received from Anthropic: {response.text}, got error{e}", - ) from e - if response.status_code != 200: - raise HTTPException( - status_code=response.status_code, - detail=json_response.get("error", "Unknown error from Anthropic"), - ) - - guardrails_execution_result = {} - response_string = json.dumps(json_response) - response_code = response.status_code - - if context.config and context.config.guardrails: - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, json_response - ) - if guardrails_execution_result.get("errors", []): - response_string = json.dumps( - { - "error": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - } - ) - response_code = 400 - if context.dataset_name: - # Push to Explorer - don't block on its response - asyncio.create_task( - push_to_explorer(context, json_response, guardrails_execution_result) - ) - - updated_headers = response.headers.copy() - updated_headers.pop("Content-Length", None) - return Response( - content=response_string, - status_code=response_code, - media_type="application/json", - headers=dict(updated_headers), + response = InstrumentedAnthropicResponse( + context=context, + client=client, + anthropic_request=anthropic_request, ) + return await response.instrumented_request() + + +class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + anthropic_request: httpx.Request, + ): + super().__init__() + + # request parameters + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.anthropic_request: httpx.Request = anthropic_request + + # response data + self.merged_response = {} + + # guardrailing response (if any) + self.guardrails_execution_result = {} + + async def on_start(self): + """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + if self.context.config and self.context.config.guardrails: + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.merged_response + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "error": { + "message": "[Invariant] The request did not pass the guardrails", + "details": self.guardrails_execution_result, + } + } + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.merged_response, + self.guardrails_execution_result, + ) + ) + + # if we find something, we end the stream prematurely (end_of_stream=True) + # and yield an error chunk instead of actually beginning the stream + return ExtraItem( + f"event: error\ndata: {error_chunk}\n\n".encode(), + end_of_stream=True, + ) + + async def event_generator(self): + """Actual streaming response generator""" + response = await self.client.send(self.anthropic_request, stream=True) + if response.status_code != 200: + error_content = await response.aread() + try: + error_json = json.loads(error_content) + error_detail = error_json.get("error", "Unknown error from Anthropic") + except json.JSONDecodeError: + error_detail = { + "error": "Failed to decode error response from Anthropic" + } + raise HTTPException(status_code=response.status_code, detail=error_detail) + + # iterate over the response stream + async for chunk in response.aiter_bytes(): + yield chunk + + async def on_chunk(self, chunk): + decoded_chunk = chunk.decode().strip() + if not decoded_chunk: + return + + # process chunk and extend the merged_response + process_chunk(decoded_chunk, self.merged_response) + + # on last stream chunk, run output guardrails + if ( + "event: message_stop" in decoded_chunk + and self.context.config + and self.context.config.guardrails + ): + # Block on the guardrails check + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.merged_response + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "type": "error", + "error": { + "message": "[Invariant] The response did not pass the guardrails", + "details": self.guardrails_execution_result, + }, + } + ) + + # yield an extra error chunk (without preventing the original chunk to go through after, + # so client gets the proper message_stop event still) + return ExtraItem( + value=f"event: error\ndata: {error_chunk}\n\n".encode() + ) + + async def on_end(self): + """on_end: send full merged response to the exploree (if configured)""" + # don't block on the response from explorer (.create_task) + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.merged_response, + self.guardrails_execution_result, + ) + ) + async def handle_streaming_response( context: RequestContextData, @@ -211,63 +445,15 @@ async def handle_streaming_response( anthropic_request: httpx.Request, ) -> StreamingResponse: """Handles streaming Anthropic responses""" - merged_response = {} + response = InstrumentedAnthropicStreamingResposne( + context=context, + client=client, + anthropic_request=anthropic_request, + ) - response = await client.send(anthropic_request, stream=True) - if response.status_code != 200: - error_content = await response.aread() - try: - error_json = json.loads(error_content) - error_detail = error_json.get("error", "Unknown error from Anthropic") - except json.JSONDecodeError: - error_detail = {"error": "Failed to decode error response from Anthropic"} - raise HTTPException(status_code=response.status_code, detail=error_detail) - - async def event_generator() -> Any: - async for chunk in response.aiter_bytes(): - decoded_chunk = chunk.decode().strip() - if not decoded_chunk: - continue - process_chunk(decoded_chunk, merged_response) - if ( - "event: message_stop" in decoded_chunk - and context.config - and context.config.guardrails - ): - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, merged_response - ) - if guardrails_execution_result.get("errors", []): - error_chunk = json.dumps( - { - "type": "error", - "error": { - "message": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - }, - } - ) - # Push annotated trace to the explorer - don't block on its response - if context.dataset_name: - asyncio.create_task( - push_to_explorer( - context, - merged_response, - guardrails_execution_result, - ) - ) - yield f"event: error\ndata: {error_chunk}\n\n".encode() - return - yield chunk - - if context.dataset_name: - # Push to Explorer - don't block on the response - asyncio.create_task(push_to_explorer(context, merged_response)) - - generator = event_generator() - - return StreamingResponse(generator, media_type="text/event-stream") + return StreamingResponse( + response.instrumented_event_generator(), media_type="text/event-stream" + ) def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None: diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 629dc6b..6ef3808 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -80,13 +80,13 @@ async def openai_chat_completions_gateway( asyncio.create_task(preload_guardrails(context)) if request_json.get("stream", False): - return await stream_response( + return await handle_stream_response( context, client, open_ai_request, ) - return await non_stream_response(context, client, open_ai_request) + return await handle_non_stream_response(context, client, open_ai_request) class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): @@ -158,7 +158,8 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): # if we find something, we end the stream prematurely (end_of_stream=True) # and yield an error chunk instead of actually beginning the stream return ExtraItem( - f"data: {error_chunk}\n\n".encode(), end_of_stream=True + f"data: {error_chunk}\n\n".encode(), + end_of_stream=True, ) async def on_chunk(self, chunk): @@ -231,7 +232,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): yield chunk -async def stream_response( +async def handle_stream_response( context: RequestContextData, client: httpx.AsyncClient, open_ai_request: httpx.Request, @@ -598,7 +599,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): ) -async def non_stream_response( +async def handle_non_stream_response( context: RequestContextData, client: httpx.AsyncClient, open_ai_request: httpx.Request, diff --git a/tests/integration/guardrails/test_guardrails_anthropic.py b/tests/integration/guardrails/test_guardrails_anthropic.py index 45ba1fd..173e5ab 100644 --- a/tests/integration/guardrails/test_guardrails_anthropic.py +++ b/tests/integration/guardrails/test_guardrails_anthropic.py @@ -238,3 +238,97 @@ async def test_tool_call_guardrail_from_file( == "get_capital is called with Germany as argument" and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) + + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize( + "do_stream, push_to_explorer", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_input_from_guardrail_from_file( + explorer_api_url, gateway_url, do_stream, push_to_explorer +): + """Test input guardrail enforcement with Anthropic.""" + if not os.getenv("INVARIANT_API_KEY"): + pytest.fail("No INVARIANT_API_KEY set, failing") + + dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" + + client = Anthropic( + http_client=Client( + headers={ + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + ), + base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/anthropic", + ) + + request = { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 100, + "messages": [{"role": "user", "content": "Tell me more about Fight Club."}], + } + + if not do_stream: + with pytest.raises(BadRequestError) as exc_info: + _ = client.messages.create(**request, stream=False) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "Users must not mention the magic phrase 'Fight Club'" in str( + exc_info.value + ) + + else: + with pytest.raises(APIStatusError) as exc_info: + chat_response = client.messages.create(**request, stream=True) + for _ in chat_response: + pass + + assert ( + "[Invariant] The request did not pass the guardrails" + in exc_info.value.message + ) + assert "Users must not mention the magic phrase 'Fight Club'" in str( + exc_info.value.body + ) + + if push_to_explorer: + time.sleep(2) + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + # in case of input guardrailing, the pushed trace will not contain a response + trace = trace_response.json() + assert len(trace["messages"]) == 1, "Only the user message should be present" + assert trace["messages"][0] == { + "role": "user", + "content": "Tell me more about Fight Club.", + } + + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + assert len(annotations) == 1 + assert ( + annotations[0]["content"] + == "Users must not mention the magic phrase 'Fight Club'" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + ) diff --git a/tests/integration/guardrails/test_guardrails_open_ai.py b/tests/integration/guardrails/test_guardrails_open_ai.py index acc2f67..cf9e274 100644 --- a/tests/integration/guardrails/test_guardrails_open_ai.py +++ b/tests/integration/guardrails/test_guardrails_open_ai.py @@ -330,7 +330,7 @@ async def test_input_from_guardrail_from_file( trace = trace_response.json() # in case of input guardrailing, the pushed trace will not contain a response - assert len(trace["messages"]) == 1 + assert len(trace["messages"]) == 1, "Trace should only contain the user message" assert trace["messages"][0] == { "role": "user", "content": "Tell me more about Fight Club.",