Add files via upload

增加mcp-stdio
This commit is contained in:
公明
2025-11-13 00:44:37 +08:00
committed by GitHub
parent 0c2cd14567
commit 5a4a1b3269
4 changed files with 326 additions and 37 deletions

112
README.md
View File

@@ -4,6 +4,7 @@
![详情预览](./img/效果1.png)
## 更新日志
- 2025.11.13 新增 MCP stdio 模式支持,可在 Cursor IDE 中直接使用所有安全工具;
- 2025.11.12 增加了任务停止功能,优化前端;
## ✨ 功能特性
@@ -19,6 +20,7 @@
### 工具集成
- 🔌 **MCP协议支持** - 完整实现MCP协议支持工具注册、调用、监控
- 📡 **双传输模式** - 支持HTTP和stdio两种传输方式可在Web应用和IDE中无缝使用
- 🛠️ **灵活工具配置** - 支持从目录加载工具配置YAML易于扩展和维护
- 📈 **实时监控** - 监控所有工具的执行状态、结果、调用次数和统计信息
- 🔍 **漏洞自动分析** - 自动分析工具输出,提取和分类发现的漏洞
@@ -36,6 +38,8 @@ CyberStrikeAI/
├── cmd/
│ ├── server/
│ │ └── main.go # 程序入口启动HTTP服务器
│ ├── mcp-stdio/
│ │ └── main.go # MCP stdio模式入口用于Cursor等IDE集成
│ └── test-config/
│ └── main.go # 配置测试工具
├── internal/
@@ -572,6 +576,23 @@ MCP协议端点支持JSON-RPC 2.0格式的请求。
本项目完整实现了MCPModel Context Protocol协议支持以下功能
### 传输模式
CyberStrikeAI 支持两种 MCP 传输模式:
#### 1. HTTP 模式(默认)
- 通过 HTTP POST 请求进行通信
- 适用于 Web 应用和其他 HTTP 客户端
- 默认监听地址:`0.0.0.0:8081/mcp`
- 可通过 `/api/mcp` 端点访问
#### 2. stdio 模式(新增)
- 通过标准输入输出stdio进行通信
- 适用于 Cursor、Claude Desktop 等 IDE 集成
- 完全符合 JSON-RPC 2.0 规范
- 支持字符串、数字和 null 类型的 id 字段
- 正确处理通知notification消息
### 支持的方法
- `initialize` - 初始化连接,协商协议版本和功能
@@ -582,6 +603,7 @@ MCP协议端点支持JSON-RPC 2.0格式的请求。
- `resources/list` - 列出可用资源
- `resources/read` - 读取资源内容
- `sampling/request` - 采样请求(占位实现)
- `notifications/initialized` - 初始化完成通知stdio 模式)
### 工具执行机制
@@ -594,7 +616,73 @@ MCP协议端点支持JSON-RPC 2.0格式的请求。
- 执行结果或错误信息
- 系统自动跟踪所有工具的执行统计信息
### MCP协议使用示例
### MCP stdio 模式Cursor IDE 集成)
stdio 模式允许你在 Cursor IDE 中直接使用 CyberStrikeAI 的所有安全工具。
#### 编译 stdio 模式程序
```bash
# 在项目根目录执行
go build -o cyberstrike-ai-mcp cmd/mcp-stdio/main.go
```
#### 在 Cursor 中配置
**方法一:通过 UI 配置**
1. 打开 Cursor 设置 → **Tools & MCP**
2. 点击 **Add Custom MCP**
3. 配置如下(请替换为你的实际路径):
```json
{
"mcpServers": {
"cyberstrike-ai": {
"command": "/absolute/path/to/cyberstrike-ai-mcp",
"args": [
"--config",
"/absolute/path/to/config.yaml"
]
}
}
}
```
**方法二:通过项目配置文件**
在项目根目录创建 `.cursor/mcp.json` 文件:
```json
{
"mcpServers": {
"cyberstrike-ai": {
"command": "/Users/yourname/Downloads/CyberStrikeAI-main/cyberstrike-ai-mcp",
"args": [
"--config",
"/Users/yourname/Downloads/CyberStrikeAI-main/config.yaml"
]
}
}
}
```
**重要提示:**
- ✅ 使用绝对路径:`command` 和配置文件路径必须使用绝对路径
- ✅ 可执行权限确保编译后的程序有执行权限Linux/macOS
- ✅ 重启 Cursor配置后需要重启 Cursor 才能生效
配置完成后,重启 Cursor你就可以在聊天中直接使用所有安全工具了
#### stdio 模式特性
- ✅ 完全符合 JSON-RPC 2.0 规范
- ✅ 支持字符串、数字和 null 类型的 id 字段
- ✅ 正确处理通知notification消息
- ✅ 日志输出到 stderr不干扰 JSON-RPC 通信
- ✅ 与 HTTP 模式完全独立,可同时使用
### MCP HTTP 模式使用示例
#### 初始化连接
@@ -826,6 +914,28 @@ parameters:
- ✅ 验证MCP配置中的 `enabled: true`
- ✅ 查看日志中的MCP服务器启动信息
**问题Cursor 中 MCP stdio 模式无法连接**
- ✅ 检查 `cyberstrike-ai-mcp` 程序路径是否正确(使用绝对路径)
- ✅ 检查程序是否有执行权限Linux/macOS`chmod +x cyberstrike-ai-mcp`
- ✅ 检查 `config.yaml` 配置文件路径是否正确
- ✅ 查看 Cursor 的日志输出(通常在 Cursor 的开发者工具中)
- ✅ 确保配置文件中的 `security.tools_dir` 配置正确
**问题Cursor 中工具列表为空**
- ✅ 确保 `config.yaml` 中的 `security.tools_dir` 配置正确
- ✅ 确保工具配置文件在指定目录中
- ✅ 检查工具配置文件格式是否正确YAML 语法)
- ✅ 查看程序日志stderr 输出)
**问题Cursor 中工具执行失败**
- ✅ 确保相应的安全工具已安装在系统中
- ✅ 检查工具是否在系统 PATH 中
- ✅ 查看程序日志stderr 输出)
- ✅ 尝试在终端中直接运行工具命令,确认工具可用
### 日志调试
启用详细日志:

