From 750c83d3f88025394afcdf503d2dc718bf74efd5 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 1 Apr 2025 14:41:18 +0200 Subject: [PATCH] Add calls to execute logging guardrails before pushing to explorer. --- gateway/integrations/guardrails.py | 23 -------- gateway/routes/anthropic.py | 59 ++++++++++--------- gateway/routes/gemini.py | 13 +++- gateway/routes/open_ai.py | 41 ++++++++----- ...est_generate_content_without_tool_calls.py | 6 +- 5 files changed, 73 insertions(+), 69 deletions(-) diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index 412418e..b7377c3 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -1,7 +1,6 @@ """Utility functions for Guardrails execution.""" import asyncio -import json import os import time from typing import Any, Dict, List @@ -351,28 +350,6 @@ async def check_guardrails( async with httpx.AsyncClient() as client: url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") try: - print( - "Hello there this is the request to guardrails: ", - json.dumps( - { - "messages": messages, - "policies": [g.content for g in guardrails], - }, - indent=2, - ), - flush=True, - ) - print( - "Hello there this is the request to guardrails: ", - json.dumps( - { - "Authorization": invariant_authorization, - "Accept": "application/json", - }, - indent=2, - ), - flush=True, - ) result = await client.post( f"{url}/api/v1/policy/check/batch", json={ diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 9905258..09ce85b 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -120,7 +120,7 @@ def create_metadata( def combine_request_and_response_messages( - context: RequestContext, json_response: dict[str, Any] + context: RequestContext, response_json: dict[str, Any] ): """Combine the request and response messages""" messages = [] @@ -129,13 +129,13 @@ def combine_request_and_response_messages( {"role": "system", "content": context.request_json.get("system")} ) messages.extend(context.request_json.get("messages", [])) - if len(json_response) > 0: - messages.append(json_response) + if len(response_json) > 0: + messages.append(response_json) return messages async def get_guardrails_check_result( - context: RequestContext, action: GuardrailAction, json_response: dict[str, Any] + context: RequestContext, action: GuardrailAction, response_json: dict[str, Any] ) -> dict[str, Any]: """Get the guardrails check result""" # Determine which guardrails to apply based on the action @@ -147,7 +147,7 @@ async def get_guardrails_check_result( if not guardrails: return {} - messages = combine_request_and_response_messages(context, json_response) + messages = combine_request_and_response_messages(context, response_json) converted_messages = convert_anthropic_to_invariant_message_format(messages) # Block on the guardrails check @@ -170,10 +170,22 @@ async def push_to_explorer( guardrails_execution_result.get("errors", []) ) + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=merged_response, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []) + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + # Combine the messages from the request body and Anthropic response messages = combine_request_and_response_messages(context, merged_response) - converted_messages = convert_anthropic_to_invariant_message_format(messages) + _ = await push_trace( dataset_name=context.dataset_name, messages=[converted_messages], @@ -200,7 +212,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): # response data self.response: Optional[httpx.Response] = None self.response_string: Optional[str] = None - self.json_response: Optional[dict[str, Any]] = None + self.response_json: Optional[dict[str, Any]] = None # guardrailing response (if any) self.guardrails_execution_result = {} @@ -209,7 +221,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, action=GuardrailAction.BLOCK, json_response={} + self.context, action=GuardrailAction.BLOCK, response_json={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -243,10 +255,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): ) async def request(self): + """Make the request to the Anthropic API.""" self.response = await self.client.send(self.anthropic_request) try: - json_response = self.response.json() + response_json = self.response.json() except json.JSONDecodeError as e: raise HTTPException( status_code=self.response.status_code, @@ -255,11 +268,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): if self.response.status_code != 200: raise HTTPException( status_code=self.response.status_code, - detail=json_response.get("error", "Unknown error from Anthropic"), + detail=response_json.get("error", "Unknown error from Anthropic"), ) - self.json_response = json_response - self.response_string = json.dumps(json_response) + self.response_json = response_json + self.response_string = json.dumps(response_json) return self._make_response( content=self.response_string, @@ -284,7 +297,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): """Checks guardrails after the response is received, and asynchronously pushes to Explorer.""" # ensure the response data is available assert self.response is not None, "response is None" - assert self.json_response is not None, "json_response is None" + assert self.response_json is not None, "response_json is None" assert self.response_string is not None, "response_string is None" if self.context.dataset_guardrails: @@ -292,12 +305,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.json_response, - ) - print( - "Here is the guardrails_execution_result in on_end in InstrumentedAnthropicResponse: ", - guardrails_execution_result, - flush=True, + response_json=self.response_json, ) if guardrails_execution_result.get("errors", []): guardrail_response_string = json.dumps( @@ -313,7 +321,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): asyncio.create_task( push_to_explorer( self.context, - self.json_response, + self.response_json, guardrails_execution_result, ) ) @@ -330,7 +338,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): # Push to Explorer - don't block on its response asyncio.create_task( push_to_explorer( - self.context, self.json_response, guardrails_execution_result + self.context, self.response_json, guardrails_execution_result ) ) @@ -378,7 +386,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.merged_response, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -440,12 +448,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.merged_response, - ) - print( - "Here is the guardrails_execution_result in on_chunk in InstrumentedAnthropicStreamingResponse: ", - self.guardrails_execution_result, - flush=True, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 1e21b90..b390461 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -290,7 +290,6 @@ async def stream_response( async def event_generator(): async for chunk in response.instrumented_event_generator(): yield chunk - print("chunk", chunk) return StreamingResponse( event_generator(), @@ -408,6 +407,18 @@ async def push_to_explorer( guardrails_execution_result.get("errors", []) ) + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=response_json, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []) + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + converted_requests = convert_request(context.request_json) converted_responses = convert_response(response_json) diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index ada5309..f4a20f4 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -152,7 +152,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.merged_response, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -203,7 +203,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.merged_response, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -438,6 +438,19 @@ async def push_to_explorer( not in FINISH_REASON_TO_PUSH_TRACE ): annotations = create_annotations_from_guardrails_errors(guardrails_errors) + + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=merged_response, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []) + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + # Combine the messages from the request body and the choices from the OpenAI response messages = list(context.request_json.get("messages", [])) messages += [choice["message"] for choice in merged_response.get("choices", [])] @@ -453,7 +466,7 @@ async def push_to_explorer( async def get_guardrails_check_result( context: RequestContext, action: GuardrailAction, - json_response: dict[str, Any] | None = None, + response_json: dict[str, Any] | None = None, ) -> dict[str, Any]: """Get the guardrails check result""" # Determine which guardrails to apply based on the action @@ -466,8 +479,8 @@ async def get_guardrails_check_result( return {} messages = list(context.request_json.get("messages", [])) - if json_response is not None: - messages += [choice["message"] for choice in json_response.get("choices", [])] + if response_json is not None: + messages += [choice["message"] for choice in response_json.get("choices", [])] # Block on the guardrails check guardrails_execution_result = await check_guardrails( @@ -499,7 +512,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): # request outputs self.response: Optional[httpx.Response] = None - self.json_response: Optional[dict[str, Any]] = None + self.response_json: Optional[dict[str, Any]] = None # guardrailing output (if any) self.guardrails_execution_result: Optional[dict] = None @@ -545,7 +558,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.response = await self.client.send(self.open_ai_request) try: - self.json_response = self.response.json() + self.response_json = self.response.json() except json.JSONDecodeError as e: raise HTTPException( status_code=self.response.status_code, @@ -554,10 +567,10 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): if self.response.status_code != 200: raise HTTPException( status_code=self.response.status_code, - detail=self.json_response.get("error", "Unknown error from OpenAI API"), + detail=self.response_json.get("error", "Unknown error from OpenAI API"), ) - response_string = json.dumps(self.json_response) + response_string = json.dumps(self.response_json) response_code = self.response.status_code return Response( @@ -577,8 +590,8 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.response is not None ), "on_end called before 'self.response' was available" assert ( - self.json_response is not None - ), "on_end called before 'self.json_response' was available" + self.response_json is not None + ), "on_end called before 'self.response_json' was available" # extract original response status code response_code = self.response.status_code @@ -589,7 +602,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.json_response, + response_json=self.response_json, ) if self.guardrails_execution_result.get("errors", []): response_string = json.dumps( @@ -605,7 +618,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): asyncio.create_task( push_to_explorer( self.context, - self.json_response, + self.response_json, self.guardrails_execution_result, ) ) @@ -624,7 +637,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): asyncio.create_task( push_to_explorer( self.context, - self.json_response, + self.response_json, # include any guardrailing errors if available self.guardrails_execution_result, ) diff --git a/tests/integration/gemini/test_generate_content_without_tool_calls.py b/tests/integration/gemini/test_generate_content_without_tool_calls.py index cf3f42f..84ed352 100644 --- a/tests/integration/gemini/test_generate_content_without_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_without_tool_calls.py @@ -195,14 +195,14 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( chat_response = client.models.generate_content( model="gemini-2.0-flash", - contents="What is the capital of Spain?", + contents="What is the capital of Denmark?", config={ "maxOutputTokens": 100, }, ) # Verify the chat response - assert "MADRID" in chat_response.candidates[0].content.parts[0].text.upper() + assert "COPENHAGEN" in chat_response.candidates[0].content.parts[0].text.upper() expected_assistant_message = chat_response.candidates[0].content.parts[0].text # Wait for the trace to be saved @@ -229,7 +229,7 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( assert trace["messages"] == [ { "role": "user", - "content": [{"text": "What is the capital of Spain?", "type": "text"}], + "content": [{"text": "What is the capital of Denmark?", "type": "text"}], }, { "role": "assistant",