From 62441973393b91b9cdc6313f7758fe5d96fed238 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Tue, 23 Jun 2026 15:06:02 +0800 Subject: [PATCH] Add files via upload --- .../continuation_user_dedup_middleware.go | 104 ++++++++++++++++++ ...continuation_user_dedup_middleware_test.go | 65 +++++++++++ internal/multiagent/eino_adk_run_loop.go | 2 + .../eino_chat_model_tail_middleware.go | 50 +++++++++ internal/multiagent/eino_orchestration.go | 22 ++-- internal/multiagent/eino_single_runner.go | 15 +-- internal/multiagent/eino_summarize.go | 14 ++- internal/multiagent/eino_summarize_test.go | 21 +++- .../multiagent/eino_summarize_transcript.go | 23 +++- internal/multiagent/eino_transient_retry.go | 9 +- .../multiagent/eino_transient_retry_test.go | 15 +++ .../orphan_tool_pruner_middleware.go | 2 +- internal/multiagent/runner.go | 60 +++++----- .../system_message_normalizer_middleware.go | 86 +++++++++++++++ ...stem_message_normalizer_middleware_test.go | 75 +++++++++++++ 15 files changed, 493 insertions(+), 70 deletions(-) create mode 100644 internal/multiagent/continuation_user_dedup_middleware.go create mode 100644 internal/multiagent/continuation_user_dedup_middleware_test.go create mode 100644 internal/multiagent/eino_chat_model_tail_middleware.go create mode 100644 internal/multiagent/system_message_normalizer_middleware.go create mode 100644 internal/multiagent/system_message_normalizer_middleware_test.go diff --git a/internal/multiagent/continuation_user_dedup_middleware.go b/internal/multiagent/continuation_user_dedup_middleware.go new file mode 100644 index 00000000..fdb3b915 --- /dev/null +++ b/internal/multiagent/continuation_user_dedup_middleware.go @@ -0,0 +1,104 @@ +package multiagent + +import ( + "context" + "strings" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// continuationSessionMarker matches Cursor / IDE session-resume user injections. +const continuationSessionMarker = "This session is being continued from a previous conversation" + +// continuationUserDedupMiddleware keeps only the latest session-resume user message when +// multiple continuation injections were stacked (e.g. after repeated out-of-context resumes). +type continuationUserDedupMiddleware struct { + adk.BaseChatModelAgentMiddleware + logger *zap.Logger + phase string +} + +func newContinuationUserDedupMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware { + return &continuationUserDedupMiddleware{logger: logger, phase: phase} +} + +func (m *continuationUserDedupMiddleware) BeforeModelRewriteState( + ctx context.Context, + state *adk.ChatModelAgentState, + mc *adk.ModelContext, +) (context.Context, *adk.ChatModelAgentState, error) { + _ = mc + if m == nil || state == nil || len(state.Messages) == 0 { + return ctx, state, nil + } + deduped, dropped := dedupContinuationUserMessages(state.Messages) + if dropped == 0 { + return ctx, state, nil + } + if m.logger != nil { + m.logger.Info("eino continuation user messages deduplicated", + zap.String("phase", m.phase), + zap.Int("dropped", dropped), + zap.Int("messages_before", len(state.Messages)), + zap.Int("messages_after", len(deduped)), + ) + } + out := *state + out.Messages = deduped + return ctx, &out, nil +} + +func adkUserMessageText(msg adk.Message) string { + if msg == nil { + return "" + } + var b strings.Builder + if s := strings.TrimSpace(msg.Content); s != "" { + b.WriteString(s) + } + for _, part := range msg.UserInputMultiContent { + if part.Type == schema.ChatMessagePartTypeText { + if s := strings.TrimSpace(part.Text); s != "" { + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(s) + } + } + } + return b.String() +} + +func isContinuationUserMessage(msg adk.Message) bool { + if msg == nil || msg.Role != schema.User { + return false + } + return strings.Contains(adkUserMessageText(msg), continuationSessionMarker) +} + +func dedupContinuationUserMessages(msgs []adk.Message) ([]adk.Message, int) { + lastIdx := -1 + contCount := 0 + for i, msg := range msgs { + if !isContinuationUserMessage(msg) { + continue + } + contCount++ + lastIdx = i + } + if contCount <= 1 { + return msgs, 0 + } + out := make([]adk.Message, 0, len(msgs)-(contCount-1)) + dropped := 0 + for i, msg := range msgs { + if isContinuationUserMessage(msg) && i != lastIdx { + dropped++ + continue + } + out = append(out, msg) + } + return out, dropped +} diff --git a/internal/multiagent/continuation_user_dedup_middleware_test.go b/internal/multiagent/continuation_user_dedup_middleware_test.go new file mode 100644 index 00000000..75987d86 --- /dev/null +++ b/internal/multiagent/continuation_user_dedup_middleware_test.go @@ -0,0 +1,65 @@ +package multiagent + +import ( + "context" + "strings" + "testing" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/schema" +) + +func continuationUser(text string) adk.Message { + return &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: continuationSessionMarker + "\n" + text}, + {Type: schema.ChatMessagePartTypeText, Text: "Please continue the conversation from where we left it off."}, + }, + } +} + +func TestDedupContinuationUserMessages_KeepsLatest(t *testing.T) { + msgs := []adk.Message{ + continuationUser("summary old"), + schema.UserMessage("real task"), + continuationUser("summary new"), + } + out, dropped := dedupContinuationUserMessages(msgs) + if dropped != 1 { + t.Fatalf("dropped=%d want 1", dropped) + } + if len(out) != 2 { + t.Fatalf("len=%d want 2", len(out)) + } + if out[0].Role != schema.User || adkUserMessageText(out[0]) != "real task" { + t.Fatalf("first should remain real task, got %q", adkUserMessageText(out[0])) + } + if !strings.Contains(adkUserMessageText(out[1]), "summary new") { + t.Fatalf("latest continuation not kept: %q", adkUserMessageText(out[1])) + } +} + +func TestDedupContinuationUserMessages_NoOpSingle(t *testing.T) { + msgs := []adk.Message{continuationUser("only"), schema.UserMessage("task")} + out, dropped := dedupContinuationUserMessages(msgs) + if dropped != 0 || len(out) != 2 { + t.Fatalf("unexpected change dropped=%d len=%d", dropped, len(out)) + } +} + +func TestContinuationUserDedupMiddleware(t *testing.T) { + mw := newContinuationUserDedupMiddleware(nil, "test") + state := &adk.ChatModelAgentState{Messages: []adk.Message{ + continuationUser("old"), + continuationUser("new"), + schema.UserMessage("task"), + }} + _, out, err := mw.(*continuationUserDedupMiddleware).BeforeModelRewriteState(context.Background(), state, nil) + if err != nil { + t.Fatal(err) + } + if len(out.Messages) != 2 { + t.Fatalf("want 2 messages after dedup, got %d", len(out.Messages)) + } +} diff --git a/internal/multiagent/eino_adk_run_loop.go b/internal/multiagent/eino_adk_run_loop.go index 01df2fd3..358a933d 100644 --- a/internal/multiagent/eino_adk_run_loop.go +++ b/internal/multiagent/eino_adk_run_loop.go @@ -627,6 +627,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs if restarted { continue } + } else { + transientRetrier.reset() } if ev.AgentName != "" && progress != nil { iterEinoAgent := orchestratorName diff --git a/internal/multiagent/eino_chat_model_tail_middleware.go b/internal/multiagent/eino_chat_model_tail_middleware.go new file mode 100644 index 00000000..c1d7d4f4 --- /dev/null +++ b/internal/multiagent/eino_chat_model_tail_middleware.go @@ -0,0 +1,50 @@ +package multiagent + +import ( + "github.com/cloudwego/eino/adk" + "go.uber.org/zap" +) + +// einoChatModelTailConfig configures middleware appended after reduction/skill/plantask +// and immediately before each ChatModel invocation pipeline completes. +// +// Order (best practice): +// 1. system merge — accurate token count for summarization +// 2. continuation user dedup — drop stale session-resume injections +// 3. summarization +// 4. orphan tool prune +// 5. telemetry +// 6. model-facing trace snapshot +type einoChatModelTailConfig struct { + logger *zap.Logger + phase string + summarization adk.ChatModelAgentMiddleware + modelName string + conversationID string + trace *modelFacingTraceHolder + skipOrphanPruner bool + skipTelemetry bool + skipTrace bool +} + +func appendEinoChatModelTailMiddlewares(handlers []adk.ChatModelAgentMiddleware, cfg einoChatModelTailConfig) []adk.ChatModelAgentMiddleware { + handlers = append(handlers, newSystemMessageNormalizerMiddleware(cfg.logger, cfg.phase)) + handlers = append(handlers, newContinuationUserDedupMiddleware(cfg.logger, cfg.phase)) + if cfg.summarization != nil { + handlers = append(handlers, cfg.summarization) + } + if !cfg.skipOrphanPruner { + handlers = append(handlers, newOrphanToolPrunerMiddleware(cfg.logger, cfg.phase)) + } + if !cfg.skipTelemetry { + if teleMw := newEinoModelInputTelemetryMiddleware(cfg.logger, cfg.modelName, cfg.conversationID, cfg.phase); teleMw != nil { + handlers = append(handlers, teleMw) + } + } + if !cfg.skipTrace && cfg.trace != nil { + if capMw := newModelFacingTraceMiddleware(cfg.trace); capMw != nil { + handlers = append(handlers, capMw) + } + } + return handlers +} diff --git a/internal/multiagent/eino_orchestration.go b/internal/multiagent/eino_orchestration.go index fa387137..1ca95069 100644 --- a/internal/multiagent/eino_orchestration.go +++ b/internal/multiagent/eino_orchestration.go @@ -94,24 +94,20 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma if a.SkillMiddleware != nil { execHandlers = append(execHandlers, a.SkillMiddleware) } - // 4. summarization(最后,与 Deep/Supervisor 一致) + // 4. pre-summarization normalize + continuation dedup, then summarization (与 Deep/Supervisor 一致) if a.AppCfg != nil { sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger) if sumErr != nil { return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr) } - execHandlers = append(execHandlers, sumMw) - } - // 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、 - // telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。 - execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor")) - if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil { - execHandlers = append(execHandlers, teleMw) - } - if a.ModelFacingTrace != nil { - if capMw := newModelFacingTraceMiddleware(a.ModelFacingTrace); capMw != nil { - execHandlers = append(execHandlers, capMw) - } + execHandlers = appendEinoChatModelTailMiddlewares(execHandlers, einoChatModelTailConfig{ + logger: a.Logger, + phase: "plan_execute_executor", + summarization: sumMw, + modelName: a.ModelName, + conversationID: a.ConversationID, + trace: a.ModelFacingTrace, + }) } executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{ Model: a.ExecModel, diff --git a/internal/multiagent/eino_single_runner.go b/internal/multiagent/eino_single_runner.go index 2d5cb9cb..d92e0c3c 100644 --- a/internal/multiagent/eino_single_runner.go +++ b/internal/multiagent/eino_single_runner.go @@ -144,13 +144,14 @@ func RunEinoSingleChatModelAgent( } handlers = append(handlers, einoSkillMW) } - handlers = append(handlers, mainSumMw) - if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil { - handlers = append(handlers, teleMw) - } - if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil { - handlers = append(handlers, capMw) - } + handlers = appendEinoChatModelTailMiddlewares(handlers, einoChatModelTailConfig{ + logger: logger, + phase: "eino_single", + summarization: mainSumMw, + modelName: appCfg.OpenAI.Model, + conversationID: conversationID, + trace: modelFacingTrace, + }) maxIter := agentMaxIterations(appCfg) diff --git a/internal/multiagent/eino_summarize.go b/internal/multiagent/eino_summarize.go index 222601fa..37495fee 100644 --- a/internal/multiagent/eino_summarize.go +++ b/internal/multiagent/eino_summarize.go @@ -257,17 +257,19 @@ func summarizeFinalizeWithRecentAssistantToolTrail( nonSystem = append(nonSystem, msg) } + mergedSystem := mergeCollectedSystemMessages(systemMsgs) + if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 { - out := make([]adk.Message, 0, len(systemMsgs)+1) - out = append(out, systemMsgs...) + out := make([]adk.Message, 0, len(mergedSystem)+1) + out = append(out, mergedSystem...) out = append(out, summary) return out, nil } rounds := splitMessagesIntoRounds(nonSystem) if len(rounds) == 0 { - out := make([]adk.Message, 0, len(systemMsgs)+1) - out = append(out, systemMsgs...) + out := make([]adk.Message, 0, len(mergedSystem)+1) + out = append(out, mergedSystem...) out = append(out, summary) return out, nil } @@ -319,8 +321,8 @@ func summarizeFinalizeWithRecentAssistantToolTrail( selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...) } - out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs)) - out = append(out, systemMsgs...) + out := make([]adk.Message, 0, len(mergedSystem)+1+len(selectedMsgs)) + out = append(out, mergedSystem...) out = append(out, summary) out = append(out, selectedMsgs...) return out, nil diff --git a/internal/multiagent/eino_summarize_test.go b/internal/multiagent/eino_summarize_test.go index 94fdefda..0f4672da 100644 --- a/internal/multiagent/eino_summarize_test.go +++ b/internal/multiagent/eino_summarize_test.go @@ -192,8 +192,8 @@ func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) { if len(out) < 2 { t.Fatalf("output too short: %d", len(out)) } - if out[0] != sys { - t.Fatalf("first message must be system") + if out[0].Role != schema.System || out[0].Content != "sys" { + t.Fatalf("first message must be system sys, got %s: %q", out[0].Role, out[0].Content) } if out[1] != summary { t.Fatalf("second message must be summary") @@ -293,12 +293,12 @@ func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if len(out) != 2 || out[0] != sys || out[1] != summary { + if len(out) != 2 || out[0].Role != schema.System || out[0].Content != "sys" || out[1] != summary { t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out) } } -func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) { +func TestSummarizeFinalize_MergesSystemMessages(t *testing.T) { sys1 := schema.SystemMessage("sys1") sys2 := schema.SystemMessage("sys2") summary := schema.AssistantMessage("s", nil) @@ -321,10 +321,13 @@ func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) { for _, m := range out { if m != nil && m.Role == schema.System { systemCount++ + if got := m.Content; got != "sys1\n\nsys2" { + t.Fatalf("unexpected merged system content: %q", got) + } } } - if systemCount != 2 { - t.Fatalf("want 2 system messages retained, got %d", systemCount) + if systemCount != 1 { + t.Fatalf("want 1 merged system message, got %d", systemCount) } } @@ -378,6 +381,12 @@ func TestWriteSummarizationTranscript(t *testing.T) { if !strings.Contains(text, "tool_calls:") || !strings.Contains(text, "nmap output") { t.Fatalf("missing tool round: %q", text) } + if !strings.Contains(text, `"name":"stub_tool"`) || !strings.Contains(text, `"arguments":"{}"`) { + t.Fatalf("missing tool name/arguments: %q", text) + } + if strings.Contains(text, "tool_call_id") || strings.Contains(text, `"id":"tc1"`) { + t.Fatalf("transcript should omit tool_call_id: %q", text) + } } func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) { diff --git a/internal/multiagent/eino_summarize_transcript.go b/internal/multiagent/eino_summarize_transcript.go index fcb7e2c3..15b8edb6 100644 --- a/internal/multiagent/eino_summarize_transcript.go +++ b/internal/multiagent/eino_summarize_transcript.go @@ -23,6 +23,11 @@ const ( transcriptSkillsSystemMarker = "# Skills System" ) +type transcriptToolCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + // formatSummarizationTranscript renders pre-compaction messages for transcript.txt. // Best practice: keep full user/assistant/tool turns; slim system to dynamic blocks only. func formatSummarizationTranscript(msgs []adk.Message) string { @@ -138,15 +143,21 @@ func appendTranscriptMessage(sb *strings.Builder, msg adk.Message) { } } if len(msg.ToolCalls) > 0 { - if b, err := sonic.Marshal(msg.ToolCalls); err == nil { + if b, err := sonic.Marshal(formatTranscriptToolCalls(msg.ToolCalls)); err == nil { sb.WriteString("tool_calls: ") sb.Write(b) sb.WriteByte('\n') } } - if msg.ToolCallID != "" { - sb.WriteString("tool_call_id: ") - sb.WriteString(msg.ToolCallID) - sb.WriteByte('\n') - } +} + +func formatTranscriptToolCalls(calls []schema.ToolCall) []transcriptToolCall { + out := make([]transcriptToolCall, 0, len(calls)) + for _, tc := range calls { + out = append(out, transcriptToolCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }) + } + return out } diff --git a/internal/multiagent/eino_transient_retry.go b/internal/multiagent/eino_transient_retry.go index 7090fe68..dfeb8228 100644 --- a/internal/multiagent/eino_transient_retry.go +++ b/internal/multiagent/eino_transient_retry.go @@ -62,6 +62,7 @@ func isEinoTransientRunError(err error) bool { "dial tcp", "tls handshake timeout", "stream error", + "goaway", // http2: server sent GOAWAY and closed the connection "unexpected eof", `": eof`, // net/http: Post "url": EOF (often wraps io.EOF) "unexpected end of json", @@ -142,6 +143,9 @@ func (r *einoTransientRunRetrier) attempt() int { return r.attempts } func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts } +// reset 在一次成功推进后清零重试计数,使后续临时错误从第 1 次退避重新开始。 +func (r *einoTransientRunRetrier) reset() { r.attempts = 0 } + func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int { if args != nil && args.RunRetryMaxAttempts > 0 { return args.RunRetryMaxAttempts @@ -177,10 +181,11 @@ const ( // 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。 func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) { if trace := persistTraceSource(args, nil); len(trace) > 0 { - return append([]adk.Message(nil), trace...), einoRestartContextModelTrace + // modelFacingTrace includes prior Instruction system message(s); genModelInput will prepend again. + return stripADKSystemMessages(trace), einoRestartContextModelTrace } if len(accumulated) > baseCount { - return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated + return stripADKSystemMessages(accumulated), einoRestartContextAccumulated } return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial } diff --git a/internal/multiagent/eino_transient_retry_test.go b/internal/multiagent/eino_transient_retry_test.go index 0761dc40..0b2a5c69 100644 --- a/internal/multiagent/eino_transient_retry_test.go +++ b/internal/multiagent/eino_transient_retry_test.go @@ -27,6 +27,7 @@ func TestIsEinoTransientRunError(t *testing.T) { {"429", errors.New("HTTP 429 Too Many Requests"), true}, {"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true}, {"connection reset", errors.New("read tcp: connection reset by peer"), true}, + {"http2 goaway", errors.New("failed to receive stream chunk: error, http2: server sent GOAWAY and closed the connection; LastStreamID=791, ErrCode=NO_ERROR"), true}, {"unexpected eof", errors.New("unexpected EOF"), true}, {"503", errors.New("upstream returned 503"), true}, {"iteration limit", errors.New("max iteration reached"), false}, @@ -90,6 +91,20 @@ func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) { } } +func TestEinoTransientRunRetrierReset(t *testing.T) { + t.Parallel() + r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second}) + r.attempts = 3 + r.reset() + if r.attempt() != 0 { + t.Fatalf("after reset: attempt=%d, want 0", r.attempt()) + } + // 重置后下一次退避应从 2s 起算(attempt index 0)。 + if got := einoTransientRetryBackoff(r.attempt(), r.policy.maxBackoff); got != 2*time.Second { + t.Fatalf("backoff after reset: got %v, want 2s", got) + } +} + func TestAppendUserMessageIfNeeded(t *testing.T) { t.Parallel() msgs := []adk.Message{schema.UserMessage("old task")} diff --git a/internal/multiagent/orphan_tool_pruner_middleware.go b/internal/multiagent/orphan_tool_pruner_middleware.go index 8e33f8bb..4e6b42c3 100644 --- a/internal/multiagent/orphan_tool_pruner_middleware.go +++ b/internal/multiagent/orphan_tool_pruner_middleware.go @@ -27,7 +27,7 @@ import ( // 本中间件与之互补,专职兜底正向孤儿。 // - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。 // 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。 -// - 位置建议:挂在所有可能改写历史的中间件(summarization / reduction / skill / plantask / +// - 位置建议:挂在 summarization / reduction / skill / plantask / system 合并 / 续聊 dedup 之后, // tool_search)之后,靠近 ChatModel 调用的那一端。 type orphanToolPrunerMiddleware struct { adk.BaseChatModelAgentMiddleware diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go index c5425b49..1fd02cf8 100644 --- a/internal/multiagent/runner.go +++ b/internal/multiagent/runner.go @@ -231,13 +231,13 @@ func RunDeepAgent( } subHandlers = append(subHandlers, einoSkillMW) } - subHandlers = append(subHandlers, subSumMw) - // 孤儿 tool 消息兜底:放在 summarization 之后,telemetry 之前, - // 以便 telemetry 记录的 token 数与 LLM 实际入参一致。 - subHandlers = append(subHandlers, newOrphanToolPrunerMiddleware(logger, "sub_agent:"+id)) - if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil { - subHandlers = append(subHandlers, teleMw) - } + subHandlers = appendEinoChatModelTailMiddlewares(subHandlers, einoChatModelTailConfig{ + logger: logger, + phase: "sub_agent:" + id, + summarization: subSumMw, + modelName: appCfg.OpenAI.Model, + conversationID: conversationID, + }) subInstrFinal := project.AppendVisionImageAnalysisIfReady(instr, appCfg.Vision.Ready()) subInstrFinal = injectToolNamesOnlyInstruction(ctx, subInstrFinal, subTools, subToolSearchActive) @@ -379,14 +379,14 @@ func RunDeepAgent( if einoSkillMW != nil { deepHandlers = append(deepHandlers, einoSkillMW) } - deepHandlers = append(deepHandlers, mainSumMw) - deepHandlers = append(deepHandlers, newOrphanToolPrunerMiddleware(logger, "deep_orchestrator")) - if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil { - deepHandlers = append(deepHandlers, teleMw) - } - if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil { - deepHandlers = append(deepHandlers, capMw) - } + deepHandlers = appendEinoChatModelTailMiddlewares(deepHandlers, einoChatModelTailConfig{ + logger: logger, + phase: "deep_orchestrator", + summarization: mainSumMw, + modelName: appCfg.OpenAI.Model, + conversationID: conversationID, + trace: modelFacingTrace, + }) supHandlers := []adk.ChatModelAgentMiddleware{} if len(mainOrchestratorPre) > 0 { @@ -395,14 +395,14 @@ func RunDeepAgent( if einoSkillMW != nil { supHandlers = append(supHandlers, einoSkillMW) } - supHandlers = append(supHandlers, mainSumMw) - supHandlers = append(supHandlers, newOrphanToolPrunerMiddleware(logger, "supervisor_orchestrator")) - if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil { - supHandlers = append(supHandlers, teleMw) - } - if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil { - supHandlers = append(supHandlers, capMw) - } + supHandlers = appendEinoChatModelTailMiddlewares(supHandlers, einoChatModelTailConfig{ + logger: logger, + phase: "supervisor_orchestrator", + summarization: mainSumMw, + modelName: appCfg.OpenAI.Model, + conversationID: conversationID, + trace: modelFacingTrace, + }) mainToolsCfg := adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ @@ -451,12 +451,14 @@ func RunDeepAgent( SkillMiddleware: einoSkillMW, FilesystemMiddleware: peFsMw, ModelFacingTrace: modelFacingTrace, - PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{ - mainSumMw, - // 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。 - newOrphanToolPrunerMiddleware(logger, "plan_execute_planner_replanner"), - newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"), - }, + PlannerReplannerRewriteHandlers: appendEinoChatModelTailMiddlewares(nil, einoChatModelTailConfig{ + logger: logger, + phase: "plan_execute_planner_replanner", + summarization: mainSumMw, + modelName: appCfg.OpenAI.Model, + conversationID: conversationID, + skipTrace: true, + }), }) if perr != nil { return nil, perr diff --git a/internal/multiagent/system_message_normalizer_middleware.go b/internal/multiagent/system_message_normalizer_middleware.go new file mode 100644 index 00000000..6739d202 --- /dev/null +++ b/internal/multiagent/system_message_normalizer_middleware.go @@ -0,0 +1,86 @@ +package multiagent + +import ( + "context" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// systemMessageNormalizerMiddleware merges duplicate role=system messages into a single +// leading system message before summarization and each ChatModel call. +type systemMessageNormalizerMiddleware struct { + adk.BaseChatModelAgentMiddleware + logger *zap.Logger + phase string +} + +func newSystemMessageNormalizerMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware { + return &systemMessageNormalizerMiddleware{logger: logger, phase: phase} +} + +func (m *systemMessageNormalizerMiddleware) BeforeModelRewriteState( + ctx context.Context, + state *adk.ChatModelAgentState, + mc *adk.ModelContext, +) (context.Context, *adk.ChatModelAgentState, error) { + _ = mc + if m == nil || state == nil || len(state.Messages) == 0 { + return ctx, state, nil + } + before := countADKSystemMessages(state.Messages) + if before <= 1 { + return ctx, state, nil + } + normalized := normalizeSingleLeadingSystemMessage(state.Messages, "") + if len(normalized) == len(state.Messages) && countADKSystemMessages(normalized) >= before { + return ctx, state, nil + } + if m.logger != nil { + m.logger.Info("eino system messages merged", + zap.String("phase", m.phase), + zap.Int("system_before", before), + zap.Int("system_after", countADKSystemMessages(normalized)), + zap.Int("messages_before", len(state.Messages)), + zap.Int("messages_after", len(normalized)), + ) + } + out := *state + out.Messages = normalized + return ctx, &out, nil +} + +func countADKSystemMessages(msgs []adk.Message) int { + n := 0 + for _, msg := range msgs { + if msg != nil && msg.Role == schema.System { + n++ + } + } + return n +} + +// stripADKSystemMessages removes all system messages. Use before runner.Run restart when +// genModelInput will prepend a fresh Instruction. +func stripADKSystemMessages(msgs []adk.Message) []adk.Message { + if len(msgs) == 0 { + return msgs + } + out := make([]adk.Message, 0, len(msgs)) + for _, msg := range msgs { + if msg == nil || msg.Role == schema.System { + continue + } + out = append(out, msg) + } + return out +} + +// mergeCollectedSystemMessages collapses multiple system messages into one (or none). +func mergeCollectedSystemMessages(systemMsgs []adk.Message) []adk.Message { + if len(systemMsgs) == 0 { + return nil + } + return normalizeSingleLeadingSystemMessage(systemMsgs, "") +} diff --git a/internal/multiagent/system_message_normalizer_middleware_test.go b/internal/multiagent/system_message_normalizer_middleware_test.go new file mode 100644 index 00000000..eaf8219e --- /dev/null +++ b/internal/multiagent/system_message_normalizer_middleware_test.go @@ -0,0 +1,75 @@ +package multiagent + +import ( + "context" + "testing" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/schema" +) + +func TestStripADKSystemMessages(t *testing.T) { + in := []adk.Message{ + schema.SystemMessage("a"), + schema.UserMessage("u"), + schema.SystemMessage("b"), + schema.AssistantMessage("x", nil), + } + out := stripADKSystemMessages(in) + if len(out) != 2 { + t.Fatalf("got %d messages, want 2", len(out)) + } + if out[0].Role != schema.User || out[1].Role != schema.Assistant { + t.Fatalf("unexpected roles: %s, %s", out[0].Role, out[1].Role) + } +} + +func TestEinoMessagesForRunRestart_StripsSystemFromTrace(t *testing.T) { + holder := newModelFacingTraceHolder() + holder.storeFromState(&adk.ChatModelAgentState{Messages: []adk.Message{ + schema.SystemMessage("sys-1"), + schema.SystemMessage("sys-2"), + schema.UserMessage("task"), + }}) + msgs, src := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, nil, nil, 0) + if src != einoRestartContextModelTrace { + t.Fatalf("source: got %q want model_trace", src) + } + if len(msgs) != 1 || msgs[0].Role != schema.User { + t.Fatalf("expected user-only restart msgs, got %+v", msgs) + } +} + +func TestSystemMessageNormalizerMiddleware_MergesDuplicates(t *testing.T) { + mw := newSystemMessageNormalizerMiddleware(nil, "test") + state := &adk.ChatModelAgentState{Messages: []adk.Message{ + schema.SystemMessage("a"), + schema.SystemMessage("b"), + schema.UserMessage("u"), + }} + _, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil) + if err != nil { + t.Fatal(err) + } + if countADKSystemMessages(out.Messages) != 1 { + t.Fatalf("want 1 system, got %d", countADKSystemMessages(out.Messages)) + } + if out.Messages[0].Content != "a\n\nb" { + t.Fatalf("merged content: %q", out.Messages[0].Content) + } +} + +func TestSystemMessageNormalizerMiddleware_NoOpSingleSystem(t *testing.T) { + mw := newSystemMessageNormalizerMiddleware(nil, "test") + state := &adk.ChatModelAgentState{Messages: []adk.Message{ + schema.SystemMessage("only"), + schema.UserMessage("u"), + }} + _, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil) + if err != nil { + t.Fatal(err) + } + if out != state { + t.Fatalf("expected same state pointer for no-op") + } +}