Add files via upload

This commit is contained in:
公明
2026-02-04 01:34:25 +08:00
committed by GitHub
parent 91230f273e
commit 84d54b1ea9
9 changed files with 882 additions and 257 deletions

View File

@@ -13,19 +13,19 @@ import (
)
type Config struct {
Server ServerConfig `yaml:"server"`
Log LogConfig `yaml:"log"`
MCP MCPConfig `yaml:"mcp"`
OpenAI OpenAIConfig `yaml:"openai"`
Agent AgentConfig `yaml:"agent"`
Security SecurityConfig `yaml:"security"`
Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
Server ServerConfig `yaml:"server"`
Log LogConfig `yaml:"log"`
MCP MCPConfig `yaml:"mcp"`
OpenAI OpenAIConfig `yaml:"openai"`
Agent AgentConfig `yaml:"agent"`
Security SecurityConfig `yaml:"security"`
Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
}
type ServerConfig struct {
@@ -83,13 +83,14 @@ type ExternalMCPConfig struct {
// ExternalMCPServerConfig 外部MCP服务器配置
type ExternalMCPServerConfig struct {
// stdio模式配置
Command string `yaml:"command,omitempty" json:"command,omitempty"`
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
Command string `yaml:"command,omitempty" json:"command,omitempty"`
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量用于stdio模式
// HTTP模式配置
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio"
URL string `yaml:"url,omitempty" json:"url,omitempty"`
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "stdio" | "sse" | "http"(Streamable) | "simple_http"(自建/简单POST端点如本机 http://127.0.0.1:8081/mcp)
URL string `yaml:"url,omitempty" json:"url,omitempty"`
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` // HTTP/SSE 请求头(如 x-api-key
// 通用配置
Description string `yaml:"description,omitempty" json:"description,omitempty"`
@@ -108,8 +109,8 @@ type ToolConfig struct {
ShortDescription string `yaml:"short_description,omitempty"` // 简短描述用于工具列表减少token消耗
Description string `yaml:"description"` // 详细描述(用于工具文档)
Enabled bool `yaml:"enabled"`
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
AllowedExitCodes []int `yaml:"allowed_exit_codes,omitempty"` // 允许的退出码列表(某些工具在成功时也返回非零退出码)
}
@@ -467,7 +468,7 @@ func LoadRoleFromFile(path string) (*RoleConfig, error) {
icon := role.Icon
// 去除可能的引号
icon = strings.Trim(icon, `"`)
// 检查是否是 Unicode 转义格式 \U0001F3C68位十六进制或 \uXXXX4位十六进制
if len(icon) >= 3 && icon[0] == '\\' {
if icon[1] == 'U' && len(icon) >= 10 {
@@ -576,12 +577,12 @@ type RolesConfig struct {
// RoleConfig 单个角色配置
type RoleConfig struct {
Name string `yaml:"name" json:"name"` // 角色名称
Description string `yaml:"description" json:"description"` // 角色描述
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表toolKey格式如 "toolName" 或 "mcpName::toolName"
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容关联的MCP服务器列表已废弃使用tools替代
Name string `yaml:"name" json:"name"` // 角色名称
Description string `yaml:"description" json:"description"` // 角色描述
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表toolKey格式如 "toolName" 或 "mcpName::toolName"
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容关联的MCP服务器列表已废弃使用tools替代
Skills []string `yaml:"skills,omitempty" json:"skills,omitempty"` // 关联的skills列表skill名称列表在执行任务前会读取这些skills的内容
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
}

View File

@@ -8,6 +8,7 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
@@ -36,12 +37,12 @@ func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config,
func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
configs := h.manager.GetConfigs()
// 获取所有外部MCP的工具数量
toolCounts := h.manager.GetToolCounts()
// 转换为响应格式
result := make(map[string]ExternalMCPResponse)
for name, cfg := range configs {
@@ -54,13 +55,13 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
} else {
status = "disabled"
}
toolCount := toolCounts[name]
errorMsg := ""
if status == "error" {
errorMsg = h.manager.GetError(name)
}
result[name] = ExternalMCPResponse{
Config: cfg,
Status: status,
@@ -68,7 +69,7 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
Error: errorMsg,
}
}
c.JSON(http.StatusOK, gin.H{
"servers": result,
"stats": h.manager.GetStats(),
@@ -78,17 +79,17 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
// GetExternalMCP 获取单个外部MCP配置
func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.RLock()
defer h.mu.RUnlock()
configs := h.manager.GetConfigs()
cfg, exists := configs[name]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"})
return
}
client, clientExists := h.manager.GetClient(name)
status := "disconnected"
if clientExists {
@@ -98,7 +99,7 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
} else {
status = "disabled"
}
// 获取工具数量
toolCount := 0
if clientExists && client.IsConnected() {
@@ -106,13 +107,13 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
toolCount = count
}
}
// 获取错误信息
errorMsg := ""
if status == "error" {
errorMsg = h.manager.GetError(name)
}
c.JSON(http.StatusOK, ExternalMCPResponse{
Config: cfg,
Status: status,
@@ -128,38 +129,38 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
name := c.Param("name")
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"})
return
}
// 验证配置
if err := h.validateConfig(req.Config); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.mu.Lock()
defer h.mu.Unlock()
// 添加或更新配置
if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil {
h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()})
return
}
// 更新内存中的配置
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
// 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容
// 同时将值迁移到 external_mcp_enable
cfg := req.Config
if req.Config.Disabled {
// 用户设置了 disabled: true
cfg.ExternalMCPEnable = false
@@ -185,16 +186,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
cfg.Enabled = true
cfg.Disabled = false
}
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
}
@@ -202,28 +203,28 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
// DeleteExternalMCP 删除外部MCP配置
func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 移除配置
if err := h.manager.RemoveConfig(name); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"})
return
}
// 从内存配置中删除
if h.config.ExternalMCP.Servers != nil {
delete(h.config.ExternalMCP.Servers, name)
}
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
}
@@ -231,10 +232,10 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
// StartExternalMCP 启动外部MCP
func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 更新配置为启用
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
@@ -242,32 +243,32 @@ func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
cfg := h.config.ExternalMCP.Servers[name]
cfg.ExternalMCPEnable = true
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
// 启动客户端立即创建客户端并设置状态为connecting实际连接在后台进行
h.logger.Info("开始启动外部MCP", zap.String("name", name))
if err := h.manager.StartClient(name); err != nil {
h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{
"error": err.Error(),
"error": err.Error(),
"status": "error",
})
return
}
// 获取客户端状态应该是connecting
client, exists := h.manager.GetClient(name)
status := "connecting"
if exists {
status = client.GetStatus()
}
// 立即返回,不等待连接完成
// 客户端会在后台异步连接,用户可以通过状态查询接口查看连接状态
c.JSON(http.StatusOK, gin.H{
@@ -279,16 +280,16 @@ func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
// StopExternalMCP 停止外部MCP
func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 停止客户端
if err := h.manager.StopClient(name); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 更新配置
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
@@ -296,14 +297,14 @@ func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
cfg := h.config.ExternalMCP.Servers[name]
cfg.ExternalMCPEnable = false
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP已停止", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"})
}
@@ -327,7 +328,7 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
return fmt.Errorf("需要指定commandstdio模式或urlhttp/sse模式")
}
}
switch transport {
case "http":
if cfg.URL == "" {
@@ -344,7 +345,7 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
default:
return fmt.Errorf("不支持的传输模式: %s支持的模式: http, stdio, sse", transport)
}
return nil
}
@@ -428,17 +429,17 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
root := doc.Content[0]
externalMCPNode := ensureMap(root, "external_mcp")
serversNode := ensureMap(externalMCPNode, "servers")
// 清空现有服务器配置
serversNode.Content = nil
// 添加新的服务器配置
for name, serverCfg := range cfg.Servers {
// 添加服务器名称键
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
// 设置服务器配置字段
if serverCfg.Command != "" {
setStringInMap(serverNode, "command", serverCfg.Command)
@@ -459,6 +460,13 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
if serverCfg.URL != "" {
setStringInMap(serverNode, "url", serverCfg.URL)
}
// 保存 headers 字段HTTP/SSE 请求头)
if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 {
headersNode := ensureMap(serverNode, "headers")
for k, v := range serverCfg.Headers {
setStringInMap(headersNode, k, v)
}
}
if serverCfg.Description != "" {
setStringInMap(serverNode, "description", serverCfg.Description)
}
@@ -476,7 +484,7 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
}
// 保留旧的 enabled/disabled 字段以保持向后兼容
originalFields, hasOriginal := originalConfigs[name]
// 如果原始配置中有 enabled 字段,保留它
if hasOriginal {
if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled {
@@ -494,7 +502,7 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
}
}
}
// 如果用户在当前请求中明确设置了这些字段,也保存它们
if serverCfg.Enabled {
setBoolInMap(serverNode, "enabled", serverCfg.Enabled)
@@ -528,8 +536,7 @@ type AddOrUpdateExternalMCPRequest struct {
// ExternalMCPResponse 外部MCP响应
type ExternalMCPResponse struct {
Config config.ExternalMCPServerConfig `json:"config"`
Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting"
ToolCount int `json:"tool_count"` // 工具数量
Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting"
ToolCount int `json:"tool_count"` // 工具数量
Error string `json:"error,omitempty"` // 错误信息仅在status为error时存在
}

View File

@@ -11,6 +11,7 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
@@ -18,7 +19,7 @@ import (
func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建临时配置文件
tmpFile, err := os.CreateTemp("", "test-config-*.yaml")
if err != nil {
@@ -27,7 +28,7 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n")
tmpFile.Close()
configPath := tmpFile.Name()
logger := zap.NewNop()
manager := mcp.NewExternalMCPManager(logger)
cfg := &config.Config{
@@ -35,9 +36,9 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
Servers: make(map[string]config.ExternalMCPServerConfig),
},
}
handler := NewExternalMCPHandler(manager, cfg, configPath, logger)
api := router.Group("/api")
api.GET("/external-mcp", handler.GetExternalMCPs)
api.GET("/external-mcp/stats", handler.GetExternalMCPStats)
@@ -46,7 +47,7 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP)
api.POST("/external-mcp/:name/start", handler.StartExternalMCP)
api.POST("/external-mcp/:name/stop", handler.StopExternalMCP)
return router, handler, configPath
}
@@ -58,7 +59,7 @@ func cleanupTestConfig(configPath string) {
func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加stdio模式的配置
configJSON := `{
"command": "python3",
@@ -67,41 +68,41 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
"timeout": 300,
"enabled": true
}`
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已添加
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Command != "python3" {
t.Errorf("期望command为python3实际%s", response.Config.Command)
}
@@ -122,48 +123,48 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加HTTP模式的配置
configJSON := `{
"transport": "http",
"url": "http://127.0.0.1:8081/mcp",
"enabled": true
}`
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已添加
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Transport != "http" {
t.Errorf("期望transport为http实际%s", response.Config.Transport)
}
@@ -178,7 +179,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
testCases := []struct {
name string
configJSON string
@@ -187,7 +188,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
{
name: "缺少command和url",
configJSON: `{"enabled": true}`,
expectedErr: "需要指定commandstdio模式或urlhttp模式",
expectedErr: "需要指定commandstdio模式或urlhttp/sse模式)",
},
{
name: "stdio模式缺少command",
@@ -205,34 +206,34 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
expectedErr: "不支持的传输模式",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码400实际%d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
errorMsg := response["error"].(string)
// 对于stdio模式缺少command的情况错误信息可能略有不同
if tc.name == "stdio模式缺少command" {
@@ -249,28 +250,28 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 先添加一个配置
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
handler.manager.AddOrUpdateConfig("test-delete", configObj)
// 删除配置
req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已删除
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusNotFound {
t.Errorf("期望状态码404实际%d: %s", w2.Code, w2.Body.String())
}
@@ -278,7 +279,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
router, handler, _ := setupTestRouter()
// 添加多个配置
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
Command: "python3",
@@ -288,20 +289,20 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
URL: "http://127.0.0.1:8081/mcp",
Enabled: false,
})
req := httptest.NewRequest("GET", "/api/external-mcp", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200实际%d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
servers := response["servers"].(map[string]interface{})
if len(servers) != 2 {
t.Errorf("期望2个服务器实际%d", len(servers))
@@ -312,7 +313,7 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
if _, ok := servers["test2"]; !ok {
t.Error("期望包含test2")
}
stats := response["stats"].(map[string]interface{})
if int(stats["total"].(float64)) != 2 {
t.Errorf("期望总数为2实际%d", int(stats["total"].(float64)))
@@ -321,7 +322,7 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
router, handler, _ := setupTestRouter()
// 添加配置
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
@@ -336,20 +337,20 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
Enabled: false,
Disabled: true,
})
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200实际%d: %s", w.Code, w.Body.String())
}
var stats map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if int(stats["total"].(float64)) != 3 {
t.Errorf("期望总数为3实际%d", int(stats["total"].(float64)))
}
@@ -364,19 +365,19 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 添加一个禁用的配置
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: false,
Disabled: true,
})
// 测试启动(可能会失败,因为没有真实的服务器)
req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 启动可能会失败,但应该返回合理的状态码
if w.Code != http.StatusOK {
// 如果启动失败应该是400或500
@@ -384,12 +385,12 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
t.Errorf("期望状态码200/400/500实际%d: %s", w.Code, w.Body.String())
}
}
// 测试停止
req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Errorf("期望状态码200实际%d: %s", w2.Code, w2.Body.String())
}
@@ -397,11 +398,11 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
router, _, _ := setupTestRouter()
req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码404实际%d: %s", w.Code, w.Body.String())
}
@@ -410,11 +411,11 @@ func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 删除不存在的配置可能返回200幂等操作或404都是合理的
if w.Code != http.StatusNotFound && w.Code != http.StatusOK {
t.Errorf("期望状态码404或200实际%d: %s", w.Code, w.Body.String())
@@ -423,23 +424,23 @@ func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
router, _, _ := setupTestRouter()
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 空名称应该返回404或400
if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest {
t.Errorf("期望状态码404或400实际%d: %s", w.Code, w.Body.String())
@@ -448,15 +449,15 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
router, _, _ := setupTestRouter()
// 发送无效的JSON
body := []byte(`{"config": invalid json}`)
req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码400实际%d: %s", w.Code, w.Body.String())
}
@@ -465,49 +466,49 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 先添加配置
config1 := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
handler.manager.AddOrUpdateConfig("test-update", config1)
// 更新配置
config2 := config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: config2,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已更新
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
}
@@ -515,4 +516,3 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
t.Errorf("期望command为空实际%s", response.Config.Command)
}
}

