mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-22 14:59:41 +02:00
Refactor guardrails check for openai route.
This commit is contained in:
+22
-11
@@ -310,6 +310,26 @@ async def push_to_explorer(
|
||||
)
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContextData, json_response: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
messages += [choice["message"] for choice in json_response.get("choices", [])]
|
||||
# TODO: Remove this once the guardrails API is fixed
|
||||
for message in messages:
|
||||
if "tool_calls" in message and message["tool_calls"] is None:
|
||||
message["tool_calls"] = []
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=messages,
|
||||
guardrails=context.config.guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData, response: httpx.Response
|
||||
) -> Response:
|
||||
@@ -332,18 +352,9 @@ async def handle_non_streaming_response(
|
||||
response_code = response.status_code
|
||||
|
||||
if context.config and context.config.guardrails:
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
messages += [choice["message"] for choice in json_response.get("choices", [])]
|
||||
# TODO: Remove this once the guardrails API is fixed
|
||||
for message in messages:
|
||||
if "tool_calls" in message and message["tool_calls"] is None:
|
||||
message["tool_calls"] = []
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=messages,
|
||||
guardrails=context.config.guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, json_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
|
||||
Reference in New Issue
Block a user