From 0e35506ae14d3cd032172dd67dbf3d330b1b3b22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Fri, 1 May 2026 01:00:23 +0800 Subject: [PATCH] Add files via upload --- internal/multiagent/eino_orchestration.go | 3 + internal/multiagent/eino_summarize.go | 176 ++++++--- internal/multiagent/eino_summarize_test.go | 345 ++++++++++++++++++ .../orphan_tool_pruner_middleware.go | 124 +++++++ .../orphan_tool_pruner_middleware_test.go | 131 +++++++ internal/multiagent/runner.go | 7 + 6 files changed, 730 insertions(+), 56 deletions(-) create mode 100644 internal/multiagent/eino_summarize_test.go create mode 100644 internal/multiagent/orphan_tool_pruner_middleware.go create mode 100644 internal/multiagent/orphan_tool_pruner_middleware_test.go diff --git a/internal/multiagent/eino_orchestration.go b/internal/multiagent/eino_orchestration.go index 7dc7a968..dccd99d5 100644 --- a/internal/multiagent/eino_orchestration.go +++ b/internal/multiagent/eino_orchestration.go @@ -95,6 +95,9 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma } execHandlers = append(execHandlers, sumMw) } + // 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、 + // telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。 + execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor")) if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil { execHandlers = append(execHandlers, teleMw) } diff --git a/internal/multiagent/eino_summarize.go b/internal/multiagent/eino_summarize.go index 3f8defcd..ade4ec60 100644 --- a/internal/multiagent/eino_summarize.go +++ b/internal/multiagent/eino_summarize.go @@ -130,6 +130,14 @@ func newEinoSummarizationMiddleware( } // summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。 +// +// 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。 +// 把消息切成 round(回合)为原子单位: +// - user(...) 单条为一个 round; +// - assistant(tool_calls=[...]) 及其后连续的 role=tool 消息合成一个 round; +// - 其它 assistant(reply, 无 tool_calls) 单条为一个 round。 +// +// 倒序挑 round(预算不够即放弃该 round),保证 tool 消息不会跨 round 被孤立。 func summarizeFinalizeWithRecentAssistantToolTrail( ctx context.Context, originalMessages []adk.Message, @@ -157,80 +165,136 @@ func summarizeFinalizeWithRecentAssistantToolTrail( return out, nil } - selectedReverse := make([]adk.Message, 0, 8) - seen := make(map[adk.Message]struct{}) - totalTokens := 0 - assistantToolKept := 0 - const minAssistantToolTrail = 4 + rounds := splitMessagesIntoRounds(nonSystem) + if len(rounds) == 0 { + out := make([]adk.Message, 0, len(systemMsgs)+1) + out = append(out, systemMsgs...) + out = append(out, summary) + return out, nil + } - tryKeep := func(msg adk.Message) (bool, error) { - if msg == nil { - return false, nil + // 目标:至少保留 minRounds 个 round 的执行轨迹;在预算允许时尽量多保留。 + // 优先确保最后一个 round(通常是最新的 tool 往返或 assistant 回复)存在。 + const minRounds = 2 + + selectedRoundsReverse := make([]messageRound, 0, 8) + selectedCount := 0 + totalTokens := 0 + + tokensOfRound := func(r messageRound) (int, error) { + if len(r.messages) == 0 { + return 0, nil } - if _, ok := seen[msg]; ok { - return false, nil - } - n, err := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: []adk.Message{msg}}) + n, err := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: r.messages}) if err != nil { - return false, err + return 0, err } if n <= 0 { - n = 1 + n = len(r.messages) } + return n, nil + } + + for i := len(rounds) - 1; i >= 0; i-- { + r := rounds[i] + n, err := tokensOfRound(r) + if err != nil { + return nil, err + } + // 预算不够:已经保留了足够 round 则停,否则跳过该 round 继续往前找 + // (避免一个超大 round 挤占全部预算,至少保证有轨迹)。 if totalTokens+n > recentTrailTokenBudget { - return false, nil + if selectedCount >= minRounds { + break + } + continue } totalTokens += n - selectedReverse = append(selectedReverse, msg) - seen[msg] = struct{}{} - return true, nil + selectedRoundsReverse = append(selectedRoundsReverse, r) + selectedCount++ } - // 优先保留最近 assistant/tool,确保执行轨迹可续跑。 - for i := len(nonSystem) - 1; i >= 0; i-- { - msg := nonSystem[i] - if msg.Role != schema.Assistant && msg.Role != schema.Tool { - continue - } - ok, err := tryKeep(msg) - if err != nil { - return nil, err - } - if ok { - assistantToolKept++ - } - if assistantToolKept >= minAssistantToolTrail { - break - } + // 还原时间顺序 + selectedMsgs := make([]adk.Message, 0, 8) + for i := len(selectedRoundsReverse) - 1; i >= 0; i-- { + selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...) } - // 在预算内回填更多最近消息,保持短链路上下文。 - for i := len(nonSystem) - 1; i >= 0; i-- { - _, exists := seen[nonSystem[i]] - if exists { - continue - } - ok, err := tryKeep(nonSystem[i]) - if err != nil { - return nil, err - } - if !ok { - break - } - } - - selected := make([]adk.Message, 0, len(selectedReverse)) - for i := len(selectedReverse) - 1; i >= 0; i-- { - selected = append(selected, selectedReverse[i]) - } - - out := make([]adk.Message, 0, len(systemMsgs)+1+len(selected)) + out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs)) out = append(out, systemMsgs...) out = append(out, summary) - out = append(out, selected...) + out = append(out, selectedMsgs...) return out, nil } +// messageRound 表示一个"不可分割"的消息回合。 +// - 对 assistant(tool_calls) + 随后若干 tool 消息的组合,round 内全部 call_id 成对完整; +// - 对独立的 user / assistant(reply) 消息,round 仅包含该条消息。 +type messageRound struct { + messages []adk.Message +} + +// splitMessagesIntoRounds 将非 system 消息切分为若干 round,保证: +// - 每个 assistant(tool_calls) 与其对应的 role=tool 响应消息在同一个 round; +// - 孤立(无对应 assistant(tool_calls))的 role=tool 消息不会单独成为 round, +// 而是被丢弃(这些消息在 pair 完整性层面已属孤儿,保留反而会触发 LLM 400)。 +func splitMessagesIntoRounds(msgs []adk.Message) []messageRound { + if len(msgs) == 0 { + return nil + } + rounds := make([]messageRound, 0, len(msgs)) + i := 0 + for i < len(msgs) { + msg := msgs[i] + if msg == nil { + i++ + continue + } + switch { + case msg.Role == schema.Assistant && len(msg.ToolCalls) > 0: + // 收集该 assistant 提供的 call_id 集合。 + provided := make(map[string]struct{}, len(msg.ToolCalls)) + for _, tc := range msg.ToolCalls { + if tc.ID != "" { + provided[tc.ID] = struct{}{} + } + } + round := messageRound{messages: []adk.Message{msg}} + j := i + 1 + for j < len(msgs) { + next := msgs[j] + if next == nil { + j++ + continue + } + if next.Role != schema.Tool { + break + } + if next.ToolCallID != "" { + if _, ok := provided[next.ToolCallID]; !ok { + // 下一条 tool 不属于当前 assistant,认为当前 round 结束。 + break + } + } + round.messages = append(round.messages, next) + j++ + } + rounds = append(rounds, round) + i = j + case msg.Role == schema.Tool: + // 孤儿 tool 消息:既不跟随在一个 assistant(tool_calls) 后, + // 说明它对应的 assistant 已被上游裁剪;直接丢弃,下一步到 orphan pruner + // 兜底也不会出错,但在 round 切分这里就剔除更干净。 + i++ + default: + // user / assistant(reply) / 其它:单条成 round。 + rounds = append(rounds, messageRound{messages: []adk.Message{msg}}) + i++ + } + } + return rounds +} + func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc { tc := agent.NewTikTokenCounter() return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) { diff --git a/internal/multiagent/eino_summarize_test.go b/internal/multiagent/eino_summarize_test.go new file mode 100644 index 00000000..dd8d6da7 --- /dev/null +++ b/internal/multiagent/eino_summarize_test.go @@ -0,0 +1,345 @@ +package multiagent + +import ( + "context" + "testing" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/middlewares/summarization" + "github.com/cloudwego/eino/schema" +) + +// fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。 +// 用于验证 tool-round 超预算时整体被跳过的分支。 +func fixedTokenCounter(tokensPerToolMessage int) summarization.TokenCounterFunc { + return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) { + total := 0 + for _, msg := range in.Messages { + if msg == nil { + continue + } + switch msg.Role { + case schema.Tool: + total += tokensPerToolMessage + default: + total++ + } + } + return total, nil + } +} + +// variableTokenCounter 让 tool 消息按 len(Content) 计(可区分不同大小的 tool 结果), +// 其它消息按 1 计;assistant 附加 len(ToolCalls) token 近似 tool_calls schema 开销。 +func variableTokenCounter() summarization.TokenCounterFunc { + return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) { + total := 0 + for _, msg := range in.Messages { + if msg == nil { + continue + } + if msg.Role == schema.Tool { + total += len(msg.Content) + continue + } + total++ + total += len(msg.ToolCalls) + } + return total, nil + } +} + +func TestSplitMessagesIntoRounds_Complex(t *testing.T) { + msgs := []adk.Message{ + schema.UserMessage("q1"), + assistantToolCallsMsg("", "c1", "c2"), + schema.ToolMessage("r1", "c1"), + schema.ToolMessage("r2", "c2"), + schema.AssistantMessage("reply1", nil), + schema.UserMessage("q2"), + assistantToolCallsMsg("", "c3"), + schema.ToolMessage("r3", "c3"), + } + rounds := splitMessagesIntoRounds(msgs) + // 5 rounds: user(q1) | assistant(tc:c1,c2)+tool*2 | assistant(reply1) | user(q2) | assistant(tc:c3)+tool(c3) + if len(rounds) != 5 { + t.Fatalf("want 5 rounds, got %d", len(rounds)) + } + // round 1 应为 tool-round,必须成对 + r1 := rounds[1] + if len(r1.messages) != 3 { + t.Fatalf("rounds[1] size: want 3, got %d", len(r1.messages)) + } + if r1.messages[0].Role != schema.Assistant || len(r1.messages[0].ToolCalls) != 2 { + t.Fatalf("rounds[1][0] must be assistant(tc=2)") + } + for i := 1; i < 3; i++ { + if r1.messages[i].Role != schema.Tool { + t.Fatalf("rounds[1][%d] must be tool, got %s", i, r1.messages[i].Role) + } + } + // 最后一个 round 成对 + rLast := rounds[len(rounds)-1] + if len(rLast.messages) != 2 { + t.Fatalf("rounds[last] size: want 2, got %d", len(rLast.messages)) + } + if rLast.messages[0].Role != schema.Assistant || rLast.messages[1].Role != schema.Tool { + t.Fatalf("last round must be assistant(tc)+tool(c3)") + } +} + +func TestSplitMessagesIntoRounds_DropsOrphanTool(t *testing.T) { + // 起点直接是 tool 消息(孤儿)—— 应被丢弃,不独立成 round。 + msgs := []adk.Message{ + schema.ToolMessage("orphan", "c_old"), + schema.UserMessage("continue"), + assistantToolCallsMsg("", "c_new"), + schema.ToolMessage("r_new", "c_new"), + } + rounds := splitMessagesIntoRounds(msgs) + // user(continue) | assistant(tc:c_new)+tool(c_new) → 2 rounds + if len(rounds) != 2 { + t.Fatalf("want 2 rounds after dropping orphan, got %d", len(rounds)) + } + for _, r := range rounds { + for _, m := range r.messages { + if m.Role == schema.Tool && m.ToolCallID == "c_old" { + t.Fatalf("orphan tool c_old must not appear in any round") + } + } + } +} + +func TestSplitMessagesIntoRounds_ToolBelongsToCurrentAssistantOnly(t *testing.T) { + // 两个相邻 assistant(tc),第二个的 tool 不应被归到第一个 assistant。 + msgs := []adk.Message{ + assistantToolCallsMsg("", "c1"), + schema.ToolMessage("r1", "c1"), + assistantToolCallsMsg("", "c2"), + schema.ToolMessage("r2", "c2"), + } + rounds := splitMessagesIntoRounds(msgs) + if len(rounds) != 2 { + t.Fatalf("want 2 rounds, got %d", len(rounds)) + } + if len(rounds[0].messages) != 2 || rounds[0].messages[0].ToolCalls[0].ID != "c1" { + t.Fatalf("round[0] wrong: %+v", rounds[0].messages) + } + if len(rounds[1].messages) != 2 || rounds[1].messages[0].ToolCalls[0].ID != "c2" { + t.Fatalf("round[1] wrong: %+v", rounds[1].messages) + } +} + +func TestSplitMessagesIntoRounds_ToolBelongsToWrongAssistant(t *testing.T) { + // assistant(tc:c1) 后面跟一个 tool_call_id=c999 的 tool 消息(本不属它)。 + // 切分规则:该 tool 不应拼入第一个 round(配对不完整),round 在此结束。 + // 而 c999 又没有对应 assistant,应被当孤儿丢弃。 + msgs := []adk.Message{ + assistantToolCallsMsg("", "c1"), + schema.ToolMessage("wrong", "c999"), + schema.UserMessage("hi"), + } + rounds := splitMessagesIntoRounds(msgs) + // assistant(tc:c1) 没有对应 tool(c1),但不是孤儿(patchtoolcalls 会兜底补); + // 它独立成 round 允许上游后处理。user(hi) 独立成 round。共 2 rounds。 + if len(rounds) != 2 { + t.Fatalf("want 2 rounds, got %d: %+v", len(rounds), rounds) + } + for _, r := range rounds { + for _, m := range r.messages { + if m.Role == schema.Tool && m.ToolCallID == "c999" { + t.Fatalf("wrong-owner tool must be dropped as orphan") + } + } + } +} + +func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) { + // 关键回归测试:一个 tool-round 整体被保留,而不是只保留 tool 消息。 + sys := schema.SystemMessage("sys") + summary := schema.AssistantMessage("summary_content", nil) + msgs := []adk.Message{ + sys, + schema.UserMessage("q1"), + schema.AssistantMessage("reply_before_tc", nil), // 填料,占预算 + assistantToolCallsMsg("", "c1"), + schema.ToolMessage("r1", "c1"), + } + + // token 预算:2 条消息(1 assistant + 1 tool)恰好够用。 + // 若按条数保留,可能先吃 tool(c1) 再吃 assistant(reply) 落入 budget,assistant(tc:c1) 被挤掉,导致孤儿。 + // 按 round 保留时,整个 tool-round 为原子,要么保留 2 条都在,要么都不在。 + out, err := summarizeFinalizeWithRecentAssistantToolTrail( + context.Background(), + msgs, + summary, + fixedTokenCounter(1), + 2, // 预算:2 tokens + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 必须包含 system + summary + if len(out) < 2 { + t.Fatalf("output too short: %d", len(out)) + } + if out[0] != sys { + t.Fatalf("first message must be system") + } + if out[1] != summary { + t.Fatalf("second message must be summary") + } + + // 关键不变量:每个被保留的 tool 消息,必须能在输出中找到提供其 ToolCallID 的 assistant(tc)。 + assertNoOrphanTool(t, out) +} + +func TestSummarizeFinalize_SkipsOversizedToolRoundButKeepsSmallerRound(t *testing.T) { + // 构造两个大小差异显著的 tool-round: + // c_big round 的 tool 结果 content="aaaaaaaaaa"(10 bytes),round token ≈ 2 (assistant+tc) + 10 = 12 + // c_ok round 的 tool 结果 content="ok"(2 bytes),round token ≈ 2 + 2 = 4 + // 配上 budget=8,使得: + // - 最新的 c_ok round(4)能放下; + // - 进一步的中间 round(assistant reply + user)也能放下; + // - 更早的 c_big round(12)放不下会被跳过(continue),而非 break。 + sys := schema.SystemMessage("sys") + summary := schema.AssistantMessage("summary_content", nil) + msgs := []adk.Message{ + sys, + schema.UserMessage("q1"), + assistantToolCallsMsg("", "c_big"), + schema.ToolMessage("aaaaaaaaaa", "c_big"), + schema.AssistantMessage("s", nil), + schema.UserMessage("q2"), + assistantToolCallsMsg("", "c_ok"), + schema.ToolMessage("ok", "c_ok"), + } + + out, err := summarizeFinalizeWithRecentAssistantToolTrail( + context.Background(), + msgs, + summary, + variableTokenCounter(), + 8, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + assertNoOrphanTool(t, out) + + // c_big 整个 round 必须被丢弃(tool 和 assistant 都不能出现) + for _, m := range out { + if m == nil { + continue + } + if m.Role == schema.Tool && m.ToolCallID == "c_big" { + t.Fatal("oversized tool round must be skipped: tool(c_big) leaked") + } + if m.Role == schema.Assistant { + for _, tc := range m.ToolCalls { + if tc.ID == "c_big" { + t.Fatal("oversized tool round must be skipped: assistant(tc:c_big) leaked") + } + } + } + } + + // 最近 round (c_ok) 作为一个原子单位必须整体保留。 + foundOKTool, foundOKAsst := false, false + for _, m := range out { + if m == nil { + continue + } + if m.Role == schema.Tool && m.ToolCallID == "c_ok" { + foundOKTool = true + } + if m.Role == schema.Assistant { + for _, tc := range m.ToolCalls { + if tc.ID == "c_ok" { + foundOKAsst = true + } + } + } + } + if !foundOKTool || !foundOKAsst { + t.Fatalf("recent tool-round (c_ok) must be retained as an atomic pair: assistantKept=%v toolKept=%v", foundOKAsst, foundOKTool) + } +} + +func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) { + sys := schema.SystemMessage("sys") + summary := schema.AssistantMessage("summary", nil) + msgs := []adk.Message{ + sys, + assistantToolCallsMsg("", "c1"), + schema.ToolMessage("r1", "c1"), + } + out, err := summarizeFinalizeWithRecentAssistantToolTrail( + context.Background(), + msgs, + summary, + fixedTokenCounter(1), + 0, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(out) != 2 || out[0] != sys || out[1] != summary { + t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out) + } +} + +func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) { + sys1 := schema.SystemMessage("sys1") + sys2 := schema.SystemMessage("sys2") + summary := schema.AssistantMessage("s", nil) + msgs := []adk.Message{ + sys1, + schema.UserMessage("q"), + sys2, // 非典型位置,但应当被 system group 捕获 + } + out, err := summarizeFinalizeWithRecentAssistantToolTrail( + context.Background(), + msgs, + summary, + fixedTokenCounter(1), + 100, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + systemCount := 0 + for _, m := range out { + if m != nil && m.Role == schema.System { + systemCount++ + } + } + if systemCount != 2 { + t.Fatalf("want 2 system messages retained, got %d", systemCount) + } +} + +// assertNoOrphanTool 断言消息列表里的每个 role=tool 消息都能在更前面找到一个 +// assistant(tool_calls) 提供相同 ID,否则说明产生了孤儿(触发 LLM 400 的根因)。 +func assertNoOrphanTool(t *testing.T, msgs []adk.Message) { + t.Helper() + provided := make(map[string]struct{}) + for _, m := range msgs { + if m == nil { + continue + } + if m.Role == schema.Assistant { + for _, tc := range m.ToolCalls { + if tc.ID != "" { + provided[tc.ID] = struct{}{} + } + } + } + if m.Role == schema.Tool && m.ToolCallID != "" { + if _, ok := provided[m.ToolCallID]; !ok { + t.Fatalf("orphan tool message found: ToolCallID=%q has no preceding assistant(tool_calls)", m.ToolCallID) + } + } + } +} diff --git a/internal/multiagent/orphan_tool_pruner_middleware.go b/internal/multiagent/orphan_tool_pruner_middleware.go new file mode 100644 index 00000000..8e33f8bb --- /dev/null +++ b/internal/multiagent/orphan_tool_pruner_middleware.go @@ -0,0 +1,124 @@ +package multiagent + +import ( + "context" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// orphanToolPrunerMiddleware 在每次 ChatModel 调用前剪掉没有对应 assistant(tool_calls) 的孤儿 tool 消息。 +// +// 背景: +// - eino 的 summarization 中间件在触发摘要后,默认把所有非 system 消息替换为 1 条 summary 消息; +// 本项目通过自定义 Finalize(summarizeFinalizeWithRecentAssistantToolTrail)在 summary 后回填 +// 最近的 assistant/tool 轨迹。若 Finalize 的保留策略按"条数"截断而未按 round 对齐,可能保留 +// 了 tool 结果却把对应的 assistant(tool_calls) 落在了 summary 前面,形成孤儿 tool 消息。 +// - 同样,reduction / tool_search / 自定义断点恢复等任一改写历史的逻辑,都可能破坏 +// tool_call ↔ tool_result 配对。 +// +// 一旦孤儿 tool 消息进入 ChatModel,OpenAI 兼容 API(含 DashScope / 各类中转)会返回 +// 400 "No tool call found for function call output with call_id ...",并被 Eino 包装成 +// [NodeRunError] 抛出,终止整轮编排。 +// +// 设计取舍: +// - 官方 patchtoolcalls 中间件只补反向(assistant(tc) 缺 tool_result),不处理孤儿 tool。 +// 本中间件与之互补,专职兜底正向孤儿。 +// - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。 +// 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。 +// - 位置建议:挂在所有可能改写历史的中间件(summarization / reduction / skill / plantask / +// tool_search)之后,靠近 ChatModel 调用的那一端。 +type orphanToolPrunerMiddleware struct { + adk.BaseChatModelAgentMiddleware + logger *zap.Logger + phase string +} + +// newOrphanToolPrunerMiddleware 构造中间件。phase 仅用于日志区分 deep / supervisor / +// plan_execute_executor / sub_agent,不影响运行时行为。 +func newOrphanToolPrunerMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware { + return &orphanToolPrunerMiddleware{ + logger: logger, + phase: phase, + } +} + +// BeforeModelRewriteState 扫描消息列表,收集 assistant.tool_calls 提供的 call_id 集合, +// 再剔除掉 ToolCallID 不在该集合中的 role=tool 消息。 +// +// 复杂度:O(N)。当未发现孤儿时不产生任何分配,state 原样返回以便上游快路径。 +func (m *orphanToolPrunerMiddleware) BeforeModelRewriteState( + ctx context.Context, + state *adk.ChatModelAgentState, + mc *adk.ModelContext, +) (context.Context, *adk.ChatModelAgentState, error) { + _ = mc + if m == nil || state == nil || len(state.Messages) == 0 { + return ctx, state, nil + } + + // 第一遍:收集所有已提供的 tool_call_id;同时快路径判定是否真的存在孤儿。 + provided := make(map[string]struct{}, 8) + for _, msg := range state.Messages { + if msg == nil { + continue + } + if msg.Role == schema.Assistant { + for _, tc := range msg.ToolCalls { + if tc.ID != "" { + provided[tc.ID] = struct{}{} + } + } + } + } + + hasOrphan := false + for _, msg := range state.Messages { + if msg == nil { + continue + } + if msg.Role == schema.Tool && msg.ToolCallID != "" { + if _, ok := provided[msg.ToolCallID]; !ok { + hasOrphan = true + break + } + } + } + if !hasOrphan { + return ctx, state, nil + } + + // 第二遍:生成剪除孤儿后的新消息列表。 + pruned := make([]adk.Message, 0, len(state.Messages)) + droppedIDs := make([]string, 0, 2) + droppedNames := make([]string, 0, 2) + for _, msg := range state.Messages { + if msg == nil { + continue + } + if msg.Role == schema.Tool && msg.ToolCallID != "" { + if _, ok := provided[msg.ToolCallID]; !ok { + droppedIDs = append(droppedIDs, msg.ToolCallID) + droppedNames = append(droppedNames, msg.ToolName) + continue + } + } + pruned = append(pruned, msg) + } + + if m.logger != nil { + m.logger.Warn("eino orphan tool messages pruned before model call", + zap.String("phase", m.phase), + zap.Int("dropped_count", len(droppedIDs)), + zap.Strings("dropped_tool_call_ids", droppedIDs), + zap.Strings("dropped_tool_names", droppedNames), + zap.Int("messages_before", len(state.Messages)), + zap.Int("messages_after", len(pruned)), + ) + } + + ns := *state + ns.Messages = pruned + return ctx, &ns, nil +} diff --git a/internal/multiagent/orphan_tool_pruner_middleware_test.go b/internal/multiagent/orphan_tool_pruner_middleware_test.go new file mode 100644 index 00000000..7af512ea --- /dev/null +++ b/internal/multiagent/orphan_tool_pruner_middleware_test.go @@ -0,0 +1,131 @@ +package multiagent + +import ( + "context" + "testing" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/schema" +) + +func assistantToolCallsMsg(content string, callIDs ...string) *schema.Message { + tcs := make([]schema.ToolCall, 0, len(callIDs)) + for _, id := range callIDs { + tcs = append(tcs, schema.ToolCall{ + ID: id, + Type: "function", + Function: schema.FunctionCall{ + Name: "stub_tool", + Arguments: `{}`, + }, + }) + } + return schema.AssistantMessage(content, tcs) +} + +func TestOrphanToolPruner_NoOpWhenPaired(t *testing.T) { + mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware) + + msgs := []adk.Message{ + schema.SystemMessage("sys"), + schema.UserMessage("hi"), + assistantToolCallsMsg("", "c1", "c2"), + schema.ToolMessage("r1", "c1"), + schema.ToolMessage("r2", "c2"), + schema.AssistantMessage("done", nil), + } + in := &adk.ChatModelAgentState{Messages: msgs} + + _, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out == nil { + t.Fatal("expected non-nil state") + } + if len(out.Messages) != len(msgs) { + t.Fatalf("expected %d messages kept, got %d", len(msgs), len(out.Messages)) + } + // 快路径:未发现孤儿时必须原地返回 state,不分配新切片。 + if &out.Messages[0] != &msgs[0] { + t.Fatalf("expected state to be returned as-is (same backing slice) when no orphan present") + } +} + +func TestOrphanToolPruner_DropsOrphanToolMessages(t *testing.T) { + mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware) + + msgs := []adk.Message{ + schema.SystemMessage("sys"), + // 摘要前的 assistant(tc: c_old) 已被裁剪,但对应的 tool 结果漏保留了。 + schema.ToolMessage("orphan result", "c_old"), + schema.UserMessage("continue"), + assistantToolCallsMsg("", "c_new"), + schema.ToolMessage("r_new", "c_new"), + } + in := &adk.ChatModelAgentState{Messages: msgs} + + _, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out == nil { + t.Fatal("expected non-nil state") + } + if len(out.Messages) != len(msgs)-1 { + t.Fatalf("expected %d messages after pruning, got %d", len(msgs)-1, len(out.Messages)) + } + for _, m := range out.Messages { + if m != nil && m.Role == schema.Tool && m.ToolCallID == "c_old" { + t.Fatalf("orphan tool message with ToolCallID=c_old should have been dropped") + } + } + // 合法的 tool(c_new) 必须保留。 + foundNew := false + for _, m := range out.Messages { + if m != nil && m.Role == schema.Tool && m.ToolCallID == "c_new" { + foundNew = true + break + } + } + if !foundNew { + t.Fatal("paired tool message (c_new) must be retained") + } +} + +func TestOrphanToolPruner_EmptyToolCallIDIsIgnored(t *testing.T) { + // 空 ToolCallID 的 tool 消息在真实场景中极罕见,但不应当被误判为孤儿。 + // 语义上把它当作"无法校验,保留",避免误删。 + mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware) + + odd := schema.ToolMessage("no_id", "") + msgs := []adk.Message{ + schema.UserMessage("hi"), + odd, + schema.AssistantMessage("ok", nil), + } + in := &adk.ChatModelAgentState{Messages: msgs} + + _, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(out.Messages) != len(msgs) { + t.Fatalf("empty ToolCallID tool message should be kept, got %d messages", len(out.Messages)) + } +} + +func TestOrphanToolPruner_NilAndEmpty(t *testing.T) { + mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware) + + ctx := context.Background() + // nil state + if _, out, err := mw.BeforeModelRewriteState(ctx, nil, &adk.ModelContext{}); err != nil || out != nil { + t.Fatalf("nil state: expected (nil,nil), got (%v,%v)", out, err) + } + // empty messages + empty := &adk.ChatModelAgentState{} + if _, out, err := mw.BeforeModelRewriteState(ctx, empty, &adk.ModelContext{}); err != nil || out != empty { + t.Fatalf("empty messages: expected same state, got (%v,%v)", out, err) + } +} diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go index 79437567..c3ed3a88 100644 --- a/internal/multiagent/runner.go +++ b/internal/multiagent/runner.go @@ -257,6 +257,9 @@ func RunDeepAgent( subHandlers = append(subHandlers, einoSkillMW) } subHandlers = append(subHandlers, subSumMw) + // 孤儿 tool 消息兜底:放在 summarization 之后,telemetry 之前, + // 以便 telemetry 记录的 token 数与 LLM 实际入参一致。 + subHandlers = append(subHandlers, newOrphanToolPrunerMiddleware(logger, "sub_agent:"+id)) if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil { subHandlers = append(subHandlers, teleMw) } @@ -393,6 +396,7 @@ func RunDeepAgent( deepHandlers = append(deepHandlers, einoSkillMW) } deepHandlers = append(deepHandlers, mainSumMw) + deepHandlers = append(deepHandlers, newOrphanToolPrunerMiddleware(logger, "deep_orchestrator")) if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil { deepHandlers = append(deepHandlers, teleMw) } @@ -405,6 +409,7 @@ func RunDeepAgent( supHandlers = append(supHandlers, einoSkillMW) } supHandlers = append(supHandlers, mainSumMw) + supHandlers = append(supHandlers, newOrphanToolPrunerMiddleware(logger, "supervisor_orchestrator")) if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil { supHandlers = append(supHandlers, teleMw) } @@ -455,6 +460,8 @@ func RunDeepAgent( FilesystemMiddleware: peFsMw, PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{ mainSumMw, + // 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。 + newOrphanToolPrunerMiddleware(logger, "plan_execute_planner_replanner"), newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"), }, })