Add files via upload

This commit is contained in:
公明
2025-11-20 21:55:45 +08:00
committed by GitHub
parent b5a79ec83a
commit 6f19cac4d2
4 changed files with 654 additions and 33 deletions

2
go.mod
View File

@@ -6,6 +6,7 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.5.0
github.com/mattn/go-sqlite3 v1.14.18
github.com/pkoukk/tiktoken-go v0.1.8
go.uber.org/zap v1.26.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -13,6 +14,7 @@ require (
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect

4
go.sum
View File

@@ -7,6 +7,8 @@ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583j
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -47,6 +49,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View File

@@ -24,6 +24,7 @@ type Agent struct {
openAIClient *http.Client
config *config.OpenAIConfig
agentConfig *config.AgentConfig
memoryCompressor *MemoryCompressor
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
logger *zap.Logger
@@ -89,13 +90,32 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
// 增加超时时间到30分钟以支持长时间运行的AI推理
// 特别是当使用流式响应或处理复杂任务时
httpClient := &http.Client{
Timeout: 30 * time.Minute, // 从5分钟增加到30分钟
Transport: transport,
}
var memoryCompressor *MemoryCompressor
if cfg != nil {
mc, err := NewMemoryCompressor(MemoryCompressorConfig{
OpenAIConfig: cfg,
HTTPClient: httpClient,
Logger: logger,
})
if err != nil {
logger.Warn("初始化MemoryCompressor失败将跳过上下文压缩", zap.Error(err))
} else {
memoryCompressor = mc
}
} else {
logger.Warn("OpenAI配置为空无法初始化MemoryCompressor")
}
return &Agent{
openAIClient: &http.Client{
Timeout: 30 * time.Minute, // 从5分钟增加到30分钟
Transport: transport,
},
openAIClient: httpClient,
config: cfg,
agentConfig: agentCfg,
memoryCompressor: memoryCompressor,
mcpServer: mcpServer,
externalMCPMgr: externalMCPMgr,
logger: logger,
@@ -417,12 +437,28 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
maxIterations := a.maxIterations
for i := 0; i < maxIterations; i++ {
// 每轮调用前先尝试压缩,防止历史消息持续膨胀
messages = a.applyMemoryCompression(ctx, messages)
// 检查是否是最后一次迭代
isLastIteration := (i == maxIterations-1)
// 获取可用工具
tools := a.getAvailableTools()
// 记录当前上下文的Token用量展示压缩器运行状态
if a.memoryCompressor != nil {
totalTokens, systemCount, regularCount := a.memoryCompressor.totalTokensFor(messages)
a.logger.Info("memory compressor context stats",
zap.Int("iteration", i+1),
zap.Int("messagesCount", len(messages)),
zap.Int("systemMessages", systemCount),
zap.Int("regularMessages", regularCount),
zap.Int("totalTokens", totalTokens),
zap.Int("maxTotalTokens", a.memoryCompressor.maxTotalTokens),
)
}
// 发送迭代开始事件
if i == 0 {
sendProgress("iteration", "开始分析请求并制定测试策略", map[string]interface{}{
@@ -479,6 +515,25 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}
if response.Error != nil {
if handled, toolName := a.handleMissingToolError(response.Error.Message, &messages); handled {
sendProgress("warning", fmt.Sprintf("模型尝试调用不存在的工具:%s已提示其改用可用工具。", toolName), map[string]interface{}{
"toolName": toolName,
})
a.logger.Warn("模型调用了不存在的工具,将重试",
zap.String("tool", toolName),
zap.String("error", response.Error.Message),
)
continue
}
if a.handleToolRoleError(response.Error.Message, &messages) {
sendProgress("warning", "检测到未配对的工具结果,已自动修复上下文并重试。", map[string]interface{}{
"error": response.Error.Message,
})
a.logger.Warn("检测到未配对的工具消息,已修复并重试",
zap.String("error", response.Error.Message),
)
continue
}
result.Response = ""
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
}
@@ -601,6 +656,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
Role: "user",
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
})
messages = a.applyMemoryCompression(ctx, messages)
// 立即调用OpenAI获取总结
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
@@ -639,6 +695,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
Role: "user",
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
})
messages = a.applyMemoryCompression(ctx, messages)
// 立即调用OpenAI获取总结
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
@@ -674,6 +731,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
Content: fmt.Sprintf("已达到最大迭代次数(%d轮。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试请提供详细的下一步执行计划。请直接回复不要调用工具。", a.maxIterations),
}
messages = append(messages, finalSummaryPrompt)
messages = a.applyMemoryCompression(ctx, messages)
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
@@ -1052,35 +1110,6 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to
return &response, nil
}
// parseToolCall 解析工具调用
func (a *Agent) parseToolCall(content string) (map[string]interface{}, error) {
// 简单解析,实际应该更复杂
// 格式: [TOOL_CALL]tool_name:arg1=value1,arg2=value2
if !strings.HasPrefix(content, "[TOOL_CALL]") {
return nil, fmt.Errorf("不是有效的工具调用格式")
}
parts := strings.Split(content[len("[TOOL_CALL]"):], ":")
if len(parts) < 2 {
return nil, fmt.Errorf("工具调用格式错误")
}
toolName := strings.TrimSpace(parts[0])
argsStr := strings.TrimSpace(parts[1])
args := make(map[string]interface{})
argPairs := strings.Split(argsStr, ",")
for _, pair := range argPairs {
kv := strings.Split(pair, "=")
if len(kv) == 2 {
args[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
}
}
args["_tool_name"] = toolName
return args, nil
}
// ToolExecutionResult 工具执行结果
type ToolExecutionResult struct {
Result string
@@ -1286,3 +1315,144 @@ func (a *Agent) formatToolError(toolName string, args map[string]interface{}, er
return errorMsg
}
// applyMemoryCompression 在调用LLM前对消息进行压缩避免超过token限制
func (a *Agent) applyMemoryCompression(ctx context.Context, messages []ChatMessage) []ChatMessage {
if a.memoryCompressor == nil {
return messages
}
compressed, changed, err := a.memoryCompressor.CompressHistory(ctx, messages)
if err != nil {
a.logger.Warn("上下文压缩失败,将使用原始消息继续", zap.Error(err))
return messages
}
if changed {
a.logger.Info("历史上下文已压缩",
zap.Int("originalMessages", len(messages)),
zap.Int("compressedMessages", len(compressed)),
)
return compressed
}
return messages
}
// handleMissingToolError 当LLM调用不存在的工具时向其追加提示消息并允许继续迭代
func (a *Agent) handleMissingToolError(errMsg string, messages *[]ChatMessage) (bool, string) {
lowerMsg := strings.ToLower(errMsg)
if !(strings.Contains(lowerMsg, "non-exist tool") || strings.Contains(lowerMsg, "non exist tool")) {
return false, ""
}
toolName := extractQuotedToolName(errMsg)
if toolName == "" {
toolName = "unknown_tool"
}
notice := fmt.Sprintf("System notice: the previous call failed with error: %s. Please verify tool availability and proceed using existing tools or pure reasoning.", errMsg)
*messages = append(*messages, ChatMessage{
Role: "user",
Content: notice,
})
return true, toolName
}
// handleToolRoleError 自动修复因缺失tool_calls导致的OpenAI错误
func (a *Agent) handleToolRoleError(errMsg string, messages *[]ChatMessage) bool {
if messages == nil {
return false
}
lowerMsg := strings.ToLower(errMsg)
if !(strings.Contains(lowerMsg, "role 'tool'") && strings.Contains(lowerMsg, "tool_calls")) {
return false
}
fixed := a.repairOrphanToolMessages(messages)
if !fixed {
return false
}
notice := "System notice: the previous call failed because some tool outputs lost their corresponding assistant tool_calls context. The history has been repaired. Please continue."
*messages = append(*messages, ChatMessage{
Role: "user",
Content: notice,
})
return true
}
// repairOrphanToolMessages 清理失去配对的tool消息避免OpenAI报错
func (a *Agent) repairOrphanToolMessages(messages *[]ChatMessage) bool {
if messages == nil {
return false
}
msgs := *messages
if len(msgs) == 0 {
return false
}
pending := make(map[string]int)
cleaned := make([]ChatMessage, 0, len(msgs))
removed := false
for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
case "assistant":
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
if tc.ID != "" {
pending[tc.ID]++
}
}
}
cleaned = append(cleaned, msg)
case "tool":
callID := msg.ToolCallID
if callID == "" {
removed = true
continue
}
if count, exists := pending[callID]; exists && count > 0 {
if count == 1 {
delete(pending, callID)
} else {
pending[callID] = count - 1
}
cleaned = append(cleaned, msg)
} else {
removed = true
continue
}
default:
cleaned = append(cleaned, msg)
}
}
if removed {
a.logger.Warn("移除了失配的tool消息以修复对话历史",
zap.Int("original_messages", len(msgs)),
zap.Int("cleaned_messages", len(cleaned)),
)
*messages = cleaned
}
return removed
}
// extractQuotedToolName 尝试从错误信息中提取被引用的工具名称
func extractQuotedToolName(errMsg string) string {
start := strings.Index(errMsg, "\"")
if start == -1 {
return ""
}
rest := errMsg[start+1:]
end := strings.Index(rest, "\"")
if end == -1 {
return ""
}
return rest[:end]
}

