From 6c47996ea8319318e441ebd71e78a615a210151b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Mon, 23 Mar 2026 02:37:45 +0800 Subject: [PATCH] Add files via upload --- internal/multiagent/eino_summarize.go | 140 ++++++++++++++++++++++++++ internal/multiagent/runner.go | 19 +++- 2 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 internal/multiagent/eino_summarize.go diff --git a/internal/multiagent/eino_summarize.go b/internal/multiagent/eino_summarize.go new file mode 100644 index 00000000..81260109 --- /dev/null +++ b/internal/multiagent/eino_summarize.go @@ -0,0 +1,140 @@ +package multiagent + +import ( + "context" + "fmt" + "strings" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/config" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/middlewares/summarization" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// einoSummarizeUserInstruction 与单 Agent MemoryCompressor 目标一致:压缩时保留渗透关键信息。 +const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。 + +必须保留:已确认漏洞与攻击路径、工具输出中的核心发现、凭证与认证细节、架构与薄弱点、当前进度、失败尝试与死路、策略决策。 +保留精确技术细节(URL、路径、参数、Payload、版本号、报错原文可摘要但要点不丢)。 +将冗长扫描输出概括为结论;重复发现合并表述。 + +输出须使后续代理能无缝继续同一授权测试任务。` + +// newEinoSummarizationMiddleware 使用 Eino ADK Summarization 中间件(见 https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_summarization/)。 +// 触发阈值与单 Agent MemoryCompressor 一致:当估算 token 超过 openai.max_total_tokens 的 90% 时摘要。 +func newEinoSummarizationMiddleware( + ctx context.Context, + summaryModel model.BaseChatModel, + appCfg *config.Config, + logger *zap.Logger, +) (adk.ChatModelAgentMiddleware, error) { + if summaryModel == nil || appCfg == nil { + return nil, fmt.Errorf("multiagent: summarization 需要 model 与配置") + } + maxTotal := appCfg.OpenAI.MaxTotalTokens + if maxTotal <= 0 { + maxTotal = 120000 + } + trigger := int(float64(maxTotal) * 0.9) + if trigger < 4096 { + trigger = maxTotal + if trigger < 4096 { + trigger = 4096 + } + } + preserveMax := trigger / 3 + if preserveMax < 2048 { + preserveMax = 2048 + } + + modelName := strings.TrimSpace(appCfg.OpenAI.Model) + if modelName == "" { + modelName = "gpt-4o" + } + + mw, err := summarization.New(ctx, &summarization.Config{ + Model: summaryModel, + Trigger: &summarization.TriggerCondition{ + ContextTokens: trigger, + }, + TokenCounter: einoSummarizationTokenCounter(modelName), + UserInstruction: einoSummarizeUserInstruction, + EmitInternalEvents: false, + PreserveUserMessages: &summarization.PreserveUserMessages{ + Enabled: true, + MaxTokens: preserveMax, + }, + Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error { + if logger == nil { + return nil + } + logger.Info("eino summarization 已压缩上下文", + zap.Int("messages_before", len(before.Messages)), + zap.Int("messages_after", len(after.Messages)), + zap.Int("max_total_tokens", maxTotal), + zap.Int("trigger_context_tokens", trigger), + ) + return nil + }, + }) + if err != nil { + return nil, fmt.Errorf("summarization.New: %w", err) + } + return mw, nil +} + +func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc { + tc := agent.NewTikTokenCounter() + return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) { + var sb strings.Builder + for _, msg := range input.Messages { + if msg == nil { + continue + } + sb.WriteString(string(msg.Role)) + sb.WriteByte('\n') + if msg.Content != "" { + sb.WriteString(msg.Content) + sb.WriteByte('\n') + } + if msg.ReasoningContent != "" { + sb.WriteString(msg.ReasoningContent) + sb.WriteByte('\n') + } + if len(msg.ToolCalls) > 0 { + if b, err := sonic.Marshal(msg.ToolCalls); err == nil { + sb.Write(b) + sb.WriteByte('\n') + } + } + for _, part := range msg.UserInputMultiContent { + if part.Type == schema.ChatMessagePartTypeText && part.Text != "" { + sb.WriteString(part.Text) + sb.WriteByte('\n') + } + } + } + for _, tl := range input.Tools { + if tl == nil { + continue + } + cp := *tl + cp.Extra = nil + if text, err := sonic.MarshalString(cp); err == nil { + sb.WriteString(text) + sb.WriteByte('\n') + } + } + text := sb.String() + n, err := tc.Count(openAIModel, text) + if err != nil { + return (len(text) + 3) / 4, nil + } + return n, nil + } +} diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go index 49ff124b..bf1d7761 100644 --- a/internal/multiagent/runner.go +++ b/internal/multiagent/runner.go @@ -193,6 +193,11 @@ func RunDeepAgent( subMax = subDefaultIter } + subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, logger) + if err != nil { + return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err) + } + sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ Name: id, Description: desc, @@ -205,6 +210,7 @@ func RunDeepAgent( EmitInternalEvents: true, }, MaxIterations: subMax, + Handlers: []adk.ChatModelAgentMiddleware{subSumMw}, }) if err != nil { return nil, fmt.Errorf("子代理 %q: %w", id, err) @@ -217,6 +223,11 @@ func RunDeepAgent( return nil, fmt.Errorf("Deep 主模型: %w", err) } + mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, logger) + if err != nil { + return nil, fmt.Errorf("Deep 主代理 summarization 中间件: %w", err) + } + // 与 deep.Config.Name 一致。子代理的 assistant 正文也会经 EmitInternalEvents 流出,若全部当主回复会重复(编排器总结 + 子代理原文)。 orchestratorName := "cyberstrike-deep" orchDescription := "Coordinates specialist agents and MCP tools for authorized security testing." @@ -241,6 +252,7 @@ func RunDeepAgent( WithoutGeneralSubAgent: ma.WithoutGeneralSubAgent, WithoutWriteTodos: ma.WithoutWriteTodos, MaxIteration: deepMaxIter, + Handlers: []adk.ChatModelAgentMiddleware{mainSumMw}, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: mainTools, @@ -484,10 +496,11 @@ func historyToMessages(history []agent.ChatMessage) []adk.Message { if len(history) == 0 { return nil } - const maxTurns = 40 + // 放宽条数上限:跨轮历史交给 Eino Summarization(阈值对齐 openai.max_total_tokens)在调用模型前压缩,避免在入队前硬截断为 40 条。 + const maxHistoryMessages = 300 start := 0 - if len(history) > maxTurns { - start = len(history) - maxTurns + if len(history) > maxHistoryMessages { + start = len(history) - maxHistoryMessages } out := make([]adk.Message, 0, len(history[start:])) for _, h := range history[start:] {