diff --git a/mcp/builtin/constants.go b/mcp/builtin/constants.go deleted file mode 100644 index 29d2fad7..00000000 --- a/mcp/builtin/constants.go +++ /dev/null @@ -1,133 +0,0 @@ -package builtin - -// 内置工具名称常量 -// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串 -const ( - // 漏洞管理工具 - ToolRecordVulnerability = "record_vulnerability" - - // 知识库工具 - ToolListKnowledgeRiskTypes = "list_knowledge_risk_types" - ToolSearchKnowledgeBase = "search_knowledge_base" - - // WebShell 助手工具(AI 在 WebShell 管理 - AI 助手 中使用) - ToolWebshellExec = "webshell_exec" - ToolWebshellFileList = "webshell_file_list" - ToolWebshellFileRead = "webshell_file_read" - ToolWebshellFileWrite = "webshell_file_write" - - // WebShell 连接管理工具(用于通过 MCP 管理 webshell 连接) - ToolManageWebshellList = "manage_webshell_list" - ToolManageWebshellAdd = "manage_webshell_add" - ToolManageWebshellUpdate = "manage_webshell_update" - ToolManageWebshellDelete = "manage_webshell_delete" - ToolManageWebshellTest = "manage_webshell_test" - - // 批量任务队列(与 Web 端批量任务一致,供模型创建/启停/查询队列) - ToolBatchTaskList = "batch_task_list" - ToolBatchTaskGet = "batch_task_get" - ToolBatchTaskCreate = "batch_task_create" - ToolBatchTaskStart = "batch_task_start" - ToolBatchTaskRerun = "batch_task_rerun" - ToolBatchTaskPause = "batch_task_pause" - ToolBatchTaskDelete = "batch_task_delete" - ToolBatchTaskUpdateMetadata = "batch_task_update_metadata" - ToolBatchTaskUpdateSchedule = "batch_task_update_schedule" - ToolBatchTaskScheduleEnabled = "batch_task_schedule_enabled" - ToolBatchTaskAdd = "batch_task_add_task" - ToolBatchTaskUpdate = "batch_task_update_task" - ToolBatchTaskRemove = "batch_task_remove_task" - - // C2 工具集(合并同类项,8 个统一工具) - ToolC2Listener = "c2_listener" // 监听器管理(create/start/stop/list/get/update/delete) - ToolC2Session = "c2_session" // 会话管理(list/get/set_sleep/kill/delete) - ToolC2Task = "c2_task" // 任务下发(统一 task_type 参数) - ToolC2TaskManage = "c2_task_manage" // 任务管理(get_result/wait/list/cancel) - ToolC2Payload = "c2_payload" // Payload 生成(oneliner/build) - ToolC2Event = "c2_event" // 事件查询 - ToolC2Profile = "c2_profile" // Malleable Profile 管理(list/get/create/update/delete) - ToolC2File = "c2_file" // 文件管理(list/get_result) -) - -// IsBuiltinTool 检查工具名称是否是内置工具 -func IsBuiltinTool(toolName string) bool { - switch toolName { - case ToolRecordVulnerability, - ToolListKnowledgeRiskTypes, - ToolSearchKnowledgeBase, - ToolWebshellExec, - ToolWebshellFileList, - ToolWebshellFileRead, - ToolWebshellFileWrite, - ToolManageWebshellList, - ToolManageWebshellAdd, - ToolManageWebshellUpdate, - ToolManageWebshellDelete, - ToolManageWebshellTest, - ToolBatchTaskList, - ToolBatchTaskGet, - ToolBatchTaskCreate, - ToolBatchTaskStart, - ToolBatchTaskRerun, - ToolBatchTaskPause, - ToolBatchTaskDelete, - ToolBatchTaskUpdateMetadata, - ToolBatchTaskUpdateSchedule, - ToolBatchTaskScheduleEnabled, - ToolBatchTaskAdd, - ToolBatchTaskUpdate, - ToolBatchTaskRemove, - // C2 工具 - ToolC2Listener, - ToolC2Session, - ToolC2Task, - ToolC2TaskManage, - ToolC2Payload, - ToolC2Event, - ToolC2Profile, - ToolC2File: - return true - default: - return false - } -} - -// GetAllBuiltinTools 返回所有内置工具名称列表 -func GetAllBuiltinTools() []string { - return []string{ - ToolRecordVulnerability, - ToolListKnowledgeRiskTypes, - ToolSearchKnowledgeBase, - ToolWebshellExec, - ToolWebshellFileList, - ToolWebshellFileRead, - ToolWebshellFileWrite, - ToolManageWebshellList, - ToolManageWebshellAdd, - ToolManageWebshellUpdate, - ToolManageWebshellDelete, - ToolManageWebshellTest, - ToolBatchTaskList, - ToolBatchTaskGet, - ToolBatchTaskCreate, - ToolBatchTaskStart, - ToolBatchTaskRerun, - ToolBatchTaskPause, - ToolBatchTaskDelete, - ToolBatchTaskUpdateMetadata, - ToolBatchTaskUpdateSchedule, - ToolBatchTaskScheduleEnabled, - ToolBatchTaskAdd, - ToolBatchTaskUpdate, - ToolBatchTaskRemove, - // C2 工具 - ToolC2Listener, - ToolC2Session, - ToolC2Task, - ToolC2TaskManage, - ToolC2Payload, - ToolC2Event, - ToolC2Profile, - ToolC2File, - } -} diff --git a/mcp/client_sdk.go b/mcp/client_sdk.go deleted file mode 100644 index bfbbcb15..00000000 --- a/mcp/client_sdk.go +++ /dev/null @@ -1,405 +0,0 @@ -// Package mcp 外部 MCP 客户端 - 基于官方 go-sdk 实现,保证协议兼容性 -package mcp - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "os" - "os/exec" - "strings" - "sync" - "time" - - "cyberstrike-ai/internal/config" - - "github.com/modelcontextprotocol/go-sdk/mcp" - "go.uber.org/zap" -) - -const ( - clientName = "CyberStrikeAI" - clientVersion = "1.0.0" -) - -// sdkClient 基于官方 MCP Go SDK 的外部 MCP 客户端,实现 ExternalMCPClient 接口 -type sdkClient struct { - session *mcp.ClientSession - client *mcp.Client - logger *zap.Logger - mu sync.RWMutex - status string // "disconnected", "connecting", "connected", "error" -} - -// newSDKClientFromSession 用已连接成功的 session 构造(供 createSDKClient 内部使用) -func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, logger *zap.Logger) *sdkClient { - return &sdkClient{ - session: session, - client: client, - logger: logger, - status: "connected", - } -} - -// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient -type lazySDKClient struct { - serverCfg config.ExternalMCPServerConfig - logger *zap.Logger - inner ExternalMCPClient // 连接成功后为 *sdkClient - mu sync.RWMutex - status string -} - -func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient { - return &lazySDKClient{ - serverCfg: serverCfg, - logger: logger, - status: "connecting", - } -} - -func (c *lazySDKClient) setStatus(s string) { - c.mu.Lock() - defer c.mu.Unlock() - c.status = s -} - -func (c *lazySDKClient) GetStatus() string { - c.mu.RLock() - defer c.mu.RUnlock() - if c.inner != nil { - return c.inner.GetStatus() - } - return c.status -} - -func (c *lazySDKClient) IsConnected() bool { - c.mu.RLock() - inner := c.inner - c.mu.RUnlock() - if inner != nil { - return inner.IsConnected() - } - return false -} - -func (c *lazySDKClient) Initialize(ctx context.Context) error { - c.mu.Lock() - if c.inner != nil { - c.mu.Unlock() - return nil - } - c.mu.Unlock() - - inner, err := createSDKClient(ctx, c.serverCfg, c.logger) - if err != nil { - c.setStatus("error") - return err - } - - c.mu.Lock() - c.inner = inner - c.mu.Unlock() - c.setStatus("connected") - return nil -} - -func (c *lazySDKClient) ListTools(ctx context.Context) ([]Tool, error) { - c.mu.RLock() - inner := c.inner - c.mu.RUnlock() - if inner == nil { - return nil, fmt.Errorf("未连接") - } - return inner.ListTools(ctx) -} - -func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { - c.mu.RLock() - inner := c.inner - c.mu.RUnlock() - if inner == nil { - return nil, fmt.Errorf("未连接") - } - return inner.CallTool(ctx, name, args) -} - -func (c *lazySDKClient) Close() error { - c.mu.Lock() - inner := c.inner - c.inner = nil - c.mu.Unlock() - c.setStatus("disconnected") - if inner != nil { - return inner.Close() - } - return nil -} - -func (c *sdkClient) setStatus(s string) { - c.mu.Lock() - defer c.mu.Unlock() - c.status = s -} - -func (c *sdkClient) GetStatus() string { - c.mu.RLock() - defer c.mu.RUnlock() - return c.status -} - -func (c *sdkClient) IsConnected() bool { - return c.GetStatus() == "connected" -} - -func (c *sdkClient) Initialize(ctx context.Context) error { - // sdkClient 由 createSDKClient 在 Connect 成功后才创建,因此 Initialize 时已经连接 - // 此方法仅用于满足 ExternalMCPClient 接口,实际连接在 createSDKClient 中完成 - return nil -} - -func (c *sdkClient) ListTools(ctx context.Context) ([]Tool, error) { - if c.session == nil { - return nil, fmt.Errorf("未连接") - } - res, err := c.session.ListTools(ctx, nil) - if err != nil { - return nil, err - } - if res == nil { - return nil, nil - } - return sdkToolsToOur(res.Tools), nil -} - -func (c *sdkClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { - if c.session == nil { - return nil, fmt.Errorf("未连接") - } - params := &mcp.CallToolParams{ - Name: name, - Arguments: args, - } - res, err := c.session.CallTool(ctx, params) - if err != nil { - return nil, err - } - return sdkCallToolResultToOurs(res), nil -} - -func (c *sdkClient) Close() error { - c.setStatus("disconnected") - if c.session != nil { - err := c.session.Close() - c.session = nil - return err - } - return nil -} - -// sdkToolsToOur 将 SDK 的 []*mcp.Tool 转为我们的 []Tool -func sdkToolsToOur(tools []*mcp.Tool) []Tool { - if len(tools) == 0 { - return nil - } - out := make([]Tool, 0, len(tools)) - for _, t := range tools { - if t == nil { - continue - } - schema := make(map[string]interface{}) - if t.InputSchema != nil { - // SDK InputSchema 可能为 *jsonschema.Schema 或 map,统一转为 map - if m, ok := t.InputSchema.(map[string]interface{}); ok { - schema = m - } else { - _ = json.Unmarshal(mustJSON(t.InputSchema), &schema) - } - } - desc := t.Description - shortDesc := desc - if t.Annotations != nil && t.Annotations.Title != "" { - shortDesc = t.Annotations.Title - } - out = append(out, Tool{ - Name: t.Name, - Description: desc, - ShortDescription: shortDesc, - InputSchema: schema, - }) - } - return out -} - -// sdkCallToolResultToOurs 将 SDK 的 *mcp.CallToolResult 转为我们的 *ToolResult -func sdkCallToolResultToOurs(res *mcp.CallToolResult) *ToolResult { - if res == nil { - return &ToolResult{Content: []Content{}} - } - content := sdkContentToOurs(res.Content) - return &ToolResult{ - Content: content, - IsError: res.IsError, - } -} - -func sdkContentToOurs(list []mcp.Content) []Content { - if len(list) == 0 { - return nil - } - out := make([]Content, 0, len(list)) - for _, c := range list { - switch v := c.(type) { - case *mcp.TextContent: - out = append(out, Content{Type: "text", Text: v.Text}) - default: - out = append(out, Content{Type: "text", Text: fmt.Sprintf("%v", c)}) - } - } - return out -} - -func mustJSON(v interface{}) []byte { - b, _ := json.Marshal(v) - return b -} - -// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient -// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。 -func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) { - timeout := time.Duration(serverCfg.Timeout) * time.Second - if timeout <= 0 { - timeout = 30 * time.Second - } - - transport := serverCfg.GetTransportType() - if transport == "" { - return nil, fmt.Errorf("配置缺少 command 或 url,且未指定 type/transport") - } - - // 构造 ClientOptions:KeepAlive 心跳 - var clientOpts *mcp.ClientOptions - if serverCfg.KeepAlive > 0 { - clientOpts = &mcp.ClientOptions{ - KeepAlive: time.Duration(serverCfg.KeepAlive) * time.Second, - } - } - - client := mcp.NewClient(&mcp.Implementation{ - Name: clientName, - Version: clientVersion, - }, clientOpts) - - var t mcp.Transport - switch transport { - case "stdio": - if serverCfg.Command == "" { - return nil, fmt.Errorf("stdio 模式需要配置 command") - } - // 必须用 exec.Command 而非 CommandContext:doConnect 返回后 ctx 会被 cancel, - // 若用 CommandContext(ctx) 会立刻杀掉子进程,导致 ListTools 等后续请求失败、显示 0 工具 - cmd := exec.Command(serverCfg.Command, serverCfg.Args...) - if len(serverCfg.Env) > 0 { - cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...) - } - ct := &mcp.CommandTransport{Command: cmd} - if serverCfg.TerminateDuration > 0 { - ct.TerminateDuration = time.Duration(serverCfg.TerminateDuration) * time.Second - } - t = ct - case "sse": - if serverCfg.URL == "" { - return nil, fmt.Errorf("sse 模式需要配置 url") - } - // SSE 是长连接(GET 流持续打开),不能设置 http.Client.Timeout(会在超时后杀掉整个连接导致 EOF)。 - // 超时由每次 ListTools/CallTool 的 context 单独控制。 - httpClient := httpClientForLongLived(serverCfg.Headers) - t = &mcp.SSEClientTransport{ - Endpoint: serverCfg.URL, - HTTPClient: httpClient, - } - case "http": - if serverCfg.URL == "" { - return nil, fmt.Errorf("http 模式需要配置 url") - } - httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers) - st := &mcp.StreamableClientTransport{ - Endpoint: serverCfg.URL, - HTTPClient: httpClient, - } - if serverCfg.MaxRetries > 0 { - st.MaxRetries = serverCfg.MaxRetries - } - t = st - default: - return nil, fmt.Errorf("不支持的传输模式: %s(支持: stdio, sse, http)", transport) - } - - session, err := client.Connect(ctx, t, nil) - if err != nil { - return nil, fmt.Errorf("连接失败: %w", err) - } - - return newSDKClientFromSession(session, client, logger), nil -} - -func envMapToSlice(env map[string]string) []string { - m := make(map[string]string) - for _, s := range os.Environ() { - if i := strings.IndexByte(s, '='); i > 0 { - m[s[:i]] = s[i+1:] - } - } - for k, v := range env { - m[k] = v - } - out := make([]string, 0, len(m)) - for k, v := range m { - out = append(out, k+"="+v) - } - return out -} - -func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]string) *http.Client { - transport := http.DefaultTransport - if len(headers) > 0 { - transport = &headerRoundTripper{ - headers: headers, - base: http.DefaultTransport, - } - } - return &http.Client{ - Timeout: timeout, - Transport: transport, - } -} - -// httpClientForLongLived 创建不设超时的 HTTP 客户端,用于 SSE 等长连接传输。 -// SSE 的 GET 流会持续打开,http.Client.Timeout 会在超时后强制关闭连接导致 EOF。 -// 超时由调用方通过 context 控制。 -func httpClientForLongLived(headers map[string]string) *http.Client { - transport := http.DefaultTransport - if len(headers) > 0 { - transport = &headerRoundTripper{ - headers: headers, - base: http.DefaultTransport, - } - } - return &http.Client{ - Transport: transport, - // 不设 Timeout,SSE 长连接的超时由 per-request context 控制 - } -} - -type headerRoundTripper struct { - headers map[string]string - base http.RoundTripper -} - -func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - for k, v := range h.headers { - req.Header.Set(k, v) - } - return h.base.RoundTrip(req) -} diff --git a/mcp/external_manager.go b/mcp/external_manager.go deleted file mode 100644 index 036f243a..00000000 --- a/mcp/external_manager.go +++ /dev/null @@ -1,1182 +0,0 @@ -package mcp - -import ( - "context" - "fmt" - "strings" - "sync" - "sync/atomic" - "time" - - "cyberstrike-ai/internal/config" - - "github.com/google/uuid" - - "go.uber.org/zap" -) - -// ExternalMCPManager 外部MCP管理器 -type ExternalMCPManager struct { - clients map[string]ExternalMCPClient - configs map[string]config.ExternalMCPServerConfig - logger *zap.Logger - storage MonitorStorage // 可选的持久化存储 - executions map[string]*ToolExecution // 执行记录 - stats map[string]*ToolStats // 工具统计信息 - errors map[string]string // 错误信息 - toolCounts map[string]int // 工具数量缓存 - toolCountsMu sync.RWMutex // 工具数量缓存的锁 - toolCache map[string][]Tool // 工具列表缓存:MCP名称 -> 工具列表 - toolCacheMu sync.RWMutex // 工具列表缓存的锁 - stopRefresh chan struct{} // 停止后台刷新的信号 - refreshWg sync.WaitGroup // 等待后台刷新goroutine完成 - refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积 - mu sync.RWMutex - runningCancels map[string]context.CancelFunc - abortUserNotes map[string]string -} - -// NewExternalMCPManager 创建外部MCP管理器 -func NewExternalMCPManager(logger *zap.Logger) *ExternalMCPManager { - return NewExternalMCPManagerWithStorage(logger, nil) -} - -// NewExternalMCPManagerWithStorage 创建外部MCP管理器(带持久化存储) -func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage) *ExternalMCPManager { - manager := &ExternalMCPManager{ - clients: make(map[string]ExternalMCPClient), - configs: make(map[string]config.ExternalMCPServerConfig), - logger: logger, - storage: storage, - executions: make(map[string]*ToolExecution), - stats: make(map[string]*ToolStats), - errors: make(map[string]string), - toolCounts: make(map[string]int), - toolCache: make(map[string][]Tool), - stopRefresh: make(chan struct{}), - runningCancels: make(map[string]context.CancelFunc), - abortUserNotes: make(map[string]string), - } - // 启动后台刷新工具数量的goroutine - manager.startToolCountRefresh() - return manager -} - -// LoadConfigs 加载配置 -func (m *ExternalMCPManager) LoadConfigs(cfg *config.ExternalMCPConfig) { - m.mu.Lock() - defer m.mu.Unlock() - - if cfg == nil || cfg.Servers == nil { - return - } - - m.configs = make(map[string]config.ExternalMCPServerConfig) - for name, serverCfg := range cfg.Servers { - m.configs[name] = serverCfg - } -} - -// GetConfigs 获取所有配置 -func (m *ExternalMCPManager) GetConfigs() map[string]config.ExternalMCPServerConfig { - m.mu.RLock() - defer m.mu.RUnlock() - - result := make(map[string]config.ExternalMCPServerConfig) - for k, v := range m.configs { - result[k] = v - } - return result -} - -// AddOrUpdateConfig 添加或更新配置 -func (m *ExternalMCPManager) AddOrUpdateConfig(name string, serverCfg config.ExternalMCPServerConfig) error { - m.mu.Lock() - defer m.mu.Unlock() - - // 如果已存在客户端,先关闭 - if client, exists := m.clients[name]; exists { - client.Close() - delete(m.clients, name) - } - - m.configs[name] = serverCfg - - // 如果启用,自动连接 - if m.isEnabled(serverCfg) { - go m.connectClient(name, serverCfg) - } - - return nil -} - -// RemoveConfig 移除配置 -func (m *ExternalMCPManager) RemoveConfig(name string) error { - m.mu.Lock() - defer m.mu.Unlock() - - // 关闭客户端 - if client, exists := m.clients[name]; exists { - client.Close() - delete(m.clients, name) - } - - delete(m.configs, name) - - // 清理工具数量缓存 - m.toolCountsMu.Lock() - delete(m.toolCounts, name) - m.toolCountsMu.Unlock() - - // 清理工具列表缓存 - m.toolCacheMu.Lock() - delete(m.toolCache, name) - m.toolCacheMu.Unlock() - - return nil -} - -// StartClient 启动客户端 -func (m *ExternalMCPManager) StartClient(name string) error { - m.mu.Lock() - serverCfg, exists := m.configs[name] - m.mu.Unlock() - - if !exists { - return fmt.Errorf("配置不存在: %s", name) - } - - // 检查是否已经有连接的客户端 - m.mu.RLock() - existingClient, hasClient := m.clients[name] - m.mu.RUnlock() - - if hasClient { - // 检查客户端是否已连接 - if existingClient.IsConnected() { - // 客户端已连接,直接返回成功(目标状态已达成) - // 更新配置为启用(确保配置一致) - m.mu.Lock() - serverCfg.ExternalMCPEnable = true - m.configs[name] = serverCfg - m.mu.Unlock() - return nil - } - // 如果有客户端但未连接,先关闭 - existingClient.Close() - m.mu.Lock() - delete(m.clients, name) - m.mu.Unlock() - } - - // 更新配置为启用 - m.mu.Lock() - serverCfg.ExternalMCPEnable = true - m.configs[name] = serverCfg - // 清除之前的错误信息(重新启动时) - delete(m.errors, name) - m.mu.Unlock() - - // 立即创建客户端并设置为"connecting"状态,这样前端可以立即看到状态 - client := m.createClient(serverCfg) - if client == nil { - return fmt.Errorf("无法创建客户端:不支持的传输模式") - } - - // 设置状态为connecting - m.setClientStatus(client, "connecting") - - // 立即保存客户端,这样前端查询时就能看到"connecting"状态 - m.mu.Lock() - m.clients[name] = client - m.mu.Unlock() - - // 在后台异步进行实际连接 - go func() { - if err := m.doConnect(name, serverCfg, client); err != nil { - m.logger.Error("连接外部MCP客户端失败", - zap.String("name", name), - zap.Error(err), - ) - // 连接失败,设置状态为error并保存错误信息 - m.setClientStatus(client, "error") - m.mu.Lock() - m.errors[name] = err.Error() - m.mu.Unlock() - // 触发工具数量刷新(连接失败,工具数量应为0) - m.triggerToolCountRefresh() - } else { - // 连接成功,清除错误信息 - m.mu.Lock() - delete(m.errors, name) - m.mu.Unlock() - // 立即刷新工具数量和工具列表缓存 - m.triggerToolCountRefresh() - m.refreshToolCache(name, client) - // 2 秒后再刷新一次,覆盖 SSE/Streamable 等需稍等就绪的远端 - go func() { - time.Sleep(2 * time.Second) - m.triggerToolCountRefresh() - m.refreshToolCache(name, client) - }() - } - }() - - return nil -} - -// StopClient 停止客户端 -func (m *ExternalMCPManager) StopClient(name string) error { - m.mu.Lock() - defer m.mu.Unlock() - - serverCfg, exists := m.configs[name] - if !exists { - return fmt.Errorf("配置不存在: %s", name) - } - - // 关闭客户端 - if client, exists := m.clients[name]; exists { - client.Close() - delete(m.clients, name) - } - - // 清除错误信息 - delete(m.errors, name) - - // 更新工具数量缓存(停止后工具数量为0) - m.toolCountsMu.Lock() - m.toolCounts[name] = 0 - m.toolCountsMu.Unlock() - - // 更新配置为禁用 - serverCfg.ExternalMCPEnable = false - m.configs[name] = serverCfg - - return nil -} - -// GetClient 获取客户端 -func (m *ExternalMCPManager) GetClient(name string) (ExternalMCPClient, bool) { - m.mu.RLock() - defer m.mu.RUnlock() - - client, exists := m.clients[name] - return client, exists -} - -// GetError 获取错误信息 -func (m *ExternalMCPManager) GetError(name string) string { - m.mu.RLock() - defer m.mu.RUnlock() - - return m.errors[name] -} - -// GetAllTools 获取所有外部MCP的工具 -// 优先从已连接的客户端获取,如果连接断开则返回缓存的工具列表 -// 策略: -// - error 状态:不使用缓存,直接跳过(配置错误或服务不可用) -// - disconnected/connecting 状态:使用缓存(临时断开) -// - connected 状态:正常获取,失败时降级使用缓存 -func (m *ExternalMCPManager) GetAllTools(ctx context.Context) ([]Tool, error) { - m.mu.RLock() - clients := make(map[string]ExternalMCPClient) - for k, v := range m.clients { - clients[k] = v - } - m.mu.RUnlock() - - var allTools []Tool - var hasError bool - var lastError error - - // 使用较短的超时时间进行快速检查(3秒),避免阻塞 - quickCtx, quickCancel := context.WithTimeout(ctx, 3*time.Second) - defer quickCancel() - - for name, client := range clients { - tools, err := m.getToolsForClient(name, client, quickCtx) - if err != nil { - // 记录错误,但继续处理其他客户端 - hasError = true - if lastError == nil { - lastError = err - } - continue - } - - // 为工具添加前缀,避免冲突 - for _, tool := range tools { - tool.Name = fmt.Sprintf("%s::%s", name, tool.Name) - allTools = append(allTools, tool) - } - } - - // 如果有错误但至少返回了一些工具,不返回错误(部分成功) - if hasError && len(allTools) == 0 { - return nil, fmt.Errorf("获取外部MCP工具失败: %w", lastError) - } - - return allTools, nil -} - -// getToolsForClient 获取指定客户端的工具列表 -// 返回工具列表和错误(如果完全无法获取) -func (m *ExternalMCPManager) getToolsForClient(name string, client ExternalMCPClient, ctx context.Context) ([]Tool, error) { - status := client.GetStatus() - - // error 状态:不使用缓存,直接返回错误 - if status == "error" { - m.logger.Debug("跳过连接失败的外部MCP(不使用缓存)", - zap.String("name", name), - zap.String("status", status), - ) - return nil, fmt.Errorf("外部MCP连接失败: %s", name) - } - - // 已连接:尝试获取最新工具列表 - if client.IsConnected() { - tools, err := client.ListTools(ctx) - if err != nil { - // 获取失败,尝试使用缓存 - return m.getCachedTools(name, "连接正常但获取失败", err) - } - - // 获取成功,更新缓存 - m.updateToolCache(name, tools) - return tools, nil - } - - // 未连接:根据状态决定是否使用缓存 - if status == "disconnected" || status == "connecting" { - return m.getCachedTools(name, fmt.Sprintf("客户端临时断开(状态: %s)", status), nil) - } - - // 其他未知状态,不使用缓存 - m.logger.Debug("跳过外部MCP(未知状态)", - zap.String("name", name), - zap.String("status", status), - ) - return nil, fmt.Errorf("外部MCP状态未知: %s (状态: %s)", name, status) -} - -// getCachedTools 获取缓存的工具列表 -func (m *ExternalMCPManager) getCachedTools(name, reason string, originalErr error) ([]Tool, error) { - m.toolCacheMu.RLock() - cachedTools, hasCache := m.toolCache[name] - m.toolCacheMu.RUnlock() - - if hasCache && len(cachedTools) > 0 { - m.logger.Debug("使用缓存的工具列表", - zap.String("name", name), - zap.String("reason", reason), - zap.Int("count", len(cachedTools)), - zap.Error(originalErr), - ) - return cachedTools, nil - } - - // 无缓存,返回错误 - if originalErr != nil { - return nil, fmt.Errorf("获取外部MCP工具失败且无缓存: %w", originalErr) - } - return nil, fmt.Errorf("外部MCP无缓存工具: %s", name) -} - -// updateToolCache 更新工具列表缓存 -func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) { - m.toolCacheMu.Lock() - m.toolCache[name] = tools - m.toolCacheMu.Unlock() - - // 如果返回空列表,记录警告 - if len(tools) == 0 { - m.logger.Warn("外部MCP返回空工具列表", - zap.String("name", name), - zap.String("hint", "服务可能暂时不可用,工具列表为空"), - ) - } else { - m.logger.Debug("工具列表缓存已更新", - zap.String("name", name), - zap.Int("count", len(tools)), - ) - } -} - -// CallTool 调用外部MCP工具(返回执行ID) -func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) { - // 解析工具名称:name::toolName - var mcpName, actualToolName string - if idx := findSubstring(toolName, "::"); idx > 0 { - mcpName = toolName[:idx] - actualToolName = toolName[idx+2:] - } else { - return nil, "", fmt.Errorf("无效的工具名称格式: %s", toolName) - } - - client, exists := m.GetClient(mcpName) - if !exists { - return nil, "", fmt.Errorf("外部MCP客户端不存在: %s", mcpName) - } - - // 检查连接状态,如果未连接或状态为error,不允许调用 - if !client.IsConnected() { - status := client.GetStatus() - if status == "error" { - // 获取错误信息(如果有) - errorMsg := m.GetError(mcpName) - if errorMsg != "" { - return nil, "", fmt.Errorf("外部MCP连接失败: %s (错误: %s)", mcpName, errorMsg) - } - return nil, "", fmt.Errorf("外部MCP连接失败: %s", mcpName) - } - return nil, "", fmt.Errorf("外部MCP客户端未连接: %s (状态: %s)", mcpName, status) - } - - // 创建执行记录 - executionID := uuid.New().String() - execution := &ToolExecution{ - ID: executionID, - ToolName: toolName, // 使用完整工具名称(包含MCP名称) - Arguments: args, - Status: "running", - StartTime: time.Now(), - } - - m.mu.Lock() - m.executions[executionID] = execution - // 如果内存中的执行记录超过限制,清理最旧的记录 - m.cleanupOldExecutions() - m.mu.Unlock() - - if m.storage != nil { - if err := m.storage.SaveToolExecution(execution); err != nil { - m.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) - } - } - - execCtx, runCancel := context.WithCancel(ctx) - m.registerRunningCancel(executionID, runCancel) - notifyToolRunBegin(ctx, executionID) - defer func() { - notifyToolRunEnd(ctx, executionID) - runCancel() - m.unregisterRunningCancel(executionID) - }() - - // 调用工具 - result, err := client.CallTool(execCtx, actualToolName, args) - cancelledWithUserNote := m.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err) - - // 更新执行记录 - m.mu.Lock() - now := time.Now() - execution.EndTime = &now - execution.Duration = now.Sub(execution.StartTime) - - if err != nil { - st, msg := executionStatusAndMessage(err) - execution.Status = st - execution.Error = msg - } else if result != nil && result.IsError { - if cancelledWithUserNote { - execution.Status = "cancelled" - execution.Error = "" - execution.Result = result - } else { - execution.Status = "failed" - if len(result.Content) > 0 { - execution.Error = result.Content[0].Text - } else { - execution.Error = "工具执行返回错误结果" - } - execution.Result = result - } - } else { - execution.Status = "completed" - if result == nil { - result = &ToolResult{ - Content: []Content{ - {Type: "text", Text: "工具执行完成,但未返回结果"}, - }, - } - } - execution.Result = result - } - m.mu.Unlock() - - if m.storage != nil { - if err := m.storage.SaveToolExecution(execution); err != nil { - m.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) - } - } - - // 更新统计信息 - failed := err != nil || (result != nil && result.IsError) - m.updateStats(toolName, failed) - - // 如果使用存储,从内存中删除(已持久化) - if m.storage != nil { - m.mu.Lock() - delete(m.executions, executionID) - m.mu.Unlock() - } - - if err != nil { - return nil, executionID, err - } - - return result, executionID, nil -} - -func (m *ExternalMCPManager) applyAbortUserNoteToCancelledToolResult(executionID string, result **ToolResult, err *error) (cancelledWithUserNote bool) { - note := strings.TrimSpace(m.readAbortUserNote(executionID)) - if note == "" { - return false - } - hasErr := err != nil && *err != nil - hasRes := result != nil && *result != nil - if !hasErr && !hasRes { - return false - } - _ = m.takeAbortUserNote(executionID) - partial := "" - if hasRes { - partial = ToolResultPlainText(*result) - } - if partial == "" && hasErr { - partial = (*err).Error() - } - merged := MergePartialToolOutputAndAbortNote(partial, note) - *err = nil - *result = &ToolResult{Content: []Content{{Type: "text", Text: merged}}, IsError: true} - return true -} - -func (m *ExternalMCPManager) readAbortUserNote(id string) string { - m.mu.Lock() - defer m.mu.Unlock() - if m.abortUserNotes == nil { - return "" - } - return m.abortUserNotes[id] -} - -func (m *ExternalMCPManager) takeAbortUserNote(id string) string { - m.mu.Lock() - defer m.mu.Unlock() - if m.abortUserNotes == nil { - return "" - } - n := m.abortUserNotes[id] - delete(m.abortUserNotes, id) - return n -} - -// cleanupOldExecutions 清理旧的执行记录(保持内存中的记录数量在限制内) -func (m *ExternalMCPManager) cleanupOldExecutions() { - const maxExecutionsInMemory = 1000 - if len(m.executions) <= maxExecutionsInMemory { - return - } - - // 按开始时间排序,删除最旧的记录 - type execTime struct { - id string - startTime time.Time - } - var execs []execTime - for id, exec := range m.executions { - execs = append(execs, execTime{id: id, startTime: exec.StartTime}) - } - - // 按时间排序 - for i := 0; i < len(execs)-1; i++ { - for j := i + 1; j < len(execs); j++ { - if execs[i].startTime.After(execs[j].startTime) { - execs[i], execs[j] = execs[j], execs[i] - } - } - } - - // 删除最旧的记录 - toDelete := len(m.executions) - maxExecutionsInMemory - for i := 0; i < toDelete && i < len(execs); i++ { - delete(m.executions, execs[i].id) - } -} - -// GetExecution 获取执行记录(先从内存查找,再从数据库查找) -func (m *ExternalMCPManager) GetExecution(id string) (*ToolExecution, bool) { - m.mu.RLock() - exec, exists := m.executions[id] - m.mu.RUnlock() - - if exists { - return exec, true - } - - if m.storage != nil { - exec, err := m.storage.GetToolExecution(id) - if err == nil { - return exec, true - } - } - - return nil, false -} - -func (m *ExternalMCPManager) registerRunningCancel(id string, cancel context.CancelFunc) { - m.mu.Lock() - m.runningCancels[id] = cancel - m.mu.Unlock() -} - -func (m *ExternalMCPManager) unregisterRunningCancel(id string) { - m.mu.Lock() - delete(m.runningCancels, id) - m.mu.Unlock() -} - -// CancelToolExecutionWithNote 取消外部 MCP 工具;note 非空时与已返回输出合并后交给模型。 -func (m *ExternalMCPManager) CancelToolExecutionWithNote(id string, note string) bool { - m.mu.Lock() - cancel, ok := m.runningCancels[id] - if !ok || cancel == nil { - m.mu.Unlock() - return false - } - if strings.TrimSpace(note) != "" { - if m.abortUserNotes == nil { - m.abortUserNotes = make(map[string]string) - } - m.abortUserNotes[id] = strings.TrimSpace(note) - } - m.mu.Unlock() - cancel() - return true -} - -// CancelToolExecution 取消正在执行的外部 MCP 工具(无用户说明)。 -func (m *ExternalMCPManager) CancelToolExecution(id string) bool { - return m.CancelToolExecutionWithNote(id, "") -} - -// updateStats 更新统计信息 -func (m *ExternalMCPManager) updateStats(toolName string, failed bool) { - now := time.Now() - if m.storage != nil { - totalCalls := 1 - successCalls := 0 - failedCalls := 0 - if failed { - failedCalls = 1 - } else { - successCalls = 1 - } - if err := m.storage.UpdateToolStats(toolName, totalCalls, successCalls, failedCalls, &now); err != nil { - m.logger.Warn("保存统计信息到数据库失败", zap.Error(err)) - } - return - } - - m.mu.Lock() - defer m.mu.Unlock() - - if m.stats[toolName] == nil { - m.stats[toolName] = &ToolStats{ - ToolName: toolName, - } - } - - stats := m.stats[toolName] - stats.TotalCalls++ - stats.LastCallTime = &now - - if failed { - stats.FailedCalls++ - } else { - stats.SuccessCalls++ - } -} - -// GetStats 获取MCP服务器统计信息 -func (m *ExternalMCPManager) GetStats() map[string]interface{} { - m.mu.RLock() - defer m.mu.RUnlock() - - total := len(m.configs) - enabled := 0 - disabled := 0 - connected := 0 - - for name, cfg := range m.configs { - if m.isEnabled(cfg) { - enabled++ - if client, exists := m.clients[name]; exists && client.IsConnected() { - connected++ - } - } else { - disabled++ - } - } - - return map[string]interface{}{ - "total": total, - "enabled": enabled, - "disabled": disabled, - "connected": connected, - } -} - -// GetToolStats 获取工具统计信息(合并内存和数据库) -// 只返回外部MCP工具的统计信息(工具名称包含 "::") -func (m *ExternalMCPManager) GetToolStats() map[string]*ToolStats { - result := make(map[string]*ToolStats) - - // 从数据库加载统计信息(如果使用数据库存储) - if m.storage != nil { - dbStats, err := m.storage.LoadToolStats() - if err == nil { - // 只保留外部MCP工具的统计信息(工具名称包含 "::") - for k, v := range dbStats { - if findSubstring(k, "::") > 0 { - result[k] = v - } - } - } else { - m.logger.Warn("从数据库加载统计信息失败", zap.Error(err)) - } - } - - // 合并内存中的统计信息 - m.mu.RLock() - for k, v := range m.stats { - // 如果数据库中已有该工具的统计信息,合并它们 - if existing, exists := result[k]; exists { - // 创建新的统计信息对象,避免修改共享对象 - merged := &ToolStats{ - ToolName: k, - TotalCalls: existing.TotalCalls + v.TotalCalls, - SuccessCalls: existing.SuccessCalls + v.SuccessCalls, - FailedCalls: existing.FailedCalls + v.FailedCalls, - } - // 使用最新的调用时间 - if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) { - merged.LastCallTime = v.LastCallTime - } else if existing.LastCallTime != nil { - timeCopy := *existing.LastCallTime - merged.LastCallTime = &timeCopy - } - result[k] = merged - } else { - // 如果数据库中没有,直接使用内存中的统计信息 - statCopy := *v - result[k] = &statCopy - } - } - m.mu.RUnlock() - - return result -} - -// GetToolCount 获取指定外部MCP的工具数量(从缓存读取,不阻塞) -func (m *ExternalMCPManager) GetToolCount(name string) (int, error) { - // 先从缓存读取 - m.toolCountsMu.RLock() - if count, exists := m.toolCounts[name]; exists { - m.toolCountsMu.RUnlock() - return count, nil - } - m.toolCountsMu.RUnlock() - - // 如果缓存中没有,检查客户端状态 - client, exists := m.GetClient(name) - if !exists { - return 0, fmt.Errorf("客户端不存在: %s", name) - } - - if !client.IsConnected() { - // 未连接,缓存为0 - m.toolCountsMu.Lock() - m.toolCounts[name] = 0 - m.toolCountsMu.Unlock() - return 0, nil - } - - // 如果已连接但缓存中没有,触发异步刷新并返回0(避免阻塞) - m.triggerToolCountRefresh() - return 0, nil -} - -// GetToolCounts 获取所有外部MCP的工具数量(从缓存读取,不阻塞) -func (m *ExternalMCPManager) GetToolCounts() map[string]int { - m.toolCountsMu.RLock() - defer m.toolCountsMu.RUnlock() - - // 返回缓存的副本,避免外部修改 - result := make(map[string]int) - for k, v := range m.toolCounts { - result[k] = v - } - return result -} - -// refreshToolCounts 刷新工具数量缓存(后台异步执行) -// 使用 atomic flag 防止并发堆积:如果上一次刷新尚未完成,本次触发直接跳过。 -func (m *ExternalMCPManager) refreshToolCounts() { - if !m.refreshing.CompareAndSwap(false, true) { - return // 上一次刷新尚未完成,跳过 - } - defer m.refreshing.Store(false) - - m.mu.RLock() - clients := make(map[string]ExternalMCPClient) - for k, v := range m.clients { - clients[k] = v - } - m.mu.RUnlock() - - newCounts := make(map[string]int) - - // 使用goroutine并发获取每个客户端的工具数量,避免串行阻塞 - type countResult struct { - name string - count int - } - resultChan := make(chan countResult, len(clients)) - - for name, client := range clients { - go func(n string, c ExternalMCPClient) { - if !c.IsConnected() { - resultChan <- countResult{name: n, count: 0} - return - } - - // 使用合理的超时时间(15秒),既能应对网络延迟,又不会过长阻塞 - // 由于这是后台异步刷新,超时不会影响前端响应 - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - tools, err := c.ListTools(ctx) - cancel() - - if err != nil { - errStr := err.Error() - // SSE 连接 EOF:远端可能关闭了流或未按规范在流上推送响应,仅首次用 Warn 提示 - if strings.Contains(errStr, "EOF") || strings.Contains(errStr, "client is closing") { - m.logger.Warn("获取外部MCP工具数量失败(SSE 流已关闭或服务端未在流上返回 tools/list 响应)", - zap.String("name", n), - zap.String("hint", "若为 SSE 连接,请确认服务端保持 GET 流打开并按 MCP 规范以 event: message 推送 JSON-RPC 响应"), - zap.Error(err), - ) - } else { - m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list", - zap.String("name", n), - zap.Error(err), - ) - } - resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值 - return - } - - resultChan <- countResult{name: n, count: len(tools)} - }(name, client) - } - - // 收集结果 - m.toolCountsMu.RLock() - oldCounts := make(map[string]int) - for k, v := range m.toolCounts { - oldCounts[k] = v - } - m.toolCountsMu.RUnlock() - - for i := 0; i < len(clients); i++ { - result := <-resultChan - if result.count >= 0 { - newCounts[result.name] = result.count - } else { - // 获取失败,保留旧值 - if oldCount, exists := oldCounts[result.name]; exists { - newCounts[result.name] = oldCount - } else { - newCounts[result.name] = 0 - } - } - } - - // 更新缓存 - m.toolCountsMu.Lock() - // 更新所有获取到的值 - for name, count := range newCounts { - m.toolCounts[name] = count - } - // 对于未连接的客户端,设置为0 - for name, client := range clients { - if !client.IsConnected() { - m.toolCounts[name] = 0 - } - } - m.toolCountsMu.Unlock() -} - -// refreshToolCache 刷新指定MCP的工具列表缓存 -func (m *ExternalMCPManager) refreshToolCache(name string, client ExternalMCPClient) { - if !client.IsConnected() { - return - } - - // 检查状态,如果是error状态,不更新缓存 - status := client.GetStatus() - if status == "error" { - m.logger.Debug("跳过刷新工具列表缓存(连接失败)", - zap.String("name", name), - zap.String("status", status), - ) - return - } - - // 使用较短的超时时间(5秒) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - tools, err := client.ListTools(ctx) - if err != nil { - m.logger.Debug("刷新工具列表缓存失败", - zap.String("name", name), - zap.Error(err), - ) - // 刷新失败时不更新缓存,保留旧缓存(如果有) - return - } - - // 使用统一的缓存更新方法 - m.updateToolCache(name, tools) -} - -// startToolCountRefresh 启动后台刷新工具数量的goroutine -func (m *ExternalMCPManager) startToolCountRefresh() { - m.refreshWg.Add(1) - go func() { - defer m.refreshWg.Done() - ticker := time.NewTicker(10 * time.Second) // 每10秒刷新一次 - defer ticker.Stop() - - // 立即执行一次刷新 - m.refreshToolCounts() - - for { - select { - case <-ticker.C: - m.refreshToolCounts() - case <-m.stopRefresh: - return - } - } - }() -} - -// triggerToolCountRefresh 触发立即刷新工具数量(异步) -func (m *ExternalMCPManager) triggerToolCountRefresh() { - go m.refreshToolCounts() -} - -// createClient 创建客户端(不连接)。统一使用官方 MCP Go SDK 的 lazy 客户端,连接在 Initialize 时完成。 -func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient { - transport := serverCfg.GetTransportType() - - switch transport { - case "http": - if serverCfg.URL == "" { - return nil - } - return newLazySDKClient(serverCfg, m.logger) - case "stdio": - if serverCfg.Command == "" { - return nil - } - return newLazySDKClient(serverCfg, m.logger) - case "sse": - if serverCfg.URL == "" { - return nil - } - return newLazySDKClient(serverCfg, m.logger) - default: - if transport == "" { - return nil - } - // 未知传输类型也尝试使用 lazy client - return newLazySDKClient(serverCfg, m.logger) - } -} - -// doConnect 执行实际连接 -func (m *ExternalMCPManager) doConnect(name string, serverCfg config.ExternalMCPServerConfig, client ExternalMCPClient) error { - timeout := time.Duration(serverCfg.Timeout) * time.Second - if timeout <= 0 { - timeout = 30 * time.Second - } - - // 初始化连接 - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - if err := client.Initialize(ctx); err != nil { - return err - } - - m.logger.Info("外部MCP客户端已连接", - zap.String("name", name), - ) - - return nil -} - -// setClientStatus 设置客户端状态(通过类型断言) -func (m *ExternalMCPManager) setClientStatus(client ExternalMCPClient, status string) { - if c, ok := client.(*lazySDKClient); ok { - c.setStatus(status) - } -} - -// connectClient 连接客户端(异步)- 保留用于向后兼容 -func (m *ExternalMCPManager) connectClient(name string, serverCfg config.ExternalMCPServerConfig) error { - client := m.createClient(serverCfg) - if client == nil { - return fmt.Errorf("无法创建客户端:不支持的传输模式") - } - - // 设置状态为connecting - m.setClientStatus(client, "connecting") - - // 初始化连接 - timeout := time.Duration(serverCfg.Timeout) * time.Second - if timeout <= 0 { - timeout = 30 * time.Second - } - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - if err := client.Initialize(ctx); err != nil { - m.logger.Error("初始化外部MCP客户端失败", - zap.String("name", name), - zap.Error(err), - ) - return err - } - - // 保存客户端 - m.mu.Lock() - m.clients[name] = client - m.mu.Unlock() - - m.logger.Info("外部MCP客户端已连接", - zap.String("name", name), - ) - - // 连接成功,触发工具数量刷新和工具列表缓存刷新 - m.triggerToolCountRefresh() - m.mu.RLock() - if client, exists := m.clients[name]; exists { - m.refreshToolCache(name, client) - } - m.mu.RUnlock() - - return nil -} - -// isEnabled 检查是否启用 -func (m *ExternalMCPManager) isEnabled(cfg config.ExternalMCPServerConfig) bool { - return cfg.ExternalMCPEnable -} - -// findSubstring 查找子字符串(简单实现) -func findSubstring(s, substr string) int { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return i - } - } - return -1 -} - -// StartAllEnabled 启动所有启用的客户端 -func (m *ExternalMCPManager) StartAllEnabled() { - m.mu.RLock() - configs := make(map[string]config.ExternalMCPServerConfig) - for k, v := range m.configs { - configs[k] = v - } - m.mu.RUnlock() - - for name, cfg := range configs { - if m.isEnabled(cfg) { - go func(n string, c config.ExternalMCPServerConfig) { - if err := m.connectClient(n, c); err != nil { - // 检查是否是连接被拒绝的错误(服务可能还没启动) - errStr := strings.ToLower(err.Error()) - isConnectionRefused := strings.Contains(errStr, "connection refused") || - strings.Contains(errStr, "dial tcp") || - strings.Contains(errStr, "connect: connection refused") - - if isConnectionRefused { - // 连接被拒绝,说明目标服务可能还没启动,这是正常的 - // 使用 Warn 级别,提示用户这是正常的,可以通过手动启动或等待服务启动后自动连接 - fields := []zap.Field{ - zap.String("name", n), - zap.String("message", "目标服务可能尚未启动,这是正常的。服务启动后可通过界面手动连接,或等待自动重试"), - zap.Error(err), - } - - transport := c.GetTransportType() - - if transport == "http" && c.URL != "" { - fields = append(fields, zap.String("url", c.URL)) - } else if transport == "stdio" && c.Command != "" { - fields = append(fields, zap.String("command", c.Command)) - } - - m.logger.Warn("外部MCP服务暂未就绪", fields...) - } else { - // 其他错误,使用 Error 级别 - m.logger.Error("启动外部MCP客户端失败", - zap.String("name", n), - zap.Error(err), - ) - } - } - }(name, cfg) - } - } -} - -// StopAll 停止所有客户端 -func (m *ExternalMCPManager) StopAll() { - m.mu.Lock() - defer m.mu.Unlock() - - for name, client := range m.clients { - client.Close() - delete(m.clients, name) - } - - // 清理所有工具数量缓存 - m.toolCountsMu.Lock() - m.toolCounts = make(map[string]int) - m.toolCountsMu.Unlock() - - // 清理所有工具列表缓存 - m.toolCacheMu.Lock() - m.toolCache = make(map[string][]Tool) - m.toolCacheMu.Unlock() - - // 停止后台刷新(使用 select 避免重复关闭 channel) - select { - case <-m.stopRefresh: - // 已经关闭,不需要再次关闭 - default: - close(m.stopRefresh) - m.refreshWg.Wait() - } -} diff --git a/mcp/external_manager_test.go b/mcp/external_manager_test.go deleted file mode 100644 index c7260f1d..00000000 --- a/mcp/external_manager_test.go +++ /dev/null @@ -1,235 +0,0 @@ -package mcp - -import ( - "context" - "testing" - "time" - - "cyberstrike-ai/internal/config" - - "go.uber.org/zap" -) - -func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) { - logger := zap.NewNop() - manager := NewExternalMCPManager(logger) - - // 测试添加stdio配置 - stdioCfg := config.ExternalMCPServerConfig{ - Command: "python3", - Args: []string{"/path/to/script.py"}, - Description: "Test stdio MCP", - Timeout: 30, - ExternalMCPEnable: true, - } - - err := manager.AddOrUpdateConfig("test-stdio", stdioCfg) - if err != nil { - t.Fatalf("添加stdio配置失败: %v", err) - } - - // 测试添加HTTP配置 - httpCfg := config.ExternalMCPServerConfig{ - Type: "http", - URL: "http://127.0.0.1:8081/mcp", - Description: "Test HTTP MCP", - Timeout: 30, - ExternalMCPEnable: false, - } - - err = manager.AddOrUpdateConfig("test-http", httpCfg) - if err != nil { - t.Fatalf("添加HTTP配置失败: %v", err) - } - - // 验证配置已保存 - configs := manager.GetConfigs() - if len(configs) != 2 { - t.Fatalf("期望2个配置,实际%d个", len(configs)) - } - - if configs["test-stdio"].Command != stdioCfg.Command { - t.Errorf("stdio配置命令不匹配") - } - - if configs["test-http"].URL != httpCfg.URL { - t.Errorf("HTTP配置URL不匹配") - } -} - -func TestExternalMCPManager_RemoveConfig(t *testing.T) { - logger := zap.NewNop() - manager := NewExternalMCPManager(logger) - - cfg := config.ExternalMCPServerConfig{ - Command: "python3", - ExternalMCPEnable: false, - } - - manager.AddOrUpdateConfig("test-remove", cfg) - - // 移除配置 - err := manager.RemoveConfig("test-remove") - if err != nil { - t.Fatalf("移除配置失败: %v", err) - } - - configs := manager.GetConfigs() - if _, exists := configs["test-remove"]; exists { - t.Error("配置应该已被移除") - } -} - -func TestExternalMCPManager_GetStats(t *testing.T) { - logger := zap.NewNop() - manager := NewExternalMCPManager(logger) - - // 添加多个配置 - manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ - Command: "python3", - ExternalMCPEnable: true, - }) - - manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ - URL: "http://127.0.0.1:8081/mcp", - ExternalMCPEnable: true, - }) - - manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ - Command: "python3", - ExternalMCPEnable: false, - }) - - stats := manager.GetStats() - - if stats["total"].(int) != 3 { - t.Errorf("期望总数3,实际%d", stats["total"]) - } - - if stats["enabled"].(int) != 2 { - t.Errorf("期望启用数2,实际%d", stats["enabled"]) - } - - if stats["disabled"].(int) != 1 { - t.Errorf("期望停用数1,实际%d", stats["disabled"]) - } -} - -func TestExternalMCPManager_LoadConfigs(t *testing.T) { - logger := zap.NewNop() - manager := NewExternalMCPManager(logger) - - externalMCPConfig := config.ExternalMCPConfig{ - Servers: map[string]config.ExternalMCPServerConfig{ - "loaded1": { - Command: "python3", - ExternalMCPEnable: true, - }, - "loaded2": { - URL: "http://127.0.0.1:8081/mcp", - ExternalMCPEnable: false, - }, - }, - } - - manager.LoadConfigs(&externalMCPConfig) - - configs := manager.GetConfigs() - if len(configs) != 2 { - t.Fatalf("期望2个配置,实际%d个", len(configs)) - } - - if configs["loaded1"].Command != "python3" { - t.Error("配置1加载失败") - } - - if configs["loaded2"].URL != "http://127.0.0.1:8081/mcp" { - t.Error("配置2加载失败") - } -} - -// TestLazySDKClient_InitializeFails 验证无效配置时 SDK 客户端 Initialize 失败并设置 error 状态 -func TestLazySDKClient_InitializeFails(t *testing.T) { - logger := zap.NewNop() - // 使用不存在的 HTTP 地址,Initialize 应失败 - cfg := config.ExternalMCPServerConfig{ - Type: "http", - URL: "http://127.0.0.1:19999/nonexistent", - Timeout: 2, - } - c := newLazySDKClient(cfg, logger) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - err := c.Initialize(ctx) - if err == nil { - t.Fatal("expected error when connecting to invalid server") - } - if c.GetStatus() != "error" { - t.Errorf("expected status error, got %s", c.GetStatus()) - } - c.Close() -} - -func TestExternalMCPManager_StartStopClient(t *testing.T) { - logger := zap.NewNop() - manager := NewExternalMCPManager(logger) - - // 添加一个禁用的配置 - cfg := config.ExternalMCPServerConfig{ - Command: "python3", - ExternalMCPEnable: false, - } - - manager.AddOrUpdateConfig("test-start-stop", cfg) - - // 尝试启动(可能会失败,因为没有真实的服务器) - err := manager.StartClient("test-start-stop") - if err != nil { - t.Logf("启动失败(可能是没有服务器): %v", err) - } - - // 停止 - err = manager.StopClient("test-start-stop") - if err != nil { - t.Fatalf("停止失败: %v", err) - } - - // 验证配置已更新为禁用 - configs := manager.GetConfigs() - if configs["test-start-stop"].ExternalMCPEnable { - t.Error("配置应该已被禁用") - } -} - -func TestExternalMCPManager_CallTool(t *testing.T) { - logger := zap.NewNop() - manager := NewExternalMCPManager(logger) - - // 测试调用不存在的工具 - _, _, err := manager.CallTool(context.Background(), "nonexistent::tool", map[string]interface{}{}) - if err == nil { - t.Error("应该返回错误") - } - - // 测试无效的工具名称格式 - _, _, err = manager.CallTool(context.Background(), "invalid-tool-name", map[string]interface{}{}) - if err == nil { - t.Error("应该返回错误(无效格式)") - } -} - -func TestExternalMCPManager_GetAllTools(t *testing.T) { - logger := zap.NewNop() - manager := NewExternalMCPManager(logger) - - ctx := context.Background() - tools, err := manager.GetAllTools(ctx) - if err != nil { - t.Fatalf("获取工具列表失败: %v", err) - } - - // 如果没有连接的客户端,应该返回空列表 - if len(tools) != 0 { - t.Logf("获取到%d个工具", len(tools)) - } -} diff --git a/mcp/run_context.go b/mcp/run_context.go deleted file mode 100644 index 48dac642..00000000 --- a/mcp/run_context.go +++ /dev/null @@ -1,77 +0,0 @@ -package mcp - -import ( - "context" - "strings" -) - -// ToolRunRegistry 在工具开始/结束时登记当前 executionId,供对话页「仅终止当前工具」与监控页共用取消逻辑。 -type ToolRunRegistry interface { - RegisterRunningTool(conversationID, executionID string) - UnregisterRunningTool(conversationID, executionID string) -} - -type toolRunRegistryCtxKey struct{} -type mcpConversationIDCtxKey struct{} - -// WithToolRunRegistry 将登记器注入 ctx(Eino / 原生 Agent 任务 ctx)。 -func WithToolRunRegistry(ctx context.Context, reg ToolRunRegistry) context.Context { - if ctx == nil || reg == nil { - return ctx - } - return context.WithValue(ctx, toolRunRegistryCtxKey{}, reg) -} - -// ToolRunRegistryFromContext 取出登记器(无则 nil)。 -func ToolRunRegistryFromContext(ctx context.Context) ToolRunRegistry { - if ctx == nil { - return nil - } - v, _ := ctx.Value(toolRunRegistryCtxKey{}).(ToolRunRegistry) - return v -} - -// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。 -func WithMCPConversationID(ctx context.Context, conversationID string) context.Context { - if ctx == nil { - return nil - } - id := strings.TrimSpace(conversationID) - if id == "" { - return ctx - } - return context.WithValue(ctx, mcpConversationIDCtxKey{}, id) -} - -// MCPConversationIDFromContext 读取对话 ID。 -func MCPConversationIDFromContext(ctx context.Context) string { - if ctx == nil { - return "" - } - v, _ := ctx.Value(mcpConversationIDCtxKey{}).(string) - return v -} - -func notifyToolRunBegin(ctx context.Context, executionID string) { - reg := ToolRunRegistryFromContext(ctx) - if reg == nil { - return - } - conv := MCPConversationIDFromContext(ctx) - if conv == "" || strings.TrimSpace(executionID) == "" { - return - } - reg.RegisterRunningTool(conv, executionID) -} - -func notifyToolRunEnd(ctx context.Context, executionID string) { - reg := ToolRunRegistryFromContext(ctx) - if reg == nil { - return - } - conv := MCPConversationIDFromContext(ctx) - if conv == "" || strings.TrimSpace(executionID) == "" { - return - } - reg.UnregisterRunningTool(conv, executionID) -} diff --git a/mcp/server.go b/mcp/server.go deleted file mode 100644 index 074beaa6..00000000 --- a/mcp/server.go +++ /dev/null @@ -1,1450 +0,0 @@ -package mcp - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "os" - "sort" - "strings" - "sync" - "time" - - "github.com/google/uuid" - "go.uber.org/zap" -) - -// MonitorStorage 监控数据存储接口 -type MonitorStorage interface { - SaveToolExecution(exec *ToolExecution) error - LoadToolExecutions() ([]*ToolExecution, error) - GetToolExecution(id string) (*ToolExecution, error) - SaveToolStats(toolName string, stats *ToolStats) error - LoadToolStats() (map[string]*ToolStats, error) - UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error -} - -// Server MCP服务器 -type Server struct { - tools map[string]ToolHandler - toolDefs map[string]Tool // 工具定义 - executions map[string]*ToolExecution - stats map[string]*ToolStats - prompts map[string]*Prompt // 提示词模板 - resources map[string]*Resource // 资源 - storage MonitorStorage // 可选的持久化存储 - mu sync.RWMutex - logger *zap.Logger - maxExecutionsInMemory int // 内存中最大执行记录数 - sseClients map[string]*sseClient - runningCancels map[string]context.CancelFunc - runningCancelsMu sync.Mutex - abortUserNotes map[string]string // 监控页终止时附带的用户说明,与 executionID 对应 - // httpToolTimeoutMinutes 同步 agent.tool_timeout_minutes,用于 POST /api/mcp 的 tools/call(不经 Agent 包装的路径)。 - // nil 表示未配置,沿用默认 30 分钟;指向 0 表示不限制;>0 为分钟数。 - httpToolTimeoutMinutes *int - httpToolTimeoutMu sync.RWMutex -} - -type sseClient struct { - id string - send chan []byte -} - -// ToolHandler 工具处理函数 -type ToolHandler func(ctx context.Context, args map[string]interface{}) (*ToolResult, error) - -func executionStatusAndMessage(err error) (status string, errMsg string) { - if errors.Is(err, context.Canceled) { - return "cancelled", "已手动终止(MCP 监控)" - } - return "failed", err.Error() -} - -// NewServer 创建新的MCP服务器 -func NewServer(logger *zap.Logger) *Server { - return NewServerWithStorage(logger, nil) -} - -// NewServerWithStorage 创建新的MCP服务器(带持久化存储) -func NewServerWithStorage(logger *zap.Logger, storage MonitorStorage) *Server { - s := &Server{ - tools: make(map[string]ToolHandler), - toolDefs: make(map[string]Tool), - executions: make(map[string]*ToolExecution), - stats: make(map[string]*ToolStats), - prompts: make(map[string]*Prompt), - resources: make(map[string]*Resource), - storage: storage, - logger: logger, - maxExecutionsInMemory: 1000, // 默认最多在内存中保留1000条执行记录 - sseClients: make(map[string]*sseClient), - runningCancels: make(map[string]context.CancelFunc), - abortUserNotes: make(map[string]string), - } - - // 初始化默认提示词和资源 - s.initDefaultPrompts() - s.initDefaultResources() - - return s -} - -// ConfigureHTTPToolCallTimeoutFromAgentMinutes 将 agent.tool_timeout_minutes 同步到经 HTTP POST /api/mcp 触发的 tools/call。 -// minutes<=0 表示不设置硬性截止时间(与配置「0 不限制」一致);minutes>0 为该次调用的最长等待时间。 -// 未调用前对 tools/call 使用默认 30 分钟(与历史硬编码一致)。 -func (s *Server) ConfigureHTTPToolCallTimeoutFromAgentMinutes(minutes int) { - if s == nil { - return - } - v := minutes - if v < 0 { - v = 0 - } - s.httpToolTimeoutMu.Lock() - defer s.httpToolTimeoutMu.Unlock() - s.httpToolTimeoutMinutes = &v -} - -func (s *Server) effectiveHTTPToolCallDeadline() (context.Context, context.CancelFunc) { - const defaultDur = 30 * time.Minute - if s == nil { - return context.WithTimeout(context.Background(), defaultDur) - } - s.httpToolTimeoutMu.RLock() - mPtr := s.httpToolTimeoutMinutes - s.httpToolTimeoutMu.RUnlock() - if mPtr == nil { - return context.WithTimeout(context.Background(), defaultDur) - } - if *mPtr <= 0 { - return context.WithCancel(context.Background()) - } - return context.WithTimeout(context.Background(), time.Duration(*mPtr)*time.Minute) -} - -// RegisterTool 注册工具 -func (s *Server) RegisterTool(tool Tool, handler ToolHandler) { - s.mu.Lock() - defer s.mu.Unlock() - s.tools[tool.Name] = handler - s.toolDefs[tool.Name] = tool - - // 自动为工具创建资源文档 - resourceURI := fmt.Sprintf("tool://%s", tool.Name) - s.resources[resourceURI] = &Resource{ - URI: resourceURI, - Name: fmt.Sprintf("%s工具文档", tool.Name), - Description: tool.Description, - MimeType: "text/plain", - } -} - -// ClearTools 清空所有工具(用于重新加载配置) -func (s *Server) ClearTools() { - s.mu.Lock() - defer s.mu.Unlock() - - // 清空工具和工具定义 - s.tools = make(map[string]ToolHandler) - s.toolDefs = make(map[string]Tool) - - // 清空工具相关的资源(保留其他资源) - newResources := make(map[string]*Resource) - for uri, resource := range s.resources { - // 保留非工具资源 - if !strings.HasPrefix(uri, "tool://") { - newResources[uri] = resource - } - } - s.resources = newResources -} - -// HandleHTTP 处理HTTP请求 -func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet && strings.Contains(r.Header.Get("Accept"), "text/event-stream") { - s.handleSSE(w, r) - return - } - - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - // 官方 MCP SSE 规范:带 sessionid 的 POST 表示消息发往该 SSE 会话,响应通过 SSE 流返回 - if sessionID := r.URL.Query().Get("sessionid"); sessionID != "" { - s.serveSSESessionMessage(w, r, sessionID) - return - } - - // 简单 POST:请求体为 JSON-RPC,响应在 body 中返回 - body, err := io.ReadAll(r.Body) - if err != nil { - s.sendError(w, nil, -32700, "Parse error", err.Error()) - return - } - - var msg Message - if err := json.Unmarshal(body, &msg); err != nil { - s.sendError(w, nil, -32700, "Parse error", err.Error()) - return - } - - response := s.handleMessage(&msg) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) -} - -// serveSSESessionMessage 处理发往 SSE 会话的 POST:读取 JSON-RPC 请求,处理后将响应通过该会话的 SSE 流推送 -func (s *Server) serveSSESessionMessage(w http.ResponseWriter, r *http.Request, sessionID string) { - s.mu.RLock() - client, exists := s.sseClients[sessionID] - s.mu.RUnlock() - if !exists || client == nil { - http.Error(w, "session not found", http.StatusNotFound) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "failed to read body", http.StatusBadRequest) - return - } - - var msg Message - if err := json.Unmarshal(body, &msg); err != nil { - http.Error(w, "failed to parse body", http.StatusBadRequest) - return - } - - response := s.handleMessage(&msg) - if response == nil { - w.WriteHeader(http.StatusAccepted) - return - } - - respBytes, err := json.Marshal(response) - if err != nil { - http.Error(w, "failed to encode response", http.StatusInternalServerError) - return - } - - select { - case client.send <- respBytes: - w.WriteHeader(http.StatusAccepted) - default: - http.Error(w, "session send buffer full", http.StatusServiceUnavailable) - } -} - -// handleSSE 处理 SSE 连接,兼容官方 MCP 2024-11-05 SSE 规范: -// 1. 首个事件必须为 event: endpoint,data 为客户端 POST 消息的 URL(含 sessionid) -// 2. 后续事件为 event: message,data 为 JSON-RPC 响应 -func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) { - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "Streaming unsupported", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - - sessionID := uuid.New().String() - client := &sseClient{ - id: sessionID, - send: make(chan []byte, 32), - } - - s.addSSEClient(client) - defer s.removeSSEClient(client.id) - - // 官方规范:首个事件为 endpoint,data 为消息端点 URL(客户端将向该 URL POST 请求) - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - if r.URL.Scheme != "" { - scheme = r.URL.Scheme - } - endpointURL := fmt.Sprintf("%s://%s%s?sessionid=%s", scheme, r.Host, r.URL.Path, sessionID) - fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpointURL) - flusher.Flush() - - ticker := time.NewTicker(15 * time.Second) - defer ticker.Stop() - - for { - select { - case <-r.Context().Done(): - return - case msg, ok := <-client.send: - if !ok { - return - } - fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg) - flusher.Flush() - case <-ticker.C: - fmt.Fprintf(w, ": ping\n\n") - flusher.Flush() - } - } -} - -// addSSEClient 注册SSE客户端 -func (s *Server) addSSEClient(client *sseClient) { - s.mu.Lock() - defer s.mu.Unlock() - s.sseClients[client.id] = client -} - -// removeSSEClient 移除SSE客户端 -func (s *Server) removeSSEClient(id string) { - s.mu.Lock() - defer s.mu.Unlock() - if client, exists := s.sseClients[id]; exists { - close(client.send) - delete(s.sseClients, id) - } -} - -// handleMessage 处理MCP消息 -func (s *Server) handleMessage(msg *Message) *Message { - // 检查是否是通知(notification)- 通知没有id字段,不需要响应 - isNotification := msg.ID.Value() == nil || msg.ID.String() == "" - - // 如果不是通知且ID为空,生成新的UUID - if !isNotification && msg.ID.String() == "" { - msg.ID = MessageID{value: uuid.New().String()} - } - - switch msg.Method { - case "initialize": - return s.handleInitialize(msg) - case "tools/list": - return s.handleListTools(msg) - case "tools/call": - return s.handleCallTool(msg) - case "prompts/list": - return s.handleListPrompts(msg) - case "prompts/get": - return s.handleGetPrompt(msg) - case "resources/list": - return s.handleListResources(msg) - case "resources/read": - return s.handleReadResource(msg) - case "sampling/request": - return s.handleSamplingRequest(msg) - case "notifications/initialized": - // 通知类型,不需要响应 - s.logger.Debug("收到 initialized 通知") - return nil - case "": - // 空方法名,可能是通知,不返回错误 - if isNotification { - s.logger.Debug("收到无方法名的通知消息") - return nil - } - fallthrough - default: - // 如果是通知,不返回错误响应 - if isNotification { - s.logger.Debug("收到未知通知", zap.String("method", msg.Method)) - return nil - } - // 对于请求,返回方法未找到错误 - return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: -32601, Message: "Method not found"}, - } - } -} - -// handleInitialize 处理初始化请求 -func (s *Server) handleInitialize(msg *Message) *Message { - var req InitializeRequest - if err := json.Unmarshal(msg.Params, &req); err != nil { - return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: -32602, Message: "Invalid params"}, - } - } - - response := InitializeResponse{ - ProtocolVersion: ProtocolVersion, - Capabilities: ServerCapabilities{ - Tools: map[string]interface{}{ - "listChanged": true, - }, - Prompts: map[string]interface{}{ - "listChanged": true, - }, - Resources: map[string]interface{}{ - "subscribe": true, - "listChanged": true, - }, - Sampling: map[string]interface{}{}, - }, - ServerInfo: ServerInfo{ - Name: "CyberStrikeAI", - Version: "1.0.0", - }, - } - - result, _ := json.Marshal(response) - return &Message{ - ID: msg.ID, - Type: MessageTypeResponse, - Version: "2.0", - Result: result, - } -} - -// handleListTools 处理列出工具请求 -func (s *Server) handleListTools(msg *Message) *Message { - s.mu.RLock() - tools := make([]Tool, 0, len(s.toolDefs)) - for _, tool := range s.toolDefs { - tools = append(tools, tool) - } - s.mu.RUnlock() - s.logger.Debug("tools/list 请求", zap.Int("返回工具数", len(tools))) - - response := ListToolsResponse{Tools: tools} - result, _ := json.Marshal(response) - return &Message{ - ID: msg.ID, - Type: MessageTypeResponse, - Version: "2.0", - Result: result, - } -} - -// handleCallTool 处理工具调用请求 -func (s *Server) handleCallTool(msg *Message) *Message { - var req CallToolRequest - if err := json.Unmarshal(msg.Params, &req); err != nil { - return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: -32602, Message: "Invalid params"}, - } - } - - executionID := uuid.New().String() - execution := &ToolExecution{ - ID: executionID, - ToolName: req.Name, - Arguments: req.Arguments, - Status: "running", - StartTime: time.Now(), - } - - s.mu.Lock() - s.executions[executionID] = execution - // 如果内存中的执行记录超过限制,清理最旧的记录 - s.cleanupOldExecutions() - s.mu.Unlock() - - if s.storage != nil { - if err := s.storage.SaveToolExecution(execution); err != nil { - s.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) - } - } - - s.mu.RLock() - handler, exists := s.tools[req.Name] - s.mu.RUnlock() - - if !exists { - execution.Status = "failed" - execution.Error = "Tool not found" - now := time.Now() - execution.EndTime = &now - execution.Duration = now.Sub(execution.StartTime) - - if s.storage != nil { - if err := s.storage.SaveToolExecution(execution); err != nil { - s.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) - } - s.mu.Lock() - delete(s.executions, executionID) - s.mu.Unlock() - } - - s.updateStats(req.Name, true) - - return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: -32601, Message: "Tool not found"}, - } - } - - baseCtx, timeoutCancel := s.effectiveHTTPToolCallDeadline() - defer timeoutCancel() - execCtx, runCancel := context.WithCancel(baseCtx) - s.registerRunningCancel(executionID, runCancel) - defer func() { - runCancel() - s.unregisterRunningCancel(executionID) - }() - - s.logger.Info("开始执行工具", - zap.String("toolName", req.Name), - zap.Any("arguments", req.Arguments), - ) - - result, err := handler(execCtx, req.Arguments) - cancelledWithUserNote := s.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err) - now := time.Now() - var failed bool - var finalResult *ToolResult - - s.mu.Lock() - execution.EndTime = &now - execution.Duration = now.Sub(execution.StartTime) - - if err != nil { - st, msg := executionStatusAndMessage(err) - execution.Status = st - execution.Error = msg - failed = true - } else if result != nil && result.IsError { - if cancelledWithUserNote { - execution.Status = "cancelled" - execution.Error = "" - execution.Result = result - failed = true - } else { - execution.Status = "failed" - if len(result.Content) > 0 { - execution.Error = result.Content[0].Text - } else { - execution.Error = "工具执行返回错误结果" - } - execution.Result = result - failed = true - } - } else { - execution.Status = "completed" - if result == nil { - result = &ToolResult{ - Content: []Content{ - {Type: "text", Text: "工具执行完成,但未返回结果"}, - }, - } - } - execution.Result = result - failed = false - } - - finalResult = execution.Result - s.mu.Unlock() - - if s.storage != nil { - if err := s.storage.SaveToolExecution(execution); err != nil { - s.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) - } - } - - s.updateStats(req.Name, failed) - - if s.storage != nil { - s.mu.Lock() - delete(s.executions, executionID) - s.mu.Unlock() - } - - if err != nil { - s.logger.Error("工具执行失败", - zap.String("toolName", req.Name), - zap.Error(err), - ) - - errText := fmt.Sprintf("工具执行失败: %v", err) - if errors.Is(err, context.Canceled) { - errText = "工具执行已手动终止(MCP 监控)。后续编排步骤可继续。" - } - errorResult, _ := json.Marshal(CallToolResponse{ - Content: []Content{ - {Type: "text", Text: errText}, - }, - IsError: true, - }) - return &Message{ - ID: msg.ID, - Type: MessageTypeResponse, - Version: "2.0", - Result: errorResult, - } - } - - if finalResult != nil && finalResult.IsError { - s.logger.Warn("工具执行返回错误结果", - zap.String("toolName", req.Name), - ) - - errorResult, _ := json.Marshal(CallToolResponse{ - Content: finalResult.Content, - IsError: true, - }) - return &Message{ - ID: msg.ID, - Type: MessageTypeResponse, - Version: "2.0", - Result: errorResult, - } - } - - if finalResult == nil { - finalResult = &ToolResult{ - Content: []Content{ - {Type: "text", Text: "工具执行完成,但未返回结果"}, - }, - } - } - - resultJSON, _ := json.Marshal(CallToolResponse{ - Content: finalResult.Content, - IsError: false, - }) - - s.logger.Info("工具执行完成", - zap.String("toolName", req.Name), - zap.Bool("isError", finalResult.IsError), - ) - - return &Message{ - ID: msg.ID, - Type: MessageTypeResponse, - Version: "2.0", - Result: resultJSON, - } -} - -// updateStats 更新统计信息 -func (s *Server) updateStats(toolName string, failed bool) { - now := time.Now() - if s.storage != nil { - totalCalls := 1 - successCalls := 0 - failedCalls := 0 - if failed { - failedCalls = 1 - } else { - successCalls = 1 - } - if err := s.storage.UpdateToolStats(toolName, totalCalls, successCalls, failedCalls, &now); err != nil { - s.logger.Warn("保存统计信息到数据库失败", zap.Error(err)) - } - return - } - - s.mu.Lock() - defer s.mu.Unlock() - - if s.stats[toolName] == nil { - s.stats[toolName] = &ToolStats{ - ToolName: toolName, - } - } - - stats := s.stats[toolName] - stats.TotalCalls++ - stats.LastCallTime = &now - - if failed { - stats.FailedCalls++ - } else { - stats.SuccessCalls++ - } -} - -// GetExecution 获取执行记录(先从内存查找,再从数据库查找) -func (s *Server) GetExecution(id string) (*ToolExecution, bool) { - s.mu.RLock() - exec, exists := s.executions[id] - s.mu.RUnlock() - - if exists { - return exec, true - } - - if s.storage != nil { - exec, err := s.storage.GetToolExecution(id) - if err == nil { - return exec, true - } - } - - return nil, false -} - -// loadHistoricalData 从数据库加载历史数据 -func (s *Server) loadHistoricalData() { - if s.storage == nil { - return - } - - // 加载历史执行记录(最近1000条) - executions, err := s.storage.LoadToolExecutions() - if err != nil { - s.logger.Warn("加载历史执行记录失败", zap.Error(err)) - } else { - s.mu.Lock() - for _, exec := range executions { - // 只加载最近 maxExecutionsInMemory 条,避免内存占用过大 - if len(s.executions) < s.maxExecutionsInMemory { - s.executions[exec.ID] = exec - } else { - break - } - } - s.mu.Unlock() - s.logger.Info("加载历史执行记录", zap.Int("count", len(executions))) - } - - // 加载历史统计信息 - stats, err := s.storage.LoadToolStats() - if err != nil { - s.logger.Warn("加载历史统计信息失败", zap.Error(err)) - } else { - s.mu.Lock() - for k, v := range stats { - s.stats[k] = v - } - s.mu.Unlock() - s.logger.Info("加载历史统计信息", zap.Int("count", len(stats))) - } -} - -// GetAllExecutions 获取所有执行记录(合并内存和数据库) -func (s *Server) GetAllExecutions() []*ToolExecution { - if s.storage != nil { - dbExecutions, err := s.storage.LoadToolExecutions() - if err == nil { - execMap := make(map[string]*ToolExecution) - for _, exec := range dbExecutions { - if _, exists := execMap[exec.ID]; !exists { - execMap[exec.ID] = exec - } - } - - s.mu.RLock() - for id, exec := range s.executions { - if _, exists := execMap[id]; !exists { - execMap[id] = exec - } - } - s.mu.RUnlock() - - result := make([]*ToolExecution, 0, len(execMap)) - for _, exec := range execMap { - result = append(result, exec) - } - return result - } else { - s.logger.Warn("从数据库加载执行记录失败", zap.Error(err)) - } - } - - s.mu.RLock() - defer s.mu.RUnlock() - - memExecutions := make([]*ToolExecution, 0, len(s.executions)) - for _, exec := range s.executions { - memExecutions = append(memExecutions, exec) - } - return memExecutions -} - -// GetStats 获取统计信息(合并内存和数据库) -func (s *Server) GetStats() map[string]*ToolStats { - if s.storage != nil { - dbStats, err := s.storage.LoadToolStats() - if err == nil { - return dbStats - } - s.logger.Warn("从数据库加载统计信息失败", zap.Error(err)) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - memStats := make(map[string]*ToolStats) - for k, v := range s.stats { - statCopy := *v - memStats[k] = &statCopy - } - - return memStats -} - -// GetAllTools 获取所有已注册的工具(用于Agent动态获取工具列表) -func (s *Server) GetAllTools() []Tool { - s.mu.RLock() - defer s.mu.RUnlock() - - tools := make([]Tool, 0, len(s.toolDefs)) - for _, tool := range s.toolDefs { - tools = append(tools, tool) - } - return tools -} - -// CallTool 直接调用工具(用于内部调用) -func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) { - s.mu.RLock() - handler, exists := s.tools[toolName] - s.mu.RUnlock() - - if !exists { - return nil, "", fmt.Errorf("工具 %s 未找到", toolName) - } - - // 创建执行记录 - executionID := uuid.New().String() - execution := &ToolExecution{ - ID: executionID, - ToolName: toolName, - Arguments: args, - Status: "running", - StartTime: time.Now(), - } - - s.mu.Lock() - s.executions[executionID] = execution - // 如果内存中的执行记录超过限制,清理最旧的记录 - s.cleanupOldExecutions() - s.mu.Unlock() - - if s.storage != nil { - if err := s.storage.SaveToolExecution(execution); err != nil { - s.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) - } - } - - execCtx, runCancel := context.WithCancel(ctx) - s.registerRunningCancel(executionID, runCancel) - notifyToolRunBegin(ctx, executionID) - defer func() { - notifyToolRunEnd(ctx, executionID) - runCancel() - s.unregisterRunningCancel(executionID) - }() - - result, err := handler(execCtx, args) - cancelledWithUserNote := s.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err) - - s.mu.Lock() - now := time.Now() - execution.EndTime = &now - execution.Duration = now.Sub(execution.StartTime) - var failed bool - var finalResult *ToolResult - - if err != nil { - st, msg := executionStatusAndMessage(err) - execution.Status = st - execution.Error = msg - failed = true - } else if result != nil && result.IsError { - if cancelledWithUserNote { - execution.Status = "cancelled" - execution.Error = "" - execution.Result = result - failed = true - finalResult = result - } else { - execution.Status = "failed" - if len(result.Content) > 0 { - execution.Error = result.Content[0].Text - } else { - execution.Error = "工具执行返回错误结果" - } - execution.Result = result - failed = true - finalResult = result - } - } else { - execution.Status = "completed" - if result == nil { - result = &ToolResult{ - Content: []Content{ - {Type: "text", Text: "工具执行完成,但未返回结果"}, - }, - } - } - execution.Result = result - finalResult = result - failed = false - } - - if finalResult == nil { - finalResult = execution.Result - } - s.mu.Unlock() - - if s.storage != nil { - if err := s.storage.SaveToolExecution(execution); err != nil { - s.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) - } - } - - s.updateStats(toolName, failed) - - if s.storage != nil { - s.mu.Lock() - delete(s.executions, executionID) - s.mu.Unlock() - } - - if err != nil { - return nil, executionID, err - } - - return finalResult, executionID, nil -} - -// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致), -// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。 -func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string { - if s == nil { - return "" - } - if args == nil { - args = map[string]interface{}{} - } - executionID := uuid.New().String() - now := time.Now() - failed := invokeErr != nil - exec := &ToolExecution{ - ID: executionID, - ToolName: toolName, - Arguments: args, - StartTime: now, - EndTime: &now, - Duration: 0, - } - if failed { - exec.Status = "failed" - exec.Error = invokeErr.Error() - if strings.TrimSpace(resultText) != "" { - exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}} - } - } else { - exec.Status = "completed" - text := resultText - if strings.TrimSpace(text) == "" { - text = "(无输出)" - } - exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: text}}} - } - if s.storage != nil { - if err := s.storage.SaveToolExecution(exec); err != nil { - s.logger.Warn("RecordCompletedToolInvocation 保存失败", zap.Error(err)) - } - } - s.updateStats(toolName, failed) - return executionID -} - -// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长 -func (s *Server) cleanupOldExecutions() { - if len(s.executions) <= s.maxExecutionsInMemory { - return - } - - // 按开始时间排序,找出最旧的记录 - type execWithTime struct { - id string - startTime time.Time - } - execs := make([]execWithTime, 0, len(s.executions)) - for id, exec := range s.executions { - execs = append(execs, execWithTime{ - id: id, - startTime: exec.StartTime, - }) - } - - // 使用 sort 包进行高效排序(最旧的在前) - sort.Slice(execs, func(i, j int) bool { - return execs[i].startTime.Before(execs[j].startTime) - }) - - // 删除最旧的记录,保留 maxExecutionsInMemory 条 - toDelete := len(s.executions) - s.maxExecutionsInMemory - for i := 0; i < toDelete; i++ { - delete(s.executions, execs[i].id) - } - - s.logger.Debug("清理旧的执行记录", - zap.Int("before", len(execs)), - zap.Int("after", len(s.executions)), - zap.Int("deleted", toDelete), - ) -} - -func (s *Server) registerRunningCancel(id string, cancel context.CancelFunc) { - s.runningCancelsMu.Lock() - s.runningCancels[id] = cancel - s.runningCancelsMu.Unlock() -} - -func (s *Server) unregisterRunningCancel(id string) { - s.runningCancelsMu.Lock() - delete(s.runningCancels, id) - s.runningCancelsMu.Unlock() -} - -func (s *Server) readAbortUserNote(id string) string { - s.runningCancelsMu.Lock() - defer s.runningCancelsMu.Unlock() - if s.abortUserNotes == nil { - return "" - } - return s.abortUserNotes[id] -} - -func (s *Server) takeAbortUserNote(id string) string { - s.runningCancelsMu.Lock() - defer s.runningCancelsMu.Unlock() - if s.abortUserNotes == nil { - return "" - } - n := s.abortUserNotes[id] - delete(s.abortUserNotes, id) - return n -} - -// applyAbortUserNoteToCancelledToolResult 监控页「终止并填写说明」时合并「工具已输出 + 用户说明」交给模型。 -// exec 等工具会把失败写在 *ToolResult 里并返回 err==nil,若仅在 err!=nil 时合并会漏掉说明,甚至误 clear 掉 note。 -func (s *Server) applyAbortUserNoteToCancelledToolResult(executionID string, result **ToolResult, err *error) (cancelledWithUserNote bool) { - note := strings.TrimSpace(s.readAbortUserNote(executionID)) - if note == "" { - return false - } - hasErr := err != nil && *err != nil - hasRes := result != nil && *result != nil - if !hasErr && !hasRes { - return false - } - _ = s.takeAbortUserNote(executionID) - partial := "" - if hasRes { - partial = ToolResultPlainText(*result) - } - if partial == "" && hasErr { - partial = (*err).Error() - } - merged := MergePartialToolOutputAndAbortNote(partial, note) - *err = nil - *result = &ToolResult{Content: []Content{{Type: "text", Text: merged}}, IsError: true} - return true -} - -// CancelToolExecutionWithNote 取消内部工具;note 非空时与工具已返回文本合并后交给上层模型。 -func (s *Server) CancelToolExecutionWithNote(id string, note string) bool { - s.runningCancelsMu.Lock() - cancel, ok := s.runningCancels[id] - if !ok || cancel == nil { - s.runningCancelsMu.Unlock() - return false - } - if strings.TrimSpace(note) != "" { - if s.abortUserNotes == nil { - s.abortUserNotes = make(map[string]string) - } - s.abortUserNotes[id] = strings.TrimSpace(note) - } - s.runningCancelsMu.Unlock() - cancel() - return true -} - -// CancelToolExecution 取消正在执行的内部工具调用(无用户说明)。 -func (s *Server) CancelToolExecution(id string) bool { - return s.CancelToolExecutionWithNote(id, "") -} - -// initDefaultPrompts 初始化默认提示词模板 -func (s *Server) initDefaultPrompts() { - s.mu.Lock() - defer s.mu.Unlock() - - // 网络安全测试提示词 - s.prompts["security_scan"] = &Prompt{ - Name: "security_scan", - Description: "生成网络安全扫描任务的提示词", - Arguments: []PromptArgument{ - {Name: "target", Description: "扫描目标(IP地址或域名)", Required: true}, - {Name: "scan_type", Description: "扫描类型(port, vuln, web等)", Required: false}, - }, - } - - // 渗透测试提示词 - s.prompts["penetration_test"] = &Prompt{ - Name: "penetration_test", - Description: "生成渗透测试任务的提示词", - Arguments: []PromptArgument{ - {Name: "target", Description: "测试目标", Required: true}, - {Name: "scope", Description: "测试范围", Required: false}, - }, - } -} - -// initDefaultResources 初始化默认资源 -// 注意:工具资源现在在 RegisterTool 时自动创建,此函数保留用于其他非工具资源 -func (s *Server) initDefaultResources() { - // 工具资源已改为在 RegisterTool 时自动创建,无需在此硬编码 -} - -// handleListPrompts 处理列出提示词请求 -func (s *Server) handleListPrompts(msg *Message) *Message { - s.mu.RLock() - prompts := make([]Prompt, 0, len(s.prompts)) - for _, prompt := range s.prompts { - prompts = append(prompts, *prompt) - } - s.mu.RUnlock() - - response := ListPromptsResponse{ - Prompts: prompts, - } - result, _ := json.Marshal(response) - return &Message{ - ID: msg.ID, - Type: MessageTypeResponse, - Version: "2.0", - Result: result, - } -} - -// handleGetPrompt 处理获取提示词请求 -func (s *Server) handleGetPrompt(msg *Message) *Message { - var req GetPromptRequest - if err := json.Unmarshal(msg.Params, &req); err != nil { - return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: -32602, Message: "Invalid params"}, - } - } - - s.mu.RLock() - prompt, exists := s.prompts[req.Name] - s.mu.RUnlock() - - if !exists { - return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: -32601, Message: "Prompt not found"}, - } - } - - // 根据提示词名称生成消息 - messages := s.generatePromptMessages(prompt, req.Arguments) - - response := GetPromptResponse{ - Messages: messages, - } - result, _ := json.Marshal(response) - return &Message{ - ID: msg.ID, - Type: MessageTypeResponse, - Version: "2.0", - Result: result, - } -} - -// generatePromptMessages 生成提示词消息 -func (s *Server) generatePromptMessages(prompt *Prompt, args map[string]interface{}) []PromptMessage { - messages := []PromptMessage{} - - switch prompt.Name { - case "security_scan": - target, _ := args["target"].(string) - scanType, _ := args["scan_type"].(string) - if scanType == "" { - scanType = "comprehensive" - } - - content := fmt.Sprintf(`请对目标 %s 执行%s安全扫描。包括: -1. 端口扫描和服务识别 -2. 漏洞检测 -3. Web应用安全测试 -4. 生成详细的安全报告`, target, scanType) - - messages = append(messages, PromptMessage{ - Role: "user", - Content: content, - }) - - case "penetration_test": - target, _ := args["target"].(string) - scope, _ := args["scope"].(string) - - content := fmt.Sprintf(`请对目标 %s 执行渗透测试。`, target) - if scope != "" { - content += fmt.Sprintf("测试范围:%s", scope) - } - content += "\n请按照OWASP Top 10进行全面的安全测试。" - - messages = append(messages, PromptMessage{ - Role: "user", - Content: content, - }) - - default: - messages = append(messages, PromptMessage{ - Role: "user", - Content: "请执行安全测试任务", - }) - } - - return messages -} - -// handleListResources 处理列出资源请求 -func (s *Server) handleListResources(msg *Message) *Message { - s.mu.RLock() - resources := make([]Resource, 0, len(s.resources)) - for _, resource := range s.resources { - resources = append(resources, *resource) - } - s.mu.RUnlock() - - response := ListResourcesResponse{ - Resources: resources, - } - result, _ := json.Marshal(response) - return &Message{ - ID: msg.ID, - Type: MessageTypeResponse, - Version: "2.0", - Result: result, - } -} - -// handleReadResource 处理读取资源请求 -func (s *Server) handleReadResource(msg *Message) *Message { - var req ReadResourceRequest - if err := json.Unmarshal(msg.Params, &req); err != nil { - return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: -32602, Message: "Invalid params"}, - } - } - - s.mu.RLock() - resource, exists := s.resources[req.URI] - s.mu.RUnlock() - - if !exists { - return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: -32601, Message: "Resource not found"}, - } - } - - // 生成资源内容 - content := s.generateResourceContent(resource) - - response := ReadResourceResponse{ - Contents: []ResourceContent{content}, - } - result, _ := json.Marshal(response) - return &Message{ - ID: msg.ID, - Type: MessageTypeResponse, - Version: "2.0", - Result: result, - } -} - -// generateResourceContent 生成资源内容 -func (s *Server) generateResourceContent(resource *Resource) ResourceContent { - content := ResourceContent{ - URI: resource.URI, - MimeType: resource.MimeType, - } - - // 如果是工具资源,生成详细文档 - if strings.HasPrefix(resource.URI, "tool://") { - toolName := strings.TrimPrefix(resource.URI, "tool://") - content.Text = s.generateToolDocumentation(toolName, resource) - } else { - // 其他资源使用描述或默认内容 - content.Text = resource.Description - } - - return content -} - -// generateToolDocumentation 生成工具文档 -// 注意:硬编码的工具文档已移除,现在只使用工具定义中的信息 -func (s *Server) generateToolDocumentation(toolName string, resource *Resource) string { - // 获取工具定义以获取更详细的信息 - s.mu.RLock() - tool, hasTool := s.toolDefs[toolName] - s.mu.RUnlock() - - // 使用工具定义中的描述信息 - if hasTool { - doc := fmt.Sprintf("%s\n\n", resource.Description) - if tool.InputSchema != nil { - if props, ok := tool.InputSchema["properties"].(map[string]interface{}); ok { - doc += "参数说明:\n" - for paramName, paramInfo := range props { - if paramMap, ok := paramInfo.(map[string]interface{}); ok { - if desc, ok := paramMap["description"].(string); ok { - doc += fmt.Sprintf("- %s: %s\n", paramName, desc) - } - } - } - } - } - return doc - } - return resource.Description -} - -// handleSamplingRequest 处理采样请求 -func (s *Server) handleSamplingRequest(msg *Message) *Message { - var req SamplingRequest - if err := json.Unmarshal(msg.Params, &req); err != nil { - return &Message{ - ID: msg.ID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: -32602, Message: "Invalid params"}, - } - } - - // 注意:采样功能通常需要连接到实际的LLM服务 - // 这里返回一个占位符响应,实际实现需要集成LLM API - s.logger.Warn("Sampling request received but not fully implemented", - zap.Any("request", req), - ) - - response := SamplingResponse{ - Content: []SamplingContent{ - { - Type: "text", - Text: "采样功能需要配置LLM服务。请使用Agent Loop API进行AI对话。", - }, - }, - StopReason: "length", - } - result, _ := json.Marshal(response) - return &Message{ - ID: msg.ID, - Type: MessageTypeResponse, - Version: "2.0", - Result: result, - } -} - -// RegisterPrompt 注册提示词模板 -func (s *Server) RegisterPrompt(prompt *Prompt) { - s.mu.Lock() - defer s.mu.Unlock() - s.prompts[prompt.Name] = prompt -} - -// RegisterResource 注册资源 -func (s *Server) RegisterResource(resource *Resource) { - s.mu.Lock() - defer s.mu.Unlock() - s.resources[resource.URI] = resource -} - -// HandleStdio 处理标准输入输出(用于 stdio 传输模式) -// MCP 协议使用换行分隔的 JSON-RPC 消息;管道下需每次写入后 Flush,否则客户端会读不到响应 -func (s *Server) HandleStdio() error { - decoder := json.NewDecoder(os.Stdin) - stdout := bufio.NewWriter(os.Stdout) - encoder := json.NewEncoder(stdout) - // 注意:不设置缩进,MCP 协议期望紧凑的 JSON 格式 - - for { - var msg Message - if err := decoder.Decode(&msg); err != nil { - if err == io.EOF { - break - } - // 日志输出到 stderr,避免干扰 stdout 的 JSON-RPC 通信 - s.logger.Error("读取消息失败", zap.Error(err)) - // 发送错误响应 - errorMsg := Message{ - ID: msg.ID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: -32700, Message: "Parse error", Data: err.Error()}, - } - if err := encoder.Encode(errorMsg); err != nil { - return fmt.Errorf("发送错误响应失败: %w", err) - } - if err := stdout.Flush(); err != nil { - return fmt.Errorf("刷新 stdout 失败: %w", err) - } - continue - } - - // 处理消息 - response := s.handleMessage(&msg) - - // 如果是通知(response 为 nil),不需要发送响应 - if response == nil { - continue - } - - // 发送响应 - if err := encoder.Encode(response); err != nil { - return fmt.Errorf("发送响应失败: %w", err) - } - if err := stdout.Flush(); err != nil { - return fmt.Errorf("刷新 stdout 失败: %w", err) - } - } - - return nil -} - -// sendError 发送错误响应 -func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) { - var msgID MessageID - if id != nil { - msgID = MessageID{value: id} - } - response := Message{ - ID: msgID, - Type: MessageTypeError, - Version: "2.0", - Error: &Error{Code: code, Message: message, Data: data}, - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) -} diff --git a/mcp/types.go b/mcp/types.go deleted file mode 100644 index bc93bb72..00000000 --- a/mcp/types.go +++ /dev/null @@ -1,329 +0,0 @@ -package mcp - -import ( - "context" - "encoding/json" - "fmt" - "strings" - "time" -) - -// ExternalMCPClient 外部 MCP 客户端接口(由 client_sdk.go 基于官方 SDK 实现) -type ExternalMCPClient interface { - Initialize(ctx context.Context) error - ListTools(ctx context.Context) ([]Tool, error) - CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) - Close() error - IsConnected() bool - GetStatus() string -} - -// MCP消息类型 -const ( - MessageTypeRequest = "request" - MessageTypeResponse = "response" - MessageTypeError = "error" - MessageTypeNotify = "notify" -) - -// MCP协议版本 -const ProtocolVersion = "2024-11-05" - -// MessageID 表示JSON-RPC 2.0的id字段,可以是字符串、数字或null -type MessageID struct { - value interface{} -} - -// UnmarshalJSON 自定义反序列化,支持字符串、数字和null -func (m *MessageID) UnmarshalJSON(data []byte) error { - // 尝试解析为null - if string(data) == "null" { - m.value = nil - return nil - } - - // 尝试解析为字符串 - var str string - if err := json.Unmarshal(data, &str); err == nil { - m.value = str - return nil - } - - // 尝试解析为数字 - var num json.Number - if err := json.Unmarshal(data, &num); err == nil { - m.value = num - return nil - } - - return fmt.Errorf("invalid id type") -} - -// MarshalJSON 自定义序列化 -func (m MessageID) MarshalJSON() ([]byte, error) { - if m.value == nil { - return []byte("null"), nil - } - return json.Marshal(m.value) -} - -// String 返回字符串表示 -func (m MessageID) String() string { - if m.value == nil { - return "" - } - return fmt.Sprintf("%v", m.value) -} - -// Value 返回原始值 -func (m MessageID) Value() interface{} { - return m.value -} - -// Message 表示MCP消息(符合JSON-RPC 2.0规范) -type Message struct { - ID MessageID `json:"id,omitempty"` - Type string `json:"-"` // 内部使用,不序列化到JSON - Method string `json:"method,omitempty"` - Params json.RawMessage `json:"params,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *Error `json:"error,omitempty"` - Version string `json:"jsonrpc,omitempty"` // JSON-RPC 2.0 版本标识 -} - -// Error 表示MCP错误 -type Error struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data,omitempty"` -} - -// Tool 表示MCP工具定义 -type Tool struct { - Name string `json:"name"` - Description string `json:"description"` // 详细描述 - ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗) - InputSchema map[string]interface{} `json:"inputSchema"` -} - -// ToolCall 表示工具调用 -type ToolCall struct { - Name string `json:"name"` - Arguments map[string]interface{} `json:"arguments"` -} - -// ToolResult 表示工具执行结果 -type ToolResult struct { - Content []Content `json:"content"` - IsError bool `json:"isError,omitempty"` -} - -// Content 表示内容 -type Content struct { - Type string `json:"type"` - Text string `json:"text"` -} - -// InitializeRequest 初始化请求 -type InitializeRequest struct { - ProtocolVersion string `json:"protocolVersion"` - Capabilities map[string]interface{} `json:"capabilities"` - ClientInfo ClientInfo `json:"clientInfo"` -} - -// ClientInfo 客户端信息 -type ClientInfo struct { - Name string `json:"name"` - Version string `json:"version"` -} - -// InitializeResponse 初始化响应 -type InitializeResponse struct { - ProtocolVersion string `json:"protocolVersion"` - Capabilities ServerCapabilities `json:"capabilities"` - ServerInfo ServerInfo `json:"serverInfo"` -} - -// ServerCapabilities 服务器能力 -type ServerCapabilities struct { - Tools map[string]interface{} `json:"tools,omitempty"` - Prompts map[string]interface{} `json:"prompts,omitempty"` - Resources map[string]interface{} `json:"resources,omitempty"` - Sampling map[string]interface{} `json:"sampling,omitempty"` -} - -// ServerInfo 服务器信息 -type ServerInfo struct { - Name string `json:"name"` - Version string `json:"version"` -} - -// ListToolsRequest 列出工具请求 -type ListToolsRequest struct{} - -// ListToolsResponse 列出工具响应 -type ListToolsResponse struct { - Tools []Tool `json:"tools"` -} - -// ListPromptsResponse 列出提示词响应 -type ListPromptsResponse struct { - Prompts []Prompt `json:"prompts"` -} - -// ListResourcesResponse 列出资源响应 -type ListResourcesResponse struct { - Resources []Resource `json:"resources"` -} - -// CallToolRequest 调用工具请求 -type CallToolRequest struct { - Name string `json:"name"` - Arguments map[string]interface{} `json:"arguments"` -} - -// CallToolResponse 调用工具响应 -type CallToolResponse struct { - Content []Content `json:"content"` - IsError bool `json:"isError,omitempty"` -} - -// ToolExecution 工具执行记录 -type ToolExecution struct { - ID string `json:"id"` - ToolName string `json:"toolName"` - Arguments map[string]interface{} `json:"arguments"` - Status string `json:"status"` // pending, running, completed, failed, cancelled - Result *ToolResult `json:"result,omitempty"` - Error string `json:"error,omitempty"` - StartTime time.Time `json:"startTime"` - EndTime *time.Time `json:"endTime,omitempty"` - Duration time.Duration `json:"duration,omitempty"` -} - -// ToolStats 工具统计信息 -type ToolStats struct { - ToolName string `json:"toolName"` - TotalCalls int `json:"totalCalls"` - SuccessCalls int `json:"successCalls"` - FailedCalls int `json:"failedCalls"` - LastCallTime *time.Time `json:"lastCallTime,omitempty"` -} - -// Prompt 提示词模板 -type Prompt struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Arguments []PromptArgument `json:"arguments,omitempty"` -} - -// PromptArgument 提示词参数 -type PromptArgument struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Required bool `json:"required,omitempty"` -} - -// GetPromptRequest 获取提示词请求 -type GetPromptRequest struct { - Name string `json:"name"` - Arguments map[string]interface{} `json:"arguments,omitempty"` -} - -// GetPromptResponse 获取提示词响应 -type GetPromptResponse struct { - Messages []PromptMessage `json:"messages"` -} - -// PromptMessage 提示词消息 -type PromptMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -// Resource 资源 -type Resource struct { - URI string `json:"uri"` - Name string `json:"name"` - Description string `json:"description,omitempty"` - MimeType string `json:"mimeType,omitempty"` -} - -// ReadResourceRequest 读取资源请求 -type ReadResourceRequest struct { - URI string `json:"uri"` -} - -// ReadResourceResponse 读取资源响应 -type ReadResourceResponse struct { - Contents []ResourceContent `json:"contents"` -} - -// ResourceContent 资源内容 -type ResourceContent struct { - URI string `json:"uri"` - MimeType string `json:"mimeType,omitempty"` - Text string `json:"text,omitempty"` - Blob string `json:"blob,omitempty"` -} - -// SamplingRequest 采样请求 -type SamplingRequest struct { - Messages []SamplingMessage `json:"messages"` - Model string `json:"model,omitempty"` - MaxTokens int `json:"maxTokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` -} - -// SamplingMessage 采样消息 -type SamplingMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -// SamplingResponse 采样响应 -type SamplingResponse struct { - Content []SamplingContent `json:"content"` - Model string `json:"model,omitempty"` - StopReason string `json:"stopReason,omitempty"` -} - -// SamplingContent 采样内容 -type SamplingContent struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` -} - -// ToolResultPlainText 拼接工具结果中的文本(手动终止时作为「工具原始输出」)。 -func ToolResultPlainText(r *ToolResult) string { - if r == nil || len(r.Content) == 0 { - return "" - } - var b strings.Builder - for _, c := range r.Content { - b.WriteString(c.Text) - } - return strings.TrimSpace(b.String()) -} - -// AbortNoteBannerForModel 标出后续文本来自「用户手动终止工具时在弹窗中填写」,避免与 stdout/stderr 混淆。 -const AbortNoteBannerForModel = "---\n" + - "【用户终止说明|USER INTERRUPT NOTE】\n" + - "(以下由操作者填写,用于指示模型如何继续;不是工具原始输出。)\n" + - "(Written by the operator when stopping this tool; not raw tool output.)\n" + - "---" - -// MergePartialToolOutputAndAbortNote 格式:工具原始输出 + 醒目标题 + 用户终止说明(无说明则原样返回 partial)。 -func MergePartialToolOutputAndAbortNote(partial, userNote string) string { - partial = strings.TrimSpace(partial) - userNote = strings.TrimSpace(userNote) - if userNote == "" { - return partial - } - section := AbortNoteBannerForModel + "\n" + userNote - if partial == "" { - return section - } - return partial + "\n\n" + section -}