Add files via upload

This commit is contained in:
公明
2025-11-15 20:15:55 +08:00
committed by GitHub
parent 0455549c18
commit ac2c62f882
8 changed files with 274 additions and 181 deletions

View File

@@ -12,6 +12,7 @@
![Preview](./img/外部MCP接入.png)
## Changelog
- 2025.11.16 Added large result pagination feature: when tool execution results exceed the threshold (default 50KB), automatically save to file and return execution ID, support paginated queries, keyword search, conditional filtering, and regex matching through query_execution_result tool, effectively solving the problem of overly long single responses and improving large file processing capabilities
- 2025.11.15 Added external MCP integration feature: support for integrating external MCP servers to extend tool capabilities, supports both stdio and HTTP transport modes, tool-level enable/disable control, complete configuration guide and management APIs
- 2025.11.14 Performance optimizations: optimized tool lookup from O(n) to O(1) using index map, added automatic cleanup mechanism for execution records to prevent memory leaks, and added pagination support for database queries
- 2025.11.13 Added authentication for the web mode, including automatic password generation and in-app password change
@@ -30,6 +31,7 @@
- 💬 **Conversational Interface** - Natural language conversation interface with streaming output (SSE), real-time execution viewing
- 📊 **Conversation History Management** - Complete conversation history records, supports viewing, deletion, and management
- ⚙️ **Visual Configuration Management** - Web interface for system settings, supports real-time loading and saving configurations with required field validation
- 📄 **Large Result Pagination** - When tool execution results exceed the threshold, automatically save to file, support paginated queries, keyword search, conditional filtering, and regex matching, effectively solving the problem of overly long single responses, with examples for various tools (head, tail, grep, sed, etc.) for segmented reading
### Tool Integration
- 🔌 **MCP Protocol Support** - Complete MCP protocol implementation, supports tool registration, invocation, and monitoring

View File

@@ -9,6 +9,7 @@
![详情预览](./img/外部MCP接入.png)
## 更新日志
- 2025.11.16 新增大结果分段读取功能当工具执行结果超过阈值默认50KB自动保存到文件并返回执行ID支持通过 query_execution_result 工具进行分页查询、关键词搜索、条件过滤和正则表达式匹配,有效解决单次返回过长的问题,提升大文件处理能力
- 2025.11.15 新增外部 MCP 接入功能:支持接入外部 MCP 服务器扩展工具能力,支持 stdio 和 HTTP 两种传输模式,支持工具级别的启用/禁用控制,提供完整的配置指南和管理接口
- 2025.11.14 性能优化:工具查找从 O(n) 优化为 O(1)(使用索引映射),添加执行记录自动清理机制防止内存泄漏,数据库查询支持分页加载
- 2025.11.13 Web 端新增统一鉴权,支持自动生成强密码与前端修改密码;
@@ -28,6 +29,7 @@
- 💬 **对话式交互** - 自然语言对话界面支持流式输出SSE实时查看执行过程
- 📊 **对话历史管理** - 完整的对话历史记录,支持查看、删除和管理
- ⚙️ **可视化配置管理** - Web界面配置系统设置支持实时加载和保存配置必填项验证
- 📄 **大结果分段读取** - 当工具执行结果超过阈值时自动保存支持分页查询、关键词搜索、条件过滤和正则表达式匹配有效解决单次返回过长问题提供多种工具head、tail、grep、sed等的分段读取示例
### 工具集成
- 🔌 **MCP协议支持** - 完整实现MCP协议支持工具注册、调用、监控

View File

@@ -27,7 +27,7 @@ mcp:
# 必填项api_key, base_url, model 必须填写才能正常运行
openai:
base_url: https://api.deepseek.com/v1 # API 基础 URL必填
api_key: sk-xxx # API 密钥(必填)
api_key: sk-xxxx # API 密钥(必填)
# 支持的 API 服务商:
# - OpenAI: https://api.openai.com/v1
# - DeepSeek: https://api.deepseek.com/v1

View File