46
cmd/mcp-stdio/main.go Normal file
View File

@@ -0,0 +1,46 @@
package main
import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"flag"
"fmt"
"os"
"go.uber.org/zap"
)
func main() {
var configPath = flag.String("config", "config.yaml", "配置文件路径")
flag.Parse()
// 加载配置
cfg, err := config.Load(*configPath)
if err != nil {
fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err)
os.Exit(1)
}
// 初始化日志stdio 模式下使用 stderr 输出日志,避免干扰 JSON-RPC 通信)
log := logger.New(cfg.Log.Level, "stderr")
// 创建MCP服务器
mcpServer := mcp.NewServer(log.Logger)
// 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
// 注册工具
executor.RegisterTools(mcpServer)
log.Logger.Info("MCP服务器stdio模式已启动等待消息...")
// 运行 stdio 循环
if err := mcpServer.HandleStdio(); err != nil {
log.Logger.Error("MCP服务器运行失败", zap.Error(err))
os.Exit(1)
}
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
@@ -93,8 +94,12 @@ func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
// handleMessage 处理MCP消息
func (s *Server) handleMessage(msg *Message) *Message {
if msg.ID == "" {
msg.ID = uuid.New().String()
// 检查是否是通知notification- 通知没有id字段不需要响应
isNotification := msg.ID.Value() == nil || msg.ID.String() == ""
// 如果不是通知且ID为空生成新的UUID
if !isNotification && msg.ID.String() == "" {
msg.ID = MessageID{value: uuid.New().String()}
}
switch msg.Method {
@@ -114,11 +119,29 @@ func (s *Server) handleMessage(msg *Message) *Message {
return s.handleReadResource(msg)
case "sampling/request":
return s.handleSamplingRequest(msg)
case "notifications/initialized":
// 通知类型,不需要响应
s.logger.Debug("收到 initialized 通知")
return nil
case "":
// 空方法名,可能是通知,不返回错误
if isNotification {
s.logger.Debug("收到无方法名的通知消息")
return nil
}
fallthrough
default:
// 如果是通知,不返回错误响应
if isNotification {
s.logger.Debug("收到未知通知", zap.String("method", msg.Method))
return nil
}
// 对于请求,返回方法未找到错误
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32601, Message: "Method not found"},
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32601, Message: "Method not found"},
}
}
}
@@ -128,9 +151,10 @@ func (s *Server) handleInitialize(msg *Message) *Message {
var req InitializeRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32602, Message: "Invalid params"},
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
@@ -188,9 +212,10 @@ func (s *Server) handleCallTool(msg *Message) *Message {
var req CallToolRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32602, Message: "Invalid params"},
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
@@ -222,9 +247,10 @@ func (s *Server) handleCallTool(msg *Message) *Message {
now := time.Now()
execution.EndTime = &now
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32601, Message: "Tool not found"},
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32601, Message: "Tool not found"},
}
}
@@ -481,9 +507,10 @@ func (s *Server) handleGetPrompt(msg *Message) *Message {
var req GetPromptRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32602, Message: "Invalid params"},
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
@@ -493,9 +520,10 @@ func (s *Server) handleGetPrompt(msg *Message) *Message {
if !exists {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32601, Message: "Prompt not found"},
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32601, Message: "Prompt not found"},
}
}
@@ -588,9 +616,10 @@ func (s *Server) handleReadResource(msg *Message) *Message {
var req ReadResourceRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32602, Message: "Invalid params"},
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
@@ -600,9 +629,10 @@ func (s *Server) handleReadResource(msg *Message) *Message {
if !exists {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32601, Message: "Resource not found"},
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32601, Message: "Resource not found"},
}
}
@@ -753,9 +783,10 @@ func (s *Server) handleSamplingRequest(msg *Message) *Message {
var req SamplingRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Type: MessageTypeError,
Error: &Error{Code: -32602, Message: "Invalid params"},
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32602, Message: "Invalid params"},
}
}
@@ -797,12 +828,62 @@ func (s *Server) RegisterResource(resource *Resource) {
s.resources[resource.URI] = resource
}
// HandleStdio 处理标准输入输出(用于 stdio 传输模式)
// MCP 协议使用换行分隔的 JSON-RPC 消息
func (s *Server) HandleStdio() error {
decoder := json.NewDecoder(os.Stdin)
encoder := json.NewEncoder(os.Stdout)
// 注意不设置缩进MCP 协议期望紧凑的 JSON 格式
for {
var msg Message
if err := decoder.Decode(&msg); err != nil {
if err == io.EOF {
break
}
// 日志输出到 stderr避免干扰 stdout 的 JSON-RPC 通信
s.logger.Error("读取消息失败", zap.Error(err))
// 发送错误响应
errorMsg := Message{
ID: msg.ID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: -32700, Message: "Parse error", Data: err.Error()},
}
if err := encoder.Encode(errorMsg); err != nil {
return fmt.Errorf("发送错误响应失败: %w", err)
}
continue
}
// 处理消息
response := s.handleMessage(&msg)
// 如果是通知response 为 nil不需要发送响应
if response == nil {
continue
}
// 发送响应
if err := encoder.Encode(response); err != nil {
return fmt.Errorf("发送响应失败: %w", err)
}
}
return nil
}
// sendError 发送错误响应
func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) {
var msgID MessageID
if id != nil {
msgID = MessageID{value: id}
}
response := Message{
ID: fmt.Sprintf("%v", id),
Type: MessageTypeError,
Error: &Error{Code: code, Message: message, Data: data},
ID: msgID,
Type: MessageTypeError,
Version: "2.0",
Error: &Error{Code: code, Message: message, Data: data},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)

View File

@@ -2,6 +2,7 @@ package mcp
import (
"encoding/json"
"fmt"
"time"
)
@@ -16,15 +17,66 @@ const (
// MCP协议版本
const ProtocolVersion = "2024-11-05"
// Message 表示MCP消息
// MessageID 表示JSON-RPC 2.0的id字段可以是字符串、数字或null
type MessageID struct {
value interface{}
}
// UnmarshalJSON 自定义反序列化支持字符串、数字和null
func (m *MessageID) UnmarshalJSON(data []byte) error {
// 尝试解析为null
if string(data) == "null" {
m.value = nil
return nil
}
// 尝试解析为字符串
var str string
if err := json.Unmarshal(data, &str); err == nil {
m.value = str
return nil
}
// 尝试解析为数字
var num json.Number
if err := json.Unmarshal(data, &num); err == nil {
m.value = num
return nil
}
return fmt.Errorf("invalid id type")
}
// MarshalJSON 自定义序列化
func (m MessageID) MarshalJSON() ([]byte, error) {
if m.value == nil {
return []byte("null"), nil
}
return json.Marshal(m.value)
}
// String 返回字符串表示
func (m MessageID) String() string {
if m.value == nil {
return ""
}
return fmt.Sprintf("%v", m.value)
}
// Value 返回原始值
func (m MessageID) Value() interface{} {
return m.value
}
// Message 表示MCP消息符合JSON-RPC 2.0规范)
type Message struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
ID MessageID `json:"id,omitempty"`
Type string `json:"-"` // 内部使用不序列化到JSON
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
Version string `json:"jsonrpc,omitempty"`
Version string `json:"jsonrpc,omitempty"` // JSON-RPC 2.0 版本标识
}
// Error 表示MCP错误