diff --git a/README.md b/README.md index d6304c63..0cfbc5eb 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ ![详情预览](./img/效果1.png) ## 更新日志 + - 2025.11.13 新增 MCP stdio 模式支持,可在 Cursor IDE 中直接使用所有安全工具; - 2025.11.12 增加了任务停止功能,优化前端; ## ✨ 功能特性 @@ -19,6 +20,7 @@ ### 工具集成 - 🔌 **MCP协议支持** - 完整实现MCP协议,支持工具注册、调用、监控 +- 📡 **双传输模式** - 支持HTTP和stdio两种传输方式,可在Web应用和IDE中无缝使用 - 🛠️ **灵活工具配置** - 支持从目录加载工具配置(YAML),易于扩展和维护 - 📈 **实时监控** - 监控所有工具的执行状态、结果、调用次数和统计信息 - 🔍 **漏洞自动分析** - 自动分析工具输出,提取和分类发现的漏洞 @@ -36,6 +38,8 @@ CyberStrikeAI/ ├── cmd/ │ ├── server/ │ │ └── main.go # 程序入口,启动HTTP服务器 +│ ├── mcp-stdio/ +│ │ └── main.go # MCP stdio模式入口(用于Cursor等IDE集成) │ └── test-config/ │ └── main.go # 配置测试工具 ├── internal/ @@ -572,6 +576,23 @@ MCP协议端点,支持JSON-RPC 2.0格式的请求。 本项目完整实现了MCP(Model Context Protocol)协议,支持以下功能: +### 传输模式 + +CyberStrikeAI 支持两种 MCP 传输模式: + +#### 1. HTTP 模式(默认) +- 通过 HTTP POST 请求进行通信 +- 适用于 Web 应用和其他 HTTP 客户端 +- 默认监听地址:`0.0.0.0:8081/mcp` +- 可通过 `/api/mcp` 端点访问 + +#### 2. stdio 模式(新增) +- 通过标准输入输出(stdio)进行通信 +- 适用于 Cursor、Claude Desktop 等 IDE 集成 +- 完全符合 JSON-RPC 2.0 规范 +- 支持字符串、数字和 null 类型的 id 字段 +- 正确处理通知(notification)消息 + ### 支持的方法 - `initialize` - 初始化连接,协商协议版本和功能 @@ -582,6 +603,7 @@ MCP协议端点,支持JSON-RPC 2.0格式的请求。 - `resources/list` - 列出可用资源 - `resources/read` - 读取资源内容 - `sampling/request` - 采样请求(占位实现) +- `notifications/initialized` - 初始化完成通知(stdio 模式) ### 工具执行机制 @@ -594,7 +616,73 @@ MCP协议端点,支持JSON-RPC 2.0格式的请求。 - 执行结果或错误信息 - 系统自动跟踪所有工具的执行统计信息 -### MCP协议使用示例 +### MCP stdio 模式(Cursor IDE 集成) + +stdio 模式允许你在 Cursor IDE 中直接使用 CyberStrikeAI 的所有安全工具。 + +#### 编译 stdio 模式程序 + +```bash +# 在项目根目录执行 +go build -o cyberstrike-ai-mcp cmd/mcp-stdio/main.go +``` + +#### 在 Cursor 中配置 + +**方法一:通过 UI 配置** + +1. 打开 Cursor 设置 → **Tools & MCP** +2. 点击 **Add Custom MCP** +3. 配置如下(请替换为你的实际路径): + +```json +{ + "mcpServers": { + "cyberstrike-ai": { + "command": "/absolute/path/to/cyberstrike-ai-mcp", + "args": [ + "--config", + "/absolute/path/to/config.yaml" + ] + } + } +} +``` + +**方法二:通过项目配置文件** + +在项目根目录创建 `.cursor/mcp.json` 文件: + +```json +{ + "mcpServers": { + "cyberstrike-ai": { + "command": "/Users/yourname/Downloads/CyberStrikeAI-main/cyberstrike-ai-mcp", + "args": [ + "--config", + "/Users/yourname/Downloads/CyberStrikeAI-main/config.yaml" + ] + } + } +} +``` + +**重要提示:** +- ✅ 使用绝对路径:`command` 和配置文件路径必须使用绝对路径 +- ✅ 可执行权限:确保编译后的程序有执行权限(Linux/macOS) +- ✅ 重启 Cursor:配置后需要重启 Cursor 才能生效 + +配置完成后,重启 Cursor,你就可以在聊天中直接使用所有安全工具了! + +#### stdio 模式特性 + +- ✅ 完全符合 JSON-RPC 2.0 规范 +- ✅ 支持字符串、数字和 null 类型的 id 字段 +- ✅ 正确处理通知(notification)消息 +- ✅ 日志输出到 stderr,不干扰 JSON-RPC 通信 +- ✅ 与 HTTP 模式完全独立,可同时使用 + +### MCP HTTP 模式使用示例 #### 初始化连接 @@ -826,6 +914,28 @@ parameters: - ✅ 验证MCP配置中的 `enabled: true` - ✅ 查看日志中的MCP服务器启动信息 +**问题:Cursor 中 MCP stdio 模式无法连接** + +- ✅ 检查 `cyberstrike-ai-mcp` 程序路径是否正确(使用绝对路径) +- ✅ 检查程序是否有执行权限(Linux/macOS):`chmod +x cyberstrike-ai-mcp` +- ✅ 检查 `config.yaml` 配置文件路径是否正确 +- ✅ 查看 Cursor 的日志输出(通常在 Cursor 的开发者工具中) +- ✅ 确保配置文件中的 `security.tools_dir` 配置正确 + +**问题:Cursor 中工具列表为空** + +- ✅ 确保 `config.yaml` 中的 `security.tools_dir` 配置正确 +- ✅ 确保工具配置文件在指定目录中 +- ✅ 检查工具配置文件格式是否正确(YAML 语法) +- ✅ 查看程序日志(stderr 输出) + +**问题:Cursor 中工具执行失败** + +- ✅ 确保相应的安全工具已安装在系统中 +- ✅ 检查工具是否在系统 PATH 中 +- ✅ 查看程序日志(stderr 输出) +- ✅ 尝试在终端中直接运行工具命令,确认工具可用 + ### 日志调试 启用详细日志: diff --git a/cmd/mcp-stdio/main.go b/cmd/mcp-stdio/main.go new file mode 100644 index 00000000..977d794f --- /dev/null +++ b/cmd/mcp-stdio/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/logger" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/security" + "flag" + "fmt" + "os" + + "go.uber.org/zap" +) + +func main() { + var configPath = flag.String("config", "config.yaml", "配置文件路径") + flag.Parse() + + // 加载配置 + cfg, err := config.Load(*configPath) + if err != nil { + fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err) + os.Exit(1) + } + + // 初始化日志(stdio 模式下使用 stderr 输出日志,避免干扰 JSON-RPC 通信) + log := logger.New(cfg.Log.Level, "stderr") + + // 创建MCP服务器 + mcpServer := mcp.NewServer(log.Logger) + + // 创建安全工具执行器 + executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger) + + // 注册工具 + executor.RegisterTools(mcpServer) + + log.Logger.Info("MCP服务器(stdio模式)已启动,等待消息...") + + // 运行 stdio 循环 + if err := mcpServer.HandleStdio(); err != nil { + log.Logger.Error("MCP服务器运行失败", zap.Error(err)) + os.Exit(1) + } +} + diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 6de13152..ac668a0c 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "os" "strings" "sync" "time" @@ -93,8 +94,12 @@ func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) { // handleMessage 处理MCP消息 func (s *Server) handleMessage(msg *Message) *Message { - if msg.ID == "" { - msg.ID = uuid.New().String() + // 检查是否是通知(notification)- 通知没有id字段,不需要响应 + isNotification := msg.ID.Value() == nil || msg.ID.String() == "" + + // 如果不是通知且ID为空,生成新的UUID + if !isNotification && msg.ID.String() == "" { + msg.ID = MessageID{value: uuid.New().String()} } switch msg.Method { @@ -114,11 +119,29 @@ func (s *Server) handleMessage(msg *Message) *Message { return s.handleReadResource(msg) case "sampling/request": return s.handleSamplingRequest(msg) + case "notifications/initialized": + // 通知类型,不需要响应 + s.logger.Debug("收到 initialized 通知") + return nil + case "": + // 空方法名,可能是通知,不返回错误 + if isNotification { + s.logger.Debug("收到无方法名的通知消息") + return nil + } + fallthrough default: + // 如果是通知,不返回错误响应 + if isNotification { + s.logger.Debug("收到未知通知", zap.String("method", msg.Method)) + return nil + } + // 对于请求,返回方法未找到错误 return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Error: &Error{Code: -32601, Message: "Method not found"}, + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32601, Message: "Method not found"}, } } } @@ -128,9 +151,10 @@ func (s *Server) handleInitialize(msg *Message) *Message { var req InitializeRequest if err := json.Unmarshal(msg.Params, &req); err != nil { return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Error: &Error{Code: -32602, Message: "Invalid params"}, + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32602, Message: "Invalid params"}, } } @@ -188,9 +212,10 @@ func (s *Server) handleCallTool(msg *Message) *Message { var req CallToolRequest if err := json.Unmarshal(msg.Params, &req); err != nil { return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Error: &Error{Code: -32602, Message: "Invalid params"}, + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32602, Message: "Invalid params"}, } } @@ -222,9 +247,10 @@ func (s *Server) handleCallTool(msg *Message) *Message { now := time.Now() execution.EndTime = &now return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Error: &Error{Code: -32601, Message: "Tool not found"}, + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32601, Message: "Tool not found"}, } } @@ -481,9 +507,10 @@ func (s *Server) handleGetPrompt(msg *Message) *Message { var req GetPromptRequest if err := json.Unmarshal(msg.Params, &req); err != nil { return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Error: &Error{Code: -32602, Message: "Invalid params"}, + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32602, Message: "Invalid params"}, } } @@ -493,9 +520,10 @@ func (s *Server) handleGetPrompt(msg *Message) *Message { if !exists { return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Error: &Error{Code: -32601, Message: "Prompt not found"}, + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32601, Message: "Prompt not found"}, } } @@ -588,9 +616,10 @@ func (s *Server) handleReadResource(msg *Message) *Message { var req ReadResourceRequest if err := json.Unmarshal(msg.Params, &req); err != nil { return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Error: &Error{Code: -32602, Message: "Invalid params"}, + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32602, Message: "Invalid params"}, } } @@ -600,9 +629,10 @@ func (s *Server) handleReadResource(msg *Message) *Message { if !exists { return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Error: &Error{Code: -32601, Message: "Resource not found"}, + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32601, Message: "Resource not found"}, } } @@ -753,9 +783,10 @@ func (s *Server) handleSamplingRequest(msg *Message) *Message { var req SamplingRequest if err := json.Unmarshal(msg.Params, &req); err != nil { return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Error: &Error{Code: -32602, Message: "Invalid params"}, + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32602, Message: "Invalid params"}, } } @@ -797,12 +828,62 @@ func (s *Server) RegisterResource(resource *Resource) { s.resources[resource.URI] = resource } +// HandleStdio 处理标准输入输出(用于 stdio 传输模式) +// MCP 协议使用换行分隔的 JSON-RPC 消息 +func (s *Server) HandleStdio() error { + decoder := json.NewDecoder(os.Stdin) + encoder := json.NewEncoder(os.Stdout) + // 注意:不设置缩进,MCP 协议期望紧凑的 JSON 格式 + + for { + var msg Message + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF { + break + } + // 日志输出到 stderr,避免干扰 stdout 的 JSON-RPC 通信 + s.logger.Error("读取消息失败", zap.Error(err)) + // 发送错误响应 + errorMsg := Message{ + ID: msg.ID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: -32700, Message: "Parse error", Data: err.Error()}, + } + if err := encoder.Encode(errorMsg); err != nil { + return fmt.Errorf("发送错误响应失败: %w", err) + } + continue + } + + // 处理消息 + response := s.handleMessage(&msg) + + // 如果是通知(response 为 nil),不需要发送响应 + if response == nil { + continue + } + + // 发送响应 + if err := encoder.Encode(response); err != nil { + return fmt.Errorf("发送响应失败: %w", err) + } + } + + return nil +} + // sendError 发送错误响应 func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) { + var msgID MessageID + if id != nil { + msgID = MessageID{value: id} + } response := Message{ - ID: fmt.Sprintf("%v", id), - Type: MessageTypeError, - Error: &Error{Code: code, Message: message, Data: data}, + ID: msgID, + Type: MessageTypeError, + Version: "2.0", + Error: &Error{Code: code, Message: message, Data: data}, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) diff --git a/internal/mcp/types.go b/internal/mcp/types.go index 40618a54..91c9b3d8 100644 --- a/internal/mcp/types.go +++ b/internal/mcp/types.go @@ -2,6 +2,7 @@ package mcp import ( "encoding/json" + "fmt" "time" ) @@ -16,15 +17,66 @@ const ( // MCP协议版本 const ProtocolVersion = "2024-11-05" -// Message 表示MCP消息 +// MessageID 表示JSON-RPC 2.0的id字段,可以是字符串、数字或null +type MessageID struct { + value interface{} +} + +// UnmarshalJSON 自定义反序列化,支持字符串、数字和null +func (m *MessageID) UnmarshalJSON(data []byte) error { + // 尝试解析为null + if string(data) == "null" { + m.value = nil + return nil + } + + // 尝试解析为字符串 + var str string + if err := json.Unmarshal(data, &str); err == nil { + m.value = str + return nil + } + + // 尝试解析为数字 + var num json.Number + if err := json.Unmarshal(data, &num); err == nil { + m.value = num + return nil + } + + return fmt.Errorf("invalid id type") +} + +// MarshalJSON 自定义序列化 +func (m MessageID) MarshalJSON() ([]byte, error) { + if m.value == nil { + return []byte("null"), nil + } + return json.Marshal(m.value) +} + +// String 返回字符串表示 +func (m MessageID) String() string { + if m.value == nil { + return "" + } + return fmt.Sprintf("%v", m.value) +} + +// Value 返回原始值 +func (m MessageID) Value() interface{} { + return m.value +} + +// Message 表示MCP消息(符合JSON-RPC 2.0规范) type Message struct { - ID string `json:"id,omitempty"` - Type string `json:"type"` + ID MessageID `json:"id,omitempty"` + Type string `json:"-"` // 内部使用,不序列化到JSON Method string `json:"method,omitempty"` Params json.RawMessage `json:"params,omitempty"` Result json.RawMessage `json:"result,omitempty"` Error *Error `json:"error,omitempty"` - Version string `json:"jsonrpc,omitempty"` + Version string `json:"jsonrpc,omitempty"` // JSON-RPC 2.0 版本标识 } // Error 表示MCP错误