diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index e8210ae..7b05136 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -299,23 +299,31 @@ impl OpenAiSseParser { #[derive(Debug)] struct StreamState { model: String, - message_started: bool, - text_started: bool, - text_finished: bool, - finished: bool, + message: MessageState, + text: TextState, stop_reason: Option, usage: Option, tool_calls: BTreeMap, } +#[derive(Debug, Default)] +struct MessageState { + started: bool, + finished: bool, +} + +#[derive(Debug, Default)] +struct TextState { + started: bool, + finished: bool, +} + impl StreamState { fn new(model: String) -> Self { Self { model, - message_started: false, - text_started: false, - text_finished: false, - finished: false, + message: MessageState::default(), + text: TextState::default(), stop_reason: None, usage: None, tool_calls: BTreeMap::new(), @@ -324,8 +332,8 @@ impl StreamState { fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result, ApiError> { let mut events = Vec::new(); - if !self.message_started { - self.message_started = true; + if !self.message.started { + self.message.started = true; events.push(StreamEvent::MessageStart(MessageStartEvent { message: MessageResponse { id: chunk.id.clone(), @@ -357,8 +365,8 @@ impl StreamState { for choice in chunk.choices { if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) { - if !self.text_started { - self.text_started = true; + if !self.text.started { + self.text.started = true; events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { index: 0, content_block: OutputContentBlock::Text { @@ -414,14 +422,14 @@ impl StreamState { } fn finish(&mut self) -> Result, ApiError> { - if self.finished { + if self.message.finished { return Ok(Vec::new()); } - self.finished = true; + self.message.finished = true; let mut events = Vec::new(); - if self.text_started && !self.text_finished { - self.text_finished = true; + if self.text.started && !self.text.finished { + self.text.finished = true; events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0, })); @@ -445,7 +453,7 @@ impl StreamState { } } - if self.message_started { + if self.message.started { events.push(StreamEvent::MessageDelta(MessageDeltaEvent { delta: MessageDelta { stop_reason: Some(