mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-21 06:26:47 +02:00
Update streaming in anthropic route to handle chunks with incomplete events. Introduce an sse_buffer which keeps track of the current incomplete event from the last processed chunk.
This commit is contained in:
+103
-26
@@ -380,6 +380,8 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
|
||||
# guardrailing response (if any)
|
||||
self.guardrails_execution_result = {}
|
||||
|
||||
self.sse_buffer = "" # Buffer for incomplete events
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.guardrails:
|
||||
@@ -434,16 +436,78 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
|
||||
yield chunk
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
"""Process the chunk and update the merged_response"""
|
||||
decoded_chunk = chunk.decode().strip()
|
||||
if not decoded_chunk:
|
||||
return
|
||||
"""
|
||||
Process the chunk and update the merged_response.
|
||||
Each chunk may contain multiple events, separated by double newlines.
|
||||
Each event has type and data fields, separated by a newline.
|
||||
It is possible that a chunk contains some incomplete events.
|
||||
|
||||
# process chunk and extend the merged_response
|
||||
process_chunk(decoded_chunk, self.merged_response)
|
||||
Example:
|
||||
|
||||
b'event: message_start\ndata: {"type":"message_start","message":
|
||||
{"id":"msg_01LkayzAaw7b7QkUAw91psyx","type":"message","role":"assistant"
|
||||
,"model":"claude-3-5-sonnet-20241022","content":[],"stop_reason":null,
|
||||
"stop_sequence":null,"usage":{"input_tokens":20,"cache_creation_input_to'
|
||||
|
||||
and
|
||||
|
||||
b'kens":0,"cache_read_input_tokens":0,"output_tokens":1}}}\n\nevent: content_block_start
|
||||
\ndata: {"type":"content_block_start","index":0,"content_block"
|
||||
:{"type":"text","text":""} }\n\nevent: ping
|
||||
\ndata: {"type": "ping"}\n\nevent: content_block_delta
|
||||
\ndata: {"type":"content_block_delta","index":0,"delta":{"type":
|
||||
"text_delta","text":"Originally"} }\n\n'
|
||||
|
||||
In this case the first chunk ends with 'cache_creation_input_to' which is
|
||||
continued in the next chunk.
|
||||
|
||||
in this case we need to maintain a buffer of the incomplete events.
|
||||
We filter out the ping events and update a merged_response.
|
||||
"""
|
||||
# Decode the chunk and add to buffer
|
||||
decoded_chunk = chunk.decode("utf-8", errors="replace")
|
||||
self.sse_buffer += decoded_chunk
|
||||
|
||||
# Process complete events from buffer
|
||||
complete_events, incomplete_events = self.process_complete_events(
|
||||
self.sse_buffer
|
||||
)
|
||||
self.sse_buffer = incomplete_events
|
||||
|
||||
# Check if we've received message_stop in any events
|
||||
message_stop_received = False
|
||||
|
||||
# Update the merged_response based on complete events
|
||||
for event in complete_events:
|
||||
try:
|
||||
if "event: message_stop" in event:
|
||||
message_stop_received = True
|
||||
|
||||
# Extract event data
|
||||
lines = event.split("\n")
|
||||
event_type = None
|
||||
event_data = None
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("event:"):
|
||||
event_type = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
event_data = line[5:].strip()
|
||||
|
||||
if event_data and event_type != "ping": # Skip ping events
|
||||
try:
|
||||
event_json = json.loads(event_data)
|
||||
update_merged_response(event_json, self.merged_response)
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
f"JSON parsing error in event: {e}. Event data: {event_data[:100]}...",
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error processing event: {e}", flush=True)
|
||||
|
||||
# on last stream chunk, run output guardrails
|
||||
if "event: message_stop" in decoded_chunk and self.context.guardrails:
|
||||
if message_stop_received and self.context.guardrails:
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context,
|
||||
@@ -468,8 +532,32 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
|
||||
value=f"event: error\ndata: {error_chunk}\n\n".encode()
|
||||
)
|
||||
|
||||
def process_complete_events(self, buffer):
|
||||
"""Process the buffer and extract complete SSE events.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], str]: A tuple containing a list of
|
||||
complete events and the remaining buffer with incomplete events.
|
||||
"""
|
||||
# Split on double newlines which separate SSE events
|
||||
if not buffer:
|
||||
return [], ""
|
||||
events = []
|
||||
remaining = buffer
|
||||
|
||||
# Process events that are complete (ending with \n\n)
|
||||
while "\n\n" in remaining:
|
||||
pos = remaining.find("\n\n")
|
||||
if pos >= 0:
|
||||
event = remaining[: pos + 2]
|
||||
remaining = remaining[pos + 2 :]
|
||||
if event.strip(): # Skip empty events
|
||||
events.append(event)
|
||||
|
||||
return events, remaining
|
||||
|
||||
async def on_end(self):
|
||||
"""on_end: send full merged response to the exploree (if configured)"""
|
||||
"""on_end: send full merged response to the explorer (if configured)"""
|
||||
# don't block on the response from explorer (.create_task)
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
@@ -498,20 +586,6 @@ async def handle_streaming_response(
|
||||
)
|
||||
|
||||
|
||||
def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None:
|
||||
"""
|
||||
Process the chunk of text and update the merged_response
|
||||
Example of chunk list can be find in:
|
||||
../../resources/streaming_chunk_text/anthropic.txt
|
||||
"""
|
||||
for text_block in chunk.split("\n\n"):
|
||||
# might be empty block
|
||||
if len(text_block.split("\ndata:")) > 1:
|
||||
event_text = text_block.split("\ndata:")[1]
|
||||
event = json.loads(event_text)
|
||||
update_merged_response(event, merged_response)
|
||||
|
||||
|
||||
def update_merged_response(
|
||||
event: dict[str, Any], merged_response: dict[str, Any]
|
||||
) -> None:
|
||||
@@ -527,22 +601,25 @@ def update_merged_response(
|
||||
final Message content array.
|
||||
3. One or more message_delta events, indicating top-level changes to the final Message object.
|
||||
A final message_stop event.
|
||||
We filter out the ping eventss
|
||||
|
||||
"""
|
||||
if event.get("type") == MESSAGE_START:
|
||||
event_type = event.get("type")
|
||||
|
||||
if event_type == MESSAGE_START:
|
||||
merged_response.update(**event.get("message"))
|
||||
elif event.get("type") == CONTENT_BLOCK_START:
|
||||
elif event_type == CONTENT_BLOCK_START:
|
||||
index = event.get("index")
|
||||
if index >= len(merged_response.get("content")):
|
||||
merged_response["content"].append(event.get("content_block"))
|
||||
if event.get("content_block").get("type") == "tool_use":
|
||||
merged_response.get("content")[-1]["input"] = ""
|
||||
elif event.get("type") == CONTENT_BLOCK_DELTA:
|
||||
elif event_type == CONTENT_BLOCK_DELTA:
|
||||
index = event.get("index")
|
||||
delta = event.get("delta")
|
||||
if delta.get("type") == "text_delta":
|
||||
merged_response.get("content")[index]["text"] += delta.get("text")
|
||||
elif delta.get("type") == "input_json_delta":
|
||||
merged_response.get("content")[index]["input"] += delta.get("partial_json")
|
||||
elif event.get("type") == MESSAGE_DELTA:
|
||||
elif event_type == MESSAGE_DELTA:
|
||||
merged_response["usage"].update(**event.get("usage"))
|
||||
|
||||
@@ -179,7 +179,7 @@ async def test_response_with_tool_call(explorer_api_url, gateway_url, push_to_ex
|
||||
|
||||
weather_agent = WeatherAgent(gateway_url, push_to_explorer)
|
||||
|
||||
query = "Tell me the weather for New York"
|
||||
query = "Tell me the weather for New York in Celsius"
|
||||
|
||||
city = "new york"
|
||||
# Process each query
|
||||
|
||||
@@ -8,10 +8,9 @@ import uuid
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_anthropic_client
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from utils import get_anthropic_client
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
@@ -90,9 +89,9 @@ async def test_streaming_response_without_tool_call(
|
||||
|
||||
cities = ["zurich", "new york", "london"]
|
||||
queries = [
|
||||
"Can you introduce Zurich, Switzerland within 200 words?",
|
||||
"Tell me the history of New York within 100 words?",
|
||||
"How's the weather in London next week?",
|
||||
"Can you introduce Zurich, Switzerland in 2 short sentences?",
|
||||
"Tell me the history of New York within 2 short sentences.",
|
||||
"Explain the geography of London in 2 short sentences.",
|
||||
]
|
||||
# Process each query
|
||||
responses = []
|
||||
@@ -102,7 +101,7 @@ async def test_streaming_response_without_tool_call(
|
||||
|
||||
with client.messages.stream(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
max_tokens=200,
|
||||
messages=messages,
|
||||
) as response:
|
||||
for reply in response.text_stream:
|
||||
|
||||
Reference in New Issue
Block a user