mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-02 07:45:24 +02:00
Add files via upload
This commit is contained in:
@@ -95,6 +95,9 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
}
|
||||
execHandlers = append(execHandlers, sumMw)
|
||||
}
|
||||
// 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、
|
||||
// telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。
|
||||
execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
|
||||
execHandlers = append(execHandlers, teleMw)
|
||||
}
|
||||
|
||||
@@ -130,6 +130,14 @@ func newEinoSummarizationMiddleware(
|
||||
}
|
||||
|
||||
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
|
||||
//
|
||||
// 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。
|
||||
// 把消息切成 round(回合)为原子单位:
|
||||
// - user(...) 单条为一个 round;
|
||||
// - assistant(tool_calls=[...]) 及其后连续的 role=tool 消息合成一个 round;
|
||||
// - 其它 assistant(reply, 无 tool_calls) 单条为一个 round。
|
||||
//
|
||||
// 倒序挑 round(预算不够即放弃该 round),保证 tool 消息不会跨 round 被孤立。
|
||||
func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
ctx context.Context,
|
||||
originalMessages []adk.Message,
|
||||
@@ -157,80 +165,136 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
return out, nil
|
||||
}
|
||||
|
||||
selectedReverse := make([]adk.Message, 0, 8)
|
||||
seen := make(map[adk.Message]struct{})
|
||||
totalTokens := 0
|
||||
assistantToolKept := 0
|
||||
const minAssistantToolTrail = 4
|
||||
rounds := splitMessagesIntoRounds(nonSystem)
|
||||
if len(rounds) == 0 {
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||||
out = append(out, systemMsgs...)
|
||||
out = append(out, summary)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
tryKeep := func(msg adk.Message) (bool, error) {
|
||||
if msg == nil {
|
||||
return false, nil
|
||||
// 目标:至少保留 minRounds 个 round 的执行轨迹;在预算允许时尽量多保留。
|
||||
// 优先确保最后一个 round(通常是最新的 tool 往返或 assistant 回复)存在。
|
||||
const minRounds = 2
|
||||
|
||||
selectedRoundsReverse := make([]messageRound, 0, 8)
|
||||
selectedCount := 0
|
||||
totalTokens := 0
|
||||
|
||||
tokensOfRound := func(r messageRound) (int, error) {
|
||||
if len(r.messages) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if _, ok := seen[msg]; ok {
|
||||
return false, nil
|
||||
}
|
||||
n, err := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: []adk.Message{msg}})
|
||||
n, err := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: r.messages})
|
||||
if err != nil {
|
||||
return false, err
|
||||
return 0, err
|
||||
}
|
||||
if n <= 0 {
|
||||
n = 1
|
||||
n = len(r.messages)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
for i := len(rounds) - 1; i >= 0; i-- {
|
||||
r := rounds[i]
|
||||
n, err := tokensOfRound(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 预算不够:已经保留了足够 round 则停,否则跳过该 round 继续往前找
|
||||
// (避免一个超大 round 挤占全部预算,至少保证有轨迹)。
|
||||
if totalTokens+n > recentTrailTokenBudget {
|
||||
return false, nil
|
||||
if selectedCount >= minRounds {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
totalTokens += n
|
||||
selectedReverse = append(selectedReverse, msg)
|
||||
seen[msg] = struct{}{}
|
||||
return true, nil
|
||||
selectedRoundsReverse = append(selectedRoundsReverse, r)
|
||||
selectedCount++
|
||||
}
|
||||
|
||||
// 优先保留最近 assistant/tool,确保执行轨迹可续跑。
|
||||
for i := len(nonSystem) - 1; i >= 0; i-- {
|
||||
msg := nonSystem[i]
|
||||
if msg.Role != schema.Assistant && msg.Role != schema.Tool {
|
||||
continue
|
||||
}
|
||||
ok, err := tryKeep(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
assistantToolKept++
|
||||
}
|
||||
if assistantToolKept >= minAssistantToolTrail {
|
||||
break
|
||||
}
|
||||
// 还原时间顺序
|
||||
selectedMsgs := make([]adk.Message, 0, 8)
|
||||
for i := len(selectedRoundsReverse) - 1; i >= 0; i-- {
|
||||
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
||||
}
|
||||
|
||||
// 在预算内回填更多最近消息,保持短链路上下文。
|
||||
for i := len(nonSystem) - 1; i >= 0; i-- {
|
||||
_, exists := seen[nonSystem[i]]
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
ok, err := tryKeep(nonSystem[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
selected := make([]adk.Message, 0, len(selectedReverse))
|
||||
for i := len(selectedReverse) - 1; i >= 0; i-- {
|
||||
selected = append(selected, selectedReverse[i])
|
||||
}
|
||||
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selected))
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs))
|
||||
out = append(out, systemMsgs...)
|
||||
out = append(out, summary)
|
||||
out = append(out, selected...)
|
||||
out = append(out, selectedMsgs...)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// messageRound 表示一个"不可分割"的消息回合。
|
||||
// - 对 assistant(tool_calls) + 随后若干 tool 消息的组合,round 内全部 call_id 成对完整;
|
||||
// - 对独立的 user / assistant(reply) 消息,round 仅包含该条消息。
|
||||
type messageRound struct {
|
||||
messages []adk.Message
|
||||
}
|
||||
|
||||
// splitMessagesIntoRounds 将非 system 消息切分为若干 round,保证:
|
||||
// - 每个 assistant(tool_calls) 与其对应的 role=tool 响应消息在同一个 round;
|
||||
// - 孤立(无对应 assistant(tool_calls))的 role=tool 消息不会单独成为 round,
|
||||
// 而是被丢弃(这些消息在 pair 完整性层面已属孤儿,保留反而会触发 LLM 400)。
|
||||
func splitMessagesIntoRounds(msgs []adk.Message) []messageRound {
|
||||
if len(msgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
rounds := make([]messageRound, 0, len(msgs))
|
||||
i := 0
|
||||
for i < len(msgs) {
|
||||
msg := msgs[i]
|
||||
if msg == nil {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case msg.Role == schema.Assistant && len(msg.ToolCalls) > 0:
|
||||
// 收集该 assistant 提供的 call_id 集合。
|
||||
provided := make(map[string]struct{}, len(msg.ToolCalls))
|
||||
for _, tc := range msg.ToolCalls {
|
||||
if tc.ID != "" {
|
||||
provided[tc.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
round := messageRound{messages: []adk.Message{msg}}
|
||||
j := i + 1
|
||||
for j < len(msgs) {
|
||||
next := msgs[j]
|
||||
if next == nil {
|
||||
j++
|
||||
continue
|
||||
}
|
||||
if next.Role != schema.Tool {
|
||||
break
|
||||
}
|
||||
if next.ToolCallID != "" {
|
||||
if _, ok := provided[next.ToolCallID]; !ok {
|
||||
// 下一条 tool 不属于当前 assistant,认为当前 round 结束。
|
||||
break
|
||||
}
|
||||
}
|
||||
round.messages = append(round.messages, next)
|
||||
j++
|
||||
}
|
||||
rounds = append(rounds, round)
|
||||
i = j
|
||||
case msg.Role == schema.Tool:
|
||||
// 孤儿 tool 消息:既不跟随在一个 assistant(tool_calls) 后,
|
||||
// 说明它对应的 assistant 已被上游裁剪;直接丢弃,下一步到 orphan pruner
|
||||
// 兜底也不会出错,但在 round 切分这里就剔除更干净。
|
||||
i++
|
||||
default:
|
||||
// user / assistant(reply) / 其它:单条成 round。
|
||||
rounds = append(rounds, messageRound{messages: []adk.Message{msg}})
|
||||
i++
|
||||
}
|
||||
}
|
||||
return rounds
|
||||
}
|
||||
|
||||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||||
tc := agent.NewTikTokenCounter()
|
||||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||||
|
||||
@@ -0,0 +1,345 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/adk/middlewares/summarization"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。
|
||||
// 用于验证 tool-round 超预算时整体被跳过的分支。
|
||||
func fixedTokenCounter(tokensPerToolMessage int) summarization.TokenCounterFunc {
|
||||
return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) {
|
||||
total := 0
|
||||
for _, msg := range in.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
switch msg.Role {
|
||||
case schema.Tool:
|
||||
total += tokensPerToolMessage
|
||||
default:
|
||||
total++
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
}
|
||||
|
||||
// variableTokenCounter 让 tool 消息按 len(Content) 计(可区分不同大小的 tool 结果),
|
||||
// 其它消息按 1 计;assistant 附加 len(ToolCalls) token 近似 tool_calls schema 开销。
|
||||
func variableTokenCounter() summarization.TokenCounterFunc {
|
||||
return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) {
|
||||
total := 0
|
||||
for _, msg := range in.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.Role == schema.Tool {
|
||||
total += len(msg.Content)
|
||||
continue
|
||||
}
|
||||
total++
|
||||
total += len(msg.ToolCalls)
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitMessagesIntoRounds_Complex(t *testing.T) {
|
||||
msgs := []adk.Message{
|
||||
schema.UserMessage("q1"),
|
||||
assistantToolCallsMsg("", "c1", "c2"),
|
||||
schema.ToolMessage("r1", "c1"),
|
||||
schema.ToolMessage("r2", "c2"),
|
||||
schema.AssistantMessage("reply1", nil),
|
||||
schema.UserMessage("q2"),
|
||||
assistantToolCallsMsg("", "c3"),
|
||||
schema.ToolMessage("r3", "c3"),
|
||||
}
|
||||
rounds := splitMessagesIntoRounds(msgs)
|
||||
// 5 rounds: user(q1) | assistant(tc:c1,c2)+tool*2 | assistant(reply1) | user(q2) | assistant(tc:c3)+tool(c3)
|
||||
if len(rounds) != 5 {
|
||||
t.Fatalf("want 5 rounds, got %d", len(rounds))
|
||||
}
|
||||
// round 1 应为 tool-round,必须成对
|
||||
r1 := rounds[1]
|
||||
if len(r1.messages) != 3 {
|
||||
t.Fatalf("rounds[1] size: want 3, got %d", len(r1.messages))
|
||||
}
|
||||
if r1.messages[0].Role != schema.Assistant || len(r1.messages[0].ToolCalls) != 2 {
|
||||
t.Fatalf("rounds[1][0] must be assistant(tc=2)")
|
||||
}
|
||||
for i := 1; i < 3; i++ {
|
||||
if r1.messages[i].Role != schema.Tool {
|
||||
t.Fatalf("rounds[1][%d] must be tool, got %s", i, r1.messages[i].Role)
|
||||
}
|
||||
}
|
||||
// 最后一个 round 成对
|
||||
rLast := rounds[len(rounds)-1]
|
||||
if len(rLast.messages) != 2 {
|
||||
t.Fatalf("rounds[last] size: want 2, got %d", len(rLast.messages))
|
||||
}
|
||||
if rLast.messages[0].Role != schema.Assistant || rLast.messages[1].Role != schema.Tool {
|
||||
t.Fatalf("last round must be assistant(tc)+tool(c3)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitMessagesIntoRounds_DropsOrphanTool(t *testing.T) {
|
||||
// 起点直接是 tool 消息(孤儿)—— 应被丢弃,不独立成 round。
|
||||
msgs := []adk.Message{
|
||||
schema.ToolMessage("orphan", "c_old"),
|
||||
schema.UserMessage("continue"),
|
||||
assistantToolCallsMsg("", "c_new"),
|
||||
schema.ToolMessage("r_new", "c_new"),
|
||||
}
|
||||
rounds := splitMessagesIntoRounds(msgs)
|
||||
// user(continue) | assistant(tc:c_new)+tool(c_new) → 2 rounds
|
||||
if len(rounds) != 2 {
|
||||
t.Fatalf("want 2 rounds after dropping orphan, got %d", len(rounds))
|
||||
}
|
||||
for _, r := range rounds {
|
||||
for _, m := range r.messages {
|
||||
if m.Role == schema.Tool && m.ToolCallID == "c_old" {
|
||||
t.Fatalf("orphan tool c_old must not appear in any round")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitMessagesIntoRounds_ToolBelongsToCurrentAssistantOnly(t *testing.T) {
|
||||
// 两个相邻 assistant(tc),第二个的 tool 不应被归到第一个 assistant。
|
||||
msgs := []adk.Message{
|
||||
assistantToolCallsMsg("", "c1"),
|
||||
schema.ToolMessage("r1", "c1"),
|
||||
assistantToolCallsMsg("", "c2"),
|
||||
schema.ToolMessage("r2", "c2"),
|
||||
}
|
||||
rounds := splitMessagesIntoRounds(msgs)
|
||||
if len(rounds) != 2 {
|
||||
t.Fatalf("want 2 rounds, got %d", len(rounds))
|
||||
}
|
||||
if len(rounds[0].messages) != 2 || rounds[0].messages[0].ToolCalls[0].ID != "c1" {
|
||||
t.Fatalf("round[0] wrong: %+v", rounds[0].messages)
|
||||
}
|
||||
if len(rounds[1].messages) != 2 || rounds[1].messages[0].ToolCalls[0].ID != "c2" {
|
||||
t.Fatalf("round[1] wrong: %+v", rounds[1].messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitMessagesIntoRounds_ToolBelongsToWrongAssistant(t *testing.T) {
|
||||
// assistant(tc:c1) 后面跟一个 tool_call_id=c999 的 tool 消息(本不属它)。
|
||||
// 切分规则:该 tool 不应拼入第一个 round(配对不完整),round 在此结束。
|
||||
// 而 c999 又没有对应 assistant,应被当孤儿丢弃。
|
||||
msgs := []adk.Message{
|
||||
assistantToolCallsMsg("", "c1"),
|
||||
schema.ToolMessage("wrong", "c999"),
|
||||
schema.UserMessage("hi"),
|
||||
}
|
||||
rounds := splitMessagesIntoRounds(msgs)
|
||||
// assistant(tc:c1) 没有对应 tool(c1),但不是孤儿(patchtoolcalls 会兜底补);
|
||||
// 它独立成 round 允许上游后处理。user(hi) 独立成 round。共 2 rounds。
|
||||
if len(rounds) != 2 {
|
||||
t.Fatalf("want 2 rounds, got %d: %+v", len(rounds), rounds)
|
||||
}
|
||||
for _, r := range rounds {
|
||||
for _, m := range r.messages {
|
||||
if m.Role == schema.Tool && m.ToolCallID == "c999" {
|
||||
t.Fatalf("wrong-owner tool must be dropped as orphan")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) {
|
||||
// 关键回归测试:一个 tool-round 整体被保留,而不是只保留 tool 消息。
|
||||
sys := schema.SystemMessage("sys")
|
||||
summary := schema.AssistantMessage("summary_content", nil)
|
||||
msgs := []adk.Message{
|
||||
sys,
|
||||
schema.UserMessage("q1"),
|
||||
schema.AssistantMessage("reply_before_tc", nil), // 填料,占预算
|
||||
assistantToolCallsMsg("", "c1"),
|
||||
schema.ToolMessage("r1", "c1"),
|
||||
}
|
||||
|
||||
// token 预算:2 条消息(1 assistant + 1 tool)恰好够用。
|
||||
// 若按条数保留,可能先吃 tool(c1) 再吃 assistant(reply) 落入 budget,assistant(tc:c1) 被挤掉,导致孤儿。
|
||||
// 按 round 保留时,整个 tool-round 为原子,要么保留 2 条都在,要么都不在。
|
||||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
context.Background(),
|
||||
msgs,
|
||||
summary,
|
||||
fixedTokenCounter(1),
|
||||
2, // 预算:2 tokens
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 必须包含 system + summary
|
||||
if len(out) < 2 {
|
||||
t.Fatalf("output too short: %d", len(out))
|
||||
}
|
||||
if out[0] != sys {
|
||||
t.Fatalf("first message must be system")
|
||||
}
|
||||
if out[1] != summary {
|
||||
t.Fatalf("second message must be summary")
|
||||
}
|
||||
|
||||
// 关键不变量:每个被保留的 tool 消息,必须能在输出中找到提供其 ToolCallID 的 assistant(tc)。
|
||||
assertNoOrphanTool(t, out)
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_SkipsOversizedToolRoundButKeepsSmallerRound(t *testing.T) {
|
||||
// 构造两个大小差异显著的 tool-round:
|
||||
// c_big round 的 tool 结果 content="aaaaaaaaaa"(10 bytes),round token ≈ 2 (assistant+tc) + 10 = 12
|
||||
// c_ok round 的 tool 结果 content="ok"(2 bytes),round token ≈ 2 + 2 = 4
|
||||
// 配上 budget=8,使得:
|
||||
// - 最新的 c_ok round(4)能放下;
|
||||
// - 进一步的中间 round(assistant reply + user)也能放下;
|
||||
// - 更早的 c_big round(12)放不下会被跳过(continue),而非 break。
|
||||
sys := schema.SystemMessage("sys")
|
||||
summary := schema.AssistantMessage("summary_content", nil)
|
||||
msgs := []adk.Message{
|
||||
sys,
|
||||
schema.UserMessage("q1"),
|
||||
assistantToolCallsMsg("", "c_big"),
|
||||
schema.ToolMessage("aaaaaaaaaa", "c_big"),
|
||||
schema.AssistantMessage("s", nil),
|
||||
schema.UserMessage("q2"),
|
||||
assistantToolCallsMsg("", "c_ok"),
|
||||
schema.ToolMessage("ok", "c_ok"),
|
||||
}
|
||||
|
||||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
context.Background(),
|
||||
msgs,
|
||||
summary,
|
||||
variableTokenCounter(),
|
||||
8,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
assertNoOrphanTool(t, out)
|
||||
|
||||
// c_big 整个 round 必须被丢弃(tool 和 assistant 都不能出现)
|
||||
for _, m := range out {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.Role == schema.Tool && m.ToolCallID == "c_big" {
|
||||
t.Fatal("oversized tool round must be skipped: tool(c_big) leaked")
|
||||
}
|
||||
if m.Role == schema.Assistant {
|
||||
for _, tc := range m.ToolCalls {
|
||||
if tc.ID == "c_big" {
|
||||
t.Fatal("oversized tool round must be skipped: assistant(tc:c_big) leaked")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 最近 round (c_ok) 作为一个原子单位必须整体保留。
|
||||
foundOKTool, foundOKAsst := false, false
|
||||
for _, m := range out {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.Role == schema.Tool && m.ToolCallID == "c_ok" {
|
||||
foundOKTool = true
|
||||
}
|
||||
if m.Role == schema.Assistant {
|
||||
for _, tc := range m.ToolCalls {
|
||||
if tc.ID == "c_ok" {
|
||||
foundOKAsst = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundOKTool || !foundOKAsst {
|
||||
t.Fatalf("recent tool-round (c_ok) must be retained as an atomic pair: assistantKept=%v toolKept=%v", foundOKAsst, foundOKTool)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) {
|
||||
sys := schema.SystemMessage("sys")
|
||||
summary := schema.AssistantMessage("summary", nil)
|
||||
msgs := []adk.Message{
|
||||
sys,
|
||||
assistantToolCallsMsg("", "c1"),
|
||||
schema.ToolMessage("r1", "c1"),
|
||||
}
|
||||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
context.Background(),
|
||||
msgs,
|
||||
summary,
|
||||
fixedTokenCounter(1),
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(out) != 2 || out[0] != sys || out[1] != summary {
|
||||
t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
|
||||
sys1 := schema.SystemMessage("sys1")
|
||||
sys2 := schema.SystemMessage("sys2")
|
||||
summary := schema.AssistantMessage("s", nil)
|
||||
msgs := []adk.Message{
|
||||
sys1,
|
||||
schema.UserMessage("q"),
|
||||
sys2, // 非典型位置,但应当被 system group 捕获
|
||||
}
|
||||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
context.Background(),
|
||||
msgs,
|
||||
summary,
|
||||
fixedTokenCounter(1),
|
||||
100,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
systemCount := 0
|
||||
for _, m := range out {
|
||||
if m != nil && m.Role == schema.System {
|
||||
systemCount++
|
||||
}
|
||||
}
|
||||
if systemCount != 2 {
|
||||
t.Fatalf("want 2 system messages retained, got %d", systemCount)
|
||||
}
|
||||
}
|
||||
|
||||
// assertNoOrphanTool 断言消息列表里的每个 role=tool 消息都能在更前面找到一个
|
||||
// assistant(tool_calls) 提供相同 ID,否则说明产生了孤儿(触发 LLM 400 的根因)。
|
||||
func assertNoOrphanTool(t *testing.T, msgs []adk.Message) {
|
||||
t.Helper()
|
||||
provided := make(map[string]struct{})
|
||||
for _, m := range msgs {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.Role == schema.Assistant {
|
||||
for _, tc := range m.ToolCalls {
|
||||
if tc.ID != "" {
|
||||
provided[tc.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
if m.Role == schema.Tool && m.ToolCallID != "" {
|
||||
if _, ok := provided[m.ToolCallID]; !ok {
|
||||
t.Fatalf("orphan tool message found: ToolCallID=%q has no preceding assistant(tool_calls)", m.ToolCallID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// orphanToolPrunerMiddleware 在每次 ChatModel 调用前剪掉没有对应 assistant(tool_calls) 的孤儿 tool 消息。
|
||||
//
|
||||
// 背景:
|
||||
// - eino 的 summarization 中间件在触发摘要后,默认把所有非 system 消息替换为 1 条 summary 消息;
|
||||
// 本项目通过自定义 Finalize(summarizeFinalizeWithRecentAssistantToolTrail)在 summary 后回填
|
||||
// 最近的 assistant/tool 轨迹。若 Finalize 的保留策略按"条数"截断而未按 round 对齐,可能保留
|
||||
// 了 tool 结果却把对应的 assistant(tool_calls) 落在了 summary 前面,形成孤儿 tool 消息。
|
||||
// - 同样,reduction / tool_search / 自定义断点恢复等任一改写历史的逻辑,都可能破坏
|
||||
// tool_call ↔ tool_result 配对。
|
||||
//
|
||||
// 一旦孤儿 tool 消息进入 ChatModel,OpenAI 兼容 API(含 DashScope / 各类中转)会返回
|
||||
// 400 "No tool call found for function call output with call_id ...",并被 Eino 包装成
|
||||
// [NodeRunError] 抛出,终止整轮编排。
|
||||
//
|
||||
// 设计取舍:
|
||||
// - 官方 patchtoolcalls 中间件只补反向(assistant(tc) 缺 tool_result),不处理孤儿 tool。
|
||||
// 本中间件与之互补,专职兜底正向孤儿。
|
||||
// - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。
|
||||
// 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。
|
||||
// - 位置建议:挂在所有可能改写历史的中间件(summarization / reduction / skill / plantask /
|
||||
// tool_search)之后,靠近 ChatModel 调用的那一端。
|
||||
type orphanToolPrunerMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
}
|
||||
|
||||
// newOrphanToolPrunerMiddleware 构造中间件。phase 仅用于日志区分 deep / supervisor /
|
||||
// plan_execute_executor / sub_agent,不影响运行时行为。
|
||||
func newOrphanToolPrunerMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
|
||||
return &orphanToolPrunerMiddleware{
|
||||
logger: logger,
|
||||
phase: phase,
|
||||
}
|
||||
}
|
||||
|
||||
// BeforeModelRewriteState 扫描消息列表,收集 assistant.tool_calls 提供的 call_id 集合,
|
||||
// 再剔除掉 ToolCallID 不在该集合中的 role=tool 消息。
|
||||
//
|
||||
// 复杂度:O(N)。当未发现孤儿时不产生任何分配,state 原样返回以便上游快路径。
|
||||
func (m *orphanToolPrunerMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
_ = mc
|
||||
if m == nil || state == nil || len(state.Messages) == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
|
||||
// 第一遍:收集所有已提供的 tool_call_id;同时快路径判定是否真的存在孤儿。
|
||||
provided := make(map[string]struct{}, 8)
|
||||
for _, msg := range state.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.Role == schema.Assistant {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
if tc.ID != "" {
|
||||
provided[tc.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hasOrphan := false
|
||||
for _, msg := range state.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.Role == schema.Tool && msg.ToolCallID != "" {
|
||||
if _, ok := provided[msg.ToolCallID]; !ok {
|
||||
hasOrphan = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasOrphan {
|
||||
return ctx, state, nil
|
||||
}
|
||||
|
||||
// 第二遍:生成剪除孤儿后的新消息列表。
|
||||
pruned := make([]adk.Message, 0, len(state.Messages))
|
||||
droppedIDs := make([]string, 0, 2)
|
||||
droppedNames := make([]string, 0, 2)
|
||||
for _, msg := range state.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.Role == schema.Tool && msg.ToolCallID != "" {
|
||||
if _, ok := provided[msg.ToolCallID]; !ok {
|
||||
droppedIDs = append(droppedIDs, msg.ToolCallID)
|
||||
droppedNames = append(droppedNames, msg.ToolName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
pruned = append(pruned, msg)
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
m.logger.Warn("eino orphan tool messages pruned before model call",
|
||||
zap.String("phase", m.phase),
|
||||
zap.Int("dropped_count", len(droppedIDs)),
|
||||
zap.Strings("dropped_tool_call_ids", droppedIDs),
|
||||
zap.Strings("dropped_tool_names", droppedNames),
|
||||
zap.Int("messages_before", len(state.Messages)),
|
||||
zap.Int("messages_after", len(pruned)),
|
||||
)
|
||||
}
|
||||
|
||||
ns := *state
|
||||
ns.Messages = pruned
|
||||
return ctx, &ns, nil
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func assistantToolCallsMsg(content string, callIDs ...string) *schema.Message {
|
||||
tcs := make([]schema.ToolCall, 0, len(callIDs))
|
||||
for _, id := range callIDs {
|
||||
tcs = append(tcs, schema.ToolCall{
|
||||
ID: id,
|
||||
Type: "function",
|
||||
Function: schema.FunctionCall{
|
||||
Name: "stub_tool",
|
||||
Arguments: `{}`,
|
||||
},
|
||||
})
|
||||
}
|
||||
return schema.AssistantMessage(content, tcs)
|
||||
}
|
||||
|
||||
func TestOrphanToolPruner_NoOpWhenPaired(t *testing.T) {
|
||||
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
|
||||
|
||||
msgs := []adk.Message{
|
||||
schema.SystemMessage("sys"),
|
||||
schema.UserMessage("hi"),
|
||||
assistantToolCallsMsg("", "c1", "c2"),
|
||||
schema.ToolMessage("r1", "c1"),
|
||||
schema.ToolMessage("r2", "c2"),
|
||||
schema.AssistantMessage("done", nil),
|
||||
}
|
||||
in := &adk.ChatModelAgentState{Messages: msgs}
|
||||
|
||||
_, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if len(out.Messages) != len(msgs) {
|
||||
t.Fatalf("expected %d messages kept, got %d", len(msgs), len(out.Messages))
|
||||
}
|
||||
// 快路径:未发现孤儿时必须原地返回 state,不分配新切片。
|
||||
if &out.Messages[0] != &msgs[0] {
|
||||
t.Fatalf("expected state to be returned as-is (same backing slice) when no orphan present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrphanToolPruner_DropsOrphanToolMessages(t *testing.T) {
|
||||
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
|
||||
|
||||
msgs := []adk.Message{
|
||||
schema.SystemMessage("sys"),
|
||||
// 摘要前的 assistant(tc: c_old) 已被裁剪,但对应的 tool 结果漏保留了。
|
||||
schema.ToolMessage("orphan result", "c_old"),
|
||||
schema.UserMessage("continue"),
|
||||
assistantToolCallsMsg("", "c_new"),
|
||||
schema.ToolMessage("r_new", "c_new"),
|
||||
}
|
||||
in := &adk.ChatModelAgentState{Messages: msgs}
|
||||
|
||||
_, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if len(out.Messages) != len(msgs)-1 {
|
||||
t.Fatalf("expected %d messages after pruning, got %d", len(msgs)-1, len(out.Messages))
|
||||
}
|
||||
for _, m := range out.Messages {
|
||||
if m != nil && m.Role == schema.Tool && m.ToolCallID == "c_old" {
|
||||
t.Fatalf("orphan tool message with ToolCallID=c_old should have been dropped")
|
||||
}
|
||||
}
|
||||
// 合法的 tool(c_new) 必须保留。
|
||||
foundNew := false
|
||||
for _, m := range out.Messages {
|
||||
if m != nil && m.Role == schema.Tool && m.ToolCallID == "c_new" {
|
||||
foundNew = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundNew {
|
||||
t.Fatal("paired tool message (c_new) must be retained")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrphanToolPruner_EmptyToolCallIDIsIgnored(t *testing.T) {
|
||||
// 空 ToolCallID 的 tool 消息在真实场景中极罕见,但不应当被误判为孤儿。
|
||||
// 语义上把它当作"无法校验,保留",避免误删。
|
||||
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
|
||||
|
||||
odd := schema.ToolMessage("no_id", "")
|
||||
msgs := []adk.Message{
|
||||
schema.UserMessage("hi"),
|
||||
odd,
|
||||
schema.AssistantMessage("ok", nil),
|
||||
}
|
||||
in := &adk.ChatModelAgentState{Messages: msgs}
|
||||
|
||||
_, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(out.Messages) != len(msgs) {
|
||||
t.Fatalf("empty ToolCallID tool message should be kept, got %d messages", len(out.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrphanToolPruner_NilAndEmpty(t *testing.T) {
|
||||
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
|
||||
|
||||
ctx := context.Background()
|
||||
// nil state
|
||||
if _, out, err := mw.BeforeModelRewriteState(ctx, nil, &adk.ModelContext{}); err != nil || out != nil {
|
||||
t.Fatalf("nil state: expected (nil,nil), got (%v,%v)", out, err)
|
||||
}
|
||||
// empty messages
|
||||
empty := &adk.ChatModelAgentState{}
|
||||
if _, out, err := mw.BeforeModelRewriteState(ctx, empty, &adk.ModelContext{}); err != nil || out != empty {
|
||||
t.Fatalf("empty messages: expected same state, got (%v,%v)", out, err)
|
||||
}
|
||||
}
|
||||
@@ -257,6 +257,9 @@ func RunDeepAgent(
|
||||
subHandlers = append(subHandlers, einoSkillMW)
|
||||
}
|
||||
subHandlers = append(subHandlers, subSumMw)
|
||||
// 孤儿 tool 消息兜底:放在 summarization 之后,telemetry 之前,
|
||||
// 以便 telemetry 记录的 token 数与 LLM 实际入参一致。
|
||||
subHandlers = append(subHandlers, newOrphanToolPrunerMiddleware(logger, "sub_agent:"+id))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil {
|
||||
subHandlers = append(subHandlers, teleMw)
|
||||
}
|
||||
@@ -393,6 +396,7 @@ func RunDeepAgent(
|
||||
deepHandlers = append(deepHandlers, einoSkillMW)
|
||||
}
|
||||
deepHandlers = append(deepHandlers, mainSumMw)
|
||||
deepHandlers = append(deepHandlers, newOrphanToolPrunerMiddleware(logger, "deep_orchestrator"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
|
||||
deepHandlers = append(deepHandlers, teleMw)
|
||||
}
|
||||
@@ -405,6 +409,7 @@ func RunDeepAgent(
|
||||
supHandlers = append(supHandlers, einoSkillMW)
|
||||
}
|
||||
supHandlers = append(supHandlers, mainSumMw)
|
||||
supHandlers = append(supHandlers, newOrphanToolPrunerMiddleware(logger, "supervisor_orchestrator"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
|
||||
supHandlers = append(supHandlers, teleMw)
|
||||
}
|
||||
@@ -455,6 +460,8 @@ func RunDeepAgent(
|
||||
FilesystemMiddleware: peFsMw,
|
||||
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
|
||||
mainSumMw,
|
||||
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。
|
||||
newOrphanToolPrunerMiddleware(logger, "plan_execute_planner_replanner"),
|
||||
newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"),
|
||||
},
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user