Refactor guardrails check for openai route.

This commit is contained in:
Hemang
2025-03-17 07:35:02 +01:00
committed by Hemang Sarkar
parent a5ea86a64e
commit ca3c43ef76
+22 -11
View File
@@ -310,6 +310,26 @@ async def push_to_explorer(
)
async def get_guardrails_check_result(
context: RequestContextData, json_response: dict[str, Any]
) -> dict[str, Any]:
"""Get the guardrails check result"""
messages = list(context.request_json.get("messages", []))
messages += [choice["message"] for choice in json_response.get("choices", [])]
# TODO: Remove this once the guardrails API is fixed
for message in messages:
if "tool_calls" in message and message["tool_calls"] is None:
message["tool_calls"] = []
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
messages=messages,
guardrails=context.config.guardrails,
invariant_authorization=context.invariant_authorization,
)
return guardrails_execution_result
async def handle_non_streaming_response(
context: RequestContextData, response: httpx.Response
) -> Response:
@@ -332,18 +352,9 @@ async def handle_non_streaming_response(
response_code = response.status_code
if context.config and context.config.guardrails:
messages = list(context.request_json.get("messages", []))
messages += [choice["message"] for choice in json_response.get("choices", [])]
# TODO: Remove this once the guardrails API is fixed
for message in messages:
if "tool_calls" in message and message["tool_calls"] is None:
message["tool_calls"] = []
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
messages=messages,
guardrails=context.config.guardrails,
invariant_authorization=context.invariant_authorization,
guardrails_execution_result = await get_guardrails_check_result(
context, json_response
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(