mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-24 15:54:05 +02:00
patching
This commit is contained in:
+31
-20
@@ -7,7 +7,7 @@ import httpx
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from starlette.responses import StreamingResponse
|
||||
from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
|
||||
from utils.explorer import error_label, push_trace, validate_guardrails
|
||||
from utils.explorer import PromptPatch, error_label, push_trace, validate_guardrails
|
||||
|
||||
ALLOWED_OPEN_AI_ENDPOINTS = {"chat/completions"}
|
||||
|
||||
@@ -73,26 +73,36 @@ async def openai_proxy(
|
||||
# Update the authorization header to pass the OpenAI API Key to the OpenAI API
|
||||
headers["authorization"] = f"{api_keys[0].strip()}"
|
||||
|
||||
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
|
||||
open_ai_request = client.build_request(
|
||||
"POST",
|
||||
f"https://api.openai.com/v1/{endpoint}",
|
||||
content=request_body_bytes,
|
||||
headers=headers,
|
||||
)
|
||||
if is_streaming:
|
||||
return await stream_response(
|
||||
client,
|
||||
open_ai_request,
|
||||
dataset_name,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
)
|
||||
async with client:
|
||||
response = await client.send(open_ai_request)
|
||||
return await handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, invariant_authorization
|
||||
while True:
|
||||
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
|
||||
open_ai_request = client.build_request(
|
||||
"POST",
|
||||
f"https://api.openai.com/v1/{endpoint}",
|
||||
content=request_body_bytes,
|
||||
headers=headers,
|
||||
)
|
||||
# recompute compute length
|
||||
open_ai_request.headers["content-length"] = str(len(request_body_bytes))
|
||||
|
||||
if is_streaming:
|
||||
return await stream_response(
|
||||
client,
|
||||
open_ai_request,
|
||||
dataset_name,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
)
|
||||
async with client:
|
||||
try:
|
||||
response = await client.send(open_ai_request)
|
||||
return await handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, invariant_authorization
|
||||
)
|
||||
except PromptPatch as e:
|
||||
# go into 'messages' payload and prepend the e.prompt to the first 'system' message
|
||||
print(request_body_json)
|
||||
request_body_json["messages"][0]["content"] += "\n" + e.patch
|
||||
request_body_bytes = json.dumps(request_body_json).encode()
|
||||
|
||||
|
||||
async def stream_response(
|
||||
@@ -319,6 +329,7 @@ async def push_to_explorer(
|
||||
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
messages = request_body.get("messages", [])
|
||||
messages = [{**msg} for msg in messages]
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
|
||||
blocked, response = await push_trace(
|
||||
|
||||
@@ -17,6 +17,11 @@ from invariant.analyzer import Policy
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
class PromptPatch(Exception):
|
||||
def __init__(self, patch=""):
|
||||
self.patch = patch
|
||||
|
||||
|
||||
async def push_trace(
|
||||
messages: List[List[Dict[str, Any]]],
|
||||
dataset_name: str,
|
||||
@@ -116,6 +121,13 @@ async def validate_guardrails(
|
||||
if "content" not in msg:
|
||||
msg["content"] = ""
|
||||
|
||||
# get first system prompt message
|
||||
system_prompt = next(
|
||||
(msg["content"] for msg in trace if msg["role"] == "system" and msg["content"]),
|
||||
None,
|
||||
)
|
||||
system_prompt_content = system_prompt if system_prompt else ""
|
||||
|
||||
annotations = []
|
||||
|
||||
for guardrail in guardrails:
|
||||
@@ -133,6 +145,11 @@ async def validate_guardrails(
|
||||
action = label.split("action=")[1].split()[0]
|
||||
if action == "block":
|
||||
blocked = error
|
||||
elif "patch=" in label:
|
||||
print("label is", label)
|
||||
patch = label.split("patch=", 1)[1]
|
||||
if patch not in system_prompt_content:
|
||||
raise PromptPatch(patch)
|
||||
|
||||
ranges = [range for range in error.ranges]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user