View File

@@ -0,0 +1,445 @@
package agent
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"github.com/pkoukk/tiktoken-go"
"go.uber.org/zap"
)
const (
DefaultMaxTotalTokens = 100_000
DefaultMinRecentMessage = 10
defaultChunkSize = 10
defaultMaxImages = 3
defaultSummaryTimeout = 10 * time.Minute
summaryPromptTemplate = `你是一名负责为安全代理执行上下文压缩的助手,任务是在保持所有关键渗透信息完整的前提下压缩扫描数据。
必须保留的关键信息:
- 已发现的漏洞与潜在攻击路径
- 扫描结果与工具输出(可压缩,但需保留核心发现)
- 获取到的访问凭证、令牌或认证细节
- 系统架构洞察与潜在薄弱点
- 当前评估进展
- 失败尝试与死路(避免重复劳动)
- 关于测试策略的所有决策记录
压缩指南:
- 保留精确技术细节URL、路径、参数、Payload 等)
- 将冗长的工具输出压缩成概述,但保留关键发现
- 记录版本号与识别出的技术/组件信息
- 保留可能暗示漏洞的原始报错
- 将重复或相似发现整合成一条带有共性说明的结论
请牢记:另一位安全代理会依赖这份摘要继续测试,他必须在不损失任何作战上下文的情况下无缝接手。
需要压缩的对话片段:
%s
请给出技术精准且简明扼要的摘要,覆盖全部与安全评估相关的上下文。`
)
// MemoryCompressor 负责在调用LLM前压缩历史上下文以避免Token爆炸。
type MemoryCompressor struct {
maxTotalTokens int
minRecentMessage int
maxImages int
chunkSize int
summaryModel string
timeout time.Duration
tokenCounter TokenCounter
completionClient CompletionClient
logger *zap.Logger
}
// MemoryCompressorConfig 用于初始化 MemoryCompressor。
type MemoryCompressorConfig struct {
MaxTotalTokens int
MinRecentMessage int
MaxImages int
ChunkSize int
SummaryModel string
Timeout time.Duration
TokenCounter TokenCounter
CompletionClient CompletionClient
Logger *zap.Logger
// 当 CompletionClient 为空时,可以通过 OpenAIConfig + HTTPClient 构造默认的客户端。
OpenAIConfig *config.OpenAIConfig
HTTPClient *http.Client
}
// NewMemoryCompressor 创建新的 MemoryCompressor。
func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error) {
if cfg.Logger == nil {
cfg.Logger = zap.NewNop()
}
if cfg.MaxTotalTokens <= 0 {
cfg.MaxTotalTokens = DefaultMaxTotalTokens
}
if cfg.MinRecentMessage <= 0 {
cfg.MinRecentMessage = DefaultMinRecentMessage
}
if cfg.MaxImages <= 0 {
cfg.MaxImages = defaultMaxImages
}
if cfg.ChunkSize <= 0 {
cfg.ChunkSize = defaultChunkSize
}
if cfg.Timeout <= 0 {
cfg.Timeout = defaultSummaryTimeout
}
if cfg.SummaryModel == "" && cfg.OpenAIConfig != nil && cfg.OpenAIConfig.Model != "" {
cfg.SummaryModel = cfg.OpenAIConfig.Model
}
if cfg.TokenCounter == nil {
cfg.TokenCounter = NewTikTokenCounter()
}
if cfg.CompletionClient == nil {
if cfg.OpenAIConfig == nil {
return nil, errors.New("memory compressor requires either CompletionClient or OpenAIConfig")
}
if cfg.HTTPClient == nil {
cfg.HTTPClient = &http.Client{
Timeout: 5 * time.Minute,
}
}
cfg.CompletionClient = NewOpenAICompletionClient(cfg.OpenAIConfig, cfg.HTTPClient, cfg.Logger)
}
return &MemoryCompressor{
maxTotalTokens: cfg.MaxTotalTokens,
minRecentMessage: cfg.MinRecentMessage,
maxImages: cfg.MaxImages,
chunkSize: cfg.ChunkSize,
summaryModel: cfg.SummaryModel,
timeout: cfg.Timeout,
tokenCounter: cfg.TokenCounter,
completionClient: cfg.CompletionClient,
logger: cfg.Logger,
}, nil
}
// CompressHistory 根据Token限制压缩历史消息。
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage) ([]ChatMessage, bool, error) {
if len(messages) == 0 {
return messages, false, nil
}
mc.handleImages(messages)
systemMsgs, regularMsgs := mc.splitMessages(messages)
if len(regularMsgs) <= mc.minRecentMessage {
return messages, false, nil
}
totalTokens := mc.countTotalTokens(systemMsgs, regularMsgs)
if totalTokens <= int(float64(mc.maxTotalTokens)*0.9) {
return messages, false, nil
}
recentStart := len(regularMsgs) - mc.minRecentMessage
recentStart = mc.adjustRecentStartForToolCalls(regularMsgs, recentStart)
oldMsgs := regularMsgs[:recentStart]
recentMsgs := regularMsgs[recentStart:]
mc.logger.Info("memory compression triggered",
zap.Int("total_tokens", totalTokens),
zap.Int("max_total_tokens", mc.maxTotalTokens),
zap.Int("system_messages", len(systemMsgs)),
zap.Int("regular_messages", len(regularMsgs)),
zap.Int("old_messages", len(oldMsgs)),
zap.Int("recent_messages", len(recentMsgs)))
var compressed []ChatMessage
for i := 0; i < len(oldMsgs); i += mc.chunkSize {
end := i + mc.chunkSize
if end > len(oldMsgs) {
end = len(oldMsgs)
}
chunk := oldMsgs[i:end]
if len(chunk) == 0 {
continue
}
summary, err := mc.summarizeChunk(ctx, chunk)
if err != nil {
mc.logger.Warn("chunk summary failed, fallback to raw chunk",
zap.Error(err),
zap.Int("start", i),
zap.Int("end", end))
compressed = append(compressed, chunk...)
continue
}
compressed = append(compressed, summary)
}
finalMessages := make([]ChatMessage, 0, len(systemMsgs)+len(compressed)+len(recentMsgs))
finalMessages = append(finalMessages, systemMsgs...)
finalMessages = append(finalMessages, compressed...)
finalMessages = append(finalMessages, recentMsgs...)
return finalMessages, true, nil
}
func (mc *MemoryCompressor) handleImages(messages []ChatMessage) {
if mc.maxImages <= 0 {
return
}
count := 0
for i := len(messages) - 1; i >= 0; i-- {
content := messages[i].Content
if !strings.Contains(content, "[IMAGE]") {
continue
}
count++
if count > mc.maxImages {
messages[i].Content = "[Previously attached image removed to preserve context]"
}
}
}
func (mc *MemoryCompressor) splitMessages(messages []ChatMessage) (systemMsgs, regularMsgs []ChatMessage) {
for _, msg := range messages {
if strings.EqualFold(msg.Role, "system") {
systemMsgs = append(systemMsgs, msg)
} else {
regularMsgs = append(regularMsgs, msg)
}
}
return
}
func (mc *MemoryCompressor) countTotalTokens(systemMsgs, regularMsgs []ChatMessage) int {
total := 0
for _, msg := range systemMsgs {
total += mc.countTokens(msg.Content)
}
for _, msg := range regularMsgs {
total += mc.countTokens(msg.Content)
}
return total
}
func (mc *MemoryCompressor) countTokens(text string) int {
if mc.tokenCounter == nil {
return len(text) / 4
}
count, err := mc.tokenCounter.Count(mc.summaryModel, text)
if err != nil {
return len(text) / 4
}
return count
}
// totalTokensFor provides token statistics without mutating the message list.
func (mc *MemoryCompressor) totalTokensFor(messages []ChatMessage) (totalTokens int, systemCount int, regularCount int) {
if len(messages) == 0 {
return 0, 0, 0
}
systemMsgs, regularMsgs := mc.splitMessages(messages)
return mc.countTotalTokens(systemMsgs, regularMsgs), len(systemMsgs), len(regularMsgs)
}
func (mc *MemoryCompressor) summarizeChunk(ctx context.Context, chunk []ChatMessage) (ChatMessage, error) {
if len(chunk) == 0 {
return ChatMessage{}, errors.New("chunk is empty")
}
formatted := make([]string, 0, len(chunk))
for _, msg := range chunk {
formatted = append(formatted, fmt.Sprintf("%s: %s", msg.Role, mc.extractMessageText(msg)))
}
conversation := strings.Join(formatted, "\n")
prompt := fmt.Sprintf(summaryPromptTemplate, conversation)
summary, err := mc.completionClient.Complete(ctx, mc.summaryModel, prompt, mc.timeout)
if err != nil {
return ChatMessage{}, err
}
summary = strings.TrimSpace(summary)
if summary == "" {
return chunk[0], nil
}
return ChatMessage{
Role: "assistant",
Content: fmt.Sprintf("<context_summary message_count='%d'>%s</context_summary>", len(chunk), summary),
}, nil
}
func (mc *MemoryCompressor) extractMessageText(msg ChatMessage) string {
return msg.Content
}
func (mc *MemoryCompressor) adjustRecentStartForToolCalls(msgs []ChatMessage, recentStart int) int {
if recentStart <= 0 || recentStart >= len(msgs) {
return recentStart
}
adjusted := recentStart
for adjusted > 0 && strings.EqualFold(msgs[adjusted].Role, "tool") {
adjusted--
}
if adjusted != recentStart {
mc.logger.Debug("adjusted recent window to keep tool call context",
zap.Int("original_recent_start", recentStart),
zap.Int("adjusted_recent_start", adjusted),
)
}
return adjusted
}
// TokenCounter 用于计算文本Token数量。
type TokenCounter interface {
Count(model, text string) (int, error)
}
// TikTokenCounter 基于 tiktoken 的 Token 统计器。
type TikTokenCounter struct {
mu sync.RWMutex
cache map[string]*tiktoken.Tiktoken
fallbackEncoding *tiktoken.Tiktoken
}
// NewTikTokenCounter 创建新的 TikTokenCounter。
func NewTikTokenCounter() *TikTokenCounter {
return &TikTokenCounter{
cache: make(map[string]*tiktoken.Tiktoken),
}
}
// Count 实现 TokenCounter 接口。
func (tc *TikTokenCounter) Count(model, text string) (int, error) {
enc, err := tc.encodingForModel(model)
if err != nil {
return len(text) / 4, err
}
tokens := enc.Encode(text, nil, nil)
return len(tokens), nil
}
func (tc *TikTokenCounter) encodingForModel(model string) (*tiktoken.Tiktoken, error) {
tc.mu.RLock()
if enc, ok := tc.cache[model]; ok {
tc.mu.RUnlock()
return enc, nil
}
tc.mu.RUnlock()
tc.mu.Lock()
defer tc.mu.Unlock()
if enc, ok := tc.cache[model]; ok {
return enc, nil
}
enc, err := tiktoken.EncodingForModel(model)
if err != nil {
if tc.fallbackEncoding == nil {
tc.fallbackEncoding, err = tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, err
}
}
tc.cache[model] = tc.fallbackEncoding
return tc.fallbackEncoding, nil
}
tc.cache[model] = enc
return enc, nil
}
// CompletionClient 对话压缩时使用的补全接口。
type CompletionClient interface {
Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error)
}
// OpenAICompletionClient 基于 OpenAI Chat Completion。
type OpenAICompletionClient struct {
config *config.OpenAIConfig
httpClient *http.Client
logger *zap.Logger
}
// NewOpenAICompletionClient 创建 OpenAICompletionClient。
func NewOpenAICompletionClient(cfg *config.OpenAIConfig, client *http.Client, logger *zap.Logger) *OpenAICompletionClient {
if logger == nil {
logger = zap.NewNop()
}
return &OpenAICompletionClient{
config: cfg,
httpClient: client,
logger: logger,
}
}
// Complete 调用OpenAI获取摘要。
func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error) {
if c.config == nil {
return "", errors.New("openai config is required")
}
reqBody := OpenAIRequest{
Model: model,
Messages: []ChatMessage{
{Role: "user", Content: prompt},
},
}
body, err := json.Marshal(reqBody)
if err != nil {
return "", err
}
requestCtx := ctx
var cancel context.CancelFunc
if timeout > 0 {
requestCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, c.config.BaseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("openai completion failed, status: %s", resp.Status)
}
var completion OpenAIResponse
if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil {
return "", err
}
if completion.Error != nil {
return "", errors.New(completion.Error.Message)
}
if len(completion.Choices) == 0 || completion.Choices[0].Message.Content == "" {
return "", errors.New("empty completion response")
}
return completion.Choices[0].Message.Content, nil
}