diff --git a/go.mod b/go.mod index e5233bfe..7b165b58 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/gin-gonic/gin v1.9.1 github.com/google/uuid v1.5.0 github.com/mattn/go-sqlite3 v1.14.18 + github.com/pkoukk/tiktoken-go v0.1.8 go.uber.org/zap v1.26.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -13,6 +14,7 @@ require ( require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect diff --git a/go.sum b/go.sum index 2d559b00..404ea598 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583j github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -47,6 +49,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo= +github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index a2e3d75e..bc858106 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -24,6 +24,7 @@ type Agent struct { openAIClient *http.Client config *config.OpenAIConfig agentConfig *config.AgentConfig + memoryCompressor *MemoryCompressor mcpServer *mcp.Server externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 logger *zap.Logger @@ -89,13 +90,32 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer // 增加超时时间到30分钟,以支持长时间运行的AI推理 // 特别是当使用流式响应或处理复杂任务时 + httpClient := &http.Client{ + Timeout: 30 * time.Minute, // 从5分钟增加到30分钟 + Transport: transport, + } + + var memoryCompressor *MemoryCompressor + if cfg != nil { + mc, err := NewMemoryCompressor(MemoryCompressorConfig{ + OpenAIConfig: cfg, + HTTPClient: httpClient, + Logger: logger, + }) + if err != nil { + logger.Warn("初始化MemoryCompressor失败,将跳过上下文压缩", zap.Error(err)) + } else { + memoryCompressor = mc + } + } else { + logger.Warn("OpenAI配置为空,无法初始化MemoryCompressor") + } + return &Agent{ - openAIClient: &http.Client{ - Timeout: 30 * time.Minute, // 从5分钟增加到30分钟 - Transport: transport, - }, + openAIClient: httpClient, config: cfg, agentConfig: agentCfg, + memoryCompressor: memoryCompressor, mcpServer: mcpServer, externalMCPMgr: externalMCPMgr, logger: logger, @@ -417,12 +437,28 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his maxIterations := a.maxIterations for i := 0; i < maxIterations; i++ { + // 每轮调用前先尝试压缩,防止历史消息持续膨胀 + messages = a.applyMemoryCompression(ctx, messages) + // 检查是否是最后一次迭代 isLastIteration := (i == maxIterations-1) // 获取可用工具 tools := a.getAvailableTools() + // 记录当前上下文的Token用量,展示压缩器运行状态 + if a.memoryCompressor != nil { + totalTokens, systemCount, regularCount := a.memoryCompressor.totalTokensFor(messages) + a.logger.Info("memory compressor context stats", + zap.Int("iteration", i+1), + zap.Int("messagesCount", len(messages)), + zap.Int("systemMessages", systemCount), + zap.Int("regularMessages", regularCount), + zap.Int("totalTokens", totalTokens), + zap.Int("maxTotalTokens", a.memoryCompressor.maxTotalTokens), + ) + } + // 发送迭代开始事件 if i == 0 { sendProgress("iteration", "开始分析请求并制定测试策略", map[string]interface{}{ @@ -479,6 +515,25 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his } if response.Error != nil { + if handled, toolName := a.handleMissingToolError(response.Error.Message, &messages); handled { + sendProgress("warning", fmt.Sprintf("模型尝试调用不存在的工具:%s,已提示其改用可用工具。", toolName), map[string]interface{}{ + "toolName": toolName, + }) + a.logger.Warn("模型调用了不存在的工具,将重试", + zap.String("tool", toolName), + zap.String("error", response.Error.Message), + ) + continue + } + if a.handleToolRoleError(response.Error.Message, &messages) { + sendProgress("warning", "检测到未配对的工具结果,已自动修复上下文并重试。", map[string]interface{}{ + "error": response.Error.Message, + }) + a.logger.Warn("检测到未配对的工具消息,已修复并重试", + zap.String("error", response.Error.Message), + ) + continue + } result.Response = "" return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message) } @@ -601,6 +656,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his Role: "user", Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", }) + messages = a.applyMemoryCompression(ctx, messages) // 立即调用OpenAI获取总结 summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复 if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 { @@ -639,6 +695,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his Role: "user", Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", }) + messages = a.applyMemoryCompression(ctx, messages) // 立即调用OpenAI获取总结 summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复 if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 { @@ -674,6 +731,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his Content: fmt.Sprintf("已达到最大迭代次数(%d轮)。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", a.maxIterations), } messages = append(messages, finalSummaryPrompt) + messages = a.applyMemoryCompression(ctx, messages) summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复 if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 { @@ -1052,35 +1110,6 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to return &response, nil } -// parseToolCall 解析工具调用 -func (a *Agent) parseToolCall(content string) (map[string]interface{}, error) { - // 简单解析,实际应该更复杂 - // 格式: [TOOL_CALL]tool_name:arg1=value1,arg2=value2 - if !strings.HasPrefix(content, "[TOOL_CALL]") { - return nil, fmt.Errorf("不是有效的工具调用格式") - } - - parts := strings.Split(content[len("[TOOL_CALL]"):], ":") - if len(parts) < 2 { - return nil, fmt.Errorf("工具调用格式错误") - } - - toolName := strings.TrimSpace(parts[0]) - argsStr := strings.TrimSpace(parts[1]) - - args := make(map[string]interface{}) - argPairs := strings.Split(argsStr, ",") - for _, pair := range argPairs { - kv := strings.Split(pair, "=") - if len(kv) == 2 { - args[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) - } - } - - args["_tool_name"] = toolName - return args, nil -} - // ToolExecutionResult 工具执行结果 type ToolExecutionResult struct { Result string @@ -1286,3 +1315,144 @@ func (a *Agent) formatToolError(toolName string, args map[string]interface{}, er return errorMsg } + +// applyMemoryCompression 在调用LLM前对消息进行压缩,避免超过token限制 +func (a *Agent) applyMemoryCompression(ctx context.Context, messages []ChatMessage) []ChatMessage { + if a.memoryCompressor == nil { + return messages + } + + compressed, changed, err := a.memoryCompressor.CompressHistory(ctx, messages) + if err != nil { + a.logger.Warn("上下文压缩失败,将使用原始消息继续", zap.Error(err)) + return messages + } + if changed { + a.logger.Info("历史上下文已压缩", + zap.Int("originalMessages", len(messages)), + zap.Int("compressedMessages", len(compressed)), + ) + return compressed + } + + return messages +} + +// handleMissingToolError 当LLM调用不存在的工具时,向其追加提示消息并允许继续迭代 +func (a *Agent) handleMissingToolError(errMsg string, messages *[]ChatMessage) (bool, string) { + lowerMsg := strings.ToLower(errMsg) + if !(strings.Contains(lowerMsg, "non-exist tool") || strings.Contains(lowerMsg, "non exist tool")) { + return false, "" + } + + toolName := extractQuotedToolName(errMsg) + if toolName == "" { + toolName = "unknown_tool" + } + + notice := fmt.Sprintf("System notice: the previous call failed with error: %s. Please verify tool availability and proceed using existing tools or pure reasoning.", errMsg) + *messages = append(*messages, ChatMessage{ + Role: "user", + Content: notice, + }) + + return true, toolName +} + +// handleToolRoleError 自动修复因缺失tool_calls导致的OpenAI错误 +func (a *Agent) handleToolRoleError(errMsg string, messages *[]ChatMessage) bool { + if messages == nil { + return false + } + + lowerMsg := strings.ToLower(errMsg) + if !(strings.Contains(lowerMsg, "role 'tool'") && strings.Contains(lowerMsg, "tool_calls")) { + return false + } + + fixed := a.repairOrphanToolMessages(messages) + if !fixed { + return false + } + + notice := "System notice: the previous call failed because some tool outputs lost their corresponding assistant tool_calls context. The history has been repaired. Please continue." + *messages = append(*messages, ChatMessage{ + Role: "user", + Content: notice, + }) + + return true +} + +// repairOrphanToolMessages 清理失去配对的tool消息,避免OpenAI报错 +func (a *Agent) repairOrphanToolMessages(messages *[]ChatMessage) bool { + if messages == nil { + return false + } + + msgs := *messages + if len(msgs) == 0 { + return false + } + + pending := make(map[string]int) + cleaned := make([]ChatMessage, 0, len(msgs)) + removed := false + + for _, msg := range msgs { + switch strings.ToLower(msg.Role) { + case "assistant": + if len(msg.ToolCalls) > 0 { + for _, tc := range msg.ToolCalls { + if tc.ID != "" { + pending[tc.ID]++ + } + } + } + cleaned = append(cleaned, msg) + case "tool": + callID := msg.ToolCallID + if callID == "" { + removed = true + continue + } + if count, exists := pending[callID]; exists && count > 0 { + if count == 1 { + delete(pending, callID) + } else { + pending[callID] = count - 1 + } + cleaned = append(cleaned, msg) + } else { + removed = true + continue + } + default: + cleaned = append(cleaned, msg) + } + } + + if removed { + a.logger.Warn("移除了失配的tool消息以修复对话历史", + zap.Int("original_messages", len(msgs)), + zap.Int("cleaned_messages", len(cleaned)), + ) + *messages = cleaned + } + + return removed +} + +// extractQuotedToolName 尝试从错误信息中提取被引用的工具名称 +func extractQuotedToolName(errMsg string) string { + start := strings.Index(errMsg, "\"") + if start == -1 { + return "" + } + rest := errMsg[start+1:] + end := strings.Index(rest, "\"") + if end == -1 { + return "" + } + return rest[:end] +} diff --git a/internal/agent/memory_compressor.go b/internal/agent/memory_compressor.go new file mode 100644 index 00000000..e472d3e6 --- /dev/null +++ b/internal/agent/memory_compressor.go @@ -0,0 +1,445 @@ +package agent + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + + "github.com/pkoukk/tiktoken-go" + "go.uber.org/zap" +) + +const ( + DefaultMaxTotalTokens = 100_000 + DefaultMinRecentMessage = 10 + defaultChunkSize = 10 + defaultMaxImages = 3 + defaultSummaryTimeout = 10 * time.Minute + + summaryPromptTemplate = `你是一名负责为安全代理执行上下文压缩的助手,任务是在保持所有关键渗透信息完整的前提下压缩扫描数据。 + +必须保留的关键信息: +- 已发现的漏洞与潜在攻击路径 +- 扫描结果与工具输出(可压缩,但需保留核心发现) +- 获取到的访问凭证、令牌或认证细节 +- 系统架构洞察与潜在薄弱点 +- 当前评估进展 +- 失败尝试与死路(避免重复劳动) +- 关于测试策略的所有决策记录 + +压缩指南: +- 保留精确技术细节(URL、路径、参数、Payload 等) +- 将冗长的工具输出压缩成概述,但保留关键发现 +- 记录版本号与识别出的技术/组件信息 +- 保留可能暗示漏洞的原始报错 +- 将重复或相似发现整合成一条带有共性说明的结论 + +请牢记:另一位安全代理会依赖这份摘要继续测试,他必须在不损失任何作战上下文的情况下无缝接手。 + +需要压缩的对话片段: +%s + +请给出技术精准且简明扼要的摘要,覆盖全部与安全评估相关的上下文。` +) + +// MemoryCompressor 负责在调用LLM前压缩历史上下文,以避免Token爆炸。 +type MemoryCompressor struct { + maxTotalTokens int + minRecentMessage int + maxImages int + chunkSize int + summaryModel string + timeout time.Duration + + tokenCounter TokenCounter + completionClient CompletionClient + logger *zap.Logger +} + +// MemoryCompressorConfig 用于初始化 MemoryCompressor。 +type MemoryCompressorConfig struct { + MaxTotalTokens int + MinRecentMessage int + MaxImages int + ChunkSize int + SummaryModel string + Timeout time.Duration + TokenCounter TokenCounter + CompletionClient CompletionClient + Logger *zap.Logger + + // 当 CompletionClient 为空时,可以通过 OpenAIConfig + HTTPClient 构造默认的客户端。 + OpenAIConfig *config.OpenAIConfig + HTTPClient *http.Client +} + +// NewMemoryCompressor 创建新的 MemoryCompressor。 +func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error) { + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + + if cfg.MaxTotalTokens <= 0 { + cfg.MaxTotalTokens = DefaultMaxTotalTokens + } + if cfg.MinRecentMessage <= 0 { + cfg.MinRecentMessage = DefaultMinRecentMessage + } + if cfg.MaxImages <= 0 { + cfg.MaxImages = defaultMaxImages + } + if cfg.ChunkSize <= 0 { + cfg.ChunkSize = defaultChunkSize + } + if cfg.Timeout <= 0 { + cfg.Timeout = defaultSummaryTimeout + } + if cfg.SummaryModel == "" && cfg.OpenAIConfig != nil && cfg.OpenAIConfig.Model != "" { + cfg.SummaryModel = cfg.OpenAIConfig.Model + } + if cfg.TokenCounter == nil { + cfg.TokenCounter = NewTikTokenCounter() + } + + if cfg.CompletionClient == nil { + if cfg.OpenAIConfig == nil { + return nil, errors.New("memory compressor requires either CompletionClient or OpenAIConfig") + } + if cfg.HTTPClient == nil { + cfg.HTTPClient = &http.Client{ + Timeout: 5 * time.Minute, + } + } + cfg.CompletionClient = NewOpenAICompletionClient(cfg.OpenAIConfig, cfg.HTTPClient, cfg.Logger) + } + + return &MemoryCompressor{ + maxTotalTokens: cfg.MaxTotalTokens, + minRecentMessage: cfg.MinRecentMessage, + maxImages: cfg.MaxImages, + chunkSize: cfg.ChunkSize, + summaryModel: cfg.SummaryModel, + timeout: cfg.Timeout, + tokenCounter: cfg.TokenCounter, + completionClient: cfg.CompletionClient, + logger: cfg.Logger, + }, nil +} + +// CompressHistory 根据Token限制压缩历史消息。 +func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage) ([]ChatMessage, bool, error) { + if len(messages) == 0 { + return messages, false, nil + } + + mc.handleImages(messages) + + systemMsgs, regularMsgs := mc.splitMessages(messages) + if len(regularMsgs) <= mc.minRecentMessage { + return messages, false, nil + } + + totalTokens := mc.countTotalTokens(systemMsgs, regularMsgs) + if totalTokens <= int(float64(mc.maxTotalTokens)*0.9) { + return messages, false, nil + } + + recentStart := len(regularMsgs) - mc.minRecentMessage + recentStart = mc.adjustRecentStartForToolCalls(regularMsgs, recentStart) + oldMsgs := regularMsgs[:recentStart] + recentMsgs := regularMsgs[recentStart:] + + mc.logger.Info("memory compression triggered", + zap.Int("total_tokens", totalTokens), + zap.Int("max_total_tokens", mc.maxTotalTokens), + zap.Int("system_messages", len(systemMsgs)), + zap.Int("regular_messages", len(regularMsgs)), + zap.Int("old_messages", len(oldMsgs)), + zap.Int("recent_messages", len(recentMsgs))) + + var compressed []ChatMessage + for i := 0; i < len(oldMsgs); i += mc.chunkSize { + end := i + mc.chunkSize + if end > len(oldMsgs) { + end = len(oldMsgs) + } + chunk := oldMsgs[i:end] + if len(chunk) == 0 { + continue + } + summary, err := mc.summarizeChunk(ctx, chunk) + if err != nil { + mc.logger.Warn("chunk summary failed, fallback to raw chunk", + zap.Error(err), + zap.Int("start", i), + zap.Int("end", end)) + compressed = append(compressed, chunk...) + continue + } + compressed = append(compressed, summary) + } + + finalMessages := make([]ChatMessage, 0, len(systemMsgs)+len(compressed)+len(recentMsgs)) + finalMessages = append(finalMessages, systemMsgs...) + finalMessages = append(finalMessages, compressed...) + finalMessages = append(finalMessages, recentMsgs...) + + return finalMessages, true, nil +} + +func (mc *MemoryCompressor) handleImages(messages []ChatMessage) { + if mc.maxImages <= 0 { + return + } + count := 0 + for i := len(messages) - 1; i >= 0; i-- { + content := messages[i].Content + if !strings.Contains(content, "[IMAGE]") { + continue + } + count++ + if count > mc.maxImages { + messages[i].Content = "[Previously attached image removed to preserve context]" + } + } +} + +func (mc *MemoryCompressor) splitMessages(messages []ChatMessage) (systemMsgs, regularMsgs []ChatMessage) { + for _, msg := range messages { + if strings.EqualFold(msg.Role, "system") { + systemMsgs = append(systemMsgs, msg) + } else { + regularMsgs = append(regularMsgs, msg) + } + } + return +} + +func (mc *MemoryCompressor) countTotalTokens(systemMsgs, regularMsgs []ChatMessage) int { + total := 0 + for _, msg := range systemMsgs { + total += mc.countTokens(msg.Content) + } + for _, msg := range regularMsgs { + total += mc.countTokens(msg.Content) + } + return total +} + +func (mc *MemoryCompressor) countTokens(text string) int { + if mc.tokenCounter == nil { + return len(text) / 4 + } + count, err := mc.tokenCounter.Count(mc.summaryModel, text) + if err != nil { + return len(text) / 4 + } + return count +} + +// totalTokensFor provides token statistics without mutating the message list. +func (mc *MemoryCompressor) totalTokensFor(messages []ChatMessage) (totalTokens int, systemCount int, regularCount int) { + if len(messages) == 0 { + return 0, 0, 0 + } + systemMsgs, regularMsgs := mc.splitMessages(messages) + return mc.countTotalTokens(systemMsgs, regularMsgs), len(systemMsgs), len(regularMsgs) +} + +func (mc *MemoryCompressor) summarizeChunk(ctx context.Context, chunk []ChatMessage) (ChatMessage, error) { + if len(chunk) == 0 { + return ChatMessage{}, errors.New("chunk is empty") + } + formatted := make([]string, 0, len(chunk)) + for _, msg := range chunk { + formatted = append(formatted, fmt.Sprintf("%s: %s", msg.Role, mc.extractMessageText(msg))) + } + conversation := strings.Join(formatted, "\n") + prompt := fmt.Sprintf(summaryPromptTemplate, conversation) + + summary, err := mc.completionClient.Complete(ctx, mc.summaryModel, prompt, mc.timeout) + if err != nil { + return ChatMessage{}, err + } + summary = strings.TrimSpace(summary) + if summary == "" { + return chunk[0], nil + } + + return ChatMessage{ + Role: "assistant", + Content: fmt.Sprintf("%s", len(chunk), summary), + }, nil +} + +func (mc *MemoryCompressor) extractMessageText(msg ChatMessage) string { + return msg.Content +} + +func (mc *MemoryCompressor) adjustRecentStartForToolCalls(msgs []ChatMessage, recentStart int) int { + if recentStart <= 0 || recentStart >= len(msgs) { + return recentStart + } + + adjusted := recentStart + for adjusted > 0 && strings.EqualFold(msgs[adjusted].Role, "tool") { + adjusted-- + } + + if adjusted != recentStart { + mc.logger.Debug("adjusted recent window to keep tool call context", + zap.Int("original_recent_start", recentStart), + zap.Int("adjusted_recent_start", adjusted), + ) + } + + return adjusted +} + +// TokenCounter 用于计算文本Token数量。 +type TokenCounter interface { + Count(model, text string) (int, error) +} + +// TikTokenCounter 基于 tiktoken 的 Token 统计器。 +type TikTokenCounter struct { + mu sync.RWMutex + cache map[string]*tiktoken.Tiktoken + fallbackEncoding *tiktoken.Tiktoken +} + +// NewTikTokenCounter 创建新的 TikTokenCounter。 +func NewTikTokenCounter() *TikTokenCounter { + return &TikTokenCounter{ + cache: make(map[string]*tiktoken.Tiktoken), + } +} + +// Count 实现 TokenCounter 接口。 +func (tc *TikTokenCounter) Count(model, text string) (int, error) { + enc, err := tc.encodingForModel(model) + if err != nil { + return len(text) / 4, err + } + tokens := enc.Encode(text, nil, nil) + return len(tokens), nil +} + +func (tc *TikTokenCounter) encodingForModel(model string) (*tiktoken.Tiktoken, error) { + tc.mu.RLock() + if enc, ok := tc.cache[model]; ok { + tc.mu.RUnlock() + return enc, nil + } + tc.mu.RUnlock() + + tc.mu.Lock() + defer tc.mu.Unlock() + + if enc, ok := tc.cache[model]; ok { + return enc, nil + } + + enc, err := tiktoken.EncodingForModel(model) + if err != nil { + if tc.fallbackEncoding == nil { + tc.fallbackEncoding, err = tiktoken.GetEncoding("cl100k_base") + if err != nil { + return nil, err + } + } + tc.cache[model] = tc.fallbackEncoding + return tc.fallbackEncoding, nil + } + + tc.cache[model] = enc + return enc, nil +} + +// CompletionClient 对话压缩时使用的补全接口。 +type CompletionClient interface { + Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error) +} + +// OpenAICompletionClient 基于 OpenAI Chat Completion。 +type OpenAICompletionClient struct { + config *config.OpenAIConfig + httpClient *http.Client + logger *zap.Logger +} + +// NewOpenAICompletionClient 创建 OpenAICompletionClient。 +func NewOpenAICompletionClient(cfg *config.OpenAIConfig, client *http.Client, logger *zap.Logger) *OpenAICompletionClient { + if logger == nil { + logger = zap.NewNop() + } + return &OpenAICompletionClient{ + config: cfg, + httpClient: client, + logger: logger, + } +} + +// Complete 调用OpenAI获取摘要。 +func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error) { + if c.config == nil { + return "", errors.New("openai config is required") + } + + reqBody := OpenAIRequest{ + Model: model, + Messages: []ChatMessage{ + {Role: "user", Content: prompt}, + }, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return "", err + } + + requestCtx := ctx + var cancel context.CancelFunc + if timeout > 0 { + requestCtx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + + req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, c.config.BaseURL+"/chat/completions", bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.config.APIKey) + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("openai completion failed, status: %s", resp.Status) + } + + var completion OpenAIResponse + if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { + return "", err + } + if completion.Error != nil { + return "", errors.New(completion.Error.Message) + } + + if len(completion.Choices) == 0 || completion.Choices[0].Message.Content == "" { + return "", errors.New("empty completion response") + } + return completion.Choices[0].Message.Content, nil +}