mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-15 04:51:01 +02:00
Add files via upload
This commit is contained in:
@@ -195,6 +195,8 @@ type ChatMessage struct {
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
// ToolName 仅 tool 角色:从 Eino/轨迹 JSON 的 name 或 tool_name 恢复,供续跑构造 ToolMessage。
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
// ReasoningContent 对应 OpenAI/DeepSeek 的 reasoning_content;思考模式 + 工具调用后续跑须回传(见 DeepSeek 文档)。
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串
|
||||
@@ -208,6 +210,9 @@ func (cm ChatMessage) MarshalJSON() ([]byte, error) {
|
||||
if cm.Content != "" {
|
||||
aux["content"] = cm.Content
|
||||
}
|
||||
if cm.ReasoningContent != "" {
|
||||
aux["reasoning_content"] = cm.ReasoningContent
|
||||
}
|
||||
|
||||
// 添加tool_call_id(如果存在)
|
||||
if cm.ToolCallID != "" {
|
||||
@@ -663,8 +668,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// 检查是否有工具调用
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// 思考内容:如果本轮启用了思考流式增量(thinking_stream_*),前端会去重;
|
||||
// 同时也需要在该“思考阶段结束”时补一条可落库的 thinking(用于刷新后持久化展示)。
|
||||
// ReAct 助手正文流式增量(thinking_stream_*)在 UI 上归为「思考」;若与 streamId 重复则前端会去重。
|
||||
// 该条 thinking 用于刷新后持久化展示(与流式聚合一致)。
|
||||
if choice.Message.Content != "" {
|
||||
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
|
||||
"iteration": i + 1,
|
||||
|
||||
@@ -25,14 +25,15 @@ type Conversation struct {
|
||||
|
||||
// Message 消息
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoningContent,omitempty"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
@@ -498,8 +499,8 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
|
||||
}
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, mcpIDsJSON, now, now,
|
||||
"INSERT INTO messages (id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, "", mcpIDsJSON, now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("添加消息失败: %w", err)
|
||||
@@ -523,10 +524,30 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
|
||||
return message, nil
|
||||
}
|
||||
|
||||
// UpdateAssistantMessageFinalize 更新助手消息终态(正文、MCP id、思考链聚合文本,供无轨迹回退时回放)。
|
||||
func (db *DB) UpdateAssistantMessageFinalize(messageID, content string, mcpExecutionIDs []string, reasoningContent string) error {
|
||||
var mcpIDsJSON string
|
||||
if len(mcpExecutionIDs) > 0 {
|
||||
jsonData, err := json.Marshal(mcpExecutionIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化MCP执行ID失败: %w", err)
|
||||
}
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, err := db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, reasoning_content = ?, updated_at = ? WHERE id = ?",
|
||||
content, mcpIDsJSON, strings.TrimSpace(reasoningContent), time.Now(), messageID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新助手消息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessages 获取对话的所有消息
|
||||
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
"SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -537,13 +558,17 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
var reasoning sql.NullString
|
||||
var mcpIDsJSON sql.NullString
|
||||
var createdAt string
|
||||
var updatedAt sql.NullString
|
||||
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &reasoning, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||
}
|
||||
if reasoning.Valid {
|
||||
msg.ReasoningContent = reasoning.String
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err error
|
||||
@@ -683,7 +708,7 @@ type ProcessDetail struct {
|
||||
ID string `json:"id"`
|
||||
MessageID string `json:"messageId"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error
|
||||
EventType string `json:"eventType"` // iteration, thinking, reasoning_chain, tool_calls_detected, tool_call, tool_result, progress, error
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"` // JSON格式的数据
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
|
||||
@@ -594,6 +594,25 @@ func (db *DB) migrateMessagesTable() error {
|
||||
|
||||
// 回填已有数据:让 updated_at 至少等于 created_at,避免前端出现空/当前时间回退。
|
||||
_, _ = db.Exec("UPDATE messages SET updated_at = created_at WHERE updated_at IS NULL OR updated_at = ''")
|
||||
|
||||
// reasoning_content:DeepSeek 思考模式 + 工具调用续跑;与 last_react_input 互补,供消息表回退路径回放
|
||||
var rcColCount int
|
||||
errRC := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='reasoning_content'").Scan(&rcColCount)
|
||||
if errRC != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", addErr)
|
||||
}
|
||||
}
|
||||
} else if rcColCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); err != nil {
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
@@ -550,6 +551,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
var mainAssistantBuf string
|
||||
var mainAssistDupTarget string // 非空表示本段主助手流需缓冲至 EOF,与 execute 输出比对去重
|
||||
var reasoningBuf string
|
||||
var prevReasoningDisplay string // UI 用:剥离 Claude 内部 signature 尾缀后的累计展示
|
||||
var streamRecvErr error
|
||||
type streamMsg struct {
|
||||
chunk *schema.Message
|
||||
@@ -597,19 +599,29 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
var reasoningDelta string
|
||||
reasoningBuf, reasoningDelta = normalizeStreamingDelta(reasoningBuf, chunk.ReasoningContent)
|
||||
if reasoningDelta != "" {
|
||||
if reasoningStreamID == "" {
|
||||
reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1))
|
||||
progress("thinking_stream_start", " ", map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
fullDisplay := openai.DisplayReasoningContent(reasoningBuf)
|
||||
var displayDelta string
|
||||
if strings.HasPrefix(fullDisplay, prevReasoningDisplay) {
|
||||
displayDelta = fullDisplay[len(prevReasoningDisplay):]
|
||||
} else {
|
||||
displayDelta = fullDisplay
|
||||
}
|
||||
prevReasoningDisplay = fullDisplay
|
||||
if displayDelta != "" {
|
||||
if reasoningStreamID == "" {
|
||||
reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1))
|
||||
progress("reasoning_chain_stream_start", " ", map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
progress("reasoning_chain_stream_delta", displayDelta, map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
})
|
||||
}
|
||||
progress("thinking_stream_delta", reasoningDelta, map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
})
|
||||
}
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
@@ -777,7 +789,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
|
||||
if mv.Role == schema.Assistant {
|
||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||
progress("thinking", strings.TrimSpace(msg.ReasoningContent), map[string]interface{}{
|
||||
progress("reasoning_chain", openai.DisplayReasoningContent(strings.TrimSpace(msg.ReasoningContent)), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -37,6 +38,7 @@ func RunEinoSingleChatModelAgent(
|
||||
history []agent.ChatMessage,
|
||||
roleTools []string,
|
||||
progress func(eventType, message string, data interface{}),
|
||||
reasoningClient *reasoning.ClientIntent,
|
||||
) (*RunResult, error) {
|
||||
if appCfg == nil || ag == nil {
|
||||
return nil, fmt.Errorf("eino single: 配置或 Agent 为空")
|
||||
@@ -121,6 +123,7 @@ func RunEinoSingleChatModelAgent(
|
||||
Model: appCfg.OpenAI.Model,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
reasoning.ApplyToEinoChatModelConfig(baseModelCfg, &appCfg.OpenAI, reasoningClient)
|
||||
|
||||
mainModel, err := einoopenai.NewChatModel(ctx, baseModelCfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -214,7 +214,7 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
selectedCount++
|
||||
}
|
||||
|
||||
// 还原时间顺序
|
||||
// 还原时间顺序。round 内为原始 *schema.Message 指针,保留 ReasoningContent(DeepSeek 工具续跑所必需)。
|
||||
selectedMsgs := make([]adk.Message, 0, 8)
|
||||
for i := len(selectedRoundsReverse) - 1; i >= 0; i-- {
|
||||
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AggregatedReasoningFromTraceJSON concatenates non-empty assistant `reasoning_content`
|
||||
// fields from last_react-style JSON (slice of message objects) in document order.
|
||||
// Used to persist on the single assistant bubble row for audit and for GetMessages fallback
|
||||
// when the full trace JSON is unavailable. For strict per-message replay, prefer last_react_input.
|
||||
func AggregatedReasoningFromTraceJSON(traceJSON string) string {
|
||||
traceJSON = strings.TrimSpace(traceJSON)
|
||||
if traceJSON == "" {
|
||||
return ""
|
||||
}
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(traceJSON), &arr); err != nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, m := range arr {
|
||||
role, _ := m["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "assistant") {
|
||||
continue
|
||||
}
|
||||
rc := reasoningContentFromMessageMap(m)
|
||||
if rc == "" {
|
||||
continue
|
||||
}
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(rc)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func reasoningContentFromMessageMap(m map[string]interface{}) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := m["reasoning_content"].(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case nil:
|
||||
return ""
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package multiagent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAggregatedReasoningFromTraceJSON(t *testing.T) {
|
||||
const j = `[
|
||||
{"role":"user","content":"hi"},
|
||||
{"role":"assistant","content":"c1","reasoning_content":"r1","tool_calls":[{"id":"1","type":"function","function":{"name":"f","arguments":"{}"}}]},
|
||||
{"role":"tool","tool_call_id":"1","content":"out"},
|
||||
{"role":"assistant","content":"c2","reasoning_content":"r2"}
|
||||
]`
|
||||
got := AggregatedReasoningFromTraceJSON(j)
|
||||
want := "r1\nr2"
|
||||
if got != want {
|
||||
t.Fatalf("got %q want %q", got, want)
|
||||
}
|
||||
if AggregatedReasoningFromTraceJSON("") != "" || AggregatedReasoningFromTraceJSON("[]") != "" {
|
||||
t.Fatal("empty expected")
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -48,6 +49,7 @@ type toolCallPendingInfo struct {
|
||||
|
||||
// RunDeepAgent 使用 Eino 多代理预置编排执行一轮对话(deep / plan_execute / supervisor;流式事件通过 progress 回调输出)。
|
||||
// orchestrationOverride 非空时优先(如聊天/WebShell 请求体);否则用 multi_agent.orchestration(遗留 yaml);皆空则按 deep。
|
||||
// reasoningClient 来自 ChatRequest.reasoning;可为 nil(机器人/批量等走全局 openai.reasoning)。
|
||||
func RunDeepAgent(
|
||||
ctx context.Context,
|
||||
appCfg *config.Config,
|
||||
@@ -61,6 +63,7 @@ func RunDeepAgent(
|
||||
progress func(eventType, message string, data interface{}),
|
||||
agentsMarkdownDir string,
|
||||
orchestrationOverride string,
|
||||
reasoningClient *reasoning.ClientIntent,
|
||||
) (*RunResult, error) {
|
||||
if appCfg == nil || ma == nil || ag == nil {
|
||||
return nil, fmt.Errorf("multiagent: 配置或 Agent 为空")
|
||||
@@ -163,6 +166,7 @@ func RunDeepAgent(
|
||||
Model: appCfg.OpenAI.Model,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
reasoning.ApplyToEinoChatModelConfig(baseModelCfg, &appCfg.OpenAI, reasoningClient)
|
||||
|
||||
deepMaxIter := ma.MaxIteration
|
||||
if deepMaxIter <= 0 {
|
||||
@@ -636,8 +640,13 @@ func historyToMessages(history []agent.ChatMessage, appCfg *config.Config, mwCfg
|
||||
}
|
||||
case "assistant":
|
||||
toolSchema := chatToolCallsToSchema(h.ToolCalls)
|
||||
if len(toolSchema) > 0 || strings.TrimSpace(h.Content) != "" {
|
||||
raw = append(raw, schema.AssistantMessage(h.Content, toolSchema))
|
||||
hasRC := strings.TrimSpace(h.ReasoningContent) != ""
|
||||
if len(toolSchema) > 0 || strings.TrimSpace(h.Content) != "" || hasRC {
|
||||
am := schema.AssistantMessage(h.Content, toolSchema)
|
||||
if hasRC {
|
||||
am.ReasoningContent = strings.TrimSpace(h.ReasoningContent)
|
||||
}
|
||||
raw = append(raw, am)
|
||||
}
|
||||
case "tool":
|
||||
if strings.TrimSpace(h.ToolCallID) == "" && strings.TrimSpace(h.Content) == "" {
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
)
|
||||
|
||||
func TestHistoryToMessagesPreservesReasoningContent(t *testing.T) {
|
||||
h := []agent.ChatMessage{
|
||||
{Role: "user", Content: "u"},
|
||||
{Role: "assistant", Content: "c", ReasoningContent: "r1", ToolCalls: []agent.ToolCall{{ID: "t1", Type: "function", Function: agent.FunctionCall{Name: "f", Arguments: map[string]interface{}{}}}}},
|
||||
}
|
||||
msgs := historyToMessages(h, nil, nil)
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("len=%d", len(msgs))
|
||||
}
|
||||
am := msgs[1]
|
||||
if am.ReasoningContent != "r1" || am.Content != "c" {
|
||||
t.Fatalf("got reasoning=%q content=%q", am.ReasoningContent, am.Content)
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,9 @@ package openai
|
||||
// Stream: Claude SSE (event: content_block_delta / message_delta) → OpenAI SSE 格式
|
||||
// Auth: Bearer → x-api-key
|
||||
// Tools: OpenAI tools[] → Claude tools[] (input_schema)
|
||||
//
|
||||
// Extended thinking: 顶层 `thinking` 从 OpenAI 请求体透传;响应中 `thinking` block 映射为
|
||||
// `reasoning_content`(可读前缀 + 内部 JSON 尾缀以保留 signature,供多轮工具续跑;UI 用 openai.DisplayReasoningContent 剥离)。
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -38,6 +41,7 @@ type claudeRequest struct {
|
||||
Messages []claudeMessage `json:"messages"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Thinking json.RawMessage `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type claudeMessage struct {
|
||||
@@ -76,6 +80,10 @@ type claudeContentBlock struct {
|
||||
// text block
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// thinking block (extended thinking)
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
|
||||
// tool_use block (assistant 返回)
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -176,7 +184,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
||||
|
||||
// tool_calls (assistant 消息中包含工具调用)
|
||||
if role == "assistant" {
|
||||
rc, _ := mm["reasoning_content"].(string)
|
||||
_, thinkingReplay := parseClaudeReasoningAssistantBlocks(rc)
|
||||
|
||||
var blocks []claudeContentBlock
|
||||
for _, tb := range thinkingReplay {
|
||||
blocks = append(blocks, tb)
|
||||
}
|
||||
if content != "" {
|
||||
blocks = append(blocks, claudeContentBlock{Type: "text", Text: content})
|
||||
}
|
||||
@@ -290,6 +304,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Extended thinking (Anthropic top-level); merged from Eino ExtraFields / admin extras.
|
||||
if th, ok := oai["thinking"]; ok && th != nil {
|
||||
if raw, err := json.Marshal(th); err == nil && len(raw) > 0 && string(raw) != "null" {
|
||||
req.Thinking = json.RawMessage(raw)
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -318,9 +339,12 @@ func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) {
|
||||
|
||||
var textContent string
|
||||
var toolCalls []interface{}
|
||||
var thinkingBlocks []claudeContentBlock
|
||||
|
||||
for _, block := range cr.Content {
|
||||
switch block.Type {
|
||||
case "thinking":
|
||||
thinkingBlocks = append(thinkingBlocks, block)
|
||||
case "text":
|
||||
textContent += block.Text
|
||||
case "tool_use":
|
||||
@@ -344,6 +368,18 @@ func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) {
|
||||
if len(toolCalls) > 0 {
|
||||
message["tool_calls"] = toolCalls
|
||||
}
|
||||
if len(thinkingBlocks) > 0 {
|
||||
var parts []string
|
||||
for _, tb := range thinkingBlocks {
|
||||
if strings.TrimSpace(tb.Thinking) != "" {
|
||||
parts = append(parts, tb.Thinking)
|
||||
}
|
||||
}
|
||||
rc := appendClaudeReasoningRoundTrip(strings.Join(parts, "\n\n"), thinkingBlocks)
|
||||
if rc != "" {
|
||||
message["reasoning_content"] = rc
|
||||
}
|
||||
}
|
||||
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
@@ -901,8 +937,16 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
blockToToolIndex := make(map[int]int)
|
||||
blockIndexToType := make(map[int]string)
|
||||
nextToolIndex := 0
|
||||
|
||||
type thinkingAcc struct {
|
||||
text strings.Builder
|
||||
sig strings.Builder
|
||||
}
|
||||
thinkingByIndex := make(map[int]*thinkingAcc)
|
||||
var finishedThinking []claudeContentBlock
|
||||
|
||||
for {
|
||||
line, readErr := reader.ReadString('\n')
|
||||
if readErr != nil {
|
||||
@@ -947,6 +991,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
blockIdx := int(blockIdxFlt)
|
||||
cb, _ := event["content_block"].(map[string]interface{})
|
||||
bt, _ := cb["type"].(string)
|
||||
blockIndexToType[blockIdx] = bt
|
||||
|
||||
if bt == "thinking" {
|
||||
thinkingByIndex[blockIdx] = &thinkingAcc{}
|
||||
}
|
||||
|
||||
if bt == "tool_use" {
|
||||
id, _ := cb["id"].(string)
|
||||
@@ -986,7 +1035,35 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
delta, _ := event["delta"].(map[string]interface{})
|
||||
dt, _ := delta["type"].(string)
|
||||
|
||||
if dt == "text_delta" {
|
||||
if dt == "thinking_delta" {
|
||||
tPart, _ := delta["thinking"].(string)
|
||||
if tPart != "" {
|
||||
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||
acc.text.WriteString(tPart)
|
||||
}
|
||||
oaiChunk := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"delta": map[string]interface{}{
|
||||
"reasoning_content": tPart,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(oaiChunk)
|
||||
if !writeLine("data: " + string(b) + "\n\n") {
|
||||
pw.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if dt == "signature_delta" {
|
||||
sigPart, _ := delta["signature"].(string)
|
||||
if sigPart != "" {
|
||||
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||
acc.sig.WriteString(sigPart)
|
||||
}
|
||||
}
|
||||
} else if dt == "text_delta" {
|
||||
text, _ := delta["text"].(string)
|
||||
oaiChunk := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
@@ -1031,6 +1108,21 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
blockIdxFlt, _ := event["index"].(float64)
|
||||
blockIdx := int(blockIdxFlt)
|
||||
bt := blockIndexToType[blockIdx]
|
||||
if bt == "thinking" {
|
||||
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||
finishedThinking = append(finishedThinking, claudeContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: acc.text.String(),
|
||||
Signature: acc.sig.String(),
|
||||
})
|
||||
delete(thinkingByIndex, blockIdx)
|
||||
}
|
||||
}
|
||||
|
||||
case "message_delta":
|
||||
d, _ := event["delta"].(map[string]interface{})
|
||||
if sr, ok := d["stop_reason"].(string); ok {
|
||||
@@ -1051,6 +1143,25 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
}
|
||||
|
||||
case "message_stop":
|
||||
if len(finishedThinking) > 0 {
|
||||
suffix := appendClaudeReasoningRoundTrip("", finishedThinking)
|
||||
if strings.TrimSpace(suffix) != "" {
|
||||
oaiChunk := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"delta": map[string]interface{}{
|
||||
"reasoning_content": suffix,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(oaiChunk)
|
||||
if !writeLine("data: " + string(b) + "\n\n") {
|
||||
pw.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
writeLine("data: [DONE]\n\n")
|
||||
pw.Close()
|
||||
return
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// claudeReasoningRoundTripSep separates human-readable reasoning from a JSON payload of
|
||||
// Anthropic thinking blocks (with signatures) for multi-turn extended thinking + tools.
|
||||
// Not shown in UI (see DisplayReasoningContent).
|
||||
const claudeReasoningRoundTripSep = "\n---CSAI_CLAUDE_THINKING_BLOCKS---\n"
|
||||
|
||||
// DisplayReasoningContent returns reasoning text suitable for the UI (strips internal
|
||||
// Claude round-trip JSON suffix). Safe for DeepSeek/plain reasoning strings (no-op).
|
||||
func DisplayReasoningContent(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
i := strings.LastIndex(s, claudeReasoningRoundTripSep)
|
||||
if i < 0 {
|
||||
return s
|
||||
}
|
||||
return strings.TrimSpace(s[:i])
|
||||
}
|
||||
|
||||
func appendClaudeReasoningRoundTrip(display string, blocks []claudeContentBlock) string {
|
||||
var payload []map[string]string
|
||||
for _, b := range blocks {
|
||||
if b.Type != "thinking" {
|
||||
continue
|
||||
}
|
||||
payload = append(payload, map[string]string{
|
||||
"type": b.Type,
|
||||
"thinking": b.Thinking,
|
||||
"signature": b.Signature,
|
||||
})
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return strings.TrimSpace(display)
|
||||
}
|
||||
js, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(display)
|
||||
}
|
||||
d := strings.TrimSpace(display)
|
||||
if d == "" {
|
||||
return claudeReasoningRoundTripSep + string(js)
|
||||
}
|
||||
return d + claudeReasoningRoundTripSep + string(js)
|
||||
}
|
||||
|
||||
// parseClaudeReasoningAssistantBlocks extracts Anthropic thinking blocks from an OpenAI-style
|
||||
// reasoning_content string. When no suffix is present, blocks is nil (caller must not invent signatures).
|
||||
func parseClaudeReasoningAssistantBlocks(reasoningContent string) (display string, blocks []claudeContentBlock) {
|
||||
reasoningContent = strings.TrimSpace(reasoningContent)
|
||||
if reasoningContent == "" {
|
||||
return "", nil
|
||||
}
|
||||
idx := strings.LastIndex(reasoningContent, claudeReasoningRoundTripSep)
|
||||
if idx < 0 {
|
||||
return reasoningContent, nil
|
||||
}
|
||||
display = strings.TrimSpace(reasoningContent[:idx])
|
||||
jsonPart := strings.TrimSpace(reasoningContent[idx+len(claudeReasoningRoundTripSep):])
|
||||
var arr []struct {
|
||||
Type string `json:"type"`
|
||||
Thinking string `json:"thinking"`
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(jsonPart), &arr); err != nil {
|
||||
return reasoningContent, nil
|
||||
}
|
||||
for _, x := range arr {
|
||||
if x.Type != "thinking" {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, claudeContentBlock{Type: "thinking", Thinking: x.Thinking, Signature: x.Signature})
|
||||
}
|
||||
return display, blocks
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDisplayReasoningContent(t *testing.T) {
|
||||
raw := "hello" + claudeReasoningRoundTripSep + `[{"type":"thinking","thinking":"x","signature":"sig"}]`
|
||||
if d := DisplayReasoningContent(raw); d != "hello" {
|
||||
t.Fatalf("got %q", d)
|
||||
}
|
||||
if DisplayReasoningContent("plain") != "plain" {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendParseClaudeReasoningRoundTrip(t *testing.T) {
|
||||
blocks := []claudeContentBlock{
|
||||
{Type: "thinking", Thinking: "a", Signature: "s1"},
|
||||
{Type: "thinking", Thinking: "b", Signature: "s2"},
|
||||
}
|
||||
s := appendClaudeReasoningRoundTrip("sum", blocks)
|
||||
if !strings.Contains(s, claudeReasoningRoundTripSep) {
|
||||
t.Fatal("missing sep")
|
||||
}
|
||||
display, back := parseClaudeReasoningAssistantBlocks(s)
|
||||
if display != "sum" || len(back) != 2 {
|
||||
t.Fatalf("display=%q len=%d", display, len(back))
|
||||
}
|
||||
if back[0].Signature != "s1" || back[1].Thinking != "b" {
|
||||
t.Fatalf("%+v", back)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIToClaude_AssistantReasoningReplay(t *testing.T) {
|
||||
rc := appendClaudeReasoningRoundTrip("vis", []claudeContentBlock{
|
||||
{Type: "thinking", Thinking: "t1", Signature: "sig1"},
|
||||
})
|
||||
payload := map[string]interface{}{
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"messages": []interface{}{
|
||||
map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": "out",
|
||||
"reasoning_content": rc,
|
||||
},
|
||||
},
|
||||
}
|
||||
req, err := convertOpenAIToClaude(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(req.Messages) != 1 {
|
||||
t.Fatalf("messages=%d", len(req.Messages))
|
||||
}
|
||||
blocks := req.Messages[0].Content.Blocks
|
||||
if len(blocks) < 2 {
|
||||
t.Fatalf("blocks=%d", len(blocks))
|
||||
}
|
||||
if blocks[0].Type != "thinking" || blocks[0].Signature != "sig1" {
|
||||
t.Fatalf("first block %+v", blocks[0])
|
||||
}
|
||||
foundText := false
|
||||
for _, b := range blocks {
|
||||
if b.Type == "text" && b.Text == "out" {
|
||||
foundText = true
|
||||
}
|
||||
}
|
||||
if !foundText {
|
||||
t.Fatalf("blocks=%+v", blocks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIResponseJSON_Thinking(t *testing.T) {
|
||||
claudeBody := []byte(`{
|
||||
"id":"msg_1","type":"message","role":"assistant","model":"x","stop_reason":"end_turn",
|
||||
"content":[
|
||||
{"type":"thinking","thinking":"step","signature":"sigx"},
|
||||
{"type":"text","text":"hi"}
|
||||
]
|
||||
}`)
|
||||
oai, err := claudeToOpenAIResponseJSON(claudeBody)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var wrap map[string]interface{}
|
||||
if err := json.Unmarshal(oai, &wrap); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
choices := wrap["choices"].([]interface{})
|
||||
ch0 := choices[0].(map[string]interface{})
|
||||
msg := ch0["message"].(map[string]interface{})
|
||||
rc, _ := msg["reasoning_content"].(string)
|
||||
if !strings.Contains(rc, "step") || !strings.Contains(rc, claudeReasoningRoundTripSep) {
|
||||
t.Fatalf("reasoning_content=%q", rc)
|
||||
}
|
||||
if msg["content"] != "hi" {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user