Add files via upload

This commit is contained in:
公明
2025-11-15 17:50:47 +08:00
committed by GitHub
parent f7344c0090
commit d8ed2d8cb3
18 changed files with 4749 additions and 130 deletions
+309 -22
View File
@@ -2,6 +2,7 @@ package handler
import (
"bytes"
"context"
"fmt"
"net/http"
"os"
@@ -9,6 +10,7 @@ import (
"strconv"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
@@ -20,13 +22,14 @@ import (
// ConfigHandler 配置处理器
type ConfigHandler struct {
configPath string
config *config.Config
mcpServer *mcp.Server
executor *security.Executor
agent AgentUpdater // Agent接口,用于更新Agent配置
logger *zap.Logger
mu sync.RWMutex
configPath string
config *config.Config
mcpServer *mcp.Server
executor *security.Executor
agent AgentUpdater // Agent接口,用于更新Agent配置
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
logger *zap.Logger
mu sync.RWMutex
}
// AgentUpdater Agent更新接口
@@ -36,14 +39,15 @@ type AgentUpdater interface {
}
// NewConfigHandler 创建新的配置处理器
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, logger *zap.Logger) *ConfigHandler {
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler {
return &ConfigHandler{
configPath: configPath,
config: cfg,
mcpServer: mcpServer,
executor: executor,
agent: agent,
logger: logger,
configPath: configPath,
config: cfg,
mcpServer: mcpServer,
executor: executor,
agent: agent,
externalMCPMgr: externalMCPMgr,
logger: logger,
}
}
@@ -60,6 +64,8 @@ type ToolConfigInfo struct {
Name string `json:"name"`
Description string `json:"description"`
Enabled bool `json:"enabled"`
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
}
// GetConfig 获取当前配置
@@ -67,13 +73,14 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
// 获取工具列表
// 获取工具列表(包含内部和外部工具)
tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
for _, tool := range h.config.Security.Tools {
tools = append(tools, ToolConfigInfo{
Name: tool.Name,
Description: tool.ShortDescription,
Enabled: tool.Enabled,
IsExternal: false,
})
// 如果没有简短描述,使用详细描述的前100个字符
if tools[len(tools)-1].Description == "" {
@@ -85,6 +92,65 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
}
}
// 获取外部MCP工具
if h.externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
if err == nil {
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
for _, externalTool := range externalTools {
var mcpName, actualToolName string
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
mcpName = externalTool.Name[:idx]
actualToolName = externalTool.Name[idx+2:]
} else {
continue
}
enabled := false
if cfg, exists := externalMCPConfigs[mcpName]; exists {
// 首先检查外部MCP是否启用
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
enabled = false // MCP未启用,所有工具都禁用
} else {
// MCP已启用,检查单个工具的启用状态
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
if cfg.ToolEnabled == nil {
enabled = true // 未设置工具状态,默认为启用
} else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists {
enabled = toolEnabled // 使用配置的工具状态
} else {
enabled = true // 工具未在配置中,默认为启用
}
}
}
client, exists := h.externalMCPMgr.GetClient(mcpName)
if !exists || !client.IsConnected() {
enabled = false
}
description := externalTool.ShortDescription
if description == "" {
description = externalTool.Description
}
if len(description) > 100 {
description = description[:100] + "..."
}
tools = append(tools, ToolConfigInfo{
Name: actualToolName,
Description: description,
Enabled: enabled,
IsExternal: true,
ExternalMCP: mcpName,
})
}
}
}
c.JSON(http.StatusOK, GetConfigResponse{
OpenAI: h.config.OpenAI,
MCP: h.config.MCP,
@@ -128,13 +194,14 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
searchTermLower = strings.ToLower(searchTerm)
}
// 获取所有工具并应用搜索过滤
// 获取所有内部工具并应用搜索过滤
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
for _, tool := range h.config.Security.Tools {
toolInfo := ToolConfigInfo{
Name: tool.Name,
Description: tool.ShortDescription,
Enabled: tool.Enabled,
IsExternal: false,
}
// 如果没有简短描述,使用详细描述的前100个字符
if toolInfo.Description == "" {
@@ -157,6 +224,81 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
allTools = append(allTools, toolInfo)
}
// 获取外部MCP工具
if h.externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
if err != nil {
h.logger.Warn("获取外部MCP工具失败", zap.Error(err))
} else {
// 获取外部MCP配置,用于判断启用状态
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
for _, externalTool := range externalTools {
// 解析工具名称:mcpName::toolName
var mcpName, actualToolName string
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
mcpName = externalTool.Name[:idx]
actualToolName = externalTool.Name[idx+2:]
} else {
continue // 跳过格式不正确的工具
}
// 获取外部工具的启用状态
enabled := false
if cfg, exists := externalMCPConfigs[mcpName]; exists {
// 首先检查外部MCP是否启用
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
enabled = false // MCP未启用,所有工具都禁用
} else {
// MCP已启用,检查单个工具的启用状态
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
if cfg.ToolEnabled == nil {
enabled = true // 未设置工具状态,默认为启用
} else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists {
enabled = toolEnabled // 使用配置的工具状态
} else {
enabled = true // 工具未在配置中,默认为启用
}
}
}
// 检查外部MCP是否已连接
client, exists := h.externalMCPMgr.GetClient(mcpName)
if !exists || !client.IsConnected() {
enabled = false // 未连接时视为禁用
}
description := externalTool.ShortDescription
if description == "" {
description = externalTool.Description
}
if len(description) > 100 {
description = description[:100] + "..."
}
// 如果有关键词,进行搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(actualToolName)
descLower := strings.ToLower(description)
if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) {
continue // 不匹配,跳过
}
}
allTools = append(allTools, ToolConfigInfo{
Name: actualToolName, // 显示实际工具名称,不带前缀
Description: description,
Enabled: enabled,
IsExternal: true,
ExternalMCP: mcpName,
})
}
}
}
total := len(allTools)
totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 {
@@ -196,8 +338,10 @@ type UpdateConfigRequest struct {
// ToolEnableStatus 工具启用状态
type ToolEnableStatus struct {
Name string `json:"name"`
Enabled bool `json:"enabled"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
}
// UpdateConfig 更新配置
@@ -240,14 +384,28 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
// 更新工具启用状态
if req.Tools != nil {
toolMap := make(map[string]bool)
// 分离内部工具和外部工具
internalToolMap := make(map[string]bool)
// 外部工具状态:MCP名称 -> 工具名称 -> 启用状态
externalMCPToolMap := make(map[string]map[string]bool)
for _, toolStatus := range req.Tools {
toolMap[toolStatus.Name] = toolStatus.Enabled
if toolStatus.IsExternal && toolStatus.ExternalMCP != "" {
// 外部工具:保存每个工具的独立状态
mcpName := toolStatus.ExternalMCP
if externalMCPToolMap[mcpName] == nil {
externalMCPToolMap[mcpName] = make(map[string]bool)
}
externalMCPToolMap[mcpName][toolStatus.Name] = toolStatus.Enabled
} else {
// 内部工具
internalToolMap[toolStatus.Name] = toolStatus.Enabled
}
}
// 更新配置中的工具状态
// 更新内部工具状态
for i := range h.config.Security.Tools {
if enabled, ok := toolMap[h.config.Security.Tools[i].Name]; ok {
if enabled, ok := internalToolMap[h.config.Security.Tools[i].Name]; ok {
h.config.Security.Tools[i].Enabled = enabled
h.logger.Info("更新工具启用状态",
zap.String("tool", h.config.Security.Tools[i].Name),
@@ -255,6 +413,80 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
)
}
}
// 更新外部MCP工具状态
if h.externalMCPMgr != nil {
for mcpName, toolStates := range externalMCPToolMap {
// 更新配置中的工具启用状态
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
cfg, exists := h.config.ExternalMCP.Servers[mcpName]
if !exists {
h.logger.Warn("外部MCP配置不存在", zap.String("mcp", mcpName))
continue
}
// 初始化ToolEnabled map
if cfg.ToolEnabled == nil {
cfg.ToolEnabled = make(map[string]bool)
}
// 更新每个工具的启用状态
for toolName, enabled := range toolStates {
cfg.ToolEnabled[toolName] = enabled
h.logger.Info("更新外部工具启用状态",
zap.String("mcp", mcpName),
zap.String("tool", toolName),
zap.Bool("enabled", enabled),
)
}
// 检查是否有任何工具启用,如果有则启用MCP
hasEnabledTool := false
for _, enabled := range cfg.ToolEnabled {
if enabled {
hasEnabledTool = true
break
}
}
// 如果MCP之前未启用,但现在有工具启用,则启用MCP
// 如果MCP之前已启用,保持启用状态(允许部分工具禁用)
if !cfg.ExternalMCPEnable && hasEnabledTool {
cfg.ExternalMCPEnable = true
h.logger.Info("自动启用外部MCP(因为有工具启用)", zap.String("mcp", mcpName))
}
h.config.ExternalMCP.Servers[mcpName] = cfg
}
// 同步更新 externalMCPMgr 中的配置,确保 GetConfigs() 返回最新配置
// 在循环外部统一更新,避免重复调用
h.externalMCPMgr.LoadConfigs(&h.config.ExternalMCP)
// 处理MCP连接状态
for mcpName := range externalMCPToolMap {
cfg := h.config.ExternalMCP.Servers[mcpName]
// 如果MCP需要启用,确保客户端已启动
if cfg.ExternalMCPEnable {
// 启动外部MCP(如果未启动)
client, exists := h.externalMCPMgr.GetClient(mcpName)
if !exists || !client.IsConnected() {
if err := h.externalMCPMgr.StartClient(mcpName); err != nil {
h.logger.Warn("启动外部MCP失败",
zap.String("mcp", mcpName),
zap.Error(err),
)
} else {
h.logger.Info("启动外部MCP",
zap.String("mcp", mcpName),
)
}
}
}
}
}
}
// 保存配置到文件
@@ -318,6 +550,33 @@ func (h *ConfigHandler) saveConfig() error {
updateAgentConfig(root, h.config.Agent.MaxIterations)
updateMCPConfig(root, h.config.MCP)
updateOpenAIConfig(root, h.config.OpenAI)
// 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用)
// 读取原始配置以保持向后兼容
originalConfigs := make(map[string]map[string]bool)
externalMCPNode := findMapValue(root, "external_mcp")
if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode {
serversNode := findMapValue(externalMCPNode, "servers")
if serversNode != nil && serversNode.Kind == yaml.MappingNode {
for i := 0; i < len(serversNode.Content); i += 2 {
if i+1 >= len(serversNode.Content) {
break
}
nameNode := serversNode.Content[i]
serverNode := serversNode.Content[i+1]
if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode {
serverName := nameNode.Value
originalConfigs[serverName] = make(map[string]bool)
if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil {
originalConfigs[serverName]["enabled"] = *enabledVal
}
if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil {
originalConfigs[serverName]["disabled"] = *disabledVal
}
}
}
}
}
updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs)
if err := writeYAMLDocument(h.configPath, root); err != nil {
return fmt.Errorf("保存配置文件失败: %w", err)
@@ -504,6 +763,34 @@ func setIntInMap(mapNode *yaml.Node, key string, value int) {
valueNode.Value = fmt.Sprintf("%d", value)
}
func findBoolInMap(mapNode *yaml.Node, key string) *bool {
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
return nil
}
for i := 0; i < len(mapNode.Content); i += 2 {
if i+1 >= len(mapNode.Content) {
break
}
keyNode := mapNode.Content[i]
valueNode := mapNode.Content[i+1]
if keyNode.Kind == yaml.ScalarNode && keyNode.Value == key {
if valueNode.Kind == yaml.ScalarNode {
if valueNode.Value == "true" {
result := true
return &result
} else if valueNode.Value == "false" {
result := false
return &result
}
}
return nil
}
}
return nil
}
func setBoolInMap(mapNode *yaml.Node, key string, value bool) {
_, valueNode := ensureKeyValue(mapNode, key)
valueNode.Kind = yaml.ScalarNode
+510
View File
@@ -0,0 +1,510 @@
package handler
import (
"fmt"
"net/http"
"os"
"sync"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
// ExternalMCPHandler 外部MCP处理器
type ExternalMCPHandler struct {
manager *mcp.ExternalMCPManager
config *config.Config
configPath string
logger *zap.Logger
mu sync.RWMutex
}
// NewExternalMCPHandler 创建外部MCP处理器
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
return &ExternalMCPHandler{
manager: manager,
config: cfg,
configPath: configPath,
logger: logger,
}
}
// GetExternalMCPs 获取所有外部MCP配置
func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
configs := h.manager.GetConfigs()
// 获取所有外部MCP的工具数量
toolCounts := h.manager.GetToolCounts()
// 转换为响应格式
result := make(map[string]ExternalMCPResponse)
for name, cfg := range configs {
client, exists := h.manager.GetClient(name)
status := "disconnected"
if exists {
status = client.GetStatus()
} else if h.isEnabled(cfg) {
status = "disconnected"
} else {
status = "disabled"
}
toolCount := toolCounts[name]
result[name] = ExternalMCPResponse{
Config: cfg,
Status: status,
ToolCount: toolCount,
}
}
c.JSON(http.StatusOK, gin.H{
"servers": result,
"stats": h.manager.GetStats(),
})
}
// GetExternalMCP 获取单个外部MCP配置
func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.RLock()
defer h.mu.RUnlock()
configs := h.manager.GetConfigs()
cfg, exists := configs[name]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"})
return
}
client, clientExists := h.manager.GetClient(name)
status := "disconnected"
if clientExists {
status = client.GetStatus()
} else if h.isEnabled(cfg) {
status = "disconnected"
} else {
status = "disabled"
}
// 获取工具数量
toolCount := 0
if clientExists && client.IsConnected() {
if count, err := h.manager.GetToolCount(name); err == nil {
toolCount = count
}
}
c.JSON(http.StatusOK, ExternalMCPResponse{
Config: cfg,
Status: status,
ToolCount: toolCount,
})
}
// AddOrUpdateExternalMCP 添加或更新外部MCP配置
func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
var req AddOrUpdateExternalMCPRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
name := c.Param("name")
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"})
return
}
// 验证配置
if err := h.validateConfig(req.Config); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.mu.Lock()
defer h.mu.Unlock()
// 添加或更新配置
if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil {
h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()})
return
}
// 更新内存中的配置
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
// 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容
// 同时将值迁移到 external_mcp_enable
cfg := req.Config
if req.Config.Disabled {
// 用户设置了 disabled: true
cfg.ExternalMCPEnable = false
cfg.Disabled = true
cfg.Enabled = false
} else if req.Config.Enabled {
// 用户设置了 enabled: true
cfg.ExternalMCPEnable = true
cfg.Enabled = true
cfg.Disabled = false
} else if !req.Config.ExternalMCPEnable {
// 用户没有设置任何字段,且 external_mcp_enable 为 false
// 检查现有配置是否有旧字段
if existingCfg, exists := h.config.ExternalMCP.Servers[name]; exists {
// 保留现有的旧字段
cfg.Enabled = existingCfg.Enabled
cfg.Disabled = existingCfg.Disabled
}
} else {
// 用户通过新字段启用了(external_mcp_enable: true),但没有设置旧字段
// 为了向后兼容,我们设置 enabled: true
// 这样即使原始配置中有 disabled: false,也会被转换为 enabled: true
cfg.Enabled = true
cfg.Disabled = false
}
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
}
// DeleteExternalMCP 删除外部MCP配置
func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 移除配置
if err := h.manager.RemoveConfig(name); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"})
return
}
// 从内存配置中删除
if h.config.ExternalMCP.Servers != nil {
delete(h.config.ExternalMCP.Servers, name)
}
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
}
// StartExternalMCP 启动外部MCP
func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 更新配置为启用
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
cfg := h.config.ExternalMCP.Servers[name]
cfg.ExternalMCPEnable = true
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
// 启动客户端(这可能会花费一些时间)
h.logger.Info("开始启动外部MCP", zap.String("name", name))
if err := h.manager.StartClient(name); err != nil {
h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{
"error": err.Error(),
"status": "error",
})
return
}
// 获取连接状态
client, exists := h.manager.GetClient(name)
status := "disconnected"
if exists {
status = client.GetStatus()
}
h.logger.Info("外部MCP启动完成", zap.String("name", name), zap.String("status", status))
c.JSON(http.StatusOK, gin.H{
"message": "外部MCP启动完成",
"status": status,
})
}
// StopExternalMCP 停止外部MCP
func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 停止客户端
if err := h.manager.StopClient(name); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 更新配置
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
cfg := h.config.ExternalMCP.Servers[name]
cfg.ExternalMCPEnable = false
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP已停止", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"})
}
// GetExternalMCPStats 获取统计信息
func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) {
stats := h.manager.GetStats()
c.JSON(http.StatusOK, stats)
}
// validateConfig 验证配置
func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error {
transport := cfg.Transport
if transport == "" {
// 如果没有指定transport,根据是否有command或url判断
if cfg.Command != "" {
transport = "stdio"
} else if cfg.URL != "" {
transport = "http"
} else {
return fmt.Errorf("需要指定commandstdio模式)或urlhttp模式)")
}
}
switch transport {
case "http":
if cfg.URL == "" {
return fmt.Errorf("HTTP模式需要URL")
}
case "stdio":
if cfg.Command == "" {
return fmt.Errorf("stdio模式需要command")
}
default:
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio", transport)
}
return nil
}
// isEnabled 检查是否启用
func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool {
// 优先使用 ExternalMCPEnable 字段
// 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容)
if cfg.ExternalMCPEnable {
return true
}
// 向后兼容:检查旧字段
if cfg.Disabled {
return false
}
if cfg.Enabled {
return true
}
// 都没有设置,默认为启用
return true
}
// saveConfig 保存配置到文件
func (h *ExternalMCPHandler) saveConfig() error {
// 读取现有配置文件并创建备份
data, err := os.ReadFile(h.configPath)
if err != nil {
return fmt.Errorf("读取配置文件失败: %w", err)
}
if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil {
h.logger.Warn("创建配置备份失败", zap.Error(err))
}
root, err := loadYAMLDocument(h.configPath)
if err != nil {
return fmt.Errorf("解析配置文件失败: %w", err)
}
// 在更新前,读取原始配置中的 enabled/disabled 字段,以便保持向后兼容
originalConfigs := make(map[string]map[string]bool)
externalMCPNode := findMapValue(root.Content[0], "external_mcp")
if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode {
serversNode := findMapValue(externalMCPNode, "servers")
if serversNode != nil && serversNode.Kind == yaml.MappingNode {
// 遍历现有的服务器配置,保存 enabled/disabled 字段
for i := 0; i < len(serversNode.Content); i += 2 {
if i+1 >= len(serversNode.Content) {
break
}
nameNode := serversNode.Content[i]
serverNode := serversNode.Content[i+1]
if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode {
serverName := nameNode.Value
originalConfigs[serverName] = make(map[string]bool)
// 检查是否有 enabled 字段
if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil {
originalConfigs[serverName]["enabled"] = *enabledVal
}
// 检查是否有 disabled 字段
if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil {
originalConfigs[serverName]["disabled"] = *disabledVal
}
}
}
}
}
// 更新外部MCP配置
updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs)
if err := writeYAMLDocument(h.configPath, root); err != nil {
return fmt.Errorf("保存配置文件失败: %w", err)
}
h.logger.Info("配置已保存", zap.String("path", h.configPath))
return nil
}
// updateExternalMCPConfig 更新外部MCP配置
func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, originalConfigs map[string]map[string]bool) {
root := doc.Content[0]
externalMCPNode := ensureMap(root, "external_mcp")
serversNode := ensureMap(externalMCPNode, "servers")
// 清空现有服务器配置
serversNode.Content = nil
// 添加新的服务器配置
for name, serverCfg := range cfg.Servers {
// 添加服务器名称键
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
// 设置服务器配置字段
if serverCfg.Command != "" {
setStringInMap(serverNode, "command", serverCfg.Command)
}
if len(serverCfg.Args) > 0 {
setStringArrayInMap(serverNode, "args", serverCfg.Args)
}
if serverCfg.Transport != "" {
setStringInMap(serverNode, "transport", serverCfg.Transport)
}
if serverCfg.URL != "" {
setStringInMap(serverNode, "url", serverCfg.URL)
}
if serverCfg.Description != "" {
setStringInMap(serverNode, "description", serverCfg.Description)
}
if serverCfg.Timeout > 0 {
setIntInMap(serverNode, "timeout", serverCfg.Timeout)
}
// 保存 external_mcp_enable 字段(新字段)
setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable)
// 保存 tool_enabled 字段(每个工具的启用状态)
if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 {
toolEnabledNode := ensureMap(serverNode, "tool_enabled")
for toolName, enabled := range serverCfg.ToolEnabled {
setBoolInMap(toolEnabledNode, toolName, enabled)
}
}
// 保留旧的 enabled/disabled 字段以保持向后兼容
originalFields, hasOriginal := originalConfigs[name]
// 如果原始配置中有 enabled 字段,保留它
if hasOriginal {
if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled {
setBoolInMap(serverNode, "enabled", enabledVal)
}
// 如果原始配置中有 disabled 字段,保留它
// 注意:由于 omitemptydisabled: false 不会被保存,但 disabled: true 会被保存
if disabledVal, hasDisabled := originalFields["disabled"]; hasDisabled {
if disabledVal {
setBoolInMap(serverNode, "disabled", disabledVal)
} else {
// 如果原始配置中有 disabled: false,我们保存 enabled: true 来等效表示
// 因为 disabled: false 等价于 enabled: true
setBoolInMap(serverNode, "enabled", true)
}
}
}
// 如果用户在当前请求中明确设置了这些字段,也保存它们
if serverCfg.Enabled {
setBoolInMap(serverNode, "enabled", serverCfg.Enabled)
}
if serverCfg.Disabled {
setBoolInMap(serverNode, "disabled", serverCfg.Disabled)
} else if !hasOriginal && serverCfg.ExternalMCPEnable {
// 如果用户通过新字段启用了,且原始配置中没有旧字段,保存 enabled: true 以保持向后兼容
setBoolInMap(serverNode, "enabled", true)
}
}
}
// setStringArrayInMap 设置字符串数组
func setStringArrayInMap(mapNode *yaml.Node, key string, values []string) {
_, valueNode := ensureKeyValue(mapNode, key)
valueNode.Kind = yaml.SequenceNode
valueNode.Tag = "!!seq"
valueNode.Content = nil
for _, v := range values {
itemNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: v}
valueNode.Content = append(valueNode.Content, itemNode)
}
}
// AddOrUpdateExternalMCPRequest 添加或更新外部MCP请求
type AddOrUpdateExternalMCPRequest struct {
Config config.ExternalMCPServerConfig `json:"config"`
}
// ExternalMCPResponse 外部MCP响应
type ExternalMCPResponse struct {
Config config.ExternalMCPServerConfig `json:"config"`
Status string `json:"status"` // "connected", "disconnected", "disabled", "error"
ToolCount int `json:"tool_count"` // 工具数量
}
+518
View File
@@ -0,0 +1,518 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建临时配置文件
tmpFile, err := os.CreateTemp("", "test-config-*.yaml")
if err != nil {
panic(err)
}
tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n")
tmpFile.Close()
configPath := tmpFile.Name()
logger := zap.NewNop()
manager := mcp.NewExternalMCPManager(logger)
cfg := &config.Config{
ExternalMCP: config.ExternalMCPConfig{
Servers: make(map[string]config.ExternalMCPServerConfig),
},
}
handler := NewExternalMCPHandler(manager, cfg, configPath, logger)
api := router.Group("/api")
api.GET("/external-mcp", handler.GetExternalMCPs)
api.GET("/external-mcp/stats", handler.GetExternalMCPStats)
api.GET("/external-mcp/:name", handler.GetExternalMCP)
api.PUT("/external-mcp/:name", handler.AddOrUpdateExternalMCP)
api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP)
api.POST("/external-mcp/:name/start", handler.StartExternalMCP)
api.POST("/external-mcp/:name/stop", handler.StopExternalMCP)
return router, handler, configPath
}
func cleanupTestConfig(configPath string) {
os.Remove(configPath)
os.Remove(configPath + ".backup")
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加stdio模式的配置
configJSON := `{
"command": "python3",
"args": ["/path/to/script.py", "--server", "http://example.com"],
"description": "Test stdio MCP",
"timeout": 300,
"enabled": true
}`
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已添加
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Command != "python3" {
t.Errorf("期望command为python3,实际%s", response.Config.Command)
}
if len(response.Config.Args) != 3 {
t.Errorf("期望args长度为3,实际%d", len(response.Config.Args))
}
if response.Config.Description != "Test stdio MCP" {
t.Errorf("期望description为'Test stdio MCP',实际%s", response.Config.Description)
}
if response.Config.Timeout != 300 {
t.Errorf("期望timeout为300,实际%d", response.Config.Timeout)
}
if !response.Config.Enabled {
t.Error("期望enabled为true")
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加HTTP模式的配置
configJSON := `{
"transport": "http",
"url": "http://127.0.0.1:8081/mcp",
"enabled": true
}`
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已添加
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Transport != "http" {
t.Errorf("期望transport为http,实际%s", response.Config.Transport)
}
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
}
if !response.Config.Enabled {
t.Error("期望enabled为true")
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
testCases := []struct {
name string
configJSON string
expectedErr string
}{
{
name: "缺少command和url",
configJSON: `{"enabled": true}`,
expectedErr: "需要指定commandstdio模式)或urlhttp模式)",
},
{
name: "stdio模式缺少command",
configJSON: `{"args": ["test"], "enabled": true}`,
expectedErr: "stdio模式需要command",
},
{
name: "http模式缺少url",
configJSON: `{"transport": "http", "enabled": true}`,
expectedErr: "HTTP模式需要URL",
},
{
name: "无效的transport",
configJSON: `{"transport": "invalid", "enabled": true}`,
expectedErr: "不支持的传输模式",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
errorMsg := response["error"].(string)
// 对于stdio模式缺少command的情况,错误信息可能略有不同
if tc.name == "stdio模式缺少command" {
if !strings.Contains(errorMsg, "stdio") && !strings.Contains(errorMsg, "command") {
t.Errorf("期望错误信息包含'stdio'或'command',实际'%s'", errorMsg)
}
} else if !strings.Contains(errorMsg, tc.expectedErr) {
t.Errorf("期望错误信息包含'%s',实际'%s'", tc.expectedErr, errorMsg)
}
})
}
}
func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 先添加一个配置
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
handler.manager.AddOrUpdateConfig("test-delete", configObj)
// 删除配置
req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已删除
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusNotFound {
t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String())
}
}
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
router, handler, _ := setupTestRouter()
// 添加多个配置
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
})
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: false,
})
req := httptest.NewRequest("GET", "/api/external-mcp", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
servers := response["servers"].(map[string]interface{})
if len(servers) != 2 {
t.Errorf("期望2个服务器,实际%d", len(servers))
}
if _, ok := servers["test1"]; !ok {
t.Error("期望包含test1")
}
if _, ok := servers["test2"]; !ok {
t.Error("期望包含test2")
}
stats := response["stats"].(map[string]interface{})
if int(stats["total"].(float64)) != 2 {
t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64)))
}
}
func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
router, handler, _ := setupTestRouter()
// 添加配置
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
})
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: true,
})
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: false,
Disabled: true,
})
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
var stats map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if int(stats["total"].(float64)) != 3 {
t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64)))
}
if int(stats["enabled"].(float64)) != 2 {
t.Errorf("期望启用数为2,实际%d", int(stats["enabled"].(float64)))
}
if int(stats["disabled"].(float64)) != 1 {
t.Errorf("期望停用数为1,实际%d", int(stats["disabled"].(float64)))
}
}
func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 添加一个禁用的配置
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: false,
Disabled: true,
})
// 测试启动(可能会失败,因为没有真实的服务器)
req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 启动可能会失败,但应该返回合理的状态码
if w.Code != http.StatusOK {
// 如果启动失败,应该是400或500
if w.Code != http.StatusBadRequest && w.Code != http.StatusInternalServerError {
t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String())
}
}
// 测试停止
req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
}
func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
router, _, _ := setupTestRouter()
req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 删除不存在的配置可能返回200(幂等操作)或404,都是合理的
if w.Code != http.StatusNotFound && w.Code != http.StatusOK {
t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
router, _, _ := setupTestRouter()
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 空名称应该返回404或400
if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest {
t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
router, _, _ := setupTestRouter()
// 发送无效的JSON
body := []byte(`{"config": invalid json}`)
req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 先添加配置
config1 := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
handler.manager.AddOrUpdateConfig("test-update", config1)
// 更新配置
config2 := config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: config2,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已更新
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
}
if response.Config.Command != "" {
t.Errorf("期望command为空,实际%s", response.Config.Command)
}
}
+77 -16
View File
@@ -14,22 +14,29 @@ import (
// MonitorHandler 监控处理器
type MonitorHandler struct {
mcpServer *mcp.Server
executor *security.Executor
db *database.DB
logger *zap.Logger
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager
executor *security.Executor
db *database.DB
logger *zap.Logger
}
// NewMonitorHandler 创建新的监控处理器
func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, db *database.DB, logger *zap.Logger) *MonitorHandler {
return &MonitorHandler{
mcpServer: mcpServer,
executor: executor,
db: db,
logger: logger,
mcpServer: mcpServer,
externalMCPMgr: nil, // 将在创建后设置
executor: executor,
db: db,
logger: logger,
}
}
// SetExternalMCPManager 设置外部MCP管理器
func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
h.externalMCPMgr = mgr
}
// MonitorResponse 监控响应
type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"`
@@ -128,15 +135,49 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int) ([]*mc
}
func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
// 合并内部MCP服务器和外部MCP管理器的统计信息
stats := make(map[string]*mcp.ToolStats)
// 加载内部MCP服务器的统计信息
if h.db == nil {
return h.mcpServer.GetStats()
internalStats := h.mcpServer.GetStats()
for k, v := range internalStats {
stats[k] = v
}
} else {
dbStats, err := h.db.LoadToolStats()
if err != nil {
h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err))
internalStats := h.mcpServer.GetStats()
for k, v := range internalStats {
stats[k] = v
}
} else {
for k, v := range dbStats {
stats[k] = v
}
}
}
stats, err := h.db.LoadToolStats()
if err != nil {
h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err))
return h.mcpServer.GetStats()
// 合并外部MCP管理器的统计信息
if h.externalMCPMgr != nil {
externalStats := h.externalMCPMgr.GetToolStats()
for k, v := range externalStats {
// 如果已存在,合并统计信息
if existing, exists := stats[k]; exists {
existing.TotalCalls += v.TotalCalls
existing.SuccessCalls += v.SuccessCalls
existing.FailedCalls += v.FailedCalls
// 使用最新的调用时间
if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) {
existing.LastCallTime = v.LastCallTime
}
} else {
stats[k] = v
}
}
}
return stats
}
@@ -145,13 +186,32 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id")
// 先从内部MCP服务器查找
exec, exists := h.mcpServer.GetExecution(id)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
if exists {
c.JSON(http.StatusOK, exec)
return
}
c.JSON(http.StatusOK, exec)
// 如果找不到,尝试从外部MCP管理器查找
if h.externalMCPMgr != nil {
exec, exists = h.externalMCPMgr.GetExecution(id)
if exists {
c.JSON(http.StatusOK, exec)
return
}
}
// 如果都找不到,尝试从数据库查找(如果使用数据库存储)
if h.db != nil {
exec, err := h.db.GetToolExecution(id)
if err == nil && exec != nil {
c.JSON(http.StatusOK, exec)
return
}
}
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
}
// GetStats 获取统计信息
@@ -160,3 +220,4 @@ func (h *MonitorHandler) GetStats(c *gin.Context) {
c.JSON(http.StatusOK, stats)
}