Add files via upload

This commit is contained in:
公明
2025-11-14 01:44:28 +08:00
committed by GitHub
parent 1b14070cee
commit f8dbfbb65f
5 changed files with 144 additions and 66 deletions
+1
View File
@@ -10,6 +10,7 @@
![Preview](./img/mcp-stdio2.png)
## Changelog
- 2025.11.14 Performance optimizations: optimized tool lookup from O(n) to O(1) using index map, added automatic cleanup mechanism for execution records to prevent memory leaks, and added pagination support for database queries
- 2025.11.13 Added authentication for the web mode, including automatic password generation and in-app password change
- 2025.11.13 Added `Settings` feature in the frontend
- 2025.11.13 Added MCP Stdio mode support, now seamlessly integrated and usable in code editors, CLI, and automation scripts
+1
View File
@@ -7,6 +7,7 @@
![详情预览](./img/mcp-stdio2.png)
## 更新日志
- 2025.11.14 性能优化:工具查找从 O(n) 优化为 O(1)(使用索引映射),添加执行记录自动清理机制防止内存泄漏,数据库查询支持分页加载
- 2025.11.13 Web 端新增统一鉴权,支持自动生成强密码与前端修改密码;
- 2025.11.13 在前端新增`设置`功能;
- 2025.11.13 新增 MCP Stdio 模式支持,现可在代码编辑器、CLI 及自动化脚本等多种场景下,无缝集成并使用全套安全工具;
+17 -3
View File
@@ -69,16 +69,30 @@ func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
return nil
}
// LoadToolExecutions 加载所有工具执行记录
// LoadToolExecutions 加载所有工具执行记录(支持分页)
func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) {
return db.LoadToolExecutionsWithPagination(0, 1000)
}
// LoadToolExecutionsWithPagination 分页加载工具执行记录
// limit: 最大返回记录数,0 表示使用默认值 1000
// offset: 跳过的记录数,用于分页
func (db *DB) LoadToolExecutionsWithPagination(offset, limit int) ([]*mcp.ToolExecution, error) {
if limit <= 0 {
limit = 1000 // 默认限制
}
if limit > 10000 {
limit = 10000 // 最大限制,防止一次性加载过多数据
}
query := `
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
FROM tool_executions
ORDER BY start_time DESC
LIMIT 1000
LIMIT ? OFFSET ?
`
rows, err := db.Query(query)
rows, err := db.Query(query, limit, offset)
if err != nil {
return nil, err
}
+65 -19
View File
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"os"
"sort"
"strings"
"sync"
"time"
@@ -27,15 +28,16 @@ type MonitorStorage interface {
// Server MCP服务器
type Server struct {
tools map[string]ToolHandler
toolDefs map[string]Tool // 工具定义
executions map[string]*ToolExecution
stats map[string]*ToolStats
prompts map[string]*Prompt // 提示词模板
resources map[string]*Resource // 资源
storage MonitorStorage // 可选的持久化存储
mu sync.RWMutex
logger *zap.Logger
tools map[string]ToolHandler
toolDefs map[string]Tool // 工具定义
executions map[string]*ToolExecution
stats map[string]*ToolStats
prompts map[string]*Prompt // 提示词模板
resources map[string]*Resource // 资源
storage MonitorStorage // 可选的持久化存储
mu sync.RWMutex
logger *zap.Logger
maxExecutionsInMemory int // 内存中最大执行记录数
}
// ToolHandler 工具处理函数
@@ -49,14 +51,15 @@ func NewServer(logger *zap.Logger) *Server {
// NewServerWithStorage 创建新的MCP服务器(带持久化存储)
func NewServerWithStorage(logger *zap.Logger, storage MonitorStorage) *Server {
s := &Server{
tools: make(map[string]ToolHandler),
toolDefs: make(map[string]Tool),
executions: make(map[string]*ToolExecution),
stats: make(map[string]*ToolStats),
prompts: make(map[string]*Prompt),
resources: make(map[string]*Resource),
storage: storage,
logger: logger,
tools: make(map[string]ToolHandler),
toolDefs: make(map[string]Tool),
executions: make(map[string]*ToolExecution),
stats: make(map[string]*ToolStats),
prompts: make(map[string]*Prompt),
resources: make(map[string]*Resource),
storage: storage,
logger: logger,
maxExecutionsInMemory: 1000, // 默认最多在内存中保留1000条执行记录
}
// 初始化默认提示词和资源
@@ -267,6 +270,8 @@ func (s *Server) handleCallTool(msg *Message) *Message {
s.mu.Lock()
s.executions[executionID] = execution
// 如果内存中的执行记录超过限制,清理最旧的记录
s.cleanupOldExecutions()
s.mu.Unlock()
if s.storage != nil {
@@ -499,9 +504,11 @@ func (s *Server) loadHistoricalData() {
} else {
s.mu.Lock()
for _, exec := range executions {
// 只加载最近1000条,避免内存占用过大
if len(s.executions) < 1000 {
// 只加载最近 maxExecutionsInMemory 条,避免内存占用过大
if len(s.executions) < s.maxExecutionsInMemory {
s.executions[exec.ID] = exec
} else {
break
}
}
s.mu.Unlock()
@@ -618,6 +625,8 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
s.mu.Lock()
s.executions[executionID] = execution
// 如果内存中的执行记录超过限制,清理最旧的记录
s.cleanupOldExecutions()
s.mu.Unlock()
if s.storage != nil {
@@ -689,6 +698,43 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
return finalResult, executionID, nil
}
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
func (s *Server) cleanupOldExecutions() {
if len(s.executions) <= s.maxExecutionsInMemory {
return
}
// 按开始时间排序,找出最旧的记录
type execWithTime struct {
id string
startTime time.Time
}
execs := make([]execWithTime, 0, len(s.executions))
for id, exec := range s.executions {
execs = append(execs, execWithTime{
id: id,
startTime: exec.StartTime,
})
}
// 使用 sort 包进行高效排序(最旧的在前)
sort.Slice(execs, func(i, j int) bool {
return execs[i].startTime.Before(execs[j].startTime)
})
// 删除最旧的记录,保留 maxExecutionsInMemory 条
toDelete := len(s.executions) - s.maxExecutionsInMemory
for i := 0; i < toDelete; i++ {
delete(s.executions, execs[i].id)
}
s.logger.Debug("清理旧的执行记录",
zap.Int("before", len(execs)),
zap.Int("after", len(s.executions)),
zap.Int("deleted", toDelete),
)
}
// initDefaultPrompts 初始化默认提示词模板
func (s *Server) initDefaultPrompts() {
s.mu.Lock()
+60 -44
View File
@@ -9,23 +9,43 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
// Executor 安全工具执行器
type Executor struct {
config *config.SecurityConfig
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
mcpServer *mcp.Server
logger *zap.Logger
}
// NewExecutor 创建新的执行器
func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor {
return &Executor{
executor := &Executor{
config: cfg,
toolIndex: make(map[string]*config.ToolConfig),
mcpServer: mcpServer,
logger: logger,
}
// 构建工具索引
executor.buildToolIndex()
return executor
}
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
func (e *Executor) buildToolIndex() {
e.toolIndex = make(map[string]*config.ToolConfig)
for i := range e.config.Tools {
if e.config.Tools[i].Enabled {
e.toolIndex[e.config.Tools[i].Name] = &e.config.Tools[i]
}
}
e.logger.Info("工具索引构建完成",
zap.Int("totalTools", len(e.config.Tools)),
zap.Int("enabledTools", len(e.toolIndex)),
)
}
// ExecuteTool 执行安全工具
@@ -34,30 +54,24 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
zap.String("toolName", toolName),
zap.Any("args", args),
)
// 特殊处理:exec工具直接执行系统命令
if toolName == "exec" {
e.logger.Info("执行exec工具")
return e.executeSystemCommand(ctx, args)
}
// 查找工具配置
var toolConfig *config.ToolConfig
for i := range e.config.Tools {
if e.config.Tools[i].Name == toolName && e.config.Tools[i].Enabled {
toolConfig = &e.config.Tools[i]
break
}
}
if toolConfig == nil {
// 使用索引查找工具配置O(1) 查找)
toolConfig, exists := e.toolIndex[toolName]
if !exists {
e.logger.Error("工具未找到或未启用",
zap.String("toolName", toolName),
zap.Int("totalTools", len(e.config.Tools)),
zap.Int("enabledTools", len(e.toolIndex)),
)
return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName)
}
e.logger.Info("找到工具配置",
zap.String("toolName", toolName),
zap.String("command", toolConfig.Command),
@@ -66,13 +80,13 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
// 构建命令 - 根据工具类型使用不同的参数格式
cmdArgs := e.buildCommandArgs(toolName, toolConfig, args)
e.logger.Info("构建命令参数完成",
zap.String("toolName", toolName),
zap.Strings("cmdArgs", cmdArgs),
zap.Int("argsCount", len(cmdArgs)),
)
// 验证命令参数
if len(cmdArgs) == 0 {
e.logger.Warn("命令参数为空",
@@ -92,7 +106,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
// 执行命令
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
e.logger.Info("执行安全工具",
zap.String("tool", toolName),
zap.Strings("args", cmdArgs),
@@ -136,8 +150,12 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
e.logger.Info("开始注册工具",
zap.Int("totalTools", len(e.config.Tools)),
zap.Int("enabledTools", len(e.toolIndex)),
)
// 重新构建索引(以防配置更新)
e.buildToolIndex()
for i, toolConfig := range e.config.Tools {
if !toolConfig.Enabled {
e.logger.Debug("跳过未启用的工具",
@@ -149,7 +167,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
// 创建工具配置的副本,避免闭包问题
toolName := toolConfig.Name
toolConfigCopy := toolConfig
// 使用简短描述(如果存在),否则使用详细描述的前100个字符
shortDesc := toolConfigCopy.ShortDescription
if shortDesc == "" {
@@ -166,7 +184,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
shortDesc = desc
}
}
tool := mcp.Tool{
Name: toolConfigCopy.Name,
Description: toolConfigCopy.Description,
@@ -189,7 +207,7 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
zap.Int("index", i),
)
}
e.logger.Info("工具注册完成",
zap.Int("registeredCount", len(e.config.Tools)),
)
@@ -208,7 +226,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
hasScanType = true
scanTypeValue = scanType
}
// 添加固定参数(如果指定了 scan_type,可能需要过滤掉默认的扫描类型参数)
if hasScanType && toolName == "nmap" {
// 对于 nmap,如果指定了 scan_type,跳过默认的 -sT -sV -sC
@@ -237,7 +255,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" {
continue
}
value := e.getParamValue(args, param)
if value == nil {
if param.Required {
@@ -255,7 +273,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
if param.Type == "bool" {
var boolVal bool
var ok bool
// 尝试多种类型转换
if boolVal, ok = value.(bool); ok {
// 已经是布尔值
@@ -272,7 +290,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
boolVal = strVal == "true" || strVal == "1" || strVal == "yes"
ok = true
}
if ok {
if !boolVal {
continue // false 时不添加任何参数
@@ -340,7 +358,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" {
continue
}
if param.Position != nil && *param.Position == i {
value := e.getParamValue(args, param)
if value == nil {
@@ -365,14 +383,14 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
}
}
}
// 特殊处理:additional_args 参数(需要按空格分割成多个参数)
if additionalArgs, ok := args["additional_args"].(string); ok && additionalArgs != "" {
// 按空格分割,但保留引号内的内容
additionalArgsList := e.parseAdditionalArgs(additionalArgs)
cmdArgs = append(cmdArgs, additionalArgsList...)
}
// 特殊处理:scan_type 参数(需要按空格分割并插入到合适位置)
if hasScanType {
scanTypeArgs := e.parseAdditionalArgs(scanTypeValue)
@@ -403,7 +421,7 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
// 如果没有定义参数配置,使用固定参数和通用处理
// 添加固定参数
cmdArgs = append(cmdArgs, toolConfig.Args...)
// 通用处理:将参数转换为命令行参数
for key, value := range args {
if key == "_tool_name" {
@@ -426,23 +444,23 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string {
if argsStr == "" {
return []string{}
}
result := make([]string, 0)
var current strings.Builder
inQuotes := false
var quoteChar rune
escapeNext := false
runes := []rune(argsStr)
for i := 0; i < len(runes); i++ {
r := runes[i]
if escapeNext {
current.WriteRune(r)
escapeNext = false
continue
}
if r == '\\' {
// 检查下一个字符是否是引号
if i+1 < len(runes) && (runes[i+1] == '"' || runes[i+1] == '\'') {
@@ -456,19 +474,19 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string {
}
continue
}
if !inQuotes && (r == '"' || r == '\'') {
inQuotes = true
quoteChar = r
continue
}
if inQuotes && r == quoteChar {
inQuotes = false
quoteChar = 0
continue
}
if !inQuotes && (r == ' ' || r == '\t' || r == '\n') {
if current.Len() > 0 {
result = append(result, current.String())
@@ -476,20 +494,20 @@ func (e *Executor) parseAdditionalArgs(argsStr string) []string {
}
continue
}
current.WriteRune(r)
}
// 处理最后一个参数(如果存在)
if current.Len() > 0 {
result = append(result, current.String())
}
// 如果解析结果为空,使用简单的空格分割作为降级方案
if len(result) == 0 {
result = strings.Fields(argsStr)
}
return result
}
@@ -638,9 +656,9 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
// buildInputSchema 构建输入模式
func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
schema := map[string]interface{}{
"type": "object",
"type": "object",
"properties": map[string]interface{}{},
"required": []string{},
"required": []string{},
}
// 如果配置中定义了参数,优先使用配置中的参数定义
@@ -651,7 +669,7 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in
for _, param := range toolConfig.Parameters {
// 转换类型为OpenAI/JSON Schema标准类型
openAIType := e.convertToOpenAIType(param.Type)
prop := map[string]interface{}{
"type": openAIType,
"description": param.Description,
@@ -750,5 +768,3 @@ func (e *Executor) GetVulnerabilityReport(vulnerabilities []Vulnerability) map[s
"generatedAt": time.Now(),
}
}