@@ -15,22 +15,23 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap"
)
// Agent AI代理
type Agent struct {
openAIClient *http.Client
config *config.OpenAIConfig
agentConfig *config.AgentConfig
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
logger *zap.Logger
maxIterations int
resultStorage ResultStorage // 结果存储
largeResultThreshold int // 大结果阈值(字节)
mu sync.RWMutex // 添加互斥锁以支持并发更新
toolNameMapping map[string]string // 工具名称映射OpenAI格式 -> 原始格式用于外部MCP工具
openAIClient *http.Client
config *config.OpenAIConfig
agentConfig *config.AgentConfig
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
logger *zap.Logger
maxIterations int
resultStorage ResultStorage // 结果存储
largeResultThreshold int // 大结果阈值(字节)
mu sync.RWMutex // 添加互斥锁以支持并发更新
toolNameMapping map[string]string // 工具名称映射OpenAI格式 -> 原始格式用于外部MCP工具
}
// ResultStorage 结果存储接口(直接使用 storage 包的类型)
@@ -38,9 +39,10 @@ type ResultStorage interface {
SaveResult(executionID string, toolName string, result string) error
GetResult(executionID string) (string, error)
GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error)
SearchResult(executionID string, keyword string) ([]string, error)
FilterResult(executionID string, filter string) ([]string, error)
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
GetResultMetadata(executionID string) (*storage.ResultMetadata, error)
GetResultPath(executionID string) string
DeleteResult(executionID string) error
}
@@ -50,19 +52,19 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
if maxIterations <= 0 {
maxIterations = 30
}
// 设置大结果阈值默认50KB
largeResultThreshold := 50 * 1024
if agentCfg != nil && agentCfg.LargeResultThreshold > 0 {
largeResultThreshold = agentCfg.LargeResultThreshold
}
// 设置结果存储目录默认tmp
resultStorageDir := "tmp"
if agentCfg != nil && agentCfg.ResultStorageDir != "" {
resultStorageDir = agentCfg.ResultStorageDir
}
// 初始化结果存储
var resultStorage ResultStorage
if resultStorageDir != "" {
@@ -70,21 +72,21 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
// 这里需要在实际使用时初始化
// 暂时设为nil在需要时初始化
}
// 配置HTTP Transport优化连接管理和超时设置
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: 300 * time.Second,
KeepAlive: 300 * time.Second,
}).DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 30 * time.Second,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 30 * time.Second,
ResponseHeaderTimeout: 60 * time.Minute, // 响应头超时增加到15分钟应对大响应
DisableKeepAlives: false, // 启用连接复用
DisableKeepAlives: false, // 启用连接复用
}
// 增加超时时间到30分钟以支持长时间运行的AI推理
// 特别是当使用流式响应或处理复杂任务时
return &Agent{
@@ -92,15 +94,15 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
Timeout: 30 * time.Minute, // 从5分钟增加到30分钟
Transport: transport,
},
config: cfg,
agentConfig: agentCfg,
mcpServer: mcpServer,
externalMCPMgr: externalMCPMgr,
logger: logger,
maxIterations: maxIterations,
resultStorage: resultStorage,
config: cfg,
agentConfig: agentCfg,
mcpServer: mcpServer,
externalMCPMgr: externalMCPMgr,
logger: logger,
maxIterations: maxIterations,
resultStorage: resultStorage,
largeResultThreshold: largeResultThreshold,
toolNameMapping: make(map[string]string), // 初始化工具名称映射
toolNameMapping: make(map[string]string), // 初始化工具名称映射
}
}
@@ -113,10 +115,10 @@ func (a *Agent) SetResultStorage(storage ResultStorage) {
// ChatMessage 聊天消息
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// MarshalJSON 自定义JSON序列化将tool_calls中的arguments转换为JSON字符串
@@ -149,7 +151,7 @@ func (cm ChatMessage) MarshalJSON() ([]byte, error) {
}
argsJSON = string(argsBytes)
}
toolCallsJSON[i] = map[string]interface{}{
"id": tc.ID,
"type": tc.Type,
@@ -187,15 +189,15 @@ type Choice struct {
// MessageWithTools 带工具调用的消息
type MessageWithTools struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
// Tool OpenAI工具定义
type Tool struct {
Type string `json:"type"`
Function FunctionDefinition `json:"function"`
Type string `json:"type"`
Function FunctionDefinition `json:"function"`
}
// FunctionDefinition 函数定义
@@ -213,9 +215,9 @@ type Error struct {
// ToolCall 工具调用
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function FunctionCall `json:"function"`
ID string `json:"id"`
Type string `json:"type"`
Function FunctionCall `json:"function"`
}
// FunctionCall 函数调用
@@ -267,7 +269,7 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
// AgentLoopResult Agent Loop执行结果
type AgentLoopResult struct {
Response string
Response string
MCPExecutionIDs []string
}
@@ -300,14 +302,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
6. 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。`
messages := []ChatMessage{
{
Role: "system",
Content: systemPrompt,
},
}
// 添加历史消息数据库只保存user和assistant消息
a.logger.Info("处理历史消息",
zap.Int("count", len(historyMessages)),
@@ -332,13 +334,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
)
}
}
a.logger.Info("构建消息数组",
zap.Int("historyMessages", len(historyMessages)),
zap.Int("addedMessages", addedCount),
zap.Int("totalMessages", len(messages)),
)
// 添加当前用户消息
messages = append(messages, ChatMessage{
Role: "user",
@@ -353,7 +355,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
for i := 0; i < maxIterations; i++ {
// 检查是否是最后一次迭代
isLastIteration := (i == maxIterations-1)
// 获取可用工具
tools := a.getAvailableTools()
@@ -451,13 +453,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
// 发送工具调用开始事件
toolArgsJSON, _ := json.Marshal(toolCall.Function.Arguments)
sendProgress("tool_call", fmt.Sprintf("正在调用工具: %s", toolCall.Function.Name), map[string]interface{}{
"toolName": toolCall.Function.Name,
"arguments": string(toolArgsJSON),
"toolName": toolCall.Function.Name,
"arguments": string(toolArgsJSON),
"argumentsObj": toolCall.Function.Arguments,
"toolCallId": toolCall.ID,
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
"iteration": i + 1,
"toolCallId": toolCall.ID,
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
"iteration": i + 1,
})
// 执行工具
@@ -466,23 +468,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
// 构建详细的错误信息帮助AI理解问题并做出决策
errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err)
messages = append(messages, ChatMessage{
Role: "tool",
Role: "tool",
ToolCallID: toolCall.ID,
Content: errorMsg,
Content: errorMsg,
})
// 发送工具执行失败事件
sendProgress("tool_result", fmt.Sprintf("工具 %s 执行失败", toolCall.Function.Name), map[string]interface{}{
"toolName": toolCall.Function.Name,
"success": false,
"isError": true,
"error": err.Error(),
"toolName": toolCall.Function.Name,
"success": false,
"isError": true,
"error": err.Error(),
"toolCallId": toolCall.ID,
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
"iteration": i + 1,
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
"iteration": i + 1,
})
a.logger.Warn("工具执行失败,已返回详细错误信息",
zap.String("tool", toolCall.Function.Name),
zap.Error(err),
@@ -490,33 +492,33 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
} else {
// 即使工具返回了错误结果IsError=true也继续处理让AI决定下一步
messages = append(messages, ChatMessage{
Role: "tool",
Role: "tool",
ToolCallID: toolCall.ID,
Content: execResult.Result,
Content: execResult.Result,
})
// 收集执行ID
if execResult.ExecutionID != "" {
result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID)
}
// 发送工具执行成功事件
resultPreview := execResult.Result
if len(resultPreview) > 200 {
resultPreview = resultPreview[:200] + "..."
}
sendProgress("tool_result", fmt.Sprintf("工具 %s 执行完成", toolCall.Function.Name), map[string]interface{}{
"toolName": toolCall.Function.Name,
"success": !execResult.IsError,
"isError": execResult.IsError,
"result": execResult.Result, // 完整结果
"resultPreview": resultPreview, // 预览结果
"executionId": execResult.ExecutionID,
"toolCallId": toolCall.ID,
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
"iteration": i + 1,
"toolName": toolCall.Function.Name,
"success": !execResult.IsError,
"isError": execResult.IsError,
"result": execResult.Result, // 完整结果
"resultPreview": resultPreview, // 预览结果
"executionId": execResult.ExecutionID,
"toolCallId": toolCall.ID,
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
"iteration": i + 1,
})
// 如果工具返回了错误,记录日志但不中断流程
if execResult.IsError {
a.logger.Warn("工具返回错误结果,但继续处理",
@@ -526,7 +528,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}
}
}
// 如果是最后一次迭代执行完工具后要求AI进行总结
if isLastIteration {
sendProgress("progress", "最后一次迭代:正在生成总结和下一步计划...", nil)
@@ -548,7 +550,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
// 如果获取总结失败,跳出循环,让后续逻辑处理
break
}
continue
}
@@ -591,7 +593,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
// 如果都没有内容,跳出循环,让后续逻辑处理
break
}
// 如果完成,返回结果
if choice.FinishReason == "stop" {
sendProgress("progress", "正在生成最终回复...", nil)
@@ -608,7 +610,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
Content: fmt.Sprintf("已达到最大迭代次数(%d轮。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试请提供详细的下一步执行计划。请直接回复不要调用工具。", a.maxIterations),
}
messages = append(messages, finalSummaryPrompt)
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
summaryChoice := summaryResponse.Choices[0]
@@ -618,7 +620,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
return result, nil
}
}
// 如果无法生成总结,返回友好的提示
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮。系统已执行了多轮测试但由于达到迭代上限无法继续自动执行。建议您查看已执行的工具结果或提出新的测试请求以继续测试。", a.maxIterations)
return result, nil
@@ -629,7 +631,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
func (a *Agent) getAvailableTools() []Tool {
// 从MCP服务器获取所有已注册的内部工具
mcpTools := a.mcpServer.GetAllTools()
// 转换为OpenAI格式的工具定义
tools := make([]Tool, 0, len(mcpTools))
for _, mcpTool := range mcpTools {
@@ -638,10 +640,10 @@ func (a *Agent) getAvailableTools() []Tool {
if description == "" {
description = mcpTool.Description
}
// 转换schema中的类型为OpenAI标准类型
convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema)
tools = append(tools, Tool{
Type: "function",
Function: FunctionDefinition{
@@ -651,24 +653,24 @@ func (a *Agent) getAvailableTools() []Tool {
},
})
}
// 获取外部MCP工具
if a.externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
externalTools, err := a.externalMCPMgr.GetAllTools(ctx)
if err != nil {
a.logger.Warn("获取外部MCP工具失败", zap.Error(err))
} else {
// 获取外部MCP配置用于检查工具启用状态
externalMCPConfigs := a.externalMCPMgr.GetConfigs()
// 清空并重建工具名称映射
a.mu.Lock()
a.toolNameMapping = make(map[string]string)
a.mu.Unlock()
// 将外部MCP工具添加到工具列表只添加启用的工具
for _, externalTool := range externalTools {
// 解析工具名称mcpName::toolName
@@ -679,7 +681,7 @@ func (a *Agent) getAvailableTools() []Tool {
} else {
continue // 跳过格式不正确的工具
}
// 检查工具是否启用
enabled := false
if cfg, exists := externalMCPConfigs[mcpName]; exists {
@@ -698,30 +700,30 @@ func (a *Agent) getAvailableTools() []Tool {
}
}
}
// 只添加启用的工具
if !enabled {
continue
}
// 使用简短描述(如果存在),否则使用详细描述
description := externalTool.ShortDescription
if description == "" {
description = externalTool.Description
}
// 转换schema中的类型为OpenAI标准类型
convertedSchema := a.convertSchemaTypes(externalTool.InputSchema)
// 将工具名称中的 "::" 替换为 "__" 以符合OpenAI命名规范
// OpenAI要求工具名称只能包含 [a-zA-Z0-9_-]
openAIName := strings.ReplaceAll(externalTool.Name, "::", "__")
// 保存名称映射关系OpenAI格式 -> 原始格式)
a.mu.Lock()
a.toolNameMapping[openAIName] = externalTool.Name
a.mu.Unlock()
tools = append(tools, Tool{
Type: "function",
Function: FunctionDefinition{
@@ -733,12 +735,12 @@ func (a *Agent) getAvailableTools() []Tool {
}
}
}
a.logger.Debug("获取可用工具列表",
zap.Int("internalTools", len(mcpTools)),
zap.Int("totalTools", len(tools)),
)
return tools
}
@@ -747,13 +749,13 @@ func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]int
if schema == nil {
return schema
}
// 创建新的schema副本
converted := make(map[string]interface{})
for k, v := range schema {
converted[k] = v
}
// 转换properties中的类型
if properties, ok := converted["properties"].(map[string]interface{}); ok {
convertedProperties := make(map[string]interface{})
@@ -779,7 +781,7 @@ func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]int
}
converted["properties"] = convertedProperties
}
return converted
}
@@ -834,7 +836,7 @@ func (a *Agent) isRetryableError(err error) bool {
func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) {
maxRetries := 3
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
response, err := a.callOpenAISingle(ctx, messages, tools)
if err == nil {
@@ -846,14 +848,14 @@ func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []
}
return response, nil
}
lastErr = err
// 如果不是可重试的错误,直接返回
if !a.isRetryableError(err) {
return nil, err
}
// 如果不是最后一次重试,等待后重试
if attempt < maxRetries-1 {
// 指数退避2s, 4s, 8s...
@@ -867,7 +869,7 @@ func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []
zap.Int("maxRetries", maxRetries),
zap.Duration("backoff", backoff),
)
// 检查上下文是否已取消
select {
case <-ctx.Done():
@@ -877,7 +879,7 @@ func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []
}
}
}
return nil, fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
}
@@ -924,7 +926,7 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to
// 记录响应头接收时间
headerReceiveTime := time.Now()
headerReceiveDuration := headerReceiveTime.Sub(requestStartTime)
a.logger.Debug("收到OpenAI响应头",
zap.Int("statusCode", resp.StatusCode),
zap.Duration("headerReceiveDuration", headerReceiveDuration),
@@ -934,7 +936,7 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to
// 使用带超时的读取通过context控制
bodyChan := make(chan []byte, 1)
errChan := make(chan error, 1)
go func() {
body, err := io.ReadAll(resp.Body)
if err != nil {
@@ -943,7 +945,7 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to
}
bodyChan <- body
}()
var body []byte
select {
case body = <-bodyChan:
@@ -951,7 +953,7 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to
bodyReceiveTime := time.Now()
bodyReceiveDuration := bodyReceiveTime.Sub(headerReceiveTime)
totalDuration := bodyReceiveTime.Sub(requestStartTime)
a.logger.Debug("完成读取OpenAI响应体",
zap.Int("bodySizeKB", len(body)/1024),
zap.Duration("bodyReceiveDuration", bodyReceiveDuration),
@@ -1050,7 +1052,7 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
// 调用内部MCP工具
result, executionID, err = a.mcpServer.CallTool(ctx, toolName, args)
}
// 如果调用失败(如工具不存在),返回友好的错误信息而不是抛出异常
if err != nil {
errorMsg := fmt.Sprintf(`工具调用失败
@@ -1068,7 +1070,7 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
- 检查工具名称是否正确
- 尝试使用其他替代工具
- 如果这是必需的工具,请向用户说明情况`, toolName, err, toolName)
return &ToolExecutionResult{
Result: errorMsg,
ExecutionID: executionID,
@@ -1082,16 +1084,16 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果并保存
a.mu.RLock()
threshold := a.largeResultThreshold
storage := a.resultStorage
a.mu.RUnlock()
if resultSize > threshold && storage != nil {
// 异步保存大结果
go func() {
@@ -1109,11 +1111,15 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
)
}
}()
// 返回最小化通知
lines := strings.Split(resultStr, "\n")
notification := a.formatMinimalNotification(executionID, toolName, resultSize, len(lines))
filePath := ""
if storage != nil {
filePath = storage.GetResultPath(executionID)
}
notification := a.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
return &ToolExecutionResult{
Result: notification,
ExecutionID: executionID,
@@ -1129,20 +1135,53 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
}
// formatMinimalNotification 格式化最小化通知
func (a *Agent) formatMinimalNotification(executionID string, toolName string, size int, lineCount int) string {
func (a *Agent) formatMinimalNotification(executionID string, toolName string, size int, lineCount int, filePath string) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("工具执行完成。结果已保存ID: %s。\n\n", executionID))
sb.WriteString("结果信息:\n")
sb.WriteString(fmt.Sprintf(" - 工具: %s\n", toolName))
sb.WriteString(fmt.Sprintf(" - 大小: %d 字节 (%.2f KB)\n", size, float64(size)/1024))
sb.WriteString(fmt.Sprintf(" - 行数: %d 行\n", lineCount))
if filePath != "" {
sb.WriteString(fmt.Sprintf(" - 文件路径: %s\n", filePath))
}
sb.WriteString("\n")
sb.WriteString("使用以下工具查询完整结果:\n")
sb.WriteString("推荐使用 query_execution_result 工具查询完整结果:\n")
sb.WriteString(fmt.Sprintf(" - 查询第一页: query_execution_result(execution_id=\"%s\", page=1, limit=100)\n", executionID))
sb.WriteString(fmt.Sprintf(" - 搜索关键词: query_execution_result(execution_id=\"%s\", search=\"关键词\")\n", executionID))
sb.WriteString(fmt.Sprintf(" - 过滤条件: query_execution_result(execution_id=\"%s\", filter=\"error\")\n", executionID))
sb.WriteString(fmt.Sprintf(" - 正则匹配: query_execution_result(execution_id=\"%s\", search=\"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", use_regex=true)\n", executionID))
sb.WriteString("\n")
if filePath != "" {
sb.WriteString("如果 query_execution_result 工具不满足需求,也可以使用其他工具处理文件:\n")
sb.WriteString("\n")
sb.WriteString("**分段读取示例:**\n")
sb.WriteString(fmt.Sprintf(" - 查看前100行: exec(command=\"head\", args=[\"-n\", \"100\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 查看后100行: exec(command=\"tail\", args=[\"-n\", \"100\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 查看第50-150行: exec(command=\"sed\", args=[\"-n\", \"50,150p\", \"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**搜索和正则匹配示例:**\n")
sb.WriteString(fmt.Sprintf(" - 搜索关键词: exec(command=\"grep\", args=[\"关键词\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 正则匹配IP地址: exec(command=\"grep\", args=[\"-E\", \"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 不区分大小写搜索: exec(command=\"grep\", args=[\"-i\", \"关键词\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 显示匹配行号: exec(command=\"grep\", args=[\"-n\", \"关键词\", \"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**过滤和统计示例:**\n")
sb.WriteString(fmt.Sprintf(" - 统计总行数: exec(command=\"wc\", args=[\"-l\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 过滤包含error的行: exec(command=\"grep\", args=[\"error\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 排除空行: exec(command=\"grep\", args=[\"-v\", \"^$\", \"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**完整读取(不推荐大文件):**\n")
sb.WriteString(fmt.Sprintf(" - 使用 cat 工具: cat(file=\"%s\")\n", filePath))
sb.WriteString(fmt.Sprintf(" - 使用 exec 工具: exec(command=\"cat\", args=[\"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**注意:**\n")
sb.WriteString(" - 直接读取大文件可能会再次触发大结果保存机制\n")
sb.WriteString(" - 建议优先使用分段读取和搜索功能,避免一次性加载整个文件\n")
sb.WriteString(" - 正则表达式语法遵循标准 POSIX 正则表达式规范\n")
}
return sb.String()
}
@@ -1180,7 +1219,6 @@ func (a *Agent) formatToolError(toolName string, args map[string]interface{}, er
2. 如果工具不可用,请尝试使用替代工具
3. 如果这是系统问题,请向用户说明情况并提供建议
4. 如果错误信息中包含有用信息,可以基于这些信息继续分析`, toolName, args, err)
return errorMsg
}