549
internal/mcp/client_sdk.go Normal file
View File

@@ -0,0 +1,549 @@
// Package mcp 外部 MCP 客户端 - 基于官方 go-sdk 实现,保证协议兼容性
package mcp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/mcp"
"go.uber.org/zap"
)
const (
clientName = "CyberStrikeAI"
clientVersion = "1.0.0"
)
// sdkClient 基于官方 MCP Go SDK 的外部 MCP 客户端,实现 ExternalMCPClient 接口
type sdkClient struct {
session *mcp.ClientSession
client *mcp.Client
logger *zap.Logger
mu sync.RWMutex
status string // "disconnected", "connecting", "connected", "error"
}
// newSDKClientFromSession 用已连接成功的 session 构造(供 createSDKClient 内部使用)
func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, logger *zap.Logger) *sdkClient {
return &sdkClient{
session: session,
client: client,
logger: logger,
status: "connected",
}
}
// lazySDKClient 延迟连接Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient
type lazySDKClient struct {
serverCfg config.ExternalMCPServerConfig
logger *zap.Logger
inner ExternalMCPClient // 连接成功后为 *sdkClient
mu sync.RWMutex
status string
}
func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient {
return &lazySDKClient{
serverCfg: serverCfg,
logger: logger,
status: "connecting",
}
}
func (c *lazySDKClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *lazySDKClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
if c.inner != nil {
return c.inner.GetStatus()
}
return c.status
}
func (c *lazySDKClient) IsConnected() bool {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner != nil {
return inner.IsConnected()
}
return false
}
func (c *lazySDKClient) Initialize(ctx context.Context) error {
c.mu.Lock()
if c.inner != nil {
c.mu.Unlock()
return nil
}
c.mu.Unlock()
inner, err := createSDKClient(ctx, c.serverCfg, c.logger)
if err != nil {
c.setStatus("error")
return err
}
c.mu.Lock()
c.inner = inner
c.mu.Unlock()
c.setStatus("connected")
return nil
}
func (c *lazySDKClient) ListTools(ctx context.Context) ([]Tool, error) {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner == nil {
return nil, fmt.Errorf("未连接")
}
return inner.ListTools(ctx)
}
func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner == nil {
return nil, fmt.Errorf("未连接")
}
return inner.CallTool(ctx, name, args)
}
func (c *lazySDKClient) Close() error {
c.mu.Lock()
inner := c.inner
c.inner = nil
c.mu.Unlock()
c.setStatus("disconnected")
if inner != nil {
return inner.Close()
}
return nil
}
func (c *sdkClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *sdkClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *sdkClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *sdkClient) Initialize(ctx context.Context) error {
// sdkClient 由 createSDKClient 在 Connect 成功后才创建,因此 Initialize 时已经连接
// 此方法仅用于满足 ExternalMCPClient 接口,实际连接在 createSDKClient 中完成
return nil
}
func (c *sdkClient) ListTools(ctx context.Context) ([]Tool, error) {
if c.session == nil {
return nil, fmt.Errorf("未连接")
}
res, err := c.session.ListTools(ctx, nil)
if err != nil {
return nil, err
}
if res == nil {
return nil, nil
}
return sdkToolsToOur(res.Tools), nil
}
func (c *sdkClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
if c.session == nil {
return nil, fmt.Errorf("未连接")
}
params := &mcp.CallToolParams{
Name: name,
Arguments: args,
}
res, err := c.session.CallTool(ctx, params)
if err != nil {
return nil, err
}
return sdkCallToolResultToOurs(res), nil
}
func (c *sdkClient) Close() error {
c.setStatus("disconnected")
if c.session != nil {
err := c.session.Close()
c.session = nil
return err
}
return nil
}
// sdkToolsToOur 将 SDK 的 []*mcp.Tool 转为我们的 []Tool
func sdkToolsToOur(tools []*mcp.Tool) []Tool {
if len(tools) == 0 {
return nil
}
out := make([]Tool, 0, len(tools))
for _, t := range tools {
if t == nil {
continue
}
schema := make(map[string]interface{})
if t.InputSchema != nil {
// SDK InputSchema 可能为 *jsonschema.Schema 或 map统一转为 map
if m, ok := t.InputSchema.(map[string]interface{}); ok {
schema = m
} else {
_ = json.Unmarshal(mustJSON(t.InputSchema), &schema)
}
}
desc := t.Description
shortDesc := desc
if t.Annotations != nil && t.Annotations.Title != "" {
shortDesc = t.Annotations.Title
}
out = append(out, Tool{
Name: t.Name,
Description: desc,
ShortDescription: shortDesc,
InputSchema: schema,
})
}
return out
}
// sdkCallToolResultToOurs 将 SDK 的 *mcp.CallToolResult 转为我们的 *ToolResult
func sdkCallToolResultToOurs(res *mcp.CallToolResult) *ToolResult {
if res == nil {
return &ToolResult{Content: []Content{}}
}
content := sdkContentToOurs(res.Content)
return &ToolResult{
Content: content,
IsError: res.IsError,
}
}
func sdkContentToOurs(list []mcp.Content) []Content {
if len(list) == 0 {
return nil
}
out := make([]Content, 0, len(list))
for _, c := range list {
switch v := c.(type) {
case *mcp.TextContent:
out = append(out, Content{Type: "text", Text: v.Text})
default:
out = append(out, Content{Type: "text", Text: fmt.Sprintf("%v", c)})
}
}
return out
}
func mustJSON(v interface{}) []byte {
b, _ := json.Marshal(v)
return b
}
// simpleHTTPClient 简单 JSON-RPC over HTTP每次请求一次 POST、响应在 body。实现 ExternalMCPClient。
// 用于自建 MCP如 http://127.0.0.1:8081/mcp或其它仅支持简单 POST 的端点。
type simpleHTTPClient struct {
url string
client *http.Client
logger *zap.Logger
mu sync.RWMutex
status string
}
func newSimpleHTTPClient(ctx context.Context, url string, timeout time.Duration, headers map[string]string, logger *zap.Logger) (ExternalMCPClient, error) {
c := &simpleHTTPClient{
url: url,
client: httpClientWithTimeoutAndHeaders(timeout, headers),
logger: logger,
status: "connecting",
}
if err := c.initialize(ctx); err != nil {
return nil, err
}
c.mu.Lock()
c.status = "connected"
c.mu.Unlock()
return c, nil
}
func (c *simpleHTTPClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *simpleHTTPClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *simpleHTTPClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *simpleHTTPClient) Initialize(context.Context) error {
return nil // 已在 newSimpleHTTPClient 中完成
}
func (c *simpleHTTPClient) initialize(ctx context.Context) error {
params := InitializeRequest{
ProtocolVersion: ProtocolVersion,
Capabilities: make(map[string]interface{}),
ClientInfo: ClientInfo{Name: clientName, Version: clientVersion},
}
paramsJSON, _ := json.Marshal(params)
req := &Message{
ID: MessageID{value: "1"},
Method: "initialize",
Version: "2.0",
Params: paramsJSON,
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return fmt.Errorf("initialize: %w", err)
}
if resp.Error != nil {
return fmt.Errorf("initialize: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
// 发送 notifications/initialized协议要求
notify := &Message{
ID: MessageID{value: nil},
Method: "notifications/initialized",
Version: "2.0",
Params: json.RawMessage("{}"),
}
_ = c.sendNotification(notify)
return nil
}
func (c *simpleHTTPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
body, err := json.Marshal(msg)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(b))
}
var out Message
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, err
}
return &out, nil
}
func (c *simpleHTTPClient) sendNotification(msg *Message) error {
body, _ := json.Marshal(msg)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return err
}
resp.Body.Close()
return nil
}
func (c *simpleHTTPClient) ListTools(ctx context.Context) ([]Tool, error) {
req := &Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/list",
Version: "2.0",
Params: json.RawMessage("{}"),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("tools/list: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
var listResp ListToolsResponse
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
return nil, err
}
return listResp.Tools, nil
}
func (c *simpleHTTPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
params := CallToolRequest{Name: name, Arguments: args}
paramsJSON, _ := json.Marshal(params)
req := &Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/call",
Version: "2.0",
Params: paramsJSON,
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("tools/call: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
var callResp CallToolResponse
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
return nil, err
}
return &ToolResult{Content: callResp.Content, IsError: callResp.IsError}, nil
}
func (c *simpleHTTPClient) Close() error {
c.setStatus("disconnected")
return nil
}
// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK返回实现 ExternalMCPClient 的 *sdkClient
// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。
func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) {
timeout := time.Duration(serverCfg.Timeout) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second
}
transport := serverCfg.Transport
if transport == "" {
if serverCfg.Command != "" {
transport = "stdio"
} else if serverCfg.URL != "" {
transport = "http"
} else {
return nil, fmt.Errorf("配置缺少 command 或 url")
}
}
client := mcp.NewClient(&mcp.Implementation{
Name: clientName,
Version: clientVersion,
}, nil)
var t mcp.Transport
switch transport {
case "stdio":
if serverCfg.Command == "" {
return nil, fmt.Errorf("stdio 模式需要配置 command")
}
cmd := exec.CommandContext(ctx, serverCfg.Command, serverCfg.Args...) // 使用 ctx 控制超时与取消
if len(serverCfg.Env) > 0 {
cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...)
}
t = &mcp.CommandTransport{Command: cmd}
case "sse":
if serverCfg.URL == "" {
return nil, fmt.Errorf("sse 模式需要配置 url")
}
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
t = &mcp.SSEClientTransport{
Endpoint: serverCfg.URL,
HTTPClient: httpClient,
}
case "http":
if serverCfg.URL == "" {
return nil, fmt.Errorf("http 模式需要配置 url")
}
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
t = &mcp.StreamableClientTransport{
Endpoint: serverCfg.URL,
HTTPClient: httpClient,
}
case "simple_http":
// 简单 JSON-RPC HTTP每次请求一次 POST、响应在 body。用于自建 MCP 或兼容旧端点(如 http://127.0.0.1:8081/mcp
if serverCfg.URL == "" {
return nil, fmt.Errorf("simple_http 模式需要配置 url")
}
return newSimpleHTTPClient(ctx, serverCfg.URL, timeout, serverCfg.Headers, logger)
default:
return nil, fmt.Errorf("不支持的传输模式: %s", transport)
}
session, err := client.Connect(ctx, t, nil)
if err != nil {
return nil, fmt.Errorf("连接失败: %w", err)
}
return newSDKClientFromSession(session, client, logger), nil
}
func envMapToSlice(env map[string]string) []string {
m := make(map[string]string)
for _, s := range os.Environ() {
if i := strings.IndexByte(s, '='); i > 0 {
m[s[:i]] = s[i+1:]
}
}
for k, v := range env {
m[k] = v
}
out := make([]string, 0, len(m))
for k, v := range m {
out = append(out, k+"="+v)
}
return out
}
func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]string) *http.Client {
transport := http.DefaultTransport
if len(headers) > 0 {
transport = &headerRoundTripper{
headers: headers,
base: http.DefaultTransport,
}
}
return &http.Client{
Timeout: timeout,
Transport: transport,
}
}
type headerRoundTripper struct {
headers map[string]string
base http.RoundTripper
}
func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range h.headers {
req.Header.Set(k, v)
}
return h.base.RoundTrip(req)
}

