diff --git a/internal/mcp/client_sdk.go b/internal/mcp/client_sdk.go index bfbbcb15..04e73d23 100644 --- a/internal/mcp/client_sdk.go +++ b/internal/mcp/client_sdk.go @@ -44,11 +44,12 @@ func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, log // lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient type lazySDKClient struct { - serverCfg config.ExternalMCPServerConfig - logger *zap.Logger - inner ExternalMCPClient // 连接成功后为 *sdkClient - mu sync.RWMutex - status string + serverCfg config.ExternalMCPServerConfig + logger *zap.Logger + sessionCancel context.CancelFunc + inner ExternalMCPClient // connected SDK client + mu sync.RWMutex + status string } func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient { @@ -92,14 +93,61 @@ func (c *lazySDKClient) Initialize(ctx context.Context) error { } c.mu.Unlock() - inner, err := createSDKClient(ctx, c.serverCfg, c.logger) - if err != nil { + sessionCtx, sessionCancel := context.WithCancel(context.Background()) + type connectResult struct { + inner ExternalMCPClient + err error + } + resultCh := make(chan connectResult) + abandoned := make(chan struct{}) + go func() { + inner, err := createSDKClient(sessionCtx, c.serverCfg, c.logger) + select { + case resultCh <- connectResult{inner: inner, err: err}: + case <-abandoned: + if inner != nil { + _ = inner.Close() + } + sessionCancel() + } + }() + + var result connectResult + select { + case result = <-resultCh: + case <-ctx.Done(): + close(abandoned) + sessionCancel() + c.setStatus("error") + return ctx.Err() + } + + if err := ctx.Err(); err != nil { + sessionCancel() + if result.inner != nil { + _ = result.inner.Close() + } c.setStatus("error") return err } + if result.err != nil { + sessionCancel() + c.setStatus("error") + return result.err + } + c.mu.Lock() - c.inner = inner + if c.inner != nil { + c.mu.Unlock() + sessionCancel() + if result.inner != nil { + _ = result.inner.Close() + } + return nil + } + c.inner = result.inner + c.sessionCancel = sessionCancel c.mu.Unlock() c.setStatus("connected") return nil @@ -128,9 +176,14 @@ func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[stri func (c *lazySDKClient) Close() error { c.mu.Lock() inner := c.inner + sessionCancel := c.sessionCancel c.inner = nil + c.sessionCancel = nil c.mu.Unlock() c.setStatus("disconnected") + if sessionCancel != nil { + sessionCancel() + } if inner != nil { return inner.Close() }