diff --git a/internal/config/config.go b/internal/config/config.go index c13522c8..72db8808 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,19 +13,19 @@ import ( ) type Config struct { - Server ServerConfig `yaml:"server"` - Log LogConfig `yaml:"log"` - MCP MCPConfig `yaml:"mcp"` - OpenAI OpenAIConfig `yaml:"openai"` - Agent AgentConfig `yaml:"agent"` - Security SecurityConfig `yaml:"security"` - Database DatabaseConfig `yaml:"database"` - Auth AuthConfig `yaml:"auth"` - ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` - Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"` - RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式) - Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色 - SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录 + Server ServerConfig `yaml:"server"` + Log LogConfig `yaml:"log"` + MCP MCPConfig `yaml:"mcp"` + OpenAI OpenAIConfig `yaml:"openai"` + Agent AgentConfig `yaml:"agent"` + Security SecurityConfig `yaml:"security"` + Database DatabaseConfig `yaml:"database"` + Auth AuthConfig `yaml:"auth"` + ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` + Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"` + RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式) + Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色 + SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录 } type ServerConfig struct { @@ -83,13 +83,14 @@ type ExternalMCPConfig struct { // ExternalMCPServerConfig 外部MCP服务器配置 type ExternalMCPServerConfig struct { // stdio模式配置 - Command string `yaml:"command,omitempty" json:"command,omitempty"` - Args []string `yaml:"args,omitempty" json:"args,omitempty"` + Command string `yaml:"command,omitempty" json:"command,omitempty"` + Args []string `yaml:"args,omitempty" json:"args,omitempty"` Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式) // HTTP模式配置 - Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio" - URL string `yaml:"url,omitempty" json:"url,omitempty"` + Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "stdio" | "sse" | "http"(Streamable) | "simple_http"(自建/简单POST端点,如本机 http://127.0.0.1:8081/mcp) + URL string `yaml:"url,omitempty" json:"url,omitempty"` + Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` // HTTP/SSE 请求头(如 x-api-key) // 通用配置 Description string `yaml:"description,omitempty" json:"description,omitempty"` @@ -108,8 +109,8 @@ type ToolConfig struct { ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗) Description string `yaml:"description"` // 详细描述(用于工具文档) Enabled bool `yaml:"enabled"` - Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选) - ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选) + Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选) + ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选) AllowedExitCodes []int `yaml:"allowed_exit_codes,omitempty"` // 允许的退出码列表(某些工具在成功时也返回非零退出码) } @@ -467,7 +468,7 @@ func LoadRoleFromFile(path string) (*RoleConfig, error) { icon := role.Icon // 去除可能的引号 icon = strings.Trim(icon, `"`) - + // 检查是否是 Unicode 转义格式 \U0001F3C6(8位十六进制)或 \uXXXX(4位十六进制) if len(icon) >= 3 && icon[0] == '\\' { if icon[1] == 'U' && len(icon) >= 10 { @@ -576,12 +577,12 @@ type RolesConfig struct { // RoleConfig 单个角色配置 type RoleConfig struct { - Name string `yaml:"name" json:"name"` // 角色名称 - Description string `yaml:"description" json:"description"` // 角色描述 - UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前) - Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选) - Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName") - MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代) + Name string `yaml:"name" json:"name"` // 角色名称 + Description string `yaml:"description" json:"description"` // 角色描述 + UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前) + Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选) + Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName") + MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代) Skills []string `yaml:"skills,omitempty" json:"skills,omitempty"` // 关联的skills列表(skill名称列表,在执行任务前会读取这些skills的内容) - Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用 + Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用 } diff --git a/internal/handler/external_mcp.go b/internal/handler/external_mcp.go index 1fc9e8b3..a8b57ae6 100644 --- a/internal/handler/external_mcp.go +++ b/internal/handler/external_mcp.go @@ -8,6 +8,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" + "github.com/gin-gonic/gin" "go.uber.org/zap" "gopkg.in/yaml.v3" @@ -36,12 +37,12 @@ func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) { h.mu.RLock() defer h.mu.RUnlock() - + configs := h.manager.GetConfigs() - + // 获取所有外部MCP的工具数量 toolCounts := h.manager.GetToolCounts() - + // 转换为响应格式 result := make(map[string]ExternalMCPResponse) for name, cfg := range configs { @@ -54,13 +55,13 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) { } else { status = "disabled" } - + toolCount := toolCounts[name] errorMsg := "" if status == "error" { errorMsg = h.manager.GetError(name) } - + result[name] = ExternalMCPResponse{ Config: cfg, Status: status, @@ -68,7 +69,7 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) { Error: errorMsg, } } - + c.JSON(http.StatusOK, gin.H{ "servers": result, "stats": h.manager.GetStats(), @@ -78,17 +79,17 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) { // GetExternalMCP 获取单个外部MCP配置 func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) { name := c.Param("name") - + h.mu.RLock() defer h.mu.RUnlock() - + configs := h.manager.GetConfigs() cfg, exists := configs[name] if !exists { c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"}) return } - + client, clientExists := h.manager.GetClient(name) status := "disconnected" if clientExists { @@ -98,7 +99,7 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) { } else { status = "disabled" } - + // 获取工具数量 toolCount := 0 if clientExists && client.IsConnected() { @@ -106,13 +107,13 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) { toolCount = count } } - + // 获取错误信息 errorMsg := "" if status == "error" { errorMsg = h.manager.GetError(name) } - + c.JSON(http.StatusOK, ExternalMCPResponse{ Config: cfg, Status: status, @@ -128,38 +129,38 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) return } - + name := c.Param("name") if name == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"}) return } - + // 验证配置 if err := h.validateConfig(req.Config); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - + h.mu.Lock() defer h.mu.Unlock() - + // 添加或更新配置 if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil { h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()}) return } - + // 更新内存中的配置 if h.config.ExternalMCP.Servers == nil { h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) } - + // 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容 // 同时将值迁移到 external_mcp_enable cfg := req.Config - + if req.Config.Disabled { // 用户设置了 disabled: true cfg.ExternalMCPEnable = false @@ -185,16 +186,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) { cfg.Enabled = true cfg.Disabled = false } - + h.config.ExternalMCP.Servers[name] = cfg - + // 保存到配置文件 if err := h.saveConfig(); err != nil { h.logger.Error("保存配置失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) return } - + h.logger.Info("外部MCP配置已更新", zap.String("name", name)) c.JSON(http.StatusOK, gin.H{"message": "配置已更新"}) } @@ -202,28 +203,28 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) { // DeleteExternalMCP 删除外部MCP配置 func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) { name := c.Param("name") - + h.mu.Lock() defer h.mu.Unlock() - + // 移除配置 if err := h.manager.RemoveConfig(name); err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"}) return } - + // 从内存配置中删除 if h.config.ExternalMCP.Servers != nil { delete(h.config.ExternalMCP.Servers, name) } - + // 保存到配置文件 if err := h.saveConfig(); err != nil { h.logger.Error("保存配置失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) return } - + h.logger.Info("外部MCP配置已删除", zap.String("name", name)) c.JSON(http.StatusOK, gin.H{"message": "配置已删除"}) } @@ -231,10 +232,10 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) { // StartExternalMCP 启动外部MCP func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) { name := c.Param("name") - + h.mu.Lock() defer h.mu.Unlock() - + // 更新配置为启用 if h.config.ExternalMCP.Servers == nil { h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) @@ -242,32 +243,32 @@ func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) { cfg := h.config.ExternalMCP.Servers[name] cfg.ExternalMCPEnable = true h.config.ExternalMCP.Servers[name] = cfg - + // 保存到配置文件 if err := h.saveConfig(); err != nil { h.logger.Error("保存配置失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) return } - + // 启动客户端(立即创建客户端并设置状态为connecting,实际连接在后台进行) h.logger.Info("开始启动外部MCP", zap.String("name", name)) if err := h.manager.StartClient(name); err != nil { h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), + "error": err.Error(), "status": "error", }) return } - + // 获取客户端状态(应该是connecting) client, exists := h.manager.GetClient(name) status := "connecting" if exists { status = client.GetStatus() } - + // 立即返回,不等待连接完成 // 客户端会在后台异步连接,用户可以通过状态查询接口查看连接状态 c.JSON(http.StatusOK, gin.H{ @@ -279,16 +280,16 @@ func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) { // StopExternalMCP 停止外部MCP func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) { name := c.Param("name") - + h.mu.Lock() defer h.mu.Unlock() - + // 停止客户端 if err := h.manager.StopClient(name); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - + // 更新配置 if h.config.ExternalMCP.Servers == nil { h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) @@ -296,14 +297,14 @@ func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) { cfg := h.config.ExternalMCP.Servers[name] cfg.ExternalMCPEnable = false h.config.ExternalMCP.Servers[name] = cfg - + // 保存到配置文件 if err := h.saveConfig(); err != nil { h.logger.Error("保存配置失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) return } - + h.logger.Info("外部MCP已停止", zap.String("name", name)) c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"}) } @@ -327,7 +328,7 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) return fmt.Errorf("需要指定command(stdio模式)或url(http/sse模式)") } } - + switch transport { case "http": if cfg.URL == "" { @@ -344,7 +345,7 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) default: return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport) } - + return nil } @@ -428,17 +429,17 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi root := doc.Content[0] externalMCPNode := ensureMap(root, "external_mcp") serversNode := ensureMap(externalMCPNode, "servers") - + // 清空现有服务器配置 serversNode.Content = nil - + // 添加新的服务器配置 for name, serverCfg := range cfg.Servers { // 添加服务器名称键 nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name} serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} serversNode.Content = append(serversNode.Content, nameNode, serverNode) - + // 设置服务器配置字段 if serverCfg.Command != "" { setStringInMap(serverNode, "command", serverCfg.Command) @@ -459,6 +460,13 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi if serverCfg.URL != "" { setStringInMap(serverNode, "url", serverCfg.URL) } + // 保存 headers 字段(HTTP/SSE 请求头) + if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 { + headersNode := ensureMap(serverNode, "headers") + for k, v := range serverCfg.Headers { + setStringInMap(headersNode, k, v) + } + } if serverCfg.Description != "" { setStringInMap(serverNode, "description", serverCfg.Description) } @@ -476,7 +484,7 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi } // 保留旧的 enabled/disabled 字段以保持向后兼容 originalFields, hasOriginal := originalConfigs[name] - + // 如果原始配置中有 enabled 字段,保留它 if hasOriginal { if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled { @@ -494,7 +502,7 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi } } } - + // 如果用户在当前请求中明确设置了这些字段,也保存它们 if serverCfg.Enabled { setBoolInMap(serverNode, "enabled", serverCfg.Enabled) @@ -528,8 +536,7 @@ type AddOrUpdateExternalMCPRequest struct { // ExternalMCPResponse 外部MCP响应 type ExternalMCPResponse struct { Config config.ExternalMCPServerConfig `json:"config"` - Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting" - ToolCount int `json:"tool_count"` // 工具数量 + Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting" + ToolCount int `json:"tool_count"` // 工具数量 Error string `json:"error,omitempty"` // 错误信息(仅在status为error时存在) } - diff --git a/internal/handler/external_mcp_test.go b/internal/handler/external_mcp_test.go index 0ba0b1bb..a663c489 100644 --- a/internal/handler/external_mcp_test.go +++ b/internal/handler/external_mcp_test.go @@ -11,6 +11,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" + "github.com/gin-gonic/gin" "go.uber.org/zap" ) @@ -18,7 +19,7 @@ import ( func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) { gin.SetMode(gin.TestMode) router := gin.New() - + // 创建临时配置文件 tmpFile, err := os.CreateTemp("", "test-config-*.yaml") if err != nil { @@ -27,7 +28,7 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) { tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n") tmpFile.Close() configPath := tmpFile.Name() - + logger := zap.NewNop() manager := mcp.NewExternalMCPManager(logger) cfg := &config.Config{ @@ -35,9 +36,9 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) { Servers: make(map[string]config.ExternalMCPServerConfig), }, } - + handler := NewExternalMCPHandler(manager, cfg, configPath, logger) - + api := router.Group("/api") api.GET("/external-mcp", handler.GetExternalMCPs) api.GET("/external-mcp/stats", handler.GetExternalMCPStats) @@ -46,7 +47,7 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) { api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP) api.POST("/external-mcp/:name/start", handler.StartExternalMCP) api.POST("/external-mcp/:name/stop", handler.StopExternalMCP) - + return router, handler, configPath } @@ -58,7 +59,7 @@ func cleanupTestConfig(configPath string) { func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) { router, _, configPath := setupTestRouter() defer cleanupTestConfig(configPath) - + // 测试添加stdio模式的配置 configJSON := `{ "command": "python3", @@ -67,41 +68,41 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) { "timeout": 300, "enabled": true }` - + var configObj config.ExternalMCPServerConfig if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil { t.Fatalf("解析配置JSON失败: %v", err) } - + reqBody := AddOrUpdateExternalMCPRequest{ Config: configObj, } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - + router.ServeHTTP(w, req) - + if w.Code != http.StatusOK { t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) } - + // 验证配置已添加 req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil) w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) - + if w2.Code != http.StatusOK { t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) } - + var response ExternalMCPResponse if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { t.Fatalf("解析响应失败: %v", err) } - + if response.Config.Command != "python3" { t.Errorf("期望command为python3,实际%s", response.Config.Command) } @@ -122,48 +123,48 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) { func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) { router, _, configPath := setupTestRouter() defer cleanupTestConfig(configPath) - + // 测试添加HTTP模式的配置 configJSON := `{ "transport": "http", "url": "http://127.0.0.1:8081/mcp", "enabled": true }` - + var configObj config.ExternalMCPServerConfig if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil { t.Fatalf("解析配置JSON失败: %v", err) } - + reqBody := AddOrUpdateExternalMCPRequest{ Config: configObj, } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - + router.ServeHTTP(w, req) - + if w.Code != http.StatusOK { t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) } - + // 验证配置已添加 req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil) w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) - + if w2.Code != http.StatusOK { t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) } - + var response ExternalMCPResponse if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { t.Fatalf("解析响应失败: %v", err) } - + if response.Config.Transport != "http" { t.Errorf("期望transport为http,实际%s", response.Config.Transport) } @@ -178,7 +179,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) { func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) { router, _, configPath := setupTestRouter() defer cleanupTestConfig(configPath) - + testCases := []struct { name string configJSON string @@ -187,7 +188,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) { { name: "缺少command和url", configJSON: `{"enabled": true}`, - expectedErr: "需要指定command(stdio模式)或url(http模式)", + expectedErr: "需要指定command(stdio模式)或url(http/sse模式)", }, { name: "stdio模式缺少command", @@ -205,34 +206,34 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) { expectedErr: "不支持的传输模式", }, } - + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var configObj config.ExternalMCPServerConfig if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil { t.Fatalf("解析配置JSON失败: %v", err) } - + reqBody := AddOrUpdateExternalMCPRequest{ Config: configObj, } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - + router.ServeHTTP(w, req) - + if w.Code != http.StatusBadRequest { t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String()) } - + var response map[string]interface{} if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { t.Fatalf("解析响应失败: %v", err) } - + errorMsg := response["error"].(string) // 对于stdio模式缺少command的情况,错误信息可能略有不同 if tc.name == "stdio模式缺少command" { @@ -249,28 +250,28 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) { func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) { router, handler, configPath := setupTestRouter() defer cleanupTestConfig(configPath) - + // 先添加一个配置 configObj := config.ExternalMCPServerConfig{ Command: "python3", Enabled: true, } handler.manager.AddOrUpdateConfig("test-delete", configObj) - + // 删除配置 req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) - + if w.Code != http.StatusOK { t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) } - + // 验证配置已删除 req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil) w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) - + if w2.Code != http.StatusNotFound { t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String()) } @@ -278,7 +279,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) { func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { router, handler, _ := setupTestRouter() - + // 添加多个配置 handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{ Command: "python3", @@ -288,20 +289,20 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { URL: "http://127.0.0.1:8081/mcp", Enabled: false, }) - + req := httptest.NewRequest("GET", "/api/external-mcp", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) - + if w.Code != http.StatusOK { t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) } - + var response map[string]interface{} if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { t.Fatalf("解析响应失败: %v", err) } - + servers := response["servers"].(map[string]interface{}) if len(servers) != 2 { t.Errorf("期望2个服务器,实际%d", len(servers)) @@ -312,7 +313,7 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { if _, ok := servers["test2"]; !ok { t.Error("期望包含test2") } - + stats := response["stats"].(map[string]interface{}) if int(stats["total"].(float64)) != 2 { t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64))) @@ -321,7 +322,7 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) { router, handler, _ := setupTestRouter() - + // 添加配置 handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ Command: "python3", @@ -336,20 +337,20 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) { Enabled: false, Disabled: true, }) - + req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) - + if w.Code != http.StatusOK { t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) } - + var stats map[string]interface{} if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil { t.Fatalf("解析响应失败: %v", err) } - + if int(stats["total"].(float64)) != 3 { t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64))) } @@ -364,19 +365,19 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) { func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) { router, handler, configPath := setupTestRouter() defer cleanupTestConfig(configPath) - + // 添加一个禁用的配置 handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{ Command: "python3", Enabled: false, Disabled: true, }) - + // 测试启动(可能会失败,因为没有真实的服务器) req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) - + // 启动可能会失败,但应该返回合理的状态码 if w.Code != http.StatusOK { // 如果启动失败,应该是400或500 @@ -384,12 +385,12 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) { t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String()) } } - + // 测试停止 req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil) w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) - + if w2.Code != http.StatusOK { t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) } @@ -397,11 +398,11 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) { func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) { router, _, _ := setupTestRouter() - + req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) - + if w.Code != http.StatusNotFound { t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String()) } @@ -410,11 +411,11 @@ func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) { func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) { router, _, configPath := setupTestRouter() defer cleanupTestConfig(configPath) - + req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) - + // 删除不存在的配置可能返回200(幂等操作)或404,都是合理的 if w.Code != http.StatusNotFound && w.Code != http.StatusOK { t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String()) @@ -423,23 +424,23 @@ func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) { func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) { router, _, _ := setupTestRouter() - + configObj := config.ExternalMCPServerConfig{ Command: "python3", Enabled: true, } - + reqBody := AddOrUpdateExternalMCPRequest{ Config: configObj, } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - + router.ServeHTTP(w, req) - + // 空名称应该返回404或400 if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest { t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String()) @@ -448,15 +449,15 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) { func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) { router, _, _ := setupTestRouter() - + // 发送无效的JSON body := []byte(`{"config": invalid json}`) req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - + router.ServeHTTP(w, req) - + if w.Code != http.StatusBadRequest { t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String()) } @@ -465,49 +466,49 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) { func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) { router, handler, configPath := setupTestRouter() defer cleanupTestConfig(configPath) - + // 先添加配置 config1 := config.ExternalMCPServerConfig{ Command: "python3", Enabled: true, } handler.manager.AddOrUpdateConfig("test-update", config1) - + // 更新配置 config2 := config.ExternalMCPServerConfig{ URL: "http://127.0.0.1:8081/mcp", Enabled: true, } - + reqBody := AddOrUpdateExternalMCPRequest{ Config: config2, } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - + router.ServeHTTP(w, req) - + if w.Code != http.StatusOK { t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) } - + // 验证配置已更新 req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil) w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) - + if w2.Code != http.StatusOK { t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) } - + var response ExternalMCPResponse if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { t.Fatalf("解析响应失败: %v", err) } - + if response.Config.URL != "http://127.0.0.1:8081/mcp" { t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL) } @@ -515,4 +516,3 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) { t.Errorf("期望command为空,实际%s", response.Config.Command) } } - diff --git a/internal/mcp/client_sdk.go b/internal/mcp/client_sdk.go new file mode 100644 index 00000000..25f27c34 --- /dev/null +++ b/internal/mcp/client_sdk.go @@ -0,0 +1,549 @@ +// Package mcp 外部 MCP 客户端 - 基于官方 go-sdk 实现,保证协议兼容性 +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + + "github.com/google/uuid" + "github.com/modelcontextprotocol/go-sdk/mcp" + "go.uber.org/zap" +) + +const ( + clientName = "CyberStrikeAI" + clientVersion = "1.0.0" +) + +// sdkClient 基于官方 MCP Go SDK 的外部 MCP 客户端,实现 ExternalMCPClient 接口 +type sdkClient struct { + session *mcp.ClientSession + client *mcp.Client + logger *zap.Logger + mu sync.RWMutex + status string // "disconnected", "connecting", "connected", "error" +} + +// newSDKClientFromSession 用已连接成功的 session 构造(供 createSDKClient 内部使用) +func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, logger *zap.Logger) *sdkClient { + return &sdkClient{ + session: session, + client: client, + logger: logger, + status: "connected", + } +} + +// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient +type lazySDKClient struct { + serverCfg config.ExternalMCPServerConfig + logger *zap.Logger + inner ExternalMCPClient // 连接成功后为 *sdkClient + mu sync.RWMutex + status string +} + +func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient { + return &lazySDKClient{ + serverCfg: serverCfg, + logger: logger, + status: "connecting", + } +} + +func (c *lazySDKClient) setStatus(s string) { + c.mu.Lock() + defer c.mu.Unlock() + c.status = s +} + +func (c *lazySDKClient) GetStatus() string { + c.mu.RLock() + defer c.mu.RUnlock() + if c.inner != nil { + return c.inner.GetStatus() + } + return c.status +} + +func (c *lazySDKClient) IsConnected() bool { + c.mu.RLock() + inner := c.inner + c.mu.RUnlock() + if inner != nil { + return inner.IsConnected() + } + return false +} + +func (c *lazySDKClient) Initialize(ctx context.Context) error { + c.mu.Lock() + if c.inner != nil { + c.mu.Unlock() + return nil + } + c.mu.Unlock() + + inner, err := createSDKClient(ctx, c.serverCfg, c.logger) + if err != nil { + c.setStatus("error") + return err + } + + c.mu.Lock() + c.inner = inner + c.mu.Unlock() + c.setStatus("connected") + return nil +} + +func (c *lazySDKClient) ListTools(ctx context.Context) ([]Tool, error) { + c.mu.RLock() + inner := c.inner + c.mu.RUnlock() + if inner == nil { + return nil, fmt.Errorf("未连接") + } + return inner.ListTools(ctx) +} + +func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { + c.mu.RLock() + inner := c.inner + c.mu.RUnlock() + if inner == nil { + return nil, fmt.Errorf("未连接") + } + return inner.CallTool(ctx, name, args) +} + +func (c *lazySDKClient) Close() error { + c.mu.Lock() + inner := c.inner + c.inner = nil + c.mu.Unlock() + c.setStatus("disconnected") + if inner != nil { + return inner.Close() + } + return nil +} + +func (c *sdkClient) setStatus(s string) { + c.mu.Lock() + defer c.mu.Unlock() + c.status = s +} + +func (c *sdkClient) GetStatus() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.status +} + +func (c *sdkClient) IsConnected() bool { + return c.GetStatus() == "connected" +} + +func (c *sdkClient) Initialize(ctx context.Context) error { + // sdkClient 由 createSDKClient 在 Connect 成功后才创建,因此 Initialize 时已经连接 + // 此方法仅用于满足 ExternalMCPClient 接口,实际连接在 createSDKClient 中完成 + return nil +} + +func (c *sdkClient) ListTools(ctx context.Context) ([]Tool, error) { + if c.session == nil { + return nil, fmt.Errorf("未连接") + } + res, err := c.session.ListTools(ctx, nil) + if err != nil { + return nil, err + } + if res == nil { + return nil, nil + } + return sdkToolsToOur(res.Tools), nil +} + +func (c *sdkClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { + if c.session == nil { + return nil, fmt.Errorf("未连接") + } + params := &mcp.CallToolParams{ + Name: name, + Arguments: args, + } + res, err := c.session.CallTool(ctx, params) + if err != nil { + return nil, err + } + return sdkCallToolResultToOurs(res), nil +} + +func (c *sdkClient) Close() error { + c.setStatus("disconnected") + if c.session != nil { + err := c.session.Close() + c.session = nil + return err + } + return nil +} + +// sdkToolsToOur 将 SDK 的 []*mcp.Tool 转为我们的 []Tool +func sdkToolsToOur(tools []*mcp.Tool) []Tool { + if len(tools) == 0 { + return nil + } + out := make([]Tool, 0, len(tools)) + for _, t := range tools { + if t == nil { + continue + } + schema := make(map[string]interface{}) + if t.InputSchema != nil { + // SDK InputSchema 可能为 *jsonschema.Schema 或 map,统一转为 map + if m, ok := t.InputSchema.(map[string]interface{}); ok { + schema = m + } else { + _ = json.Unmarshal(mustJSON(t.InputSchema), &schema) + } + } + desc := t.Description + shortDesc := desc + if t.Annotations != nil && t.Annotations.Title != "" { + shortDesc = t.Annotations.Title + } + out = append(out, Tool{ + Name: t.Name, + Description: desc, + ShortDescription: shortDesc, + InputSchema: schema, + }) + } + return out +} + +// sdkCallToolResultToOurs 将 SDK 的 *mcp.CallToolResult 转为我们的 *ToolResult +func sdkCallToolResultToOurs(res *mcp.CallToolResult) *ToolResult { + if res == nil { + return &ToolResult{Content: []Content{}} + } + content := sdkContentToOurs(res.Content) + return &ToolResult{ + Content: content, + IsError: res.IsError, + } +} + +func sdkContentToOurs(list []mcp.Content) []Content { + if len(list) == 0 { + return nil + } + out := make([]Content, 0, len(list)) + for _, c := range list { + switch v := c.(type) { + case *mcp.TextContent: + out = append(out, Content{Type: "text", Text: v.Text}) + default: + out = append(out, Content{Type: "text", Text: fmt.Sprintf("%v", c)}) + } + } + return out +} + +func mustJSON(v interface{}) []byte { + b, _ := json.Marshal(v) + return b +} + +// simpleHTTPClient 简单 JSON-RPC over HTTP:每次请求一次 POST、响应在 body。实现 ExternalMCPClient。 +// 用于自建 MCP(如 http://127.0.0.1:8081/mcp)或其它仅支持简单 POST 的端点。 +type simpleHTTPClient struct { + url string + client *http.Client + logger *zap.Logger + mu sync.RWMutex + status string +} + +func newSimpleHTTPClient(ctx context.Context, url string, timeout time.Duration, headers map[string]string, logger *zap.Logger) (ExternalMCPClient, error) { + c := &simpleHTTPClient{ + url: url, + client: httpClientWithTimeoutAndHeaders(timeout, headers), + logger: logger, + status: "connecting", + } + if err := c.initialize(ctx); err != nil { + return nil, err + } + c.mu.Lock() + c.status = "connected" + c.mu.Unlock() + return c, nil +} + +func (c *simpleHTTPClient) setStatus(s string) { + c.mu.Lock() + defer c.mu.Unlock() + c.status = s +} + +func (c *simpleHTTPClient) GetStatus() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.status +} + +func (c *simpleHTTPClient) IsConnected() bool { + return c.GetStatus() == "connected" +} + +func (c *simpleHTTPClient) Initialize(context.Context) error { + return nil // 已在 newSimpleHTTPClient 中完成 +} + +func (c *simpleHTTPClient) initialize(ctx context.Context) error { + params := InitializeRequest{ + ProtocolVersion: ProtocolVersion, + Capabilities: make(map[string]interface{}), + ClientInfo: ClientInfo{Name: clientName, Version: clientVersion}, + } + paramsJSON, _ := json.Marshal(params) + req := &Message{ + ID: MessageID{value: "1"}, + Method: "initialize", + Version: "2.0", + Params: paramsJSON, + } + resp, err := c.sendRequest(ctx, req) + if err != nil { + return fmt.Errorf("initialize: %w", err) + } + if resp.Error != nil { + return fmt.Errorf("initialize: %s (code %d)", resp.Error.Message, resp.Error.Code) + } + // 发送 notifications/initialized(协议要求) + notify := &Message{ + ID: MessageID{value: nil}, + Method: "notifications/initialized", + Version: "2.0", + Params: json.RawMessage("{}"), + } + _ = c.sendNotification(notify) + return nil +} + +func (c *simpleHTTPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) { + body, err := json.Marshal(msg) + if err != nil { + return nil, err + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Content-Type", "application/json") + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(b)) + } + var out Message + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, err + } + return &out, nil +} + +func (c *simpleHTTPClient) sendNotification(msg *Message) error { + body, _ := json.Marshal(msg) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body)) + httpReq.Header.Set("Content-Type", "application/json") + resp, err := c.client.Do(httpReq) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +func (c *simpleHTTPClient) ListTools(ctx context.Context) ([]Tool, error) { + req := &Message{ + ID: MessageID{value: uuid.New().String()}, + Method: "tools/list", + Version: "2.0", + Params: json.RawMessage("{}"), + } + resp, err := c.sendRequest(ctx, req) + if err != nil { + return nil, err + } + if resp.Error != nil { + return nil, fmt.Errorf("tools/list: %s (code %d)", resp.Error.Message, resp.Error.Code) + } + var listResp ListToolsResponse + if err := json.Unmarshal(resp.Result, &listResp); err != nil { + return nil, err + } + return listResp.Tools, nil +} + +func (c *simpleHTTPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { + params := CallToolRequest{Name: name, Arguments: args} + paramsJSON, _ := json.Marshal(params) + req := &Message{ + ID: MessageID{value: uuid.New().String()}, + Method: "tools/call", + Version: "2.0", + Params: paramsJSON, + } + resp, err := c.sendRequest(ctx, req) + if err != nil { + return nil, err + } + if resp.Error != nil { + return nil, fmt.Errorf("tools/call: %s (code %d)", resp.Error.Message, resp.Error.Code) + } + var callResp CallToolResponse + if err := json.Unmarshal(resp.Result, &callResp); err != nil { + return nil, err + } + return &ToolResult{Content: callResp.Content, IsError: callResp.IsError}, nil +} + +func (c *simpleHTTPClient) Close() error { + c.setStatus("disconnected") + return nil +} + +// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient +// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。 +func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) { + timeout := time.Duration(serverCfg.Timeout) * time.Second + if timeout <= 0 { + timeout = 30 * time.Second + } + + transport := serverCfg.Transport + if transport == "" { + if serverCfg.Command != "" { + transport = "stdio" + } else if serverCfg.URL != "" { + transport = "http" + } else { + return nil, fmt.Errorf("配置缺少 command 或 url") + } + } + + client := mcp.NewClient(&mcp.Implementation{ + Name: clientName, + Version: clientVersion, + }, nil) + + var t mcp.Transport + switch transport { + case "stdio": + if serverCfg.Command == "" { + return nil, fmt.Errorf("stdio 模式需要配置 command") + } + cmd := exec.CommandContext(ctx, serverCfg.Command, serverCfg.Args...) // 使用 ctx 控制超时与取消 + if len(serverCfg.Env) > 0 { + cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...) + } + t = &mcp.CommandTransport{Command: cmd} + case "sse": + if serverCfg.URL == "" { + return nil, fmt.Errorf("sse 模式需要配置 url") + } + httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers) + t = &mcp.SSEClientTransport{ + Endpoint: serverCfg.URL, + HTTPClient: httpClient, + } + case "http": + if serverCfg.URL == "" { + return nil, fmt.Errorf("http 模式需要配置 url") + } + httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers) + t = &mcp.StreamableClientTransport{ + Endpoint: serverCfg.URL, + HTTPClient: httpClient, + } + case "simple_http": + // 简单 JSON-RPC HTTP:每次请求一次 POST、响应在 body。用于自建 MCP 或兼容旧端点(如 http://127.0.0.1:8081/mcp) + if serverCfg.URL == "" { + return nil, fmt.Errorf("simple_http 模式需要配置 url") + } + return newSimpleHTTPClient(ctx, serverCfg.URL, timeout, serverCfg.Headers, logger) + default: + return nil, fmt.Errorf("不支持的传输模式: %s", transport) + } + + session, err := client.Connect(ctx, t, nil) + if err != nil { + return nil, fmt.Errorf("连接失败: %w", err) + } + + return newSDKClientFromSession(session, client, logger), nil +} + +func envMapToSlice(env map[string]string) []string { + m := make(map[string]string) + for _, s := range os.Environ() { + if i := strings.IndexByte(s, '='); i > 0 { + m[s[:i]] = s[i+1:] + } + } + for k, v := range env { + m[k] = v + } + out := make([]string, 0, len(m)) + for k, v := range m { + out = append(out, k+"="+v) + } + return out +} + +func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]string) *http.Client { + transport := http.DefaultTransport + if len(headers) > 0 { + transport = &headerRoundTripper{ + headers: headers, + base: http.DefaultTransport, + } + } + return &http.Client{ + Timeout: timeout, + Transport: transport, + } +} + +type headerRoundTripper struct { + headers map[string]string + base http.RoundTripper +} + +func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + for k, v := range h.headers { + req.Header.Set(k, v) + } + return h.base.RoundTrip(req) +} diff --git a/internal/mcp/external_manager.go b/internal/mcp/external_manager.go index 4eb3410b..04841982 100644 --- a/internal/mcp/external_manager.go +++ b/internal/mcp/external_manager.go @@ -196,7 +196,8 @@ func (m *ExternalMCPManager) StartClient(name string) error { m.mu.Lock() delete(m.errors, name) m.mu.Unlock() - // 连接成功,立即刷新工具数量 + // 延迟再刷新工具数量,避免 SSE/Streamable 连接尚未就绪时立即请求导致 EOF(如值得买等远端) + time.Sleep(2 * time.Second) m.triggerToolCountRefresh() } }() @@ -630,11 +631,20 @@ func (m *ExternalMCPManager) refreshToolCounts() { cancel() if err != nil { - m.logger.Debug("获取外部MCP工具数量失败", - zap.String("name", n), - zap.Error(err), - ) - // 如果获取失败,保留旧值(在更新时处理) + errStr := err.Error() + // SSE 连接 EOF:远端可能关闭了流或未按规范在流上推送响应,仅首次用 Warn 提示 + if strings.Contains(errStr, "EOF") || strings.Contains(errStr, "client is closing") { + m.logger.Warn("获取外部MCP工具数量失败(SSE 流已关闭或服务端未在流上返回 tools/list 响应)", + zap.String("name", n), + zap.String("hint", "若为 SSE 连接,请确认服务端保持 GET 流打开并按 MCP 规范以 event: message 推送 JSON-RPC 响应"), + zap.Error(err), + ) + } else { + m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list", + zap.String("name", n), + zap.Error(err), + ) + } resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值 return } @@ -707,21 +717,13 @@ func (m *ExternalMCPManager) triggerToolCountRefresh() { go m.refreshToolCounts() } -// createClient 创建客户端(不连接) +// createClient 创建客户端(不连接)。统一使用官方 MCP Go SDK 的 lazy 客户端,连接在 Initialize 时完成。 func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient { - timeout := time.Duration(serverCfg.Timeout) * time.Second - if timeout <= 0 { - timeout = 30 * time.Second - } - - // 根据传输模式创建客户端 transport := serverCfg.Transport if transport == "" { - // 如果没有指定transport,根据是否有command或url判断 if serverCfg.Command != "" { transport = "stdio" } else if serverCfg.URL != "" { - // 默认使用http,但可以通过transport字段指定sse transport = "http" } else { return nil @@ -733,17 +735,23 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf if serverCfg.URL == "" { return nil } - return NewHTTPMCPClient(serverCfg.URL, timeout, m.logger) + return newLazySDKClient(serverCfg, m.logger) + case "simple_http": + // 简单 HTTP(一次 POST 一次响应),用于自建 MCP 等 + if serverCfg.URL == "" { + return nil + } + return newLazySDKClient(serverCfg, m.logger) case "stdio": if serverCfg.Command == "" { return nil } - return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, serverCfg.Env, timeout, m.logger) + return newLazySDKClient(serverCfg, m.logger) case "sse": if serverCfg.URL == "" { return nil } - return NewSSEMCPClient(serverCfg.URL, timeout, m.logger) + return newLazySDKClient(serverCfg, m.logger) default: return nil } @@ -773,12 +781,7 @@ func (m *ExternalMCPManager) doConnect(name string, serverCfg config.ExternalMCP // setClientStatus 设置客户端状态(通过类型断言) func (m *ExternalMCPManager) setClientStatus(client ExternalMCPClient, status string) { - switch c := client.(type) { - case *HTTPMCPClient: - c.setStatus(status) - case *StdioMCPClient: - c.setStatus(status) - case *SSEMCPClient: + if c, ok := client.(*lazySDKClient); ok { c.setStatus(status) } } diff --git a/internal/mcp/external_manager_test.go b/internal/mcp/external_manager_test.go index ffd0c8a2..d4c49851 100644 --- a/internal/mcp/external_manager_test.go +++ b/internal/mcp/external_manager_test.go @@ -151,48 +151,26 @@ func TestExternalMCPManager_LoadConfigs(t *testing.T) { } } -func TestHTTPMCPClient_Initialize(t *testing.T) { - // 注意:这个测试需要一个真实的HTTP MCP服务器 - // 如果没有服务器,这个测试会失败 - // 在实际测试中,可以使用mock服务器 +// TestLazySDKClient_InitializeFails 验证无效配置时 SDK 客户端 Initialize 失败并设置 error 状态 +func TestLazySDKClient_InitializeFails(t *testing.T) { logger := zap.NewNop() - client := NewHTTPMCPClient("http://127.0.0.1:8081/mcp", 5*time.Second, logger) - + // 使用不存在的 HTTP 地址,Initialize 应失败 + cfg := config.ExternalMCPServerConfig{ + Transport: "http", + URL: "http://127.0.0.1:19999/nonexistent", + Timeout: 2, + } + c := newLazySDKClient(cfg, logger) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - - // 这个测试可能会失败,如果没有真实的服务器 - // 在实际环境中,应该使用mock服务器 - err := client.Initialize(ctx) - if err != nil { - t.Logf("初始化失败(可能是没有服务器): %v", err) + err := c.Initialize(ctx) + if err == nil { + t.Fatal("expected error when connecting to invalid server") } - - status := client.GetStatus() - if status == "" { - t.Error("状态不应该为空") + if c.GetStatus() != "error" { + t.Errorf("expected status error, got %s", c.GetStatus()) } - - client.Close() -} - -func TestStdioMCPClient_Initialize(t *testing.T) { - // 注意:这个测试需要一个真实的stdio MCP服务器 - // 如果没有服务器,这个测试会失败 - logger := zap.NewNop() - client := NewStdioMCPClient("echo", []string{"test"}, nil, 5*time.Second, logger) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // 这个测试可能会失败,因为echo不是MCP服务器 - // 在实际环境中,应该使用真实的MCP服务器或mock - err := client.Initialize(ctx) - if err != nil { - t.Logf("初始化失败(echo不是MCP服务器): %v", err) - } - - client.Close() + c.Close() } func TestExternalMCPManager_StartStopClient(t *testing.T) { diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 81cabd10..94e32a27 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -125,6 +125,13 @@ func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) { return } + // 官方 MCP SSE 规范:带 sessionid 的 POST 表示消息发往该 SSE 会话,响应通过 SSE 流返回 + if sessionID := r.URL.Query().Get("sessionid"); sessionID != "" { + s.serveSSESessionMessage(w, r, sessionID) + return + } + + // 简单 POST:请求体为 JSON-RPC,响应在 body 中返回 body, err := io.ReadAll(r.Body) if err != nil { s.sendError(w, nil, -32700, "Parse error", err.Error()) @@ -137,14 +144,56 @@ func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) { return } - // 处理消息 response := s.handleMessage(&msg) - w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } -// handleSSE 处理SSE连接(用于MCP HTTP传输的事件通道) +// serveSSESessionMessage 处理发往 SSE 会话的 POST:读取 JSON-RPC 请求,处理后将响应通过该会话的 SSE 流推送 +func (s *Server) serveSSESessionMessage(w http.ResponseWriter, r *http.Request, sessionID string) { + s.mu.RLock() + client, exists := s.sseClients[sessionID] + s.mu.RUnlock() + if !exists || client == nil { + http.Error(w, "session not found", http.StatusNotFound) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + + var msg Message + if err := json.Unmarshal(body, &msg); err != nil { + http.Error(w, "failed to parse body", http.StatusBadRequest) + return + } + + response := s.handleMessage(&msg) + if response == nil { + w.WriteHeader(http.StatusAccepted) + return + } + + respBytes, err := json.Marshal(response) + if err != nil { + http.Error(w, "failed to encode response", http.StatusInternalServerError) + return + } + + select { + case client.send <- respBytes: + w.WriteHeader(http.StatusAccepted) + default: + http.Error(w, "session send buffer full", http.StatusServiceUnavailable) + } +} + +// handleSSE 处理 SSE 连接,兼容官方 MCP 2024-11-05 SSE 规范: +// 1. 首个事件必须为 event: endpoint,data 为客户端 POST 消息的 URL(含 sessionid) +// 2. 后续事件为 event: message,data 为 JSON-RPC 响应 func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) { flusher, ok := w.(http.Flusher) if !ok { @@ -157,16 +206,25 @@ func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) { w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") + sessionID := uuid.New().String() client := &sseClient{ - id: uuid.New().String(), - send: make(chan []byte, 8), + id: sessionID, + send: make(chan []byte, 32), } s.addSSEClient(client) defer s.removeSSEClient(client.id) - // 发送初始ready事件,告知客户端连接成功 - fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n") + // 官方规范:首个事件为 endpoint,data 为消息端点 URL(客户端将向该 URL POST 请求) + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + if r.URL.Scheme != "" { + scheme = r.URL.Scheme + } + endpointURL := fmt.Sprintf("%s://%s%s?sessionid=%s", scheme, r.Host, r.URL.Path, sessionID) + fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpointURL) flusher.Flush() ticker := time.NewTicker(15 * time.Second) @@ -183,7 +241,6 @@ func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg) flusher.Flush() case <-ticker.C: - // 心跳保持连接 fmt.Fprintf(w, ": ping\n\n") flusher.Flush() } diff --git a/internal/mcp/types.go b/internal/mcp/types.go index 91c9b3d8..393717b9 100644 --- a/internal/mcp/types.go +++ b/internal/mcp/types.go @@ -1,11 +1,22 @@ package mcp import ( + "context" "encoding/json" "fmt" "time" ) +// ExternalMCPClient 外部 MCP 客户端接口(由 client_sdk.go 基于官方 SDK 实现) +type ExternalMCPClient interface { + Initialize(ctx context.Context) error + ListTools(ctx context.Context) ([]Tool, error) + CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) + Close() error + IsConnected() bool + GetStatus() string +} + // MCP消息类型 const ( MessageTypeRequest = "request" @@ -29,21 +40,21 @@ func (m *MessageID) UnmarshalJSON(data []byte) error { m.value = nil return nil } - + // 尝试解析为字符串 var str string if err := json.Unmarshal(data, &str); err == nil { m.value = str return nil } - + // 尝试解析为数字 var num json.Number if err := json.Unmarshal(data, &num); err == nil { m.value = num return nil } - + return fmt.Errorf("invalid id type") } @@ -81,15 +92,15 @@ type Message struct { // Error 表示MCP错误 type Error struct { - Code int `json:"code"` - Message string `json:"message"` + Code int `json:"code"` + Message string `json:"message"` Data interface{} `json:"data,omitempty"` } // Tool 表示MCP工具定义 type Tool struct { Name string `json:"name"` - Description string `json:"description"` // 详细描述 + Description string `json:"description"` // 详细描述 ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗) InputSchema map[string]interface{} `json:"inputSchema"` } @@ -127,9 +138,9 @@ type ClientInfo struct { // InitializeResponse 初始化响应 type InitializeResponse struct { - ProtocolVersion string `json:"protocolVersion"` - Capabilities ServerCapabilities `json:"capabilities"` - ServerInfo ServerInfo `json:"serverInfo"` + ProtocolVersion string `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo ServerInfo `json:"serverInfo"` } // ServerCapabilities 服务器能力 @@ -178,31 +189,31 @@ type CallToolResponse struct { // ToolExecution 工具执行记录 type ToolExecution struct { - ID string `json:"id"` - ToolName string `json:"toolName"` - Arguments map[string]interface{} `json:"arguments"` - Status string `json:"status"` // pending, running, completed, failed - Result *ToolResult `json:"result,omitempty"` - Error string `json:"error,omitempty"` - StartTime time.Time `json:"startTime"` - EndTime *time.Time `json:"endTime,omitempty"` - Duration time.Duration `json:"duration,omitempty"` + ID string `json:"id"` + ToolName string `json:"toolName"` + Arguments map[string]interface{} `json:"arguments"` + Status string `json:"status"` // pending, running, completed, failed + Result *ToolResult `json:"result,omitempty"` + Error string `json:"error,omitempty"` + StartTime time.Time `json:"startTime"` + EndTime *time.Time `json:"endTime,omitempty"` + Duration time.Duration `json:"duration,omitempty"` } // ToolStats 工具统计信息 type ToolStats struct { - ToolName string `json:"toolName"` - TotalCalls int `json:"totalCalls"` - SuccessCalls int `json:"successCalls"` - FailedCalls int `json:"failedCalls"` + ToolName string `json:"toolName"` + TotalCalls int `json:"totalCalls"` + SuccessCalls int `json:"successCalls"` + FailedCalls int `json:"failedCalls"` LastCallTime *time.Time `json:"lastCallTime,omitempty"` } // Prompt 提示词模板 type Prompt struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Arguments []PromptArgument `json:"arguments,omitempty"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Arguments []PromptArgument `json:"arguments,omitempty"` } // PromptArgument 提示词参数 @@ -257,11 +268,11 @@ type ResourceContent struct { // SamplingRequest 采样请求 type SamplingRequest struct { - Messages []SamplingMessage `json:"messages"` - Model string `json:"model,omitempty"` - MaxTokens int `json:"maxTokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` + Messages []SamplingMessage `json:"messages"` + Model string `json:"model,omitempty"` + MaxTokens int `json:"maxTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` } // SamplingMessage 采样消息 @@ -272,9 +283,9 @@ type SamplingMessage struct { // SamplingResponse 采样响应 type SamplingResponse struct { - Content []SamplingContent `json:"content"` - Model string `json:"model,omitempty"` - StopReason string `json:"stopReason,omitempty"` + Content []SamplingContent `json:"content"` + Model string `json:"model,omitempty"` + StopReason string `json:"stopReason,omitempty"` } // SamplingContent 采样内容 @@ -282,4 +293,3 @@ type SamplingContent struct { Type string `json:"type"` Text string `json:"text,omitempty"` } - diff --git a/web/static/js/settings.js b/web/static/js/settings.js index 27aa5b7d..2c87af75 100644 --- a/web/static/js/settings.js +++ b/web/static/js/settings.js @@ -994,6 +994,20 @@ async function loadExternalMCPs() { } } +// 延迟刷新外部MCP列表(用于在保存/连接后拉取后端异步更新的工具数量) +// 可选传入单次延迟毫秒数;不传则执行两次:2.5s 与 5s(覆盖启动后后端异步更新较慢的情况) +function scheduleExternalMCPToolCountRefresh(delayMs) { + const delays = delayMs != null ? [delayMs] : [2500, 5000]; + delays.forEach((d) => { + setTimeout(async () => { + await loadExternalMCPs(); + if (typeof window !== 'undefined' && typeof window.refreshMentionTools === 'function') { + window.refreshMentionTools(); + } + }, d); + }); +} + // 渲染外部MCP列表 function renderExternalMCPList(servers) { const list = document.getElementById('external-mcp-list'); @@ -1354,6 +1368,8 @@ async function saveExternalMCP() { if (typeof window !== 'undefined' && typeof window.refreshMentionTools === 'function') { window.refreshMentionTools(); } + // 后端在连接成功约 2 秒后才更新工具数量,延迟再拉取一次以显示正确工具数 + scheduleExternalMCPToolCountRefresh(); alert('保存成功'); } catch (error) { console.error('保存外部MCP失败:', error); @@ -1433,6 +1449,8 @@ async function toggleExternalMCP(name, currentStatus) { if (typeof window !== 'undefined' && typeof window.refreshMentionTools === 'function') { window.refreshMentionTools(); } + // 后端约 2 秒后才更新工具数量,延迟再拉取一次以显示正确工具数 + scheduleExternalMCPToolCountRefresh(); return; } } @@ -1496,6 +1514,8 @@ async function pollExternalMCPStatus(name, maxAttempts = 30) { if (typeof window !== 'undefined' && typeof window.refreshMentionTools === 'function') { window.refreshMentionTools(); } + // 后端约 2 秒后才更新工具数量,延迟再拉取一次以显示正确工具数 + scheduleExternalMCPToolCountRefresh(); return; } else if (status === 'error' || status === 'disconnected') { // 连接失败,刷新列表并显示错误