mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-02 07:45:24 +02:00
348 lines
11 KiB
Go
348 lines
11 KiB
Go
package multiagent
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
|
||
"cyberstrike-ai/internal/agent"
|
||
"cyberstrike-ai/internal/config"
|
||
|
||
"github.com/bytedance/sonic"
|
||
"github.com/cloudwego/eino/adk"
|
||
"github.com/cloudwego/eino/adk/middlewares/summarization"
|
||
"github.com/cloudwego/eino/components/model"
|
||
"github.com/cloudwego/eino/schema"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// einoSummarizeUserInstruction 与单 Agent MemoryCompressor 目标一致:压缩时保留渗透关键信息。
|
||
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。
|
||
|
||
必须保留:已确认漏洞与攻击路径、工具输出中的核心发现、凭证与认证细节、架构与薄弱点、当前进度、失败尝试与死路、策略决策。
|
||
保留精确技术细节(URL、路径、参数、Payload、版本号、报错原文可摘要但要点不丢)。
|
||
将冗长扫描输出概括为结论;重复发现合并表述。
|
||
已枚举资产须保留**可继承的摘要**:主域、关键子域/主机短表(或数量+代表样例)、高价值目标与已识别服务/端口要点,避免后续子代理因「看不见清单」而重复全量枚举。
|
||
|
||
输出须使后续代理能无缝继续同一授权测试任务。`
|
||
|
||
// newEinoSummarizationMiddleware 使用 Eino ADK Summarization 中间件(见 https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_summarization/)。
|
||
// 触发阈值与单 Agent MemoryCompressor 一致:当估算 token 超过 openai.max_total_tokens 的 90% 时摘要。
|
||
func newEinoSummarizationMiddleware(
|
||
ctx context.Context,
|
||
summaryModel model.BaseChatModel,
|
||
appCfg *config.Config,
|
||
mwCfg *config.MultiAgentEinoMiddlewareConfig,
|
||
conversationID string,
|
||
logger *zap.Logger,
|
||
) (adk.ChatModelAgentMiddleware, error) {
|
||
if summaryModel == nil || appCfg == nil {
|
||
return nil, fmt.Errorf("multiagent: summarization 需要 model 与配置")
|
||
}
|
||
maxTotal := appCfg.OpenAI.MaxTotalTokens
|
||
if maxTotal <= 0 {
|
||
maxTotal = 120000
|
||
}
|
||
triggerRatio := 0.8
|
||
emitInternalEvents := true
|
||
if mwCfg != nil {
|
||
triggerRatio = mwCfg.SummarizationTriggerRatioEffective()
|
||
emitInternalEvents = mwCfg.SummarizationEmitInternalEventsEffective()
|
||
}
|
||
// Keep enough safety margin for tokenizer/model-side accounting mismatch.
|
||
trigger := int(float64(maxTotal) * triggerRatio)
|
||
if trigger < 4096 {
|
||
trigger = maxTotal
|
||
if trigger < 4096 {
|
||
trigger = 4096
|
||
}
|
||
}
|
||
preserveMax := trigger / 3
|
||
if preserveMax < 2048 {
|
||
preserveMax = 2048
|
||
}
|
||
|
||
modelName := strings.TrimSpace(appCfg.OpenAI.Model)
|
||
if modelName == "" {
|
||
modelName = "gpt-4o"
|
||
}
|
||
tokenCounter := einoSummarizationTokenCounter(modelName)
|
||
recentTrailMax := trigger / 4
|
||
if recentTrailMax < 2048 {
|
||
recentTrailMax = 2048
|
||
}
|
||
if recentTrailMax > trigger/2 {
|
||
recentTrailMax = trigger / 2
|
||
}
|
||
transcriptPath := ""
|
||
if conv := strings.TrimSpace(conversationID); conv != "" {
|
||
baseRoot := filepath.Join(os.TempDir(), "cyberstrike-summarization")
|
||
if dbPath := strings.TrimSpace(appCfg.Database.Path); dbPath != "" {
|
||
// Persist with the same lifecycle as local conversation storage.
|
||
baseRoot = filepath.Join(filepath.Dir(dbPath), "conversation_artifacts", sanitizeEinoPathSegment(conv), "summarization")
|
||
}
|
||
base := baseRoot
|
||
if mkErr := os.MkdirAll(base, 0o755); mkErr == nil {
|
||
transcriptPath = filepath.Join(base, "transcript.txt")
|
||
}
|
||
}
|
||
|
||
mw, err := summarization.New(ctx, &summarization.Config{
|
||
Model: summaryModel,
|
||
Trigger: &summarization.TriggerCondition{
|
||
ContextTokens: trigger,
|
||
},
|
||
TokenCounter: tokenCounter,
|
||
UserInstruction: einoSummarizeUserInstruction,
|
||
EmitInternalEvents: emitInternalEvents,
|
||
TranscriptFilePath: transcriptPath,
|
||
PreserveUserMessages: &summarization.PreserveUserMessages{
|
||
Enabled: true,
|
||
MaxTokens: preserveMax,
|
||
},
|
||
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
|
||
return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
|
||
},
|
||
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
|
||
if logger == nil {
|
||
return nil
|
||
}
|
||
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
|
||
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
|
||
logger.Info("eino summarization 已压缩上下文",
|
||
zap.Int("messages_before", len(before.Messages)),
|
||
zap.Int("messages_after", len(after.Messages)),
|
||
zap.Int("tokens_before_estimated", beforeTokens),
|
||
zap.Int("tokens_after_estimated", afterTokens),
|
||
zap.Int("max_total_tokens", maxTotal),
|
||
zap.Int("trigger_context_tokens", trigger),
|
||
zap.String("transcript_file", transcriptPath),
|
||
)
|
||
return nil
|
||
},
|
||
})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("summarization.New: %w", err)
|
||
}
|
||
return mw, nil
|
||
}
|
||
|
||
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
|
||
//
|
||
// 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。
|
||
// 把消息切成 round(回合)为原子单位:
|
||
// - user(...) 单条为一个 round;
|
||
// - assistant(tool_calls=[...]) 及其后连续的 role=tool 消息合成一个 round;
|
||
// - 其它 assistant(reply, 无 tool_calls) 单条为一个 round。
|
||
//
|
||
// 倒序挑 round(预算不够即放弃该 round),保证 tool 消息不会跨 round 被孤立。
|
||
func summarizeFinalizeWithRecentAssistantToolTrail(
|
||
ctx context.Context,
|
||
originalMessages []adk.Message,
|
||
summary adk.Message,
|
||
tokenCounter summarization.TokenCounterFunc,
|
||
recentTrailTokenBudget int,
|
||
) ([]adk.Message, error) {
|
||
systemMsgs := make([]adk.Message, 0, len(originalMessages))
|
||
nonSystem := make([]adk.Message, 0, len(originalMessages))
|
||
for _, msg := range originalMessages {
|
||
if msg == nil {
|
||
continue
|
||
}
|
||
if msg.Role == schema.System {
|
||
systemMsgs = append(systemMsgs, msg)
|
||
continue
|
||
}
|
||
nonSystem = append(nonSystem, msg)
|
||
}
|
||
|
||
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
|
||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||
out = append(out, systemMsgs...)
|
||
out = append(out, summary)
|
||
return out, nil
|
||
}
|
||
|
||
rounds := splitMessagesIntoRounds(nonSystem)
|
||
if len(rounds) == 0 {
|
||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||
out = append(out, systemMsgs...)
|
||
out = append(out, summary)
|
||
return out, nil
|
||
}
|
||
|
||
// 目标:至少保留 minRounds 个 round 的执行轨迹;在预算允许时尽量多保留。
|
||
// 优先确保最后一个 round(通常是最新的 tool 往返或 assistant 回复)存在。
|
||
const minRounds = 2
|
||
|
||
selectedRoundsReverse := make([]messageRound, 0, 8)
|
||
selectedCount := 0
|
||
totalTokens := 0
|
||
|
||
tokensOfRound := func(r messageRound) (int, error) {
|
||
if len(r.messages) == 0 {
|
||
return 0, nil
|
||
}
|
||
n, err := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: r.messages})
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
if n <= 0 {
|
||
n = len(r.messages)
|
||
}
|
||
return n, nil
|
||
}
|
||
|
||
for i := len(rounds) - 1; i >= 0; i-- {
|
||
r := rounds[i]
|
||
n, err := tokensOfRound(r)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 预算不够:已经保留了足够 round 则停,否则跳过该 round 继续往前找
|
||
// (避免一个超大 round 挤占全部预算,至少保证有轨迹)。
|
||
if totalTokens+n > recentTrailTokenBudget {
|
||
if selectedCount >= minRounds {
|
||
break
|
||
}
|
||
continue
|
||
}
|
||
totalTokens += n
|
||
selectedRoundsReverse = append(selectedRoundsReverse, r)
|
||
selectedCount++
|
||
}
|
||
|
||
// 还原时间顺序
|
||
selectedMsgs := make([]adk.Message, 0, 8)
|
||
for i := len(selectedRoundsReverse) - 1; i >= 0; i-- {
|
||
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
||
}
|
||
|
||
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs))
|
||
out = append(out, systemMsgs...)
|
||
out = append(out, summary)
|
||
out = append(out, selectedMsgs...)
|
||
return out, nil
|
||
}
|
||
|
||
// messageRound 表示一个"不可分割"的消息回合。
|
||
// - 对 assistant(tool_calls) + 随后若干 tool 消息的组合,round 内全部 call_id 成对完整;
|
||
// - 对独立的 user / assistant(reply) 消息,round 仅包含该条消息。
|
||
type messageRound struct {
|
||
messages []adk.Message
|
||
}
|
||
|
||
// splitMessagesIntoRounds 将非 system 消息切分为若干 round,保证:
|
||
// - 每个 assistant(tool_calls) 与其对应的 role=tool 响应消息在同一个 round;
|
||
// - 孤立(无对应 assistant(tool_calls))的 role=tool 消息不会单独成为 round,
|
||
// 而是被丢弃(这些消息在 pair 完整性层面已属孤儿,保留反而会触发 LLM 400)。
|
||
func splitMessagesIntoRounds(msgs []adk.Message) []messageRound {
|
||
if len(msgs) == 0 {
|
||
return nil
|
||
}
|
||
rounds := make([]messageRound, 0, len(msgs))
|
||
i := 0
|
||
for i < len(msgs) {
|
||
msg := msgs[i]
|
||
if msg == nil {
|
||
i++
|
||
continue
|
||
}
|
||
switch {
|
||
case msg.Role == schema.Assistant && len(msg.ToolCalls) > 0:
|
||
// 收集该 assistant 提供的 call_id 集合。
|
||
provided := make(map[string]struct{}, len(msg.ToolCalls))
|
||
for _, tc := range msg.ToolCalls {
|
||
if tc.ID != "" {
|
||
provided[tc.ID] = struct{}{}
|
||
}
|
||
}
|
||
round := messageRound{messages: []adk.Message{msg}}
|
||
j := i + 1
|
||
for j < len(msgs) {
|
||
next := msgs[j]
|
||
if next == nil {
|
||
j++
|
||
continue
|
||
}
|
||
if next.Role != schema.Tool {
|
||
break
|
||
}
|
||
if next.ToolCallID != "" {
|
||
if _, ok := provided[next.ToolCallID]; !ok {
|
||
// 下一条 tool 不属于当前 assistant,认为当前 round 结束。
|
||
break
|
||
}
|
||
}
|
||
round.messages = append(round.messages, next)
|
||
j++
|
||
}
|
||
rounds = append(rounds, round)
|
||
i = j
|
||
case msg.Role == schema.Tool:
|
||
// 孤儿 tool 消息:既不跟随在一个 assistant(tool_calls) 后,
|
||
// 说明它对应的 assistant 已被上游裁剪;直接丢弃,下一步到 orphan pruner
|
||
// 兜底也不会出错,但在 round 切分这里就剔除更干净。
|
||
i++
|
||
default:
|
||
// user / assistant(reply) / 其它:单条成 round。
|
||
rounds = append(rounds, messageRound{messages: []adk.Message{msg}})
|
||
i++
|
||
}
|
||
}
|
||
return rounds
|
||
}
|
||
|
||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||
tc := agent.NewTikTokenCounter()
|
||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||
var sb strings.Builder
|
||
for _, msg := range input.Messages {
|
||
if msg == nil {
|
||
continue
|
||
}
|
||
sb.WriteString(string(msg.Role))
|
||
sb.WriteByte('\n')
|
||
if msg.Content != "" {
|
||
sb.WriteString(msg.Content)
|
||
sb.WriteByte('\n')
|
||
}
|
||
if msg.ReasoningContent != "" {
|
||
sb.WriteString(msg.ReasoningContent)
|
||
sb.WriteByte('\n')
|
||
}
|
||
if len(msg.ToolCalls) > 0 {
|
||
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
|
||
sb.Write(b)
|
||
sb.WriteByte('\n')
|
||
}
|
||
}
|
||
for _, part := range msg.UserInputMultiContent {
|
||
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
|
||
sb.WriteString(part.Text)
|
||
sb.WriteByte('\n')
|
||
}
|
||
}
|
||
}
|
||
for _, tl := range input.Tools {
|
||
if tl == nil {
|
||
continue
|
||
}
|
||
cp := *tl
|
||
cp.Extra = nil
|
||
if text, err := sonic.MarshalString(cp); err == nil {
|
||
sb.WriteString(text)
|
||
sb.WriteByte('\n')
|
||
}
|
||
}
|
||
text := sb.String()
|
||
n, err := tc.Count(openAIModel, text)
|
||
if err != nil {
|
||
return (len(text) + 3) / 4, nil
|
||
}
|
||
return n, nil
|
||
}
|
||
}
|