diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 00105209..95cca1fb 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -195,6 +195,8 @@ type ChatMessage struct { ToolCallID string `json:"tool_call_id,omitempty"` // ToolName 仅 tool 角色:从 Eino/轨迹 JSON 的 name 或 tool_name 恢复,供续跑构造 ToolMessage。 ToolName string `json:"tool_name,omitempty"` + // ReasoningContent 对应 OpenAI/DeepSeek 的 reasoning_content;思考模式 + 工具调用后续跑须回传(见 DeepSeek 文档)。 + ReasoningContent string `json:"reasoning_content,omitempty"` } // MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串 @@ -208,6 +210,9 @@ func (cm ChatMessage) MarshalJSON() ([]byte, error) { if cm.Content != "" { aux["content"] = cm.Content } + if cm.ReasoningContent != "" { + aux["reasoning_content"] = cm.ReasoningContent + } // 添加tool_call_id(如果存在) if cm.ToolCallID != "" { @@ -663,8 +668,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 检查是否有工具调用 if len(choice.Message.ToolCalls) > 0 { - // 思考内容:如果本轮启用了思考流式增量(thinking_stream_*),前端会去重; - // 同时也需要在该“思考阶段结束”时补一条可落库的 thinking(用于刷新后持久化展示)。 + // ReAct 助手正文流式增量(thinking_stream_*)在 UI 上归为「思考」;若与 streamId 重复则前端会去重。 + // 该条 thinking 用于刷新后持久化展示(与流式聚合一致)。 if choice.Message.Content != "" { sendProgress("thinking", choice.Message.Content, map[string]interface{}{ "iteration": i + 1, diff --git a/internal/database/conversation.go b/internal/database/conversation.go index d4c91086..d23506a4 100644 --- a/internal/database/conversation.go +++ b/internal/database/conversation.go @@ -25,14 +25,15 @@ type Conversation struct { // Message 消息 type Message struct { - ID string `json:"id"` - ConversationID string `json:"conversationId"` - Role string `json:"role"` - Content string `json:"content"` - MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` - ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` + ID string `json:"id"` + ConversationID string `json:"conversationId"` + Role string `json:"role"` + Content string `json:"content"` + ReasoningContent string `json:"reasoningContent,omitempty"` + MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` + ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` } // CreateConversation 创建新对话 @@ -498,8 +499,8 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [ } _, err := db.Exec( - "INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", - id, conversationID, role, content, mcpIDsJSON, now, now, + "INSERT INTO messages (id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + id, conversationID, role, content, "", mcpIDsJSON, now, now, ) if err != nil { return nil, fmt.Errorf("添加消息失败: %w", err) @@ -523,10 +524,30 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [ return message, nil } +// UpdateAssistantMessageFinalize 更新助手消息终态(正文、MCP id、思考链聚合文本,供无轨迹回退时回放)。 +func (db *DB) UpdateAssistantMessageFinalize(messageID, content string, mcpExecutionIDs []string, reasoningContent string) error { + var mcpIDsJSON string + if len(mcpExecutionIDs) > 0 { + jsonData, err := json.Marshal(mcpExecutionIDs) + if err != nil { + return fmt.Errorf("序列化MCP执行ID失败: %w", err) + } + mcpIDsJSON = string(jsonData) + } + _, err := db.Exec( + "UPDATE messages SET content = ?, mcp_execution_ids = ?, reasoning_content = ?, updated_at = ? WHERE id = ?", + content, mcpIDsJSON, strings.TrimSpace(reasoningContent), time.Now(), messageID, + ) + if err != nil { + return fmt.Errorf("更新助手消息失败: %w", err) + } + return nil +} + // GetMessages 获取对话的所有消息 func (db *DB) GetMessages(conversationID string) ([]Message, error) { rows, err := db.Query( - "SELECT id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC", + "SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC", conversationID, ) if err != nil { @@ -537,13 +558,17 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) { var messages []Message for rows.Next() { var msg Message + var reasoning sql.NullString var mcpIDsJSON sql.NullString var createdAt string var updatedAt sql.NullString - if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt, &updatedAt); err != nil { + if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &reasoning, &mcpIDsJSON, &createdAt, &updatedAt); err != nil { return nil, fmt.Errorf("扫描消息失败: %w", err) } + if reasoning.Valid { + msg.ReasoningContent = reasoning.String + } // 尝试多种时间格式解析 var err error @@ -683,7 +708,7 @@ type ProcessDetail struct { ID string `json:"id"` MessageID string `json:"messageId"` ConversationID string `json:"conversationId"` - EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error + EventType string `json:"eventType"` // iteration, thinking, reasoning_chain, tool_calls_detected, tool_call, tool_result, progress, error Message string `json:"message"` Data string `json:"data"` // JSON格式的数据 CreatedAt time.Time `json:"createdAt"` diff --git a/internal/database/database.go b/internal/database/database.go index 4a354294..6321e1a5 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -594,6 +594,25 @@ func (db *DB) migrateMessagesTable() error { // 回填已有数据:让 updated_at 至少等于 created_at,避免前端出现空/当前时间回退。 _, _ = db.Exec("UPDATE messages SET updated_at = created_at WHERE updated_at IS NULL OR updated_at = ''") + + // reasoning_content:DeepSeek 思考模式 + 工具调用续跑;与 last_react_input 互补,供消息表回退路径回放 + var rcColCount int + errRC := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='reasoning_content'").Scan(&rcColCount) + if errRC != nil { + if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", addErr) + } + } + } else if rcColCount == 0 { + if _, err := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); err != nil { + errMsg := strings.ToLower(err.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", err) + } + } + } return nil } diff --git a/internal/multiagent/eino_adk_run_loop.go b/internal/multiagent/eino_adk_run_loop.go index f84f537e..07db48e7 100644 --- a/internal/multiagent/eino_adk_run_loop.go +++ b/internal/multiagent/eino_adk_run_loop.go @@ -15,6 +15,7 @@ import ( "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/einomcp" + "cyberstrike-ai/internal/openai" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/schema" @@ -550,6 +551,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs var mainAssistantBuf string var mainAssistDupTarget string // 非空表示本段主助手流需缓冲至 EOF,与 execute 输出比对去重 var reasoningBuf string + var prevReasoningDisplay string // UI 用:剥离 Claude 内部 signature 尾缀后的累计展示 var streamRecvErr error type streamMsg struct { chunk *schema.Message @@ -597,19 +599,29 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs var reasoningDelta string reasoningBuf, reasoningDelta = normalizeStreamingDelta(reasoningBuf, chunk.ReasoningContent) if reasoningDelta != "" { - if reasoningStreamID == "" { - reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1)) - progress("thinking_stream_start", " ", map[string]interface{}{ - "streamId": reasoningStreamID, - "source": "eino", - "einoAgent": ev.AgentName, - "einoRole": einoRoleTag(ev.AgentName), - "orchestration": orchMode, + fullDisplay := openai.DisplayReasoningContent(reasoningBuf) + var displayDelta string + if strings.HasPrefix(fullDisplay, prevReasoningDisplay) { + displayDelta = fullDisplay[len(prevReasoningDisplay):] + } else { + displayDelta = fullDisplay + } + prevReasoningDisplay = fullDisplay + if displayDelta != "" { + if reasoningStreamID == "" { + reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1)) + progress("reasoning_chain_stream_start", " ", map[string]interface{}{ + "streamId": reasoningStreamID, + "source": "eino", + "einoAgent": ev.AgentName, + "einoRole": einoRoleTag(ev.AgentName), + "orchestration": orchMode, + }) + } + progress("reasoning_chain_stream_delta", displayDelta, map[string]interface{}{ + "streamId": reasoningStreamID, }) } - progress("thinking_stream_delta", reasoningDelta, map[string]interface{}{ - "streamId": reasoningStreamID, - }) } } if chunk.Content != "" { @@ -777,7 +789,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs if mv.Role == schema.Assistant { if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" { - progress("thinking", strings.TrimSpace(msg.ReasoningContent), map[string]interface{}{ + progress("reasoning_chain", openai.DisplayReasoningContent(strings.TrimSpace(msg.ReasoningContent)), map[string]interface{}{ "conversationId": conversationID, "source": "eino", "einoAgent": ev.AgentName, diff --git a/internal/multiagent/eino_single_runner.go b/internal/multiagent/eino_single_runner.go index 34af0af1..1d5267df 100644 --- a/internal/multiagent/eino_single_runner.go +++ b/internal/multiagent/eino_single_runner.go @@ -13,6 +13,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/openai" + "cyberstrike-ai/internal/reasoning" einoopenai "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino/adk" @@ -37,6 +38,7 @@ func RunEinoSingleChatModelAgent( history []agent.ChatMessage, roleTools []string, progress func(eventType, message string, data interface{}), + reasoningClient *reasoning.ClientIntent, ) (*RunResult, error) { if appCfg == nil || ag == nil { return nil, fmt.Errorf("eino single: 配置或 Agent 为空") @@ -121,6 +123,7 @@ func RunEinoSingleChatModelAgent( Model: appCfg.OpenAI.Model, HTTPClient: httpClient, } + reasoning.ApplyToEinoChatModelConfig(baseModelCfg, &appCfg.OpenAI, reasoningClient) mainModel, err := einoopenai.NewChatModel(ctx, baseModelCfg) if err != nil { diff --git a/internal/multiagent/eino_summarize.go b/internal/multiagent/eino_summarize.go index ade4ec60..b0e418a5 100644 --- a/internal/multiagent/eino_summarize.go +++ b/internal/multiagent/eino_summarize.go @@ -214,7 +214,7 @@ func summarizeFinalizeWithRecentAssistantToolTrail( selectedCount++ } - // 还原时间顺序 + // 还原时间顺序。round 内为原始 *schema.Message 指针,保留 ReasoningContent(DeepSeek 工具续跑所必需)。 selectedMsgs := make([]adk.Message, 0, 8) for i := len(selectedRoundsReverse) - 1; i >= 0; i-- { selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...) diff --git a/internal/multiagent/reasoning_trace.go b/internal/multiagent/reasoning_trace.go new file mode 100644 index 00000000..c2b4db13 --- /dev/null +++ b/internal/multiagent/reasoning_trace.go @@ -0,0 +1,52 @@ +package multiagent + +import ( + "encoding/json" + "fmt" + "strings" +) + +// AggregatedReasoningFromTraceJSON concatenates non-empty assistant `reasoning_content` +// fields from last_react-style JSON (slice of message objects) in document order. +// Used to persist on the single assistant bubble row for audit and for GetMessages fallback +// when the full trace JSON is unavailable. For strict per-message replay, prefer last_react_input. +func AggregatedReasoningFromTraceJSON(traceJSON string) string { + traceJSON = strings.TrimSpace(traceJSON) + if traceJSON == "" { + return "" + } + var arr []map[string]interface{} + if err := json.Unmarshal([]byte(traceJSON), &arr); err != nil { + return "" + } + var b strings.Builder + for _, m := range arr { + role, _ := m["role"].(string) + if !strings.EqualFold(strings.TrimSpace(role), "assistant") { + continue + } + rc := reasoningContentFromMessageMap(m) + if rc == "" { + continue + } + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(rc) + } + return b.String() +} + +func reasoningContentFromMessageMap(m map[string]interface{}) string { + if m == nil { + return "" + } + switch v := m["reasoning_content"].(type) { + case string: + return strings.TrimSpace(v) + case nil: + return "" + default: + return strings.TrimSpace(fmt.Sprint(v)) + } +} diff --git a/internal/multiagent/reasoning_trace_test.go b/internal/multiagent/reasoning_trace_test.go new file mode 100644 index 00000000..da99eec8 --- /dev/null +++ b/internal/multiagent/reasoning_trace_test.go @@ -0,0 +1,20 @@ +package multiagent + +import "testing" + +func TestAggregatedReasoningFromTraceJSON(t *testing.T) { + const j = `[ +{"role":"user","content":"hi"}, +{"role":"assistant","content":"c1","reasoning_content":"r1","tool_calls":[{"id":"1","type":"function","function":{"name":"f","arguments":"{}"}}]}, +{"role":"tool","tool_call_id":"1","content":"out"}, +{"role":"assistant","content":"c2","reasoning_content":"r2"} +]` + got := AggregatedReasoningFromTraceJSON(j) + want := "r1\nr2" + if got != want { + t.Fatalf("got %q want %q", got, want) + } + if AggregatedReasoningFromTraceJSON("") != "" || AggregatedReasoningFromTraceJSON("[]") != "" { + t.Fatal("empty expected") + } +} diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go index 8a0f0e25..34b2a40c 100644 --- a/internal/multiagent/runner.go +++ b/internal/multiagent/runner.go @@ -17,6 +17,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/openai" + "cyberstrike-ai/internal/reasoning" einoopenai "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino/adk" @@ -48,6 +49,7 @@ type toolCallPendingInfo struct { // RunDeepAgent 使用 Eino 多代理预置编排执行一轮对话(deep / plan_execute / supervisor;流式事件通过 progress 回调输出)。 // orchestrationOverride 非空时优先(如聊天/WebShell 请求体);否则用 multi_agent.orchestration(遗留 yaml);皆空则按 deep。 +// reasoningClient 来自 ChatRequest.reasoning;可为 nil(机器人/批量等走全局 openai.reasoning)。 func RunDeepAgent( ctx context.Context, appCfg *config.Config, @@ -61,6 +63,7 @@ func RunDeepAgent( progress func(eventType, message string, data interface{}), agentsMarkdownDir string, orchestrationOverride string, + reasoningClient *reasoning.ClientIntent, ) (*RunResult, error) { if appCfg == nil || ma == nil || ag == nil { return nil, fmt.Errorf("multiagent: 配置或 Agent 为空") @@ -163,6 +166,7 @@ func RunDeepAgent( Model: appCfg.OpenAI.Model, HTTPClient: httpClient, } + reasoning.ApplyToEinoChatModelConfig(baseModelCfg, &appCfg.OpenAI, reasoningClient) deepMaxIter := ma.MaxIteration if deepMaxIter <= 0 { @@ -636,8 +640,13 @@ func historyToMessages(history []agent.ChatMessage, appCfg *config.Config, mwCfg } case "assistant": toolSchema := chatToolCallsToSchema(h.ToolCalls) - if len(toolSchema) > 0 || strings.TrimSpace(h.Content) != "" { - raw = append(raw, schema.AssistantMessage(h.Content, toolSchema)) + hasRC := strings.TrimSpace(h.ReasoningContent) != "" + if len(toolSchema) > 0 || strings.TrimSpace(h.Content) != "" || hasRC { + am := schema.AssistantMessage(h.Content, toolSchema) + if hasRC { + am.ReasoningContent = strings.TrimSpace(h.ReasoningContent) + } + raw = append(raw, am) } case "tool": if strings.TrimSpace(h.ToolCallID) == "" && strings.TrimSpace(h.Content) == "" { diff --git a/internal/multiagent/runner_reasoning_history_test.go b/internal/multiagent/runner_reasoning_history_test.go new file mode 100644 index 00000000..8027c486 --- /dev/null +++ b/internal/multiagent/runner_reasoning_history_test.go @@ -0,0 +1,22 @@ +package multiagent + +import ( + "testing" + + "cyberstrike-ai/internal/agent" +) + +func TestHistoryToMessagesPreservesReasoningContent(t *testing.T) { + h := []agent.ChatMessage{ + {Role: "user", Content: "u"}, + {Role: "assistant", Content: "c", ReasoningContent: "r1", ToolCalls: []agent.ToolCall{{ID: "t1", Type: "function", Function: agent.FunctionCall{Name: "f", Arguments: map[string]interface{}{}}}}}, + } + msgs := historyToMessages(h, nil, nil) + if len(msgs) != 2 { + t.Fatalf("len=%d", len(msgs)) + } + am := msgs[1] + if am.ReasoningContent != "r1" || am.Content != "c" { + t.Fatalf("got reasoning=%q content=%q", am.ReasoningContent, am.Content) + } +} diff --git a/internal/openai/claude_bridge.go b/internal/openai/claude_bridge.go index e2bf73a1..10319202 100644 --- a/internal/openai/claude_bridge.go +++ b/internal/openai/claude_bridge.go @@ -9,6 +9,9 @@ package openai // Stream: Claude SSE (event: content_block_delta / message_delta) → OpenAI SSE 格式 // Auth: Bearer → x-api-key // Tools: OpenAI tools[] → Claude tools[] (input_schema) +// +// Extended thinking: 顶层 `thinking` 从 OpenAI 请求体透传;响应中 `thinking` block 映射为 +// `reasoning_content`(可读前缀 + 内部 JSON 尾缀以保留 signature,供多轮工具续跑;UI 用 openai.DisplayReasoningContent 剥离)。 import ( "bufio" @@ -38,6 +41,7 @@ type claudeRequest struct { Messages []claudeMessage `json:"messages"` Tools []claudeTool `json:"tools,omitempty"` Stream bool `json:"stream,omitempty"` + Thinking json.RawMessage `json:"thinking,omitempty"` } type claudeMessage struct { @@ -76,6 +80,10 @@ type claudeContentBlock struct { // text block Text string `json:"text,omitempty"` + // thinking block (extended thinking) + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` + // tool_use block (assistant 返回) ID string `json:"id,omitempty"` Name string `json:"name,omitempty"` @@ -176,7 +184,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) { // tool_calls (assistant 消息中包含工具调用) if role == "assistant" { + rc, _ := mm["reasoning_content"].(string) + _, thinkingReplay := parseClaudeReasoningAssistantBlocks(rc) + var blocks []claudeContentBlock + for _, tb := range thinkingReplay { + blocks = append(blocks, tb) + } if content != "" { blocks = append(blocks, claudeContentBlock{Type: "text", Text: content}) } @@ -290,6 +304,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) { } } + // Extended thinking (Anthropic top-level); merged from Eino ExtraFields / admin extras. + if th, ok := oai["thinking"]; ok && th != nil { + if raw, err := json.Marshal(th); err == nil && len(raw) > 0 && string(raw) != "null" { + req.Thinking = json.RawMessage(raw) + } + } + return req, nil } @@ -318,9 +339,12 @@ func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) { var textContent string var toolCalls []interface{} + var thinkingBlocks []claudeContentBlock for _, block := range cr.Content { switch block.Type { + case "thinking": + thinkingBlocks = append(thinkingBlocks, block) case "text": textContent += block.Text case "tool_use": @@ -344,6 +368,18 @@ func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) { if len(toolCalls) > 0 { message["tool_calls"] = toolCalls } + if len(thinkingBlocks) > 0 { + var parts []string + for _, tb := range thinkingBlocks { + if strings.TrimSpace(tb.Thinking) != "" { + parts = append(parts, tb.Thinking) + } + } + rc := appendClaudeReasoningRoundTrip(strings.Join(parts, "\n\n"), thinkingBlocks) + if rc != "" { + message["reasoning_content"] = rc + } + } choice := map[string]interface{}{ "index": 0, @@ -901,8 +937,16 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro reader := bufio.NewReader(resp.Body) blockToToolIndex := make(map[int]int) + blockIndexToType := make(map[int]string) nextToolIndex := 0 + type thinkingAcc struct { + text strings.Builder + sig strings.Builder + } + thinkingByIndex := make(map[int]*thinkingAcc) + var finishedThinking []claudeContentBlock + for { line, readErr := reader.ReadString('\n') if readErr != nil { @@ -947,6 +991,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro blockIdx := int(blockIdxFlt) cb, _ := event["content_block"].(map[string]interface{}) bt, _ := cb["type"].(string) + blockIndexToType[blockIdx] = bt + + if bt == "thinking" { + thinkingByIndex[blockIdx] = &thinkingAcc{} + } if bt == "tool_use" { id, _ := cb["id"].(string) @@ -986,7 +1035,35 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro delta, _ := event["delta"].(map[string]interface{}) dt, _ := delta["type"].(string) - if dt == "text_delta" { + if dt == "thinking_delta" { + tPart, _ := delta["thinking"].(string) + if tPart != "" { + if acc := thinkingByIndex[blockIdx]; acc != nil { + acc.text.WriteString(tPart) + } + oaiChunk := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "delta": map[string]interface{}{ + "reasoning_content": tPart, + }, + }, + }, + } + b, _ := json.Marshal(oaiChunk) + if !writeLine("data: " + string(b) + "\n\n") { + pw.Close() + return + } + } + } else if dt == "signature_delta" { + sigPart, _ := delta["signature"].(string) + if sigPart != "" { + if acc := thinkingByIndex[blockIdx]; acc != nil { + acc.sig.WriteString(sigPart) + } + } + } else if dt == "text_delta" { text, _ := delta["text"].(string) oaiChunk := map[string]interface{}{ "choices": []map[string]interface{}{ @@ -1031,6 +1108,21 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro } } + case "content_block_stop": + blockIdxFlt, _ := event["index"].(float64) + blockIdx := int(blockIdxFlt) + bt := blockIndexToType[blockIdx] + if bt == "thinking" { + if acc := thinkingByIndex[blockIdx]; acc != nil { + finishedThinking = append(finishedThinking, claudeContentBlock{ + Type: "thinking", + Thinking: acc.text.String(), + Signature: acc.sig.String(), + }) + delete(thinkingByIndex, blockIdx) + } + } + case "message_delta": d, _ := event["delta"].(map[string]interface{}) if sr, ok := d["stop_reason"].(string); ok { @@ -1051,6 +1143,25 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro } case "message_stop": + if len(finishedThinking) > 0 { + suffix := appendClaudeReasoningRoundTrip("", finishedThinking) + if strings.TrimSpace(suffix) != "" { + oaiChunk := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "delta": map[string]interface{}{ + "reasoning_content": suffix, + }, + }, + }, + } + b, _ := json.Marshal(oaiChunk) + if !writeLine("data: " + string(b) + "\n\n") { + pw.Close() + return + } + } + } writeLine("data: [DONE]\n\n") pw.Close() return diff --git a/internal/openai/claude_reasoning_roundtrip.go b/internal/openai/claude_reasoning_roundtrip.go new file mode 100644 index 00000000..1eae4c67 --- /dev/null +++ b/internal/openai/claude_reasoning_roundtrip.go @@ -0,0 +1,81 @@ +package openai + +import ( + "encoding/json" + "strings" +) + +// claudeReasoningRoundTripSep separates human-readable reasoning from a JSON payload of +// Anthropic thinking blocks (with signatures) for multi-turn extended thinking + tools. +// Not shown in UI (see DisplayReasoningContent). +const claudeReasoningRoundTripSep = "\n---CSAI_CLAUDE_THINKING_BLOCKS---\n" + +// DisplayReasoningContent returns reasoning text suitable for the UI (strips internal +// Claude round-trip JSON suffix). Safe for DeepSeek/plain reasoning strings (no-op). +func DisplayReasoningContent(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + i := strings.LastIndex(s, claudeReasoningRoundTripSep) + if i < 0 { + return s + } + return strings.TrimSpace(s[:i]) +} + +func appendClaudeReasoningRoundTrip(display string, blocks []claudeContentBlock) string { + var payload []map[string]string + for _, b := range blocks { + if b.Type != "thinking" { + continue + } + payload = append(payload, map[string]string{ + "type": b.Type, + "thinking": b.Thinking, + "signature": b.Signature, + }) + } + if len(payload) == 0 { + return strings.TrimSpace(display) + } + js, err := json.Marshal(payload) + if err != nil { + return strings.TrimSpace(display) + } + d := strings.TrimSpace(display) + if d == "" { + return claudeReasoningRoundTripSep + string(js) + } + return d + claudeReasoningRoundTripSep + string(js) +} + +// parseClaudeReasoningAssistantBlocks extracts Anthropic thinking blocks from an OpenAI-style +// reasoning_content string. When no suffix is present, blocks is nil (caller must not invent signatures). +func parseClaudeReasoningAssistantBlocks(reasoningContent string) (display string, blocks []claudeContentBlock) { + reasoningContent = strings.TrimSpace(reasoningContent) + if reasoningContent == "" { + return "", nil + } + idx := strings.LastIndex(reasoningContent, claudeReasoningRoundTripSep) + if idx < 0 { + return reasoningContent, nil + } + display = strings.TrimSpace(reasoningContent[:idx]) + jsonPart := strings.TrimSpace(reasoningContent[idx+len(claudeReasoningRoundTripSep):]) + var arr []struct { + Type string `json:"type"` + Thinking string `json:"thinking"` + Signature string `json:"signature"` + } + if err := json.Unmarshal([]byte(jsonPart), &arr); err != nil { + return reasoningContent, nil + } + for _, x := range arr { + if x.Type != "thinking" { + continue + } + blocks = append(blocks, claudeContentBlock{Type: "thinking", Thinking: x.Thinking, Signature: x.Signature}) + } + return display, blocks +} diff --git a/internal/openai/claude_reasoning_roundtrip_test.go b/internal/openai/claude_reasoning_roundtrip_test.go new file mode 100644 index 00000000..6b112f1a --- /dev/null +++ b/internal/openai/claude_reasoning_roundtrip_test.go @@ -0,0 +1,102 @@ +package openai + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestDisplayReasoningContent(t *testing.T) { + raw := "hello" + claudeReasoningRoundTripSep + `[{"type":"thinking","thinking":"x","signature":"sig"}]` + if d := DisplayReasoningContent(raw); d != "hello" { + t.Fatalf("got %q", d) + } + if DisplayReasoningContent("plain") != "plain" { + t.Fatal() + } +} + +func TestAppendParseClaudeReasoningRoundTrip(t *testing.T) { + blocks := []claudeContentBlock{ + {Type: "thinking", Thinking: "a", Signature: "s1"}, + {Type: "thinking", Thinking: "b", Signature: "s2"}, + } + s := appendClaudeReasoningRoundTrip("sum", blocks) + if !strings.Contains(s, claudeReasoningRoundTripSep) { + t.Fatal("missing sep") + } + display, back := parseClaudeReasoningAssistantBlocks(s) + if display != "sum" || len(back) != 2 { + t.Fatalf("display=%q len=%d", display, len(back)) + } + if back[0].Signature != "s1" || back[1].Thinking != "b" { + t.Fatalf("%+v", back) + } +} + +func TestConvertOpenAIToClaude_AssistantReasoningReplay(t *testing.T) { + rc := appendClaudeReasoningRoundTrip("vis", []claudeContentBlock{ + {Type: "thinking", Thinking: "t1", Signature: "sig1"}, + }) + payload := map[string]interface{}{ + "model": "claude-3-5-sonnet-latest", + "messages": []interface{}{ + map[string]interface{}{ + "role": "assistant", + "content": "out", + "reasoning_content": rc, + }, + }, + } + req, err := convertOpenAIToClaude(payload) + if err != nil { + t.Fatal(err) + } + if len(req.Messages) != 1 { + t.Fatalf("messages=%d", len(req.Messages)) + } + blocks := req.Messages[0].Content.Blocks + if len(blocks) < 2 { + t.Fatalf("blocks=%d", len(blocks)) + } + if blocks[0].Type != "thinking" || blocks[0].Signature != "sig1" { + t.Fatalf("first block %+v", blocks[0]) + } + foundText := false + for _, b := range blocks { + if b.Type == "text" && b.Text == "out" { + foundText = true + } + } + if !foundText { + t.Fatalf("blocks=%+v", blocks) + } +} + +func TestClaudeToOpenAIResponseJSON_Thinking(t *testing.T) { + claudeBody := []byte(`{ + "id":"msg_1","type":"message","role":"assistant","model":"x","stop_reason":"end_turn", + "content":[ + {"type":"thinking","thinking":"step","signature":"sigx"}, + {"type":"text","text":"hi"} + ] + }`) + oai, err := claudeToOpenAIResponseJSON(claudeBody) + if err != nil { + t.Fatal(err) + } + var wrap map[string]interface{} + if err := json.Unmarshal(oai, &wrap); err != nil { + t.Fatal(err) + } + choices := wrap["choices"].([]interface{}) + ch0 := choices[0].(map[string]interface{}) + msg := ch0["message"].(map[string]interface{}) + rc, _ := msg["reasoning_content"].(string) + if !strings.Contains(rc, "step") || !strings.Contains(rc, claudeReasoningRoundTripSep) { + t.Fatalf("reasoning_content=%q", rc) + } + if msg["content"] != "hi" { + t.Fatal() + } +}