mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-07 18:17:00 +02:00
346 lines
10 KiB
Go
346 lines
10 KiB
Go
package multiagent
|
||
|
||
import (
|
||
"context"
|
||
"testing"
|
||
|
||
"github.com/cloudwego/eino/adk"
|
||
"github.com/cloudwego/eino/adk/middlewares/summarization"
|
||
"github.com/cloudwego/eino/schema"
|
||
)
|
||
|
||
// fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。
|
||
// 用于验证 tool-round 超预算时整体被跳过的分支。
|
||
func fixedTokenCounter(tokensPerToolMessage int) summarization.TokenCounterFunc {
|
||
return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) {
|
||
total := 0
|
||
for _, msg := range in.Messages {
|
||
if msg == nil {
|
||
continue
|
||
}
|
||
switch msg.Role {
|
||
case schema.Tool:
|
||
total += tokensPerToolMessage
|
||
default:
|
||
total++
|
||
}
|
||
}
|
||
return total, nil
|
||
}
|
||
}
|
||
|
||
// variableTokenCounter 让 tool 消息按 len(Content) 计(可区分不同大小的 tool 结果),
|
||
// 其它消息按 1 计;assistant 附加 len(ToolCalls) token 近似 tool_calls schema 开销。
|
||
func variableTokenCounter() summarization.TokenCounterFunc {
|
||
return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) {
|
||
total := 0
|
||
for _, msg := range in.Messages {
|
||
if msg == nil {
|
||
continue
|
||
}
|
||
if msg.Role == schema.Tool {
|
||
total += len(msg.Content)
|
||
continue
|
||
}
|
||
total++
|
||
total += len(msg.ToolCalls)
|
||
}
|
||
return total, nil
|
||
}
|
||
}
|
||
|
||
func TestSplitMessagesIntoRounds_Complex(t *testing.T) {
|
||
msgs := []adk.Message{
|
||
schema.UserMessage("q1"),
|
||
assistantToolCallsMsg("", "c1", "c2"),
|
||
schema.ToolMessage("r1", "c1"),
|
||
schema.ToolMessage("r2", "c2"),
|
||
schema.AssistantMessage("reply1", nil),
|
||
schema.UserMessage("q2"),
|
||
assistantToolCallsMsg("", "c3"),
|
||
schema.ToolMessage("r3", "c3"),
|
||
}
|
||
rounds := splitMessagesIntoRounds(msgs)
|
||
// 5 rounds: user(q1) | assistant(tc:c1,c2)+tool*2 | assistant(reply1) | user(q2) | assistant(tc:c3)+tool(c3)
|
||
if len(rounds) != 5 {
|
||
t.Fatalf("want 5 rounds, got %d", len(rounds))
|
||
}
|
||
// round 1 应为 tool-round,必须成对
|
||
r1 := rounds[1]
|
||
if len(r1.messages) != 3 {
|
||
t.Fatalf("rounds[1] size: want 3, got %d", len(r1.messages))
|
||
}
|
||
if r1.messages[0].Role != schema.Assistant || len(r1.messages[0].ToolCalls) != 2 {
|
||
t.Fatalf("rounds[1][0] must be assistant(tc=2)")
|
||
}
|
||
for i := 1; i < 3; i++ {
|
||
if r1.messages[i].Role != schema.Tool {
|
||
t.Fatalf("rounds[1][%d] must be tool, got %s", i, r1.messages[i].Role)
|
||
}
|
||
}
|
||
// 最后一个 round 成对
|
||
rLast := rounds[len(rounds)-1]
|
||
if len(rLast.messages) != 2 {
|
||
t.Fatalf("rounds[last] size: want 2, got %d", len(rLast.messages))
|
||
}
|
||
if rLast.messages[0].Role != schema.Assistant || rLast.messages[1].Role != schema.Tool {
|
||
t.Fatalf("last round must be assistant(tc)+tool(c3)")
|
||
}
|
||
}
|
||
|
||
func TestSplitMessagesIntoRounds_DropsOrphanTool(t *testing.T) {
|
||
// 起点直接是 tool 消息(孤儿)—— 应被丢弃,不独立成 round。
|
||
msgs := []adk.Message{
|
||
schema.ToolMessage("orphan", "c_old"),
|
||
schema.UserMessage("continue"),
|
||
assistantToolCallsMsg("", "c_new"),
|
||
schema.ToolMessage("r_new", "c_new"),
|
||
}
|
||
rounds := splitMessagesIntoRounds(msgs)
|
||
// user(continue) | assistant(tc:c_new)+tool(c_new) → 2 rounds
|
||
if len(rounds) != 2 {
|
||
t.Fatalf("want 2 rounds after dropping orphan, got %d", len(rounds))
|
||
}
|
||
for _, r := range rounds {
|
||
for _, m := range r.messages {
|
||
if m.Role == schema.Tool && m.ToolCallID == "c_old" {
|
||
t.Fatalf("orphan tool c_old must not appear in any round")
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestSplitMessagesIntoRounds_ToolBelongsToCurrentAssistantOnly(t *testing.T) {
|
||
// 两个相邻 assistant(tc),第二个的 tool 不应被归到第一个 assistant。
|
||
msgs := []adk.Message{
|
||
assistantToolCallsMsg("", "c1"),
|
||
schema.ToolMessage("r1", "c1"),
|
||
assistantToolCallsMsg("", "c2"),
|
||
schema.ToolMessage("r2", "c2"),
|
||
}
|
||
rounds := splitMessagesIntoRounds(msgs)
|
||
if len(rounds) != 2 {
|
||
t.Fatalf("want 2 rounds, got %d", len(rounds))
|
||
}
|
||
if len(rounds[0].messages) != 2 || rounds[0].messages[0].ToolCalls[0].ID != "c1" {
|
||
t.Fatalf("round[0] wrong: %+v", rounds[0].messages)
|
||
}
|
||
if len(rounds[1].messages) != 2 || rounds[1].messages[0].ToolCalls[0].ID != "c2" {
|
||
t.Fatalf("round[1] wrong: %+v", rounds[1].messages)
|
||
}
|
||
}
|
||
|
||
func TestSplitMessagesIntoRounds_ToolBelongsToWrongAssistant(t *testing.T) {
|
||
// assistant(tc:c1) 后面跟一个 tool_call_id=c999 的 tool 消息(本不属它)。
|
||
// 切分规则:该 tool 不应拼入第一个 round(配对不完整),round 在此结束。
|
||
// 而 c999 又没有对应 assistant,应被当孤儿丢弃。
|
||
msgs := []adk.Message{
|
||
assistantToolCallsMsg("", "c1"),
|
||
schema.ToolMessage("wrong", "c999"),
|
||
schema.UserMessage("hi"),
|
||
}
|
||
rounds := splitMessagesIntoRounds(msgs)
|
||
// assistant(tc:c1) 没有对应 tool(c1),但不是孤儿(patchtoolcalls 会兜底补);
|
||
// 它独立成 round 允许上游后处理。user(hi) 独立成 round。共 2 rounds。
|
||
if len(rounds) != 2 {
|
||
t.Fatalf("want 2 rounds, got %d: %+v", len(rounds), rounds)
|
||
}
|
||
for _, r := range rounds {
|
||
for _, m := range r.messages {
|
||
if m.Role == schema.Tool && m.ToolCallID == "c999" {
|
||
t.Fatalf("wrong-owner tool must be dropped as orphan")
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) {
|
||
// 关键回归测试:一个 tool-round 整体被保留,而不是只保留 tool 消息。
|
||
sys := schema.SystemMessage("sys")
|
||
summary := schema.AssistantMessage("summary_content", nil)
|
||
msgs := []adk.Message{
|
||
sys,
|
||
schema.UserMessage("q1"),
|
||
schema.AssistantMessage("reply_before_tc", nil), // 填料,占预算
|
||
assistantToolCallsMsg("", "c1"),
|
||
schema.ToolMessage("r1", "c1"),
|
||
}
|
||
|
||
// token 预算:2 条消息(1 assistant + 1 tool)恰好够用。
|
||
// 若按条数保留,可能先吃 tool(c1) 再吃 assistant(reply) 落入 budget,assistant(tc:c1) 被挤掉,导致孤儿。
|
||
// 按 round 保留时,整个 tool-round 为原子,要么保留 2 条都在,要么都不在。
|
||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||
context.Background(),
|
||
msgs,
|
||
summary,
|
||
fixedTokenCounter(1),
|
||
2, // 预算:2 tokens
|
||
)
|
||
if err != nil {
|
||
t.Fatalf("unexpected error: %v", err)
|
||
}
|
||
|
||
// 必须包含 system + summary
|
||
if len(out) < 2 {
|
||
t.Fatalf("output too short: %d", len(out))
|
||
}
|
||
if out[0] != sys {
|
||
t.Fatalf("first message must be system")
|
||
}
|
||
if out[1] != summary {
|
||
t.Fatalf("second message must be summary")
|
||
}
|
||
|
||
// 关键不变量:每个被保留的 tool 消息,必须能在输出中找到提供其 ToolCallID 的 assistant(tc)。
|
||
assertNoOrphanTool(t, out)
|
||
}
|
||
|
||
func TestSummarizeFinalize_SkipsOversizedToolRoundButKeepsSmallerRound(t *testing.T) {
|
||
// 构造两个大小差异显著的 tool-round:
|
||
// c_big round 的 tool 结果 content="aaaaaaaaaa"(10 bytes),round token ≈ 2 (assistant+tc) + 10 = 12
|
||
// c_ok round 的 tool 结果 content="ok"(2 bytes),round token ≈ 2 + 2 = 4
|
||
// 配上 budget=8,使得:
|
||
// - 最新的 c_ok round(4)能放下;
|
||
// - 进一步的中间 round(assistant reply + user)也能放下;
|
||
// - 更早的 c_big round(12)放不下会被跳过(continue),而非 break。
|
||
sys := schema.SystemMessage("sys")
|
||
summary := schema.AssistantMessage("summary_content", nil)
|
||
msgs := []adk.Message{
|
||
sys,
|
||
schema.UserMessage("q1"),
|
||
assistantToolCallsMsg("", "c_big"),
|
||
schema.ToolMessage("aaaaaaaaaa", "c_big"),
|
||
schema.AssistantMessage("s", nil),
|
||
schema.UserMessage("q2"),
|
||
assistantToolCallsMsg("", "c_ok"),
|
||
schema.ToolMessage("ok", "c_ok"),
|
||
}
|
||
|
||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||
context.Background(),
|
||
msgs,
|
||
summary,
|
||
variableTokenCounter(),
|
||
8,
|
||
)
|
||
if err != nil {
|
||
t.Fatalf("unexpected error: %v", err)
|
||
}
|
||
assertNoOrphanTool(t, out)
|
||
|
||
// c_big 整个 round 必须被丢弃(tool 和 assistant 都不能出现)
|
||
for _, m := range out {
|
||
if m == nil {
|
||
continue
|
||
}
|
||
if m.Role == schema.Tool && m.ToolCallID == "c_big" {
|
||
t.Fatal("oversized tool round must be skipped: tool(c_big) leaked")
|
||
}
|
||
if m.Role == schema.Assistant {
|
||
for _, tc := range m.ToolCalls {
|
||
if tc.ID == "c_big" {
|
||
t.Fatal("oversized tool round must be skipped: assistant(tc:c_big) leaked")
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 最近 round (c_ok) 作为一个原子单位必须整体保留。
|
||
foundOKTool, foundOKAsst := false, false
|
||
for _, m := range out {
|
||
if m == nil {
|
||
continue
|
||
}
|
||
if m.Role == schema.Tool && m.ToolCallID == "c_ok" {
|
||
foundOKTool = true
|
||
}
|
||
if m.Role == schema.Assistant {
|
||
for _, tc := range m.ToolCalls {
|
||
if tc.ID == "c_ok" {
|
||
foundOKAsst = true
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if !foundOKTool || !foundOKAsst {
|
||
t.Fatalf("recent tool-round (c_ok) must be retained as an atomic pair: assistantKept=%v toolKept=%v", foundOKAsst, foundOKTool)
|
||
}
|
||
}
|
||
|
||
func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) {
|
||
sys := schema.SystemMessage("sys")
|
||
summary := schema.AssistantMessage("summary", nil)
|
||
msgs := []adk.Message{
|
||
sys,
|
||
assistantToolCallsMsg("", "c1"),
|
||
schema.ToolMessage("r1", "c1"),
|
||
}
|
||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||
context.Background(),
|
||
msgs,
|
||
summary,
|
||
fixedTokenCounter(1),
|
||
0,
|
||
)
|
||
if err != nil {
|
||
t.Fatalf("unexpected error: %v", err)
|
||
}
|
||
if len(out) != 2 || out[0] != sys || out[1] != summary {
|
||
t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out)
|
||
}
|
||
}
|
||
|
||
func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
|
||
sys1 := schema.SystemMessage("sys1")
|
||
sys2 := schema.SystemMessage("sys2")
|
||
summary := schema.AssistantMessage("s", nil)
|
||
msgs := []adk.Message{
|
||
sys1,
|
||
schema.UserMessage("q"),
|
||
sys2, // 非典型位置,但应当被 system group 捕获
|
||
}
|
||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||
context.Background(),
|
||
msgs,
|
||
summary,
|
||
fixedTokenCounter(1),
|
||
100,
|
||
)
|
||
if err != nil {
|
||
t.Fatalf("unexpected error: %v", err)
|
||
}
|
||
systemCount := 0
|
||
for _, m := range out {
|
||
if m != nil && m.Role == schema.System {
|
||
systemCount++
|
||
}
|
||
}
|
||
if systemCount != 2 {
|
||
t.Fatalf("want 2 system messages retained, got %d", systemCount)
|
||
}
|
||
}
|
||
|
||
// assertNoOrphanTool 断言消息列表里的每个 role=tool 消息都能在更前面找到一个
|
||
// assistant(tool_calls) 提供相同 ID,否则说明产生了孤儿(触发 LLM 400 的根因)。
|
||
func assertNoOrphanTool(t *testing.T, msgs []adk.Message) {
|
||
t.Helper()
|
||
provided := make(map[string]struct{})
|
||
for _, m := range msgs {
|
||
if m == nil {
|
||
continue
|
||
}
|
||
if m.Role == schema.Assistant {
|
||
for _, tc := range m.ToolCalls {
|
||
if tc.ID != "" {
|
||
provided[tc.ID] = struct{}{}
|
||
}
|
||
}
|
||
}
|
||
if m.Role == schema.Tool && m.ToolCallID != "" {
|
||
if _, ok := provided[m.ToolCallID]; !ok {
|
||
t.Fatalf("orphan tool message found: ToolCallID=%q has no preceding assistant(tool_calls)", m.ToolCallID)
|
||
}
|
||
}
|
||
}
|
||
}
|