From 7fb398730958f4b0f6ebf49b6793ec94b1520510 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sat, 15 Nov 2025 19:40:59 +0800 Subject: [PATCH] Add files via upload --- config.yaml | 7 +- internal/agent/agent.go | 117 ++++++- internal/agent/agent_test.go | 284 +++++++++++++++ internal/app/app.go | 26 +- internal/config/config.go | 4 +- internal/mcp/external_manager_test.go | 4 +- internal/security/executor.go | 267 +++++++++++++- internal/security/executor_test.go | 268 ++++++++++++++ internal/storage/result_storage.go | 270 +++++++++++++++ internal/storage/result_storage_test.go | 443 ++++++++++++++++++++++++ 10 files changed, 1668 insertions(+), 22 deletions(-) create mode 100644 internal/agent/agent_test.go create mode 100644 internal/security/executor_test.go create mode 100644 internal/storage/result_storage.go create mode 100644 internal/storage/result_storage_test.go diff --git a/config.yaml b/config.yaml index f9e88edd..77972063 100644 --- a/config.yaml +++ b/config.yaml @@ -20,7 +20,7 @@ log: # MCP 协议配置 # MCP (Model Context Protocol) 用于工具注册和调用 mcp: - enabled: true # 是否启用 MCP 服务器 + enabled: false # 是否启用 MCP 服务器(http模式) host: 0.0.0.0 # MCP 服务器监听地址 port: 8081 # MCP 服务器端口 # AI 模型配置(支持 OpenAI 兼容 API) @@ -38,6 +38,8 @@ openai: agent: max_iterations: 30 # 最大迭代次数,AI 代理最多执行多少轮工具调用 # 达到最大迭代次数时,AI 会自动总结测试结果 + large_result_threshold: 51200 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储 + result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下 # 数据库配置 database: path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息 @@ -47,4 +49,5 @@ security: # 系统会从该目录加载所有 .yaml 格式的工具配置文件 # 推荐方式:在 tools/ 目录下为每个工具创建独立的配置文件 # 外部MCP配置 -external_mcp: \ No newline at end of file +external_mcp: + servers: {} diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 99f6678a..d9a4b706 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -14,6 +14,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/storage" "go.uber.org/zap" ) @@ -21,21 +22,55 @@ import ( type Agent struct { openAIClient *http.Client config *config.OpenAIConfig + agentConfig *config.AgentConfig mcpServer *mcp.Server externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 logger *zap.Logger maxIterations int + resultStorage ResultStorage // 结果存储 + largeResultThreshold int // 大结果阈值(字节) mu sync.RWMutex // 添加互斥锁以支持并发更新 toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具) } +// ResultStorage 结果存储接口(直接使用 storage 包的类型) +type ResultStorage interface { + SaveResult(executionID string, toolName string, result string) error + GetResult(executionID string) (string, error) + GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error) + SearchResult(executionID string, keyword string) ([]string, error) + FilterResult(executionID string, filter string) ([]string, error) + GetResultMetadata(executionID string) (*storage.ResultMetadata, error) + DeleteResult(executionID string) error +} + // NewAgent 创建新的Agent -func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger, maxIterations int) *Agent { +func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer *mcp.Server, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger, maxIterations int) *Agent { // 如果 maxIterations 为 0 或负数,使用默认值 30 if maxIterations <= 0 { maxIterations = 30 } + // 设置大结果阈值,默认50KB + largeResultThreshold := 50 * 1024 + if agentCfg != nil && agentCfg.LargeResultThreshold > 0 { + largeResultThreshold = agentCfg.LargeResultThreshold + } + + // 设置结果存储目录,默认tmp + resultStorageDir := "tmp" + if agentCfg != nil && agentCfg.ResultStorageDir != "" { + resultStorageDir = agentCfg.ResultStorageDir + } + + // 初始化结果存储 + var resultStorage ResultStorage + if resultStorageDir != "" { + // 导入storage包(避免循环依赖,使用接口) + // 这里需要在实际使用时初始化 + // 暂时设为nil,在需要时初始化 + } + // 配置HTTP Transport,优化连接管理和超时设置 transport := &http.Transport{ DialContext: (&net.Dialer{ @@ -57,15 +92,25 @@ func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, externalMCPMgr *m Timeout: 30 * time.Minute, // 从5分钟增加到30分钟 Transport: transport, }, - config: cfg, - mcpServer: mcpServer, - externalMCPMgr: externalMCPMgr, - logger: logger, - maxIterations: maxIterations, - toolNameMapping: make(map[string]string), // 初始化工具名称映射 + config: cfg, + agentConfig: agentCfg, + mcpServer: mcpServer, + externalMCPMgr: externalMCPMgr, + logger: logger, + maxIterations: maxIterations, + resultStorage: resultStorage, + largeResultThreshold: largeResultThreshold, + toolNameMapping: make(map[string]string), // 初始化工具名称映射 } } +// SetResultStorage 设置结果存储(用于避免循环依赖) +func (a *Agent) SetResultStorage(storage ResultStorage) { + a.mu.Lock() + defer a.mu.Unlock() + a.resultStorage = storage +} + // ChatMessage 聊天消息 type ChatMessage struct { Role string `json:"role"` @@ -1037,14 +1082,70 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map resultText.WriteString(content.Text) resultText.WriteString("\n") } + + resultStr := resultText.String() + resultSize := len(resultStr) + + // 检测大结果并保存 + a.mu.RLock() + threshold := a.largeResultThreshold + storage := a.resultStorage + a.mu.RUnlock() + + if resultSize > threshold && storage != nil { + // 异步保存大结果 + go func() { + if err := storage.SaveResult(executionID, toolName, resultStr); err != nil { + a.logger.Warn("保存大结果失败", + zap.String("executionID", executionID), + zap.String("toolName", toolName), + zap.Error(err), + ) + } else { + a.logger.Info("大结果已保存", + zap.String("executionID", executionID), + zap.String("toolName", toolName), + zap.Int("size", resultSize), + ) + } + }() + + // 返回最小化通知 + lines := strings.Split(resultStr, "\n") + notification := a.formatMinimalNotification(executionID, toolName, resultSize, len(lines)) + + return &ToolExecutionResult{ + Result: notification, + ExecutionID: executionID, + IsError: result != nil && result.IsError, + }, nil + } return &ToolExecutionResult{ - Result: resultText.String(), + Result: resultStr, ExecutionID: executionID, IsError: result != nil && result.IsError, }, nil } +// formatMinimalNotification 格式化最小化通知 +func (a *Agent) formatMinimalNotification(executionID string, toolName string, size int, lineCount int) string { + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("工具执行完成。结果已保存(ID: %s)。\n\n", executionID)) + sb.WriteString("结果信息:\n") + sb.WriteString(fmt.Sprintf(" - 工具: %s\n", toolName)) + sb.WriteString(fmt.Sprintf(" - 大小: %d 字节 (%.2f KB)\n", size, float64(size)/1024)) + sb.WriteString(fmt.Sprintf(" - 行数: %d 行\n", lineCount)) + sb.WriteString("\n") + sb.WriteString("使用以下工具查询完整结果:\n") + sb.WriteString(fmt.Sprintf(" - 查询第一页: query_execution_result(execution_id=\"%s\", page=1, limit=100)\n", executionID)) + sb.WriteString(fmt.Sprintf(" - 搜索关键词: query_execution_result(execution_id=\"%s\", search=\"关键词\")\n", executionID)) + sb.WriteString(fmt.Sprintf(" - 过滤条件: query_execution_result(execution_id=\"%s\", filter=\"error\")\n", executionID)) + + return sb.String() +} + // UpdateConfig 更新OpenAI配置 func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) { a.mu.Lock() diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go new file mode 100644 index 00000000..c6ec9bd6 --- /dev/null +++ b/internal/agent/agent_test.go @@ -0,0 +1,284 @@ +package agent + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/storage" + + "go.uber.org/zap" +) + +// setupTestAgent 创建测试用的Agent +func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) { + logger := zap.NewNop() + mcpServer := mcp.NewServer(logger) + + openAICfg := &config.OpenAIConfig{ + APIKey: "test-key", + BaseURL: "https://api.test.com/v1", + Model: "test-model", + } + + agentCfg := &config.AgentConfig{ + MaxIterations: 10, + LargeResultThreshold: 100, // 设置较小的阈值便于测试 + ResultStorageDir: "", + } + + agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10) + + // 创建测试存储 + tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405")) + testStorage, err := storage.NewFileResultStorage(tmpDir, logger) + if err != nil { + t.Fatalf("创建测试存储失败: %v", err) + } + + agent.SetResultStorage(testStorage) + + return agent, testStorage +} + +func TestAgent_FormatMinimalNotification(t *testing.T) { + agent, testStorage := setupTestAgent(t) + _ = testStorage // 避免未使用变量警告 + + executionID := "test_exec_001" + toolName := "nmap_scan" + size := 50000 + lineCount := 1000 + + notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount) + + // 验证通知包含必要信息 + if !strings.Contains(notification, executionID) { + t.Errorf("通知中应该包含执行ID: %s", executionID) + } + + if !strings.Contains(notification, toolName) { + t.Errorf("通知中应该包含工具名称: %s", toolName) + } + + if !strings.Contains(notification, "50000") { + t.Errorf("通知中应该包含大小信息") + } + + if !strings.Contains(notification, "1000") { + t.Errorf("通知中应该包含行数信息") + } + + if !strings.Contains(notification, "query_execution_result") { + t.Errorf("通知中应该包含查询工具的使用说明") + } +} + +func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) { + agent, _ := setupTestAgent(t) + + // 创建模拟的MCP工具结果(大结果) + largeResult := &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: strings.Repeat("This is a test line with some content.\n", 1000), // 约50KB + }, + }, + IsError: false, + } + + // 模拟MCP服务器返回大结果 + // 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器 + // 为了简化测试,我们直接测试结果处理逻辑 + + // 设置阈值 + agent.mu.Lock() + agent.largeResultThreshold = 1000 // 设置较小的阈值 + agent.mu.Unlock() + + // 创建执行ID + executionID := "test_exec_large_001" + toolName := "test_tool" + + // 格式化结果 + var resultText strings.Builder + for _, content := range largeResult.Content { + resultText.WriteString(content.Text) + resultText.WriteString("\n") + } + + resultStr := resultText.String() + resultSize := len(resultStr) + + // 检测大结果并保存 + agent.mu.RLock() + threshold := agent.largeResultThreshold + storage := agent.resultStorage + agent.mu.RUnlock() + + if resultSize > threshold && storage != nil { + // 保存大结果 + err := storage.SaveResult(executionID, toolName, resultStr) + if err != nil { + t.Fatalf("保存大结果失败: %v", err) + } + + // 生成通知 + lines := strings.Split(resultStr, "\n") + notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines)) + + // 验证通知格式 + if !strings.Contains(notification, executionID) { + t.Errorf("通知中应该包含执行ID") + } + + // 验证结果已保存 + savedResult, err := storage.GetResult(executionID) + if err != nil { + t.Fatalf("获取保存的结果失败: %v", err) + } + + if savedResult != resultStr { + t.Errorf("保存的结果与原始结果不匹配") + } + } else { + t.Fatal("大结果应该被检测到并保存") + } +} + +func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) { + agent, _ := setupTestAgent(t) + + // 创建小结果 + smallResult := &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "Small result content", + }, + }, + IsError: false, + } + + // 设置较大的阈值 + agent.mu.Lock() + agent.largeResultThreshold = 100000 // 100KB + agent.mu.Unlock() + + // 格式化结果 + var resultText strings.Builder + for _, content := range smallResult.Content { + resultText.WriteString(content.Text) + resultText.WriteString("\n") + } + + resultStr := resultText.String() + resultSize := len(resultStr) + + // 检测大结果 + agent.mu.RLock() + threshold := agent.largeResultThreshold + storage := agent.resultStorage + agent.mu.RUnlock() + + if resultSize > threshold && storage != nil { + t.Fatal("小结果不应该被保存") + } + + // 小结果应该直接返回 + if resultSize <= threshold { + // 这是预期的行为 + if resultStr == "" { + t.Fatal("小结果应该直接返回,不应该为空") + } + } +} + +func TestAgent_SetResultStorage(t *testing.T) { + agent, _ := setupTestAgent(t) + + // 创建新的存储 + tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405")) + newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop()) + if err != nil { + t.Fatalf("创建新存储失败: %v", err) + } + + // 设置新存储 + agent.SetResultStorage(newStorage) + + // 验证存储已更新 + agent.mu.RLock() + currentStorage := agent.resultStorage + agent.mu.RUnlock() + + if currentStorage != newStorage { + t.Fatal("存储未正确更新") + } + + // 清理 + os.RemoveAll(tmpDir) +} + +func TestAgent_NewAgent_DefaultValues(t *testing.T) { + logger := zap.NewNop() + mcpServer := mcp.NewServer(logger) + + openAICfg := &config.OpenAIConfig{ + APIKey: "test-key", + BaseURL: "https://api.test.com/v1", + Model: "test-model", + } + + // 测试默认配置 + agent := NewAgent(openAICfg, nil, mcpServer, nil, logger, 0) + + if agent.maxIterations != 30 { + t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations) + } + + agent.mu.RLock() + threshold := agent.largeResultThreshold + agent.mu.RUnlock() + + if threshold != 50*1024 { + t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold) + } +} + +func TestAgent_NewAgent_CustomConfig(t *testing.T) { + logger := zap.NewNop() + mcpServer := mcp.NewServer(logger) + + openAICfg := &config.OpenAIConfig{ + APIKey: "test-key", + BaseURL: "https://api.test.com/v1", + Model: "test-model", + } + + agentCfg := &config.AgentConfig{ + MaxIterations: 20, + LargeResultThreshold: 100 * 1024, // 100KB + ResultStorageDir: "custom_tmp", + } + + agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15) + + if agent.maxIterations != 15 { + t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations) + } + + agent.mu.RLock() + threshold := agent.largeResultThreshold + agent.mu.RUnlock() + + if threshold != 100*1024 { + t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold) + } +} + diff --git a/internal/app/app.go b/internal/app/app.go index 2fecb409..f16d553e 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -13,6 +13,7 @@ import ( "cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/security" + "cyberstrike-ai/internal/storage" "github.com/gin-gonic/gin" "go.uber.org/zap" @@ -85,12 +86,35 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { externalMCPMgr.StartAllEnabled() } + // 初始化结果存储 + resultStorageDir := "tmp" + if cfg.Agent.ResultStorageDir != "" { + resultStorageDir = cfg.Agent.ResultStorageDir + } + + // 确保存储目录存在 + if err := os.MkdirAll(resultStorageDir, 0755); err != nil { + return nil, fmt.Errorf("创建结果存储目录失败: %w", err) + } + + // 创建结果存储实例 + resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger) + if err != nil { + return nil, fmt.Errorf("初始化结果存储失败: %w", err) + } + // 创建Agent maxIterations := cfg.Agent.MaxIterations if maxIterations <= 0 { maxIterations = 30 // 默认值 } - agent := agent.NewAgent(&cfg.OpenAI, mcpServer, externalMCPMgr, log.Logger, maxIterations) + agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations) + + // 设置结果存储到Agent + agent.SetResultStorage(resultStorage) + + // 设置结果存储到Executor(用于查询工具) + executor.SetResultStorage(resultStorage) // 获取配置文件路径 configPath := "config.yaml" diff --git a/internal/config/config.go b/internal/config/config.go index e02cef2c..5dc2eacf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -55,7 +55,9 @@ type DatabaseConfig struct { } type AgentConfig struct { - MaxIterations int `yaml:"max_iterations" json:"max_iterations"` + MaxIterations int `yaml:"max_iterations" json:"max_iterations"` + LargeResultThreshold int `yaml:"large_result_threshold" json:"large_result_threshold"` // 大结果阈值(字节),默认50KB + ResultStorageDir string `yaml:"result_storage_dir" json:"result_storage_dir"` // 结果存储目录,默认tmp } type AuthConfig struct { diff --git a/internal/mcp/external_manager_test.go b/internal/mcp/external_manager_test.go index 90542c1c..069af5c6 100644 --- a/internal/mcp/external_manager_test.go +++ b/internal/mcp/external_manager_test.go @@ -232,13 +232,13 @@ func TestExternalMCPManager_CallTool(t *testing.T) { manager := NewExternalMCPManager(logger) // 测试调用不存在的工具 - _, err := manager.CallTool(context.Background(), "nonexistent::tool", map[string]interface{}{}) + _, _, err := manager.CallTool(context.Background(), "nonexistent::tool", map[string]interface{}{}) if err == nil { t.Error("应该返回错误") } // 测试无效的工具名称格式 - _, err = manager.CallTool(context.Background(), "invalid-tool-name", map[string]interface{}{}) + _, _, err = manager.CallTool(context.Background(), "invalid-tool-name", map[string]interface{}{}) if err == nil { t.Error("应该返回错误(无效格式)") } diff --git a/internal/security/executor.go b/internal/security/executor.go index d88984b7..e10d258d 100644 --- a/internal/security/executor.go +++ b/internal/security/executor.go @@ -9,31 +9,50 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/storage" "go.uber.org/zap" ) // Executor 安全工具执行器 type Executor struct { - config *config.SecurityConfig - toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找 - mcpServer *mcp.Server - logger *zap.Logger + config *config.SecurityConfig + toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找 + mcpServer *mcp.Server + logger *zap.Logger + resultStorage ResultStorage // 结果存储(用于查询工具) +} + +// ResultStorage 结果存储接口(直接使用 storage 包的类型) +type ResultStorage interface { + SaveResult(executionID string, toolName string, result string) error + GetResult(executionID string) (string, error) + GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error) + SearchResult(executionID string, keyword string) ([]string, error) + FilterResult(executionID string, filter string) ([]string, error) + GetResultMetadata(executionID string) (*storage.ResultMetadata, error) + DeleteResult(executionID string) error } // NewExecutor 创建新的执行器 func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor { executor := &Executor{ - config: cfg, - toolIndex: make(map[string]*config.ToolConfig), - mcpServer: mcpServer, - logger: logger, + config: cfg, + toolIndex: make(map[string]*config.ToolConfig), + mcpServer: mcpServer, + logger: logger, + resultStorage: nil, // 稍后通过 SetResultStorage 设置 } // 构建工具索引 executor.buildToolIndex() return executor } +// SetResultStorage 设置结果存储 +func (e *Executor) SetResultStorage(storage ResultStorage) { + e.resultStorage = storage +} + // buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1) func (e *Executor) buildToolIndex() { e.toolIndex = make(map[string]*config.ToolConfig) @@ -78,6 +97,15 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st zap.Strings("args", toolConfig.Args), ) + // 特殊处理:内部工具(command 以 "internal:" 开头) + if strings.HasPrefix(toolConfig.Command, "internal:") { + e.logger.Info("执行内部工具", + zap.String("toolName", toolName), + zap.String("command", toolConfig.Command), + ) + return e.executeInternalTool(ctx, toolName, toolConfig.Command, args) + } + // 构建命令 - 根据工具类型使用不同的参数格式 cmdArgs := e.buildCommandArgs(toolName, toolConfig, args) @@ -653,6 +681,229 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int }, nil } +// executeInternalTool 执行内部工具(不执行外部命令) +func (e *Executor) executeInternalTool(ctx context.Context, toolName string, command string, args map[string]interface{}) (*mcp.ToolResult, error) { + // 提取内部工具类型(去掉 "internal:" 前缀) + internalToolType := strings.TrimPrefix(command, "internal:") + + e.logger.Info("执行内部工具", + zap.String("toolName", toolName), + zap.String("internalToolType", internalToolType), + zap.Any("args", args), + ) + + // 根据内部工具类型分发处理 + switch internalToolType { + case "query_execution_result": + return e.executeQueryExecutionResult(ctx, args) + default: + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("错误: 未知的内部工具类型: %s", internalToolType), + }, + }, + IsError: true, + }, nil + } +} + +// executeQueryExecutionResult 执行查询执行结果工具 +func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + // 获取 execution_id 参数 + executionID, ok := args["execution_id"].(string) + if !ok || executionID == "" { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: execution_id 参数必需且不能为空", + }, + }, + IsError: true, + }, nil + } + + // 获取可选参数 + page := 1 + if p, ok := args["page"].(float64); ok { + page = int(p) + } + if page < 1 { + page = 1 + } + + limit := 100 + if l, ok := args["limit"].(float64); ok { + limit = int(l) + } + if limit < 1 { + limit = 100 + } + if limit > 500 { + limit = 500 // 限制最大每页行数 + } + + search := "" + if s, ok := args["search"].(string); ok { + search = s + } + + filter := "" + if f, ok := args["filter"].(string); ok { + filter = f + } + + // 检查结果存储是否可用 + if e.resultStorage == nil { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: 结果存储未初始化", + }, + }, + IsError: true, + }, nil + } + + // 执行查询 + var resultPage *storage.ResultPage + var err error + + if search != "" { + // 搜索模式 + matchedLines, err := e.resultStorage.SearchResult(executionID, search) + if err != nil { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("搜索失败: %v", err), + }, + }, + IsError: true, + }, nil + } + // 对搜索结果进行分页 + resultPage = paginateLines(matchedLines, page, limit) + } else if filter != "" { + // 过滤模式 + filteredLines, err := e.resultStorage.FilterResult(executionID, filter) + if err != nil { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("过滤失败: %v", err), + }, + }, + IsError: true, + }, nil + } + // 对过滤结果进行分页 + resultPage = paginateLines(filteredLines, page, limit) + } else { + // 普通分页查询 + resultPage, err = e.resultStorage.GetResultPage(executionID, page, limit) + if err != nil { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("查询失败: %v", err), + }, + }, + IsError: true, + }, nil + } + } + + // 获取元信息 + metadata, err := e.resultStorage.GetResultMetadata(executionID) + if err != nil { + // 元信息获取失败不影响查询结果 + e.logger.Warn("获取结果元信息失败", zap.Error(err)) + } + + // 格式化返回结果 + var sb strings.Builder + sb.WriteString(fmt.Sprintf("查询结果 (执行ID: %s)\n", executionID)) + + if metadata != nil { + sb.WriteString(fmt.Sprintf("工具: %s | 大小: %d 字节 (%.2f KB) | 总行数: %d\n", + metadata.ToolName, metadata.TotalSize, float64(metadata.TotalSize)/1024, metadata.TotalLines)) + } + + sb.WriteString(fmt.Sprintf("第 %d/%d 页,每页 %d 行,共 %d 行\n\n", + resultPage.Page, resultPage.TotalPages, resultPage.Limit, resultPage.TotalLines)) + + if len(resultPage.Lines) == 0 { + sb.WriteString("没有找到匹配的结果。\n") + } else { + for i, line := range resultPage.Lines { + lineNum := (resultPage.Page-1)*resultPage.Limit + i + 1 + sb.WriteString(fmt.Sprintf("%d: %s\n", lineNum, line)) + } + } + + sb.WriteString("\n") + if resultPage.Page < resultPage.TotalPages { + sb.WriteString(fmt.Sprintf("提示: 使用 page=%d 查看下一页", resultPage.Page+1)) + if search != "" { + sb.WriteString(fmt.Sprintf(",或使用 search=\"%s\" 继续搜索", search)) + } + if filter != "" { + sb.WriteString(fmt.Sprintf(",或使用 filter=\"%s\" 继续过滤", filter)) + } + sb.WriteString("\n") + } + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: sb.String(), + }, + }, + IsError: false, + }, nil +} + +// paginateLines 对行列表进行分页 +func paginateLines(lines []string, page int, limit int) *storage.ResultPage { + totalLines := len(lines) + totalPages := (totalLines + limit - 1) / limit + if page < 1 { + page = 1 + } + if page > totalPages && totalPages > 0 { + page = totalPages + } + + start := (page - 1) * limit + end := start + limit + if end > totalLines { + end = totalLines + } + + var pageLines []string + if start < totalLines { + pageLines = lines[start:end] + } else { + pageLines = []string{} + } + + return &storage.ResultPage{ + Lines: pageLines, + Page: page, + Limit: limit, + TotalLines: totalLines, + TotalPages: totalPages, + } +} + // buildInputSchema 构建输入模式 func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} { schema := map[string]interface{}{ diff --git a/internal/security/executor_test.go b/internal/security/executor_test.go new file mode 100644 index 00000000..2885fcb4 --- /dev/null +++ b/internal/security/executor_test.go @@ -0,0 +1,268 @@ +package security + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/storage" + + "go.uber.org/zap" +) + +// setupTestExecutor 创建测试用的执行器 +func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) { + logger := zap.NewNop() + mcpServer := mcp.NewServer(logger) + + cfg := &config.SecurityConfig{ + Tools: []config.ToolConfig{}, + } + + executor := NewExecutor(cfg, mcpServer, logger) + return executor, mcpServer +} + +// setupTestStorage 创建测试用的存储 +func setupTestStorage(t *testing.T) *storage.FileResultStorage { + tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405")) + logger := zap.NewNop() + + storage, err := storage.NewFileResultStorage(tmpDir, logger) + if err != nil { + t.Fatalf("创建测试存储失败: %v", err) + } + + return storage +} + +func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) { + executor, _ := setupTestExecutor(t) + testStorage := setupTestStorage(t) + executor.SetResultStorage(testStorage) + + // 准备测试数据 + executionID := "test_exec_001" + toolName := "nmap_scan" + result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred" + + // 保存测试结果 + err := testStorage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存测试结果失败: %v", err) + } + + ctx := context.Background() + + // 测试1: 基本查询(第一页) + args := map[string]interface{}{ + "execution_id": executionID, + "page": float64(1), + "limit": float64(2), + } + + toolResult, err := executor.executeQueryExecutionResult(ctx, args) + if err != nil { + t.Fatalf("执行查询失败: %v", err) + } + + if toolResult.IsError { + t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text) + } + + // 验证结果包含预期内容 + resultText := toolResult.Content[0].Text + if !strings.Contains(resultText, executionID) { + t.Errorf("结果中应该包含执行ID: %s", executionID) + } + + if !strings.Contains(resultText, "第 1/") { + t.Errorf("结果中应该包含分页信息") + } + + // 测试2: 搜索功能 + args2 := map[string]interface{}{ + "execution_id": executionID, + "search": "error", + "page": float64(1), + "limit": float64(10), + } + + toolResult2, err := executor.executeQueryExecutionResult(ctx, args2) + if err != nil { + t.Fatalf("执行搜索失败: %v", err) + } + + if toolResult2.IsError { + t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text) + } + + resultText2 := toolResult2.Content[0].Text + if !strings.Contains(resultText2, "error") { + t.Errorf("搜索结果中应该包含关键词: error") + } + + // 测试3: 过滤功能 + args3 := map[string]interface{}{ + "execution_id": executionID, + "filter": "Port", + "page": float64(1), + "limit": float64(10), + } + + toolResult3, err := executor.executeQueryExecutionResult(ctx, args3) + if err != nil { + t.Fatalf("执行过滤失败: %v", err) + } + + if toolResult3.IsError { + t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text) + } + + resultText3 := toolResult3.Content[0].Text + if !strings.Contains(resultText3, "Port") { + t.Errorf("过滤结果中应该包含关键词: Port") + } + + // 测试4: 缺少必需参数 + args4 := map[string]interface{}{ + "page": float64(1), + } + + toolResult4, err := executor.executeQueryExecutionResult(ctx, args4) + if err != nil { + t.Fatalf("执行查询失败: %v", err) + } + + if !toolResult4.IsError { + t.Fatal("缺少execution_id应该返回错误") + } + + // 测试5: 不存在的执行ID + args5 := map[string]interface{}{ + "execution_id": "nonexistent_id", + "page": float64(1), + } + + toolResult5, err := executor.executeQueryExecutionResult(ctx, args5) + if err != nil { + t.Fatalf("执行查询失败: %v", err) + } + + if !toolResult5.IsError { + t.Fatal("不存在的执行ID应该返回错误") + } +} + +func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) { + executor, _ := setupTestExecutor(t) + + ctx := context.Background() + args := map[string]interface{}{ + "test": "value", + } + + // 测试未知的内部工具类型 + toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args) + if err != nil { + t.Fatalf("执行内部工具失败: %v", err) + } + + if !toolResult.IsError { + t.Fatal("未知的工具类型应该返回错误") + } + + if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") { + t.Errorf("错误消息应该包含'未知的内部工具类型'") + } +} + +func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) { + executor, _ := setupTestExecutor(t) + // 不设置存储,测试未初始化的情况 + + ctx := context.Background() + args := map[string]interface{}{ + "execution_id": "test_id", + } + + toolResult, err := executor.executeQueryExecutionResult(ctx, args) + if err != nil { + t.Fatalf("执行查询失败: %v", err) + } + + if !toolResult.IsError { + t.Fatal("未初始化的存储应该返回错误") + } + + if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") { + t.Errorf("错误消息应该包含'结果存储未初始化'") + } +} + +func TestPaginateLines(t *testing.T) { + lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"} + + // 测试第一页 + page := paginateLines(lines, 1, 2) + if page.Page != 1 { + t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page) + } + if page.Limit != 2 { + t.Errorf("每页行数不匹配。期望: 2, 实际: %d", page.Limit) + } + if page.TotalLines != 5 { + t.Errorf("总行数不匹配。期望: 5, 实际: %d", page.TotalLines) + } + if page.TotalPages != 3 { + t.Errorf("总页数不匹配。期望: 3, 实际: %d", page.TotalPages) + } + if len(page.Lines) != 2 { + t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines)) + } + + // 测试第二页 + page2 := paginateLines(lines, 2, 2) + if len(page2.Lines) != 2 { + t.Errorf("第二页行数不匹配。期望: 2, 实际: %d", len(page2.Lines)) + } + if page2.Lines[0] != "Line 3" { + t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0]) + } + + // 测试最后一页 + page3 := paginateLines(lines, 3, 2) + if len(page3.Lines) != 1 { + t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines)) + } + + // 测试超出范围的页码(应该返回最后一页) + page4 := paginateLines(lines, 4, 2) + if page4.Page != 3 { + t.Errorf("超出范围的页码应该被修正为最后一页。期望: 3, 实际: %d", page4.Page) + } + if len(page4.Lines) != 1 { + t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines)) + } + + // 测试无效页码(小于1) + page0 := paginateLines(lines, 0, 2) + if page0.Page != 1 { + t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page) + } + + // 测试空列表 + emptyPage := paginateLines([]string{}, 1, 10) + if emptyPage.TotalLines != 0 { + t.Errorf("空列表的总行数应该为0。实际: %d", emptyPage.TotalLines) + } + if len(emptyPage.Lines) != 0 { + t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines)) + } +} + diff --git a/internal/storage/result_storage.go b/internal/storage/result_storage.go new file mode 100644 index 00000000..e3df9e4e --- /dev/null +++ b/internal/storage/result_storage.go @@ -0,0 +1,270 @@ +package storage + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "go.uber.org/zap" +) + +// ResultStorage 结果存储接口 +type ResultStorage interface { + // SaveResult 保存工具执行结果 + SaveResult(executionID string, toolName string, result string) error + + // GetResult 获取完整结果 + GetResult(executionID string) (string, error) + + // GetResultPage 分页获取结果 + GetResultPage(executionID string, page int, limit int) (*ResultPage, error) + + // SearchResult 搜索结果 + SearchResult(executionID string, keyword string) ([]string, error) + + // FilterResult 过滤结果 + FilterResult(executionID string, filter string) ([]string, error) + + // GetResultMetadata 获取结果元信息 + GetResultMetadata(executionID string) (*ResultMetadata, error) + + // DeleteResult 删除结果 + DeleteResult(executionID string) error +} + +// ResultPage 分页结果 +type ResultPage struct { + Lines []string `json:"lines"` + Page int `json:"page"` + Limit int `json:"limit"` + TotalLines int `json:"total_lines"` + TotalPages int `json:"total_pages"` +} + +// ResultMetadata 结果元信息 +type ResultMetadata struct { + ExecutionID string `json:"execution_id"` + ToolName string `json:"tool_name"` + TotalSize int `json:"total_size"` + TotalLines int `json:"total_lines"` + CreatedAt time.Time `json:"created_at"` +} + +// FileResultStorage 基于文件的结果存储实现 +type FileResultStorage struct { + baseDir string + logger *zap.Logger + mu sync.RWMutex +} + +// NewFileResultStorage 创建新的文件结果存储 +func NewFileResultStorage(baseDir string, logger *zap.Logger) (*FileResultStorage, error) { + // 确保目录存在 + if err := os.MkdirAll(baseDir, 0755); err != nil { + return nil, fmt.Errorf("创建存储目录失败: %w", err) + } + + return &FileResultStorage{ + baseDir: baseDir, + logger: logger, + }, nil +} + +// getResultPath 获取结果文件路径 +func (s *FileResultStorage) getResultPath(executionID string) string { + return filepath.Join(s.baseDir, executionID+".txt") +} + +// getMetadataPath 获取元数据文件路径 +func (s *FileResultStorage) getMetadataPath(executionID string) string { + return filepath.Join(s.baseDir, executionID+".meta.json") +} + +// SaveResult 保存工具执行结果 +func (s *FileResultStorage) SaveResult(executionID string, toolName string, result string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // 保存结果文件 + resultPath := s.getResultPath(executionID) + if err := os.WriteFile(resultPath, []byte(result), 0644); err != nil { + return fmt.Errorf("保存结果文件失败: %w", err) + } + + // 计算统计信息 + lines := strings.Split(result, "\n") + metadata := &ResultMetadata{ + ExecutionID: executionID, + ToolName: toolName, + TotalSize: len(result), + TotalLines: len(lines), + CreatedAt: time.Now(), + } + + // 保存元数据 + metadataPath := s.getMetadataPath(executionID) + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf("序列化元数据失败: %w", err) + } + + if err := os.WriteFile(metadataPath, metadataJSON, 0644); err != nil { + return fmt.Errorf("保存元数据文件失败: %w", err) + } + + s.logger.Info("保存工具执行结果", + zap.String("executionID", executionID), + zap.String("toolName", toolName), + zap.Int("size", len(result)), + zap.Int("lines", len(lines)), + ) + + return nil +} + +// GetResult 获取完整结果 +func (s *FileResultStorage) GetResult(executionID string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + resultPath := s.getResultPath(executionID) + data, err := os.ReadFile(resultPath) + if err != nil { + if os.IsNotExist(err) { + return "", fmt.Errorf("结果不存在: %s", executionID) + } + return "", fmt.Errorf("读取结果文件失败: %w", err) + } + + return string(data), nil +} + +// GetResultMetadata 获取结果元信息 +func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetadata, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + metadataPath := s.getMetadataPath(executionID) + data, err := os.ReadFile(metadataPath) + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("结果不存在: %s", executionID) + } + return nil, fmt.Errorf("读取元数据文件失败: %w", err) + } + + var metadata ResultMetadata + if err := json.Unmarshal(data, &metadata); err != nil { + return nil, fmt.Errorf("解析元数据失败: %w", err) + } + + return &metadata, nil +} + +// GetResultPage 分页获取结果 +func (s *FileResultStorage) GetResultPage(executionID string, page int, limit int) (*ResultPage, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + // 获取完整结果 + result, err := s.GetResult(executionID) + if err != nil { + return nil, err + } + + // 分割为行 + lines := strings.Split(result, "\n") + totalLines := len(lines) + + // 计算分页 + totalPages := (totalLines + limit - 1) / limit + if page < 1 { + page = 1 + } + if page > totalPages && totalPages > 0 { + page = totalPages + } + + // 计算起始和结束索引 + start := (page - 1) * limit + end := start + limit + if end > totalLines { + end = totalLines + } + + // 提取指定页的行 + var pageLines []string + if start < totalLines { + pageLines = lines[start:end] + } else { + pageLines = []string{} + } + + return &ResultPage{ + Lines: pageLines, + Page: page, + Limit: limit, + TotalLines: totalLines, + TotalPages: totalPages, + }, nil +} + +// SearchResult 搜索结果 +func (s *FileResultStorage) SearchResult(executionID string, keyword string) ([]string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + // 获取完整结果 + result, err := s.GetResult(executionID) + if err != nil { + return nil, err + } + + // 分割为行并搜索 + lines := strings.Split(result, "\n") + var matchedLines []string + + for _, line := range lines { + if strings.Contains(line, keyword) { + matchedLines = append(matchedLines, line) + } + } + + return matchedLines, nil +} + +// FilterResult 过滤结果 +func (s *FileResultStorage) FilterResult(executionID string, filter string) ([]string, error) { + // 过滤和搜索逻辑相同,都是查找包含关键词的行 + return s.SearchResult(executionID, filter) +} + +// DeleteResult 删除结果 +func (s *FileResultStorage) DeleteResult(executionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + resultPath := s.getResultPath(executionID) + metadataPath := s.getMetadataPath(executionID) + + // 删除结果文件 + if err := os.Remove(resultPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("删除结果文件失败: %w", err) + } + + // 删除元数据文件 + if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("删除元数据文件失败: %w", err) + } + + s.logger.Info("删除工具执行结果", + zap.String("executionID", executionID), + ) + + return nil +} + diff --git a/internal/storage/result_storage_test.go b/internal/storage/result_storage_test.go new file mode 100644 index 00000000..aaf2bfa1 --- /dev/null +++ b/internal/storage/result_storage_test.go @@ -0,0 +1,443 @@ +package storage + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "go.uber.org/zap" +) + +// setupTestStorage 创建测试用的存储实例 +func setupTestStorage(t *testing.T) (*FileResultStorage, string) { + tmpDir := filepath.Join(os.TempDir(), "test_result_storage_"+time.Now().Format("20060102_150405")) + logger := zap.NewNop() + + storage, err := NewFileResultStorage(tmpDir, logger) + if err != nil { + t.Fatalf("创建测试存储失败: %v", err) + } + + return storage, tmpDir +} + +// cleanupTestStorage 清理测试数据 +func cleanupTestStorage(t *testing.T, tmpDir string) { + if err := os.RemoveAll(tmpDir); err != nil { + t.Logf("清理测试目录失败: %v", err) + } +} + +func TestNewFileResultStorage(t *testing.T) { + tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405")) + defer cleanupTestStorage(t, tmpDir) + + logger := zap.NewNop() + storage, err := NewFileResultStorage(tmpDir, logger) + if err != nil { + t.Fatalf("创建存储失败: %v", err) + } + + if storage == nil { + t.Fatal("存储实例为nil") + } + + // 验证目录已创建 + if _, err := os.Stat(tmpDir); os.IsNotExist(err) { + t.Fatal("存储目录未创建") + } +} + +func TestFileResultStorage_SaveResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_001" + toolName := "nmap_scan" + result := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 验证结果文件存在 + resultPath := filepath.Join(tmpDir, executionID+".txt") + if _, err := os.Stat(resultPath); os.IsNotExist(err) { + t.Fatal("结果文件未创建") + } + + // 验证元数据文件存在 + metadataPath := filepath.Join(tmpDir, executionID+".meta.json") + if _, err := os.Stat(metadataPath); os.IsNotExist(err) { + t.Fatal("元数据文件未创建") + } +} + +func TestFileResultStorage_GetResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_002" + toolName := "test_tool" + expectedResult := "Test result content\nLine 2\nLine 3" + + // 先保存结果 + err := storage.SaveResult(executionID, toolName, expectedResult) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 获取结果 + result, err := storage.GetResult(executionID) + if err != nil { + t.Fatalf("获取结果失败: %v", err) + } + + if result != expectedResult { + t.Errorf("结果不匹配。期望: %q, 实际: %q", expectedResult, result) + } + + // 测试不存在的执行ID + _, err = storage.GetResult("nonexistent_id") + if err == nil { + t.Fatal("应该返回错误") + } +} + +func TestFileResultStorage_GetResultMetadata(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_003" + toolName := "test_tool" + result := "Line 1\nLine 2\nLine 3" + + // 保存结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 获取元数据 + metadata, err := storage.GetResultMetadata(executionID) + if err != nil { + t.Fatalf("获取元数据失败: %v", err) + } + + if metadata.ExecutionID != executionID { + t.Errorf("执行ID不匹配。期望: %s, 实际: %s", executionID, metadata.ExecutionID) + } + + if metadata.ToolName != toolName { + t.Errorf("工具名称不匹配。期望: %s, 实际: %s", toolName, metadata.ToolName) + } + + if metadata.TotalSize != len(result) { + t.Errorf("总大小不匹配。期望: %d, 实际: %d", len(result), metadata.TotalSize) + } + + expectedLines := len(strings.Split(result, "\n")) + if metadata.TotalLines != expectedLines { + t.Errorf("总行数不匹配。期望: %d, 实际: %d", expectedLines, metadata.TotalLines) + } + + // 验证创建时间在合理范围内 + now := time.Now() + if metadata.CreatedAt.After(now) || metadata.CreatedAt.Before(now.Add(-time.Second)) { + t.Errorf("创建时间不在合理范围内: %v", metadata.CreatedAt) + } +} + +func TestFileResultStorage_GetResultPage(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_004" + toolName := "test_tool" + // 创建包含10行的结果 + lines := make([]string, 10) + for i := 0; i < 10; i++ { + lines[i] = fmt.Sprintf("Line %d", i+1) + } + result := strings.Join(lines, "\n") + + // 保存结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 测试第一页(每页3行) + page, err := storage.GetResultPage(executionID, 1, 3) + if err != nil { + t.Fatalf("获取第一页失败: %v", err) + } + + if page.Page != 1 { + t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page) + } + + if page.Limit != 3 { + t.Errorf("每页行数不匹配。期望: 3, 实际: %d", page.Limit) + } + + if page.TotalLines != 10 { + t.Errorf("总行数不匹配。期望: 10, 实际: %d", page.TotalLines) + } + + if page.TotalPages != 4 { + t.Errorf("总页数不匹配。期望: 4, 实际: %d", page.TotalPages) + } + + if len(page.Lines) != 3 { + t.Errorf("第一页行数不匹配。期望: 3, 实际: %d", len(page.Lines)) + } + + if page.Lines[0] != "Line 1" { + t.Errorf("第一行内容不匹配。期望: Line 1, 实际: %s", page.Lines[0]) + } + + // 测试第二页 + page2, err := storage.GetResultPage(executionID, 2, 3) + if err != nil { + t.Fatalf("获取第二页失败: %v", err) + } + + if len(page2.Lines) != 3 { + t.Errorf("第二页行数不匹配。期望: 3, 实际: %d", len(page2.Lines)) + } + + if page2.Lines[0] != "Line 4" { + t.Errorf("第二页第一行内容不匹配。期望: Line 4, 实际: %s", page2.Lines[0]) + } + + // 测试最后一页(可能不满一页) + page4, err := storage.GetResultPage(executionID, 4, 3) + if err != nil { + t.Fatalf("获取第四页失败: %v", err) + } + + if len(page4.Lines) != 1 { + t.Errorf("第四页行数不匹配。期望: 1, 实际: %d", len(page4.Lines)) + } + + // 测试超出范围的页码(应该返回最后一页) + page5, err := storage.GetResultPage(executionID, 5, 3) + if err != nil { + t.Fatalf("获取第五页失败: %v", err) + } + + // 超出范围的页码会被修正为最后一页,所以应该返回最后一页的内容 + if page5.Page != 4 { + t.Errorf("超出范围的页码应该被修正为最后一页。期望: 4, 实际: %d", page5.Page) + } + + // 最后一页应该只有1行 + if len(page5.Lines) != 1 { + t.Errorf("最后一页应该只有1行。实际: %d行", len(page5.Lines)) + } +} + +func TestFileResultStorage_SearchResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_005" + toolName := "test_tool" + result := "Line 1: error occurred\nLine 2: success\nLine 3: error again\nLine 4: ok" + + // 保存结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 搜索包含"error"的行 + matchedLines, err := storage.SearchResult(executionID, "error") + if err != nil { + t.Fatalf("搜索失败: %v", err) + } + + if len(matchedLines) != 2 { + t.Errorf("搜索结果数量不匹配。期望: 2, 实际: %d", len(matchedLines)) + } + + // 验证搜索结果内容 + for i, line := range matchedLines { + if !strings.Contains(line, "error") { + t.Errorf("搜索结果第%d行不包含关键词: %s", i+1, line) + } + } + + // 测试搜索不存在的关键词 + noMatch, err := storage.SearchResult(executionID, "nonexistent") + if err != nil { + t.Fatalf("搜索失败: %v", err) + } + + if len(noMatch) != 0 { + t.Errorf("搜索不存在的关键词应该返回空结果。实际: %d行", len(noMatch)) + } +} + +func TestFileResultStorage_FilterResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_006" + toolName := "test_tool" + result := "Line 1: warning message\nLine 2: info message\nLine 3: warning again\nLine 4: debug message" + + // 保存结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 过滤包含"warning"的行 + filteredLines, err := storage.FilterResult(executionID, "warning") + if err != nil { + t.Fatalf("过滤失败: %v", err) + } + + if len(filteredLines) != 2 { + t.Errorf("过滤结果数量不匹配。期望: 2, 实际: %d", len(filteredLines)) + } + + // 验证过滤结果内容 + for i, line := range filteredLines { + if !strings.Contains(line, "warning") { + t.Errorf("过滤结果第%d行不包含关键词: %s", i+1, line) + } + } +} + +func TestFileResultStorage_DeleteResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_007" + toolName := "test_tool" + result := "Test result" + + // 保存结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存结果失败: %v", err) + } + + // 验证文件存在 + resultPath := filepath.Join(tmpDir, executionID+".txt") + metadataPath := filepath.Join(tmpDir, executionID+".meta.json") + + if _, err := os.Stat(resultPath); os.IsNotExist(err) { + t.Fatal("结果文件不存在") + } + + if _, err := os.Stat(metadataPath); os.IsNotExist(err) { + t.Fatal("元数据文件不存在") + } + + // 删除结果 + err = storage.DeleteResult(executionID) + if err != nil { + t.Fatalf("删除结果失败: %v", err) + } + + // 验证文件已删除 + if _, err := os.Stat(resultPath); !os.IsNotExist(err) { + t.Fatal("结果文件未被删除") + } + + if _, err := os.Stat(metadataPath); !os.IsNotExist(err) { + t.Fatal("元数据文件未被删除") + } + + // 测试删除不存在的执行ID(应该不报错) + err = storage.DeleteResult("nonexistent_id") + if err != nil { + t.Errorf("删除不存在的执行ID不应该报错: %v", err) + } +} + +func TestFileResultStorage_ConcurrentAccess(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + // 并发保存多个结果 + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func(id int) { + executionID := fmt.Sprintf("test_exec_%d", id) + toolName := "test_tool" + result := fmt.Sprintf("Result %d\nLine 2\nLine 3", id) + + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Errorf("并发保存失败 (ID: %s): %v", executionID, err) + } + + // 并发读取 + _, err = storage.GetResult(executionID) + if err != nil { + t.Errorf("并发读取失败 (ID: %s): %v", executionID, err) + } + + done <- true + }(i) + } + + // 等待所有goroutine完成 + for i := 0; i < 10; i++ { + <-done + } +} + +func TestFileResultStorage_LargeResult(t *testing.T) { + storage, tmpDir := setupTestStorage(t) + defer cleanupTestStorage(t, tmpDir) + + executionID := "test_exec_large" + toolName := "test_tool" + + // 创建大结果(1000行) + lines := make([]string, 1000) + for i := 0; i < 1000; i++ { + lines[i] = fmt.Sprintf("Line %d: This is a test line with some content", i+1) + } + result := strings.Join(lines, "\n") + + // 保存大结果 + err := storage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存大结果失败: %v", err) + } + + // 验证元数据 + metadata, err := storage.GetResultMetadata(executionID) + if err != nil { + t.Fatalf("获取元数据失败: %v", err) + } + + if metadata.TotalLines != 1000 { + t.Errorf("总行数不匹配。期望: 1000, 实际: %d", metadata.TotalLines) + } + + // 测试分页查询大结果 + page, err := storage.GetResultPage(executionID, 1, 100) + if err != nil { + t.Fatalf("获取第一页失败: %v", err) + } + + if page.TotalPages != 10 { + t.Errorf("总页数不匹配。期望: 10, 实际: %d", page.TotalPages) + } + + if len(page.Lines) != 100 { + t.Errorf("第一页行数不匹配。期望: 100, 实际: %d", len(page.Lines)) + } +}