From af71c6aa24d4b8e806d40615d3be419221c38f13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Fri, 12 Jun 2026 22:08:15 +0800 Subject: [PATCH] Add files via upload --- internal/mcp/client_sdk.go | 17 ++ internal/mcp/connection_recovery.go | 192 ++++++++++++++++++++ internal/mcp/connection_recovery_test.go | 215 +++++++++++++++++++++++ internal/mcp/external_manager.go | 75 +++++--- 4 files changed, 479 insertions(+), 20 deletions(-) create mode 100644 internal/mcp/connection_recovery.go create mode 100644 internal/mcp/connection_recovery_test.go diff --git a/internal/mcp/client_sdk.go b/internal/mcp/client_sdk.go index 04e73d23..0d7ebfb3 100644 --- a/internal/mcp/client_sdk.go +++ b/internal/mcp/client_sdk.go @@ -190,6 +190,23 @@ func (c *lazySDKClient) Close() error { return nil } +// markDisconnected 在检测到传输层断连时关闭底层 session,避免 IsConnected 仍返回 true。 +func (c *lazySDKClient) markDisconnected() { + c.mu.Lock() + inner := c.inner + sessionCancel := c.sessionCancel + c.inner = nil + c.sessionCancel = nil + c.mu.Unlock() + if sessionCancel != nil { + sessionCancel() + } + if inner != nil { + _ = inner.Close() + } + c.setStatus("disconnected") +} + func (c *sdkClient) setStatus(s string) { c.mu.Lock() defer c.mu.Unlock() diff --git a/internal/mcp/connection_recovery.go b/internal/mcp/connection_recovery.go new file mode 100644 index 00000000..a2ed9bfb --- /dev/null +++ b/internal/mcp/connection_recovery.go @@ -0,0 +1,192 @@ +package mcp + +import ( + "context" + "errors" + "io" + "strings" + "time" + + "go.uber.org/zap" +) + +const ( + // externalReconnectMinInterval 两次自动重连之间的最短间隔 + externalReconnectMinInterval = 30 * time.Second + // externalReconnectMaxBackoff 指数退避上限 + externalReconnectMaxBackoff = 5 * time.Minute +) + +// isConnectionDeadError 判断错误是否表示底层传输已断开(而非调用方主动取消或超时)。 +func isConnectionDeadError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + if errors.Is(err, io.EOF) { + return true + } + s := strings.ToLower(err.Error()) + return strings.Contains(s, "eof") || + strings.Contains(s, "client is closing") || + strings.Contains(s, "connection closed") || + strings.Contains(s, "connection reset") || + strings.Contains(s, "broken pipe") +} + +// handleConnectionDead 在 ListTools/CallTool 等操作失败且判定为断连时,标记客户端并调度重连。 +func (m *ExternalMCPManager) handleConnectionDead(name string, client ExternalMCPClient, err error) { + if !isConnectionDeadError(err) { + return + } + m.logger.Warn("检测到外部MCP连接已断开,将尝试自动重连", + zap.String("name", name), + zap.Error(err), + ) + m.markClientDisconnected(name, client, err) + m.scheduleReconnect(name) +} + +func (m *ExternalMCPManager) markClientDisconnected(name string, client ExternalMCPClient, err error) { + if lazy, ok := client.(*lazySDKClient); ok { + lazy.markDisconnected() + } + m.mu.Lock() + if err != nil { + m.errors[name] = "连接已断开: " + err.Error() + } + m.mu.Unlock() + m.toolCountsMu.Lock() + m.toolCounts[name] = 0 + m.toolCountsMu.Unlock() +} + +func (m *ExternalMCPManager) onClientConnected(name string) { + m.clearReconnectState(name) +} + +func (m *ExternalMCPManager) clearReconnectState(name string) { + m.reconnectMu.Lock() + delete(m.reconnectAttempts, name) + delete(m.reconnectLastTry, name) + delete(m.reconnecting, name) + m.reconnectMu.Unlock() +} + +func (m *ExternalMCPManager) reconnectBackoff(attempts int) time.Duration { + if attempts <= 0 { + return 0 + } + d := externalReconnectMinInterval + for i := 1; i < attempts && d < externalReconnectMaxBackoff; i++ { + d *= 2 + } + if d > externalReconnectMaxBackoff { + return externalReconnectMaxBackoff + } + return d +} + +func (m *ExternalMCPManager) scheduleReconnect(name string) { + m.mu.RLock() + cfg, exists := m.configs[name] + enabled := exists && m.isEnabled(cfg) + m.mu.RUnlock() + if !enabled { + return + } + go m.tryReconnect(name) +} + +func (m *ExternalMCPManager) tryReconnect(name string) { + m.reconnectMu.Lock() + if m.reconnecting[name] { + m.reconnectMu.Unlock() + return + } + attempts := m.reconnectAttempts[name] + if wait := m.reconnectBackoff(attempts); wait > 0 { + if last, ok := m.reconnectLastTry[name]; ok { + if elapsed := time.Since(last); elapsed < wait { + remaining := wait - elapsed + m.reconnectMu.Unlock() + m.scheduleReconnectAfter(name, remaining) + return + } + } + } + m.reconnecting[name] = true + m.reconnectMu.Unlock() + + defer func() { + m.reconnectMu.Lock() + delete(m.reconnecting, name) + m.reconnectMu.Unlock() + }() + + m.mu.RLock() + cfg, exists := m.configs[name] + enabled := exists && m.isEnabled(cfg) + client, hasClient := m.clients[name] + connecting := hasClient && client.GetStatus() == "connecting" + m.mu.RUnlock() + + if !enabled { + m.logger.Debug("跳过自动重连(外部MCP已停用)", zap.String("name", name)) + return + } + if connecting { + m.logger.Debug("跳过自动重连(连接正在进行中)", zap.String("name", name)) + return + } + + m.reconnectMu.Lock() + m.reconnectLastTry[name] = time.Now() + m.reconnectAttempts[name] = attempts + 1 + attemptNum := m.reconnectAttempts[name] + m.reconnectMu.Unlock() + + m.logger.Info("正在自动重连外部MCP", + zap.String("name", name), + zap.Int("attempt", attemptNum), + ) + + if err := m.startClient(name, true); err != nil { + m.logger.Warn("自动重连外部MCP失败", + zap.String("name", name), + zap.Error(err), + ) + } +} + +// scheduleReconnectAfterFailure 在自动重连失败后,按当前退避间隔预约下一次重试。 +func (m *ExternalMCPManager) scheduleReconnectAfterFailure(name string) { + m.mu.RLock() + cfg, exists := m.configs[name] + enabled := exists && m.isEnabled(cfg) + m.mu.RUnlock() + if !enabled { + return + } + m.reconnectMu.Lock() + wait := m.reconnectBackoff(m.reconnectAttempts[name]) + m.reconnectMu.Unlock() + m.logger.Info("自动重连失败,将按退避间隔再次尝试", + zap.String("name", name), + zap.Duration("after", wait), + ) + m.scheduleReconnectAfter(name, wait) +} + +// scheduleReconnectAfter 在 delay 后触发 tryReconnect(delay<=0 时立即执行)。 +func (m *ExternalMCPManager) scheduleReconnectAfter(name string, delay time.Duration) { + if delay <= 0 { + go m.tryReconnect(name) + return + } + time.AfterFunc(delay, func() { + m.tryReconnect(name) + }) +} diff --git a/internal/mcp/connection_recovery_test.go b/internal/mcp/connection_recovery_test.go new file mode 100644 index 00000000..f04e4622 --- /dev/null +++ b/internal/mcp/connection_recovery_test.go @@ -0,0 +1,215 @@ +package mcp + +import ( + "context" + "errors" + "fmt" + "io" + "testing" + "time" + + "cyberstrike-ai/internal/config" + + "go.uber.org/zap" +) + +func TestIsConnectionDeadError(t *testing.T) { + t.Parallel() + cases := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"eof", io.EOF, true}, + {"wrapped eof", fmt.Errorf("connection closed: %w", io.EOF), true}, + {"client closing", errors.New(`calling "tools/list": client is closing: EOF`), true}, + {"connection reset", errors.New("read tcp: connection reset by peer"), true}, + {"canceled", context.Canceled, false}, + {"deadline", context.DeadlineExceeded, false}, + {"other", errors.New("invalid params"), false}, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isConnectionDeadError(tc.err); got != tc.want { + t.Fatalf("isConnectionDeadError(%v) = %v, want %v", tc.err, got, tc.want) + } + }) + } +} + +func TestLazySDKClient_MarkDisconnected(t *testing.T) { + c := &lazySDKClient{status: "connected"} + c.inner = &sdkClient{status: "connected"} + c.markDisconnected() + if c.IsConnected() { + t.Fatal("expected disconnected after markDisconnected") + } + if c.GetStatus() != "disconnected" { + t.Fatalf("expected status disconnected, got %s", c.GetStatus()) + } +} + +func TestHandleConnectionDead_MarksLazyClientDisconnected(t *testing.T) { + logger := zap.NewNop() + m := NewExternalMCPManager(logger) + + name := "dead-mcp" + cfg := config.ExternalMCPServerConfig{ + Type: "http", + URL: "http://example.com/mcp", + ExternalMCPEnable: true, + } + m.mu.Lock() + m.configs[name] = cfg + client := newLazySDKClient(cfg, logger) + client.inner = &sdkClient{status: "connected"} + client.status = "connected" + m.clients[name] = client + m.mu.Unlock() + + deadErr := errors.New(`connection closed: calling "tools/list": client is closing: EOF`) + m.handleConnectionDead(name, client, deadErr) + + if client.IsConnected() { + t.Fatal("expected disconnected after handleConnectionDead") + } + if m.GetError(name) == "" { + t.Fatal("expected error message to be recorded") + } + counts := m.GetToolCounts() + if counts[name] != 0 { + t.Fatalf("expected tool count 0 after disconnect, got %d", counts[name]) + } +} + +func TestReconnectBackoff(t *testing.T) { + t.Parallel() + if d := (&ExternalMCPManager{}).reconnectBackoff(0); d != 0 { + t.Fatalf("attempt 0: got %v", d) + } + if d := (&ExternalMCPManager{}).reconnectBackoff(1); d != externalReconnectMinInterval { + t.Fatalf("attempt 1: got %v", d) + } + if d := (&ExternalMCPManager{}).reconnectBackoff(10); d != externalReconnectMaxBackoff { + t.Fatalf("attempt 10: got %v, want cap %v", d, externalReconnectMaxBackoff) + } +} + +func TestTryReconnect_RateLimited(t *testing.T) { + logger := zap.NewNop() + m := NewExternalMCPManager(logger) + + name := "rate-limited" + m.reconnectMu.Lock() + m.reconnectLastTry[name] = time.Now() + m.reconnectAttempts[name] = 2 + m.reconnectMu.Unlock() + + m.tryReconnect(name) + + m.reconnectMu.Lock() + attempts := m.reconnectAttempts[name] + m.reconnectMu.Unlock() + if attempts != 2 { + t.Fatalf("rate limited reconnect should not increment attempts, got %d", attempts) + } +} + +func TestTryReconnect_SkipsWhenDisabled(t *testing.T) { + logger := zap.NewNop() + m := NewExternalMCPManager(logger) + + name := "disabled-mcp" + m.mu.Lock() + m.configs[name] = config.ExternalMCPServerConfig{ + Type: "http", + URL: "http://example.com/mcp", + ExternalMCPEnable: false, + } + m.mu.Unlock() + + m.tryReconnect(name) + + m.reconnectMu.Lock() + attempts := m.reconnectAttempts[name] + m.reconnectMu.Unlock() + if attempts != 0 { + t.Fatalf("disabled MCP should not increment reconnect attempts, got %d", attempts) + } +} + +func TestTryReconnect_SkipsWhenConnecting(t *testing.T) { + logger := zap.NewNop() + m := NewExternalMCPManager(logger) + + name := "connecting-mcp" + cfg := config.ExternalMCPServerConfig{ + Type: "http", + URL: "http://example.com/mcp", + ExternalMCPEnable: true, + } + client := newLazySDKClient(cfg, logger) + client.setStatus("connecting") + + m.mu.Lock() + m.configs[name] = cfg + m.clients[name] = client + m.mu.Unlock() + + m.tryReconnect(name) + + m.reconnectMu.Lock() + attempts := m.reconnectAttempts[name] + m.reconnectMu.Unlock() + if attempts != 0 { + t.Fatalf("connecting MCP should not increment reconnect attempts, got %d", attempts) + } +} + +func TestStartClientAutoReconnect_SkipsWhenDisabled(t *testing.T) { + logger := zap.NewNop() + m := NewExternalMCPManager(logger) + m.stopRefresh = make(chan struct{}) + + name := "stopped" + m.mu.Lock() + m.configs[name] = config.ExternalMCPServerConfig{ + Type: "http", + URL: "http://example.com/mcp", + ExternalMCPEnable: false, + } + m.mu.Unlock() + + if err := m.startClient(name, true); err != nil { + t.Fatalf("startClient: %v", err) + } + + m.mu.RLock() + cfg := m.configs[name] + _, hasClient := m.clients[name] + m.mu.RUnlock() + if cfg.ExternalMCPEnable { + t.Fatal("auto reconnect should not enable stopped MCP") + } + if hasClient { + t.Fatal("auto reconnect should not create client when disabled") + } +} + +func TestOnClientConnected_ClearsReconnectState(t *testing.T) { + m := &ExternalMCPManager{ + reconnectAttempts: map[string]int{"x": 3}, + reconnectLastTry: map[string]time.Time{"x": time.Now()}, + reconnecting: map[string]bool{"x": true}, + } + m.onClientConnected("x") + + m.reconnectMu.Lock() + defer m.reconnectMu.Unlock() + if len(m.reconnectAttempts) != 0 || len(m.reconnectLastTry) != 0 || len(m.reconnecting) != 0 { + t.Fatal("expected reconnect state cleared") + } +} diff --git a/internal/mcp/external_manager.go b/internal/mcp/external_manager.go index 470e9715..8e8182d8 100644 --- a/internal/mcp/external_manager.go +++ b/internal/mcp/external_manager.go @@ -54,8 +54,12 @@ type ExternalMCPManager struct { refreshWg sync.WaitGroup // 等待后台刷新goroutine完成 refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积 mu sync.RWMutex - runningCancels map[string]context.CancelFunc - abortUserNotes map[string]string + runningCancels map[string]context.CancelFunc + abortUserNotes map[string]string + reconnectMu sync.Mutex + reconnecting map[string]bool + reconnectLastTry map[string]time.Time + reconnectAttempts map[string]int } // NewExternalMCPManager 创建外部MCP管理器 @@ -77,8 +81,11 @@ func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage toolCache: make(map[string]toolListCacheEntry), listToolsInflight: make(map[string]*listToolsInflight), stopRefresh: make(chan struct{}), - runningCancels: make(map[string]context.CancelFunc), - abortUserNotes: make(map[string]string), + runningCancels: make(map[string]context.CancelFunc), + abortUserNotes: make(map[string]string), + reconnecting: make(map[string]bool), + reconnectLastTry: make(map[string]time.Time), + reconnectAttempts: make(map[string]int), } // 启动后台刷新工具数量的goroutine manager.startToolCountRefresh() @@ -145,6 +152,7 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error { } delete(m.configs, name) + m.clearReconnectState(name) // 清理工具数量缓存 m.toolCountsMu.Lock() @@ -159,8 +167,13 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error { return nil } -// StartClient 启动客户端 +// StartClient 启动客户端(用户手动启动;连接失败不自动重试) func (m *ExternalMCPManager) StartClient(name string) error { + return m.startClient(name, false) +} + +// startClient 启动客户端。autoReconnect 为 true 时用于断连自愈:尊重停用状态,失败后按退避继续重试。 +func (m *ExternalMCPManager) startClient(name string, autoReconnect bool) error { m.mu.Lock() serverCfg, exists := m.configs[name] m.mu.Unlock() @@ -169,6 +182,10 @@ func (m *ExternalMCPManager) StartClient(name string) error { return fmt.Errorf("配置不存在: %s", name) } + if autoReconnect && !m.isEnabled(serverCfg) { + return nil + } + // 检查是否已经有连接的客户端 m.mu.RLock() existingClient, hasClient := m.clients[name] @@ -178,11 +195,12 @@ func (m *ExternalMCPManager) StartClient(name string) error { // 检查客户端是否已连接 if existingClient.IsConnected() { // 客户端已连接,直接返回成功(目标状态已达成) - // 更新配置为启用(确保配置一致) - m.mu.Lock() - serverCfg.ExternalMCPEnable = true - m.configs[name] = serverCfg - m.mu.Unlock() + if !autoReconnect { + m.mu.Lock() + serverCfg.ExternalMCPEnable = true + m.configs[name] = serverCfg + m.mu.Unlock() + } return nil } // 如果有客户端但未连接,先关闭 @@ -192,6 +210,16 @@ func (m *ExternalMCPManager) StartClient(name string) error { m.mu.Unlock() } + if autoReconnect { + m.mu.RLock() + serverCfg, exists = m.configs[name] + enabled := exists && m.isEnabled(serverCfg) + m.mu.RUnlock() + if !enabled { + return nil + } + } + // 更新配置为启用 m.mu.Lock() serverCfg.ExternalMCPEnable = true @@ -215,10 +243,11 @@ func (m *ExternalMCPManager) StartClient(name string) error { m.mu.Unlock() // 在后台异步进行实际连接 - go func() { + go func(reconnect bool) { if err := m.doConnect(name, serverCfg, client); err != nil { m.logger.Error("连接外部MCP客户端失败", zap.String("name", name), + zap.Bool("auto_reconnect", reconnect), zap.Error(err), ) // 连接失败,设置状态为error并保存错误信息 @@ -228,15 +257,19 @@ func (m *ExternalMCPManager) StartClient(name string) error { m.mu.Unlock() // 触发工具数量刷新(连接失败,工具数量应为0) m.triggerToolCountRefresh() + if reconnect { + m.scheduleReconnectAfterFailure(name) + } } else { // 连接成功,清除错误信息 m.mu.Lock() delete(m.errors, name) m.mu.Unlock() + m.onClientConnected(name) // 异步拉取工具列表(singleflight 去重,结果同时写入 toolCache 与 toolCounts) go m.refreshToolCache(name, client) } - }() + }(autoReconnect) return nil } @@ -273,6 +306,8 @@ func (m *ExternalMCPManager) StopClient(name string) error { serverCfg.ExternalMCPEnable = false m.configs[name] = serverCfg + m.clearReconnectState(name) + return nil } @@ -465,6 +500,7 @@ func (m *ExternalMCPManager) listToolsDeduped(ctx context.Context, name string, m.listToolsMu.Unlock() if inflight.err != nil { + m.handleConnectionDead(name, client, inflight.err) return nil, inflight.err } return cloneTools(inflight.tools), nil @@ -579,6 +615,9 @@ func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args // 调用工具 result, err := client.CallTool(execCtx, actualToolName, args) + if err != nil { + m.handleConnectionDead(mcpName, client, err) + } cancelledWithUserNote := m.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err) // 更新执行记录 @@ -980,14 +1019,7 @@ func (m *ExternalMCPManager) refreshToolCounts() { cancel() if err != nil { - errStr := err.Error() - if strings.Contains(errStr, "EOF") || strings.Contains(errStr, "client is closing") { - m.logger.Warn("获取外部MCP工具数量失败(SSE 流已关闭或服务端未在流上返回 tools/list 响应)", - zap.String("name", n), - zap.String("hint", "若为 SSE 连接,请确认服务端保持 GET 流打开并按 MCP 规范以 event: message 推送 JSON-RPC 响应"), - zap.Error(err), - ) - } else { + if !isConnectionDeadError(err) { m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list", zap.String("name", n), zap.Error(err), @@ -1181,6 +1213,8 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa zap.String("name", name), ) + m.onClientConnected(name) + // 连接成功,触发工具数量刷新和工具列表缓存刷新 m.triggerToolCountRefresh() m.mu.RLock() @@ -1265,6 +1299,7 @@ func (m *ExternalMCPManager) StopAll() { for name, client := range m.clients { client.Close() delete(m.clients, name) + m.clearReconnectState(name) } // 清理所有工具数量缓存