mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-03 04:48:08 +02:00
Add files via upload
This commit is contained in:
+309
-22
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
@@ -20,13 +22,14 @@ import (
|
||||
|
||||
// ConfigHandler 配置处理器
|
||||
type ConfigHandler struct {
|
||||
configPath string
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
configPath string
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// AgentUpdater Agent更新接口
|
||||
@@ -36,14 +39,15 @@ type AgentUpdater interface {
|
||||
}
|
||||
|
||||
// NewConfigHandler 创建新的配置处理器
|
||||
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, logger *zap.Logger) *ConfigHandler {
|
||||
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler {
|
||||
return &ConfigHandler{
|
||||
configPath: configPath,
|
||||
config: cfg,
|
||||
mcpServer: mcpServer,
|
||||
executor: executor,
|
||||
agent: agent,
|
||||
logger: logger,
|
||||
configPath: configPath,
|
||||
config: cfg,
|
||||
mcpServer: mcpServer,
|
||||
executor: executor,
|
||||
agent: agent,
|
||||
externalMCPMgr: externalMCPMgr,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +64,8 @@ type ToolConfigInfo struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Enabled bool `json:"enabled"`
|
||||
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
|
||||
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
|
||||
}
|
||||
|
||||
// GetConfig 获取当前配置
|
||||
@@ -67,13 +73,14 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
// 获取工具列表
|
||||
// 获取工具列表(包含内部和外部工具)
|
||||
tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
|
||||
for _, tool := range h.config.Security.Tools {
|
||||
tools = append(tools, ToolConfigInfo{
|
||||
Name: tool.Name,
|
||||
Description: tool.ShortDescription,
|
||||
Enabled: tool.Enabled,
|
||||
IsExternal: false,
|
||||
})
|
||||
// 如果没有简短描述,使用详细描述的前100个字符
|
||||
if tools[len(tools)-1].Description == "" {
|
||||
@@ -85,6 +92,65 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获取外部MCP工具
|
||||
if h.externalMCPMgr != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
|
||||
if err == nil {
|
||||
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
|
||||
for _, externalTool := range externalTools {
|
||||
var mcpName, actualToolName string
|
||||
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
|
||||
mcpName = externalTool.Name[:idx]
|
||||
actualToolName = externalTool.Name[idx+2:]
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
enabled := false
|
||||
if cfg, exists := externalMCPConfigs[mcpName]; exists {
|
||||
// 首先检查外部MCP是否启用
|
||||
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
|
||||
enabled = false // MCP未启用,所有工具都禁用
|
||||
} else {
|
||||
// MCP已启用,检查单个工具的启用状态
|
||||
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
|
||||
if cfg.ToolEnabled == nil {
|
||||
enabled = true // 未设置工具状态,默认为启用
|
||||
} else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists {
|
||||
enabled = toolEnabled // 使用配置的工具状态
|
||||
} else {
|
||||
enabled = true // 工具未在配置中,默认为启用
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client, exists := h.externalMCPMgr.GetClient(mcpName)
|
||||
if !exists || !client.IsConnected() {
|
||||
enabled = false
|
||||
}
|
||||
|
||||
description := externalTool.ShortDescription
|
||||
if description == "" {
|
||||
description = externalTool.Description
|
||||
}
|
||||
if len(description) > 100 {
|
||||
description = description[:100] + "..."
|
||||
}
|
||||
|
||||
tools = append(tools, ToolConfigInfo{
|
||||
Name: actualToolName,
|
||||
Description: description,
|
||||
Enabled: enabled,
|
||||
IsExternal: true,
|
||||
ExternalMCP: mcpName,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, GetConfigResponse{
|
||||
OpenAI: h.config.OpenAI,
|
||||
MCP: h.config.MCP,
|
||||
@@ -128,13 +194,14 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
searchTermLower = strings.ToLower(searchTerm)
|
||||
}
|
||||
|
||||
// 获取所有工具并应用搜索过滤
|
||||
// 获取所有内部工具并应用搜索过滤
|
||||
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
|
||||
for _, tool := range h.config.Security.Tools {
|
||||
toolInfo := ToolConfigInfo{
|
||||
Name: tool.Name,
|
||||
Description: tool.ShortDescription,
|
||||
Enabled: tool.Enabled,
|
||||
IsExternal: false,
|
||||
}
|
||||
// 如果没有简短描述,使用详细描述的前100个字符
|
||||
if toolInfo.Description == "" {
|
||||
@@ -157,6 +224,81 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
allTools = append(allTools, toolInfo)
|
||||
}
|
||||
|
||||
// 获取外部MCP工具
|
||||
if h.externalMCPMgr != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取外部MCP工具失败", zap.Error(err))
|
||||
} else {
|
||||
// 获取外部MCP配置,用于判断启用状态
|
||||
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
|
||||
|
||||
for _, externalTool := range externalTools {
|
||||
// 解析工具名称:mcpName::toolName
|
||||
var mcpName, actualToolName string
|
||||
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
|
||||
mcpName = externalTool.Name[:idx]
|
||||
actualToolName = externalTool.Name[idx+2:]
|
||||
} else {
|
||||
continue // 跳过格式不正确的工具
|
||||
}
|
||||
|
||||
// 获取外部工具的启用状态
|
||||
enabled := false
|
||||
if cfg, exists := externalMCPConfigs[mcpName]; exists {
|
||||
// 首先检查外部MCP是否启用
|
||||
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
|
||||
enabled = false // MCP未启用,所有工具都禁用
|
||||
} else {
|
||||
// MCP已启用,检查单个工具的启用状态
|
||||
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
|
||||
if cfg.ToolEnabled == nil {
|
||||
enabled = true // 未设置工具状态,默认为启用
|
||||
} else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists {
|
||||
enabled = toolEnabled // 使用配置的工具状态
|
||||
} else {
|
||||
enabled = true // 工具未在配置中,默认为启用
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查外部MCP是否已连接
|
||||
client, exists := h.externalMCPMgr.GetClient(mcpName)
|
||||
if !exists || !client.IsConnected() {
|
||||
enabled = false // 未连接时视为禁用
|
||||
}
|
||||
|
||||
description := externalTool.ShortDescription
|
||||
if description == "" {
|
||||
description = externalTool.Description
|
||||
}
|
||||
if len(description) > 100 {
|
||||
description = description[:100] + "..."
|
||||
}
|
||||
|
||||
// 如果有关键词,进行搜索过滤
|
||||
if searchTermLower != "" {
|
||||
nameLower := strings.ToLower(actualToolName)
|
||||
descLower := strings.ToLower(description)
|
||||
if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) {
|
||||
continue // 不匹配,跳过
|
||||
}
|
||||
}
|
||||
|
||||
allTools = append(allTools, ToolConfigInfo{
|
||||
Name: actualToolName, // 显示实际工具名称,不带前缀
|
||||
Description: description,
|
||||
Enabled: enabled,
|
||||
IsExternal: true,
|
||||
ExternalMCP: mcpName,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total := len(allTools)
|
||||
totalPages := (total + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
@@ -196,8 +338,10 @@ type UpdateConfigRequest struct {
|
||||
|
||||
// ToolEnableStatus 工具启用状态
|
||||
type ToolEnableStatus struct {
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
|
||||
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
|
||||
}
|
||||
|
||||
// UpdateConfig 更新配置
|
||||
@@ -240,14 +384,28 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
|
||||
// 更新工具启用状态
|
||||
if req.Tools != nil {
|
||||
toolMap := make(map[string]bool)
|
||||
// 分离内部工具和外部工具
|
||||
internalToolMap := make(map[string]bool)
|
||||
// 外部工具状态:MCP名称 -> 工具名称 -> 启用状态
|
||||
externalMCPToolMap := make(map[string]map[string]bool)
|
||||
|
||||
for _, toolStatus := range req.Tools {
|
||||
toolMap[toolStatus.Name] = toolStatus.Enabled
|
||||
if toolStatus.IsExternal && toolStatus.ExternalMCP != "" {
|
||||
// 外部工具:保存每个工具的独立状态
|
||||
mcpName := toolStatus.ExternalMCP
|
||||
if externalMCPToolMap[mcpName] == nil {
|
||||
externalMCPToolMap[mcpName] = make(map[string]bool)
|
||||
}
|
||||
externalMCPToolMap[mcpName][toolStatus.Name] = toolStatus.Enabled
|
||||
} else {
|
||||
// 内部工具
|
||||
internalToolMap[toolStatus.Name] = toolStatus.Enabled
|
||||
}
|
||||
}
|
||||
|
||||
// 更新配置中的工具状态
|
||||
// 更新内部工具状态
|
||||
for i := range h.config.Security.Tools {
|
||||
if enabled, ok := toolMap[h.config.Security.Tools[i].Name]; ok {
|
||||
if enabled, ok := internalToolMap[h.config.Security.Tools[i].Name]; ok {
|
||||
h.config.Security.Tools[i].Enabled = enabled
|
||||
h.logger.Info("更新工具启用状态",
|
||||
zap.String("tool", h.config.Security.Tools[i].Name),
|
||||
@@ -255,6 +413,80 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新外部MCP工具状态
|
||||
if h.externalMCPMgr != nil {
|
||||
for mcpName, toolStates := range externalMCPToolMap {
|
||||
// 更新配置中的工具启用状态
|
||||
if h.config.ExternalMCP.Servers == nil {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
}
|
||||
cfg, exists := h.config.ExternalMCP.Servers[mcpName]
|
||||
if !exists {
|
||||
h.logger.Warn("外部MCP配置不存在", zap.String("mcp", mcpName))
|
||||
continue
|
||||
}
|
||||
|
||||
// 初始化ToolEnabled map
|
||||
if cfg.ToolEnabled == nil {
|
||||
cfg.ToolEnabled = make(map[string]bool)
|
||||
}
|
||||
|
||||
// 更新每个工具的启用状态
|
||||
for toolName, enabled := range toolStates {
|
||||
cfg.ToolEnabled[toolName] = enabled
|
||||
h.logger.Info("更新外部工具启用状态",
|
||||
zap.String("mcp", mcpName),
|
||||
zap.String("tool", toolName),
|
||||
zap.Bool("enabled", enabled),
|
||||
)
|
||||
}
|
||||
|
||||
// 检查是否有任何工具启用,如果有则启用MCP
|
||||
hasEnabledTool := false
|
||||
for _, enabled := range cfg.ToolEnabled {
|
||||
if enabled {
|
||||
hasEnabledTool = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果MCP之前未启用,但现在有工具启用,则启用MCP
|
||||
// 如果MCP之前已启用,保持启用状态(允许部分工具禁用)
|
||||
if !cfg.ExternalMCPEnable && hasEnabledTool {
|
||||
cfg.ExternalMCPEnable = true
|
||||
h.logger.Info("自动启用外部MCP(因为有工具启用)", zap.String("mcp", mcpName))
|
||||
}
|
||||
|
||||
h.config.ExternalMCP.Servers[mcpName] = cfg
|
||||
}
|
||||
|
||||
// 同步更新 externalMCPMgr 中的配置,确保 GetConfigs() 返回最新配置
|
||||
// 在循环外部统一更新,避免重复调用
|
||||
h.externalMCPMgr.LoadConfigs(&h.config.ExternalMCP)
|
||||
|
||||
// 处理MCP连接状态
|
||||
for mcpName := range externalMCPToolMap {
|
||||
cfg := h.config.ExternalMCP.Servers[mcpName]
|
||||
// 如果MCP需要启用,确保客户端已启动
|
||||
if cfg.ExternalMCPEnable {
|
||||
// 启动外部MCP(如果未启动)
|
||||
client, exists := h.externalMCPMgr.GetClient(mcpName)
|
||||
if !exists || !client.IsConnected() {
|
||||
if err := h.externalMCPMgr.StartClient(mcpName); err != nil {
|
||||
h.logger.Warn("启动外部MCP失败",
|
||||
zap.String("mcp", mcpName),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
h.logger.Info("启动外部MCP",
|
||||
zap.String("mcp", mcpName),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存配置到文件
|
||||
@@ -318,6 +550,33 @@ func (h *ConfigHandler) saveConfig() error {
|
||||
updateAgentConfig(root, h.config.Agent.MaxIterations)
|
||||
updateMCPConfig(root, h.config.MCP)
|
||||
updateOpenAIConfig(root, h.config.OpenAI)
|
||||
// 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用)
|
||||
// 读取原始配置以保持向后兼容
|
||||
originalConfigs := make(map[string]map[string]bool)
|
||||
externalMCPNode := findMapValue(root, "external_mcp")
|
||||
if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode {
|
||||
serversNode := findMapValue(externalMCPNode, "servers")
|
||||
if serversNode != nil && serversNode.Kind == yaml.MappingNode {
|
||||
for i := 0; i < len(serversNode.Content); i += 2 {
|
||||
if i+1 >= len(serversNode.Content) {
|
||||
break
|
||||
}
|
||||
nameNode := serversNode.Content[i]
|
||||
serverNode := serversNode.Content[i+1]
|
||||
if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode {
|
||||
serverName := nameNode.Value
|
||||
originalConfigs[serverName] = make(map[string]bool)
|
||||
if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil {
|
||||
originalConfigs[serverName]["enabled"] = *enabledVal
|
||||
}
|
||||
if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil {
|
||||
originalConfigs[serverName]["disabled"] = *disabledVal
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs)
|
||||
|
||||
if err := writeYAMLDocument(h.configPath, root); err != nil {
|
||||
return fmt.Errorf("保存配置文件失败: %w", err)
|
||||
@@ -504,6 +763,34 @@ func setIntInMap(mapNode *yaml.Node, key string, value int) {
|
||||
valueNode.Value = fmt.Sprintf("%d", value)
|
||||
}
|
||||
|
||||
func findBoolInMap(mapNode *yaml.Node, key string) *bool {
|
||||
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := 0; i < len(mapNode.Content); i += 2 {
|
||||
if i+1 >= len(mapNode.Content) {
|
||||
break
|
||||
}
|
||||
keyNode := mapNode.Content[i]
|
||||
valueNode := mapNode.Content[i+1]
|
||||
|
||||
if keyNode.Kind == yaml.ScalarNode && keyNode.Value == key {
|
||||
if valueNode.Kind == yaml.ScalarNode {
|
||||
if valueNode.Value == "true" {
|
||||
result := true
|
||||
return &result
|
||||
} else if valueNode.Value == "false" {
|
||||
result := false
|
||||
return &result
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setBoolInMap(mapNode *yaml.Node, key string, value bool) {
|
||||
_, valueNode := ensureKeyValue(mapNode, key)
|
||||
valueNode.Kind = yaml.ScalarNode
|
||||
|
||||
@@ -0,0 +1,510 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ExternalMCPHandler 外部MCP处理器
|
||||
type ExternalMCPHandler struct {
|
||||
manager *mcp.ExternalMCPManager
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewExternalMCPHandler 创建外部MCP处理器
|
||||
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
|
||||
return &ExternalMCPHandler{
|
||||
manager: manager,
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetExternalMCPs 获取所有外部MCP配置
|
||||
func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
configs := h.manager.GetConfigs()
|
||||
|
||||
// 获取所有外部MCP的工具数量
|
||||
toolCounts := h.manager.GetToolCounts()
|
||||
|
||||
// 转换为响应格式
|
||||
result := make(map[string]ExternalMCPResponse)
|
||||
for name, cfg := range configs {
|
||||
client, exists := h.manager.GetClient(name)
|
||||
status := "disconnected"
|
||||
if exists {
|
||||
status = client.GetStatus()
|
||||
} else if h.isEnabled(cfg) {
|
||||
status = "disconnected"
|
||||
} else {
|
||||
status = "disabled"
|
||||
}
|
||||
|
||||
toolCount := toolCounts[name]
|
||||
|
||||
result[name] = ExternalMCPResponse{
|
||||
Config: cfg,
|
||||
Status: status,
|
||||
ToolCount: toolCount,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"servers": result,
|
||||
"stats": h.manager.GetStats(),
|
||||
})
|
||||
}
|
||||
|
||||
// GetExternalMCP 获取单个外部MCP配置
|
||||
func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
configs := h.manager.GetConfigs()
|
||||
cfg, exists := configs[name]
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
client, clientExists := h.manager.GetClient(name)
|
||||
status := "disconnected"
|
||||
if clientExists {
|
||||
status = client.GetStatus()
|
||||
} else if h.isEnabled(cfg) {
|
||||
status = "disconnected"
|
||||
} else {
|
||||
status = "disabled"
|
||||
}
|
||||
|
||||
// 获取工具数量
|
||||
toolCount := 0
|
||||
if clientExists && client.IsConnected() {
|
||||
if count, err := h.manager.GetToolCount(name); err == nil {
|
||||
toolCount = count
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ExternalMCPResponse{
|
||||
Config: cfg,
|
||||
Status: status,
|
||||
ToolCount: toolCount,
|
||||
})
|
||||
}
|
||||
|
||||
// AddOrUpdateExternalMCP 添加或更新外部MCP配置
|
||||
func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||||
var req AddOrUpdateExternalMCPRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
name := c.Param("name")
|
||||
if name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
if err := h.validateConfig(req.Config); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 添加或更新配置
|
||||
if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil {
|
||||
h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新内存中的配置
|
||||
if h.config.ExternalMCP.Servers == nil {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
}
|
||||
|
||||
// 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容
|
||||
// 同时将值迁移到 external_mcp_enable
|
||||
cfg := req.Config
|
||||
|
||||
if req.Config.Disabled {
|
||||
// 用户设置了 disabled: true
|
||||
cfg.ExternalMCPEnable = false
|
||||
cfg.Disabled = true
|
||||
cfg.Enabled = false
|
||||
} else if req.Config.Enabled {
|
||||
// 用户设置了 enabled: true
|
||||
cfg.ExternalMCPEnable = true
|
||||
cfg.Enabled = true
|
||||
cfg.Disabled = false
|
||||
} else if !req.Config.ExternalMCPEnable {
|
||||
// 用户没有设置任何字段,且 external_mcp_enable 为 false
|
||||
// 检查现有配置是否有旧字段
|
||||
if existingCfg, exists := h.config.ExternalMCP.Servers[name]; exists {
|
||||
// 保留现有的旧字段
|
||||
cfg.Enabled = existingCfg.Enabled
|
||||
cfg.Disabled = existingCfg.Disabled
|
||||
}
|
||||
} else {
|
||||
// 用户通过新字段启用了(external_mcp_enable: true),但没有设置旧字段
|
||||
// 为了向后兼容,我们设置 enabled: true
|
||||
// 这样即使原始配置中有 disabled: false,也会被转换为 enabled: true
|
||||
cfg.Enabled = true
|
||||
cfg.Disabled = false
|
||||
}
|
||||
|
||||
h.config.ExternalMCP.Servers[name] = cfg
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
|
||||
// DeleteExternalMCP 删除外部MCP配置
|
||||
func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 移除配置
|
||||
if err := h.manager.RemoveConfig(name); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 从内存配置中删除
|
||||
if h.config.ExternalMCP.Servers != nil {
|
||||
delete(h.config.ExternalMCP.Servers, name)
|
||||
}
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
|
||||
}
|
||||
|
||||
// StartExternalMCP 启动外部MCP
|
||||
func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 更新配置为启用
|
||||
if h.config.ExternalMCP.Servers == nil {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
}
|
||||
cfg := h.config.ExternalMCP.Servers[name]
|
||||
cfg.ExternalMCPEnable = true
|
||||
h.config.ExternalMCP.Servers[name] = cfg
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 启动客户端(这可能会花费一些时间)
|
||||
h.logger.Info("开始启动外部MCP", zap.String("name", name))
|
||||
if err := h.manager.StartClient(name); err != nil {
|
||||
h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": err.Error(),
|
||||
"status": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取连接状态
|
||||
client, exists := h.manager.GetClient(name)
|
||||
status := "disconnected"
|
||||
if exists {
|
||||
status = client.GetStatus()
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP启动完成", zap.String("name", name), zap.String("status", status))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "外部MCP启动完成",
|
||||
"status": status,
|
||||
})
|
||||
}
|
||||
|
||||
// StopExternalMCP 停止外部MCP
|
||||
func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 停止客户端
|
||||
if err := h.manager.StopClient(name); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新配置
|
||||
if h.config.ExternalMCP.Servers == nil {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
}
|
||||
cfg := h.config.ExternalMCP.Servers[name]
|
||||
cfg.ExternalMCPEnable = false
|
||||
h.config.ExternalMCP.Servers[name] = cfg
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP已停止", zap.String("name", name))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"})
|
||||
}
|
||||
|
||||
// GetExternalMCPStats 获取统计信息
|
||||
func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) {
|
||||
stats := h.manager.GetStats()
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// validateConfig 验证配置
|
||||
func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error {
|
||||
transport := cfg.Transport
|
||||
if transport == "" {
|
||||
// 如果没有指定transport,根据是否有command或url判断
|
||||
if cfg.Command != "" {
|
||||
transport = "stdio"
|
||||
} else if cfg.URL != "" {
|
||||
transport = "http"
|
||||
} else {
|
||||
return fmt.Errorf("需要指定command(stdio模式)或url(http模式)")
|
||||
}
|
||||
}
|
||||
|
||||
switch transport {
|
||||
case "http":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("HTTP模式需要URL")
|
||||
}
|
||||
case "stdio":
|
||||
if cfg.Command == "" {
|
||||
return fmt.Errorf("stdio模式需要command")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio", transport)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isEnabled 检查是否启用
|
||||
func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool {
|
||||
// 优先使用 ExternalMCPEnable 字段
|
||||
// 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容)
|
||||
if cfg.ExternalMCPEnable {
|
||||
return true
|
||||
}
|
||||
// 向后兼容:检查旧字段
|
||||
if cfg.Disabled {
|
||||
return false
|
||||
}
|
||||
if cfg.Enabled {
|
||||
return true
|
||||
}
|
||||
// 都没有设置,默认为启用
|
||||
return true
|
||||
}
|
||||
|
||||
// saveConfig 保存配置到文件
|
||||
func (h *ExternalMCPHandler) saveConfig() error {
|
||||
// 读取现有配置文件并创建备份
|
||||
data, err := os.ReadFile(h.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil {
|
||||
h.logger.Warn("创建配置备份失败", zap.Error(err))
|
||||
}
|
||||
|
||||
root, err := loadYAMLDocument(h.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 在更新前,读取原始配置中的 enabled/disabled 字段,以便保持向后兼容
|
||||
originalConfigs := make(map[string]map[string]bool)
|
||||
externalMCPNode := findMapValue(root.Content[0], "external_mcp")
|
||||
if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode {
|
||||
serversNode := findMapValue(externalMCPNode, "servers")
|
||||
if serversNode != nil && serversNode.Kind == yaml.MappingNode {
|
||||
// 遍历现有的服务器配置,保存 enabled/disabled 字段
|
||||
for i := 0; i < len(serversNode.Content); i += 2 {
|
||||
if i+1 >= len(serversNode.Content) {
|
||||
break
|
||||
}
|
||||
nameNode := serversNode.Content[i]
|
||||
serverNode := serversNode.Content[i+1]
|
||||
if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode {
|
||||
serverName := nameNode.Value
|
||||
originalConfigs[serverName] = make(map[string]bool)
|
||||
// 检查是否有 enabled 字段
|
||||
if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil {
|
||||
originalConfigs[serverName]["enabled"] = *enabledVal
|
||||
}
|
||||
// 检查是否有 disabled 字段
|
||||
if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil {
|
||||
originalConfigs[serverName]["disabled"] = *disabledVal
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新外部MCP配置
|
||||
updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs)
|
||||
|
||||
if err := writeYAMLDocument(h.configPath, root); err != nil {
|
||||
return fmt.Errorf("保存配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
h.logger.Info("配置已保存", zap.String("path", h.configPath))
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateExternalMCPConfig 更新外部MCP配置
|
||||
func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, originalConfigs map[string]map[string]bool) {
|
||||
root := doc.Content[0]
|
||||
externalMCPNode := ensureMap(root, "external_mcp")
|
||||
serversNode := ensureMap(externalMCPNode, "servers")
|
||||
|
||||
// 清空现有服务器配置
|
||||
serversNode.Content = nil
|
||||
|
||||
// 添加新的服务器配置
|
||||
for name, serverCfg := range cfg.Servers {
|
||||
// 添加服务器名称键
|
||||
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
|
||||
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
|
||||
|
||||
// 设置服务器配置字段
|
||||
if serverCfg.Command != "" {
|
||||
setStringInMap(serverNode, "command", serverCfg.Command)
|
||||
}
|
||||
if len(serverCfg.Args) > 0 {
|
||||
setStringArrayInMap(serverNode, "args", serverCfg.Args)
|
||||
}
|
||||
if serverCfg.Transport != "" {
|
||||
setStringInMap(serverNode, "transport", serverCfg.Transport)
|
||||
}
|
||||
if serverCfg.URL != "" {
|
||||
setStringInMap(serverNode, "url", serverCfg.URL)
|
||||
}
|
||||
if serverCfg.Description != "" {
|
||||
setStringInMap(serverNode, "description", serverCfg.Description)
|
||||
}
|
||||
if serverCfg.Timeout > 0 {
|
||||
setIntInMap(serverNode, "timeout", serverCfg.Timeout)
|
||||
}
|
||||
// 保存 external_mcp_enable 字段(新字段)
|
||||
setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable)
|
||||
// 保存 tool_enabled 字段(每个工具的启用状态)
|
||||
if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 {
|
||||
toolEnabledNode := ensureMap(serverNode, "tool_enabled")
|
||||
for toolName, enabled := range serverCfg.ToolEnabled {
|
||||
setBoolInMap(toolEnabledNode, toolName, enabled)
|
||||
}
|
||||
}
|
||||
// 保留旧的 enabled/disabled 字段以保持向后兼容
|
||||
originalFields, hasOriginal := originalConfigs[name]
|
||||
|
||||
// 如果原始配置中有 enabled 字段,保留它
|
||||
if hasOriginal {
|
||||
if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled {
|
||||
setBoolInMap(serverNode, "enabled", enabledVal)
|
||||
}
|
||||
// 如果原始配置中有 disabled 字段,保留它
|
||||
// 注意:由于 omitempty,disabled: false 不会被保存,但 disabled: true 会被保存
|
||||
if disabledVal, hasDisabled := originalFields["disabled"]; hasDisabled {
|
||||
if disabledVal {
|
||||
setBoolInMap(serverNode, "disabled", disabledVal)
|
||||
} else {
|
||||
// 如果原始配置中有 disabled: false,我们保存 enabled: true 来等效表示
|
||||
// 因为 disabled: false 等价于 enabled: true
|
||||
setBoolInMap(serverNode, "enabled", true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果用户在当前请求中明确设置了这些字段,也保存它们
|
||||
if serverCfg.Enabled {
|
||||
setBoolInMap(serverNode, "enabled", serverCfg.Enabled)
|
||||
}
|
||||
if serverCfg.Disabled {
|
||||
setBoolInMap(serverNode, "disabled", serverCfg.Disabled)
|
||||
} else if !hasOriginal && serverCfg.ExternalMCPEnable {
|
||||
// 如果用户通过新字段启用了,且原始配置中没有旧字段,保存 enabled: true 以保持向后兼容
|
||||
setBoolInMap(serverNode, "enabled", true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setStringArrayInMap 设置字符串数组
|
||||
func setStringArrayInMap(mapNode *yaml.Node, key string, values []string) {
|
||||
_, valueNode := ensureKeyValue(mapNode, key)
|
||||
valueNode.Kind = yaml.SequenceNode
|
||||
valueNode.Tag = "!!seq"
|
||||
valueNode.Content = nil
|
||||
for _, v := range values {
|
||||
itemNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: v}
|
||||
valueNode.Content = append(valueNode.Content, itemNode)
|
||||
}
|
||||
}
|
||||
|
||||
// AddOrUpdateExternalMCPRequest 添加或更新外部MCP请求
|
||||
type AddOrUpdateExternalMCPRequest struct {
|
||||
Config config.ExternalMCPServerConfig `json:"config"`
|
||||
}
|
||||
|
||||
// ExternalMCPResponse 外部MCP响应
|
||||
type ExternalMCPResponse struct {
|
||||
Config config.ExternalMCPServerConfig `json:"config"`
|
||||
Status string `json:"status"` // "connected", "disconnected", "disabled", "error"
|
||||
ToolCount int `json:"tool_count"` // 工具数量
|
||||
}
|
||||
|
||||
@@ -0,0 +1,518 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 创建临时配置文件
|
||||
tmpFile, err := os.CreateTemp("", "test-config-*.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n")
|
||||
tmpFile.Close()
|
||||
configPath := tmpFile.Name()
|
||||
|
||||
logger := zap.NewNop()
|
||||
manager := mcp.NewExternalMCPManager(logger)
|
||||
cfg := &config.Config{
|
||||
ExternalMCP: config.ExternalMCPConfig{
|
||||
Servers: make(map[string]config.ExternalMCPServerConfig),
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewExternalMCPHandler(manager, cfg, configPath, logger)
|
||||
|
||||
api := router.Group("/api")
|
||||
api.GET("/external-mcp", handler.GetExternalMCPs)
|
||||
api.GET("/external-mcp/stats", handler.GetExternalMCPStats)
|
||||
api.GET("/external-mcp/:name", handler.GetExternalMCP)
|
||||
api.PUT("/external-mcp/:name", handler.AddOrUpdateExternalMCP)
|
||||
api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP)
|
||||
api.POST("/external-mcp/:name/start", handler.StartExternalMCP)
|
||||
api.POST("/external-mcp/:name/stop", handler.StopExternalMCP)
|
||||
|
||||
return router, handler, configPath
|
||||
}
|
||||
|
||||
func cleanupTestConfig(configPath string) {
|
||||
os.Remove(configPath)
|
||||
os.Remove(configPath + ".backup")
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 测试添加stdio模式的配置
|
||||
configJSON := `{
|
||||
"command": "python3",
|
||||
"args": ["/path/to/script.py", "--server", "http://example.com"],
|
||||
"description": "Test stdio MCP",
|
||||
"timeout": 300,
|
||||
"enabled": true
|
||||
}`
|
||||
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
|
||||
t.Fatalf("解析配置JSON失败: %v", err)
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// 验证配置已添加
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
|
||||
var response ExternalMCPResponse
|
||||
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Config.Command != "python3" {
|
||||
t.Errorf("期望command为python3,实际%s", response.Config.Command)
|
||||
}
|
||||
if len(response.Config.Args) != 3 {
|
||||
t.Errorf("期望args长度为3,实际%d", len(response.Config.Args))
|
||||
}
|
||||
if response.Config.Description != "Test stdio MCP" {
|
||||
t.Errorf("期望description为'Test stdio MCP',实际%s", response.Config.Description)
|
||||
}
|
||||
if response.Config.Timeout != 300 {
|
||||
t.Errorf("期望timeout为300,实际%d", response.Config.Timeout)
|
||||
}
|
||||
if !response.Config.Enabled {
|
||||
t.Error("期望enabled为true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 测试添加HTTP模式的配置
|
||||
configJSON := `{
|
||||
"transport": "http",
|
||||
"url": "http://127.0.0.1:8081/mcp",
|
||||
"enabled": true
|
||||
}`
|
||||
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
|
||||
t.Fatalf("解析配置JSON失败: %v", err)
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// 验证配置已添加
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
|
||||
var response ExternalMCPResponse
|
||||
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Config.Transport != "http" {
|
||||
t.Errorf("期望transport为http,实际%s", response.Config.Transport)
|
||||
}
|
||||
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
|
||||
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
|
||||
}
|
||||
if !response.Config.Enabled {
|
||||
t.Error("期望enabled为true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
configJSON string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "缺少command和url",
|
||||
configJSON: `{"enabled": true}`,
|
||||
expectedErr: "需要指定command(stdio模式)或url(http模式)",
|
||||
},
|
||||
{
|
||||
name: "stdio模式缺少command",
|
||||
configJSON: `{"args": ["test"], "enabled": true}`,
|
||||
expectedErr: "stdio模式需要command",
|
||||
},
|
||||
{
|
||||
name: "http模式缺少url",
|
||||
configJSON: `{"transport": "http", "enabled": true}`,
|
||||
expectedErr: "HTTP模式需要URL",
|
||||
},
|
||||
{
|
||||
name: "无效的transport",
|
||||
configJSON: `{"transport": "invalid", "enabled": true}`,
|
||||
expectedErr: "不支持的传输模式",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil {
|
||||
t.Fatalf("解析配置JSON失败: %v", err)
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
errorMsg := response["error"].(string)
|
||||
// 对于stdio模式缺少command的情况,错误信息可能略有不同
|
||||
if tc.name == "stdio模式缺少command" {
|
||||
if !strings.Contains(errorMsg, "stdio") && !strings.Contains(errorMsg, "command") {
|
||||
t.Errorf("期望错误信息包含'stdio'或'command',实际'%s'", errorMsg)
|
||||
}
|
||||
} else if !strings.Contains(errorMsg, tc.expectedErr) {
|
||||
t.Errorf("期望错误信息包含'%s',实际'%s'", tc.expectedErr, errorMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
|
||||
router, handler, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 先添加一个配置
|
||||
configObj := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
}
|
||||
handler.manager.AddOrUpdateConfig("test-delete", configObj)
|
||||
|
||||
// 删除配置
|
||||
req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// 验证配置已删除
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
||||
router, handler, _ := setupTestRouter()
|
||||
|
||||
// 添加多个配置
|
||||
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: false,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
servers := response["servers"].(map[string]interface{})
|
||||
if len(servers) != 2 {
|
||||
t.Errorf("期望2个服务器,实际%d", len(servers))
|
||||
}
|
||||
if _, ok := servers["test1"]; !ok {
|
||||
t.Error("期望包含test1")
|
||||
}
|
||||
if _, ok := servers["test2"]; !ok {
|
||||
t.Error("期望包含test2")
|
||||
}
|
||||
|
||||
stats := response["stats"].(map[string]interface{})
|
||||
if int(stats["total"].(float64)) != 2 {
|
||||
t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
|
||||
router, handler, _ := setupTestRouter()
|
||||
|
||||
// 添加配置
|
||||
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: false,
|
||||
Disabled: true,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var stats map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if int(stats["total"].(float64)) != 3 {
|
||||
t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64)))
|
||||
}
|
||||
if int(stats["enabled"].(float64)) != 2 {
|
||||
t.Errorf("期望启用数为2,实际%d", int(stats["enabled"].(float64)))
|
||||
}
|
||||
if int(stats["disabled"].(float64)) != 1 {
|
||||
t.Errorf("期望停用数为1,实际%d", int(stats["disabled"].(float64)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
|
||||
router, handler, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 添加一个禁用的配置
|
||||
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: false,
|
||||
Disabled: true,
|
||||
})
|
||||
|
||||
// 测试启动(可能会失败,因为没有真实的服务器)
|
||||
req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 启动可能会失败,但应该返回合理的状态码
|
||||
if w.Code != http.StatusOK {
|
||||
// 如果启动失败,应该是400或500
|
||||
if w.Code != http.StatusBadRequest && w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// 测试停止
|
||||
req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
|
||||
router, _, _ := setupTestRouter()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 删除不存在的配置可能返回200(幂等操作)或404,都是合理的
|
||||
if w.Code != http.StatusNotFound && w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
|
||||
router, _, _ := setupTestRouter()
|
||||
|
||||
configObj := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 空名称应该返回404或400
|
||||
if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest {
|
||||
t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
|
||||
router, _, _ := setupTestRouter()
|
||||
|
||||
// 发送无效的JSON
|
||||
body := []byte(`{"config": invalid json}`)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
|
||||
router, handler, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 先添加配置
|
||||
config1 := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
}
|
||||
handler.manager.AddOrUpdateConfig("test-update", config1)
|
||||
|
||||
// 更新配置
|
||||
config2 := config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: config2,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// 验证配置已更新
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
|
||||
var response ExternalMCPResponse
|
||||
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
|
||||
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
|
||||
}
|
||||
if response.Config.Command != "" {
|
||||
t.Errorf("期望command为空,实际%s", response.Config.Command)
|
||||
}
|
||||
}
|
||||
|
||||
+77
-16
@@ -14,22 +14,29 @@ import (
|
||||
|
||||
// MonitorHandler 监控处理器
|
||||
type MonitorHandler struct {
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
mcpServer *mcp.Server
|
||||
externalMCPMgr *mcp.ExternalMCPManager
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewMonitorHandler 创建新的监控处理器
|
||||
func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, db *database.DB, logger *zap.Logger) *MonitorHandler {
|
||||
return &MonitorHandler{
|
||||
mcpServer: mcpServer,
|
||||
executor: executor,
|
||||
db: db,
|
||||
logger: logger,
|
||||
mcpServer: mcpServer,
|
||||
externalMCPMgr: nil, // 将在创建后设置
|
||||
executor: executor,
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SetExternalMCPManager 设置外部MCP管理器
|
||||
func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
|
||||
h.externalMCPMgr = mgr
|
||||
}
|
||||
|
||||
// MonitorResponse 监控响应
|
||||
type MonitorResponse struct {
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
@@ -128,15 +135,49 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int) ([]*mc
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
|
||||
// 合并内部MCP服务器和外部MCP管理器的统计信息
|
||||
stats := make(map[string]*mcp.ToolStats)
|
||||
|
||||
// 加载内部MCP服务器的统计信息
|
||||
if h.db == nil {
|
||||
return h.mcpServer.GetStats()
|
||||
internalStats := h.mcpServer.GetStats()
|
||||
for k, v := range internalStats {
|
||||
stats[k] = v
|
||||
}
|
||||
} else {
|
||||
dbStats, err := h.db.LoadToolStats()
|
||||
if err != nil {
|
||||
h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err))
|
||||
internalStats := h.mcpServer.GetStats()
|
||||
for k, v := range internalStats {
|
||||
stats[k] = v
|
||||
}
|
||||
} else {
|
||||
for k, v := range dbStats {
|
||||
stats[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.db.LoadToolStats()
|
||||
if err != nil {
|
||||
h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err))
|
||||
return h.mcpServer.GetStats()
|
||||
// 合并外部MCP管理器的统计信息
|
||||
if h.externalMCPMgr != nil {
|
||||
externalStats := h.externalMCPMgr.GetToolStats()
|
||||
for k, v := range externalStats {
|
||||
// 如果已存在,合并统计信息
|
||||
if existing, exists := stats[k]; exists {
|
||||
existing.TotalCalls += v.TotalCalls
|
||||
existing.SuccessCalls += v.SuccessCalls
|
||||
existing.FailedCalls += v.FailedCalls
|
||||
// 使用最新的调用时间
|
||||
if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) {
|
||||
existing.LastCallTime = v.LastCallTime
|
||||
}
|
||||
} else {
|
||||
stats[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
@@ -145,13 +186,32 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
|
||||
func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
// 先从内部MCP服务器查找
|
||||
exec, exists := h.mcpServer.GetExecution(id)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
|
||||
if exists {
|
||||
c.JSON(http.StatusOK, exec)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, exec)
|
||||
// 如果找不到,尝试从外部MCP管理器查找
|
||||
if h.externalMCPMgr != nil {
|
||||
exec, exists = h.externalMCPMgr.GetExecution(id)
|
||||
if exists {
|
||||
c.JSON(http.StatusOK, exec)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果都找不到,尝试从数据库查找(如果使用数据库存储)
|
||||
if h.db != nil {
|
||||
exec, err := h.db.GetToolExecution(id)
|
||||
if err == nil && exec != nil {
|
||||
c.JSON(http.StatusOK, exec)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
|
||||
}
|
||||
|
||||
// GetStats 获取统计信息
|
||||
@@ -160,3 +220,4 @@ func (h *MonitorHandler) GetStats(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user