mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-26 12:46:32 +02:00
460 lines
13 KiB
Go
460 lines
13 KiB
Go
package handler
|
||
|
||
import (
|
||
"fmt"
|
||
"net/http"
|
||
"os"
|
||
"sync"
|
||
|
||
"cyberstrike-ai/internal/config"
|
||
"cyberstrike-ai/internal/mcp"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
"gopkg.in/yaml.v3"
|
||
)
|
||
|
||
// ExternalMCPHandler 外部MCP处理器
|
||
type ExternalMCPHandler struct {
|
||
manager *mcp.ExternalMCPManager
|
||
config *config.Config
|
||
configPath string
|
||
logger *zap.Logger
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
// NewExternalMCPHandler 创建外部MCP处理器
|
||
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
|
||
return &ExternalMCPHandler{
|
||
manager: manager,
|
||
config: cfg,
|
||
configPath: configPath,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// GetExternalMCPs 获取所有外部MCP配置
|
||
func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
|
||
h.mu.RLock()
|
||
defer h.mu.RUnlock()
|
||
|
||
configs := h.manager.GetConfigs()
|
||
|
||
// 获取所有外部MCP的工具数量
|
||
toolCounts := h.manager.GetToolCounts()
|
||
|
||
// 转换为响应格式
|
||
result := make(map[string]ExternalMCPResponse)
|
||
for name, cfg := range configs {
|
||
client, exists := h.manager.GetClient(name)
|
||
status := "disconnected"
|
||
if exists {
|
||
status = client.GetStatus()
|
||
} else if h.isEnabled(cfg) {
|
||
status = "disconnected"
|
||
} else {
|
||
status = "disabled"
|
||
}
|
||
|
||
toolCount := toolCounts[name]
|
||
errorMsg := ""
|
||
if status == "error" {
|
||
errorMsg = h.manager.GetError(name)
|
||
}
|
||
|
||
result[name] = ExternalMCPResponse{
|
||
Config: cfg,
|
||
Status: status,
|
||
ToolCount: toolCount,
|
||
Error: errorMsg,
|
||
}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"servers": result,
|
||
"stats": h.manager.GetStats(),
|
||
})
|
||
}
|
||
|
||
// GetExternalMCP 获取单个外部MCP配置
|
||
func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
|
||
name := c.Param("name")
|
||
|
||
h.mu.RLock()
|
||
defer h.mu.RUnlock()
|
||
|
||
configs := h.manager.GetConfigs()
|
||
cfg, exists := configs[name]
|
||
if !exists {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"})
|
||
return
|
||
}
|
||
|
||
client, clientExists := h.manager.GetClient(name)
|
||
status := "disconnected"
|
||
if clientExists {
|
||
status = client.GetStatus()
|
||
} else if h.isEnabled(cfg) {
|
||
status = "disconnected"
|
||
} else {
|
||
status = "disabled"
|
||
}
|
||
|
||
// 获取工具数量
|
||
toolCount := 0
|
||
if clientExists && client.IsConnected() {
|
||
if count, err := h.manager.GetToolCount(name); err == nil {
|
||
toolCount = count
|
||
}
|
||
}
|
||
|
||
// 获取错误信息
|
||
errorMsg := ""
|
||
if status == "error" {
|
||
errorMsg = h.manager.GetError(name)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, ExternalMCPResponse{
|
||
Config: cfg,
|
||
Status: status,
|
||
ToolCount: toolCount,
|
||
Error: errorMsg,
|
||
})
|
||
}
|
||
|
||
// AddOrUpdateExternalMCP 添加或更新外部MCP配置
|
||
func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||
var req AddOrUpdateExternalMCPRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
name := c.Param("name")
|
||
if name == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"})
|
||
return
|
||
}
|
||
|
||
// 验证配置
|
||
if err := h.validateConfig(req.Config); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
// 添加或更新配置
|
||
if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil {
|
||
h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
// 更新内存中的配置
|
||
if h.config.ExternalMCP.Servers == nil {
|
||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||
}
|
||
|
||
cfg := req.Config
|
||
|
||
// 官方 disabled 字段 → ExternalMCPEnable 取反
|
||
if cfg.Disabled {
|
||
cfg.ExternalMCPEnable = false
|
||
} else if !cfg.ExternalMCPEnable {
|
||
// 用户未显式设置 external_mcp_enable,官方配置默认就是启用的
|
||
cfg.ExternalMCPEnable = true
|
||
}
|
||
|
||
// 展开 ${VAR} 环境变量
|
||
config.ExpandConfigEnv(&cfg)
|
||
|
||
h.config.ExternalMCP.Servers[name] = cfg
|
||
|
||
// 保存到配置文件
|
||
if err := h.saveConfig(); err != nil {
|
||
h.logger.Error("保存配置失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
|
||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||
}
|
||
|
||
// DeleteExternalMCP 删除外部MCP配置
|
||
func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
|
||
name := c.Param("name")
|
||
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
// 移除配置
|
||
if err := h.manager.RemoveConfig(name); err != nil {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"})
|
||
return
|
||
}
|
||
|
||
// 从内存配置中删除
|
||
if h.config.ExternalMCP.Servers != nil {
|
||
delete(h.config.ExternalMCP.Servers, name)
|
||
}
|
||
|
||
// 保存到配置文件
|
||
if err := h.saveConfig(); err != nil {
|
||
h.logger.Error("保存配置失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
|
||
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
|
||
}
|
||
|
||
// StartExternalMCP 启动外部MCP
|
||
func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
|
||
name := c.Param("name")
|
||
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
// 更新配置为启用
|
||
if h.config.ExternalMCP.Servers == nil {
|
||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||
}
|
||
cfg := h.config.ExternalMCP.Servers[name]
|
||
cfg.ExternalMCPEnable = true
|
||
h.config.ExternalMCP.Servers[name] = cfg
|
||
|
||
// 保存到配置文件
|
||
if err := h.saveConfig(); err != nil {
|
||
h.logger.Error("保存配置失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
// 启动客户端(立即创建客户端并设置状态为connecting,实际连接在后台进行)
|
||
h.logger.Info("开始启动外部MCP", zap.String("name", name))
|
||
if err := h.manager.StartClient(name); err != nil {
|
||
h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err))
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": err.Error(),
|
||
"status": "error",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 获取客户端状态(应该是connecting)
|
||
client, exists := h.manager.GetClient(name)
|
||
status := "connecting"
|
||
if exists {
|
||
status = client.GetStatus()
|
||
}
|
||
|
||
// 立即返回,不等待连接完成
|
||
// 客户端会在后台异步连接,用户可以通过状态查询接口查看连接状态
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"message": "外部MCP启动请求已提交,正在后台连接中",
|
||
"status": status,
|
||
})
|
||
}
|
||
|
||
// StopExternalMCP 停止外部MCP
|
||
func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
|
||
name := c.Param("name")
|
||
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
// 停止客户端
|
||
if err := h.manager.StopClient(name); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 更新配置
|
||
if h.config.ExternalMCP.Servers == nil {
|
||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||
}
|
||
cfg := h.config.ExternalMCP.Servers[name]
|
||
cfg.ExternalMCPEnable = false
|
||
h.config.ExternalMCP.Servers[name] = cfg
|
||
|
||
// 保存到配置文件
|
||
if err := h.saveConfig(); err != nil {
|
||
h.logger.Error("保存配置失败", zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
h.logger.Info("外部MCP已停止", zap.String("name", name))
|
||
c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"})
|
||
}
|
||
|
||
// GetExternalMCPStats 获取统计信息
|
||
func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) {
|
||
stats := h.manager.GetStats()
|
||
c.JSON(http.StatusOK, stats)
|
||
}
|
||
|
||
// validateConfig 验证配置(同时支持官方 type 字段和旧版 transport 字段)
|
||
func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error {
|
||
transport := cfg.GetTransportType()
|
||
if transport == "" {
|
||
return fmt.Errorf("需要指定 command(stdio模式)或 url + type(http/sse模式)")
|
||
}
|
||
|
||
switch transport {
|
||
case "http":
|
||
if cfg.URL == "" {
|
||
return fmt.Errorf("HTTP模式需要 url")
|
||
}
|
||
case "stdio":
|
||
if cfg.Command == "" {
|
||
return fmt.Errorf("stdio模式需要 command")
|
||
}
|
||
case "sse":
|
||
if cfg.URL == "" {
|
||
return fmt.Errorf("SSE模式需要 url")
|
||
}
|
||
default:
|
||
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// isEnabled 检查是否启用
|
||
func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool {
|
||
return cfg.ExternalMCPEnable
|
||
}
|
||
|
||
// saveConfig 保存配置到文件
|
||
func (h *ExternalMCPHandler) saveConfig() error {
|
||
data, err := os.ReadFile(h.configPath)
|
||
if err != nil {
|
||
return fmt.Errorf("读取配置文件失败: %w", err)
|
||
}
|
||
|
||
if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil {
|
||
h.logger.Warn("创建配置备份失败", zap.Error(err))
|
||
}
|
||
|
||
root, err := loadYAMLDocument(h.configPath)
|
||
if err != nil {
|
||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||
}
|
||
|
||
updateExternalMCPConfig(root, h.config.ExternalMCP)
|
||
|
||
if err := writeYAMLDocument(h.configPath, root); err != nil {
|
||
return fmt.Errorf("保存配置文件失败: %w", err)
|
||
}
|
||
|
||
h.logger.Info("配置已保存", zap.String("path", h.configPath))
|
||
return nil
|
||
}
|
||
|
||
// updateExternalMCPConfig 更新外部MCP配置
|
||
func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig) {
|
||
root := doc.Content[0]
|
||
externalMCPNode := ensureMap(root, "external_mcp")
|
||
serversNode := ensureMap(externalMCPNode, "servers")
|
||
|
||
// 清空现有服务器配置
|
||
serversNode.Content = nil
|
||
|
||
// 添加新的服务器配置
|
||
for name, serverCfg := range cfg.Servers {
|
||
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
|
||
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
|
||
|
||
// type(官方 MCP 传输类型)
|
||
effectiveType := serverCfg.GetTransportType()
|
||
if effectiveType != "" && effectiveType != "stdio" {
|
||
// stdio 可省略(有 command 时自动推断)
|
||
setStringInMap(serverNode, "type", effectiveType)
|
||
}
|
||
if serverCfg.Command != "" {
|
||
setStringInMap(serverNode, "command", serverCfg.Command)
|
||
}
|
||
if len(serverCfg.Args) > 0 {
|
||
setStringArrayInMap(serverNode, "args", serverCfg.Args)
|
||
}
|
||
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
|
||
envNode := ensureMap(serverNode, "env")
|
||
for envKey, envValue := range serverCfg.Env {
|
||
setStringInMap(envNode, envKey, envValue)
|
||
}
|
||
}
|
||
if serverCfg.URL != "" {
|
||
setStringInMap(serverNode, "url", serverCfg.URL)
|
||
}
|
||
if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 {
|
||
headersNode := ensureMap(serverNode, "headers")
|
||
for k, v := range serverCfg.Headers {
|
||
setStringInMap(headersNode, k, v)
|
||
}
|
||
}
|
||
if serverCfg.Description != "" {
|
||
setStringInMap(serverNode, "description", serverCfg.Description)
|
||
}
|
||
if serverCfg.Timeout > 0 {
|
||
setIntInMap(serverNode, "timeout", serverCfg.Timeout)
|
||
}
|
||
// 官方标准字段
|
||
if serverCfg.Disabled {
|
||
setBoolInMap(serverNode, "disabled", true)
|
||
}
|
||
if len(serverCfg.AutoApprove) > 0 {
|
||
setStringArrayInMap(serverNode, "autoApprove", serverCfg.AutoApprove)
|
||
}
|
||
|
||
// SDK 高级配置
|
||
if serverCfg.MaxRetries > 0 {
|
||
setIntInMap(serverNode, "max_retries", serverCfg.MaxRetries)
|
||
}
|
||
if serverCfg.TerminateDuration > 0 {
|
||
setIntInMap(serverNode, "terminate_duration", serverCfg.TerminateDuration)
|
||
}
|
||
if serverCfg.KeepAlive > 0 {
|
||
setIntInMap(serverNode, "keep_alive", serverCfg.KeepAlive)
|
||
}
|
||
|
||
setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable)
|
||
if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 {
|
||
toolEnabledNode := ensureMap(serverNode, "tool_enabled")
|
||
for toolName, enabled := range serverCfg.ToolEnabled {
|
||
setBoolInMap(toolEnabledNode, toolName, enabled)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// setStringArrayInMap 设置字符串数组
|
||
func setStringArrayInMap(mapNode *yaml.Node, key string, values []string) {
|
||
_, valueNode := ensureKeyValue(mapNode, key)
|
||
valueNode.Kind = yaml.SequenceNode
|
||
valueNode.Tag = "!!seq"
|
||
valueNode.Content = nil
|
||
for _, v := range values {
|
||
itemNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: v}
|
||
valueNode.Content = append(valueNode.Content, itemNode)
|
||
}
|
||
}
|
||
|
||
// AddOrUpdateExternalMCPRequest 添加或更新外部MCP请求
|
||
type AddOrUpdateExternalMCPRequest struct {
|
||
Config config.ExternalMCPServerConfig `json:"config"`
|
||
}
|
||
|
||
// ExternalMCPResponse 外部MCP响应
|
||
type ExternalMCPResponse struct {
|
||
Config config.ExternalMCPServerConfig `json:"config"`
|
||
Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting"
|
||
ToolCount int `json:"tool_count"` // 工具数量
|
||
Error string `json:"error,omitempty"` // 错误信息(仅在status为error时存在)
|
||
}
|