From ca3c43ef7635cf6e7fd9e358c1153908a54c9d64 Mon Sep 17 00:00:00 2001 From: Hemang Date: Mon, 17 Mar 2025 07:35:02 +0100 Subject: [PATCH] Refactor guardrails check for openai route. --- gateway/routes/open_ai.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 0e7338d..ea70bc6 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -310,6 +310,26 @@ async def push_to_explorer( ) +async def get_guardrails_check_result( + context: RequestContextData, json_response: dict[str, Any] +) -> dict[str, Any]: + """Get the guardrails check result""" + messages = list(context.request_json.get("messages", [])) + messages += [choice["message"] for choice in json_response.get("choices", [])] + # TODO: Remove this once the guardrails API is fixed + for message in messages: + if "tool_calls" in message and message["tool_calls"] is None: + message["tool_calls"] = [] + + # Block on the guardrails check + guardrails_execution_result = await check_guardrails( + messages=messages, + guardrails=context.config.guardrails, + invariant_authorization=context.invariant_authorization, + ) + return guardrails_execution_result + + async def handle_non_streaming_response( context: RequestContextData, response: httpx.Response ) -> Response: @@ -332,18 +352,9 @@ async def handle_non_streaming_response( response_code = response.status_code if context.config and context.config.guardrails: - messages = list(context.request_json.get("messages", [])) - messages += [choice["message"] for choice in json_response.get("choices", [])] - # TODO: Remove this once the guardrails API is fixed - for message in messages: - if "tool_calls" in message and message["tool_calls"] is None: - message["tool_calls"] = [] - # Block on the guardrails check - guardrails_execution_result = await check_guardrails( - messages=messages, - guardrails=context.config.guardrails, - invariant_authorization=context.invariant_authorization, + guardrails_execution_result = await get_guardrails_check_result( + context, json_response ) if guardrails_execution_result.get("errors", []): response_string = json.dumps(