package multiagent import ( "context" "testing" "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/middlewares/summarization" "github.com/cloudwego/eino/schema" ) // fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。 // 用于验证 tool-round 超预算时整体被跳过的分支。 func fixedTokenCounter(tokensPerToolMessage int) summarization.TokenCounterFunc { return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) { total := 0 for _, msg := range in.Messages { if msg == nil { continue } switch msg.Role { case schema.Tool: total += tokensPerToolMessage default: total++ } } return total, nil } } // variableTokenCounter 让 tool 消息按 len(Content) 计(可区分不同大小的 tool 结果), // 其它消息按 1 计;assistant 附加 len(ToolCalls) token 近似 tool_calls schema 开销。 func variableTokenCounter() summarization.TokenCounterFunc { return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) { total := 0 for _, msg := range in.Messages { if msg == nil { continue } if msg.Role == schema.Tool { total += len(msg.Content) continue } total++ total += len(msg.ToolCalls) } return total, nil } } func TestSplitMessagesIntoRounds_Complex(t *testing.T) { msgs := []adk.Message{ schema.UserMessage("q1"), assistantToolCallsMsg("", "c1", "c2"), schema.ToolMessage("r1", "c1"), schema.ToolMessage("r2", "c2"), schema.AssistantMessage("reply1", nil), schema.UserMessage("q2"), assistantToolCallsMsg("", "c3"), schema.ToolMessage("r3", "c3"), } rounds := splitMessagesIntoRounds(msgs) // 5 rounds: user(q1) | assistant(tc:c1,c2)+tool*2 | assistant(reply1) | user(q2) | assistant(tc:c3)+tool(c3) if len(rounds) != 5 { t.Fatalf("want 5 rounds, got %d", len(rounds)) } // round 1 应为 tool-round,必须成对 r1 := rounds[1] if len(r1.messages) != 3 { t.Fatalf("rounds[1] size: want 3, got %d", len(r1.messages)) } if r1.messages[0].Role != schema.Assistant || len(r1.messages[0].ToolCalls) != 2 { t.Fatalf("rounds[1][0] must be assistant(tc=2)") } for i := 1; i < 3; i++ { if r1.messages[i].Role != schema.Tool { t.Fatalf("rounds[1][%d] must be tool, got %s", i, r1.messages[i].Role) } } // 最后一个 round 成对 rLast := rounds[len(rounds)-1] if len(rLast.messages) != 2 { t.Fatalf("rounds[last] size: want 2, got %d", len(rLast.messages)) } if rLast.messages[0].Role != schema.Assistant || rLast.messages[1].Role != schema.Tool { t.Fatalf("last round must be assistant(tc)+tool(c3)") } } func TestSplitMessagesIntoRounds_DropsOrphanTool(t *testing.T) { // 起点直接是 tool 消息(孤儿)—— 应被丢弃,不独立成 round。 msgs := []adk.Message{ schema.ToolMessage("orphan", "c_old"), schema.UserMessage("continue"), assistantToolCallsMsg("", "c_new"), schema.ToolMessage("r_new", "c_new"), } rounds := splitMessagesIntoRounds(msgs) // user(continue) | assistant(tc:c_new)+tool(c_new) → 2 rounds if len(rounds) != 2 { t.Fatalf("want 2 rounds after dropping orphan, got %d", len(rounds)) } for _, r := range rounds { for _, m := range r.messages { if m.Role == schema.Tool && m.ToolCallID == "c_old" { t.Fatalf("orphan tool c_old must not appear in any round") } } } } func TestSplitMessagesIntoRounds_ToolBelongsToCurrentAssistantOnly(t *testing.T) { // 两个相邻 assistant(tc),第二个的 tool 不应被归到第一个 assistant。 msgs := []adk.Message{ assistantToolCallsMsg("", "c1"), schema.ToolMessage("r1", "c1"), assistantToolCallsMsg("", "c2"), schema.ToolMessage("r2", "c2"), } rounds := splitMessagesIntoRounds(msgs) if len(rounds) != 2 { t.Fatalf("want 2 rounds, got %d", len(rounds)) } if len(rounds[0].messages) != 2 || rounds[0].messages[0].ToolCalls[0].ID != "c1" { t.Fatalf("round[0] wrong: %+v", rounds[0].messages) } if len(rounds[1].messages) != 2 || rounds[1].messages[0].ToolCalls[0].ID != "c2" { t.Fatalf("round[1] wrong: %+v", rounds[1].messages) } } func TestSplitMessagesIntoRounds_ToolBelongsToWrongAssistant(t *testing.T) { // assistant(tc:c1) 后面跟一个 tool_call_id=c999 的 tool 消息(本不属它)。 // 切分规则:该 tool 不应拼入第一个 round(配对不完整),round 在此结束。 // 而 c999 又没有对应 assistant,应被当孤儿丢弃。 msgs := []adk.Message{ assistantToolCallsMsg("", "c1"), schema.ToolMessage("wrong", "c999"), schema.UserMessage("hi"), } rounds := splitMessagesIntoRounds(msgs) // assistant(tc:c1) 没有对应 tool(c1),但不是孤儿(patchtoolcalls 会兜底补); // 它独立成 round 允许上游后处理。user(hi) 独立成 round。共 2 rounds。 if len(rounds) != 2 { t.Fatalf("want 2 rounds, got %d: %+v", len(rounds), rounds) } for _, r := range rounds { for _, m := range r.messages { if m.Role == schema.Tool && m.ToolCallID == "c999" { t.Fatalf("wrong-owner tool must be dropped as orphan") } } } } func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) { // 关键回归测试:一个 tool-round 整体被保留,而不是只保留 tool 消息。 sys := schema.SystemMessage("sys") summary := schema.AssistantMessage("summary_content", nil) msgs := []adk.Message{ sys, schema.UserMessage("q1"), schema.AssistantMessage("reply_before_tc", nil), // 填料,占预算 assistantToolCallsMsg("", "c1"), schema.ToolMessage("r1", "c1"), } // token 预算:2 条消息(1 assistant + 1 tool)恰好够用。 // 若按条数保留,可能先吃 tool(c1) 再吃 assistant(reply) 落入 budget,assistant(tc:c1) 被挤掉,导致孤儿。 // 按 round 保留时,整个 tool-round 为原子,要么保留 2 条都在,要么都不在。 out, err := summarizeFinalizeWithRecentAssistantToolTrail( context.Background(), msgs, summary, fixedTokenCounter(1), 2, // 预算:2 tokens ) if err != nil { t.Fatalf("unexpected error: %v", err) } // 必须包含 system + summary if len(out) < 2 { t.Fatalf("output too short: %d", len(out)) } if out[0] != sys { t.Fatalf("first message must be system") } if out[1] != summary { t.Fatalf("second message must be summary") } // 关键不变量:每个被保留的 tool 消息,必须能在输出中找到提供其 ToolCallID 的 assistant(tc)。 assertNoOrphanTool(t, out) } func TestSummarizeFinalize_SkipsOversizedToolRoundButKeepsSmallerRound(t *testing.T) { // 构造两个大小差异显著的 tool-round: // c_big round 的 tool 结果 content="aaaaaaaaaa"(10 bytes),round token ≈ 2 (assistant+tc) + 10 = 12 // c_ok round 的 tool 结果 content="ok"(2 bytes),round token ≈ 2 + 2 = 4 // 配上 budget=8,使得: // - 最新的 c_ok round(4)能放下; // - 进一步的中间 round(assistant reply + user)也能放下; // - 更早的 c_big round(12)放不下会被跳过(continue),而非 break。 sys := schema.SystemMessage("sys") summary := schema.AssistantMessage("summary_content", nil) msgs := []adk.Message{ sys, schema.UserMessage("q1"), assistantToolCallsMsg("", "c_big"), schema.ToolMessage("aaaaaaaaaa", "c_big"), schema.AssistantMessage("s", nil), schema.UserMessage("q2"), assistantToolCallsMsg("", "c_ok"), schema.ToolMessage("ok", "c_ok"), } out, err := summarizeFinalizeWithRecentAssistantToolTrail( context.Background(), msgs, summary, variableTokenCounter(), 8, ) if err != nil { t.Fatalf("unexpected error: %v", err) } assertNoOrphanTool(t, out) // c_big 整个 round 必须被丢弃(tool 和 assistant 都不能出现) for _, m := range out { if m == nil { continue } if m.Role == schema.Tool && m.ToolCallID == "c_big" { t.Fatal("oversized tool round must be skipped: tool(c_big) leaked") } if m.Role == schema.Assistant { for _, tc := range m.ToolCalls { if tc.ID == "c_big" { t.Fatal("oversized tool round must be skipped: assistant(tc:c_big) leaked") } } } } // 最近 round (c_ok) 作为一个原子单位必须整体保留。 foundOKTool, foundOKAsst := false, false for _, m := range out { if m == nil { continue } if m.Role == schema.Tool && m.ToolCallID == "c_ok" { foundOKTool = true } if m.Role == schema.Assistant { for _, tc := range m.ToolCalls { if tc.ID == "c_ok" { foundOKAsst = true } } } } if !foundOKTool || !foundOKAsst { t.Fatalf("recent tool-round (c_ok) must be retained as an atomic pair: assistantKept=%v toolKept=%v", foundOKAsst, foundOKTool) } } func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) { sys := schema.SystemMessage("sys") summary := schema.AssistantMessage("summary", nil) msgs := []adk.Message{ sys, assistantToolCallsMsg("", "c1"), schema.ToolMessage("r1", "c1"), } out, err := summarizeFinalizeWithRecentAssistantToolTrail( context.Background(), msgs, summary, fixedTokenCounter(1), 0, ) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(out) != 2 || out[0] != sys || out[1] != summary { t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out) } } func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) { sys1 := schema.SystemMessage("sys1") sys2 := schema.SystemMessage("sys2") summary := schema.AssistantMessage("s", nil) msgs := []adk.Message{ sys1, schema.UserMessage("q"), sys2, // 非典型位置,但应当被 system group 捕获 } out, err := summarizeFinalizeWithRecentAssistantToolTrail( context.Background(), msgs, summary, fixedTokenCounter(1), 100, ) if err != nil { t.Fatalf("unexpected error: %v", err) } systemCount := 0 for _, m := range out { if m != nil && m.Role == schema.System { systemCount++ } } if systemCount != 2 { t.Fatalf("want 2 system messages retained, got %d", systemCount) } } // assertNoOrphanTool 断言消息列表里的每个 role=tool 消息都能在更前面找到一个 // assistant(tool_calls) 提供相同 ID,否则说明产生了孤儿(触发 LLM 400 的根因)。 func assertNoOrphanTool(t *testing.T, msgs []adk.Message) { t.Helper() provided := make(map[string]struct{}) for _, m := range msgs { if m == nil { continue } if m.Role == schema.Assistant { for _, tc := range m.ToolCalls { if tc.ID != "" { provided[tc.ID] = struct{}{} } } } if m.Role == schema.Tool && m.ToolCallID != "" { if _, ok := provided[m.ToolCallID]; !ok { t.Fatalf("orphan tool message found: ToolCallID=%q has no preceding assistant(tool_calls)", m.ToolCallID) } } } }