View File

@@ -53,8 +53,9 @@ func TestAgent_FormatMinimalNotification(t *testing.T) {
toolName := "nmap_scan"
size := 50000
lineCount := 1000
filePath := "tmp/test_exec_001.txt"
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount)
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
// 验证通知包含必要信息
if !strings.Contains(notification, executionID) {
@@ -130,7 +131,8 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
// 生成通知
lines := strings.Split(resultStr, "\n")
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines))
filePath := storage.GetResultPath(executionID)
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
// 验证通知格式
if !strings.Contains(notification, executionID) {

View File

@@ -28,9 +28,10 @@ type ResultStorage interface {
SaveResult(executionID string, toolName string, result string) error
GetResult(executionID string) (string, error)
GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error)
SearchResult(executionID string, keyword string) ([]string, error)
FilterResult(executionID string, filter string) ([]string, error)
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
GetResultMetadata(executionID string) (*storage.ResultMetadata, error)
GetResultPath(executionID string) string
DeleteResult(executionID string) error
}
@@ -755,6 +756,11 @@ func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[str
filter = f
}
useRegex := false
if r, ok := args["use_regex"].(bool); ok {
useRegex = r
}
// 检查结果存储是否可用
if e.resultStorage == nil {
return &mcp.ToolResult{
@@ -774,7 +780,7 @@ func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[str
if search != "" {
// 搜索模式
matchedLines, err := e.resultStorage.SearchResult(executionID, search)
matchedLines, err := e.resultStorage.SearchResult(executionID, search, useRegex)
if err != nil {
return &mcp.ToolResult{
Content: []mcp.Content{
@@ -790,7 +796,7 @@ func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[str
resultPage = paginateLines(matchedLines, page, limit)
} else if filter != "" {
// 过滤模式
filteredLines, err := e.resultStorage.FilterResult(executionID, filter)
filteredLines, err := e.resultStorage.FilterResult(executionID, filter, useRegex)
if err != nil {
return &mcp.ToolResult{
Content: []mcp.Content{
@@ -853,9 +859,15 @@ func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[str
sb.WriteString(fmt.Sprintf("提示: 使用 page=%d 查看下一页", resultPage.Page+1))
if search != "" {
sb.WriteString(fmt.Sprintf(",或使用 search=\"%s\" 继续搜索", search))
if useRegex {
sb.WriteString(" (正则模式)")
}
}
if filter != "" {
sb.WriteString(fmt.Sprintf(",或使用 filter=\"%s\" 继续过滤", filter))
if useRegex {
sb.WriteString(" (正则模式)")
}
}
sb.WriteString("\n")
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
@@ -16,22 +17,27 @@ import (
type ResultStorage interface {
// SaveResult 保存工具执行结果
SaveResult(executionID string, toolName string, result string) error
// GetResult 获取完整结果
GetResult(executionID string) (string, error)
// GetResultPage 分页获取结果
GetResultPage(executionID string, page int, limit int) (*ResultPage, error)
// SearchResult 搜索结果
SearchResult(executionID string, keyword string) ([]string, error)
// useRegex: 如果为 true将 keyword 作为正则表达式使用;如果为 false使用简单的字符串包含匹配
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
// FilterResult 过滤结果
FilterResult(executionID string, filter string) ([]string, error)
// useRegex: 如果为 true将 filter 作为正则表达式使用;如果为 false使用简单的字符串包含匹配
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
// GetResultMetadata 获取结果元信息
GetResultMetadata(executionID string) (*ResultMetadata, error)
// GetResultPath 获取结果文件路径
GetResultPath(executionID string) string
// DeleteResult 删除结果
DeleteResult(executionID string) error
}
@@ -67,7 +73,7 @@ func NewFileResultStorage(baseDir string, logger *zap.Logger) (*FileResultStorag
if err := os.MkdirAll(baseDir, 0755); err != nil {
return nil, fmt.Errorf("创建存储目录失败: %w", err)
}
return &FileResultStorage{
baseDir: baseDir,
logger: logger,
@@ -88,13 +94,13 @@ func (s *FileResultStorage) getMetadataPath(executionID string) string {
func (s *FileResultStorage) SaveResult(executionID string, toolName string, result string) error {
s.mu.Lock()
defer s.mu.Unlock()
// 保存结果文件
resultPath := s.getResultPath(executionID)
if err := os.WriteFile(resultPath, []byte(result), 0644); err != nil {
return fmt.Errorf("保存结果文件失败: %w", err)
}
// 计算统计信息
lines := strings.Split(result, "\n")
metadata := &ResultMetadata{
@@ -104,25 +110,25 @@ func (s *FileResultStorage) SaveResult(executionID string, toolName string, resu
TotalLines: len(lines),
CreatedAt: time.Now(),
}
// 保存元数据
metadataPath := s.getMetadataPath(executionID)
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("序列化元数据失败: %w", err)
}
if err := os.WriteFile(metadataPath, metadataJSON, 0644); err != nil {
return fmt.Errorf("保存元数据文件失败: %w", err)
}
s.logger.Info("保存工具执行结果",
zap.String("executionID", executionID),
zap.String("toolName", toolName),
zap.Int("size", len(result)),
zap.Int("lines", len(lines)),
)
return nil
}
@@ -130,7 +136,7 @@ func (s *FileResultStorage) SaveResult(executionID string, toolName string, resu
func (s *FileResultStorage) GetResult(executionID string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
resultPath := s.getResultPath(executionID)
data, err := os.ReadFile(resultPath)
if err != nil {
@@ -139,7 +145,7 @@ func (s *FileResultStorage) GetResult(executionID string) (string, error) {
}
return "", fmt.Errorf("读取结果文件失败: %w", err)
}
return string(data), nil
}
@@ -147,7 +153,7 @@ func (s *FileResultStorage) GetResult(executionID string) (string, error) {
func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetadata, error) {
s.mu.RLock()
defer s.mu.RUnlock()
metadataPath := s.getMetadataPath(executionID)
data, err := os.ReadFile(metadataPath)
if err != nil {
@@ -156,12 +162,12 @@ func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetada
}
return nil, fmt.Errorf("读取元数据文件失败: %w", err)
}
var metadata ResultMetadata
if err := json.Unmarshal(data, &metadata); err != nil {
return nil, fmt.Errorf("解析元数据失败: %w", err)
}
return &metadata, nil
}
@@ -169,17 +175,17 @@ func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetada
func (s *FileResultStorage) GetResultPage(executionID string, page int, limit int) (*ResultPage, error) {
s.mu.RLock()
defer s.mu.RUnlock()
// 获取完整结果
result, err := s.GetResult(executionID)
if err != nil {
return nil, err
}
// 分割为行
lines := strings.Split(result, "\n")
totalLines := len(lines)
// 计算分页
totalPages := (totalLines + limit - 1) / limit
if page < 1 {
@@ -188,14 +194,14 @@ func (s *FileResultStorage) GetResultPage(executionID string, page int, limit in
if page > totalPages && totalPages > 0 {
page = totalPages
}
// 计算起始和结束索引
start := (page - 1) * limit
end := start + limit
if end > totalLines {
end = totalLines
}
// 提取指定页的行
var pageLines []string
if start < totalLines {
@@ -203,7 +209,7 @@ func (s *FileResultStorage) GetResultPage(executionID string, page int, limit in
} else {
pageLines = []string{}
}
return &ResultPage{
Lines: pageLines,
Page: page,
@@ -214,57 +220,78 @@ func (s *FileResultStorage) GetResultPage(executionID string, page int, limit in
}
// SearchResult 搜索结果
func (s *FileResultStorage) SearchResult(executionID string, keyword string) ([]string, error) {
func (s *FileResultStorage) SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
// 获取完整结果
result, err := s.GetResult(executionID)
if err != nil {
return nil, err
}
// 如果使用正则表达式,先编译正则
var regex *regexp.Regexp
if useRegex {
compiledRegex, err := regexp.Compile(keyword)
if err != nil {
return nil, fmt.Errorf("无效的正则表达式: %w", err)
}
regex = compiledRegex
}
// 分割为行并搜索
lines := strings.Split(result, "\n")
var matchedLines []string
for _, line := range lines {
if strings.Contains(line, keyword) {
var matched bool
if useRegex {
matched = regex.MatchString(line)
} else {
matched = strings.Contains(line, keyword)
}
if matched {
matchedLines = append(matchedLines, line)
}
}
return matchedLines, nil
}
// FilterResult 过滤结果
func (s *FileResultStorage) FilterResult(executionID string, filter string) ([]string, error) {
func (s *FileResultStorage) FilterResult(executionID string, filter string, useRegex bool) ([]string, error) {
// 过滤和搜索逻辑相同,都是查找包含关键词的行
return s.SearchResult(executionID, filter)
return s.SearchResult(executionID, filter, useRegex)
}
// GetResultPath 获取结果文件路径
func (s *FileResultStorage) GetResultPath(executionID string) string {
return s.getResultPath(executionID)
}
// DeleteResult 删除结果
func (s *FileResultStorage) DeleteResult(executionID string) error {
s.mu.Lock()
defer s.mu.Unlock()
resultPath := s.getResultPath(executionID)
metadataPath := s.getMetadataPath(executionID)
// 删除结果文件
if err := os.Remove(resultPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("删除结果文件失败: %w", err)
}
// 删除元数据文件
if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("删除元数据文件失败: %w", err)
}
s.logger.Info("删除工具执行结果",
zap.String("executionID", executionID),
)
return nil
}

View File

@@ -256,8 +256,8 @@ func TestFileResultStorage_SearchResult(t *testing.T) {
t.Fatalf("保存结果失败: %v", err)
}
// 搜索包含"error"的行
matchedLines, err := storage.SearchResult(executionID, "error")
// 搜索包含"error"的行(简单字符串匹配)
matchedLines, err := storage.SearchResult(executionID, "error", false)
if err != nil {
t.Fatalf("搜索失败: %v", err)
}
@@ -274,7 +274,7 @@ func TestFileResultStorage_SearchResult(t *testing.T) {
}
// 测试搜索不存在的关键词
noMatch, err := storage.SearchResult(executionID, "nonexistent")
noMatch, err := storage.SearchResult(executionID, "nonexistent", false)
if err != nil {
t.Fatalf("搜索失败: %v", err)
}
@@ -282,6 +282,16 @@ func TestFileResultStorage_SearchResult(t *testing.T) {
if len(noMatch) != 0 {
t.Errorf("搜索不存在的关键词应该返回空结果。实际: %d行", len(noMatch))
}
// 测试正则表达式搜索
regexMatched, err := storage.SearchResult(executionID, "error.*again", true)
if err != nil {
t.Fatalf("正则搜索失败: %v", err)
}
if len(regexMatched) != 1 {
t.Errorf("正则搜索结果数量不匹配。期望: 1, 实际: %d", len(regexMatched))
}
}
func TestFileResultStorage_FilterResult(t *testing.T) {
@@ -298,8 +308,8 @@ func TestFileResultStorage_FilterResult(t *testing.T) {
t.Fatalf("保存结果失败: %v", err)
}
// 过滤包含"warning"的行
filteredLines, err := storage.FilterResult(executionID, "warning")
// 过滤包含"warning"的行(简单字符串匹配)
filteredLines, err := storage.FilterResult(executionID, "warning", false)
if err != nil {
t.Fatalf("过滤失败: %v", err)
}