Delete agent directory

This commit is contained in:
公明
2026-04-19 01:19:25 +08:00
committed by GitHub
parent db2c4e7689
commit a33f732d16
3 changed files with 0 additions and 2689 deletions
-1912
View File
File diff suppressed because it is too large Load Diff
-286
View File
@@ -1,286 +0,0 @@
package agent
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap"
)
// setupTestAgent 创建测试用的Agent
func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
agentCfg := &config.AgentConfig{
MaxIterations: 10,
LargeResultThreshold: 100, // 设置较小的阈值便于测试
ResultStorageDir: "",
}
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
// 创建测试存储
tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405"))
testStorage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
agent.SetResultStorage(testStorage)
return agent, testStorage
}
func TestAgent_FormatMinimalNotification(t *testing.T) {
agent, testStorage := setupTestAgent(t)
_ = testStorage // 避免未使用变量警告
executionID := "test_exec_001"
toolName := "nmap_scan"
size := 50000
lineCount := 1000
filePath := "tmp/test_exec_001.txt"
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
// 验证通知包含必要信息
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID: %s", executionID)
}
if !strings.Contains(notification, toolName) {
t.Errorf("通知中应该包含工具名称: %s", toolName)
}
if !strings.Contains(notification, "50000") {
t.Errorf("通知中应该包含大小信息")
}
if !strings.Contains(notification, "1000") {
t.Errorf("通知中应该包含行数信息")
}
if !strings.Contains(notification, "query_execution_result") {
t.Errorf("通知中应该包含查询工具的使用说明")
}
}
func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建模拟的MCP工具结果(大结果)
largeResult := &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: strings.Repeat("This is a test line with some content.\n", 1000), // 约50KB
},
},
IsError: false,
}
// 模拟MCP服务器返回大结果
// 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器
// 为了简化测试,我们直接测试结果处理逻辑
// 设置阈值
agent.mu.Lock()
agent.largeResultThreshold = 1000 // 设置较小的阈值
agent.mu.Unlock()
// 创建执行ID
executionID := "test_exec_large_001"
toolName := "test_tool"
// 格式化结果
var resultText strings.Builder
for _, content := range largeResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果并保存
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
// 保存大结果
err := storage.SaveResult(executionID, toolName, resultStr)
if err != nil {
t.Fatalf("保存大结果失败: %v", err)
}
// 生成通知
lines := strings.Split(resultStr, "\n")
filePath := storage.GetResultPath(executionID)
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
// 验证通知格式
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID")
}
// 验证结果已保存
savedResult, err := storage.GetResult(executionID)
if err != nil {
t.Fatalf("获取保存的结果失败: %v", err)
}
if savedResult != resultStr {
t.Errorf("保存的结果与原始结果不匹配")
}
} else {
t.Fatal("大结果应该被检测到并保存")
}
}
func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建小结果
smallResult := &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "Small result content",
},
},
IsError: false,
}
// 设置较大的阈值
agent.mu.Lock()
agent.largeResultThreshold = 100000 // 100KB
agent.mu.Unlock()
// 格式化结果
var resultText strings.Builder
for _, content := range smallResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
t.Fatal("小结果不应该被保存")
}
// 小结果应该直接返回
if resultSize <= threshold {
// 这是预期的行为
if resultStr == "" {
t.Fatal("小结果应该直接返回,不应该为空")
}
}
}
func TestAgent_SetResultStorage(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建新的存储
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop())
if err != nil {
t.Fatalf("创建新存储失败: %v", err)
}
// 设置新存储
agent.SetResultStorage(newStorage)
// 验证存储已更新
agent.mu.RLock()
currentStorage := agent.resultStorage
agent.mu.RUnlock()
if currentStorage != newStorage {
t.Fatal("存储未正确更新")
}
// 清理
os.RemoveAll(tmpDir)
}
func TestAgent_NewAgent_DefaultValues(t *testing.T) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
// 测试默认配置
agent := NewAgent(openAICfg, nil, mcpServer, nil, logger, 0)
if agent.maxIterations != 30 {
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
}
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 50*1024 {
t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold)
}
}
func TestAgent_NewAgent_CustomConfig(t *testing.T) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
agentCfg := &config.AgentConfig{
MaxIterations: 20,
LargeResultThreshold: 100 * 1024, // 100KB
ResultStorageDir: "custom_tmp",
}
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
if agent.maxIterations != 15 {
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
}
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 100*1024 {
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
}
}
-491
View File
@@ -1,491 +0,0 @@
package agent
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/openai"
"github.com/pkoukk/tiktoken-go"
"go.uber.org/zap"
)
const (
// DefaultMinRecentMessage 压缩历史消息时保留的最近消息数量,确保最近的对话上下文不被压缩
DefaultMinRecentMessage = 5
// defaultChunkSize 压缩历史消息时每次处理的消息块大小,将旧消息分成多个块进行摘要
defaultChunkSize = 10
// defaultMaxImages 压缩时最多保留的图片数量,超过此数量的图片会被移除以节省上下文空间
defaultMaxImages = 3
// defaultSummaryTimeout 生成消息摘要时的超时时间
defaultSummaryTimeout = 10 * time.Minute
summaryPromptTemplate = `你是一名负责为安全代理执行上下文压缩的助手,任务是在保持所有关键渗透信息完整的前提下压缩扫描数据。
必须保留的关键信息:
- 已发现的漏洞与潜在攻击路径
- 扫描结果与工具输出(可压缩,但需保留核心发现)
- 获取到的访问凭证、令牌或认证细节
- 系统架构洞察与潜在薄弱点
- 当前评估进展
- 失败尝试与死路(避免重复劳动)
- 关于测试策略的所有决策记录
压缩指南:
- 保留精确技术细节(URL、路径、参数、Payload 等)
- 将冗长的工具输出压缩成概述,但保留关键发现
- 记录版本号与识别出的技术/组件信息
- 保留可能暗示漏洞的原始报错
- 将重复或相似发现整合成一条带有共性说明的结论
请牢记:另一位安全代理会依赖这份摘要继续测试,他必须在不损失任何作战上下文的情况下无缝接手。
需要压缩的对话片段:
%s
请给出技术精准且简明扼要的摘要,覆盖全部与安全评估相关的上下文。`
)
// MemoryCompressor 负责在调用LLM前压缩历史上下文,以避免Token爆炸。
type MemoryCompressor struct {
maxTotalTokens int
minRecentMessage int
maxImages int
chunkSize int
summaryModel string
timeout time.Duration
tokenCounter TokenCounter
completionClient CompletionClient
logger *zap.Logger
}
// MemoryCompressorConfig 用于初始化 MemoryCompressor。
type MemoryCompressorConfig struct {
MaxTotalTokens int
MinRecentMessage int
MaxImages int
ChunkSize int
SummaryModel string
Timeout time.Duration
TokenCounter TokenCounter
CompletionClient CompletionClient
Logger *zap.Logger
// 当 CompletionClient 为空时,可以通过 OpenAIConfig + HTTPClient 构造默认的客户端。
OpenAIConfig *config.OpenAIConfig
HTTPClient *http.Client
}
// NewMemoryCompressor 创建新的 MemoryCompressor。
func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error) {
if cfg.Logger == nil {
cfg.Logger = zap.NewNop()
}
// 如果没有显式配置 MaxTotalTokens,则后续逻辑会根据模型的最大上下文长度进行控制;
// 优先推荐在 config.yaml 的 openai.max_total_tokens 中统一配置。
if cfg.MinRecentMessage <= 0 {
cfg.MinRecentMessage = DefaultMinRecentMessage
}
if cfg.MaxImages <= 0 {
cfg.MaxImages = defaultMaxImages
}
if cfg.ChunkSize <= 0 {
cfg.ChunkSize = defaultChunkSize
}
if cfg.Timeout <= 0 {
cfg.Timeout = defaultSummaryTimeout
}
if cfg.SummaryModel == "" && cfg.OpenAIConfig != nil && cfg.OpenAIConfig.Model != "" {
cfg.SummaryModel = cfg.OpenAIConfig.Model
}
if cfg.SummaryModel == "" {
return nil, errors.New("summary model is required (either SummaryModel or OpenAIConfig.Model must be set)")
}
if cfg.TokenCounter == nil {
cfg.TokenCounter = NewTikTokenCounter()
}
if cfg.CompletionClient == nil {
if cfg.OpenAIConfig == nil {
return nil, errors.New("memory compressor requires either CompletionClient or OpenAIConfig")
}
if cfg.HTTPClient == nil {
cfg.HTTPClient = &http.Client{
Timeout: 5 * time.Minute,
}
}
cfg.CompletionClient = NewOpenAICompletionClient(cfg.OpenAIConfig, cfg.HTTPClient, cfg.Logger)
}
return &MemoryCompressor{
maxTotalTokens: cfg.MaxTotalTokens,
minRecentMessage: cfg.MinRecentMessage,
maxImages: cfg.MaxImages,
chunkSize: cfg.ChunkSize,
summaryModel: cfg.SummaryModel,
timeout: cfg.Timeout,
tokenCounter: cfg.TokenCounter,
completionClient: cfg.CompletionClient,
logger: cfg.Logger,
}, nil
}
// UpdateConfig 更新OpenAI配置(用于动态更新模型配置)
func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) {
if cfg == nil {
return
}
// 更新summaryModel字段
if cfg.Model != "" {
mc.summaryModel = cfg.Model
}
// 更新completionClient中的配置(如果是OpenAICompletionClient
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
openAIClient.UpdateConfig(cfg)
mc.logger.Info("MemoryCompressor配置已更新",
zap.String("model", cfg.Model),
)
}
}
// CompressHistory 根据 Token 限制压缩历史消息。reservedTokens 为预留给 tools 等非消息内容的 token 数,压缩时使用 (maxTotalTokens - reservedTokens) 作为消息上限。
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage, reservedTokens int) ([]ChatMessage, bool, error) {
if len(messages) == 0 {
return messages, false, nil
}
mc.handleImages(messages)
systemMsgs, regularMsgs := mc.splitMessages(messages)
if len(regularMsgs) <= mc.minRecentMessage {
return messages, false, nil
}
effectiveMax := mc.maxTotalTokens
if reservedTokens > 0 && reservedTokens < mc.maxTotalTokens {
effectiveMax = mc.maxTotalTokens - reservedTokens
}
totalTokens := mc.countTotalTokens(systemMsgs, regularMsgs)
if totalTokens <= int(float64(effectiveMax)*0.9) {
return messages, false, nil
}
recentStart := len(regularMsgs) - mc.minRecentMessage
recentStart = mc.adjustRecentStartForToolCalls(regularMsgs, recentStart)
oldMsgs := regularMsgs[:recentStart]
recentMsgs := regularMsgs[recentStart:]
mc.logger.Info("memory compression triggered",
zap.Int("total_tokens", totalTokens),
zap.Int("max_total_tokens", mc.maxTotalTokens),
zap.Int("reserved_tokens", reservedTokens),
zap.Int("effective_max", effectiveMax),
zap.Int("system_messages", len(systemMsgs)),
zap.Int("regular_messages", len(regularMsgs)),
zap.Int("old_messages", len(oldMsgs)),
zap.Int("recent_messages", len(recentMsgs)))
var compressed []ChatMessage
for i := 0; i < len(oldMsgs); i += mc.chunkSize {
end := i + mc.chunkSize
if end > len(oldMsgs) {
end = len(oldMsgs)
}
chunk := oldMsgs[i:end]
if len(chunk) == 0 {
continue
}
summary, err := mc.summarizeChunk(ctx, chunk)
if err != nil {
mc.logger.Warn("chunk summary failed, fallback to raw chunk",
zap.Error(err),
zap.Int("start", i),
zap.Int("end", end))
compressed = append(compressed, chunk...)
continue
}
compressed = append(compressed, summary)
}
finalMessages := make([]ChatMessage, 0, len(systemMsgs)+len(compressed)+len(recentMsgs))
finalMessages = append(finalMessages, systemMsgs...)
finalMessages = append(finalMessages, compressed...)
finalMessages = append(finalMessages, recentMsgs...)
return finalMessages, true, nil
}
func (mc *MemoryCompressor) handleImages(messages []ChatMessage) {
if mc.maxImages <= 0 {
return
}
count := 0
for i := len(messages) - 1; i >= 0; i-- {
content := messages[i].Content
if !strings.Contains(content, "[IMAGE]") {
continue
}
count++
if count > mc.maxImages {
messages[i].Content = "[Previously attached image removed to preserve context]"
}
}
}
func (mc *MemoryCompressor) splitMessages(messages []ChatMessage) (systemMsgs, regularMsgs []ChatMessage) {
for _, msg := range messages {
if strings.EqualFold(msg.Role, "system") {
systemMsgs = append(systemMsgs, msg)
} else {
regularMsgs = append(regularMsgs, msg)
}
}
return
}
func (mc *MemoryCompressor) countTotalTokens(systemMsgs, regularMsgs []ChatMessage) int {
total := 0
for _, msg := range systemMsgs {
total += mc.countTokens(msg.Content)
}
for _, msg := range regularMsgs {
total += mc.countTokens(msg.Content)
}
return total
}
// getModelName 获取当前使用的模型名称(优先从completionClient获取最新配置)
func (mc *MemoryCompressor) getModelName() string {
// 如果completionClient是OpenAICompletionClient,从它获取最新的模型名称
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
if openAIClient.config != nil && openAIClient.config.Model != "" {
return openAIClient.config.Model
}
}
// 否则使用保存的summaryModel
return mc.summaryModel
}
func (mc *MemoryCompressor) countTokens(text string) int {
if mc.tokenCounter == nil {
return len(text) / 4
}
modelName := mc.getModelName()
count, err := mc.tokenCounter.Count(modelName, text)
if err != nil {
return len(text) / 4
}
return count
}
// CountTextTokens 对外暴露的文本 Token 计数,用于统计 tools 等非消息内容的 token(如 agent 侧序列化 tools 后计数)。
func (mc *MemoryCompressor) CountTextTokens(text string) int {
return mc.countTokens(text)
}
// totalTokensFor provides token statistics without mutating the message list.
func (mc *MemoryCompressor) totalTokensFor(messages []ChatMessage) (totalTokens int, systemCount int, regularCount int) {
if len(messages) == 0 {
return 0, 0, 0
}
systemMsgs, regularMsgs := mc.splitMessages(messages)
return mc.countTotalTokens(systemMsgs, regularMsgs), len(systemMsgs), len(regularMsgs)
}
func (mc *MemoryCompressor) summarizeChunk(ctx context.Context, chunk []ChatMessage) (ChatMessage, error) {
if len(chunk) == 0 {
return ChatMessage{}, errors.New("chunk is empty")
}
formatted := make([]string, 0, len(chunk))
for _, msg := range chunk {
formatted = append(formatted, fmt.Sprintf("%s: %s", msg.Role, mc.extractMessageText(msg)))
}
conversation := strings.Join(formatted, "\n")
prompt := fmt.Sprintf(summaryPromptTemplate, conversation)
// 使用动态获取的模型名称,而不是保存的summaryModel
modelName := mc.getModelName()
summary, err := mc.completionClient.Complete(ctx, modelName, prompt, mc.timeout)
if err != nil {
return ChatMessage{}, err
}
summary = strings.TrimSpace(summary)
if summary == "" {
return chunk[0], nil
}
return ChatMessage{
Role: "assistant",
Content: fmt.Sprintf("<context_summary message_count='%d'>%s</context_summary>", len(chunk), summary),
}, nil
}
func (mc *MemoryCompressor) extractMessageText(msg ChatMessage) string {
return msg.Content
}
func (mc *MemoryCompressor) adjustRecentStartForToolCalls(msgs []ChatMessage, recentStart int) int {
if recentStart <= 0 || recentStart >= len(msgs) {
return recentStart
}
adjusted := recentStart
for adjusted > 0 && strings.EqualFold(msgs[adjusted].Role, "tool") {
adjusted--
}
if adjusted != recentStart {
mc.logger.Debug("adjusted recent window to keep tool call context",
zap.Int("original_recent_start", recentStart),
zap.Int("adjusted_recent_start", adjusted),
)
}
return adjusted
}
// TokenCounter 用于计算文本Token数量。
type TokenCounter interface {
Count(model, text string) (int, error)
}
// TikTokenCounter 基于 tiktoken 的 Token 统计器。
type TikTokenCounter struct {
mu sync.RWMutex
cache map[string]*tiktoken.Tiktoken
fallbackEncoding *tiktoken.Tiktoken
}
// NewTikTokenCounter 创建新的 TikTokenCounter。
func NewTikTokenCounter() *TikTokenCounter {
return &TikTokenCounter{
cache: make(map[string]*tiktoken.Tiktoken),
}
}
// Count 实现 TokenCounter 接口。
func (tc *TikTokenCounter) Count(model, text string) (int, error) {
enc, err := tc.encodingForModel(model)
if err != nil {
return len(text) / 4, err
}
tokens := enc.Encode(text, nil, nil)
return len(tokens), nil
}
func (tc *TikTokenCounter) encodingForModel(model string) (*tiktoken.Tiktoken, error) {
tc.mu.RLock()
if enc, ok := tc.cache[model]; ok {
tc.mu.RUnlock()
return enc, nil
}
tc.mu.RUnlock()
tc.mu.Lock()
defer tc.mu.Unlock()
if enc, ok := tc.cache[model]; ok {
return enc, nil
}
enc, err := tiktoken.EncodingForModel(model)
if err != nil {
if tc.fallbackEncoding == nil {
tc.fallbackEncoding, err = tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, err
}
}
tc.cache[model] = tc.fallbackEncoding
return tc.fallbackEncoding, nil
}
tc.cache[model] = enc
return enc, nil
}
// CompletionClient 对话压缩时使用的补全接口。
type CompletionClient interface {
Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error)
}
// OpenAICompletionClient 基于 OpenAI Chat Completion。
type OpenAICompletionClient struct {
config *config.OpenAIConfig
client *openai.Client
logger *zap.Logger
}
// NewOpenAICompletionClient 创建 OpenAICompletionClient。
func NewOpenAICompletionClient(cfg *config.OpenAIConfig, client *http.Client, logger *zap.Logger) *OpenAICompletionClient {
if logger == nil {
logger = zap.NewNop()
}
return &OpenAICompletionClient{
config: cfg,
client: openai.NewClient(cfg, client, logger),
logger: logger,
}
}
// UpdateConfig 更新底层配置。
func (c *OpenAICompletionClient) UpdateConfig(cfg *config.OpenAIConfig) {
c.config = cfg
if c.client != nil {
c.client.UpdateConfig(cfg)
}
}
// Complete 调用OpenAI获取摘要。
func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error) {
if c.config == nil {
return "", errors.New("openai config is required")
}
if model == "" {
return "", errors.New("model name is required")
}
reqBody := OpenAIRequest{
Model: model,
Messages: []ChatMessage{
{Role: "user", Content: prompt},
},
}
requestCtx := ctx
var cancel context.CancelFunc
if timeout > 0 {
requestCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
var completion OpenAIResponse
if c.client == nil {
return "", errors.New("openai completion client not initialized")
}
if err := c.client.ChatCompletion(requestCtx, reqBody, &completion); err != nil {
if apiErr, ok := err.(*openai.APIError); ok {
return "", fmt.Errorf("openai completion failed, status: %d, body: %s", apiErr.StatusCode, apiErr.Body)
}
return "", err
}
if completion.Error != nil {
return "", errors.New(completion.Error.Message)
}
if len(completion.Choices) == 0 || completion.Choices[0].Message.Content == "" {
return "", errors.New("empty completion response")
}
return completion.Choices[0].Message.Content, nil
}