Files
CyberStrikeAI/internal/multiagent/eino_summarize_test.go
T
2026-05-01 01:00:23 +08:00

346 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package multiagent
import (
"context"
"testing"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/adk/middlewares/summarization"
"github.com/cloudwego/eino/schema"
)
// fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。
// 用于验证 tool-round 超预算时整体被跳过的分支。
func fixedTokenCounter(tokensPerToolMessage int) summarization.TokenCounterFunc {
return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) {
total := 0
for _, msg := range in.Messages {
if msg == nil {
continue
}
switch msg.Role {
case schema.Tool:
total += tokensPerToolMessage
default:
total++
}
}
return total, nil
}
}
// variableTokenCounter 让 tool 消息按 len(Content) 计(可区分不同大小的 tool 结果),
// 其它消息按 1 计;assistant 附加 len(ToolCalls) token 近似 tool_calls schema 开销。
func variableTokenCounter() summarization.TokenCounterFunc {
return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) {
total := 0
for _, msg := range in.Messages {
if msg == nil {
continue
}
if msg.Role == schema.Tool {
total += len(msg.Content)
continue
}
total++
total += len(msg.ToolCalls)
}
return total, nil
}
}
func TestSplitMessagesIntoRounds_Complex(t *testing.T) {
msgs := []adk.Message{
schema.UserMessage("q1"),
assistantToolCallsMsg("", "c1", "c2"),
schema.ToolMessage("r1", "c1"),
schema.ToolMessage("r2", "c2"),
schema.AssistantMessage("reply1", nil),
schema.UserMessage("q2"),
assistantToolCallsMsg("", "c3"),
schema.ToolMessage("r3", "c3"),
}
rounds := splitMessagesIntoRounds(msgs)
// 5 rounds: user(q1) | assistant(tc:c1,c2)+tool*2 | assistant(reply1) | user(q2) | assistant(tc:c3)+tool(c3)
if len(rounds) != 5 {
t.Fatalf("want 5 rounds, got %d", len(rounds))
}
// round 1 应为 tool-round,必须成对
r1 := rounds[1]
if len(r1.messages) != 3 {
t.Fatalf("rounds[1] size: want 3, got %d", len(r1.messages))
}
if r1.messages[0].Role != schema.Assistant || len(r1.messages[0].ToolCalls) != 2 {
t.Fatalf("rounds[1][0] must be assistant(tc=2)")
}
for i := 1; i < 3; i++ {
if r1.messages[i].Role != schema.Tool {
t.Fatalf("rounds[1][%d] must be tool, got %s", i, r1.messages[i].Role)
}
}
// 最后一个 round 成对
rLast := rounds[len(rounds)-1]
if len(rLast.messages) != 2 {
t.Fatalf("rounds[last] size: want 2, got %d", len(rLast.messages))
}
if rLast.messages[0].Role != schema.Assistant || rLast.messages[1].Role != schema.Tool {
t.Fatalf("last round must be assistant(tc)+tool(c3)")
}
}
func TestSplitMessagesIntoRounds_DropsOrphanTool(t *testing.T) {
// 起点直接是 tool 消息(孤儿)—— 应被丢弃,不独立成 round。
msgs := []adk.Message{
schema.ToolMessage("orphan", "c_old"),
schema.UserMessage("continue"),
assistantToolCallsMsg("", "c_new"),
schema.ToolMessage("r_new", "c_new"),
}
rounds := splitMessagesIntoRounds(msgs)
// user(continue) | assistant(tc:c_new)+tool(c_new) → 2 rounds
if len(rounds) != 2 {
t.Fatalf("want 2 rounds after dropping orphan, got %d", len(rounds))
}
for _, r := range rounds {
for _, m := range r.messages {
if m.Role == schema.Tool && m.ToolCallID == "c_old" {
t.Fatalf("orphan tool c_old must not appear in any round")
}
}
}
}
func TestSplitMessagesIntoRounds_ToolBelongsToCurrentAssistantOnly(t *testing.T) {
// 两个相邻 assistant(tc),第二个的 tool 不应被归到第一个 assistant。
msgs := []adk.Message{
assistantToolCallsMsg("", "c1"),
schema.ToolMessage("r1", "c1"),
assistantToolCallsMsg("", "c2"),
schema.ToolMessage("r2", "c2"),
}
rounds := splitMessagesIntoRounds(msgs)
if len(rounds) != 2 {
t.Fatalf("want 2 rounds, got %d", len(rounds))
}
if len(rounds[0].messages) != 2 || rounds[0].messages[0].ToolCalls[0].ID != "c1" {
t.Fatalf("round[0] wrong: %+v", rounds[0].messages)
}
if len(rounds[1].messages) != 2 || rounds[1].messages[0].ToolCalls[0].ID != "c2" {
t.Fatalf("round[1] wrong: %+v", rounds[1].messages)
}
}
func TestSplitMessagesIntoRounds_ToolBelongsToWrongAssistant(t *testing.T) {
// assistant(tc:c1) 后面跟一个 tool_call_id=c999 的 tool 消息(本不属它)。
// 切分规则:该 tool 不应拼入第一个 round(配对不完整),round 在此结束。
// 而 c999 又没有对应 assistant,应被当孤儿丢弃。
msgs := []adk.Message{
assistantToolCallsMsg("", "c1"),
schema.ToolMessage("wrong", "c999"),
schema.UserMessage("hi"),
}
rounds := splitMessagesIntoRounds(msgs)
// assistant(tc:c1) 没有对应 tool(c1),但不是孤儿(patchtoolcalls 会兜底补);
// 它独立成 round 允许上游后处理。user(hi) 独立成 round。共 2 rounds。
if len(rounds) != 2 {
t.Fatalf("want 2 rounds, got %d: %+v", len(rounds), rounds)
}
for _, r := range rounds {
for _, m := range r.messages {
if m.Role == schema.Tool && m.ToolCallID == "c999" {
t.Fatalf("wrong-owner tool must be dropped as orphan")
}
}
}
}
func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) {
// 关键回归测试:一个 tool-round 整体被保留,而不是只保留 tool 消息。
sys := schema.SystemMessage("sys")
summary := schema.AssistantMessage("summary_content", nil)
msgs := []adk.Message{
sys,
schema.UserMessage("q1"),
schema.AssistantMessage("reply_before_tc", nil), // 填料,占预算
assistantToolCallsMsg("", "c1"),
schema.ToolMessage("r1", "c1"),
}
// token 预算:2 条消息(1 assistant + 1 tool)恰好够用。
// 若按条数保留,可能先吃 tool(c1) 再吃 assistant(reply) 落入 budgetassistant(tc:c1) 被挤掉,导致孤儿。
// 按 round 保留时,整个 tool-round 为原子,要么保留 2 条都在,要么都不在。
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
context.Background(),
msgs,
summary,
fixedTokenCounter(1),
2, // 预算:2 tokens
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// 必须包含 system + summary
if len(out) < 2 {
t.Fatalf("output too short: %d", len(out))
}
if out[0] != sys {
t.Fatalf("first message must be system")
}
if out[1] != summary {
t.Fatalf("second message must be summary")
}
// 关键不变量:每个被保留的 tool 消息,必须能在输出中找到提供其 ToolCallID 的 assistant(tc)。
assertNoOrphanTool(t, out)
}
func TestSummarizeFinalize_SkipsOversizedToolRoundButKeepsSmallerRound(t *testing.T) {
// 构造两个大小差异显著的 tool-round:
// c_big round 的 tool 结果 content="aaaaaaaaaa"10 bytes),round token ≈ 2 (assistant+tc) + 10 = 12
// c_ok round 的 tool 结果 content="ok"2 bytes),round token ≈ 2 + 2 = 4
// 配上 budget=8,使得:
// - 最新的 c_ok round4)能放下;
// - 进一步的中间 roundassistant reply + user)也能放下;
// - 更早的 c_big round12)放不下会被跳过(continue),而非 break。
sys := schema.SystemMessage("sys")
summary := schema.AssistantMessage("summary_content", nil)
msgs := []adk.Message{
sys,
schema.UserMessage("q1"),
assistantToolCallsMsg("", "c_big"),
schema.ToolMessage("aaaaaaaaaa", "c_big"),
schema.AssistantMessage("s", nil),
schema.UserMessage("q2"),
assistantToolCallsMsg("", "c_ok"),
schema.ToolMessage("ok", "c_ok"),
}
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
context.Background(),
msgs,
summary,
variableTokenCounter(),
8,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
assertNoOrphanTool(t, out)
// c_big 整个 round 必须被丢弃(tool 和 assistant 都不能出现)
for _, m := range out {
if m == nil {
continue
}
if m.Role == schema.Tool && m.ToolCallID == "c_big" {
t.Fatal("oversized tool round must be skipped: tool(c_big) leaked")
}
if m.Role == schema.Assistant {
for _, tc := range m.ToolCalls {
if tc.ID == "c_big" {
t.Fatal("oversized tool round must be skipped: assistant(tc:c_big) leaked")
}
}
}
}
// 最近 round (c_ok) 作为一个原子单位必须整体保留。
foundOKTool, foundOKAsst := false, false
for _, m := range out {
if m == nil {
continue
}
if m.Role == schema.Tool && m.ToolCallID == "c_ok" {
foundOKTool = true
}
if m.Role == schema.Assistant {
for _, tc := range m.ToolCalls {
if tc.ID == "c_ok" {
foundOKAsst = true
}
}
}
}
if !foundOKTool || !foundOKAsst {
t.Fatalf("recent tool-round (c_ok) must be retained as an atomic pair: assistantKept=%v toolKept=%v", foundOKAsst, foundOKTool)
}
}
func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) {
sys := schema.SystemMessage("sys")
summary := schema.AssistantMessage("summary", nil)
msgs := []adk.Message{
sys,
assistantToolCallsMsg("", "c1"),
schema.ToolMessage("r1", "c1"),
}
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
context.Background(),
msgs,
summary,
fixedTokenCounter(1),
0,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(out) != 2 || out[0] != sys || out[1] != summary {
t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out)
}
}
func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
sys1 := schema.SystemMessage("sys1")
sys2 := schema.SystemMessage("sys2")
summary := schema.AssistantMessage("s", nil)
msgs := []adk.Message{
sys1,
schema.UserMessage("q"),
sys2, // 非典型位置,但应当被 system group 捕获
}
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
context.Background(),
msgs,
summary,
fixedTokenCounter(1),
100,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
systemCount := 0
for _, m := range out {
if m != nil && m.Role == schema.System {
systemCount++
}
}
if systemCount != 2 {
t.Fatalf("want 2 system messages retained, got %d", systemCount)
}
}
// assertNoOrphanTool 断言消息列表里的每个 role=tool 消息都能在更前面找到一个
// assistant(tool_calls) 提供相同 ID,否则说明产生了孤儿(触发 LLM 400 的根因)。
func assertNoOrphanTool(t *testing.T, msgs []adk.Message) {
t.Helper()
provided := make(map[string]struct{})
for _, m := range msgs {
if m == nil {
continue
}
if m.Role == schema.Assistant {
for _, tc := range m.ToolCalls {
if tc.ID != "" {
provided[tc.ID] = struct{}{}
}
}
}
if m.Role == schema.Tool && m.ToolCallID != "" {
if _, ok := provided[m.ToolCallID]; !ok {
t.Fatalf("orphan tool message found: ToolCallID=%q has no preceding assistant(tool_calls)", m.ToolCallID)
}
}
}
}