mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-24 15:54:05 +02:00
719 lines
26 KiB
Python
719 lines
26 KiB
Python
"""Gateway service to forward requests to the OpenAI APIs"""
|
|
|
|
import asyncio
|
|
import json
|
|
from typing import Any, Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from gateway.common.authorization import extract_authorization_from_headers
|
|
from gateway.common.config_manager import (
|
|
GatewayConfig,
|
|
GatewayConfigManager,
|
|
extract_guardrails_from_header,
|
|
)
|
|
from gateway.common.constants import (
|
|
CLIENT_TIMEOUT,
|
|
IGNORED_HEADERS,
|
|
)
|
|
from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
|
|
from gateway.common.request_context import RequestContext
|
|
from gateway.integrations.explorer import (
|
|
create_annotations_from_guardrails_errors,
|
|
fetch_guardrails_from_explorer,
|
|
push_trace,
|
|
)
|
|
from gateway.integrations.guardrails import (
|
|
ExtraItem,
|
|
InstrumentedResponse,
|
|
InstrumentedStreamingResponse,
|
|
check_guardrails,
|
|
)
|
|
|
|
gateway = APIRouter()
|
|
|
|
MISSING_AUTH_HEADER = "Missing authorization header"
|
|
FINISH_REASON_TO_PUSH_TRACE = ["stop", "length", "content_filter"]
|
|
OPENAI_AUTHORIZATION_HEADER = "authorization"
|
|
|
|
|
|
def validate_headers(authorization: str = Header(None)):
|
|
"""Require the authorization header to be present"""
|
|
if authorization is None:
|
|
raise HTTPException(status_code=400, detail=MISSING_AUTH_HEADER)
|
|
|
|
|
|
def make_cors_response(request: Request, allow_methods: str) -> Response:
|
|
"""Returns a CORS response with the specified allowed methods"""
|
|
return Response(
|
|
status_code=204,
|
|
headers={
|
|
"Access-Control-Allow-Origin": request.headers.get("origin", "*"),
|
|
"Access-Control-Allow-Methods": f"{allow_methods}, OPTIONS",
|
|
"Access-Control-Allow-Headers": "Authorization, Content-Type",
|
|
"Access-Control-Max-Age": "86400",
|
|
},
|
|
)
|
|
|
|
|
|
@gateway.options("/{dataset_name}/openai/chat/completions")
|
|
@gateway.options("/openai/chat/completions")
|
|
async def openai_chat_completions_options(request: Request, dataset_name: str = None):
|
|
"""Enables CORS for the OpenAI chat completions endpoint"""
|
|
return make_cors_response(request, allow_methods="POST")
|
|
|
|
|
|
@gateway.options("/{dataset_name}/openai/models")
|
|
@gateway.options("/openai/models")
|
|
async def openai_models_options(request: Request, dataset_name: str = None):
|
|
"""Enables CORS for the OpenAI models endpoint"""
|
|
return make_cors_response(request, allow_methods="GET")
|
|
|
|
|
|
@gateway.get("/{dataset_name}/openai/models")
|
|
@gateway.get("/openai/models")
|
|
async def openai_models_gateway(
|
|
request: Request,
|
|
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
|
):
|
|
"""Proxy request to OpenAI /models endpoint"""
|
|
headers = {
|
|
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
|
|
}
|
|
_, openai_api_key = extract_authorization_from_headers(
|
|
request, dataset_name, OPENAI_AUTHORIZATION_HEADER
|
|
)
|
|
headers[OPENAI_AUTHORIZATION_HEADER] = "Bearer " + openai_api_key
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) as client:
|
|
open_ai_request = client.build_request(
|
|
"GET",
|
|
"https://api.openai.com/v1/models",
|
|
headers=headers,
|
|
)
|
|
result = await client.send(open_ai_request)
|
|
return Response(
|
|
content=result.content,
|
|
status_code=result.status_code,
|
|
headers=dict(result.headers),
|
|
)
|
|
|
|
|
|
@gateway.post(
|
|
"/{dataset_name}/openai/chat/completions",
|
|
dependencies=[Depends(validate_headers)],
|
|
)
|
|
@gateway.post(
|
|
"/openai/chat/completions",
|
|
dependencies=[Depends(validate_headers)],
|
|
)
|
|
async def openai_chat_completions_gateway(
|
|
request: Request,
|
|
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
|
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
|
|
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
|
|
) -> Response:
|
|
"""Proxy calls to the OpenAI APIs"""
|
|
headers = {
|
|
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
|
|
}
|
|
headers["accept-encoding"] = "identity"
|
|
|
|
invariant_authorization, openai_api_key = extract_authorization_from_headers(
|
|
request, dataset_name, OPENAI_AUTHORIZATION_HEADER
|
|
)
|
|
headers[OPENAI_AUTHORIZATION_HEADER] = "Bearer " + openai_api_key
|
|
|
|
request_body_bytes = await request.body()
|
|
request_json = json.loads(request_body_bytes)
|
|
|
|
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
|
|
open_ai_request = client.build_request(
|
|
"POST",
|
|
"https://api.openai.com/v1/chat/completions",
|
|
content=request_body_bytes,
|
|
headers=headers,
|
|
)
|
|
|
|
dataset_guardrails = None
|
|
if dataset_name:
|
|
# Get the guardrails for the dataset
|
|
dataset_guardrails = await fetch_guardrails_from_explorer(
|
|
dataset_name, invariant_authorization
|
|
)
|
|
context = RequestContext.create(
|
|
request_json=request_json,
|
|
dataset_name=dataset_name,
|
|
invariant_authorization=invariant_authorization,
|
|
guardrails=header_guardrails or dataset_guardrails,
|
|
config=config,
|
|
request=request,
|
|
)
|
|
if request_json.get("stream", False):
|
|
return await handle_stream_response(
|
|
context,
|
|
client,
|
|
open_ai_request,
|
|
)
|
|
|
|
return await handle_non_stream_response(context, client, open_ai_request)
|
|
|
|
|
|
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
|
"""
|
|
Does a streaming OpenAI completion request at the core, but also checks guardrails
|
|
before (concurrent) and after the request.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
context: RequestContext,
|
|
client: httpx.AsyncClient,
|
|
open_ai_request: httpx.Request,
|
|
):
|
|
super().__init__()
|
|
|
|
# request parameters
|
|
self.context: RequestContext = context
|
|
self.client: httpx.AsyncClient = client
|
|
self.open_ai_request: httpx.Request = open_ai_request
|
|
|
|
# guardrailing output (if any)
|
|
self.guardrails_execution_result: Optional[dict] = None
|
|
|
|
# merged_response will be updated with the data from the chunks in the stream
|
|
# At the end of the stream, this will be sent to the explorer
|
|
self.merged_response = {
|
|
"id": None,
|
|
"object": "chat.completion",
|
|
"created": None,
|
|
"model": None,
|
|
"choices": [],
|
|
"usage": None,
|
|
}
|
|
|
|
# Each chunk in the stream contains a list called "choices" each entry in the list
|
|
# has an index.
|
|
# A choice has a field called "delta" which may contain a list called "tool_calls".
|
|
# Maps the choice index in the stream to the index in the merged_response["choices"] list
|
|
self.choice_mapping_by_index = {}
|
|
# Combines the choice index and tool call index to uniquely identify a tool call
|
|
self.tool_call_mapping_by_index = {}
|
|
|
|
async def on_start(self):
|
|
"""
|
|
Check guardrails in a pipelined fashion, before processing the first chunk
|
|
(for input guardrailing).
|
|
"""
|
|
if self.context.guardrails:
|
|
self.guardrails_execution_result = await get_guardrails_check_result(
|
|
self.context,
|
|
action=GuardrailAction.BLOCK,
|
|
response_json=self.merged_response,
|
|
)
|
|
if self.guardrails_execution_result.get("errors", []):
|
|
error_chunk = json.dumps(
|
|
{
|
|
"error": {
|
|
"message": "[Invariant] The request did not pass the guardrails",
|
|
"details": self.guardrails_execution_result,
|
|
}
|
|
}
|
|
)
|
|
|
|
# Push annotated trace to the explorer - don't block on its response
|
|
if self.context.dataset_name:
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context,
|
|
self.merged_response,
|
|
self.guardrails_execution_result,
|
|
)
|
|
)
|
|
|
|
# if we find something, we end the stream prematurely (end_of_stream=True)
|
|
# and yield an error chunk instead of actually beginning the stream
|
|
return ExtraItem(
|
|
f"data: {error_chunk}\n\n".encode(),
|
|
end_of_stream=True,
|
|
)
|
|
|
|
async def on_chunk(self, chunk):
|
|
"""Processes each chunk of the stream and checks guardrails at the end of the stream"""
|
|
# process and check each chunk
|
|
chunk_text = chunk.decode().strip()
|
|
if not chunk_text:
|
|
return
|
|
|
|
# Process the chunk
|
|
# This will update merged_response with the data from the chunk
|
|
process_chunk_text(
|
|
chunk_text,
|
|
self.merged_response,
|
|
self.choice_mapping_by_index,
|
|
self.tool_call_mapping_by_index,
|
|
)
|
|
|
|
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
|
|
if "data: [DONE]" in chunk_text and self.context.guardrails:
|
|
# Block on the guardrails check
|
|
self.guardrails_execution_result = await get_guardrails_check_result(
|
|
self.context,
|
|
action=GuardrailAction.BLOCK,
|
|
response_json=self.merged_response,
|
|
)
|
|
if self.guardrails_execution_result.get("errors", []):
|
|
error_chunk = json.dumps(
|
|
{
|
|
"error": {
|
|
"message": "[Invariant] The response did not pass the guardrails",
|
|
"details": self.guardrails_execution_result,
|
|
}
|
|
}
|
|
)
|
|
|
|
# yield an extra error chunk (without preventing the original chunk to go through after)
|
|
return ExtraItem(f"data: {error_chunk}\n\n".encode())
|
|
|
|
# push will happen in on_end
|
|
|
|
async def on_end(self):
|
|
"""Sends full merged response to the explorer."""
|
|
# don't block on the response from explorer (.create_task)
|
|
if self.context.dataset_name:
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context, self.merged_response, self.guardrails_execution_result
|
|
)
|
|
)
|
|
|
|
async def event_generator(self):
|
|
"""Actual OpenAI stream response."""
|
|
response = await self.client.send(self.open_ai_request, stream=True)
|
|
if response.status_code != 200:
|
|
error_content = await response.aread()
|
|
try:
|
|
error_json = json.loads(error_content.decode("utf-8"))
|
|
error_detail = error_json.get("error", "Unknown error from OpenAI API")
|
|
except json.JSONDecodeError:
|
|
error_detail = {"error": "Failed to parse OpenAI error response"}
|
|
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
|
|
|
# stream out chunks
|
|
async for chunk in response.aiter_bytes():
|
|
yield chunk
|
|
|
|
|
|
async def handle_stream_response(
|
|
context: RequestContext,
|
|
client: httpx.AsyncClient,
|
|
open_ai_request: httpx.Request,
|
|
) -> Response:
|
|
"""
|
|
Handles streaming the OpenAI response to the client while building a merged_response
|
|
The chunks are returned to the caller immediately
|
|
The merged_response is built from the chunks as they are received
|
|
It is sent to the Invariant Explorer at the end of the stream
|
|
"""
|
|
|
|
response = InstrumentedOpenAIStreamResponse(
|
|
context,
|
|
client,
|
|
open_ai_request,
|
|
)
|
|
|
|
return StreamingResponse(
|
|
response.instrumented_event_generator(), media_type="text/event-stream"
|
|
)
|
|
|
|
|
|
def initialize_merged_response() -> dict[str, Any]:
|
|
"""Initializes the full response dictionary"""
|
|
return {
|
|
"id": None,
|
|
"object": "chat.completion",
|
|
"created": None,
|
|
"model": None,
|
|
"choices": [],
|
|
"usage": None,
|
|
}
|
|
|
|
|
|
def process_chunk_text(
|
|
chunk_text: str,
|
|
merged_response: dict[str, Any],
|
|
choice_mapping_by_index: dict[int, int],
|
|
tool_call_mapping_by_index: dict[str, dict[str, Any]],
|
|
) -> None:
|
|
"""Processes the chunk text and updates the merged_response to be sent to the explorer"""
|
|
# Split the chunk text into individual JSON strings
|
|
# A single chunk can contain multiple "data: " sections
|
|
for json_string in chunk_text.split("\ndata: "):
|
|
json_string = json_string.replace("data: ", "").strip()
|
|
|
|
if not json_string or json_string == "[DONE]":
|
|
continue
|
|
|
|
try:
|
|
json_chunk = json.loads(json_string)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
update_merged_response(
|
|
json_chunk,
|
|
merged_response,
|
|
choice_mapping_by_index,
|
|
tool_call_mapping_by_index,
|
|
)
|
|
|
|
|
|
def update_merged_response(
|
|
json_chunk: dict[str, Any],
|
|
merged_response: dict[str, Any],
|
|
choice_mapping_by_index: dict[int, int],
|
|
tool_call_mapping_by_index: dict[str, dict[str, Any]],
|
|
) -> None:
|
|
"""Updates the merged_response with the data (content, tool_calls, etc.) from the JSON chunk"""
|
|
merged_response["id"] = merged_response["id"] or json_chunk.get("id")
|
|
merged_response["created"] = merged_response["created"] or json_chunk.get("created")
|
|
merged_response["model"] = merged_response["model"] or json_chunk.get("model")
|
|
|
|
for choice in json_chunk.get("choices", []):
|
|
index = choice.get("index", 0)
|
|
|
|
if index not in choice_mapping_by_index:
|
|
choice_mapping_by_index[index] = len(merged_response["choices"])
|
|
merged_response["choices"].append(
|
|
{
|
|
"index": index,
|
|
"message": {"role": "assistant"},
|
|
"finish_reason": None,
|
|
}
|
|
)
|
|
|
|
existing_choice = merged_response["choices"][choice_mapping_by_index[index]]
|
|
delta = choice.get("delta", {})
|
|
if choice.get("finish_reason"):
|
|
existing_choice["finish_reason"] = choice["finish_reason"]
|
|
|
|
update_existing_choice_with_delta(
|
|
existing_choice, delta, tool_call_mapping_by_index, choice_index=index
|
|
)
|
|
|
|
|
|
def update_existing_choice_with_delta(
|
|
existing_choice: dict[str, Any],
|
|
delta: dict[str, Any],
|
|
tool_call_mapping_by_index: dict[str, dict[str, Any]],
|
|
choice_index: int,
|
|
) -> None:
|
|
"""Updates the choice with the data from the delta"""
|
|
content = delta.get("content")
|
|
if content is not None:
|
|
if "content" not in existing_choice["message"]:
|
|
existing_choice["message"]["content"] = ""
|
|
existing_choice["message"]["content"] += content
|
|
|
|
if isinstance(delta.get("tool_calls"), list):
|
|
if "tool_calls" not in existing_choice["message"]:
|
|
existing_choice["message"]["tool_calls"] = []
|
|
|
|
for tool in delta["tool_calls"]:
|
|
tool_index = tool.get("index")
|
|
tool_id = tool.get("id")
|
|
name = tool.get("function", {}).get("name")
|
|
arguments = tool.get("function", {}).get("arguments", "")
|
|
|
|
if tool_index is None:
|
|
continue
|
|
|
|
choice_with_tool_call_index = f"{choice_index}-{tool_index}"
|
|
|
|
if choice_with_tool_call_index not in tool_call_mapping_by_index:
|
|
tool_call_mapping_by_index[choice_with_tool_call_index] = {
|
|
"index": tool_index,
|
|
"id": tool_id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": name,
|
|
"arguments": "",
|
|
},
|
|
}
|
|
existing_choice["message"]["tool_calls"].append(
|
|
tool_call_mapping_by_index[choice_with_tool_call_index]
|
|
)
|
|
|
|
tool_call_entry = tool_call_mapping_by_index[choice_with_tool_call_index]
|
|
|
|
if tool_id:
|
|
tool_call_entry["id"] = tool_id
|
|
|
|
if name:
|
|
tool_call_entry["function"]["name"] = name
|
|
|
|
if arguments:
|
|
tool_call_entry["function"]["arguments"] += arguments
|
|
|
|
finish_reason = delta.get("finish_reason")
|
|
if finish_reason is not None:
|
|
existing_choice["finish_reason"] = finish_reason
|
|
|
|
|
|
def create_metadata(
|
|
context: RequestContext, merged_response: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
"""Creates metadata for the trace"""
|
|
metadata = {
|
|
k: v
|
|
for k, v in context.request_json.items()
|
|
if k != "messages" and v is not None
|
|
}
|
|
metadata["via_gateway"] = True
|
|
metadata.update(
|
|
{
|
|
key: value
|
|
for key, value in merged_response.items()
|
|
if key in ("usage", "model") and merged_response.get(key) is not None
|
|
}
|
|
)
|
|
return metadata
|
|
|
|
|
|
async def push_to_explorer(
|
|
context: RequestContext,
|
|
merged_response: dict[str, Any],
|
|
guardrails_execution_result: Optional[dict] = None,
|
|
) -> None:
|
|
"""Pushes the merged response to the Invariant Explorer"""
|
|
# Only push the trace to explorer if the message is an end turn message
|
|
# or if the guardrails check returned errors.
|
|
guardrails_execution_result = guardrails_execution_result or {}
|
|
guardrails_errors = guardrails_execution_result.get("errors", [])
|
|
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
|
|
# Execute the logging guardrails before pushing to Explorer
|
|
logging_guardrails_execution_result = await get_guardrails_check_result(
|
|
context,
|
|
action=GuardrailAction.LOG,
|
|
response_json=merged_response,
|
|
)
|
|
logging_annotations = create_annotations_from_guardrails_errors(
|
|
logging_guardrails_execution_result.get("errors", [])
|
|
)
|
|
# Update the annotations with the logging guardrails
|
|
annotations.extend(logging_annotations)
|
|
|
|
if annotations or not (
|
|
merged_response.get("choices")
|
|
and merged_response["choices"][0].get("finish_reason")
|
|
not in FINISH_REASON_TO_PUSH_TRACE
|
|
):
|
|
# Combine the messages from the request body and the choices from the OpenAI response
|
|
messages = list(context.request_json.get("messages", []))
|
|
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
|
_ = await push_trace(
|
|
dataset_name=context.dataset_name,
|
|
invariant_authorization=context.invariant_authorization,
|
|
messages=[messages],
|
|
annotations=[annotations],
|
|
metadata=[create_metadata(context, merged_response)],
|
|
)
|
|
|
|
|
|
async def get_guardrails_check_result(
|
|
context: RequestContext,
|
|
action: GuardrailAction,
|
|
response_json: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Get the guardrails check result"""
|
|
# Determine which guardrails to apply based on the action
|
|
guardrails = (
|
|
context.guardrails.logging_guardrails
|
|
if action == GuardrailAction.LOG
|
|
else context.guardrails.blocking_guardrails
|
|
)
|
|
|
|
if not guardrails:
|
|
return {}
|
|
|
|
messages = list(context.request_json.get("messages", []))
|
|
if response_json is not None:
|
|
messages += [choice["message"] for choice in response_json.get("choices", [])]
|
|
|
|
# Block on the guardrails check
|
|
guardrails_execution_result = await check_guardrails(
|
|
messages=messages,
|
|
guardrails=guardrails,
|
|
context=context,
|
|
)
|
|
return guardrails_execution_result
|
|
|
|
|
|
class InstrumentedOpenAIResponse(InstrumentedResponse):
|
|
"""
|
|
Does an OpenAI completion request at the core, but also checks guardrails
|
|
before (concurrent) and after the request.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
context: RequestContext,
|
|
client: httpx.AsyncClient,
|
|
open_ai_request: httpx.Request,
|
|
):
|
|
super().__init__()
|
|
|
|
# request parameters
|
|
self.context: RequestContext = context
|
|
self.client: httpx.AsyncClient = client
|
|
self.open_ai_request: httpx.Request = open_ai_request
|
|
|
|
# request outputs
|
|
self.response: Optional[httpx.Response] = None
|
|
self.response_json: Optional[dict[str, Any]] = None
|
|
|
|
# guardrailing output (if any)
|
|
self.guardrails_execution_result: Optional[dict] = None
|
|
|
|
async def on_start(self):
|
|
"""
|
|
Checks guardrails in a pipelined fashion, before processing
|
|
the first chunk (for input guardrailing)
|
|
"""
|
|
if self.context.guardrails:
|
|
# block on the guardrails check
|
|
self.guardrails_execution_result = await get_guardrails_check_result(
|
|
self.context, action=GuardrailAction.BLOCK
|
|
)
|
|
if self.guardrails_execution_result.get("errors", []):
|
|
# Push annotated trace to the explorer - don't block on its response
|
|
if self.context.dataset_name:
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context,
|
|
{},
|
|
self.guardrails_execution_result,
|
|
)
|
|
)
|
|
|
|
# replace the response with the error message
|
|
return ExtraItem(
|
|
Response(
|
|
content=json.dumps(
|
|
{
|
|
"error": "[Invariant] The request did not pass the guardrails",
|
|
"details": self.guardrails_execution_result,
|
|
}
|
|
),
|
|
status_code=400,
|
|
media_type="application/json",
|
|
),
|
|
end_of_stream=True,
|
|
)
|
|
|
|
async def request(self):
|
|
"""Actual OpenAI request."""
|
|
self.response = await self.client.send(self.open_ai_request)
|
|
|
|
try:
|
|
self.response_json = self.response.json()
|
|
except json.JSONDecodeError as e:
|
|
raise HTTPException(
|
|
status_code=self.response.status_code,
|
|
detail="Invalid JSON response received from OpenAI API",
|
|
) from e
|
|
if self.response.status_code != 200:
|
|
raise HTTPException(
|
|
status_code=self.response.status_code,
|
|
detail=self.response_json.get("error", "Unknown error from OpenAI API"),
|
|
)
|
|
|
|
response_string = json.dumps(self.response_json)
|
|
response_code = self.response.status_code
|
|
|
|
return Response(
|
|
content=response_string,
|
|
status_code=response_code,
|
|
media_type="application/json",
|
|
headers=dict(self.response.headers),
|
|
)
|
|
|
|
async def on_end(self):
|
|
"""Postprocesses the OpenAI response and potentially replace it with a guardrails error."""
|
|
|
|
# these two request outputs are guaranteed to be available by the time we reach
|
|
# this point (after self.request() was executed)
|
|
# nevertheless, we check for them to avoid any potential issues
|
|
assert (
|
|
self.response is not None
|
|
), "on_end called before 'self.response' was available"
|
|
assert (
|
|
self.response_json is not None
|
|
), "on_end called before 'self.response_json' was available"
|
|
|
|
# extract original response status code
|
|
response_code = self.response.status_code
|
|
|
|
# if we have guardrails, check the response
|
|
if self.context.guardrails:
|
|
# run guardrails again, this time on request + response
|
|
self.guardrails_execution_result = await get_guardrails_check_result(
|
|
self.context,
|
|
action=GuardrailAction.BLOCK,
|
|
response_json=self.response_json,
|
|
)
|
|
if self.guardrails_execution_result.get("errors", []):
|
|
response_string = json.dumps(
|
|
{
|
|
"error": "[Invariant] The response did not pass the guardrails",
|
|
"details": self.guardrails_execution_result,
|
|
}
|
|
)
|
|
response_code = 400
|
|
|
|
# Push annotated trace to the explorer - don't block on its response
|
|
if self.context.dataset_name:
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context,
|
|
self.response_json,
|
|
self.guardrails_execution_result,
|
|
)
|
|
)
|
|
|
|
# replace the response with the error message
|
|
return ExtraItem(
|
|
Response(
|
|
content=response_string,
|
|
status_code=response_code,
|
|
media_type="application/json",
|
|
),
|
|
)
|
|
|
|
# Push annotated trace to the explorer in any case - don't block on its response
|
|
if self.context.dataset_name:
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context,
|
|
self.response_json,
|
|
# include any guardrailing errors if available
|
|
self.guardrails_execution_result,
|
|
)
|
|
)
|
|
|
|
|
|
async def handle_non_stream_response(
|
|
context: RequestContext,
|
|
client: httpx.AsyncClient,
|
|
open_ai_request: httpx.Request,
|
|
) -> Response:
|
|
"""Handles non-streaming OpenAI responses"""
|
|
|
|
response = InstrumentedOpenAIResponse(
|
|
context,
|
|
client,
|
|
open_ai_request,
|
|
)
|
|
|
|
return await response.instrumented_request()
|