diff --git a/README.md b/README.md index 6220be57..3d5ecc74 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte ## Highlights - 🤖 AI decision engine with OpenAI-compatible models (GPT, Claude, DeepSeek, etc.) -- 🔌 Native MCP implementation with HTTP/stdio transports and external MCP federation +- 🔌 Native MCP implementation with HTTP/stdio/SSE transports and external MCP federation - 🧰 100+ prebuilt tool recipes + YAML-based extension system - 📄 Large-result pagination, compression, and searchable archives - 🔗 Attack-chain graph, risk scoring, and step-by-step replay @@ -149,7 +149,7 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain: ### MCP Everywhere - **Web mode** – ships with HTTP MCP server automatically consumed by the UI. - **MCP stdio mode** – `go run cmd/mcp-stdio/main.go` exposes the agent to Cursor/CLI. -- **External MCP federation** – register third-party MCP servers (HTTP or stdio) from the UI, toggle them per engagement, and monitor their health and call volume in real time. +- **External MCP federation** – register third-party MCP servers (HTTP, stdio, or SSE) from the UI, toggle them per engagement, and monitor their health and call volume in real time. #### MCP stdio quick start 1. **Build the binary** (run from the project root): @@ -189,6 +189,62 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain: } ``` +#### External MCP federation (HTTP/stdio/SSE) +CyberStrikeAI supports connecting to external MCP servers via three transport modes: +- **HTTP mode** – traditional request/response over HTTP POST +- **stdio mode** – process-based communication via standard input/output +- **SSE mode** – Server-Sent Events for real-time streaming communication + +To add an external MCP server: +1. Open the Web UI and navigate to **Settings → External MCP**. +2. Click **Add External MCP** and provide the configuration in JSON format: + + **HTTP mode example:** + ```json + { + "my-http-mcp": { + "transport": "http", + "url": "http://127.0.0.1:8081/mcp", + "description": "HTTP MCP server", + "timeout": 30 + } + } + ``` + + **stdio mode example:** + ```json + { + "my-stdio-mcp": { + "command": "python3", + "args": ["/path/to/mcp-server.py"], + "description": "stdio MCP server", + "timeout": 30 + } + } + ``` + + **SSE mode example:** + ```json + { + "my-sse-mcp": { + "transport": "sse", + "url": "http://127.0.0.1:8082/sse", + "description": "SSE MCP server", + "timeout": 30 + } + } + ``` + +3. Click **Save** and then **Start** to connect to the server. +4. Monitor the connection status, tool count, and health in real time. + +**SSE mode benefits:** +- Real-time bidirectional communication via Server-Sent Events +- Suitable for scenarios requiring continuous data streaming +- Lower latency for push-based notifications + +A test SSE MCP server is available at `cmd/test-sse-mcp-server/` for validation purposes. + ### Knowledge Base - **Vector search** – AI agent can automatically search the knowledge base for relevant security knowledge during conversations using the `search_knowledge_base` tool. - **Hybrid retrieval** – combines vector similarity search with keyword matching for better accuracy. @@ -328,6 +384,7 @@ Build an attack chain for the latest engagement and export the node list with se ## Changelog (Recent) +- 2026-01-08 – Added SSE (Server-Sent Events) transport mode support for external MCP servers. External MCP federation now supports HTTP, stdio, and SSE modes. SSE mode enables real-time streaming communication for push-based scenarios. - 2026-01-01 – Added batch task management feature: create task queues with multiple tasks, add/edit/delete tasks before execution, and execute them sequentially. Each task runs as a separate conversation with status tracking (pending/running/completed/failed/cancelled). All queues and tasks are persisted in the database. - 2025-12-25 – Added vulnerability management feature: full CRUD operations for tracking vulnerabilities discovered during testing. Supports severity levels (critical/high/medium/low/info), status workflow (open/confirmed/fixed/false_positive), filtering by conversation/severity/status, and comprehensive statistics dashboard. - 2025-12-25 – Added conversation grouping feature: organize conversations into groups, pin groups to top, rename/delete groups via context menu. All group data is persisted in the database. diff --git a/README_CN.md b/README_CN.md index 57191f2e..281f4448 100644 --- a/README_CN.md +++ b/README_CN.md @@ -32,7 +32,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集 ## 特性速览 - 🤖 兼容 OpenAI/DeepSeek/Claude 等模型的智能决策引擎 -- 🔌 原生 MCP 协议,支持 HTTP / stdio 以及外部 MCP 接入 +- 🔌 原生 MCP 协议,支持 HTTP / stdio / SSE 传输模式以及外部 MCP 接入 - 🧰 100+ 现成工具模版 + YAML 扩展能力 - 📄 大结果分页、压缩与全文检索 - 🔗 攻击链可视化、风险打分与步骤回放 @@ -147,7 +147,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集 ### MCP 全场景 - **Web 模式**:自带 HTTP MCP 服务供前端调用。 - **MCP stdio 模式**:`go run cmd/mcp-stdio/main.go` 可接入 Cursor/命令行。 -- **外部 MCP 联邦**:在设置中注册第三方 MCP(HTTP/stdio),按需启停并实时查看调用统计与健康度。 +- **外部 MCP 联邦**:在设置中注册第三方 MCP(HTTP/stdio/SSE),按需启停并实时查看调用统计与健康度。 #### MCP stdio 快速集成 1. **编译可执行文件**(在项目根目录执行): @@ -187,6 +187,62 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集 } ``` +#### 外部 MCP 联邦(HTTP/stdio/SSE) +CyberStrikeAI 支持通过三种传输模式连接外部 MCP 服务器: +- **HTTP 模式** – 通过 HTTP POST 进行传统的请求/响应通信 +- **stdio 模式** – 通过标准输入/输出进行进程间通信 +- **SSE 模式** – 通过 Server-Sent Events 实现实时流式通信 + +添加外部 MCP 服务器: +1. 打开 Web 界面,进入 **设置 → 外部MCP**。 +2. 点击 **添加外部MCP**,以 JSON 格式提供配置: + + **HTTP 模式示例:** + ```json + { + "my-http-mcp": { + "transport": "http", + "url": "http://127.0.0.1:8081/mcp", + "description": "HTTP MCP 服务器", + "timeout": 30 + } + } + ``` + + **stdio 模式示例:** + ```json + { + "my-stdio-mcp": { + "command": "python3", + "args": ["/path/to/mcp-server.py"], + "description": "stdio MCP 服务器", + "timeout": 30 + } + } + ``` + + **SSE 模式示例:** + ```json + { + "my-sse-mcp": { + "transport": "sse", + "url": "http://127.0.0.1:8082/sse", + "description": "SSE MCP 服务器", + "timeout": 30 + } + } + ``` + +3. 点击 **保存**,然后点击 **启动** 连接服务器。 +4. 实时监控连接状态、工具数量和健康度。 + +**SSE 模式优势:** +- 通过 Server-Sent Events 实现实时双向通信 +- 适用于需要持续数据流的场景 +- 对于基于推送的通知,延迟更低 + +可在 `cmd/test-sse-mcp-server/` 目录找到用于验证的测试 SSE MCP 服务器。 + ### 知识库功能 - **向量检索**:AI 智能体在对话过程中可自动调用 `search_knowledge_base` 工具搜索知识库中的安全知识。 @@ -326,6 +382,7 @@ CyberStrikeAI/ ``` ## Changelog(近期) +- 2026-01-08 —— 新增 SSE(Server-Sent Events)传输模式支持,外部 MCP 联邦现支持 HTTP、stdio 和 SSE 三种模式。SSE 模式支持实时流式通信,适用于基于推送的场景。 - 2026-01-01 —— 新增批量任务管理功能:支持创建任务队列,批量添加多个任务,执行前可编辑或删除任务,然后依次顺序执行。每个任务作为独立对话运行,支持状态跟踪(待执行/执行中/已完成/失败/已取消),所有队列和任务数据持久化存储到数据库。 - 2025-12-25 —— 新增漏洞管理功能:完整的漏洞 CRUD 操作,支持跟踪测试过程中发现的漏洞。支持严重程度分级(严重/高/中/低/信息)、状态流转(待确认/已确认/已修复/误报)、按对话/严重程度/状态过滤,以及统计看板。 - 2025-12-25 —— 新增对话分组功能:支持创建分组、将对话移动到分组、分组置顶、重命名和删除等操作,所有分组数据持久化存储到数据库。 diff --git a/cmd/test-sse-mcp-server/README.md b/cmd/test-sse-mcp-server/README.md new file mode 100644 index 00000000..a30b6009 --- /dev/null +++ b/cmd/test-sse-mcp-server/README.md @@ -0,0 +1,56 @@ +# SSE MCP 测试服务器 + +这是一个用于验证SSE模式外部MCP功能的测试服务器。 + +## 使用方法 + +### 1. 启动测试服务器 + +```bash +cd cmd/test-sse-mcp-server +go run main.go +``` + +服务器将在 `http://127.0.0.1:8082` 启动,提供以下端点: +- `GET /sse` - SSE事件流端点 +- `POST /message` - 消息接收端点 + +### 2. 在CyberStrikeAI中添加配置 + +在Web界面中添加外部MCP配置,使用以下JSON: + +```json +{ + "test-sse-mcp": { + "transport": "sse", + "url": "http://127.0.0.1:8082/sse", + "description": "SSE MCP测试服务器", + "timeout": 30 + } +} +``` + +### 3. 测试功能 + +测试服务器提供两个测试工具: + +1. **test_echo** - 回显输入的文本 + - 参数:`text` (string) - 要回显的文本 + +2. **test_add** - 计算两个数字的和 + - 参数:`a` (number) - 第一个数字 + - 参数:`b` (number) - 第二个数字 + +## 工作原理 + +1. 客户端通过 `GET /sse` 建立SSE连接,接收服务器推送的事件 +2. 客户端通过 `POST /message` 发送MCP协议消息 +3. 服务器处理消息后,通过SSE连接推送响应 + +## 日志 + +服务器会输出以下日志: +- SSE客户端连接/断开 +- 收到的请求(方法名和ID) +- 工具调用详情 + diff --git a/cmd/test-sse-mcp-server/main.go b/cmd/test-sse-mcp-server/main.go new file mode 100644 index 00000000..a336b4bb --- /dev/null +++ b/cmd/test-sse-mcp-server/main.go @@ -0,0 +1,395 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "sync" + "time" + + "github.com/google/uuid" +) + +const ProtocolVersion = "2024-11-05" + +// Message MCP消息 +type Message struct { + ID interface{} `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *Error `json:"error,omitempty"` + Version string `json:"jsonrpc,omitempty"` +} + +// Error MCP错误 +type Error struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// InitializeRequest 初始化请求 +type InitializeRequest struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]interface{} `json:"capabilities"` + ClientInfo ClientInfo `json:"clientInfo"` +} + +// ClientInfo 客户端信息 +type ClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// InitializeResponse 初始化响应 +type InitializeResponse struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo ServerInfo `json:"serverInfo"` +} + +// ServerCapabilities 服务器能力 +type ServerCapabilities struct { + Tools map[string]interface{} `json:"tools,omitempty"` +} + +// ServerInfo 服务器信息 +type ServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// Tool 工具定义 +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]interface{} `json:"inputSchema"` +} + +// ListToolsResponse 列出工具响应 +type ListToolsResponse struct { + Tools []Tool `json:"tools"` +} + +// CallToolRequest 调用工具请求 +type CallToolRequest struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// CallToolResponse 调用工具响应 +type CallToolResponse struct { + Content []Content `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// Content 内容 +type Content struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// SSEServer SSE MCP服务器 +type SSEServer struct { + sseClients map[string]chan []byte + mu sync.RWMutex +} + +func NewSSEServer() *SSEServer { + return &SSEServer{ + sseClients: make(map[string]chan []byte), + } +} + +// handleSSE 处理SSE连接 +func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + clientID := uuid.New().String() + clientChan := make(chan []byte, 10) + + s.mu.Lock() + s.sseClients[clientID] = clientChan + s.mu.Unlock() + + defer func() { + s.mu.Lock() + delete(s.sseClients, clientID) + close(clientChan) + s.mu.Unlock() + }() + + // 发送初始ready事件 + fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n") + flusher.Flush() + + log.Printf("SSE客户端连接: %s", clientID) + + // 心跳 + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for { + select { + case <-r.Context().Done(): + log.Printf("SSE客户端断开: %s", clientID) + return + case msg, ok := <-clientChan: + if !ok { + return + } + fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg) + flusher.Flush() + case <-ticker.C: + // 心跳 + fmt.Fprintf(w, ": ping\n\n") + flusher.Flush() + } + } +} + +// handleMessage 处理POST消息 +func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var msg Message + if err := json.NewDecoder(r.Body).Decode(&msg); err != nil { + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + log.Printf("收到请求: method=%s, id=%v", msg.Method, msg.ID) + + // 处理消息 + response := s.processMessage(&msg) + + // 如果有SSE客户端,通过SSE推送响应 + if response != nil { + responseJSON, _ := json.Marshal(response) + s.mu.RLock() + // 发送给所有SSE客户端 + for _, ch := range s.sseClients { + select { + case ch <- responseJSON: + default: + } + } + s.mu.RUnlock() + } + + // 也直接返回响应(兼容非SSE模式) + if response != nil { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + } else { + w.WriteHeader(http.StatusOK) + } +} + +// processMessage 处理MCP消息 +func (s *SSEServer) processMessage(msg *Message) *Message { + switch msg.Method { + case "initialize": + return s.handleInitialize(msg) + case "tools/list": + return s.handleListTools(msg) + case "tools/call": + return s.handleCallTool(msg) + default: + return &Message{ + ID: msg.ID, + Version: "2.0", + Error: &Error{ + Code: -32601, + Message: "Method not found", + }, + } + } +} + +// handleInitialize 处理初始化 +func (s *SSEServer) handleInitialize(msg *Message) *Message { + var req InitializeRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Version: "2.0", + Error: &Error{ + Code: -32602, + Message: "Invalid params", + }, + } + } + + log.Printf("初始化请求: client=%s, version=%s", req.ClientInfo.Name, req.ClientInfo.Version) + + response := InitializeResponse{ + ProtocolVersion: ProtocolVersion, + Capabilities: ServerCapabilities{ + Tools: map[string]interface{}{ + "listChanged": true, + }, + }, + ServerInfo: ServerInfo{ + Name: "Test SSE MCP Server", + Version: "1.0.0", + }, + } + + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Version: "2.0", + Result: result, + } +} + +// handleListTools 处理列出工具 +func (s *SSEServer) handleListTools(msg *Message) *Message { + tools := []Tool{ + { + Name: "test_echo", + Description: "回显输入的文本,用于测试SSE MCP服务器", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "text": map[string]interface{}{ + "type": "string", + "description": "要回显的文本", + }, + }, + "required": []string{"text"}, + }, + }, + { + Name: "test_add", + Description: "计算两个数字的和,用于测试SSE MCP服务器", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "a": map[string]interface{}{ + "type": "number", + "description": "第一个数字", + }, + "b": map[string]interface{}{ + "type": "number", + "description": "第二个数字", + }, + }, + "required": []string{"a", "b"}, + }, + }, + } + + response := ListToolsResponse{Tools: tools} + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Version: "2.0", + Result: result, + } +} + +// handleCallTool 处理工具调用 +func (s *SSEServer) handleCallTool(msg *Message) *Message { + var req CallToolRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Version: "2.0", + Error: &Error{ + Code: -32602, + Message: "Invalid params", + }, + } + } + + log.Printf("调用工具: name=%s, args=%v", req.Name, req.Arguments) + + var content []Content + + switch req.Name { + case "test_echo": + text, _ := req.Arguments["text"].(string) + content = []Content{ + { + Type: "text", + Text: fmt.Sprintf("回显: %s", text), + }, + } + case "test_add": + var a, b float64 + if val, ok := req.Arguments["a"].(float64); ok { + a = val + } + if val, ok := req.Arguments["b"].(float64); ok { + b = val + } + sum := a + b + content = []Content{ + { + Type: "text", + Text: fmt.Sprintf("%.2f + %.2f = %.2f", a, b, sum), + }, + } + default: + return &Message{ + ID: msg.ID, + Version: "2.0", + Error: &Error{ + Code: -32601, + Message: "Tool not found", + }, + } + } + + response := CallToolResponse{ + Content: content, + IsError: false, + } + + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Version: "2.0", + Result: result, + } +} + +func main() { + server := NewSSEServer() + + http.HandleFunc("/sse", server.handleSSE) + http.HandleFunc("/message", server.handleMessage) + + port := ":8082" + log.Printf("SSE MCP测试服务器启动在端口 %s", port) + log.Printf("SSE端点: http://localhost%s/sse", port) + log.Printf("消息端点: http://localhost%s/message", port) + log.Printf("配置示例:") + log.Printf(`{ + "test-sse-mcp": { + "transport": "sse", + "url": "http://127.0.0.1:8082/sse" + } +}`) + + if err := http.ListenAndServe(port, nil); err != nil { + log.Fatal(err) + } +} + diff --git a/internal/handler/external_mcp.go b/internal/handler/external_mcp.go index bd68a6af..207566c7 100644 --- a/internal/handler/external_mcp.go +++ b/internal/handler/external_mcp.go @@ -324,7 +324,7 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) } else if cfg.URL != "" { transport = "http" } else { - return fmt.Errorf("需要指定command(stdio模式)或url(http模式)") + return fmt.Errorf("需要指定command(stdio模式)或url(http/sse模式)") } } @@ -337,8 +337,12 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) if cfg.Command == "" { return fmt.Errorf("stdio模式需要command") } + case "sse": + if cfg.URL == "" { + return fmt.Errorf("SSE模式需要URL") + } default: - return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio", transport) + return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport) } return nil diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 752d1588..d3571d40 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -1,6 +1,7 @@ package mcp import ( + "bufio" "bytes" "context" "encoding/json" @@ -8,6 +9,7 @@ import ( "io" "net/http" "os/exec" + "strings" "sync" "time" @@ -472,3 +474,407 @@ func (c *StdioMCPClient) Close() error { c.setStatus("disconnected") return nil } + +// SSEMCPClient SSE模式的MCP客户端 +type SSEMCPClient struct { + url string + timeout time.Duration + client *http.Client + logger *zap.Logger + mu sync.RWMutex + status string // "disconnected", "connecting", "connected", "error" + sseConn io.ReadCloser + sseCancel context.CancelFunc + requestID int64 + responses map[string]chan *Message + responsesMu sync.Mutex + ctx context.Context +} + +// NewSSEMCPClient 创建SSE模式的MCP客户端 +func NewSSEMCPClient(url string, timeout time.Duration, logger *zap.Logger) *SSEMCPClient { + if timeout <= 0 { + timeout = 30 * time.Second + } + ctx, cancel := context.WithCancel(context.Background()) + return &SSEMCPClient{ + url: url, + timeout: timeout, + client: &http.Client{Timeout: timeout}, + logger: logger, + status: "disconnected", + responses: make(map[string]chan *Message), + ctx: ctx, + sseCancel: cancel, + } +} + +func (c *SSEMCPClient) setStatus(status string) { + c.mu.Lock() + defer c.mu.Unlock() + c.status = status +} + +func (c *SSEMCPClient) GetStatus() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.status +} + +func (c *SSEMCPClient) IsConnected() bool { + return c.GetStatus() == "connected" +} + +func (c *SSEMCPClient) Initialize(ctx context.Context) error { + c.setStatus("connecting") + + // 建立SSE连接 + if err := c.connectSSE(); err != nil { + c.setStatus("error") + return fmt.Errorf("建立SSE连接失败: %w", err) + } + + // 启动响应读取goroutine + go c.readSSEResponses() + + // 发送初始化请求 + req := Message{ + ID: MessageID{value: "1"}, + Method: "initialize", + Version: "2.0", + } + + params := InitializeRequest{ + ProtocolVersion: ProtocolVersion, + Capabilities: make(map[string]interface{}), + ClientInfo: ClientInfo{ + Name: "CyberStrikeAI", + Version: "1.0.0", + }, + } + + paramsJSON, _ := json.Marshal(params) + req.Params = paramsJSON + + _, err := c.sendRequest(ctx, &req) + if err != nil { + c.setStatus("error") + c.Close() + return fmt.Errorf("初始化失败: %w", err) + } + + c.setStatus("connected") + return nil +} + +func (c *SSEMCPClient) connectSSE() error { + // 建立SSE连接(GET请求,Accept: text/event-stream) + // SSE连接需要长连接,使用无超时的客户端 + sseClient := &http.Client{ + Timeout: 0, // 无超时,用于长连接 + } + + req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil) + if err != nil { + return fmt.Errorf("创建SSE请求失败: %w", err) + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + + resp, err := sseClient.Do(req) + if err != nil { + return fmt.Errorf("SSE连接失败: %w", err) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return fmt.Errorf("SSE连接失败,状态码: %d", resp.StatusCode) + } + + contentType := resp.Header.Get("Content-Type") + if !strings.Contains(contentType, "text/event-stream") { + resp.Body.Close() + return fmt.Errorf("服务器不支持SSE,Content-Type: %s", contentType) + } + + c.sseConn = resp.Body + return nil +} + +func (c *SSEMCPClient) readSSEResponses() { + defer func() { + if r := recover(); r != nil { + c.logger.Error("读取SSE响应时发生panic", zap.Any("error", r)) + } + if c.sseConn != nil { + c.sseConn.Close() + } + c.setStatus("disconnected") + }() + + if c.sseConn == nil { + return + } + + scanner := &sseScanner{reader: bufio.NewReader(c.sseConn)} + + for { + select { + case <-c.ctx.Done(): + return + default: + } + + // 读取SSE事件 + event, err := scanner.readEvent() + if err != nil { + if err == io.EOF { + c.setStatus("disconnected") + return + } + c.logger.Error("读取SSE数据失败", zap.Error(err)) + return + } + + if event == nil || len(event.Data) == 0 { + continue + } + + // 解析JSON消息 + var msg Message + if err := json.Unmarshal(event.Data, &msg); err != nil { + c.logger.Warn("解析SSE消息失败", zap.Error(err), zap.String("data", string(event.Data))) + continue + } + + // 处理响应 + id := msg.ID.String() + c.responsesMu.Lock() + if ch, ok := c.responses[id]; ok { + select { + case ch <- &msg: + default: + } + delete(c.responses, id) + } + c.responsesMu.Unlock() + } +} + +// sseEvent SSE事件 +type sseEvent struct { + Event string + Data []byte + ID string + Retry int +} + +// sseScanner SSE扫描器 +type sseScanner struct { + reader *bufio.Reader +} + +func (s *sseScanner) readEvent() (*sseEvent, error) { + event := &sseEvent{} + + for { + line, err := s.reader.ReadString('\n') + if err != nil { + return nil, err + } + + line = strings.TrimRight(line, "\r\n") + + // 空行表示事件结束 + if len(line) == 0 { + if len(event.Data) > 0 { + return event, nil + } + continue + } + + // 解析SSE行 + if strings.HasPrefix(line, "event: ") { + event.Event = strings.TrimSpace(line[7:]) + } else if strings.HasPrefix(line, "data: ") { + data := []byte(strings.TrimSpace(line[6:])) + if len(event.Data) > 0 { + event.Data = append(event.Data, '\n') + } + event.Data = append(event.Data, data...) + } else if strings.HasPrefix(line, "id: ") { + event.ID = strings.TrimSpace(line[4:]) + } else if strings.HasPrefix(line, "retry: ") { + fmt.Sscanf(line[7:], "%d", &event.Retry) + } + } +} + +func (c *SSEMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) { + if c.sseConn == nil { + return nil, fmt.Errorf("SSE连接未建立") + } + + id := msg.ID.String() + if id == "" { + c.mu.Lock() + c.requestID++ + id = fmt.Sprintf("%d", c.requestID) + msg.ID = MessageID{value: id} + c.mu.Unlock() + } + + // 创建响应通道 + responseCh := make(chan *Message, 1) + c.responsesMu.Lock() + c.responses[id] = responseCh + c.responsesMu.Unlock() + + // 通过HTTP POST发送请求(SSE用于接收响应,请求通过POST发送) + body, err := json.Marshal(msg) + if err != nil { + c.responsesMu.Lock() + delete(c.responses, id) + c.responsesMu.Unlock() + return nil, fmt.Errorf("序列化请求失败: %w", err) + } + + // 使用POST请求发送消息(通常SSE服务器会提供两个端点:一个用于SSE,一个用于POST) + // 如果URL是SSE端点,尝试使用相同的URL但改为POST,或者使用URL + "/message" + postURL := c.url + if strings.HasSuffix(postURL, "/sse") { + postURL = strings.TrimSuffix(postURL, "/sse") + postURL += "/message" + } else if strings.HasSuffix(postURL, "/events") { + postURL = strings.TrimSuffix(postURL, "/events") + postURL += "/message" + } else if !strings.Contains(postURL, "/message") { + // 如果URL不包含/message,尝试添加 + postURL = strings.TrimSuffix(postURL, "/") + postURL += "/message" + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, postURL, bytes.NewReader(body)) + if err != nil { + c.responsesMu.Lock() + delete(c.responses, id) + c.responsesMu.Unlock() + return nil, fmt.Errorf("创建POST请求失败: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(httpReq) + if err != nil { + c.responsesMu.Lock() + delete(c.responses, id) + c.responsesMu.Unlock() + return nil, fmt.Errorf("发送POST请求失败: %w", err) + } + defer resp.Body.Close() + + // 如果POST请求直接返回响应(非SSE模式),直接解析 + if resp.StatusCode == http.StatusOK && resp.Header.Get("Content-Type") == "application/json" { + var mcpResp Message + if err := json.NewDecoder(resp.Body).Decode(&mcpResp); err != nil { + c.responsesMu.Lock() + delete(c.responses, id) + c.responsesMu.Unlock() + return nil, fmt.Errorf("解析响应失败: %w", err) + } + + if mcpResp.Error != nil { + c.responsesMu.Lock() + delete(c.responses, id) + c.responsesMu.Unlock() + return nil, fmt.Errorf("MCP错误: %s (code: %d)", mcpResp.Error.Message, mcpResp.Error.Code) + } + + return &mcpResp, nil + } + + // 否则等待SSE响应 + select { + case resp := <-responseCh: + if resp.Error != nil { + return nil, fmt.Errorf("MCP错误: %s (code: %d)", resp.Error.Message, resp.Error.Code) + } + return resp, nil + case <-ctx.Done(): + c.responsesMu.Lock() + delete(c.responses, id) + c.responsesMu.Unlock() + return nil, ctx.Err() + case <-time.After(c.timeout): + c.responsesMu.Lock() + delete(c.responses, id) + c.responsesMu.Unlock() + return nil, fmt.Errorf("请求超时") + } +} + +func (c *SSEMCPClient) ListTools(ctx context.Context) ([]Tool, error) { + req := Message{ + ID: MessageID{value: uuid.New().String()}, + Method: "tools/list", + Version: "2.0", + } + + req.Params = json.RawMessage("{}") + + resp, err := c.sendRequest(ctx, &req) + if err != nil { + return nil, fmt.Errorf("获取工具列表失败: %w", err) + } + + var listResp ListToolsResponse + if err := json.Unmarshal(resp.Result, &listResp); err != nil { + return nil, fmt.Errorf("解析工具列表失败: %w", err) + } + + return listResp.Tools, nil +} + +func (c *SSEMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { + req := Message{ + ID: MessageID{value: uuid.New().String()}, + Method: "tools/call", + Version: "2.0", + } + + callReq := CallToolRequest{ + Name: name, + Arguments: args, + } + + paramsJSON, _ := json.Marshal(callReq) + req.Params = paramsJSON + + resp, err := c.sendRequest(ctx, &req) + if err != nil { + return nil, fmt.Errorf("调用工具失败: %w", err) + } + + var callResp CallToolResponse + if err := json.Unmarshal(resp.Result, &callResp); err != nil { + return nil, fmt.Errorf("解析工具调用结果失败: %w", err) + } + + return &ToolResult{ + Content: callResp.Content, + IsError: callResp.IsError, + }, nil +} + +func (c *SSEMCPClient) Close() error { + c.sseCancel() + + if c.sseConn != nil { + c.sseConn.Close() + c.sseConn = nil + } + + c.setStatus("disconnected") + return nil +} diff --git a/internal/mcp/external_manager.go b/internal/mcp/external_manager.go index fb4d2530..58d094c4 100644 --- a/internal/mcp/external_manager.go +++ b/internal/mcp/external_manager.go @@ -603,6 +603,7 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf if serverCfg.Command != "" { transport = "stdio" } else if serverCfg.URL != "" { + // 默认使用http,但可以通过transport字段指定sse transport = "http" } else { return nil @@ -620,6 +621,11 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf return nil } return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, timeout, m.logger) + case "sse": + if serverCfg.URL == "" { + return nil + } + return NewSSEMCPClient(serverCfg.URL, timeout, m.logger) default: return nil } @@ -654,6 +660,8 @@ func (m *ExternalMCPManager) setClientStatus(client ExternalMCPClient, status st c.setStatus(status) case *StdioMCPClient: c.setStatus(status) + case *SSEMCPClient: + c.setStatus(status) } } diff --git a/web/static/js/settings.js b/web/static/js/settings.js index 8bae61ae..74237f9a 100644 --- a/web/static/js/settings.js +++ b/web/static/js/settings.js @@ -1158,6 +1158,14 @@ function loadExternalMCPExample() { ], description: "示例描述", timeout: 300 + }, + "cyberstrike-ai-http": { + transport: "http", + url: "http://127.0.0.1:8081/mcp" + }, + "cyberstrike-ai-sse": { + transport: "sse", + url: "http://127.0.0.1:8081/mcp/sse" } }; @@ -1231,7 +1239,7 @@ async function saveExternalMCP() { // 验证配置内容 const transport = config.transport || (config.command ? 'stdio' : config.url ? 'http' : ''); if (!transport) { - errorDiv.textContent = `配置错误: "${name}" 需要指定command(stdio模式)或url(http模式)`; + errorDiv.textContent = `配置错误: "${name}" 需要指定command(stdio模式)或url(http/sse模式)`; errorDiv.style.display = 'block'; jsonTextarea.classList.add('error'); return; @@ -1250,6 +1258,13 @@ async function saveExternalMCP() { jsonTextarea.classList.add('error'); return; } + + if (transport === 'sse' && !config.url) { + errorDiv.textContent = `配置错误: "${name}" sse模式需要url字段`; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + return; + } } // 清除错误提示 diff --git a/web/templates/index.html b/web/templates/index.html index af33a87d..965990e8 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -854,6 +854,13 @@ "transport": "http", "url": "http://127.0.0.1:8081/mcp" } +} + SSE模式:
+ { + "cyberstrike-ai-sse": { + "transport": "sse", + "url": "http://127.0.0.1:8081/mcp/sse" + } }