mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
Removed the logic to ensure at least one user message is included in recent messages to avoid Qwen model error.
492 lines
14 KiB
Go
492 lines
14 KiB
Go
package agent
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/config"
|
||
"cyberstrike-ai/internal/openai"
|
||
|
||
"github.com/pkoukk/tiktoken-go"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
const (
|
||
// DefaultMinRecentMessage 压缩历史消息时保留的最近消息数量,确保最近的对话上下文不被压缩
|
||
DefaultMinRecentMessage = 5
|
||
// defaultChunkSize 压缩历史消息时每次处理的消息块大小,将旧消息分成多个块进行摘要
|
||
defaultChunkSize = 10
|
||
// defaultMaxImages 压缩时最多保留的图片数量,超过此数量的图片会被移除以节省上下文空间
|
||
defaultMaxImages = 3
|
||
// defaultSummaryTimeout 生成消息摘要时的超时时间
|
||
defaultSummaryTimeout = 10 * time.Minute
|
||
|
||
summaryPromptTemplate = `你是一名负责为安全代理执行上下文压缩的助手,任务是在保持所有关键渗透信息完整的前提下压缩扫描数据。
|
||
|
||
必须保留的关键信息:
|
||
- 已发现的漏洞与潜在攻击路径
|
||
- 扫描结果与工具输出(可压缩,但需保留核心发现)
|
||
- 获取到的访问凭证、令牌或认证细节
|
||
- 系统架构洞察与潜在薄弱点
|
||
- 当前评估进展
|
||
- 失败尝试与死路(避免重复劳动)
|
||
- 关于测试策略的所有决策记录
|
||
|
||
压缩指南:
|
||
- 保留精确技术细节(URL、路径、参数、Payload 等)
|
||
- 将冗长的工具输出压缩成概述,但保留关键发现
|
||
- 记录版本号与识别出的技术/组件信息
|
||
- 保留可能暗示漏洞的原始报错
|
||
- 将重复或相似发现整合成一条带有共性说明的结论
|
||
|
||
请牢记:另一位安全代理会依赖这份摘要继续测试,他必须在不损失任何作战上下文的情况下无缝接手。
|
||
|
||
需要压缩的对话片段:
|
||
%s
|
||
|
||
请给出技术精准且简明扼要的摘要,覆盖全部与安全评估相关的上下文。`
|
||
)
|
||
|
||
// MemoryCompressor 负责在调用LLM前压缩历史上下文,以避免Token爆炸。
|
||
type MemoryCompressor struct {
|
||
maxTotalTokens int
|
||
minRecentMessage int
|
||
maxImages int
|
||
chunkSize int
|
||
summaryModel string
|
||
timeout time.Duration
|
||
|
||
tokenCounter TokenCounter
|
||
completionClient CompletionClient
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// MemoryCompressorConfig 用于初始化 MemoryCompressor。
|
||
type MemoryCompressorConfig struct {
|
||
MaxTotalTokens int
|
||
MinRecentMessage int
|
||
MaxImages int
|
||
ChunkSize int
|
||
SummaryModel string
|
||
Timeout time.Duration
|
||
TokenCounter TokenCounter
|
||
CompletionClient CompletionClient
|
||
Logger *zap.Logger
|
||
|
||
// 当 CompletionClient 为空时,可以通过 OpenAIConfig + HTTPClient 构造默认的客户端。
|
||
OpenAIConfig *config.OpenAIConfig
|
||
HTTPClient *http.Client
|
||
}
|
||
|
||
// NewMemoryCompressor 创建新的 MemoryCompressor。
|
||
func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error) {
|
||
if cfg.Logger == nil {
|
||
cfg.Logger = zap.NewNop()
|
||
}
|
||
|
||
// 如果没有显式配置 MaxTotalTokens,则后续逻辑会根据模型的最大上下文长度进行控制;
|
||
// 优先推荐在 config.yaml 的 openai.max_total_tokens 中统一配置。
|
||
if cfg.MinRecentMessage <= 0 {
|
||
cfg.MinRecentMessage = DefaultMinRecentMessage
|
||
}
|
||
if cfg.MaxImages <= 0 {
|
||
cfg.MaxImages = defaultMaxImages
|
||
}
|
||
if cfg.ChunkSize <= 0 {
|
||
cfg.ChunkSize = defaultChunkSize
|
||
}
|
||
if cfg.Timeout <= 0 {
|
||
cfg.Timeout = defaultSummaryTimeout
|
||
}
|
||
if cfg.SummaryModel == "" && cfg.OpenAIConfig != nil && cfg.OpenAIConfig.Model != "" {
|
||
cfg.SummaryModel = cfg.OpenAIConfig.Model
|
||
}
|
||
if cfg.SummaryModel == "" {
|
||
return nil, errors.New("summary model is required (either SummaryModel or OpenAIConfig.Model must be set)")
|
||
}
|
||
if cfg.TokenCounter == nil {
|
||
cfg.TokenCounter = NewTikTokenCounter()
|
||
}
|
||
|
||
if cfg.CompletionClient == nil {
|
||
if cfg.OpenAIConfig == nil {
|
||
return nil, errors.New("memory compressor requires either CompletionClient or OpenAIConfig")
|
||
}
|
||
if cfg.HTTPClient == nil {
|
||
cfg.HTTPClient = &http.Client{
|
||
Timeout: 5 * time.Minute,
|
||
}
|
||
}
|
||
cfg.CompletionClient = NewOpenAICompletionClient(cfg.OpenAIConfig, cfg.HTTPClient, cfg.Logger)
|
||
}
|
||
|
||
return &MemoryCompressor{
|
||
maxTotalTokens: cfg.MaxTotalTokens,
|
||
minRecentMessage: cfg.MinRecentMessage,
|
||
maxImages: cfg.MaxImages,
|
||
chunkSize: cfg.ChunkSize,
|
||
summaryModel: cfg.SummaryModel,
|
||
timeout: cfg.Timeout,
|
||
tokenCounter: cfg.TokenCounter,
|
||
completionClient: cfg.CompletionClient,
|
||
logger: cfg.Logger,
|
||
}, nil
|
||
}
|
||
|
||
// UpdateConfig 更新OpenAI配置(用于动态更新模型配置)
|
||
func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) {
|
||
if cfg == nil {
|
||
return
|
||
}
|
||
|
||
// 更新summaryModel字段
|
||
if cfg.Model != "" {
|
||
mc.summaryModel = cfg.Model
|
||
}
|
||
|
||
// 更新completionClient中的配置(如果是OpenAICompletionClient)
|
||
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
|
||
openAIClient.UpdateConfig(cfg)
|
||
mc.logger.Info("MemoryCompressor配置已更新",
|
||
zap.String("model", cfg.Model),
|
||
)
|
||
}
|
||
}
|
||
|
||
// CompressHistory 根据 Token 限制压缩历史消息。reservedTokens 为预留给 tools 等非消息内容的 token 数,压缩时使用 (maxTotalTokens - reservedTokens) 作为消息上限。
|
||
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage, reservedTokens int) ([]ChatMessage, bool, error) {
|
||
if len(messages) == 0 {
|
||
return messages, false, nil
|
||
}
|
||
|
||
mc.handleImages(messages)
|
||
|
||
systemMsgs, regularMsgs := mc.splitMessages(messages)
|
||
if len(regularMsgs) <= mc.minRecentMessage {
|
||
return messages, false, nil
|
||
}
|
||
|
||
effectiveMax := mc.maxTotalTokens
|
||
if reservedTokens > 0 && reservedTokens < mc.maxTotalTokens {
|
||
effectiveMax = mc.maxTotalTokens - reservedTokens
|
||
}
|
||
|
||
totalTokens := mc.countTotalTokens(systemMsgs, regularMsgs)
|
||
if totalTokens <= int(float64(effectiveMax)*0.9) {
|
||
return messages, false, nil
|
||
}
|
||
|
||
recentStart := len(regularMsgs) - mc.minRecentMessage
|
||
recentStart = mc.adjustRecentStartForToolCalls(regularMsgs, recentStart)
|
||
oldMsgs := regularMsgs[:recentStart]
|
||
recentMsgs := regularMsgs[recentStart:]
|
||
|
||
mc.logger.Info("memory compression triggered",
|
||
zap.Int("total_tokens", totalTokens),
|
||
zap.Int("max_total_tokens", mc.maxTotalTokens),
|
||
zap.Int("reserved_tokens", reservedTokens),
|
||
zap.Int("effective_max", effectiveMax),
|
||
zap.Int("system_messages", len(systemMsgs)),
|
||
zap.Int("regular_messages", len(regularMsgs)),
|
||
zap.Int("old_messages", len(oldMsgs)),
|
||
zap.Int("recent_messages", len(recentMsgs)))
|
||
|
||
var compressed []ChatMessage
|
||
for i := 0; i < len(oldMsgs); i += mc.chunkSize {
|
||
end := i + mc.chunkSize
|
||
if end > len(oldMsgs) {
|
||
end = len(oldMsgs)
|
||
}
|
||
chunk := oldMsgs[i:end]
|
||
if len(chunk) == 0 {
|
||
continue
|
||
}
|
||
summary, err := mc.summarizeChunk(ctx, chunk)
|
||
if err != nil {
|
||
mc.logger.Warn("chunk summary failed, fallback to raw chunk",
|
||
zap.Error(err),
|
||
zap.Int("start", i),
|
||
zap.Int("end", end))
|
||
compressed = append(compressed, chunk...)
|
||
continue
|
||
}
|
||
compressed = append(compressed, summary)
|
||
}
|
||
|
||
finalMessages := make([]ChatMessage, 0, len(systemMsgs)+len(compressed)+len(recentMsgs))
|
||
finalMessages = append(finalMessages, systemMsgs...)
|
||
finalMessages = append(finalMessages, compressed...)
|
||
finalMessages = append(finalMessages, recentMsgs...)
|
||
|
||
return finalMessages, true, nil
|
||
}
|
||
|
||
func (mc *MemoryCompressor) handleImages(messages []ChatMessage) {
|
||
if mc.maxImages <= 0 {
|
||
return
|
||
}
|
||
count := 0
|
||
for i := len(messages) - 1; i >= 0; i-- {
|
||
content := messages[i].Content
|
||
if !strings.Contains(content, "[IMAGE]") {
|
||
continue
|
||
}
|
||
count++
|
||
if count > mc.maxImages {
|
||
messages[i].Content = "[Previously attached image removed to preserve context]"
|
||
}
|
||
}
|
||
}
|
||
|
||
func (mc *MemoryCompressor) splitMessages(messages []ChatMessage) (systemMsgs, regularMsgs []ChatMessage) {
|
||
for _, msg := range messages {
|
||
if strings.EqualFold(msg.Role, "system") {
|
||
systemMsgs = append(systemMsgs, msg)
|
||
} else {
|
||
regularMsgs = append(regularMsgs, msg)
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
func (mc *MemoryCompressor) countTotalTokens(systemMsgs, regularMsgs []ChatMessage) int {
|
||
total := 0
|
||
for _, msg := range systemMsgs {
|
||
total += mc.countTokens(msg.Content)
|
||
}
|
||
for _, msg := range regularMsgs {
|
||
total += mc.countTokens(msg.Content)
|
||
}
|
||
return total
|
||
}
|
||
|
||
// getModelName 获取当前使用的模型名称(优先从completionClient获取最新配置)
|
||
func (mc *MemoryCompressor) getModelName() string {
|
||
// 如果completionClient是OpenAICompletionClient,从它获取最新的模型名称
|
||
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
|
||
if openAIClient.config != nil && openAIClient.config.Model != "" {
|
||
return openAIClient.config.Model
|
||
}
|
||
}
|
||
// 否则使用保存的summaryModel
|
||
return mc.summaryModel
|
||
}
|
||
|
||
func (mc *MemoryCompressor) countTokens(text string) int {
|
||
if mc.tokenCounter == nil {
|
||
return len(text) / 4
|
||
}
|
||
modelName := mc.getModelName()
|
||
count, err := mc.tokenCounter.Count(modelName, text)
|
||
if err != nil {
|
||
return len(text) / 4
|
||
}
|
||
return count
|
||
}
|
||
|
||
// CountTextTokens 对外暴露的文本 Token 计数,用于统计 tools 等非消息内容的 token(如 agent 侧序列化 tools 后计数)。
|
||
func (mc *MemoryCompressor) CountTextTokens(text string) int {
|
||
return mc.countTokens(text)
|
||
}
|
||
|
||
// totalTokensFor provides token statistics without mutating the message list.
|
||
func (mc *MemoryCompressor) totalTokensFor(messages []ChatMessage) (totalTokens int, systemCount int, regularCount int) {
|
||
if len(messages) == 0 {
|
||
return 0, 0, 0
|
||
}
|
||
systemMsgs, regularMsgs := mc.splitMessages(messages)
|
||
return mc.countTotalTokens(systemMsgs, regularMsgs), len(systemMsgs), len(regularMsgs)
|
||
}
|
||
|
||
func (mc *MemoryCompressor) summarizeChunk(ctx context.Context, chunk []ChatMessage) (ChatMessage, error) {
|
||
if len(chunk) == 0 {
|
||
return ChatMessage{}, errors.New("chunk is empty")
|
||
}
|
||
formatted := make([]string, 0, len(chunk))
|
||
for _, msg := range chunk {
|
||
formatted = append(formatted, fmt.Sprintf("%s: %s", msg.Role, mc.extractMessageText(msg)))
|
||
}
|
||
conversation := strings.Join(formatted, "\n")
|
||
prompt := fmt.Sprintf(summaryPromptTemplate, conversation)
|
||
|
||
// 使用动态获取的模型名称,而不是保存的summaryModel
|
||
modelName := mc.getModelName()
|
||
summary, err := mc.completionClient.Complete(ctx, modelName, prompt, mc.timeout)
|
||
if err != nil {
|
||
return ChatMessage{}, err
|
||
}
|
||
summary = strings.TrimSpace(summary)
|
||
if summary == "" {
|
||
return chunk[0], nil
|
||
}
|
||
|
||
return ChatMessage{
|
||
Role: "assistant",
|
||
Content: fmt.Sprintf("<context_summary message_count='%d'>%s</context_summary>", len(chunk), summary),
|
||
}, nil
|
||
}
|
||
|
||
func (mc *MemoryCompressor) extractMessageText(msg ChatMessage) string {
|
||
return msg.Content
|
||
}
|
||
|
||
func (mc *MemoryCompressor) adjustRecentStartForToolCalls(msgs []ChatMessage, recentStart int) int {
|
||
if recentStart <= 0 || recentStart >= len(msgs) {
|
||
return recentStart
|
||
}
|
||
|
||
adjusted := recentStart
|
||
for adjusted > 0 && strings.EqualFold(msgs[adjusted].Role, "tool") {
|
||
adjusted--
|
||
}
|
||
|
||
if adjusted != recentStart {
|
||
mc.logger.Debug("adjusted recent window to keep tool call context",
|
||
zap.Int("original_recent_start", recentStart),
|
||
zap.Int("adjusted_recent_start", adjusted),
|
||
)
|
||
}
|
||
|
||
return adjusted
|
||
}
|
||
|
||
// TokenCounter 用于计算文本Token数量。
|
||
type TokenCounter interface {
|
||
Count(model, text string) (int, error)
|
||
}
|
||
|
||
// TikTokenCounter 基于 tiktoken 的 Token 统计器。
|
||
type TikTokenCounter struct {
|
||
mu sync.RWMutex
|
||
cache map[string]*tiktoken.Tiktoken
|
||
fallbackEncoding *tiktoken.Tiktoken
|
||
}
|
||
|
||
// NewTikTokenCounter 创建新的 TikTokenCounter。
|
||
func NewTikTokenCounter() *TikTokenCounter {
|
||
return &TikTokenCounter{
|
||
cache: make(map[string]*tiktoken.Tiktoken),
|
||
}
|
||
}
|
||
|
||
// Count 实现 TokenCounter 接口。
|
||
func (tc *TikTokenCounter) Count(model, text string) (int, error) {
|
||
enc, err := tc.encodingForModel(model)
|
||
if err != nil {
|
||
return len(text) / 4, err
|
||
}
|
||
tokens := enc.Encode(text, nil, nil)
|
||
return len(tokens), nil
|
||
}
|
||
|
||
func (tc *TikTokenCounter) encodingForModel(model string) (*tiktoken.Tiktoken, error) {
|
||
tc.mu.RLock()
|
||
if enc, ok := tc.cache[model]; ok {
|
||
tc.mu.RUnlock()
|
||
return enc, nil
|
||
}
|
||
tc.mu.RUnlock()
|
||
|
||
tc.mu.Lock()
|
||
defer tc.mu.Unlock()
|
||
|
||
if enc, ok := tc.cache[model]; ok {
|
||
return enc, nil
|
||
}
|
||
|
||
enc, err := tiktoken.EncodingForModel(model)
|
||
if err != nil {
|
||
if tc.fallbackEncoding == nil {
|
||
tc.fallbackEncoding, err = tiktoken.GetEncoding("cl100k_base")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
tc.cache[model] = tc.fallbackEncoding
|
||
return tc.fallbackEncoding, nil
|
||
}
|
||
|
||
tc.cache[model] = enc
|
||
return enc, nil
|
||
}
|
||
|
||
// CompletionClient 对话压缩时使用的补全接口。
|
||
type CompletionClient interface {
|
||
Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error)
|
||
}
|
||
|
||
// OpenAICompletionClient 基于 OpenAI Chat Completion。
|
||
type OpenAICompletionClient struct {
|
||
config *config.OpenAIConfig
|
||
client *openai.Client
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// NewOpenAICompletionClient 创建 OpenAICompletionClient。
|
||
func NewOpenAICompletionClient(cfg *config.OpenAIConfig, client *http.Client, logger *zap.Logger) *OpenAICompletionClient {
|
||
if logger == nil {
|
||
logger = zap.NewNop()
|
||
}
|
||
return &OpenAICompletionClient{
|
||
config: cfg,
|
||
client: openai.NewClient(cfg, client, logger),
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// UpdateConfig 更新底层配置。
|
||
func (c *OpenAICompletionClient) UpdateConfig(cfg *config.OpenAIConfig) {
|
||
c.config = cfg
|
||
if c.client != nil {
|
||
c.client.UpdateConfig(cfg)
|
||
}
|
||
}
|
||
|
||
// Complete 调用OpenAI获取摘要。
|
||
func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error) {
|
||
if c.config == nil {
|
||
return "", errors.New("openai config is required")
|
||
}
|
||
if model == "" {
|
||
return "", errors.New("model name is required")
|
||
}
|
||
|
||
reqBody := OpenAIRequest{
|
||
Model: model,
|
||
Messages: []ChatMessage{
|
||
{Role: "user", Content: prompt},
|
||
},
|
||
}
|
||
|
||
requestCtx := ctx
|
||
var cancel context.CancelFunc
|
||
if timeout > 0 {
|
||
requestCtx, cancel = context.WithTimeout(ctx, timeout)
|
||
defer cancel()
|
||
}
|
||
|
||
var completion OpenAIResponse
|
||
if c.client == nil {
|
||
return "", errors.New("openai completion client not initialized")
|
||
}
|
||
if err := c.client.ChatCompletion(requestCtx, reqBody, &completion); err != nil {
|
||
if apiErr, ok := err.(*openai.APIError); ok {
|
||
return "", fmt.Errorf("openai completion failed, status: %d, body: %s", apiErr.StatusCode, apiErr.Body)
|
||
}
|
||
return "", err
|
||
}
|
||
if completion.Error != nil {
|
||
return "", errors.New(completion.Error.Message)
|
||
}
|
||
|
||
if len(completion.Choices) == 0 || completion.Choices[0].Message.Content == "" {
|
||
return "", errors.New("empty completion response")
|
||
}
|
||
return completion.Choices[0].Message.Content, nil
|
||
}
|