This commit is contained in:
Luca Beurer-Kellner
2025-02-17 23:24:27 +01:00
parent 28974cc70c
commit fd07112d82
2 changed files with 48 additions and 20 deletions
+31 -20
View File
@@ -7,7 +7,7 @@ import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse
from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
from utils.explorer import error_label, push_trace, validate_guardrails
from utils.explorer import PromptPatch, error_label, push_trace, validate_guardrails
ALLOWED_OPEN_AI_ENDPOINTS = {"chat/completions"}
@@ -73,26 +73,36 @@ async def openai_proxy(
# Update the authorization header to pass the OpenAI API Key to the OpenAI API
headers["authorization"] = f"{api_keys[0].strip()}"
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
open_ai_request = client.build_request(
"POST",
f"https://api.openai.com/v1/{endpoint}",
content=request_body_bytes,
headers=headers,
)
if is_streaming:
return await stream_response(
client,
open_ai_request,
dataset_name,
request_body_json,
invariant_authorization,
)
async with client:
response = await client.send(open_ai_request)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
while True:
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
open_ai_request = client.build_request(
"POST",
f"https://api.openai.com/v1/{endpoint}",
content=request_body_bytes,
headers=headers,
)
# recompute compute length
open_ai_request.headers["content-length"] = str(len(request_body_bytes))
if is_streaming:
return await stream_response(
client,
open_ai_request,
dataset_name,
request_body_json,
invariant_authorization,
)
async with client:
try:
response = await client.send(open_ai_request)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
)
except PromptPatch as e:
# go into 'messages' payload and prepend the e.prompt to the first 'system' message
print(request_body_json)
request_body_json["messages"][0]["content"] += "\n" + e.patch
request_body_bytes = json.dumps(request_body_json).encode()
async def stream_response(
@@ -319,6 +329,7 @@ async def push_to_explorer(
# Combine the messages from the request body and the choices from the OpenAI response
messages = request_body.get("messages", [])
messages = [{**msg} for msg in messages]
messages += [choice["message"] for choice in merged_response.get("choices", [])]
blocked, response = await push_trace(
+17
View File
@@ -17,6 +17,11 @@ from invariant.analyzer import Policy
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
class PromptPatch(Exception):
def __init__(self, patch=""):
self.patch = patch
async def push_trace(
messages: List[List[Dict[str, Any]]],
dataset_name: str,
@@ -116,6 +121,13 @@ async def validate_guardrails(
if "content" not in msg:
msg["content"] = ""
# get first system prompt message
system_prompt = next(
(msg["content"] for msg in trace if msg["role"] == "system" and msg["content"]),
None,
)
system_prompt_content = system_prompt if system_prompt else ""
annotations = []
for guardrail in guardrails:
@@ -133,6 +145,11 @@ async def validate_guardrails(
action = label.split("action=")[1].split()[0]
if action == "block":
blocked = error
elif "patch=" in label:
print("label is", label)
patch = label.split("patch=", 1)[1]
if patch not in system_prompt_content:
raise PromptPatch(patch)
ranges = [range for range in error.ranges]