From fd07112d82bfc2bb93ab6bc65551e62a3295212b Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Mon, 17 Feb 2025 23:24:27 +0100 Subject: [PATCH] patching --- proxy/routes/open_ai.py | 51 +++++++++++++++++++++++++---------------- proxy/utils/explorer.py | 17 ++++++++++++++ 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/proxy/routes/open_ai.py b/proxy/routes/open_ai.py index eb0d00b..2f92c84 100644 --- a/proxy/routes/open_ai.py +++ b/proxy/routes/open_ai.py @@ -7,7 +7,7 @@ import httpx from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from starlette.responses import StreamingResponse from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS -from utils.explorer import error_label, push_trace, validate_guardrails +from utils.explorer import PromptPatch, error_label, push_trace, validate_guardrails ALLOWED_OPEN_AI_ENDPOINTS = {"chat/completions"} @@ -73,26 +73,36 @@ async def openai_proxy( # Update the authorization header to pass the OpenAI API Key to the OpenAI API headers["authorization"] = f"{api_keys[0].strip()}" - client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) - open_ai_request = client.build_request( - "POST", - f"https://api.openai.com/v1/{endpoint}", - content=request_body_bytes, - headers=headers, - ) - if is_streaming: - return await stream_response( - client, - open_ai_request, - dataset_name, - request_body_json, - invariant_authorization, - ) - async with client: - response = await client.send(open_ai_request) - return await handle_non_streaming_response( - response, dataset_name, request_body_json, invariant_authorization + while True: + client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) + open_ai_request = client.build_request( + "POST", + f"https://api.openai.com/v1/{endpoint}", + content=request_body_bytes, + headers=headers, ) + # recompute compute length + open_ai_request.headers["content-length"] = str(len(request_body_bytes)) + + if is_streaming: + return await stream_response( + client, + open_ai_request, + dataset_name, + request_body_json, + invariant_authorization, + ) + async with client: + try: + response = await client.send(open_ai_request) + return await handle_non_streaming_response( + response, dataset_name, request_body_json, invariant_authorization + ) + except PromptPatch as e: + # go into 'messages' payload and prepend the e.prompt to the first 'system' message + print(request_body_json) + request_body_json["messages"][0]["content"] += "\n" + e.patch + request_body_bytes = json.dumps(request_body_json).encode() async def stream_response( @@ -319,6 +329,7 @@ async def push_to_explorer( # Combine the messages from the request body and the choices from the OpenAI response messages = request_body.get("messages", []) + messages = [{**msg} for msg in messages] messages += [choice["message"] for choice in merged_response.get("choices", [])] blocked, response = await push_trace( diff --git a/proxy/utils/explorer.py b/proxy/utils/explorer.py index 852bcf6..4589886 100644 --- a/proxy/utils/explorer.py +++ b/proxy/utils/explorer.py @@ -17,6 +17,11 @@ from invariant.analyzer import Policy DEFAULT_API_URL = "https://explorer.invariantlabs.ai" +class PromptPatch(Exception): + def __init__(self, patch=""): + self.patch = patch + + async def push_trace( messages: List[List[Dict[str, Any]]], dataset_name: str, @@ -116,6 +121,13 @@ async def validate_guardrails( if "content" not in msg: msg["content"] = "" + # get first system prompt message + system_prompt = next( + (msg["content"] for msg in trace if msg["role"] == "system" and msg["content"]), + None, + ) + system_prompt_content = system_prompt if system_prompt else "" + annotations = [] for guardrail in guardrails: @@ -133,6 +145,11 @@ async def validate_guardrails( action = label.split("action=")[1].split()[0] if action == "block": blocked = error + elif "patch=" in label: + print("label is", label) + patch = label.split("patch=", 1)[1] + if patch not in system_prompt_content: + raise PromptPatch(patch) ranges = [range for range in error.ranges]