simplify request instrumentation

This commit is contained in:
Luca Beurer-Kellner
2025-03-28 19:41:38 +01:00
committed by Hemang
parent e66232215e
commit 55db93c8d3
5 changed files with 359 additions and 349 deletions
@@ -6,7 +6,13 @@ import time
from typing import Any, Dict, List
from functools import wraps
from fastapi.responses import StreamingResponse
import httpx
<<<<<<< HEAD:gateway/integrations/guardails.py
=======
from zmq import IO_THREADS
from common.request_context_data import RequestContextData
>>>>>>> 91684ce (simplify request instrumentation):gateway/integrations/guardrails.py
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
@@ -99,102 +105,59 @@ async def preload_guardrails(context: "RequestContextData") -> None:
print(f"Error scheduling preload_guardrails task: {e}")
class YieldException(Exception):
class ExtraItem:
"""
Raise this exception in stream instrumentor listeners to
end the stream early, or to emit additional items in a stream.
Return this class in a instrumented stream callback, to yield an extra item in the resulting stream.
"""
def __init__(self, value, end_of_stream=False):
super().__init__(value)
self.value = value
self.end_of_stream = end_of_stream
def __str__(self):
return f"YieldException: {self.value}"
return f"<ExtraItem value={self.value} end_of_stream={self.end_of_stream}>"
class StreamInstrumentor:
"""
A class to instrument async iterables with hooks for processing
chunks, before processing, and on completion.
Use `@on('chunk')`, `@on('start')`, and `@on('end')` decorators
to register listeners for different events.
Listeners can simply process data, or alternatively raise a designated
YieldException to yield additional values or stop the stream.
Example usage:
```
instrumentor = StreamInstrumentor()
@instrumentor.on('chunk')
async def process_chunk(chunk):
# Process the chunk
print(f"Processing chunk: {chunk}")
if some_condition:
# Yield an additional value that will be interleaved in the stream
# Pass `end_of_stream=True` to stop the stream after yielding
# Pass `end_of_stream=False` to continue the stream after the interleaved value
raise YieldException("Extra value", end_of_stream=True)
```
"""
class InstrumentedStreamingResponse:
def __init__(self):
# called on every chunk (async)
self.on_chunk_listeners = []
# called once before the first chunk is processed, or even earlier (async)
self.before_listeners = []
# called once on stream completion (async)
self.on_complete_listeners = []
# request statistics
self.stat_token_times = []
self.stat_before_time = None
self.stat_after_time = None
self.stat_first_item_time = None
# decorator
def on(self, event: str):
async def on_chunk(self, chunk: Any) -> ExtraItem | None:
"""
Decorator to register listeners for different events.
This called will be called on every chunk (async).
"""
pass
async def on_start(self) -> ExtraItem | None:
"""
Decorator to register a listener for start events.
"""
pass
async def on_end(self) -> ExtraItem | None:
"""
Decorator to register a listener for end events.
"""
pass
async def event_generator(self):
"""
Streams the async iterable and invokes all instrumented hooks.
Args:
event (str): The event to listen for. Can be 'on_chunk',
'before', or 'on_complete'.
async_iterable: An async iterable to stream.
Returns:
Callable: A decorator to register the listener.
Yields:
The streamed data.
"""
raise NotImplementedError("This method should be implemented in a subclass.")
def decorator(func):
assert asyncio.iscoroutinefunction(
func
), "Listener must be an async function"
if event == "chunk":
if self.on_chunk_listeners is None:
self.on_chunk_listeners = []
self.on_chunk_listeners.append(func)
elif event == "start":
if self.before_listeners is None:
self.before_listeners = []
self.before_listeners.append(func)
elif event == "end":
if self.on_complete_listeners is None:
self.on_complete_listeners = []
self.on_complete_listeners.append(func)
else:
raise ValueError("Invalid event type. Use 'chunk', 'before', or 'end'.")
return func
return decorator
async def stream(self, async_iterable):
async def instrumented_event_generator(self):
"""
Streams the async iterable and invokes all instrumented hooks.
@@ -207,14 +170,11 @@ class StreamInstrumentor:
try:
start = time.time()
# schedule all before listeners which can be run concurrently
before_tasks = [
asyncio.create_task(listener(), name="instrumentor:start")
for listener in self.before_listeners
]
# schedule on_start which can be run concurrently
start_task = asyncio.create_task(self.on_start(), name="instrumentor:start")
# create async iterator from async_iterable
aiterable = aiter(async_iterable)
aiterable = aiter(self.event_generator())
# [STAT] capture start time of first item
start_first_item_request = time.time()
@@ -233,31 +193,23 @@ class StreamInstrumentor:
wait_for_first_item(), name="instrumentor:next:first"
)
# wait for all before listeners to finish
has_end_of_stream = False
for before_task in before_tasks:
try:
await before_task
except YieldException as e:
# yield extra value before any real items
yield e.value
# stop the stream if end_of_stream is True
if e.end_of_stream:
# if first item is already available
if not next_item_task.done():
# cancel the task
next_item_task.cancel()
# [STAT] capture time to first item to be now +0.01
if self.stat_first_item_time is None:
self.stat_first_item_time = (
time.time() - start_first_item_request
) + 0.01
has_end_of_stream = True
# don't wait for the first item if end_of stream is True
if has_end_of_stream:
# if end_of_stream is True, stop the stream
return
# check if 'start_task' yields an extra item
if extra_item := await start_task:
# yield extra value before any real items
yield extra_item.value
# stop the stream if end_of_stream is True
if extra_item.end_of_stream:
# if first item is already available
if not next_item_task.done():
# cancel the task
next_item_task.cancel()
# [STAT] capture time to first item to be now +0.01
if self.stat_first_item_time is None:
self.stat_first_item_time = (
time.time() - start_first_item_request
) + 0.01
# don't wait for the first item if end_of stream is True
return
# [STAT] capture before time stamp
self.stat_before_time = time.time() - start
@@ -282,35 +234,20 @@ class StreamInstrumentor:
time.time() - start - sum(self.stat_token_times)
)
# invoke on_chunk listeners
any_end_of_stream = False
for listener in self.on_chunk_listeners:
try:
await listener(item)
except YieldException as e:
yield e.value
# if end_of_stream is True, stop the stream
if e.end_of_stream:
any_end_of_stream = True
# if end_of_stream is True, stop the stream
if any_end_of_stream:
return
if extra_item := await self.on_chunk(item):
yield extra_item.value
# if end_of_stream is True, stop the stream
if extra_item.end_of_stream:
return
# yield item
yield item
on_complete_tasks = [
asyncio.create_task(listener(), name="instrumentor:end")
for listener in self.on_complete_listeners
]
for result in asyncio.as_completed(on_complete_tasks):
try:
await result
except YieldException as e:
# yield extra value before any real items
yield e.value
# we ignore end_of_stream here, because we are already at the end
# run on_end, before closing the stream (may yield an extra value)
if extra_item := await self.on_end():
# yield extra value before any real items
yield extra_item.value
# we ignore end_of_stream here, because we are already at the end
# [STAT] capture after time stamp
self.stat_after_time = time.time() - start
@@ -344,29 +281,35 @@ class StreamInstrumentor:
print(f" [total: {time.time() - start:.2f}s]")
class RequestInstrumentor(StreamInstrumentor):
class InstrumentedResponse(InstrumentedStreamingResponse):
"""
Like 'StreamInstrumentor', but for non-streaming requests.
Supports similar 'start', 'end' events, but not 'chunk', since everything is assumed
to be processed in one chunk (i.e., the request).
A class to instrument an async request with hooks for concurrent
pre-processing and post-processing (input and output guardrailing).
"""
def on(self, event):
assert event in [
"start",
"end",
], "RequestInstrumentor does not support 'chunk' events"
return super().on(event)
async def event_generator(self):
"""
We implement the 'event_generator' as a single item stream,
where the item is the full result of the request.
"""
yield await self.request()
async def execute(self, request_task):
async def wrapped_request_task():
yield await request_task
async def request(self):
"""
This method should be implemented in a subclass to perform the actual request.
"""
raise NotImplementedError("This method should be implemented in a subclass.")
# pretend the 'request_task' is an async iterable with a single item
result = [item async for item in self.stream(wrapped_request_task())]
assert len(result) >= 1, "RequestInstrumentor must yield at least one item"
return result[-1]
async def instrumented_request(self):
"""
Returns the 'Response' object of the request, after applying all instrumented hooks.
"""
results = [r async for r in self.instrumented_event_generator()]
assert len(results) >= 1, "InstrumentedResponse must yield at least one item"
# we return the last item, in case the end callback yields an extra item. Then,
# don't return the actual result but the 'end' result, e.g. for output guardrailing.
return results[-1]
async def check_guardrails(
@@ -395,6 +338,10 @@ async def check_guardrails(
"Accept": "application/json",
},
)
if not result.is_success:
raise Exception(
f"Guardrails check failed: {result.status_code} - {result.text}"
)
print(f"Guardrail check response: {result.json()}")
return result.json()
except Exception as e:
+1 -1
View File
@@ -18,7 +18,7 @@ from converters.anthropic_to_invariant import (
)
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
from integrations.guardails import check_guardrails, preload_guardrails
from integrations.guardrails import check_guardrails, preload_guardrails
gateway = APIRouter()
+1 -1
View File
@@ -16,7 +16,7 @@ from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
from converters.gemini_to_invariant import convert_request, convert_response
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from integrations.guardails import check_guardrails, preload_guardrails
from integrations.guardrails import check_guardrails, preload_guardrails
gateway = APIRouter()
+250 -200
View File
@@ -13,10 +13,10 @@ from common.constants import (
IGNORED_HEADERS,
)
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from integrations.guardails import (
RequestInstrumentor,
StreamInstrumentor,
YieldException,
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
InstrumentedStreamingResponse,
check_guardrails,
preload_guardrails,
)
@@ -89,23 +89,131 @@ async def openai_chat_completions_gateway(
return await handle_non_streaming_response(context, client, open_ai_request)
async def stream_response(
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
) -> Response:
"""
Handles streaming the OpenAI response to the client while building a merged_response
The chunks are returned to the caller immediately
The merged_response is built from the chunks as they are received
It is sent to the Invariant Explorer at the end of the stream
"""
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
):
super().__init__()
async def request_and_stream():
# request parameters
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.open_ai_request: httpx.Request = open_ai_request
# guardrailing output (if any)
self.guardrails_execution_result: Optional[dict] = None
# merged_response will be updated with the data from the chunks in the stream
# At the end of the stream, this will be sent to the explorer
self.merged_response = {
"id": None,
"object": "chat.completion",
"created": None,
"model": None,
"choices": [],
"usage": None,
}
# Each chunk in the stream contains a list called "choices" each entry in the list
# has an index.
# A choice has a field called "delta" which may contain a list called "tool_calls".
# Maps the choice index in the stream to the index in the merged_response["choices"] list
self.choice_mapping_by_index = {}
# Combines the choice index and tool call index to uniquely identify a tool call
self.tool_call_mapping_by_index = {}
async def on_start(self):
# check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)
if self.context.config and self.context.config.guardrails:
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The request did not pass the guardrails",
"details": self.guardrails_execution_result,
}
}
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.merged_response,
self.guardrails_execution_result,
)
)
# if we find something, we end the stream prematurely (end_of_stream=True)
# and yield an error chunk instead of actually beginning the stream
return ExtraItem(
f"data: {error_chunk}\n\n".encode(), end_of_stream=True
)
async def on_chunk(self, chunk):
# process and check each chunk
chunk_text = chunk.decode().strip()
if not chunk_text:
return
# Process the chunk
# This will update merged_response with the data from the chunk
process_chunk_text(
chunk_text,
self.merged_response,
self.choice_mapping_by_index,
self.tool_call_mapping_by_index,
)
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
if (
"data: [DONE]" in chunk_text
and self.context.config
and self.context.config.guardrails
):
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The response did not pass the guardrails",
"details": self.guardrails_execution_result,
}
}
)
# yield an extra error chunk (without preventing the original chunk to go through after)
return ExtraItem(f"data: {error_chunk}\n\n".encode())
# push will happen in on_end
async def on_end(self):
# Send full merged response to the explorer
# Don't block on the response from explorer
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context, self.merged_response, self.guardrails_execution_result
)
)
async def event_generator(self):
"""
Sets off the request and then streams the result.
Actual OpenAI stream response.
"""
response = await client.send(open_ai_request, stream=True)
response = await self.client.send(self.open_ai_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
@@ -119,123 +227,28 @@ async def stream_response(
async for chunk in response.aiter_bytes():
yield chunk
async def event_generator() -> Any:
# merged_response will be updated with the data from the chunks in the stream
# At the end of the stream, this will be sent to the explorer
merged_response = {
"id": None,
"object": "chat.completion",
"created": None,
"model": None,
"choices": [],
"usage": None,
}
# Each chunk in the stream contains a list called "choices" each entry in the list
# has an index.
# A choice has a field called "delta" which may contain a list called "tool_calls".
# Maps the choice index in the stream to the index in the merged_response["choices"] list
choice_mapping_by_index = {}
# Combines the choice index and tool call index to uniquely identify a tool call
tool_call_mapping_by_index = {}
async def stream_response(
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
) -> Response:
"""
Handles streaming the OpenAI response to the client while building a merged_response
The chunks are returned to the caller immediately
The merged_response is built from the chunks as they are received
It is sent to the Invariant Explorer at the end of the stream
"""
# prepare stream instrumentor
instrumentor = StreamInstrumentor()
response = InstrumentedOpenAIStreamResponse(
context,
client,
open_ai_request,
)
@instrumentor.on("start")
async def precheck_guardrails() -> None:
# check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)
if context.config and context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, merged_response
)
if guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The request did not pass the guardrails",
"details": guardrails_execution_result,
}
}
)
# Push annotated trace to the explorer - don't block on its response
if context.dataset_name:
asyncio.create_task(
push_to_explorer(
context,
merged_response,
guardrails_execution_result,
)
)
# if we find something, we end the stream prematurely (end_of_stream=True)
# and yield an error chunk instead of actually beginning the stream
raise YieldException(
f"data: {error_chunk}\n\n".encode(), end_of_stream=True
)
@instrumentor.on("chunk")
async def process_chunk(chunk: bytes) -> None:
# process and check each chunk
chunk_text = chunk.decode().strip()
if not chunk_text:
return
# Process the chunk
# This will update merged_response with the data from the chunk
process_chunk_text(
chunk_text,
merged_response,
choice_mapping_by_index,
tool_call_mapping_by_index,
)
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
if (
"data: [DONE]" in chunk_text
and context.config
and context.config.guardrails
):
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, merged_response
)
if guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
}
}
)
# Push annotated trace to the explorer - don't block on its response
if context.dataset_name:
asyncio.create_task(
push_to_explorer(
context,
merged_response,
guardrails_execution_result,
)
)
# yield an extra error chunk (without preventing the original chunk to go through after)
raise YieldException(f"data: {error_chunk}\n\n".encode())
@instrumentor.on("end")
async def send_to_explorer() -> None:
# Send full merged response to the explorer
# Don't block on the response from explorer
if context.dataset_name:
asyncio.create_task(push_to_explorer(context, merged_response))
async for chunk in instrumentor.stream(request_and_stream()):
# Yield chunk to the client
yield chunk
return StreamingResponse(event_generator(), media_type="text/event-stream")
return StreamingResponse(
response.instrumented_event_generator(), media_type="text/event-stream"
)
def initialize_merged_response() -> dict[str, Any]:
@@ -436,73 +449,51 @@ async def get_guardrails_check_result(
return guardrails_execution_result
async def handle_non_streaming_response(
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
) -> Response:
"""Handles non-streaming OpenAI responses"""
class InstrumentedOpenAIResponse(InstrumentedResponse):
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
):
super().__init__()
instrumentor = RequestInstrumentor()
# request parameters
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.open_ai_request: httpx.Request = open_ai_request
# respond we get and its JSON decoded version
# available once the 'send_request' function has progressed to the point of
# being able to call 'response.json()'
response = None
json_response = None
# request outputs
self.response: Optional[httpx.Response] = None
self.json_response: Optional[dict[str, Any]] = None
async def send_request():
nonlocal response, json_response
self.guardrails_execution_result: Optional[dict] = None
response = await client.send(open_ai_request)
try:
json_response = response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=response.status_code,
detail="Invalid JSON response received from OpenAI API",
) from e
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from OpenAI API"),
)
response_string = json.dumps(json_response)
response_code = response.status_code
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(response.headers),
)
@instrumentor.on("start")
async def precheck_guardrails() -> None:
async def on_start(self):
# check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)
if context.config and context.config.guardrails:
if self.context.config and self.context.config.guardrails:
# block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(context)
if guardrails_execution_result.get("errors", []):
self.guardrails_execution_result = await get_guardrails_check_result(
self.context
)
if self.guardrails_execution_result.get("errors", []):
# Push annotated trace to the explorer - don't block on its response
if context.dataset_name:
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
context,
self.context,
{},
guardrails_execution_result,
self.guardrails_execution_result,
)
)
# replace the response with the error message
raise YieldException(
return ExtraItem(
Response(
content=json.dumps(
{
"error": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
"details": self.guardrails_execution_result,
}
),
status_code=400,
@@ -511,38 +502,77 @@ async def handle_non_streaming_response(
end_of_stream=True,
)
@instrumentor.on("end")
async def postprocess_guardrails() -> None:
async def request(self):
"""
Actual OpenAI request.
"""
self.response = await self.client.send(self.open_ai_request)
try:
self.json_response = self.response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
detail="Invalid JSON response received from OpenAI API",
) from e
if self.response.status_code != 200:
raise HTTPException(
status_code=self.response.status_code,
detail=self.json_response.get("error", "Unknown error from OpenAI API"),
)
response_string = json.dumps(self.json_response)
response_code = self.response.status_code
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(self.response.headers),
)
async def on_end(self):
"""
Postprocess the OpenAI response and potentially replace it with a guardrails error.
"""
# these two are guaranteed to be set by the time we reach this point (after self.request() was executed)
assert (
self.response is not None
), "on_end called before 'self.response' was available"
assert (
self.json_response is not None
), "on_end called before 'self.json_response' was available"
# at this point, we are guaranteed that 'send_request' has already been executed successfully
response_code = response.status_code
response_code = self.response.status_code
# if we have guardrails, check the response
if context.config and context.config.guardrails:
if self.context.config and self.context.config.guardrails:
# run guardrails again, this time on request + response
guardrails_execution_result = await get_guardrails_check_result(
context, json_response
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.json_response
)
if guardrails_execution_result.get("errors", []):
if self.guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
"details": self.guardrails_execution_result,
}
)
response_code = 400
# Push annotated trace to the explorer - don't block on its response
if context.dataset_name:
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
context,
json_response,
guardrails_execution_result,
self.context,
self.json_response,
self.guardrails_execution_result,
)
)
# replace the response with the error message
raise YieldException(
return ExtraItem(
Response(
content=response_string,
status_code=response_code,
@@ -550,10 +580,30 @@ async def handle_non_streaming_response(
),
)
# if we don't have guardrails or if the response passed the guardrails (only then, we reach this point)
if context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(push_to_explorer(context, json_response))
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
self.guardrails_execution_result,
)
)
# execute instrumented request
return await instrumentor.execute(send_request())
async def handle_non_streaming_response(
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
) -> Response:
"""Handles non-streaming OpenAI responses"""
# # execute instrumented request
# return await instrumentor.execute(send_request())
response = InstrumentedOpenAIResponse(
context,
client,
open_ai_request,
)
return await response.instrumented_request()
+13
View File
@@ -2,15 +2,28 @@
import fastapi
import uvicorn
from common.config_manager import GatewayConfigManager
from routes.anthropic import gateway as anthropic_gateway
from routes.gemini import gateway as gemini_gateway
from routes.open_ai import gateway as open_ai_gateway
from starlette_compress import CompressMiddleware
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
"""Lifespan event to load the config manager"""
gateway_config = GatewayConfigManager.get_config()
yield
# Cleanup if needed
del gateway_config
app = fastapi.app = fastapi.FastAPI(
docs_url="/api/v1/gateway/docs",
redoc_url="/api/v1/gateway/redoc",
openapi_url="/api/v1/gateway/openapi.json",
lifespan=lifespan,
)
app.add_middleware(CompressMiddleware)