View File

@@ -196,7 +196,8 @@ func (m *ExternalMCPManager) StartClient(name string) error {
m.mu.Lock()
delete(m.errors, name)
m.mu.Unlock()
// 连接成功,立即刷新工具数量
// 延迟再刷新工具数量,避免 SSE/Streamable 连接尚未就绪时立即请求导致 EOF如值得买等远端
time.Sleep(2 * time.Second)
m.triggerToolCountRefresh()
}
}()
@@ -630,11 +631,20 @@ func (m *ExternalMCPManager) refreshToolCounts() {
cancel()
if err != nil {
m.logger.Debug("获取外部MCP工具数量失败",
zap.String("name", n),
zap.Error(err),
)
// 如果获取失败,保留旧值(在更新时处理)
errStr := err.Error()
// SSE 连接 EOF远端可能关闭了流或未按规范在流上推送响应仅首次用 Warn 提示
if strings.Contains(errStr, "EOF") || strings.Contains(errStr, "client is closing") {
m.logger.Warn("获取外部MCP工具数量失败SSE 流已关闭或服务端未在流上返回 tools/list 响应)",
zap.String("name", n),
zap.String("hint", "若为 SSE 连接,请确认服务端保持 GET 流打开并按 MCP 规范以 event: message 推送 JSON-RPC 响应"),
zap.Error(err),
)
} else {
m.logger.Warn("获取外部MCP工具数量失败请检查连接或服务端 tools/list",
zap.String("name", n),
zap.Error(err),
)
}
resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值
return
}
@@ -707,21 +717,13 @@ func (m *ExternalMCPManager) triggerToolCountRefresh() {
go m.refreshToolCounts()
}
// createClient 创建客户端(不连接)
// createClient 创建客户端(不连接)。统一使用官方 MCP Go SDK 的 lazy 客户端,连接在 Initialize 时完成。
func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient {
timeout := time.Duration(serverCfg.Timeout) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second
}
// 根据传输模式创建客户端
transport := serverCfg.Transport
if transport == "" {
// 如果没有指定transport根据是否有command或url判断
if serverCfg.Command != "" {
transport = "stdio"
} else if serverCfg.URL != "" {
// 默认使用http但可以通过transport字段指定sse
transport = "http"
} else {
return nil
@@ -733,17 +735,23 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
if serverCfg.URL == "" {
return nil
}
return NewHTTPMCPClient(serverCfg.URL, timeout, m.logger)
return newLazySDKClient(serverCfg, m.logger)
case "simple_http":
// 简单 HTTP一次 POST 一次响应),用于自建 MCP 等
if serverCfg.URL == "" {
return nil
}
return newLazySDKClient(serverCfg, m.logger)
case "stdio":
if serverCfg.Command == "" {
return nil
}
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, serverCfg.Env, timeout, m.logger)
return newLazySDKClient(serverCfg, m.logger)
case "sse":
if serverCfg.URL == "" {
return nil
}
return NewSSEMCPClient(serverCfg.URL, timeout, m.logger)
return newLazySDKClient(serverCfg, m.logger)
default:
return nil
}
@@ -773,12 +781,7 @@ func (m *ExternalMCPManager) doConnect(name string, serverCfg config.ExternalMCP
// setClientStatus 设置客户端状态(通过类型断言)
func (m *ExternalMCPManager) setClientStatus(client ExternalMCPClient, status string) {
switch c := client.(type) {
case *HTTPMCPClient:
c.setStatus(status)
case *StdioMCPClient:
c.setStatus(status)
case *SSEMCPClient:
if c, ok := client.(*lazySDKClient); ok {
c.setStatus(status)
}
}

View File

@@ -151,48 +151,26 @@ func TestExternalMCPManager_LoadConfigs(t *testing.T) {
}
}
func TestHTTPMCPClient_Initialize(t *testing.T) {
// 注意这个测试需要一个真实的HTTP MCP服务器
// 如果没有服务器,这个测试会失败
// 在实际测试中可以使用mock服务器
// TestLazySDKClient_InitializeFails 验证无效配置时 SDK 客户端 Initialize 失败并设置 error 状态
func TestLazySDKClient_InitializeFails(t *testing.T) {
logger := zap.NewNop()
client := NewHTTPMCPClient("http://127.0.0.1:8081/mcp", 5*time.Second, logger)
// 使用不存在的 HTTP 地址Initialize 应失败
cfg := config.ExternalMCPServerConfig{
Transport: "http",
URL: "http://127.0.0.1:19999/nonexistent",
Timeout: 2,
}
c := newLazySDKClient(cfg, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 这个测试可能会失败,如果没有真实的服务器
// 在实际环境中应该使用mock服务器
err := client.Initialize(ctx)
if err != nil {
t.Logf("初始化失败(可能是没有服务器): %v", err)
err := c.Initialize(ctx)
if err == nil {
t.Fatal("expected error when connecting to invalid server")
}
status := client.GetStatus()
if status == "" {
t.Error("状态不应该为空")
if c.GetStatus() != "error" {
t.Errorf("expected status error, got %s", c.GetStatus())
}
client.Close()
}
func TestStdioMCPClient_Initialize(t *testing.T) {
// 注意这个测试需要一个真实的stdio MCP服务器
// 如果没有服务器,这个测试会失败
logger := zap.NewNop()
client := NewStdioMCPClient("echo", []string{"test"}, nil, 5*time.Second, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 这个测试可能会失败因为echo不是MCP服务器
// 在实际环境中应该使用真实的MCP服务器或mock
err := client.Initialize(ctx)
if err != nil {
t.Logf("初始化失败echo不是MCP服务器: %v", err)
}
client.Close()
c.Close()
}
func TestExternalMCPManager_StartStopClient(t *testing.T) {

View File

@@ -125,6 +125,13 @@ func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// 官方 MCP SSE 规范:带 sessionid 的 POST 表示消息发往该 SSE 会话,响应通过 SSE 流返回
if sessionID := r.URL.Query().Get("sessionid"); sessionID != "" {
s.serveSSESessionMessage(w, r, sessionID)
return
}
// 简单 POST请求体为 JSON-RPC响应在 body 中返回
body, err := io.ReadAll(r.Body)
if err != nil {
s.sendError(w, nil, -32700, "Parse error", err.Error())
@@ -137,14 +144,56 @@ func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// 处理消息
response := s.handleMessage(&msg)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// handleSSE 处理SSE连接用于MCP HTTP传输的事件通道
// serveSSESessionMessage 处理发往 SSE 会话的 POST读取 JSON-RPC 请求,处理后将响应通过该会话的 SSE 流推送
func (s *Server) serveSSESessionMessage(w http.ResponseWriter, r *http.Request, sessionID string) {
s.mu.RLock()
client, exists := s.sseClients[sessionID]
s.mu.RUnlock()
if !exists || client == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "failed to read body", http.StatusBadRequest)
return
}
var msg Message
if err := json.Unmarshal(body, &msg); err != nil {
http.Error(w, "failed to parse body", http.StatusBadRequest)
return
}
response := s.handleMessage(&msg)
if response == nil {
w.WriteHeader(http.StatusAccepted)
return
}
respBytes, err := json.Marshal(response)
if err != nil {
http.Error(w, "failed to encode response", http.StatusInternalServerError)
return
}
select {
case client.send <- respBytes:
w.WriteHeader(http.StatusAccepted)
default:
http.Error(w, "session send buffer full", http.StatusServiceUnavailable)
}
}
// handleSSE 处理 SSE 连接,兼容官方 MCP 2024-11-05 SSE 规范:
// 1. 首个事件必须为 event: endpointdata 为客户端 POST 消息的 URL含 sessionid
// 2. 后续事件为 event: messagedata 为 JSON-RPC 响应
func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
@@ -157,16 +206,25 @@ func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
sessionID := uuid.New().String()
client := &sseClient{
id: uuid.New().String(),
send: make(chan []byte, 8),
id: sessionID,
send: make(chan []byte, 32),
}
s.addSSEClient(client)
defer s.removeSSEClient(client.id)
// 发送初始ready事件告知客户端连接成功
fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n")
// 官方规范:首个事件为 endpointdata 为消息端点 URL客户端将向该 URL POST 请求)
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
if r.URL.Scheme != "" {
scheme = r.URL.Scheme
}
endpointURL := fmt.Sprintf("%s://%s%s?sessionid=%s", scheme, r.Host, r.URL.Path, sessionID)
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpointURL)
flusher.Flush()
ticker := time.NewTicker(15 * time.Second)
@@ -183,7 +241,6 @@ func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
flusher.Flush()
case <-ticker.C:
// 心跳保持连接
fmt.Fprintf(w, ": ping\n\n")
flusher.Flush()
}

View File

@@ -1,11 +1,22 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"time"
)
// ExternalMCPClient 外部 MCP 客户端接口(由 client_sdk.go 基于官方 SDK 实现)
type ExternalMCPClient interface {
Initialize(ctx context.Context) error
ListTools(ctx context.Context) ([]Tool, error)
CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error)
Close() error
IsConnected() bool
GetStatus() string
}
// MCP消息类型
const (
MessageTypeRequest = "request"
@@ -29,21 +40,21 @@ func (m *MessageID) UnmarshalJSON(data []byte) error {
m.value = nil
return nil
}
// 尝试解析为字符串
var str string
if err := json.Unmarshal(data, &str); err == nil {
m.value = str
return nil
}
// 尝试解析为数字
var num json.Number
if err := json.Unmarshal(data, &num); err == nil {
m.value = num
return nil
}
return fmt.Errorf("invalid id type")
}
@@ -81,15 +92,15 @@ type Message struct {
// Error 表示MCP错误
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// Tool 表示MCP工具定义
type Tool struct {
Name string `json:"name"`
Description string `json:"description"` // 详细描述
Description string `json:"description"` // 详细描述
ShortDescription string `json:"shortDescription,omitempty"` // 简短描述用于工具列表减少token消耗
InputSchema map[string]interface{} `json:"inputSchema"`
}
@@ -127,9 +138,9 @@ type ClientInfo struct {
// InitializeResponse 初始化响应
type InitializeResponse struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
}
// ServerCapabilities 服务器能力
@@ -178,31 +189,31 @@ type CallToolResponse struct {
// ToolExecution 工具执行记录
type ToolExecution struct {
ID string `json:"id"`
ToolName string `json:"toolName"`
Arguments map[string]interface{} `json:"arguments"`
Status string `json:"status"` // pending, running, completed, failed
Result *ToolResult `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartTime time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
ID string `json:"id"`
ToolName string `json:"toolName"`
Arguments map[string]interface{} `json:"arguments"`
Status string `json:"status"` // pending, running, completed, failed
Result *ToolResult `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartTime time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
}
// ToolStats 工具统计信息
type ToolStats struct {
ToolName string `json:"toolName"`
TotalCalls int `json:"totalCalls"`
SuccessCalls int `json:"successCalls"`
FailedCalls int `json:"failedCalls"`
ToolName string `json:"toolName"`
TotalCalls int `json:"totalCalls"`
SuccessCalls int `json:"successCalls"`
FailedCalls int `json:"failedCalls"`
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
}
// Prompt 提示词模板
type Prompt struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Arguments []PromptArgument `json:"arguments,omitempty"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Arguments []PromptArgument `json:"arguments,omitempty"`
}
// PromptArgument 提示词参数
@@ -257,11 +268,11 @@ type ResourceContent struct {
// SamplingRequest 采样请求
type SamplingRequest struct {
Messages []SamplingMessage `json:"messages"`
Model string `json:"model,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
Messages []SamplingMessage `json:"messages"`
Model string `json:"model,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
}
// SamplingMessage 采样消息
@@ -272,9 +283,9 @@ type SamplingMessage struct {
// SamplingResponse 采样响应
type SamplingResponse struct {
Content []SamplingContent `json:"content"`
Model string `json:"model,omitempty"`
StopReason string `json:"stopReason,omitempty"`
Content []SamplingContent `json:"content"`
Model string `json:"model,omitempty"`
StopReason string `json:"stopReason,omitempty"`
}
// SamplingContent 采样内容
@@ -282,4 +293,3 @@ type SamplingContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
}

View File

@@ -994,6 +994,20 @@ async function loadExternalMCPs() {
}
}
// 延迟刷新外部MCP列表用于在保存/连接后拉取后端异步更新的工具数量)
// 可选传入单次延迟毫秒数不传则执行两次2.5s 与 5s覆盖启动后后端异步更新较慢的情况
function scheduleExternalMCPToolCountRefresh(delayMs) {
const delays = delayMs != null ? [delayMs] : [2500, 5000];
delays.forEach((d) => {
setTimeout(async () => {
await loadExternalMCPs();
if (typeof window !== 'undefined' && typeof window.refreshMentionTools === 'function') {
window.refreshMentionTools();
}
}, d);
});
}
// 渲染外部MCP列表
function renderExternalMCPList(servers) {
const list = document.getElementById('external-mcp-list');
@@ -1354,6 +1368,8 @@ async function saveExternalMCP() {
if (typeof window !== 'undefined' && typeof window.refreshMentionTools === 'function') {
window.refreshMentionTools();
}
// 后端在连接成功约 2 秒后才更新工具数量,延迟再拉取一次以显示正确工具数
scheduleExternalMCPToolCountRefresh();
alert('保存成功');
} catch (error) {
console.error('保存外部MCP失败:', error);
@@ -1433,6 +1449,8 @@ async function toggleExternalMCP(name, currentStatus) {
if (typeof window !== 'undefined' && typeof window.refreshMentionTools === 'function') {
window.refreshMentionTools();
}
// 后端约 2 秒后才更新工具数量,延迟再拉取一次以显示正确工具数
scheduleExternalMCPToolCountRefresh();
return;
}
}
@@ -1496,6 +1514,8 @@ async function pollExternalMCPStatus(name, maxAttempts = 30) {
if (typeof window !== 'undefined' && typeof window.refreshMentionTools === 'function') {
window.refreshMentionTools();
}
// 后端约 2 秒后才更新工具数量,延迟再拉取一次以显示正确工具数
scheduleExternalMCPToolCountRefresh();
return;
} else if (status === 'error' || status === 'disconnected') {
// 连接失败,刷新列表并显示错误