diff --git a/README.md b/README.md index e6dc0a74..2458a666 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ ![Preview](./img/外部MCP接入.png) ## Changelog +- 2025.11.16 Added large result pagination feature: when tool execution results exceed the threshold (default 50KB), automatically save to file and return execution ID, support paginated queries, keyword search, conditional filtering, and regex matching through query_execution_result tool, effectively solving the problem of overly long single responses and improving large file processing capabilities - 2025.11.15 Added external MCP integration feature: support for integrating external MCP servers to extend tool capabilities, supports both stdio and HTTP transport modes, tool-level enable/disable control, complete configuration guide and management APIs - 2025.11.14 Performance optimizations: optimized tool lookup from O(n) to O(1) using index map, added automatic cleanup mechanism for execution records to prevent memory leaks, and added pagination support for database queries - 2025.11.13 Added authentication for the web mode, including automatic password generation and in-app password change @@ -30,6 +31,7 @@ - 💬 **Conversational Interface** - Natural language conversation interface with streaming output (SSE), real-time execution viewing - 📊 **Conversation History Management** - Complete conversation history records, supports viewing, deletion, and management - ⚙️ **Visual Configuration Management** - Web interface for system settings, supports real-time loading and saving configurations with required field validation +- 📄 **Large Result Pagination** - When tool execution results exceed the threshold, automatically save to file, support paginated queries, keyword search, conditional filtering, and regex matching, effectively solving the problem of overly long single responses, with examples for various tools (head, tail, grep, sed, etc.) for segmented reading ### Tool Integration - 🔌 **MCP Protocol Support** - Complete MCP protocol implementation, supports tool registration, invocation, and monitoring diff --git a/README_CN.md b/README_CN.md index 003160fe..4de90167 100644 --- a/README_CN.md +++ b/README_CN.md @@ -9,6 +9,7 @@ ![详情预览](./img/外部MCP接入.png) ## 更新日志 +- 2025.11.16 新增大结果分段读取功能:当工具执行结果超过阈值(默认50KB)时,自动保存到文件并返回执行ID,支持通过 query_execution_result 工具进行分页查询、关键词搜索、条件过滤和正则表达式匹配,有效解决单次返回过长的问题,提升大文件处理能力 - 2025.11.15 新增外部 MCP 接入功能:支持接入外部 MCP 服务器扩展工具能力,支持 stdio 和 HTTP 两种传输模式,支持工具级别的启用/禁用控制,提供完整的配置指南和管理接口 - 2025.11.14 性能优化:工具查找从 O(n) 优化为 O(1)(使用索引映射),添加执行记录自动清理机制防止内存泄漏,数据库查询支持分页加载 - 2025.11.13 Web 端新增统一鉴权,支持自动生成强密码与前端修改密码; @@ -28,6 +29,7 @@ - 💬 **对话式交互** - 自然语言对话界面,支持流式输出(SSE),实时查看执行过程 - 📊 **对话历史管理** - 完整的对话历史记录,支持查看、删除和管理 - ⚙️ **可视化配置管理** - Web界面配置系统设置,支持实时加载和保存配置,必填项验证 +- 📄 **大结果分段读取** - 当工具执行结果超过阈值时自动保存,支持分页查询、关键词搜索、条件过滤和正则表达式匹配,有效解决单次返回过长问题,提供多种工具(head、tail、grep、sed等)的分段读取示例 ### 工具集成 - 🔌 **MCP协议支持** - 完整实现MCP协议,支持工具注册、调用、监控 diff --git a/config.yaml b/config.yaml index 77972063..4d45200b 100644 --- a/config.yaml +++ b/config.yaml @@ -27,7 +27,7 @@ mcp: # 必填项:api_key, base_url, model 必须填写才能正常运行 openai: base_url: https://api.deepseek.com/v1 # API 基础 URL(必填) - api_key: sk-xxx # API 密钥(必填) + api_key: sk-xxxx # API 密钥(必填) # 支持的 API 服务商: # - OpenAI: https://api.openai.com/v1 # - DeepSeek: https://api.deepseek.com/v1 diff --git a/internal/agent/agent.go b/internal/agent/agent.go index d9a4b706..72517114 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -15,22 +15,23 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/storage" + "go.uber.org/zap" ) // Agent AI代理 type Agent struct { - openAIClient *http.Client - config *config.OpenAIConfig - agentConfig *config.AgentConfig - mcpServer *mcp.Server - externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 - logger *zap.Logger - maxIterations int - resultStorage ResultStorage // 结果存储 - largeResultThreshold int // 大结果阈值(字节) - mu sync.RWMutex // 添加互斥锁以支持并发更新 - toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具) + openAIClient *http.Client + config *config.OpenAIConfig + agentConfig *config.AgentConfig + mcpServer *mcp.Server + externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 + logger *zap.Logger + maxIterations int + resultStorage ResultStorage // 结果存储 + largeResultThreshold int // 大结果阈值(字节) + mu sync.RWMutex // 添加互斥锁以支持并发更新 + toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具) } // ResultStorage 结果存储接口(直接使用 storage 包的类型) @@ -38,9 +39,10 @@ type ResultStorage interface { SaveResult(executionID string, toolName string, result string) error GetResult(executionID string) (string, error) GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error) - SearchResult(executionID string, keyword string) ([]string, error) - FilterResult(executionID string, filter string) ([]string, error) + SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) + FilterResult(executionID string, filter string, useRegex bool) ([]string, error) GetResultMetadata(executionID string) (*storage.ResultMetadata, error) + GetResultPath(executionID string) string DeleteResult(executionID string) error } @@ -50,19 +52,19 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer if maxIterations <= 0 { maxIterations = 30 } - + // 设置大结果阈值,默认50KB largeResultThreshold := 50 * 1024 if agentCfg != nil && agentCfg.LargeResultThreshold > 0 { largeResultThreshold = agentCfg.LargeResultThreshold } - + // 设置结果存储目录,默认tmp resultStorageDir := "tmp" if agentCfg != nil && agentCfg.ResultStorageDir != "" { resultStorageDir = agentCfg.ResultStorageDir } - + // 初始化结果存储 var resultStorage ResultStorage if resultStorageDir != "" { @@ -70,21 +72,21 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer // 这里需要在实际使用时初始化 // 暂时设为nil,在需要时初始化 } - + // 配置HTTP Transport,优化连接管理和超时设置 transport := &http.Transport{ DialContext: (&net.Dialer{ Timeout: 300 * time.Second, KeepAlive: 300 * time.Second, }).DialContext, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 30 * time.Second, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 30 * time.Second, ResponseHeaderTimeout: 60 * time.Minute, // 响应头超时:增加到15分钟,应对大响应 - DisableKeepAlives: false, // 启用连接复用 + DisableKeepAlives: false, // 启用连接复用 } - + // 增加超时时间到30分钟,以支持长时间运行的AI推理 // 特别是当使用流式响应或处理复杂任务时 return &Agent{ @@ -92,15 +94,15 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer Timeout: 30 * time.Minute, // 从5分钟增加到30分钟 Transport: transport, }, - config: cfg, - agentConfig: agentCfg, - mcpServer: mcpServer, - externalMCPMgr: externalMCPMgr, - logger: logger, - maxIterations: maxIterations, - resultStorage: resultStorage, + config: cfg, + agentConfig: agentCfg, + mcpServer: mcpServer, + externalMCPMgr: externalMCPMgr, + logger: logger, + maxIterations: maxIterations, + resultStorage: resultStorage, largeResultThreshold: largeResultThreshold, - toolNameMapping: make(map[string]string), // 初始化工具名称映射 + toolNameMapping: make(map[string]string), // 初始化工具名称映射 } } @@ -113,10 +115,10 @@ func (a *Agent) SetResultStorage(storage ResultStorage) { // ChatMessage 聊天消息 type ChatMessage struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` } // MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串 @@ -149,7 +151,7 @@ func (cm ChatMessage) MarshalJSON() ([]byte, error) { } argsJSON = string(argsBytes) } - + toolCallsJSON[i] = map[string]interface{}{ "id": tc.ID, "type": tc.Type, @@ -187,15 +189,15 @@ type Choice struct { // MessageWithTools 带工具调用的消息 type MessageWithTools struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } // Tool OpenAI工具定义 type Tool struct { - Type string `json:"type"` - Function FunctionDefinition `json:"function"` + Type string `json:"type"` + Function FunctionDefinition `json:"function"` } // FunctionDefinition 函数定义 @@ -213,9 +215,9 @@ type Error struct { // ToolCall 工具调用 type ToolCall struct { - ID string `json:"id"` - Type string `json:"type"` - Function FunctionCall `json:"function"` + ID string `json:"id"` + Type string `json:"type"` + Function FunctionCall `json:"function"` } // FunctionCall 函数调用 @@ -267,7 +269,7 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error { // AgentLoopResult Agent Loop执行结果 type AgentLoopResult struct { - Response string + Response string MCPExecutionIDs []string } @@ -300,14 +302,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his 6. 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务 当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。` - + messages := []ChatMessage{ { Role: "system", Content: systemPrompt, }, } - + // 添加历史消息(数据库只保存user和assistant消息) a.logger.Info("处理历史消息", zap.Int("count", len(historyMessages)), @@ -332,13 +334,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his ) } } - + a.logger.Info("构建消息数组", zap.Int("historyMessages", len(historyMessages)), zap.Int("addedMessages", addedCount), zap.Int("totalMessages", len(messages)), ) - + // 添加当前用户消息 messages = append(messages, ChatMessage{ Role: "user", @@ -353,7 +355,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his for i := 0; i < maxIterations; i++ { // 检查是否是最后一次迭代 isLastIteration := (i == maxIterations-1) - + // 获取可用工具 tools := a.getAvailableTools() @@ -451,13 +453,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 发送工具调用开始事件 toolArgsJSON, _ := json.Marshal(toolCall.Function.Arguments) sendProgress("tool_call", fmt.Sprintf("正在调用工具: %s", toolCall.Function.Name), map[string]interface{}{ - "toolName": toolCall.Function.Name, - "arguments": string(toolArgsJSON), + "toolName": toolCall.Function.Name, + "arguments": string(toolArgsJSON), "argumentsObj": toolCall.Function.Arguments, - "toolCallId": toolCall.ID, - "index": idx + 1, - "total": len(choice.Message.ToolCalls), - "iteration": i + 1, + "toolCallId": toolCall.ID, + "index": idx + 1, + "total": len(choice.Message.ToolCalls), + "iteration": i + 1, }) // 执行工具 @@ -466,23 +468,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 构建详细的错误信息,帮助AI理解问题并做出决策 errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err) messages = append(messages, ChatMessage{ - Role: "tool", + Role: "tool", ToolCallID: toolCall.ID, - Content: errorMsg, + Content: errorMsg, }) - + // 发送工具执行失败事件 sendProgress("tool_result", fmt.Sprintf("工具 %s 执行失败", toolCall.Function.Name), map[string]interface{}{ - "toolName": toolCall.Function.Name, - "success": false, - "isError": true, - "error": err.Error(), + "toolName": toolCall.Function.Name, + "success": false, + "isError": true, + "error": err.Error(), "toolCallId": toolCall.ID, - "index": idx + 1, - "total": len(choice.Message.ToolCalls), - "iteration": i + 1, + "index": idx + 1, + "total": len(choice.Message.ToolCalls), + "iteration": i + 1, }) - + a.logger.Warn("工具执行失败,已返回详细错误信息", zap.String("tool", toolCall.Function.Name), zap.Error(err), @@ -490,33 +492,33 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his } else { // 即使工具返回了错误结果(IsError=true),也继续处理,让AI决定下一步 messages = append(messages, ChatMessage{ - Role: "tool", + Role: "tool", ToolCallID: toolCall.ID, - Content: execResult.Result, + Content: execResult.Result, }) // 收集执行ID if execResult.ExecutionID != "" { result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID) } - + // 发送工具执行成功事件 resultPreview := execResult.Result if len(resultPreview) > 200 { resultPreview = resultPreview[:200] + "..." } sendProgress("tool_result", fmt.Sprintf("工具 %s 执行完成", toolCall.Function.Name), map[string]interface{}{ - "toolName": toolCall.Function.Name, - "success": !execResult.IsError, - "isError": execResult.IsError, - "result": execResult.Result, // 完整结果 - "resultPreview": resultPreview, // 预览结果 - "executionId": execResult.ExecutionID, - "toolCallId": toolCall.ID, - "index": idx + 1, - "total": len(choice.Message.ToolCalls), - "iteration": i + 1, + "toolName": toolCall.Function.Name, + "success": !execResult.IsError, + "isError": execResult.IsError, + "result": execResult.Result, // 完整结果 + "resultPreview": resultPreview, // 预览结果 + "executionId": execResult.ExecutionID, + "toolCallId": toolCall.ID, + "index": idx + 1, + "total": len(choice.Message.ToolCalls), + "iteration": i + 1, }) - + // 如果工具返回了错误,记录日志但不中断流程 if execResult.IsError { a.logger.Warn("工具返回错误结果,但继续处理", @@ -526,7 +528,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his } } } - + // 如果是最后一次迭代,执行完工具后要求AI进行总结 if isLastIteration { sendProgress("progress", "最后一次迭代:正在生成总结和下一步计划...", nil) @@ -548,7 +550,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 如果获取总结失败,跳出循环,让后续逻辑处理 break } - + continue } @@ -591,7 +593,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 如果都没有内容,跳出循环,让后续逻辑处理 break } - + // 如果完成,返回结果 if choice.FinishReason == "stop" { sendProgress("progress", "正在生成最终回复...", nil) @@ -608,7 +610,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his Content: fmt.Sprintf("已达到最大迭代次数(%d轮)。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", a.maxIterations), } messages = append(messages, finalSummaryPrompt) - + summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复 if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 { summaryChoice := summaryResponse.Choices[0] @@ -618,7 +620,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his return result, nil } } - + // 如果无法生成总结,返回友好的提示 result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations) return result, nil @@ -629,7 +631,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his func (a *Agent) getAvailableTools() []Tool { // 从MCP服务器获取所有已注册的内部工具 mcpTools := a.mcpServer.GetAllTools() - + // 转换为OpenAI格式的工具定义 tools := make([]Tool, 0, len(mcpTools)) for _, mcpTool := range mcpTools { @@ -638,10 +640,10 @@ func (a *Agent) getAvailableTools() []Tool { if description == "" { description = mcpTool.Description } - + // 转换schema中的类型为OpenAI标准类型 convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema) - + tools = append(tools, Tool{ Type: "function", Function: FunctionDefinition{ @@ -651,24 +653,24 @@ func (a *Agent) getAvailableTools() []Tool { }, }) } - + // 获取外部MCP工具 if a.externalMCPMgr != nil { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - + externalTools, err := a.externalMCPMgr.GetAllTools(ctx) if err != nil { a.logger.Warn("获取外部MCP工具失败", zap.Error(err)) } else { // 获取外部MCP配置,用于检查工具启用状态 externalMCPConfigs := a.externalMCPMgr.GetConfigs() - + // 清空并重建工具名称映射 a.mu.Lock() a.toolNameMapping = make(map[string]string) a.mu.Unlock() - + // 将外部MCP工具添加到工具列表(只添加启用的工具) for _, externalTool := range externalTools { // 解析工具名称:mcpName::toolName @@ -679,7 +681,7 @@ func (a *Agent) getAvailableTools() []Tool { } else { continue // 跳过格式不正确的工具 } - + // 检查工具是否启用 enabled := false if cfg, exists := externalMCPConfigs[mcpName]; exists { @@ -698,30 +700,30 @@ func (a *Agent) getAvailableTools() []Tool { } } } - + // 只添加启用的工具 if !enabled { continue } - + // 使用简短描述(如果存在),否则使用详细描述 description := externalTool.ShortDescription if description == "" { description = externalTool.Description } - + // 转换schema中的类型为OpenAI标准类型 convertedSchema := a.convertSchemaTypes(externalTool.InputSchema) - + // 将工具名称中的 "::" 替换为 "__" 以符合OpenAI命名规范 // OpenAI要求工具名称只能包含 [a-zA-Z0-9_-] openAIName := strings.ReplaceAll(externalTool.Name, "::", "__") - + // 保存名称映射关系(OpenAI格式 -> 原始格式) a.mu.Lock() a.toolNameMapping[openAIName] = externalTool.Name a.mu.Unlock() - + tools = append(tools, Tool{ Type: "function", Function: FunctionDefinition{ @@ -733,12 +735,12 @@ func (a *Agent) getAvailableTools() []Tool { } } } - + a.logger.Debug("获取可用工具列表", zap.Int("internalTools", len(mcpTools)), zap.Int("totalTools", len(tools)), ) - + return tools } @@ -747,13 +749,13 @@ func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]int if schema == nil { return schema } - + // 创建新的schema副本 converted := make(map[string]interface{}) for k, v := range schema { converted[k] = v } - + // 转换properties中的类型 if properties, ok := converted["properties"].(map[string]interface{}); ok { convertedProperties := make(map[string]interface{}) @@ -779,7 +781,7 @@ func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]int } converted["properties"] = convertedProperties } - + return converted } @@ -834,7 +836,7 @@ func (a *Agent) isRetryableError(err error) bool { func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) { maxRetries := 3 var lastErr error - + for attempt := 0; attempt < maxRetries; attempt++ { response, err := a.callOpenAISingle(ctx, messages, tools) if err == nil { @@ -846,14 +848,14 @@ func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools [] } return response, nil } - + lastErr = err - + // 如果不是可重试的错误,直接返回 if !a.isRetryableError(err) { return nil, err } - + // 如果不是最后一次重试,等待后重试 if attempt < maxRetries-1 { // 指数退避:2s, 4s, 8s... @@ -867,7 +869,7 @@ func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools [] zap.Int("maxRetries", maxRetries), zap.Duration("backoff", backoff), ) - + // 检查上下文是否已取消 select { case <-ctx.Done(): @@ -877,7 +879,7 @@ func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools [] } } } - + return nil, fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr) } @@ -924,7 +926,7 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to // 记录响应头接收时间 headerReceiveTime := time.Now() headerReceiveDuration := headerReceiveTime.Sub(requestStartTime) - + a.logger.Debug("收到OpenAI响应头", zap.Int("statusCode", resp.StatusCode), zap.Duration("headerReceiveDuration", headerReceiveDuration), @@ -934,7 +936,7 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to // 使用带超时的读取(通过context控制) bodyChan := make(chan []byte, 1) errChan := make(chan error, 1) - + go func() { body, err := io.ReadAll(resp.Body) if err != nil { @@ -943,7 +945,7 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to } bodyChan <- body }() - + var body []byte select { case body = <-bodyChan: @@ -951,7 +953,7 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to bodyReceiveTime := time.Now() bodyReceiveDuration := bodyReceiveTime.Sub(headerReceiveTime) totalDuration := bodyReceiveTime.Sub(requestStartTime) - + a.logger.Debug("完成读取OpenAI响应体", zap.Int("bodySizeKB", len(body)/1024), zap.Duration("bodyReceiveDuration", bodyReceiveDuration), @@ -1050,7 +1052,7 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map // 调用内部MCP工具 result, executionID, err = a.mcpServer.CallTool(ctx, toolName, args) } - + // 如果调用失败(如工具不存在),返回友好的错误信息而不是抛出异常 if err != nil { errorMsg := fmt.Sprintf(`工具调用失败 @@ -1068,7 +1070,7 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map - 检查工具名称是否正确 - 尝试使用其他替代工具 - 如果这是必需的工具,请向用户说明情况`, toolName, err, toolName) - + return &ToolExecutionResult{ Result: errorMsg, ExecutionID: executionID, @@ -1082,16 +1084,16 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map resultText.WriteString(content.Text) resultText.WriteString("\n") } - + resultStr := resultText.String() resultSize := len(resultStr) - + // 检测大结果并保存 a.mu.RLock() threshold := a.largeResultThreshold storage := a.resultStorage a.mu.RUnlock() - + if resultSize > threshold && storage != nil { // 异步保存大结果 go func() { @@ -1109,11 +1111,15 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map ) } }() - + // 返回最小化通知 lines := strings.Split(resultStr, "\n") - notification := a.formatMinimalNotification(executionID, toolName, resultSize, len(lines)) - + filePath := "" + if storage != nil { + filePath = storage.GetResultPath(executionID) + } + notification := a.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath) + return &ToolExecutionResult{ Result: notification, ExecutionID: executionID, @@ -1129,20 +1135,53 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map } // formatMinimalNotification 格式化最小化通知 -func (a *Agent) formatMinimalNotification(executionID string, toolName string, size int, lineCount int) string { +func (a *Agent) formatMinimalNotification(executionID string, toolName string, size int, lineCount int, filePath string) string { var sb strings.Builder - + sb.WriteString(fmt.Sprintf("工具执行完成。结果已保存(ID: %s)。\n\n", executionID)) sb.WriteString("结果信息:\n") sb.WriteString(fmt.Sprintf(" - 工具: %s\n", toolName)) sb.WriteString(fmt.Sprintf(" - 大小: %d 字节 (%.2f KB)\n", size, float64(size)/1024)) sb.WriteString(fmt.Sprintf(" - 行数: %d 行\n", lineCount)) + if filePath != "" { + sb.WriteString(fmt.Sprintf(" - 文件路径: %s\n", filePath)) + } sb.WriteString("\n") - sb.WriteString("使用以下工具查询完整结果:\n") + sb.WriteString("推荐使用 query_execution_result 工具查询完整结果:\n") sb.WriteString(fmt.Sprintf(" - 查询第一页: query_execution_result(execution_id=\"%s\", page=1, limit=100)\n", executionID)) sb.WriteString(fmt.Sprintf(" - 搜索关键词: query_execution_result(execution_id=\"%s\", search=\"关键词\")\n", executionID)) sb.WriteString(fmt.Sprintf(" - 过滤条件: query_execution_result(execution_id=\"%s\", filter=\"error\")\n", executionID)) - + sb.WriteString(fmt.Sprintf(" - 正则匹配: query_execution_result(execution_id=\"%s\", search=\"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", use_regex=true)\n", executionID)) + sb.WriteString("\n") + if filePath != "" { + sb.WriteString("如果 query_execution_result 工具不满足需求,也可以使用其他工具处理文件:\n") + sb.WriteString("\n") + sb.WriteString("**分段读取示例:**\n") + sb.WriteString(fmt.Sprintf(" - 查看前100行: exec(command=\"head\", args=[\"-n\", \"100\", \"%s\"])\n", filePath)) + sb.WriteString(fmt.Sprintf(" - 查看后100行: exec(command=\"tail\", args=[\"-n\", \"100\", \"%s\"])\n", filePath)) + sb.WriteString(fmt.Sprintf(" - 查看第50-150行: exec(command=\"sed\", args=[\"-n\", \"50,150p\", \"%s\"])\n", filePath)) + sb.WriteString("\n") + sb.WriteString("**搜索和正则匹配示例:**\n") + sb.WriteString(fmt.Sprintf(" - 搜索关键词: exec(command=\"grep\", args=[\"关键词\", \"%s\"])\n", filePath)) + sb.WriteString(fmt.Sprintf(" - 正则匹配IP地址: exec(command=\"grep\", args=[\"-E\", \"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", \"%s\"])\n", filePath)) + sb.WriteString(fmt.Sprintf(" - 不区分大小写搜索: exec(command=\"grep\", args=[\"-i\", \"关键词\", \"%s\"])\n", filePath)) + sb.WriteString(fmt.Sprintf(" - 显示匹配行号: exec(command=\"grep\", args=[\"-n\", \"关键词\", \"%s\"])\n", filePath)) + sb.WriteString("\n") + sb.WriteString("**过滤和统计示例:**\n") + sb.WriteString(fmt.Sprintf(" - 统计总行数: exec(command=\"wc\", args=[\"-l\", \"%s\"])\n", filePath)) + sb.WriteString(fmt.Sprintf(" - 过滤包含error的行: exec(command=\"grep\", args=[\"error\", \"%s\"])\n", filePath)) + sb.WriteString(fmt.Sprintf(" - 排除空行: exec(command=\"grep\", args=[\"-v\", \"^$\", \"%s\"])\n", filePath)) + sb.WriteString("\n") + sb.WriteString("**完整读取(不推荐大文件):**\n") + sb.WriteString(fmt.Sprintf(" - 使用 cat 工具: cat(file=\"%s\")\n", filePath)) + sb.WriteString(fmt.Sprintf(" - 使用 exec 工具: exec(command=\"cat\", args=[\"%s\"])\n", filePath)) + sb.WriteString("\n") + sb.WriteString("**注意:**\n") + sb.WriteString(" - 直接读取大文件可能会再次触发大结果保存机制\n") + sb.WriteString(" - 建议优先使用分段读取和搜索功能,避免一次性加载整个文件\n") + sb.WriteString(" - 正则表达式语法遵循标准 POSIX 正则表达式规范\n") + } + return sb.String() } @@ -1180,7 +1219,6 @@ func (a *Agent) formatToolError(toolName string, args map[string]interface{}, er 2. 如果工具不可用,请尝试使用替代工具 3. 如果这是系统问题,请向用户说明情况并提供建议 4. 如果错误信息中包含有用信息,可以基于这些信息继续分析`, toolName, args, err) - + return errorMsg } - diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go index c6ec9bd6..fcbcfa64 100644 --- a/internal/agent/agent_test.go +++ b/internal/agent/agent_test.go @@ -53,8 +53,9 @@ func TestAgent_FormatMinimalNotification(t *testing.T) { toolName := "nmap_scan" size := 50000 lineCount := 1000 + filePath := "tmp/test_exec_001.txt" - notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount) + notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath) // 验证通知包含必要信息 if !strings.Contains(notification, executionID) { @@ -130,7 +131,8 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) { // 生成通知 lines := strings.Split(resultStr, "\n") - notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines)) + filePath := storage.GetResultPath(executionID) + notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath) // 验证通知格式 if !strings.Contains(notification, executionID) { diff --git a/internal/security/executor.go b/internal/security/executor.go index e10d258d..166bbe74 100644 --- a/internal/security/executor.go +++ b/internal/security/executor.go @@ -28,9 +28,10 @@ type ResultStorage interface { SaveResult(executionID string, toolName string, result string) error GetResult(executionID string) (string, error) GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error) - SearchResult(executionID string, keyword string) ([]string, error) - FilterResult(executionID string, filter string) ([]string, error) + SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) + FilterResult(executionID string, filter string, useRegex bool) ([]string, error) GetResultMetadata(executionID string) (*storage.ResultMetadata, error) + GetResultPath(executionID string) string DeleteResult(executionID string) error } @@ -755,6 +756,11 @@ func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[str filter = f } + useRegex := false + if r, ok := args["use_regex"].(bool); ok { + useRegex = r + } + // 检查结果存储是否可用 if e.resultStorage == nil { return &mcp.ToolResult{ @@ -774,7 +780,7 @@ func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[str if search != "" { // 搜索模式 - matchedLines, err := e.resultStorage.SearchResult(executionID, search) + matchedLines, err := e.resultStorage.SearchResult(executionID, search, useRegex) if err != nil { return &mcp.ToolResult{ Content: []mcp.Content{ @@ -790,7 +796,7 @@ func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[str resultPage = paginateLines(matchedLines, page, limit) } else if filter != "" { // 过滤模式 - filteredLines, err := e.resultStorage.FilterResult(executionID, filter) + filteredLines, err := e.resultStorage.FilterResult(executionID, filter, useRegex) if err != nil { return &mcp.ToolResult{ Content: []mcp.Content{ @@ -853,9 +859,15 @@ func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[str sb.WriteString(fmt.Sprintf("提示: 使用 page=%d 查看下一页", resultPage.Page+1)) if search != "" { sb.WriteString(fmt.Sprintf(",或使用 search=\"%s\" 继续搜索", search)) + if useRegex { + sb.WriteString(" (正则模式)") + } } if filter != "" { sb.WriteString(fmt.Sprintf(",或使用 filter=\"%s\" 继续过滤", filter)) + if useRegex { + sb.WriteString(" (正则模式)") + } } sb.WriteString("\n") } diff --git a/internal/storage/result_storage.go b/internal/storage/result_storage.go index e3df9e4e..85a8b7b3 100644 --- a/internal/storage/result_storage.go +++ b/internal/storage/result_storage.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "regexp" "strings" "sync" "time" @@ -16,22 +17,27 @@ import ( type ResultStorage interface { // SaveResult 保存工具执行结果 SaveResult(executionID string, toolName string, result string) error - + // GetResult 获取完整结果 GetResult(executionID string) (string, error) - + // GetResultPage 分页获取结果 GetResultPage(executionID string, page int, limit int) (*ResultPage, error) - + // SearchResult 搜索结果 - SearchResult(executionID string, keyword string) ([]string, error) - + // useRegex: 如果为 true,将 keyword 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配 + SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) + // FilterResult 过滤结果 - FilterResult(executionID string, filter string) ([]string, error) - + // useRegex: 如果为 true,将 filter 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配 + FilterResult(executionID string, filter string, useRegex bool) ([]string, error) + // GetResultMetadata 获取结果元信息 GetResultMetadata(executionID string) (*ResultMetadata, error) - + + // GetResultPath 获取结果文件路径 + GetResultPath(executionID string) string + // DeleteResult 删除结果 DeleteResult(executionID string) error } @@ -67,7 +73,7 @@ func NewFileResultStorage(baseDir string, logger *zap.Logger) (*FileResultStorag if err := os.MkdirAll(baseDir, 0755); err != nil { return nil, fmt.Errorf("创建存储目录失败: %w", err) } - + return &FileResultStorage{ baseDir: baseDir, logger: logger, @@ -88,13 +94,13 @@ func (s *FileResultStorage) getMetadataPath(executionID string) string { func (s *FileResultStorage) SaveResult(executionID string, toolName string, result string) error { s.mu.Lock() defer s.mu.Unlock() - + // 保存结果文件 resultPath := s.getResultPath(executionID) if err := os.WriteFile(resultPath, []byte(result), 0644); err != nil { return fmt.Errorf("保存结果文件失败: %w", err) } - + // 计算统计信息 lines := strings.Split(result, "\n") metadata := &ResultMetadata{ @@ -104,25 +110,25 @@ func (s *FileResultStorage) SaveResult(executionID string, toolName string, resu TotalLines: len(lines), CreatedAt: time.Now(), } - + // 保存元数据 metadataPath := s.getMetadataPath(executionID) metadataJSON, err := json.Marshal(metadata) if err != nil { return fmt.Errorf("序列化元数据失败: %w", err) } - + if err := os.WriteFile(metadataPath, metadataJSON, 0644); err != nil { return fmt.Errorf("保存元数据文件失败: %w", err) } - + s.logger.Info("保存工具执行结果", zap.String("executionID", executionID), zap.String("toolName", toolName), zap.Int("size", len(result)), zap.Int("lines", len(lines)), ) - + return nil } @@ -130,7 +136,7 @@ func (s *FileResultStorage) SaveResult(executionID string, toolName string, resu func (s *FileResultStorage) GetResult(executionID string) (string, error) { s.mu.RLock() defer s.mu.RUnlock() - + resultPath := s.getResultPath(executionID) data, err := os.ReadFile(resultPath) if err != nil { @@ -139,7 +145,7 @@ func (s *FileResultStorage) GetResult(executionID string) (string, error) { } return "", fmt.Errorf("读取结果文件失败: %w", err) } - + return string(data), nil } @@ -147,7 +153,7 @@ func (s *FileResultStorage) GetResult(executionID string) (string, error) { func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetadata, error) { s.mu.RLock() defer s.mu.RUnlock() - + metadataPath := s.getMetadataPath(executionID) data, err := os.ReadFile(metadataPath) if err != nil { @@ -156,12 +162,12 @@ func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetada } return nil, fmt.Errorf("读取元数据文件失败: %w", err) } - + var metadata ResultMetadata if err := json.Unmarshal(data, &metadata); err != nil { return nil, fmt.Errorf("解析元数据失败: %w", err) } - + return &metadata, nil } @@ -169,17 +175,17 @@ func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetada func (s *FileResultStorage) GetResultPage(executionID string, page int, limit int) (*ResultPage, error) { s.mu.RLock() defer s.mu.RUnlock() - + // 获取完整结果 result, err := s.GetResult(executionID) if err != nil { return nil, err } - + // 分割为行 lines := strings.Split(result, "\n") totalLines := len(lines) - + // 计算分页 totalPages := (totalLines + limit - 1) / limit if page < 1 { @@ -188,14 +194,14 @@ func (s *FileResultStorage) GetResultPage(executionID string, page int, limit in if page > totalPages && totalPages > 0 { page = totalPages } - + // 计算起始和结束索引 start := (page - 1) * limit end := start + limit if end > totalLines { end = totalLines } - + // 提取指定页的行 var pageLines []string if start < totalLines { @@ -203,7 +209,7 @@ func (s *FileResultStorage) GetResultPage(executionID string, page int, limit in } else { pageLines = []string{} } - + return &ResultPage{ Lines: pageLines, Page: page, @@ -214,57 +220,78 @@ func (s *FileResultStorage) GetResultPage(executionID string, page int, limit in } // SearchResult 搜索结果 -func (s *FileResultStorage) SearchResult(executionID string, keyword string) ([]string, error) { +func (s *FileResultStorage) SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) { s.mu.RLock() defer s.mu.RUnlock() - + // 获取完整结果 result, err := s.GetResult(executionID) if err != nil { return nil, err } - + + // 如果使用正则表达式,先编译正则 + var regex *regexp.Regexp + if useRegex { + compiledRegex, err := regexp.Compile(keyword) + if err != nil { + return nil, fmt.Errorf("无效的正则表达式: %w", err) + } + regex = compiledRegex + } + // 分割为行并搜索 lines := strings.Split(result, "\n") var matchedLines []string - + for _, line := range lines { - if strings.Contains(line, keyword) { + var matched bool + if useRegex { + matched = regex.MatchString(line) + } else { + matched = strings.Contains(line, keyword) + } + + if matched { matchedLines = append(matchedLines, line) } } - + return matchedLines, nil } // FilterResult 过滤结果 -func (s *FileResultStorage) FilterResult(executionID string, filter string) ([]string, error) { +func (s *FileResultStorage) FilterResult(executionID string, filter string, useRegex bool) ([]string, error) { // 过滤和搜索逻辑相同,都是查找包含关键词的行 - return s.SearchResult(executionID, filter) + return s.SearchResult(executionID, filter, useRegex) +} + +// GetResultPath 获取结果文件路径 +func (s *FileResultStorage) GetResultPath(executionID string) string { + return s.getResultPath(executionID) } // DeleteResult 删除结果 func (s *FileResultStorage) DeleteResult(executionID string) error { s.mu.Lock() defer s.mu.Unlock() - + resultPath := s.getResultPath(executionID) metadataPath := s.getMetadataPath(executionID) - + // 删除结果文件 if err := os.Remove(resultPath); err != nil && !os.IsNotExist(err) { return fmt.Errorf("删除结果文件失败: %w", err) } - + // 删除元数据文件 if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) { return fmt.Errorf("删除元数据文件失败: %w", err) } - + s.logger.Info("删除工具执行结果", zap.String("executionID", executionID), ) - + return nil } - diff --git a/internal/storage/result_storage_test.go b/internal/storage/result_storage_test.go index aaf2bfa1..51305c92 100644 --- a/internal/storage/result_storage_test.go +++ b/internal/storage/result_storage_test.go @@ -256,8 +256,8 @@ func TestFileResultStorage_SearchResult(t *testing.T) { t.Fatalf("保存结果失败: %v", err) } - // 搜索包含"error"的行 - matchedLines, err := storage.SearchResult(executionID, "error") + // 搜索包含"error"的行(简单字符串匹配) + matchedLines, err := storage.SearchResult(executionID, "error", false) if err != nil { t.Fatalf("搜索失败: %v", err) } @@ -274,7 +274,7 @@ func TestFileResultStorage_SearchResult(t *testing.T) { } // 测试搜索不存在的关键词 - noMatch, err := storage.SearchResult(executionID, "nonexistent") + noMatch, err := storage.SearchResult(executionID, "nonexistent", false) if err != nil { t.Fatalf("搜索失败: %v", err) } @@ -282,6 +282,16 @@ func TestFileResultStorage_SearchResult(t *testing.T) { if len(noMatch) != 0 { t.Errorf("搜索不存在的关键词应该返回空结果。实际: %d行", len(noMatch)) } + + // 测试正则表达式搜索 + regexMatched, err := storage.SearchResult(executionID, "error.*again", true) + if err != nil { + t.Fatalf("正则搜索失败: %v", err) + } + + if len(regexMatched) != 1 { + t.Errorf("正则搜索结果数量不匹配。期望: 1, 实际: %d", len(regexMatched)) + } } func TestFileResultStorage_FilterResult(t *testing.T) { @@ -298,8 +308,8 @@ func TestFileResultStorage_FilterResult(t *testing.T) { t.Fatalf("保存结果失败: %v", err) } - // 过滤包含"warning"的行 - filteredLines, err := storage.FilterResult(executionID, "warning") + // 过滤包含"warning"的行(简单字符串匹配) + filteredLines, err := storage.FilterResult(executionID, "warning", false) if err != nil { t.Fatalf("过滤失败: %v", err) }