mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
938 lines
23 KiB
Go
938 lines
23 KiB
Go
package mcp
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/config"
|
||
|
||
"github.com/google/uuid"
|
||
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// ExternalMCPManager 外部MCP管理器
|
||
type ExternalMCPManager struct {
|
||
clients map[string]ExternalMCPClient
|
||
configs map[string]config.ExternalMCPServerConfig
|
||
logger *zap.Logger
|
||
storage MonitorStorage // 可选的持久化存储
|
||
executions map[string]*ToolExecution // 执行记录
|
||
stats map[string]*ToolStats // 工具统计信息
|
||
errors map[string]string // 错误信息
|
||
toolCounts map[string]int // 工具数量缓存
|
||
toolCountsMu sync.RWMutex // 工具数量缓存的锁
|
||
stopRefresh chan struct{} // 停止后台刷新的信号
|
||
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
// NewExternalMCPManager 创建外部MCP管理器
|
||
func NewExternalMCPManager(logger *zap.Logger) *ExternalMCPManager {
|
||
return NewExternalMCPManagerWithStorage(logger, nil)
|
||
}
|
||
|
||
// NewExternalMCPManagerWithStorage 创建外部MCP管理器(带持久化存储)
|
||
func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage) *ExternalMCPManager {
|
||
manager := &ExternalMCPManager{
|
||
clients: make(map[string]ExternalMCPClient),
|
||
configs: make(map[string]config.ExternalMCPServerConfig),
|
||
logger: logger,
|
||
storage: storage,
|
||
executions: make(map[string]*ToolExecution),
|
||
stats: make(map[string]*ToolStats),
|
||
errors: make(map[string]string),
|
||
toolCounts: make(map[string]int),
|
||
stopRefresh: make(chan struct{}),
|
||
}
|
||
// 启动后台刷新工具数量的goroutine
|
||
manager.startToolCountRefresh()
|
||
return manager
|
||
}
|
||
|
||
// LoadConfigs 加载配置
|
||
func (m *ExternalMCPManager) LoadConfigs(cfg *config.ExternalMCPConfig) {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
|
||
if cfg == nil || cfg.Servers == nil {
|
||
return
|
||
}
|
||
|
||
m.configs = make(map[string]config.ExternalMCPServerConfig)
|
||
for name, serverCfg := range cfg.Servers {
|
||
m.configs[name] = serverCfg
|
||
}
|
||
}
|
||
|
||
// GetConfigs 获取所有配置
|
||
func (m *ExternalMCPManager) GetConfigs() map[string]config.ExternalMCPServerConfig {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
|
||
result := make(map[string]config.ExternalMCPServerConfig)
|
||
for k, v := range m.configs {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// AddOrUpdateConfig 添加或更新配置
|
||
func (m *ExternalMCPManager) AddOrUpdateConfig(name string, serverCfg config.ExternalMCPServerConfig) error {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
|
||
// 如果已存在客户端,先关闭
|
||
if client, exists := m.clients[name]; exists {
|
||
client.Close()
|
||
delete(m.clients, name)
|
||
}
|
||
|
||
m.configs[name] = serverCfg
|
||
|
||
// 如果启用,自动连接
|
||
if m.isEnabled(serverCfg) {
|
||
go m.connectClient(name, serverCfg)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// RemoveConfig 移除配置
|
||
func (m *ExternalMCPManager) RemoveConfig(name string) error {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
|
||
// 关闭客户端
|
||
if client, exists := m.clients[name]; exists {
|
||
client.Close()
|
||
delete(m.clients, name)
|
||
}
|
||
|
||
delete(m.configs, name)
|
||
|
||
// 清理工具数量缓存
|
||
m.toolCountsMu.Lock()
|
||
delete(m.toolCounts, name)
|
||
m.toolCountsMu.Unlock()
|
||
|
||
return nil
|
||
}
|
||
|
||
// StartClient 启动客户端
|
||
func (m *ExternalMCPManager) StartClient(name string) error {
|
||
m.mu.Lock()
|
||
serverCfg, exists := m.configs[name]
|
||
m.mu.Unlock()
|
||
|
||
if !exists {
|
||
return fmt.Errorf("配置不存在: %s", name)
|
||
}
|
||
|
||
// 检查是否已经有连接的客户端
|
||
m.mu.RLock()
|
||
existingClient, hasClient := m.clients[name]
|
||
m.mu.RUnlock()
|
||
|
||
if hasClient {
|
||
// 检查客户端是否已连接
|
||
if existingClient.IsConnected() {
|
||
// 客户端已连接,直接返回成功(目标状态已达成)
|
||
// 更新配置为启用(确保配置一致)
|
||
m.mu.Lock()
|
||
serverCfg.ExternalMCPEnable = true
|
||
m.configs[name] = serverCfg
|
||
m.mu.Unlock()
|
||
return nil
|
||
}
|
||
// 如果有客户端但未连接,先关闭
|
||
existingClient.Close()
|
||
m.mu.Lock()
|
||
delete(m.clients, name)
|
||
m.mu.Unlock()
|
||
}
|
||
|
||
// 更新配置为启用
|
||
m.mu.Lock()
|
||
serverCfg.ExternalMCPEnable = true
|
||
m.configs[name] = serverCfg
|
||
// 清除之前的错误信息(重新启动时)
|
||
delete(m.errors, name)
|
||
m.mu.Unlock()
|
||
|
||
// 立即创建客户端并设置为"connecting"状态,这样前端可以立即看到状态
|
||
client := m.createClient(serverCfg)
|
||
if client == nil {
|
||
return fmt.Errorf("无法创建客户端:不支持的传输模式")
|
||
}
|
||
|
||
// 设置状态为connecting
|
||
m.setClientStatus(client, "connecting")
|
||
|
||
// 立即保存客户端,这样前端查询时就能看到"connecting"状态
|
||
m.mu.Lock()
|
||
m.clients[name] = client
|
||
m.mu.Unlock()
|
||
|
||
// 在后台异步进行实际连接
|
||
go func() {
|
||
if err := m.doConnect(name, serverCfg, client); err != nil {
|
||
m.logger.Error("连接外部MCP客户端失败",
|
||
zap.String("name", name),
|
||
zap.Error(err),
|
||
)
|
||
// 连接失败,设置状态为error并保存错误信息
|
||
m.setClientStatus(client, "error")
|
||
m.mu.Lock()
|
||
m.errors[name] = err.Error()
|
||
m.mu.Unlock()
|
||
// 触发工具数量刷新(连接失败,工具数量应为0)
|
||
m.triggerToolCountRefresh()
|
||
} else {
|
||
// 连接成功,清除错误信息
|
||
m.mu.Lock()
|
||
delete(m.errors, name)
|
||
m.mu.Unlock()
|
||
// 连接成功,立即刷新工具数量
|
||
m.triggerToolCountRefresh()
|
||
}
|
||
}()
|
||
|
||
return nil
|
||
}
|
||
|
||
// StopClient 停止客户端
|
||
func (m *ExternalMCPManager) StopClient(name string) error {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
|
||
serverCfg, exists := m.configs[name]
|
||
if !exists {
|
||
return fmt.Errorf("配置不存在: %s", name)
|
||
}
|
||
|
||
// 关闭客户端
|
||
if client, exists := m.clients[name]; exists {
|
||
client.Close()
|
||
delete(m.clients, name)
|
||
}
|
||
|
||
// 清除错误信息
|
||
delete(m.errors, name)
|
||
|
||
// 更新工具数量缓存(停止后工具数量为0)
|
||
m.toolCountsMu.Lock()
|
||
m.toolCounts[name] = 0
|
||
m.toolCountsMu.Unlock()
|
||
|
||
// 更新配置为禁用
|
||
serverCfg.ExternalMCPEnable = false
|
||
m.configs[name] = serverCfg
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetClient 获取客户端
|
||
func (m *ExternalMCPManager) GetClient(name string) (ExternalMCPClient, bool) {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
|
||
client, exists := m.clients[name]
|
||
return client, exists
|
||
}
|
||
|
||
// GetError 获取错误信息
|
||
func (m *ExternalMCPManager) GetError(name string) string {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
|
||
return m.errors[name]
|
||
}
|
||
|
||
// GetAllTools 获取所有外部MCP的工具
|
||
func (m *ExternalMCPManager) GetAllTools(ctx context.Context) ([]Tool, error) {
|
||
m.mu.RLock()
|
||
clients := make(map[string]ExternalMCPClient)
|
||
for k, v := range m.clients {
|
||
clients[k] = v
|
||
}
|
||
m.mu.RUnlock()
|
||
|
||
var allTools []Tool
|
||
for name, client := range clients {
|
||
if !client.IsConnected() {
|
||
continue
|
||
}
|
||
|
||
tools, err := client.ListTools(ctx)
|
||
if err != nil {
|
||
m.logger.Warn("获取外部MCP工具列表失败",
|
||
zap.String("name", name),
|
||
zap.Error(err),
|
||
)
|
||
continue
|
||
}
|
||
|
||
// 为工具添加前缀,避免冲突
|
||
for _, tool := range tools {
|
||
tool.Name = fmt.Sprintf("%s::%s", name, tool.Name)
|
||
allTools = append(allTools, tool)
|
||
}
|
||
}
|
||
|
||
return allTools, nil
|
||
}
|
||
|
||
// CallTool 调用外部MCP工具(返回执行ID)
|
||
func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) {
|
||
// 解析工具名称:name::toolName
|
||
var mcpName, actualToolName string
|
||
if idx := findSubstring(toolName, "::"); idx > 0 {
|
||
mcpName = toolName[:idx]
|
||
actualToolName = toolName[idx+2:]
|
||
} else {
|
||
return nil, "", fmt.Errorf("无效的工具名称格式: %s", toolName)
|
||
}
|
||
|
||
client, exists := m.GetClient(mcpName)
|
||
if !exists {
|
||
return nil, "", fmt.Errorf("外部MCP客户端不存在: %s", mcpName)
|
||
}
|
||
|
||
if !client.IsConnected() {
|
||
return nil, "", fmt.Errorf("外部MCP客户端未连接: %s", mcpName)
|
||
}
|
||
|
||
// 创建执行记录
|
||
executionID := uuid.New().String()
|
||
execution := &ToolExecution{
|
||
ID: executionID,
|
||
ToolName: toolName, // 使用完整工具名称(包含MCP名称)
|
||
Arguments: args,
|
||
Status: "running",
|
||
StartTime: time.Now(),
|
||
}
|
||
|
||
m.mu.Lock()
|
||
m.executions[executionID] = execution
|
||
// 如果内存中的执行记录超过限制,清理最旧的记录
|
||
m.cleanupOldExecutions()
|
||
m.mu.Unlock()
|
||
|
||
if m.storage != nil {
|
||
if err := m.storage.SaveToolExecution(execution); err != nil {
|
||
m.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// 调用工具
|
||
result, err := client.CallTool(ctx, actualToolName, args)
|
||
|
||
// 更新执行记录
|
||
m.mu.Lock()
|
||
now := time.Now()
|
||
execution.EndTime = &now
|
||
execution.Duration = now.Sub(execution.StartTime)
|
||
|
||
if err != nil {
|
||
execution.Status = "failed"
|
||
execution.Error = err.Error()
|
||
} else if result != nil && result.IsError {
|
||
execution.Status = "failed"
|
||
if len(result.Content) > 0 {
|
||
execution.Error = result.Content[0].Text
|
||
} else {
|
||
execution.Error = "工具执行返回错误结果"
|
||
}
|
||
execution.Result = result
|
||
} else {
|
||
execution.Status = "completed"
|
||
if result == nil {
|
||
result = &ToolResult{
|
||
Content: []Content{
|
||
{Type: "text", Text: "工具执行完成,但未返回结果"},
|
||
},
|
||
}
|
||
}
|
||
execution.Result = result
|
||
}
|
||
m.mu.Unlock()
|
||
|
||
if m.storage != nil {
|
||
if err := m.storage.SaveToolExecution(execution); err != nil {
|
||
m.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// 更新统计信息
|
||
failed := err != nil || (result != nil && result.IsError)
|
||
m.updateStats(toolName, failed)
|
||
|
||
// 如果使用存储,从内存中删除(已持久化)
|
||
if m.storage != nil {
|
||
m.mu.Lock()
|
||
delete(m.executions, executionID)
|
||
m.mu.Unlock()
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, executionID, err
|
||
}
|
||
|
||
return result, executionID, nil
|
||
}
|
||
|
||
// cleanupOldExecutions 清理旧的执行记录(保持内存中的记录数量在限制内)
|
||
func (m *ExternalMCPManager) cleanupOldExecutions() {
|
||
const maxExecutionsInMemory = 1000
|
||
if len(m.executions) <= maxExecutionsInMemory {
|
||
return
|
||
}
|
||
|
||
// 按开始时间排序,删除最旧的记录
|
||
type execTime struct {
|
||
id string
|
||
startTime time.Time
|
||
}
|
||
var execs []execTime
|
||
for id, exec := range m.executions {
|
||
execs = append(execs, execTime{id: id, startTime: exec.StartTime})
|
||
}
|
||
|
||
// 按时间排序
|
||
for i := 0; i < len(execs)-1; i++ {
|
||
for j := i + 1; j < len(execs); j++ {
|
||
if execs[i].startTime.After(execs[j].startTime) {
|
||
execs[i], execs[j] = execs[j], execs[i]
|
||
}
|
||
}
|
||
}
|
||
|
||
// 删除最旧的记录
|
||
toDelete := len(m.executions) - maxExecutionsInMemory
|
||
for i := 0; i < toDelete && i < len(execs); i++ {
|
||
delete(m.executions, execs[i].id)
|
||
}
|
||
}
|
||
|
||
// GetExecution 获取执行记录(先从内存查找,再从数据库查找)
|
||
func (m *ExternalMCPManager) GetExecution(id string) (*ToolExecution, bool) {
|
||
m.mu.RLock()
|
||
exec, exists := m.executions[id]
|
||
m.mu.RUnlock()
|
||
|
||
if exists {
|
||
return exec, true
|
||
}
|
||
|
||
if m.storage != nil {
|
||
exec, err := m.storage.GetToolExecution(id)
|
||
if err == nil {
|
||
return exec, true
|
||
}
|
||
}
|
||
|
||
return nil, false
|
||
}
|
||
|
||
// updateStats 更新统计信息
|
||
func (m *ExternalMCPManager) updateStats(toolName string, failed bool) {
|
||
now := time.Now()
|
||
if m.storage != nil {
|
||
totalCalls := 1
|
||
successCalls := 0
|
||
failedCalls := 0
|
||
if failed {
|
||
failedCalls = 1
|
||
} else {
|
||
successCalls = 1
|
||
}
|
||
if err := m.storage.UpdateToolStats(toolName, totalCalls, successCalls, failedCalls, &now); err != nil {
|
||
m.logger.Warn("保存统计信息到数据库失败", zap.Error(err))
|
||
}
|
||
return
|
||
}
|
||
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
|
||
if m.stats[toolName] == nil {
|
||
m.stats[toolName] = &ToolStats{
|
||
ToolName: toolName,
|
||
}
|
||
}
|
||
|
||
stats := m.stats[toolName]
|
||
stats.TotalCalls++
|
||
stats.LastCallTime = &now
|
||
|
||
if failed {
|
||
stats.FailedCalls++
|
||
} else {
|
||
stats.SuccessCalls++
|
||
}
|
||
}
|
||
|
||
// GetStats 获取MCP服务器统计信息
|
||
func (m *ExternalMCPManager) GetStats() map[string]interface{} {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
|
||
total := len(m.configs)
|
||
enabled := 0
|
||
disabled := 0
|
||
connected := 0
|
||
|
||
for name, cfg := range m.configs {
|
||
if m.isEnabled(cfg) {
|
||
enabled++
|
||
if client, exists := m.clients[name]; exists && client.IsConnected() {
|
||
connected++
|
||
}
|
||
} else {
|
||
disabled++
|
||
}
|
||
}
|
||
|
||
return map[string]interface{}{
|
||
"total": total,
|
||
"enabled": enabled,
|
||
"disabled": disabled,
|
||
"connected": connected,
|
||
}
|
||
}
|
||
|
||
// GetToolStats 获取工具统计信息(合并内存和数据库)
|
||
// 只返回外部MCP工具的统计信息(工具名称包含 "::")
|
||
func (m *ExternalMCPManager) GetToolStats() map[string]*ToolStats {
|
||
result := make(map[string]*ToolStats)
|
||
|
||
// 从数据库加载统计信息(如果使用数据库存储)
|
||
if m.storage != nil {
|
||
dbStats, err := m.storage.LoadToolStats()
|
||
if err == nil {
|
||
// 只保留外部MCP工具的统计信息(工具名称包含 "::")
|
||
for k, v := range dbStats {
|
||
if findSubstring(k, "::") > 0 {
|
||
result[k] = v
|
||
}
|
||
}
|
||
} else {
|
||
m.logger.Warn("从数据库加载统计信息失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// 合并内存中的统计信息
|
||
m.mu.RLock()
|
||
for k, v := range m.stats {
|
||
// 如果数据库中已有该工具的统计信息,合并它们
|
||
if existing, exists := result[k]; exists {
|
||
// 创建新的统计信息对象,避免修改共享对象
|
||
merged := &ToolStats{
|
||
ToolName: k,
|
||
TotalCalls: existing.TotalCalls + v.TotalCalls,
|
||
SuccessCalls: existing.SuccessCalls + v.SuccessCalls,
|
||
FailedCalls: existing.FailedCalls + v.FailedCalls,
|
||
}
|
||
// 使用最新的调用时间
|
||
if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) {
|
||
merged.LastCallTime = v.LastCallTime
|
||
} else if existing.LastCallTime != nil {
|
||
timeCopy := *existing.LastCallTime
|
||
merged.LastCallTime = &timeCopy
|
||
}
|
||
result[k] = merged
|
||
} else {
|
||
// 如果数据库中没有,直接使用内存中的统计信息
|
||
statCopy := *v
|
||
result[k] = &statCopy
|
||
}
|
||
}
|
||
m.mu.RUnlock()
|
||
|
||
return result
|
||
}
|
||
|
||
// GetToolCount 获取指定外部MCP的工具数量(从缓存读取,不阻塞)
|
||
func (m *ExternalMCPManager) GetToolCount(name string) (int, error) {
|
||
// 先从缓存读取
|
||
m.toolCountsMu.RLock()
|
||
if count, exists := m.toolCounts[name]; exists {
|
||
m.toolCountsMu.RUnlock()
|
||
return count, nil
|
||
}
|
||
m.toolCountsMu.RUnlock()
|
||
|
||
// 如果缓存中没有,检查客户端状态
|
||
client, exists := m.GetClient(name)
|
||
if !exists {
|
||
return 0, fmt.Errorf("客户端不存在: %s", name)
|
||
}
|
||
|
||
if !client.IsConnected() {
|
||
// 未连接,缓存为0
|
||
m.toolCountsMu.Lock()
|
||
m.toolCounts[name] = 0
|
||
m.toolCountsMu.Unlock()
|
||
return 0, nil
|
||
}
|
||
|
||
// 如果已连接但缓存中没有,触发异步刷新并返回0(避免阻塞)
|
||
m.triggerToolCountRefresh()
|
||
return 0, nil
|
||
}
|
||
|
||
// GetToolCounts 获取所有外部MCP的工具数量(从缓存读取,不阻塞)
|
||
func (m *ExternalMCPManager) GetToolCounts() map[string]int {
|
||
m.toolCountsMu.RLock()
|
||
defer m.toolCountsMu.RUnlock()
|
||
|
||
// 返回缓存的副本,避免外部修改
|
||
result := make(map[string]int)
|
||
for k, v := range m.toolCounts {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// refreshToolCounts 刷新工具数量缓存(后台异步执行)
|
||
func (m *ExternalMCPManager) refreshToolCounts() {
|
||
m.mu.RLock()
|
||
clients := make(map[string]ExternalMCPClient)
|
||
for k, v := range m.clients {
|
||
clients[k] = v
|
||
}
|
||
m.mu.RUnlock()
|
||
|
||
newCounts := make(map[string]int)
|
||
|
||
// 使用goroutine并发获取每个客户端的工具数量,避免串行阻塞
|
||
type countResult struct {
|
||
name string
|
||
count int
|
||
}
|
||
resultChan := make(chan countResult, len(clients))
|
||
|
||
for name, client := range clients {
|
||
go func(n string, c ExternalMCPClient) {
|
||
if !c.IsConnected() {
|
||
resultChan <- countResult{name: n, count: 0}
|
||
return
|
||
}
|
||
|
||
// 使用合理的超时时间(15秒),既能应对网络延迟,又不会过长阻塞
|
||
// 由于这是后台异步刷新,超时不会影响前端响应
|
||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||
tools, err := c.ListTools(ctx)
|
||
cancel()
|
||
|
||
if err != nil {
|
||
m.logger.Debug("获取外部MCP工具数量失败",
|
||
zap.String("name", n),
|
||
zap.Error(err),
|
||
)
|
||
// 如果获取失败,保留旧值(在更新时处理)
|
||
resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值
|
||
return
|
||
}
|
||
|
||
resultChan <- countResult{name: n, count: len(tools)}
|
||
}(name, client)
|
||
}
|
||
|
||
// 收集结果
|
||
m.toolCountsMu.RLock()
|
||
oldCounts := make(map[string]int)
|
||
for k, v := range m.toolCounts {
|
||
oldCounts[k] = v
|
||
}
|
||
m.toolCountsMu.RUnlock()
|
||
|
||
for i := 0; i < len(clients); i++ {
|
||
result := <-resultChan
|
||
if result.count >= 0 {
|
||
newCounts[result.name] = result.count
|
||
} else {
|
||
// 获取失败,保留旧值
|
||
if oldCount, exists := oldCounts[result.name]; exists {
|
||
newCounts[result.name] = oldCount
|
||
} else {
|
||
newCounts[result.name] = 0
|
||
}
|
||
}
|
||
}
|
||
|
||
// 更新缓存
|
||
m.toolCountsMu.Lock()
|
||
// 更新所有获取到的值
|
||
for name, count := range newCounts {
|
||
m.toolCounts[name] = count
|
||
}
|
||
// 对于未连接的客户端,设置为0
|
||
for name, client := range clients {
|
||
if !client.IsConnected() {
|
||
m.toolCounts[name] = 0
|
||
}
|
||
}
|
||
m.toolCountsMu.Unlock()
|
||
}
|
||
|
||
// startToolCountRefresh 启动后台刷新工具数量的goroutine
|
||
func (m *ExternalMCPManager) startToolCountRefresh() {
|
||
m.refreshWg.Add(1)
|
||
go func() {
|
||
defer m.refreshWg.Done()
|
||
ticker := time.NewTicker(10 * time.Second) // 每10秒刷新一次
|
||
defer ticker.Stop()
|
||
|
||
// 立即执行一次刷新
|
||
m.refreshToolCounts()
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
m.refreshToolCounts()
|
||
case <-m.stopRefresh:
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
// triggerToolCountRefresh 触发立即刷新工具数量(异步)
|
||
func (m *ExternalMCPManager) triggerToolCountRefresh() {
|
||
go m.refreshToolCounts()
|
||
}
|
||
|
||
// createClient 创建客户端(不连接)
|
||
func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient {
|
||
timeout := time.Duration(serverCfg.Timeout) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 30 * time.Second
|
||
}
|
||
|
||
// 根据传输模式创建客户端
|
||
transport := serverCfg.Transport
|
||
if transport == "" {
|
||
// 如果没有指定transport,根据是否有command或url判断
|
||
if serverCfg.Command != "" {
|
||
transport = "stdio"
|
||
} else if serverCfg.URL != "" {
|
||
// 默认使用http,但可以通过transport字段指定sse
|
||
transport = "http"
|
||
} else {
|
||
return nil
|
||
}
|
||
}
|
||
|
||
switch transport {
|
||
case "http":
|
||
if serverCfg.URL == "" {
|
||
return nil
|
||
}
|
||
return NewHTTPMCPClient(serverCfg.URL, timeout, m.logger)
|
||
case "stdio":
|
||
if serverCfg.Command == "" {
|
||
return nil
|
||
}
|
||
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, serverCfg.Env, timeout, m.logger)
|
||
case "sse":
|
||
if serverCfg.URL == "" {
|
||
return nil
|
||
}
|
||
return NewSSEMCPClient(serverCfg.URL, timeout, m.logger)
|
||
default:
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// doConnect 执行实际连接
|
||
func (m *ExternalMCPManager) doConnect(name string, serverCfg config.ExternalMCPServerConfig, client ExternalMCPClient) error {
|
||
timeout := time.Duration(serverCfg.Timeout) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 30 * time.Second
|
||
}
|
||
|
||
// 初始化连接
|
||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
|
||
if err := client.Initialize(ctx); err != nil {
|
||
return err
|
||
}
|
||
|
||
m.logger.Info("外部MCP客户端已连接",
|
||
zap.String("name", name),
|
||
)
|
||
|
||
return nil
|
||
}
|
||
|
||
// setClientStatus 设置客户端状态(通过类型断言)
|
||
func (m *ExternalMCPManager) setClientStatus(client ExternalMCPClient, status string) {
|
||
switch c := client.(type) {
|
||
case *HTTPMCPClient:
|
||
c.setStatus(status)
|
||
case *StdioMCPClient:
|
||
c.setStatus(status)
|
||
case *SSEMCPClient:
|
||
c.setStatus(status)
|
||
}
|
||
}
|
||
|
||
// connectClient 连接客户端(异步)- 保留用于向后兼容
|
||
func (m *ExternalMCPManager) connectClient(name string, serverCfg config.ExternalMCPServerConfig) error {
|
||
client := m.createClient(serverCfg)
|
||
if client == nil {
|
||
return fmt.Errorf("无法创建客户端:不支持的传输模式")
|
||
}
|
||
|
||
// 设置状态为connecting
|
||
m.setClientStatus(client, "connecting")
|
||
|
||
// 初始化连接
|
||
timeout := time.Duration(serverCfg.Timeout) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 30 * time.Second
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
|
||
if err := client.Initialize(ctx); err != nil {
|
||
m.logger.Error("初始化外部MCP客户端失败",
|
||
zap.String("name", name),
|
||
zap.Error(err),
|
||
)
|
||
return err
|
||
}
|
||
|
||
// 保存客户端
|
||
m.mu.Lock()
|
||
m.clients[name] = client
|
||
m.mu.Unlock()
|
||
|
||
m.logger.Info("外部MCP客户端已连接",
|
||
zap.String("name", name),
|
||
)
|
||
|
||
// 连接成功,触发工具数量刷新
|
||
m.triggerToolCountRefresh()
|
||
|
||
return nil
|
||
}
|
||
|
||
// isEnabled 检查是否启用
|
||
func (m *ExternalMCPManager) isEnabled(cfg config.ExternalMCPServerConfig) bool {
|
||
// 优先使用 ExternalMCPEnable 字段
|
||
// 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容)
|
||
if cfg.ExternalMCPEnable {
|
||
return true
|
||
}
|
||
// 向后兼容:检查旧字段
|
||
if cfg.Disabled {
|
||
return false
|
||
}
|
||
if cfg.Enabled {
|
||
return true
|
||
}
|
||
// 都没有设置,默认为启用
|
||
return true
|
||
}
|
||
|
||
// findSubstring 查找子字符串(简单实现)
|
||
func findSubstring(s, substr string) int {
|
||
for i := 0; i <= len(s)-len(substr); i++ {
|
||
if s[i:i+len(substr)] == substr {
|
||
return i
|
||
}
|
||
}
|
||
return -1
|
||
}
|
||
|
||
// StartAllEnabled 启动所有启用的客户端
|
||
func (m *ExternalMCPManager) StartAllEnabled() {
|
||
m.mu.RLock()
|
||
configs := make(map[string]config.ExternalMCPServerConfig)
|
||
for k, v := range m.configs {
|
||
configs[k] = v
|
||
}
|
||
m.mu.RUnlock()
|
||
|
||
for name, cfg := range configs {
|
||
if m.isEnabled(cfg) {
|
||
go func(n string, c config.ExternalMCPServerConfig) {
|
||
if err := m.connectClient(n, c); err != nil {
|
||
// 检查是否是连接被拒绝的错误(服务可能还没启动)
|
||
errStr := strings.ToLower(err.Error())
|
||
isConnectionRefused := strings.Contains(errStr, "connection refused") ||
|
||
strings.Contains(errStr, "dial tcp") ||
|
||
strings.Contains(errStr, "connect: connection refused")
|
||
|
||
if isConnectionRefused {
|
||
// 连接被拒绝,说明目标服务可能还没启动,这是正常的
|
||
// 使用 Warn 级别,提示用户这是正常的,可以通过手动启动或等待服务启动后自动连接
|
||
fields := []zap.Field{
|
||
zap.String("name", n),
|
||
zap.String("message", "目标服务可能尚未启动,这是正常的。服务启动后可通过界面手动连接,或等待自动重试"),
|
||
zap.Error(err),
|
||
}
|
||
|
||
// 根据传输模式添加相应的信息
|
||
transport := c.Transport
|
||
if transport == "" {
|
||
if c.Command != "" {
|
||
transport = "stdio"
|
||
} else if c.URL != "" {
|
||
transport = "http"
|
||
}
|
||
}
|
||
|
||
if transport == "http" && c.URL != "" {
|
||
fields = append(fields, zap.String("url", c.URL))
|
||
} else if transport == "stdio" && c.Command != "" {
|
||
fields = append(fields, zap.String("command", c.Command))
|
||
}
|
||
|
||
m.logger.Warn("外部MCP服务暂未就绪", fields...)
|
||
} else {
|
||
// 其他错误,使用 Error 级别
|
||
m.logger.Error("启动外部MCP客户端失败",
|
||
zap.String("name", n),
|
||
zap.Error(err),
|
||
)
|
||
}
|
||
}
|
||
}(name, cfg)
|
||
}
|
||
}
|
||
}
|
||
|
||
// StopAll 停止所有客户端
|
||
func (m *ExternalMCPManager) StopAll() {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
|
||
for name, client := range m.clients {
|
||
client.Close()
|
||
delete(m.clients, name)
|
||
}
|
||
|
||
// 清理所有工具数量缓存
|
||
m.toolCountsMu.Lock()
|
||
m.toolCounts = make(map[string]int)
|
||
m.toolCountsMu.Unlock()
|
||
|
||
// 停止后台刷新(使用 select 避免重复关闭 channel)
|
||
select {
|
||
case <-m.stopRefresh:
|
||
// 已经关闭,不需要再次关闭
|
||
default:
|
||
close(m.stopRefresh)
|
||
m.refreshWg.Wait()
|
||
}
|
||
}
|