mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-23 23:34:02 +02:00
simplify request instrumentation
This commit is contained in:
committed by
Hemang
parent
e66232215e
commit
55db93c8d3
@@ -6,7 +6,13 @@ import time
|
||||
from typing import Any, Dict, List
|
||||
from functools import wraps
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
import httpx
|
||||
<<<<<<< HEAD:gateway/integrations/guardails.py
|
||||
=======
|
||||
from zmq import IO_THREADS
|
||||
from common.request_context_data import RequestContextData
|
||||
>>>>>>> 91684ce (simplify request instrumentation):gateway/integrations/guardrails.py
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
@@ -99,102 +105,59 @@ async def preload_guardrails(context: "RequestContextData") -> None:
|
||||
print(f"Error scheduling preload_guardrails task: {e}")
|
||||
|
||||
|
||||
class YieldException(Exception):
|
||||
class ExtraItem:
|
||||
"""
|
||||
Raise this exception in stream instrumentor listeners to
|
||||
end the stream early, or to emit additional items in a stream.
|
||||
Return this class in a instrumented stream callback, to yield an extra item in the resulting stream.
|
||||
"""
|
||||
|
||||
def __init__(self, value, end_of_stream=False):
|
||||
super().__init__(value)
|
||||
self.value = value
|
||||
self.end_of_stream = end_of_stream
|
||||
|
||||
def __str__(self):
|
||||
return f"YieldException: {self.value}"
|
||||
return f"<ExtraItem value={self.value} end_of_stream={self.end_of_stream}>"
|
||||
|
||||
|
||||
class StreamInstrumentor:
|
||||
"""
|
||||
A class to instrument async iterables with hooks for processing
|
||||
chunks, before processing, and on completion.
|
||||
|
||||
Use `@on('chunk')`, `@on('start')`, and `@on('end')` decorators
|
||||
to register listeners for different events.
|
||||
|
||||
Listeners can simply process data, or alternatively raise a designated
|
||||
YieldException to yield additional values or stop the stream.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
instrumentor = StreamInstrumentor()
|
||||
|
||||
@instrumentor.on('chunk')
|
||||
async def process_chunk(chunk):
|
||||
# Process the chunk
|
||||
print(f"Processing chunk: {chunk}")
|
||||
|
||||
if some_condition:
|
||||
# Yield an additional value that will be interleaved in the stream
|
||||
# Pass `end_of_stream=True` to stop the stream after yielding
|
||||
# Pass `end_of_stream=False` to continue the stream after the interleaved value
|
||||
raise YieldException("Extra value", end_of_stream=True)
|
||||
```
|
||||
"""
|
||||
|
||||
class InstrumentedStreamingResponse:
|
||||
def __init__(self):
|
||||
# called on every chunk (async)
|
||||
self.on_chunk_listeners = []
|
||||
# called once before the first chunk is processed, or even earlier (async)
|
||||
self.before_listeners = []
|
||||
# called once on stream completion (async)
|
||||
self.on_complete_listeners = []
|
||||
|
||||
# request statistics
|
||||
self.stat_token_times = []
|
||||
self.stat_before_time = None
|
||||
self.stat_after_time = None
|
||||
|
||||
self.stat_first_item_time = None
|
||||
|
||||
# decorator
|
||||
def on(self, event: str):
|
||||
async def on_chunk(self, chunk: Any) -> ExtraItem | None:
|
||||
"""
|
||||
Decorator to register listeners for different events.
|
||||
This called will be called on every chunk (async).
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_start(self) -> ExtraItem | None:
|
||||
"""
|
||||
Decorator to register a listener for start events.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_end(self) -> ExtraItem | None:
|
||||
"""
|
||||
Decorator to register a listener for end events.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def event_generator(self):
|
||||
"""
|
||||
Streams the async iterable and invokes all instrumented hooks.
|
||||
|
||||
Args:
|
||||
event (str): The event to listen for. Can be 'on_chunk',
|
||||
'before', or 'on_complete'.
|
||||
async_iterable: An async iterable to stream.
|
||||
|
||||
Returns:
|
||||
Callable: A decorator to register the listener.
|
||||
Yields:
|
||||
The streamed data.
|
||||
"""
|
||||
raise NotImplementedError("This method should be implemented in a subclass.")
|
||||
|
||||
def decorator(func):
|
||||
assert asyncio.iscoroutinefunction(
|
||||
func
|
||||
), "Listener must be an async function"
|
||||
|
||||
if event == "chunk":
|
||||
if self.on_chunk_listeners is None:
|
||||
self.on_chunk_listeners = []
|
||||
self.on_chunk_listeners.append(func)
|
||||
elif event == "start":
|
||||
if self.before_listeners is None:
|
||||
self.before_listeners = []
|
||||
self.before_listeners.append(func)
|
||||
elif event == "end":
|
||||
if self.on_complete_listeners is None:
|
||||
self.on_complete_listeners = []
|
||||
self.on_complete_listeners.append(func)
|
||||
else:
|
||||
raise ValueError("Invalid event type. Use 'chunk', 'before', or 'end'.")
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def stream(self, async_iterable):
|
||||
async def instrumented_event_generator(self):
|
||||
"""
|
||||
Streams the async iterable and invokes all instrumented hooks.
|
||||
|
||||
@@ -207,14 +170,11 @@ class StreamInstrumentor:
|
||||
try:
|
||||
start = time.time()
|
||||
|
||||
# schedule all before listeners which can be run concurrently
|
||||
before_tasks = [
|
||||
asyncio.create_task(listener(), name="instrumentor:start")
|
||||
for listener in self.before_listeners
|
||||
]
|
||||
# schedule on_start which can be run concurrently
|
||||
start_task = asyncio.create_task(self.on_start(), name="instrumentor:start")
|
||||
|
||||
# create async iterator from async_iterable
|
||||
aiterable = aiter(async_iterable)
|
||||
aiterable = aiter(self.event_generator())
|
||||
|
||||
# [STAT] capture start time of first item
|
||||
start_first_item_request = time.time()
|
||||
@@ -233,31 +193,23 @@ class StreamInstrumentor:
|
||||
wait_for_first_item(), name="instrumentor:next:first"
|
||||
)
|
||||
|
||||
# wait for all before listeners to finish
|
||||
has_end_of_stream = False
|
||||
for before_task in before_tasks:
|
||||
try:
|
||||
await before_task
|
||||
except YieldException as e:
|
||||
# yield extra value before any real items
|
||||
yield e.value
|
||||
# stop the stream if end_of_stream is True
|
||||
if e.end_of_stream:
|
||||
# if first item is already available
|
||||
if not next_item_task.done():
|
||||
# cancel the task
|
||||
next_item_task.cancel()
|
||||
# [STAT] capture time to first item to be now +0.01
|
||||
if self.stat_first_item_time is None:
|
||||
self.stat_first_item_time = (
|
||||
time.time() - start_first_item_request
|
||||
) + 0.01
|
||||
has_end_of_stream = True
|
||||
|
||||
# don't wait for the first item if end_of stream is True
|
||||
if has_end_of_stream:
|
||||
# if end_of_stream is True, stop the stream
|
||||
return
|
||||
# check if 'start_task' yields an extra item
|
||||
if extra_item := await start_task:
|
||||
# yield extra value before any real items
|
||||
yield extra_item.value
|
||||
# stop the stream if end_of_stream is True
|
||||
if extra_item.end_of_stream:
|
||||
# if first item is already available
|
||||
if not next_item_task.done():
|
||||
# cancel the task
|
||||
next_item_task.cancel()
|
||||
# [STAT] capture time to first item to be now +0.01
|
||||
if self.stat_first_item_time is None:
|
||||
self.stat_first_item_time = (
|
||||
time.time() - start_first_item_request
|
||||
) + 0.01
|
||||
# don't wait for the first item if end_of stream is True
|
||||
return
|
||||
|
||||
# [STAT] capture before time stamp
|
||||
self.stat_before_time = time.time() - start
|
||||
@@ -282,35 +234,20 @@ class StreamInstrumentor:
|
||||
time.time() - start - sum(self.stat_token_times)
|
||||
)
|
||||
|
||||
# invoke on_chunk listeners
|
||||
any_end_of_stream = False
|
||||
for listener in self.on_chunk_listeners:
|
||||
try:
|
||||
await listener(item)
|
||||
except YieldException as e:
|
||||
yield e.value
|
||||
# if end_of_stream is True, stop the stream
|
||||
if e.end_of_stream:
|
||||
any_end_of_stream = True
|
||||
|
||||
# if end_of_stream is True, stop the stream
|
||||
if any_end_of_stream:
|
||||
return
|
||||
if extra_item := await self.on_chunk(item):
|
||||
yield extra_item.value
|
||||
# if end_of_stream is True, stop the stream
|
||||
if extra_item.end_of_stream:
|
||||
return
|
||||
|
||||
# yield item
|
||||
yield item
|
||||
|
||||
on_complete_tasks = [
|
||||
asyncio.create_task(listener(), name="instrumentor:end")
|
||||
for listener in self.on_complete_listeners
|
||||
]
|
||||
for result in asyncio.as_completed(on_complete_tasks):
|
||||
try:
|
||||
await result
|
||||
except YieldException as e:
|
||||
# yield extra value before any real items
|
||||
yield e.value
|
||||
# we ignore end_of_stream here, because we are already at the end
|
||||
# run on_end, before closing the stream (may yield an extra value)
|
||||
if extra_item := await self.on_end():
|
||||
# yield extra value before any real items
|
||||
yield extra_item.value
|
||||
# we ignore end_of_stream here, because we are already at the end
|
||||
|
||||
# [STAT] capture after time stamp
|
||||
self.stat_after_time = time.time() - start
|
||||
@@ -344,29 +281,35 @@ class StreamInstrumentor:
|
||||
print(f" [total: {time.time() - start:.2f}s]")
|
||||
|
||||
|
||||
class RequestInstrumentor(StreamInstrumentor):
|
||||
class InstrumentedResponse(InstrumentedStreamingResponse):
|
||||
"""
|
||||
Like 'StreamInstrumentor', but for non-streaming requests.
|
||||
|
||||
Supports similar 'start', 'end' events, but not 'chunk', since everything is assumed
|
||||
to be processed in one chunk (i.e., the request).
|
||||
A class to instrument an async request with hooks for concurrent
|
||||
pre-processing and post-processing (input and output guardrailing).
|
||||
"""
|
||||
|
||||
def on(self, event):
|
||||
assert event in [
|
||||
"start",
|
||||
"end",
|
||||
], "RequestInstrumentor does not support 'chunk' events"
|
||||
return super().on(event)
|
||||
async def event_generator(self):
|
||||
"""
|
||||
We implement the 'event_generator' as a single item stream,
|
||||
where the item is the full result of the request.
|
||||
"""
|
||||
yield await self.request()
|
||||
|
||||
async def execute(self, request_task):
|
||||
async def wrapped_request_task():
|
||||
yield await request_task
|
||||
async def request(self):
|
||||
"""
|
||||
This method should be implemented in a subclass to perform the actual request.
|
||||
"""
|
||||
raise NotImplementedError("This method should be implemented in a subclass.")
|
||||
|
||||
# pretend the 'request_task' is an async iterable with a single item
|
||||
result = [item async for item in self.stream(wrapped_request_task())]
|
||||
assert len(result) >= 1, "RequestInstrumentor must yield at least one item"
|
||||
return result[-1]
|
||||
async def instrumented_request(self):
|
||||
"""
|
||||
Returns the 'Response' object of the request, after applying all instrumented hooks.
|
||||
"""
|
||||
results = [r async for r in self.instrumented_event_generator()]
|
||||
assert len(results) >= 1, "InstrumentedResponse must yield at least one item"
|
||||
|
||||
# we return the last item, in case the end callback yields an extra item. Then,
|
||||
# don't return the actual result but the 'end' result, e.g. for output guardrailing.
|
||||
return results[-1]
|
||||
|
||||
|
||||
async def check_guardrails(
|
||||
@@ -395,6 +338,10 @@ async def check_guardrails(
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
if not result.is_success:
|
||||
raise Exception(
|
||||
f"Guardrails check failed: {result.status_code} - {result.text}"
|
||||
)
|
||||
print(f"Guardrail check response: {result.json()}")
|
||||
return result.json()
|
||||
except Exception as e:
|
||||
@@ -18,7 +18,7 @@ from converters.anthropic_to_invariant import (
|
||||
)
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
from integrations.guardails import check_guardrails, preload_guardrails
|
||||
from integrations.guardrails import check_guardrails, preload_guardrails
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
from converters.gemini_to_invariant import convert_request, convert_response
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from integrations.guardails import check_guardrails, preload_guardrails
|
||||
from integrations.guardrails import check_guardrails, preload_guardrails
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
|
||||
+250
-200
@@ -13,10 +13,10 @@ from common.constants import (
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from integrations.guardails import (
|
||||
RequestInstrumentor,
|
||||
StreamInstrumentor,
|
||||
YieldException,
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
InstrumentedStreamingResponse,
|
||||
check_guardrails,
|
||||
preload_guardrails,
|
||||
)
|
||||
@@ -89,23 +89,131 @@ async def openai_chat_completions_gateway(
|
||||
return await handle_non_streaming_response(context, client, open_ai_request)
|
||||
|
||||
|
||||
async def stream_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
) -> Response:
|
||||
"""
|
||||
Handles streaming the OpenAI response to the client while building a merged_response
|
||||
The chunks are returned to the caller immediately
|
||||
The merged_response is built from the chunks as they are received
|
||||
It is sent to the Invariant Explorer at the end of the stream
|
||||
"""
|
||||
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
async def request_and_stream():
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.open_ai_request: httpx.Request = open_ai_request
|
||||
|
||||
# guardrailing output (if any)
|
||||
self.guardrails_execution_result: Optional[dict] = None
|
||||
|
||||
# merged_response will be updated with the data from the chunks in the stream
|
||||
# At the end of the stream, this will be sent to the explorer
|
||||
self.merged_response = {
|
||||
"id": None,
|
||||
"object": "chat.completion",
|
||||
"created": None,
|
||||
"model": None,
|
||||
"choices": [],
|
||||
"usage": None,
|
||||
}
|
||||
|
||||
# Each chunk in the stream contains a list called "choices" each entry in the list
|
||||
# has an index.
|
||||
# A choice has a field called "delta" which may contain a list called "tool_calls".
|
||||
# Maps the choice index in the stream to the index in the merged_response["choices"] list
|
||||
self.choice_mapping_by_index = {}
|
||||
# Combines the choice index and tool call index to uniquely identify a tool call
|
||||
self.tool_call_mapping_by_index = {}
|
||||
|
||||
async def on_start(self):
|
||||
# check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The request did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.merged_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we end the stream prematurely (end_of_stream=True)
|
||||
# and yield an error chunk instead of actually beginning the stream
|
||||
return ExtraItem(
|
||||
f"data: {error_chunk}\n\n".encode(), end_of_stream=True
|
||||
)
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
# process and check each chunk
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
return
|
||||
|
||||
# Process the chunk
|
||||
# This will update merged_response with the data from the chunk
|
||||
process_chunk_text(
|
||||
chunk_text,
|
||||
self.merged_response,
|
||||
self.choice_mapping_by_index,
|
||||
self.tool_call_mapping_by_index,
|
||||
)
|
||||
|
||||
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
|
||||
if (
|
||||
"data: [DONE]" in chunk_text
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# yield an extra error chunk (without preventing the original chunk to go through after)
|
||||
return ExtraItem(f"data: {error_chunk}\n\n".encode())
|
||||
|
||||
# push will happen in on_end
|
||||
|
||||
async def on_end(self):
|
||||
# Send full merged response to the explorer
|
||||
# Don't block on the response from explorer
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context, self.merged_response, self.guardrails_execution_result
|
||||
)
|
||||
)
|
||||
|
||||
async def event_generator(self):
|
||||
"""
|
||||
Sets off the request and then streams the result.
|
||||
Actual OpenAI stream response.
|
||||
"""
|
||||
response = await client.send(open_ai_request, stream=True)
|
||||
|
||||
response = await self.client.send(self.open_ai_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
@@ -119,123 +227,28 @@ async def stream_response(
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
async def event_generator() -> Any:
|
||||
# merged_response will be updated with the data from the chunks in the stream
|
||||
# At the end of the stream, this will be sent to the explorer
|
||||
merged_response = {
|
||||
"id": None,
|
||||
"object": "chat.completion",
|
||||
"created": None,
|
||||
"model": None,
|
||||
"choices": [],
|
||||
"usage": None,
|
||||
}
|
||||
|
||||
# Each chunk in the stream contains a list called "choices" each entry in the list
|
||||
# has an index.
|
||||
# A choice has a field called "delta" which may contain a list called "tool_calls".
|
||||
# Maps the choice index in the stream to the index in the merged_response["choices"] list
|
||||
choice_mapping_by_index = {}
|
||||
# Combines the choice index and tool call index to uniquely identify a tool call
|
||||
tool_call_mapping_by_index = {}
|
||||
async def stream_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
) -> Response:
|
||||
"""
|
||||
Handles streaming the OpenAI response to the client while building a merged_response
|
||||
The chunks are returned to the caller immediately
|
||||
The merged_response is built from the chunks as they are received
|
||||
It is sent to the Invariant Explorer at the end of the stream
|
||||
"""
|
||||
|
||||
# prepare stream instrumentor
|
||||
instrumentor = StreamInstrumentor()
|
||||
response = InstrumentedOpenAIStreamResponse(
|
||||
context,
|
||||
client,
|
||||
open_ai_request,
|
||||
)
|
||||
|
||||
@instrumentor.on("start")
|
||||
async def precheck_guardrails() -> None:
|
||||
# check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)
|
||||
if context.config and context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, merged_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The request did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
merged_response,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we end the stream prematurely (end_of_stream=True)
|
||||
# and yield an error chunk instead of actually beginning the stream
|
||||
raise YieldException(
|
||||
f"data: {error_chunk}\n\n".encode(), end_of_stream=True
|
||||
)
|
||||
|
||||
@instrumentor.on("chunk")
|
||||
async def process_chunk(chunk: bytes) -> None:
|
||||
# process and check each chunk
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
return
|
||||
|
||||
# Process the chunk
|
||||
# This will update merged_response with the data from the chunk
|
||||
process_chunk_text(
|
||||
chunk_text,
|
||||
merged_response,
|
||||
choice_mapping_by_index,
|
||||
tool_call_mapping_by_index,
|
||||
)
|
||||
|
||||
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
|
||||
if (
|
||||
"data: [DONE]" in chunk_text
|
||||
and context.config
|
||||
and context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, merged_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
merged_response,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# yield an extra error chunk (without preventing the original chunk to go through after)
|
||||
raise YieldException(f"data: {error_chunk}\n\n".encode())
|
||||
|
||||
@instrumentor.on("end")
|
||||
async def send_to_explorer() -> None:
|
||||
# Send full merged response to the explorer
|
||||
# Don't block on the response from explorer
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(push_to_explorer(context, merged_response))
|
||||
|
||||
async for chunk in instrumentor.stream(request_and_stream()):
|
||||
# Yield chunk to the client
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
return StreamingResponse(
|
||||
response.instrumented_event_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
def initialize_merged_response() -> dict[str, Any]:
|
||||
@@ -436,73 +449,51 @@ async def get_guardrails_check_result(
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
) -> Response:
|
||||
"""Handles non-streaming OpenAI responses"""
|
||||
class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
instrumentor = RequestInstrumentor()
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.open_ai_request: httpx.Request = open_ai_request
|
||||
|
||||
# respond we get and its JSON decoded version
|
||||
# available once the 'send_request' function has progressed to the point of
|
||||
# being able to call 'response.json()'
|
||||
response = None
|
||||
json_response = None
|
||||
# request outputs
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.json_response: Optional[dict[str, Any]] = None
|
||||
|
||||
async def send_request():
|
||||
nonlocal response, json_response
|
||||
self.guardrails_execution_result: Optional[dict] = None
|
||||
|
||||
response = await client.send(open_ai_request)
|
||||
|
||||
try:
|
||||
json_response = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="Invalid JSON response received from OpenAI API",
|
||||
) from e
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from OpenAI API"),
|
||||
)
|
||||
|
||||
response_string = json.dumps(json_response)
|
||||
response_code = response.status_code
|
||||
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
@instrumentor.on("start")
|
||||
async def precheck_guardrails() -> None:
|
||||
async def on_start(self):
|
||||
# check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)
|
||||
if context.config and context.config.guardrails:
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
# block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(context)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if context.dataset_name:
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
self.context,
|
||||
{},
|
||||
guardrails_execution_result,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# replace the response with the error message
|
||||
raise YieldException(
|
||||
return ExtraItem(
|
||||
Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
),
|
||||
status_code=400,
|
||||
@@ -511,38 +502,77 @@ async def handle_non_streaming_response(
|
||||
end_of_stream=True,
|
||||
)
|
||||
|
||||
@instrumentor.on("end")
|
||||
async def postprocess_guardrails() -> None:
|
||||
async def request(self):
|
||||
"""
|
||||
Actual OpenAI request.
|
||||
"""
|
||||
self.response = await self.client.send(self.open_ai_request)
|
||||
|
||||
try:
|
||||
self.json_response = self.response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail="Invalid JSON response received from OpenAI API",
|
||||
) from e
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=self.json_response.get("error", "Unknown error from OpenAI API"),
|
||||
)
|
||||
|
||||
response_string = json.dumps(self.json_response)
|
||||
response_code = self.response.status_code
|
||||
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(self.response.headers),
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""
|
||||
Postprocess the OpenAI response and potentially replace it with a guardrails error.
|
||||
"""
|
||||
# these two are guaranteed to be set by the time we reach this point (after self.request() was executed)
|
||||
assert (
|
||||
self.response is not None
|
||||
), "on_end called before 'self.response' was available"
|
||||
assert (
|
||||
self.json_response is not None
|
||||
), "on_end called before 'self.json_response' was available"
|
||||
|
||||
# at this point, we are guaranteed that 'send_request' has already been executed successfully
|
||||
response_code = response.status_code
|
||||
response_code = self.response.status_code
|
||||
|
||||
# if we have guardrails, check the response
|
||||
if context.config and context.config.guardrails:
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
# run guardrails again, this time on request + response
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, json_response
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.json_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if context.dataset_name:
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
json_response,
|
||||
guardrails_execution_result,
|
||||
self.context,
|
||||
self.json_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# replace the response with the error message
|
||||
raise YieldException(
|
||||
return ExtraItem(
|
||||
Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
@@ -550,10 +580,30 @@ async def handle_non_streaming_response(
|
||||
),
|
||||
)
|
||||
|
||||
# if we don't have guardrails or if the response passed the guardrails (only then, we reach this point)
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(push_to_explorer(context, json_response))
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# execute instrumented request
|
||||
return await instrumentor.execute(send_request())
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
) -> Response:
|
||||
"""Handles non-streaming OpenAI responses"""
|
||||
|
||||
# # execute instrumented request
|
||||
# return await instrumentor.execute(send_request())
|
||||
response = InstrumentedOpenAIResponse(
|
||||
context,
|
||||
client,
|
||||
open_ai_request,
|
||||
)
|
||||
|
||||
return await response.instrumented_request()
|
||||
|
||||
@@ -2,15 +2,28 @@
|
||||
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from common.config_manager import GatewayConfigManager
|
||||
from routes.anthropic import gateway as anthropic_gateway
|
||||
from routes.gemini import gateway as gemini_gateway
|
||||
from routes.open_ai import gateway as open_ai_gateway
|
||||
from starlette_compress import CompressMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: fastapi.FastAPI):
|
||||
"""Lifespan event to load the config manager"""
|
||||
gateway_config = GatewayConfigManager.get_config()
|
||||
yield
|
||||
# Cleanup if needed
|
||||
del gateway_config
|
||||
|
||||
|
||||
app = fastapi.app = fastapi.FastAPI(
|
||||
docs_url="/api/v1/gateway/docs",
|
||||
redoc_url="/api/v1/gateway/redoc",
|
||||
openapi_url="/api/v1/gateway/openapi.json",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.add_middleware(CompressMiddleware)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user