Update streaming in anthropic route to handle chunks with incomplete events. Introduce an sse_buffer which keeps track of the current incomplete event from the last processed chunk.

This commit is contained in:
Hemang
2025-05-07 00:36:56 +05:30
committed by Hemang Sarkar
parent aec7808e3e
commit 1f6e2ed7fb
3 changed files with 109 additions and 33 deletions
+103 -26
View File
@@ -380,6 +380,8 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
# guardrailing response (if any)
self.guardrails_execution_result = {}
self.sse_buffer = "" # Buffer for incomplete events
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.guardrails:
@@ -434,16 +436,78 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
yield chunk
async def on_chunk(self, chunk):
"""Process the chunk and update the merged_response"""
decoded_chunk = chunk.decode().strip()
if not decoded_chunk:
return
"""
Process the chunk and update the merged_response.
Each chunk may contain multiple events, separated by double newlines.
Each event has type and data fields, separated by a newline.
It is possible that a chunk contains some incomplete events.
# process chunk and extend the merged_response
process_chunk(decoded_chunk, self.merged_response)
Example:
b'event: message_start\ndata: {"type":"message_start","message":
{"id":"msg_01LkayzAaw7b7QkUAw91psyx","type":"message","role":"assistant"
,"model":"claude-3-5-sonnet-20241022","content":[],"stop_reason":null,
"stop_sequence":null,"usage":{"input_tokens":20,"cache_creation_input_to'
and
b'kens":0,"cache_read_input_tokens":0,"output_tokens":1}}}\n\nevent: content_block_start
\ndata: {"type":"content_block_start","index":0,"content_block"
:{"type":"text","text":""} }\n\nevent: ping
\ndata: {"type": "ping"}\n\nevent: content_block_delta
\ndata: {"type":"content_block_delta","index":0,"delta":{"type":
"text_delta","text":"Originally"} }\n\n'
In this case the first chunk ends with 'cache_creation_input_to' which is
continued in the next chunk.
in this case we need to maintain a buffer of the incomplete events.
We filter out the ping events and update a merged_response.
"""
# Decode the chunk and add to buffer
decoded_chunk = chunk.decode("utf-8", errors="replace")
self.sse_buffer += decoded_chunk
# Process complete events from buffer
complete_events, incomplete_events = self.process_complete_events(
self.sse_buffer
)
self.sse_buffer = incomplete_events
# Check if we've received message_stop in any events
message_stop_received = False
# Update the merged_response based on complete events
for event in complete_events:
try:
if "event: message_stop" in event:
message_stop_received = True
# Extract event data
lines = event.split("\n")
event_type = None
event_data = None
for line in lines:
if line.startswith("event:"):
event_type = line[6:].strip()
elif line.startswith("data:"):
event_data = line[5:].strip()
if event_data and event_type != "ping": # Skip ping events
try:
event_json = json.loads(event_data)
update_merged_response(event_json, self.merged_response)
except json.JSONDecodeError as e:
print(
f"JSON parsing error in event: {e}. Event data: {event_data[:100]}...",
flush=True,
)
except Exception as e:
print(f"Error processing event: {e}", flush=True)
# on last stream chunk, run output guardrails
if "event: message_stop" in decoded_chunk and self.context.guardrails:
if message_stop_received and self.context.guardrails:
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context,
@@ -468,8 +532,32 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
value=f"event: error\ndata: {error_chunk}\n\n".encode()
)
def process_complete_events(self, buffer):
"""Process the buffer and extract complete SSE events.
Returns:
Tuple[List[str], str]: A tuple containing a list of
complete events and the remaining buffer with incomplete events.
"""
# Split on double newlines which separate SSE events
if not buffer:
return [], ""
events = []
remaining = buffer
# Process events that are complete (ending with \n\n)
while "\n\n" in remaining:
pos = remaining.find("\n\n")
if pos >= 0:
event = remaining[: pos + 2]
remaining = remaining[pos + 2 :]
if event.strip(): # Skip empty events
events.append(event)
return events, remaining
async def on_end(self):
"""on_end: send full merged response to the exploree (if configured)"""
"""on_end: send full merged response to the explorer (if configured)"""
# don't block on the response from explorer (.create_task)
if self.context.dataset_name:
asyncio.create_task(
@@ -498,20 +586,6 @@ async def handle_streaming_response(
)
def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None:
"""
Process the chunk of text and update the merged_response
Example of chunk list can be find in:
../../resources/streaming_chunk_text/anthropic.txt
"""
for text_block in chunk.split("\n\n"):
# might be empty block
if len(text_block.split("\ndata:")) > 1:
event_text = text_block.split("\ndata:")[1]
event = json.loads(event_text)
update_merged_response(event, merged_response)
def update_merged_response(
event: dict[str, Any], merged_response: dict[str, Any]
) -> None:
@@ -527,22 +601,25 @@ def update_merged_response(
final Message content array.
3. One or more message_delta events, indicating top-level changes to the final Message object.
A final message_stop event.
We filter out the ping eventss
"""
if event.get("type") == MESSAGE_START:
event_type = event.get("type")
if event_type == MESSAGE_START:
merged_response.update(**event.get("message"))
elif event.get("type") == CONTENT_BLOCK_START:
elif event_type == CONTENT_BLOCK_START:
index = event.get("index")
if index >= len(merged_response.get("content")):
merged_response["content"].append(event.get("content_block"))
if event.get("content_block").get("type") == "tool_use":
merged_response.get("content")[-1]["input"] = ""
elif event.get("type") == CONTENT_BLOCK_DELTA:
elif event_type == CONTENT_BLOCK_DELTA:
index = event.get("index")
delta = event.get("delta")
if delta.get("type") == "text_delta":
merged_response.get("content")[index]["text"] += delta.get("text")
elif delta.get("type") == "input_json_delta":
merged_response.get("content")[index]["input"] += delta.get("partial_json")
elif event.get("type") == MESSAGE_DELTA:
elif event_type == MESSAGE_DELTA:
merged_response["usage"].update(**event.get("usage"))
@@ -179,7 +179,7 @@ async def test_response_with_tool_call(explorer_api_url, gateway_url, push_to_ex
weather_agent = WeatherAgent(gateway_url, push_to_explorer)
query = "Tell me the weather for New York"
query = "Tell me the weather for New York in Celsius"
city = "new york"
# Process each query
@@ -8,10 +8,9 @@ import uuid
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_anthropic_client
import pytest
import requests
from utils import get_anthropic_client
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -90,9 +89,9 @@ async def test_streaming_response_without_tool_call(
cities = ["zurich", "new york", "london"]
queries = [
"Can you introduce Zurich, Switzerland within 200 words?",
"Tell me the history of New York within 100 words?",
"How's the weather in London next week?",
"Can you introduce Zurich, Switzerland in 2 short sentences?",
"Tell me the history of New York within 2 short sentences.",
"Explain the geography of London in 2 short sentences.",
]
# Process each query
responses = []
@@ -102,7 +101,7 @@ async def test_streaming_response_without_tool_call(
with client.messages.stream(
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
max_tokens=200,
messages=messages,
) as response:
for reply in response.text_stream: