diff --git a/README.md b/README.md index 195cc46c..b095dbd6 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ ![Preview](./img/mcp-stdio2.png) ## Changelog +- 2025.11.14 Performance optimizations: optimized tool lookup from O(n) to O(1) using index map, added automatic cleanup mechanism for execution records to prevent memory leaks, and added pagination support for database queries - 2025.11.13 Added authentication for the web mode, including automatic password generation and in-app password change - 2025.11.13 Added `Settings` feature in the frontend - 2025.11.13 Added MCP Stdio mode support, now seamlessly integrated and usable in code editors, CLI, and automation scripts diff --git a/README_CN.md b/README_CN.md index c9b7b14a..5bcb077f 100644 --- a/README_CN.md +++ b/README_CN.md @@ -7,6 +7,7 @@ ![详情预览](./img/mcp-stdio2.png) ## 更新日志 +- 2025.11.14 性能优化:工具查找从 O(n) 优化为 O(1)(使用索引映射),添加执行记录自动清理机制防止内存泄漏,数据库查询支持分页加载 - 2025.11.13 Web 端新增统一鉴权,支持自动生成强密码与前端修改密码; - 2025.11.13 在前端新增`设置`功能; - 2025.11.13 新增 MCP Stdio 模式支持,现可在代码编辑器、CLI 及自动化脚本等多种场景下,无缝集成并使用全套安全工具; diff --git a/internal/database/monitor.go b/internal/database/monitor.go index 5b899f87..8140c3b9 100644 --- a/internal/database/monitor.go +++ b/internal/database/monitor.go @@ -69,16 +69,30 @@ func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error { return nil } -// LoadToolExecutions 加载所有工具执行记录 +// LoadToolExecutions 加载所有工具执行记录(支持分页) func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) { + return db.LoadToolExecutionsWithPagination(0, 1000) +} + +// LoadToolExecutionsWithPagination 分页加载工具执行记录 +// limit: 最大返回记录数,0 表示使用默认值 1000 +// offset: 跳过的记录数,用于分页 +func (db *DB) LoadToolExecutionsWithPagination(offset, limit int) ([]*mcp.ToolExecution, error) { + if limit <= 0 { + limit = 1000 // 默认限制 + } + if limit > 10000 { + limit = 10000 // 最大限制,防止一次性加载过多数据 + } + query := ` SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms FROM tool_executions ORDER BY start_time DESC - LIMIT 1000 + LIMIT ? OFFSET ? ` - rows, err := db.Query(query) + rows, err := db.Query(query, limit, offset) if err != nil { return nil, err } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index f6601c92..6980dd08 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "sort" "strings" "sync" "time" @@ -27,15 +28,16 @@ type MonitorStorage interface { // Server MCP服务器 type Server struct { - tools map[string]ToolHandler - toolDefs map[string]Tool // 工具定义 - executions map[string]*ToolExecution - stats map[string]*ToolStats - prompts map[string]*Prompt // 提示词模板 - resources map[string]*Resource // 资源 - storage MonitorStorage // 可选的持久化存储 - mu sync.RWMutex - logger *zap.Logger + tools map[string]ToolHandler + toolDefs map[string]Tool // 工具定义 + executions map[string]*ToolExecution + stats map[string]*ToolStats + prompts map[string]*Prompt // 提示词模板 + resources map[string]*Resource // 资源 + storage MonitorStorage // 可选的持久化存储 + mu sync.RWMutex + logger *zap.Logger + maxExecutionsInMemory int // 内存中最大执行记录数 } // ToolHandler 工具处理函数 @@ -49,14 +51,15 @@ func NewServer(logger *zap.Logger) *Server { // NewServerWithStorage 创建新的MCP服务器(带持久化存储) func NewServerWithStorage(logger *zap.Logger, storage MonitorStorage) *Server { s := &Server{ - tools: make(map[string]ToolHandler), - toolDefs: make(map[string]Tool), - executions: make(map[string]*ToolExecution), - stats: make(map[string]*ToolStats), - prompts: make(map[string]*Prompt), - resources: make(map[string]*Resource), - storage: storage, - logger: logger, + tools: make(map[string]ToolHandler), + toolDefs: make(map[string]Tool), + executions: make(map[string]*ToolExecution), + stats: make(map[string]*ToolStats), + prompts: make(map[string]*Prompt), + resources: make(map[string]*Resource), + storage: storage, + logger: logger, + maxExecutionsInMemory: 1000, // 默认最多在内存中保留1000条执行记录 } // 初始化默认提示词和资源 @@ -267,6 +270,8 @@ func (s *Server) handleCallTool(msg *Message) *Message { s.mu.Lock() s.executions[executionID] = execution + // 如果内存中的执行记录超过限制,清理最旧的记录 + s.cleanupOldExecutions() s.mu.Unlock() if s.storage != nil { @@ -499,9 +504,11 @@ func (s *Server) loadHistoricalData() { } else { s.mu.Lock() for _, exec := range executions { - // 只加载最近1000条,避免内存占用过大 - if len(s.executions) < 1000 { + // 只加载最近 maxExecutionsInMemory 条,避免内存占用过大 + if len(s.executions) < s.maxExecutionsInMemory { s.executions[exec.ID] = exec + } else { + break } } s.mu.Unlock() @@ -618,6 +625,8 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string] s.mu.Lock() s.executions[executionID] = execution + // 如果内存中的执行记录超过限制,清理最旧的记录 + s.cleanupOldExecutions() s.mu.Unlock() if s.storage != nil { @@ -689,6 +698,43 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string] return finalResult, executionID, nil } +// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长 +func (s *Server) cleanupOldExecutions() { + if len(s.executions) <= s.maxExecutionsInMemory { + return + } + + // 按开始时间排序,找出最旧的记录 + type execWithTime struct { + id string + startTime time.Time + } + execs := make([]execWithTime, 0, len(s.executions)) + for id, exec := range s.executions { + execs = append(execs, execWithTime{ + id: id, + startTime: exec.StartTime, + }) + } + + // 使用 sort 包进行高效排序(最旧的在前) + sort.Slice(execs, func(i, j int) bool { + return execs[i].startTime.Before(execs[j].startTime) + }) + + // 删除最旧的记录,保留 maxExecutionsInMemory 条 + toDelete := len(s.executions) - s.maxExecutionsInMemory + for i := 0; i < toDelete; i++ { + delete(s.executions, execs[i].id) + } + + s.logger.Debug("清理旧的执行记录", + zap.Int("before", len(execs)), + zap.Int("after", len(s.executions)), + zap.Int("deleted", toDelete), + ) +} + // initDefaultPrompts 初始化默认提示词模板 func (s *Server) initDefaultPrompts() { s.mu.Lock() diff --git a/internal/security/executor.go b/internal/security/executor.go index 362a03cf..d88984b7 100644 --- a/internal/security/executor.go +++ b/internal/security/executor.go @@ -9,23 +9,43 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" + "go.uber.org/zap" ) // Executor 安全工具执行器 type Executor struct { config *config.SecurityConfig + toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找 mcpServer *mcp.Server logger *zap.Logger } // NewExecutor 创建新的执行器 func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor { - return &Executor{ + executor := &Executor{ config: cfg, + toolIndex: make(map[string]*config.ToolConfig), mcpServer: mcpServer, logger: logger, } + // 构建工具索引 + executor.buildToolIndex() + return executor +} + +// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1) +func (e *Executor) buildToolIndex() { + e.toolIndex = make(map[string]*config.ToolConfig) + for i := range e.config.Tools { + if e.config.Tools[i].Enabled { + e.toolIndex[e.config.Tools[i].Name] = &e.config.Tools[i] + } + } + e.logger.Info("工具索引构建完成", + zap.Int("totalTools", len(e.config.Tools)), + zap.Int("enabledTools", len(e.toolIndex)), + ) } // ExecuteTool 执行安全工具 @@ -34,30 +54,24 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st zap.String("toolName", toolName), zap.Any("args", args), ) - + // 特殊处理:exec工具直接执行系统命令 if toolName == "exec" { e.logger.Info("执行exec工具") return e.executeSystemCommand(ctx, args) } - // 查找工具配置 - var toolConfig *config.ToolConfig - for i := range e.config.Tools { - if e.config.Tools[i].Name == toolName && e.config.Tools[i].Enabled { - toolConfig = &e.config.Tools[i] - break - } - } - - if toolConfig == nil { + // 使用索引查找工具配置(O(1) 查找) + toolConfig, exists := e.toolIndex[toolName] + if !exists { e.logger.Error("工具未找到或未启用", zap.String("toolName", toolName), zap.Int("totalTools", len(e.config.Tools)), + zap.Int("enabledTools", len(e.toolIndex)), ) return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName) } - + e.logger.Info("找到工具配置", zap.String("toolName", toolName), zap.String("command", toolConfig.Command), @@ -66,13 +80,13 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st // 构建命令 - 根据工具类型使用不同的参数格式 cmdArgs := e.buildCommandArgs(toolName, toolConfig, args) - + e.logger.Info("构建命令参数完成", zap.String("toolName", toolName), zap.Strings("cmdArgs", cmdArgs), zap.Int("argsCount", len(cmdArgs)), ) - + // 验证命令参数 if len(cmdArgs) == 0 { e.logger.Warn("命令参数为空", @@ -92,7 +106,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st // 执行命令 cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...) - + e.logger.Info("执行安全工具", zap.String("tool", toolName), zap.Strings("args", cmdArgs), @@ -136,8 +150,12 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st func (e *Executor) RegisterTools(mcpServer *mcp.Server) { e.logger.Info("开始注册工具", zap.Int("totalTools", len(e.config.Tools)), + zap.Int("enabledTools", len(e.toolIndex)), ) - + + // 重新构建索引(以防配置更新) + e.buildToolIndex() + for i, toolConfig := range e.config.Tools { if !toolConfig.Enabled { e.logger.Debug("跳过未启用的工具", @@ -149,7 +167,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) { // 创建工具配置的副本,避免闭包问题 toolName := toolConfig.Name toolConfigCopy := toolConfig - + // 使用简短描述(如果存在),否则使用详细描述的前100个字符 shortDesc := toolConfigCopy.ShortDescription if shortDesc == "" { @@ -166,7 +184,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) { shortDesc = desc } } - + tool := mcp.Tool{ Name: toolConfigCopy.Name, Description: toolConfigCopy.Description, @@ -189,7 +207,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) { zap.Int("index", i), ) } - + e.logger.Info("工具注册完成", zap.Int("registeredCount", len(e.config.Tools)), ) @@ -208,7 +226,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf hasScanType = true scanTypeValue = scanType } - + // 添加固定参数(如果指定了 scan_type,可能需要过滤掉默认的扫描类型参数) if hasScanType && toolName == "nmap" { // 对于 nmap,如果指定了 scan_type,跳过默认的 -sT -sV -sC @@ -237,7 +255,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" { continue } - + value := e.getParamValue(args, param) if value == nil { if param.Required { @@ -255,7 +273,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf if param.Type == "bool" { var boolVal bool var ok bool - + // 尝试多种类型转换 if boolVal, ok = value.(bool); ok { // 已经是布尔值 @@ -272,7 +290,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf boolVal = strVal == "true" || strVal == "1" || strVal == "yes" ok = true } - + if ok { if !boolVal { continue // false 时不添加任何参数 @@ -340,7 +358,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" { continue } - + if param.Position != nil && *param.Position == i { value := e.getParamValue(args, param) if value == nil { @@ -365,14 +383,14 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf } } } - + // 特殊处理:additional_args 参数(需要按空格分割成多个参数) if additionalArgs, ok := args["additional_args"].(string); ok && additionalArgs != "" { // 按空格分割,但保留引号内的内容 additionalArgsList := e.parseAdditionalArgs(additionalArgs) cmdArgs = append(cmdArgs, additionalArgsList...) } - + // 特殊处理:scan_type 参数(需要按空格分割并插入到合适位置) if hasScanType { scanTypeArgs := e.parseAdditionalArgs(scanTypeValue) @@ -403,7 +421,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf // 如果没有定义参数配置,使用固定参数和通用处理 // 添加固定参数 cmdArgs = append(cmdArgs, toolConfig.Args...) - + // 通用处理:将参数转换为命令行参数 for key, value := range args { if key == "_tool_name" { @@ -426,23 +444,23 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string { if argsStr == "" { return []string{} } - + result := make([]string, 0) var current strings.Builder inQuotes := false var quoteChar rune escapeNext := false - + runes := []rune(argsStr) for i := 0; i < len(runes); i++ { r := runes[i] - + if escapeNext { current.WriteRune(r) escapeNext = false continue } - + if r == '\\' { // 检查下一个字符是否是引号 if i+1 < len(runes) && (runes[i+1] == '"' || runes[i+1] == '\'') { @@ -456,19 +474,19 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string { } continue } - + if !inQuotes && (r == '"' || r == '\'') { inQuotes = true quoteChar = r continue } - + if inQuotes && r == quoteChar { inQuotes = false quoteChar = 0 continue } - + if !inQuotes && (r == ' ' || r == '\t' || r == '\n') { if current.Len() > 0 { result = append(result, current.String()) @@ -476,20 +494,20 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string { } continue } - + current.WriteRune(r) } - + // 处理最后一个参数(如果存在) if current.Len() > 0 { result = append(result, current.String()) } - + // 如果解析结果为空,使用简单的空格分割作为降级方案 if len(result) == 0 { result = strings.Fields(argsStr) } - + return result } @@ -638,9 +656,9 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int // buildInputSchema 构建输入模式 func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} { schema := map[string]interface{}{ - "type": "object", + "type": "object", "properties": map[string]interface{}{}, - "required": []string{}, + "required": []string{}, } // 如果配置中定义了参数,优先使用配置中的参数定义 @@ -651,7 +669,7 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in for _, param := range toolConfig.Parameters { // 转换类型为OpenAI/JSON Schema标准类型 openAIType := e.convertToOpenAIType(param.Type) - + prop := map[string]interface{}{ "type": openAIType, "description": param.Description, @@ -750,5 +768,3 @@ func (e *Executor) GetVulnerabilityReport(vulnerabilities []Vulnerability) map[s "generatedAt": time.Now(), } } - -