mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-23 15:29:43 +02:00
anthropic integration of pipelined and pre-guardrailing
This commit is contained in:
committed by
Hemang
parent
7f820bd79f
commit
c2177faaa8
@@ -118,6 +118,18 @@ class ExtraItem:
|
||||
return f"<ExtraItem value={self.value} end_of_stream={self.end_of_stream}>"
|
||||
|
||||
|
||||
class Replacement(ExtraItem):
|
||||
"""
|
||||
Like ExtraItem, but used to replace the full request result in case of 'InstrumentedResponse'.
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
super().__init__(value, end_of_stream=True)
|
||||
|
||||
def __str__(self):
|
||||
return f"<Replacement value={self.value}>"
|
||||
|
||||
|
||||
class InstrumentedStreamingResponse:
|
||||
def __init__(self):
|
||||
# request statistics
|
||||
|
||||
+290
-104
@@ -5,6 +5,7 @@ import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from regex import R
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from starlette.responses import StreamingResponse
|
||||
@@ -18,7 +19,14 @@ from converters.anthropic_to_invariant import (
|
||||
)
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
from integrations.guardrails import check_guardrails, preload_guardrails
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
InstrumentedStreamingResponse,
|
||||
Replacement,
|
||||
check_guardrails,
|
||||
preload_guardrails,
|
||||
)
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
@@ -85,8 +93,7 @@ async def anthropic_v1_messages_gateway(
|
||||
|
||||
if request_json.get("stream"):
|
||||
return await handle_streaming_response(context, client, anthropic_request)
|
||||
response = await client.send(anthropic_request)
|
||||
return await handle_non_streaming_response(context, response)
|
||||
return await handle_non_streaming_response(context, client, anthropic_request)
|
||||
|
||||
|
||||
def create_metadata(
|
||||
@@ -110,7 +117,8 @@ def combine_request_and_response_messages(
|
||||
{"role": "system", "content": context.request_json.get("system")}
|
||||
)
|
||||
messages.extend(context.request_json.get("messages", []))
|
||||
messages.append(json_response)
|
||||
if len(json_response) > 0:
|
||||
messages.append(json_response)
|
||||
return messages
|
||||
|
||||
|
||||
@@ -154,56 +162,282 @@ async def push_to_explorer(
|
||||
)
|
||||
|
||||
|
||||
class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.anthropic_request: httpx.Request = anthropic_request
|
||||
|
||||
# response data
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.response_string: Optional[str] = None
|
||||
self.json_response: Optional[dict[str, Any]] = None
|
||||
|
||||
# guardrailing response (if any)
|
||||
self.guardrails_execution_result = {}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The request did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
{},
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we prevent the request from going through
|
||||
# and return an error instead
|
||||
return Replacement(
|
||||
Response(
|
||||
content=error_chunk,
|
||||
status_code=400,
|
||||
media_type="application/json",
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
)
|
||||
|
||||
async def request(self):
|
||||
self.response = await self.client.send(self.anthropic_request)
|
||||
|
||||
try:
|
||||
json_response = self.response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=f"Invalid JSON response received from Anthropic: {self.response.text}, got error{e}",
|
||||
) from e
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
|
||||
self.json_response = json_response
|
||||
self.response_string = json.dumps(json_response)
|
||||
|
||||
return self._make_response(
|
||||
content=self.response_string,
|
||||
status_code=self.response.status_code,
|
||||
)
|
||||
|
||||
def _make_response(self, content: str, status_code: int):
|
||||
"""Creates a new Response object with the correct headers and content"""
|
||||
assert self.response is not None, "response is None"
|
||||
|
||||
updated_headers = self.response.headers.copy()
|
||||
updated_headers.pop("Content-Length", None)
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=status_code,
|
||||
media_type="application/json",
|
||||
headers=dict(updated_headers),
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""Checks guardrails after the response is received, and asynchronously pushes to Explorer."""
|
||||
# ensure the response data is available
|
||||
assert self.response is not None, "response is None"
|
||||
assert self.json_response is not None, "json_response is None"
|
||||
assert self.response_string is not None, "response_string is None"
|
||||
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.json_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
guardrail_response_string = json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
}
|
||||
)
|
||||
|
||||
# push to explorer (if configured)
|
||||
if self.context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
return Replacement(
|
||||
self._make_response(
|
||||
content=guardrail_response_string,
|
||||
status_code=400,
|
||||
)
|
||||
)
|
||||
|
||||
# push to explorer (if configured)
|
||||
if self.context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context, self.json_response, guardrails_execution_result
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
response: httpx.Response,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
) -> Response:
|
||||
"""Handles non-streaming Anthropic responses"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Invalid JSON response received from Anthropic: {response.text}, got error{e}",
|
||||
) from e
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
|
||||
guardrails_execution_result = {}
|
||||
response_string = json.dumps(json_response)
|
||||
response_code = response.status_code
|
||||
|
||||
if context.config and context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, json_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(context, json_response, guardrails_execution_result)
|
||||
)
|
||||
|
||||
updated_headers = response.headers.copy()
|
||||
updated_headers.pop("Content-Length", None)
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(updated_headers),
|
||||
response = InstrumentedAnthropicResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
anthropic_request=anthropic_request,
|
||||
)
|
||||
|
||||
return await response.instrumented_request()
|
||||
|
||||
|
||||
class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.anthropic_request: httpx.Request = anthropic_request
|
||||
|
||||
# response data
|
||||
self.merged_response = {}
|
||||
|
||||
# guardrailing response (if any)
|
||||
self.guardrails_execution_result = {}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The request did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.merged_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we end the stream prematurely (end_of_stream=True)
|
||||
# and yield an error chunk instead of actually beginning the stream
|
||||
return ExtraItem(
|
||||
f"event: error\ndata: {error_chunk}\n\n".encode(),
|
||||
end_of_stream=True,
|
||||
)
|
||||
|
||||
async def event_generator(self):
|
||||
"""Actual streaming response generator"""
|
||||
response = await self.client.send(self.anthropic_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content)
|
||||
error_detail = error_json.get("error", "Unknown error from Anthropic")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {
|
||||
"error": "Failed to decode error response from Anthropic"
|
||||
}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
|
||||
# iterate over the response stream
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
decoded_chunk = chunk.decode().strip()
|
||||
if not decoded_chunk:
|
||||
return
|
||||
|
||||
# process chunk and extend the merged_response
|
||||
process_chunk(decoded_chunk, self.merged_response)
|
||||
|
||||
# on last stream chunk, run output guardrails
|
||||
if (
|
||||
"event: message_stop" in decoded_chunk
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": {
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# yield an extra error chunk (without preventing the original chunk to go through after,
|
||||
# so client gets the proper message_stop event still)
|
||||
return ExtraItem(
|
||||
value=f"event: error\ndata: {error_chunk}\n\n".encode()
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""on_end: send full merged response to the exploree (if configured)"""
|
||||
# don't block on the response from explorer (.create_task)
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.merged_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_streaming_response(
|
||||
context: RequestContextData,
|
||||
@@ -211,63 +445,15 @@ async def handle_streaming_response(
|
||||
anthropic_request: httpx.Request,
|
||||
) -> StreamingResponse:
|
||||
"""Handles streaming Anthropic responses"""
|
||||
merged_response = {}
|
||||
response = InstrumentedAnthropicStreamingResposne(
|
||||
context=context,
|
||||
client=client,
|
||||
anthropic_request=anthropic_request,
|
||||
)
|
||||
|
||||
response = await client.send(anthropic_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content)
|
||||
error_detail = error_json.get("error", "Unknown error from Anthropic")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {"error": "Failed to decode error response from Anthropic"}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
|
||||
async def event_generator() -> Any:
|
||||
async for chunk in response.aiter_bytes():
|
||||
decoded_chunk = chunk.decode().strip()
|
||||
if not decoded_chunk:
|
||||
continue
|
||||
process_chunk(decoded_chunk, merged_response)
|
||||
if (
|
||||
"event: message_stop" in decoded_chunk
|
||||
and context.config
|
||||
and context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, merged_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": {
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
},
|
||||
}
|
||||
)
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
merged_response,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
yield f"event: error\ndata: {error_chunk}\n\n".encode()
|
||||
return
|
||||
yield chunk
|
||||
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on the response
|
||||
asyncio.create_task(push_to_explorer(context, merged_response))
|
||||
|
||||
generator = event_generator()
|
||||
|
||||
return StreamingResponse(generator, media_type="text/event-stream")
|
||||
return StreamingResponse(
|
||||
response.instrumented_event_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None:
|
||||
|
||||
@@ -80,13 +80,13 @@ async def openai_chat_completions_gateway(
|
||||
asyncio.create_task(preload_guardrails(context))
|
||||
|
||||
if request_json.get("stream", False):
|
||||
return await stream_response(
|
||||
return await handle_stream_response(
|
||||
context,
|
||||
client,
|
||||
open_ai_request,
|
||||
)
|
||||
|
||||
return await non_stream_response(context, client, open_ai_request)
|
||||
return await handle_non_stream_response(context, client, open_ai_request)
|
||||
|
||||
|
||||
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
@@ -158,7 +158,8 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
# if we find something, we end the stream prematurely (end_of_stream=True)
|
||||
# and yield an error chunk instead of actually beginning the stream
|
||||
return ExtraItem(
|
||||
f"data: {error_chunk}\n\n".encode(), end_of_stream=True
|
||||
f"data: {error_chunk}\n\n".encode(),
|
||||
end_of_stream=True,
|
||||
)
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
@@ -231,7 +232,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
yield chunk
|
||||
|
||||
|
||||
async def stream_response(
|
||||
async def handle_stream_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
@@ -598,7 +599,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
)
|
||||
|
||||
|
||||
async def non_stream_response(
|
||||
async def handle_non_stream_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
|
||||
@@ -238,3 +238,97 @@ async def test_tool_call_guardrail_from_file(
|
||||
== "get_capital is called with Germany as argument"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, push_to_explorer",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_input_from_guardrail_from_file(
|
||||
explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
):
|
||||
"""Test input guardrail enforcement with Anthropic."""
|
||||
if not os.getenv("INVARIANT_API_KEY"):
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
|
||||
client = Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/anthropic",
|
||||
)
|
||||
|
||||
request = {
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"max_tokens": 100,
|
||||
"messages": [{"role": "user", "content": "Tell me more about Fight Club."}],
|
||||
}
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
_ = client.messages.create(**request, stream=False)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
else:
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
chat_response = client.messages.create(**request, stream=True)
|
||||
for _ in chat_response:
|
||||
pass
|
||||
|
||||
assert (
|
||||
"[Invariant] The request did not pass the guardrails"
|
||||
in exc_info.value.message
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
exc_info.value.body
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
time.sleep(2)
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
# in case of input guardrailing, the pushed trace will not contain a response
|
||||
trace = trace_response.json()
|
||||
assert len(trace["messages"]) == 1, "Only the user message should be present"
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": "Tell me more about Fight Club.",
|
||||
}
|
||||
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"]
|
||||
== "Users must not mention the magic phrase 'Fight Club'"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
@@ -330,7 +330,7 @@ async def test_input_from_guardrail_from_file(
|
||||
trace = trace_response.json()
|
||||
|
||||
# in case of input guardrailing, the pushed trace will not contain a response
|
||||
assert len(trace["messages"]) == 1
|
||||
assert len(trace["messages"]) == 1, "Trace should only contain the user message"
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": "Tell me more about Fight Club.",
|
||||
|
||||
Reference in New Issue
Block a user