diff --git a/internal/agent/memory_compressor.go b/internal/agent/memory_compressor.go deleted file mode 100644 index c830d1a9..00000000 --- a/internal/agent/memory_compressor.go +++ /dev/null @@ -1,491 +0,0 @@ -package agent - -import ( - "context" - "errors" - "fmt" - "net/http" - "strings" - "sync" - "time" - - "cyberstrike-ai/internal/config" - "cyberstrike-ai/internal/openai" - - "github.com/pkoukk/tiktoken-go" - "go.uber.org/zap" -) - -const ( - // DefaultMinRecentMessage 压缩历史消息时保留的最近消息数量,确保最近的对话上下文不被压缩 - DefaultMinRecentMessage = 5 - // defaultChunkSize 压缩历史消息时每次处理的消息块大小,将旧消息分成多个块进行摘要 - defaultChunkSize = 10 - // defaultMaxImages 压缩时最多保留的图片数量,超过此数量的图片会被移除以节省上下文空间 - defaultMaxImages = 3 - // defaultSummaryTimeout 生成消息摘要时的超时时间 - defaultSummaryTimeout = 10 * time.Minute - - summaryPromptTemplate = `你是一名负责为安全代理执行上下文压缩的助手,任务是在保持所有关键渗透信息完整的前提下压缩扫描数据。 - -必须保留的关键信息: -- 已发现的漏洞与潜在攻击路径 -- 扫描结果与工具输出(可压缩,但需保留核心发现) -- 获取到的访问凭证、令牌或认证细节 -- 系统架构洞察与潜在薄弱点 -- 当前评估进展 -- 失败尝试与死路(避免重复劳动) -- 关于测试策略的所有决策记录 - -压缩指南: -- 保留精确技术细节(URL、路径、参数、Payload 等) -- 将冗长的工具输出压缩成概述,但保留关键发现 -- 记录版本号与识别出的技术/组件信息 -- 保留可能暗示漏洞的原始报错 -- 将重复或相似发现整合成一条带有共性说明的结论 - -请牢记:另一位安全代理会依赖这份摘要继续测试,他必须在不损失任何作战上下文的情况下无缝接手。 - -需要压缩的对话片段: -%s - -请给出技术精准且简明扼要的摘要,覆盖全部与安全评估相关的上下文。` -) - -// MemoryCompressor 负责在调用LLM前压缩历史上下文,以避免Token爆炸。 -type MemoryCompressor struct { - maxTotalTokens int - minRecentMessage int - maxImages int - chunkSize int - summaryModel string - timeout time.Duration - - tokenCounter TokenCounter - completionClient CompletionClient - logger *zap.Logger -} - -// MemoryCompressorConfig 用于初始化 MemoryCompressor。 -type MemoryCompressorConfig struct { - MaxTotalTokens int - MinRecentMessage int - MaxImages int - ChunkSize int - SummaryModel string - Timeout time.Duration - TokenCounter TokenCounter - CompletionClient CompletionClient - Logger *zap.Logger - - // 当 CompletionClient 为空时,可以通过 OpenAIConfig + HTTPClient 构造默认的客户端。 - OpenAIConfig *config.OpenAIConfig - HTTPClient *http.Client -} - -// NewMemoryCompressor 创建新的 MemoryCompressor。 -func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error) { - if cfg.Logger == nil { - cfg.Logger = zap.NewNop() - } - - // 如果没有显式配置 MaxTotalTokens,则后续逻辑会根据模型的最大上下文长度进行控制; - // 优先推荐在 config.yaml 的 openai.max_total_tokens 中统一配置。 - if cfg.MinRecentMessage <= 0 { - cfg.MinRecentMessage = DefaultMinRecentMessage - } - if cfg.MaxImages <= 0 { - cfg.MaxImages = defaultMaxImages - } - if cfg.ChunkSize <= 0 { - cfg.ChunkSize = defaultChunkSize - } - if cfg.Timeout <= 0 { - cfg.Timeout = defaultSummaryTimeout - } - if cfg.SummaryModel == "" && cfg.OpenAIConfig != nil && cfg.OpenAIConfig.Model != "" { - cfg.SummaryModel = cfg.OpenAIConfig.Model - } - if cfg.SummaryModel == "" { - return nil, errors.New("summary model is required (either SummaryModel or OpenAIConfig.Model must be set)") - } - if cfg.TokenCounter == nil { - cfg.TokenCounter = NewTikTokenCounter() - } - - if cfg.CompletionClient == nil { - if cfg.OpenAIConfig == nil { - return nil, errors.New("memory compressor requires either CompletionClient or OpenAIConfig") - } - if cfg.HTTPClient == nil { - cfg.HTTPClient = &http.Client{ - Timeout: 5 * time.Minute, - } - } - cfg.CompletionClient = NewOpenAICompletionClient(cfg.OpenAIConfig, cfg.HTTPClient, cfg.Logger) - } - - return &MemoryCompressor{ - maxTotalTokens: cfg.MaxTotalTokens, - minRecentMessage: cfg.MinRecentMessage, - maxImages: cfg.MaxImages, - chunkSize: cfg.ChunkSize, - summaryModel: cfg.SummaryModel, - timeout: cfg.Timeout, - tokenCounter: cfg.TokenCounter, - completionClient: cfg.CompletionClient, - logger: cfg.Logger, - }, nil -} - -// UpdateConfig 更新OpenAI配置(用于动态更新模型配置) -func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) { - if cfg == nil { - return - } - - // 更新summaryModel字段 - if cfg.Model != "" { - mc.summaryModel = cfg.Model - } - - // 更新completionClient中的配置(如果是OpenAICompletionClient) - if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok { - openAIClient.UpdateConfig(cfg) - mc.logger.Info("MemoryCompressor配置已更新", - zap.String("model", cfg.Model), - ) - } -} - -// CompressHistory 根据 Token 限制压缩历史消息。reservedTokens 为预留给 tools 等非消息内容的 token 数,压缩时使用 (maxTotalTokens - reservedTokens) 作为消息上限。 -func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage, reservedTokens int) ([]ChatMessage, bool, error) { - if len(messages) == 0 { - return messages, false, nil - } - - mc.handleImages(messages) - - systemMsgs, regularMsgs := mc.splitMessages(messages) - if len(regularMsgs) <= mc.minRecentMessage { - return messages, false, nil - } - - effectiveMax := mc.maxTotalTokens - if reservedTokens > 0 && reservedTokens < mc.maxTotalTokens { - effectiveMax = mc.maxTotalTokens - reservedTokens - } - - totalTokens := mc.countTotalTokens(systemMsgs, regularMsgs) - if totalTokens <= int(float64(effectiveMax)*0.9) { - return messages, false, nil - } - - recentStart := len(regularMsgs) - mc.minRecentMessage - recentStart = mc.adjustRecentStartForToolCalls(regularMsgs, recentStart) - oldMsgs := regularMsgs[:recentStart] - recentMsgs := regularMsgs[recentStart:] - - mc.logger.Info("memory compression triggered", - zap.Int("total_tokens", totalTokens), - zap.Int("max_total_tokens", mc.maxTotalTokens), - zap.Int("reserved_tokens", reservedTokens), - zap.Int("effective_max", effectiveMax), - zap.Int("system_messages", len(systemMsgs)), - zap.Int("regular_messages", len(regularMsgs)), - zap.Int("old_messages", len(oldMsgs)), - zap.Int("recent_messages", len(recentMsgs))) - - var compressed []ChatMessage - for i := 0; i < len(oldMsgs); i += mc.chunkSize { - end := i + mc.chunkSize - if end > len(oldMsgs) { - end = len(oldMsgs) - } - chunk := oldMsgs[i:end] - if len(chunk) == 0 { - continue - } - summary, err := mc.summarizeChunk(ctx, chunk) - if err != nil { - mc.logger.Warn("chunk summary failed, fallback to raw chunk", - zap.Error(err), - zap.Int("start", i), - zap.Int("end", end)) - compressed = append(compressed, chunk...) - continue - } - compressed = append(compressed, summary) - } - - finalMessages := make([]ChatMessage, 0, len(systemMsgs)+len(compressed)+len(recentMsgs)) - finalMessages = append(finalMessages, systemMsgs...) - finalMessages = append(finalMessages, compressed...) - finalMessages = append(finalMessages, recentMsgs...) - - return finalMessages, true, nil -} - -func (mc *MemoryCompressor) handleImages(messages []ChatMessage) { - if mc.maxImages <= 0 { - return - } - count := 0 - for i := len(messages) - 1; i >= 0; i-- { - content := messages[i].Content - if !strings.Contains(content, "[IMAGE]") { - continue - } - count++ - if count > mc.maxImages { - messages[i].Content = "[Previously attached image removed to preserve context]" - } - } -} - -func (mc *MemoryCompressor) splitMessages(messages []ChatMessage) (systemMsgs, regularMsgs []ChatMessage) { - for _, msg := range messages { - if strings.EqualFold(msg.Role, "system") { - systemMsgs = append(systemMsgs, msg) - } else { - regularMsgs = append(regularMsgs, msg) - } - } - return -} - -func (mc *MemoryCompressor) countTotalTokens(systemMsgs, regularMsgs []ChatMessage) int { - total := 0 - for _, msg := range systemMsgs { - total += mc.countTokens(msg.Content) - } - for _, msg := range regularMsgs { - total += mc.countTokens(msg.Content) - } - return total -} - -// getModelName 获取当前使用的模型名称(优先从completionClient获取最新配置) -func (mc *MemoryCompressor) getModelName() string { - // 如果completionClient是OpenAICompletionClient,从它获取最新的模型名称 - if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok { - if openAIClient.config != nil && openAIClient.config.Model != "" { - return openAIClient.config.Model - } - } - // 否则使用保存的summaryModel - return mc.summaryModel -} - -func (mc *MemoryCompressor) countTokens(text string) int { - if mc.tokenCounter == nil { - return len(text) / 4 - } - modelName := mc.getModelName() - count, err := mc.tokenCounter.Count(modelName, text) - if err != nil { - return len(text) / 4 - } - return count -} - -// CountTextTokens 对外暴露的文本 Token 计数,用于统计 tools 等非消息内容的 token(如 agent 侧序列化 tools 后计数)。 -func (mc *MemoryCompressor) CountTextTokens(text string) int { - return mc.countTokens(text) -} - -// totalTokensFor provides token statistics without mutating the message list. -func (mc *MemoryCompressor) totalTokensFor(messages []ChatMessage) (totalTokens int, systemCount int, regularCount int) { - if len(messages) == 0 { - return 0, 0, 0 - } - systemMsgs, regularMsgs := mc.splitMessages(messages) - return mc.countTotalTokens(systemMsgs, regularMsgs), len(systemMsgs), len(regularMsgs) -} - -func (mc *MemoryCompressor) summarizeChunk(ctx context.Context, chunk []ChatMessage) (ChatMessage, error) { - if len(chunk) == 0 { - return ChatMessage{}, errors.New("chunk is empty") - } - formatted := make([]string, 0, len(chunk)) - for _, msg := range chunk { - formatted = append(formatted, fmt.Sprintf("%s: %s", msg.Role, mc.extractMessageText(msg))) - } - conversation := strings.Join(formatted, "\n") - prompt := fmt.Sprintf(summaryPromptTemplate, conversation) - - // 使用动态获取的模型名称,而不是保存的summaryModel - modelName := mc.getModelName() - summary, err := mc.completionClient.Complete(ctx, modelName, prompt, mc.timeout) - if err != nil { - return ChatMessage{}, err - } - summary = strings.TrimSpace(summary) - if summary == "" { - return chunk[0], nil - } - - return ChatMessage{ - Role: "assistant", - Content: fmt.Sprintf("%s", len(chunk), summary), - }, nil -} - -func (mc *MemoryCompressor) extractMessageText(msg ChatMessage) string { - return msg.Content -} - -func (mc *MemoryCompressor) adjustRecentStartForToolCalls(msgs []ChatMessage, recentStart int) int { - if recentStart <= 0 || recentStart >= len(msgs) { - return recentStart - } - - adjusted := recentStart - for adjusted > 0 && strings.EqualFold(msgs[adjusted].Role, "tool") { - adjusted-- - } - - if adjusted != recentStart { - mc.logger.Debug("adjusted recent window to keep tool call context", - zap.Int("original_recent_start", recentStart), - zap.Int("adjusted_recent_start", adjusted), - ) - } - - return adjusted -} - -// TokenCounter 用于计算文本Token数量。 -type TokenCounter interface { - Count(model, text string) (int, error) -} - -// TikTokenCounter 基于 tiktoken 的 Token 统计器。 -type TikTokenCounter struct { - mu sync.RWMutex - cache map[string]*tiktoken.Tiktoken - fallbackEncoding *tiktoken.Tiktoken -} - -// NewTikTokenCounter 创建新的 TikTokenCounter。 -func NewTikTokenCounter() *TikTokenCounter { - return &TikTokenCounter{ - cache: make(map[string]*tiktoken.Tiktoken), - } -} - -// Count 实现 TokenCounter 接口。 -func (tc *TikTokenCounter) Count(model, text string) (int, error) { - enc, err := tc.encodingForModel(model) - if err != nil { - return len(text) / 4, err - } - tokens := enc.Encode(text, nil, nil) - return len(tokens), nil -} - -func (tc *TikTokenCounter) encodingForModel(model string) (*tiktoken.Tiktoken, error) { - tc.mu.RLock() - if enc, ok := tc.cache[model]; ok { - tc.mu.RUnlock() - return enc, nil - } - tc.mu.RUnlock() - - tc.mu.Lock() - defer tc.mu.Unlock() - - if enc, ok := tc.cache[model]; ok { - return enc, nil - } - - enc, err := tiktoken.EncodingForModel(model) - if err != nil { - if tc.fallbackEncoding == nil { - tc.fallbackEncoding, err = tiktoken.GetEncoding("cl100k_base") - if err != nil { - return nil, err - } - } - tc.cache[model] = tc.fallbackEncoding - return tc.fallbackEncoding, nil - } - - tc.cache[model] = enc - return enc, nil -} - -// CompletionClient 对话压缩时使用的补全接口。 -type CompletionClient interface { - Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error) -} - -// OpenAICompletionClient 基于 OpenAI Chat Completion。 -type OpenAICompletionClient struct { - config *config.OpenAIConfig - client *openai.Client - logger *zap.Logger -} - -// NewOpenAICompletionClient 创建 OpenAICompletionClient。 -func NewOpenAICompletionClient(cfg *config.OpenAIConfig, client *http.Client, logger *zap.Logger) *OpenAICompletionClient { - if logger == nil { - logger = zap.NewNop() - } - return &OpenAICompletionClient{ - config: cfg, - client: openai.NewClient(cfg, client, logger), - logger: logger, - } -} - -// UpdateConfig 更新底层配置。 -func (c *OpenAICompletionClient) UpdateConfig(cfg *config.OpenAIConfig) { - c.config = cfg - if c.client != nil { - c.client.UpdateConfig(cfg) - } -} - -// Complete 调用OpenAI获取摘要。 -func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error) { - if c.config == nil { - return "", errors.New("openai config is required") - } - if model == "" { - return "", errors.New("model name is required") - } - - reqBody := OpenAIRequest{ - Model: model, - Messages: []ChatMessage{ - {Role: "user", Content: prompt}, - }, - } - - requestCtx := ctx - var cancel context.CancelFunc - if timeout > 0 { - requestCtx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - - var completion OpenAIResponse - if c.client == nil { - return "", errors.New("openai completion client not initialized") - } - if err := c.client.ChatCompletion(requestCtx, reqBody, &completion); err != nil { - if apiErr, ok := err.(*openai.APIError); ok { - return "", fmt.Errorf("openai completion failed, status: %d, body: %s", apiErr.StatusCode, apiErr.Body) - } - return "", err - } - if completion.Error != nil { - return "", errors.New(completion.Error.Message) - } - - if len(completion.Choices) == 0 || completion.Choices[0].Message.Content == "" { - return "", errors.New("empty completion response") - } - return completion.Choices[0].Message.Content, nil -}