From 1f6e2ed7fb0ee603b5c27bbb962ff96f9f87ca72 Mon Sep 17 00:00:00 2001 From: Hemang Date: Wed, 7 May 2025 00:36:56 +0530 Subject: [PATCH] Update streaming in anthropic route to handle chunks with incomplete events. Introduce an sse_buffer which keeps track of the current incomplete event from the last processed chunk. --- gateway/routes/anthropic.py | 129 ++++++++++++++---- .../test_anthropic_with_tool_call.py | 2 +- .../test_anthropic_without_tool_call.py | 11 +- 3 files changed, 109 insertions(+), 33 deletions(-) diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index bbc2af7..36c4a63 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -380,6 +380,8 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): # guardrailing response (if any) self.guardrails_execution_result = {} + self.sse_buffer = "" # Buffer for incomplete events + async def on_start(self): """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" if self.context.guardrails: @@ -434,16 +436,78 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): yield chunk async def on_chunk(self, chunk): - """Process the chunk and update the merged_response""" - decoded_chunk = chunk.decode().strip() - if not decoded_chunk: - return + """ + Process the chunk and update the merged_response. + Each chunk may contain multiple events, separated by double newlines. + Each event has type and data fields, separated by a newline. + It is possible that a chunk contains some incomplete events. - # process chunk and extend the merged_response - process_chunk(decoded_chunk, self.merged_response) + Example: + + b'event: message_start\ndata: {"type":"message_start","message": + {"id":"msg_01LkayzAaw7b7QkUAw91psyx","type":"message","role":"assistant" + ,"model":"claude-3-5-sonnet-20241022","content":[],"stop_reason":null, + "stop_sequence":null,"usage":{"input_tokens":20,"cache_creation_input_to' + + and + + b'kens":0,"cache_read_input_tokens":0,"output_tokens":1}}}\n\nevent: content_block_start + \ndata: {"type":"content_block_start","index":0,"content_block" + :{"type":"text","text":""} }\n\nevent: ping + \ndata: {"type": "ping"}\n\nevent: content_block_delta + \ndata: {"type":"content_block_delta","index":0,"delta":{"type": + "text_delta","text":"Originally"} }\n\n' + + In this case the first chunk ends with 'cache_creation_input_to' which is + continued in the next chunk. + + in this case we need to maintain a buffer of the incomplete events. + We filter out the ping events and update a merged_response. + """ + # Decode the chunk and add to buffer + decoded_chunk = chunk.decode("utf-8", errors="replace") + self.sse_buffer += decoded_chunk + + # Process complete events from buffer + complete_events, incomplete_events = self.process_complete_events( + self.sse_buffer + ) + self.sse_buffer = incomplete_events + + # Check if we've received message_stop in any events + message_stop_received = False + + # Update the merged_response based on complete events + for event in complete_events: + try: + if "event: message_stop" in event: + message_stop_received = True + + # Extract event data + lines = event.split("\n") + event_type = None + event_data = None + + for line in lines: + if line.startswith("event:"): + event_type = line[6:].strip() + elif line.startswith("data:"): + event_data = line[5:].strip() + + if event_data and event_type != "ping": # Skip ping events + try: + event_json = json.loads(event_data) + update_merged_response(event_json, self.merged_response) + except json.JSONDecodeError as e: + print( + f"JSON parsing error in event: {e}. Event data: {event_data[:100]}...", + flush=True, + ) + except Exception as e: + print(f"Error processing event: {e}", flush=True) # on last stream chunk, run output guardrails - if "event: message_stop" in decoded_chunk and self.context.guardrails: + if message_stop_received and self.context.guardrails: # Block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( self.context, @@ -468,8 +532,32 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): value=f"event: error\ndata: {error_chunk}\n\n".encode() ) + def process_complete_events(self, buffer): + """Process the buffer and extract complete SSE events. + + Returns: + Tuple[List[str], str]: A tuple containing a list of + complete events and the remaining buffer with incomplete events. + """ + # Split on double newlines which separate SSE events + if not buffer: + return [], "" + events = [] + remaining = buffer + + # Process events that are complete (ending with \n\n) + while "\n\n" in remaining: + pos = remaining.find("\n\n") + if pos >= 0: + event = remaining[: pos + 2] + remaining = remaining[pos + 2 :] + if event.strip(): # Skip empty events + events.append(event) + + return events, remaining + async def on_end(self): - """on_end: send full merged response to the exploree (if configured)""" + """on_end: send full merged response to the explorer (if configured)""" # don't block on the response from explorer (.create_task) if self.context.dataset_name: asyncio.create_task( @@ -498,20 +586,6 @@ async def handle_streaming_response( ) -def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None: - """ - Process the chunk of text and update the merged_response - Example of chunk list can be find in: - ../../resources/streaming_chunk_text/anthropic.txt - """ - for text_block in chunk.split("\n\n"): - # might be empty block - if len(text_block.split("\ndata:")) > 1: - event_text = text_block.split("\ndata:")[1] - event = json.loads(event_text) - update_merged_response(event, merged_response) - - def update_merged_response( event: dict[str, Any], merged_response: dict[str, Any] ) -> None: @@ -527,22 +601,25 @@ def update_merged_response( final Message content array. 3. One or more message_delta events, indicating top-level changes to the final Message object. A final message_stop event. + We filter out the ping eventss """ - if event.get("type") == MESSAGE_START: + event_type = event.get("type") + + if event_type == MESSAGE_START: merged_response.update(**event.get("message")) - elif event.get("type") == CONTENT_BLOCK_START: + elif event_type == CONTENT_BLOCK_START: index = event.get("index") if index >= len(merged_response.get("content")): merged_response["content"].append(event.get("content_block")) if event.get("content_block").get("type") == "tool_use": merged_response.get("content")[-1]["input"] = "" - elif event.get("type") == CONTENT_BLOCK_DELTA: + elif event_type == CONTENT_BLOCK_DELTA: index = event.get("index") delta = event.get("delta") if delta.get("type") == "text_delta": merged_response.get("content")[index]["text"] += delta.get("text") elif delta.get("type") == "input_json_delta": merged_response.get("content")[index]["input"] += delta.get("partial_json") - elif event.get("type") == MESSAGE_DELTA: + elif event_type == MESSAGE_DELTA: merged_response["usage"].update(**event.get("usage")) diff --git a/tests/integration/anthropic/test_anthropic_with_tool_call.py b/tests/integration/anthropic/test_anthropic_with_tool_call.py index 5e42881..1ed7dd5 100644 --- a/tests/integration/anthropic/test_anthropic_with_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_with_tool_call.py @@ -179,7 +179,7 @@ async def test_response_with_tool_call(explorer_api_url, gateway_url, push_to_ex weather_agent = WeatherAgent(gateway_url, push_to_explorer) - query = "Tell me the weather for New York" + query = "Tell me the weather for New York in Celsius" city = "new york" # Process each query diff --git a/tests/integration/anthropic/test_anthropic_without_tool_call.py b/tests/integration/anthropic/test_anthropic_without_tool_call.py index 280341f..727e76a 100644 --- a/tests/integration/anthropic/test_anthropic_without_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_without_tool_call.py @@ -8,10 +8,9 @@ import uuid # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils import get_anthropic_client - import pytest import requests +from utils import get_anthropic_client # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -90,9 +89,9 @@ async def test_streaming_response_without_tool_call( cities = ["zurich", "new york", "london"] queries = [ - "Can you introduce Zurich, Switzerland within 200 words?", - "Tell me the history of New York within 100 words?", - "How's the weather in London next week?", + "Can you introduce Zurich, Switzerland in 2 short sentences?", + "Tell me the history of New York within 2 short sentences.", + "Explain the geography of London in 2 short sentences.", ] # Process each query responses = [] @@ -102,7 +101,7 @@ async def test_streaming_response_without_tool_call( with client.messages.stream( model="claude-3-5-sonnet-20241022", - max_tokens=1024, + max_tokens=200, messages=messages, ) as response: for reply in response.text_stream: