From 964c52021516087b55f609fda6d7fd4a958c8f42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Tue, 21 Apr 2026 19:17:46 +0800 Subject: [PATCH] Add files via upload --- internal/config/config.go | 71 ++++-- internal/config/envexpand.go | 66 +++++ internal/config/envexpand_test.go | 81 +++++++ internal/database/database.go | 17 +- internal/mcp/client_sdk.go | 226 ++++-------------- internal/mcp/external_manager.go | 56 ++--- internal/mcp/external_manager_test.go | 42 ++-- internal/multiagent/eino_adk_run_loop.go | 75 +++--- internal/multiagent/tool_error_middleware.go | 59 +---- .../multiagent/tool_error_middleware_test.go | 18 +- internal/multiagent/tool_execution_retry.go | 68 ++---- 11 files changed, 366 insertions(+), 413 deletions(-) create mode 100644 internal/config/envexpand.go create mode 100644 internal/config/envexpand_test.go diff --git a/internal/config/config.go b/internal/config/config.go index 056c9920..c647026b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -257,28 +257,52 @@ type ExternalMCPConfig struct { Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"` } -// ExternalMCPServerConfig 外部MCP服务器配置 +// ExternalMCPServerConfig 外部MCP服务器配置(遵循官方 MCP 配置格式,兼容 Claude Desktop / Cursor / VS Code)。 +// 所有字符串字段均支持 ${VAR} 和 ${VAR:-default} 环境变量展开语法。 type ExternalMCPServerConfig struct { - // stdio模式配置 + // 传输类型: "stdio" | "sse" | "http"(Streamable HTTP)。 + // stdio 模式可省略,有 command 字段时自动推断。 + Type string `yaml:"type,omitempty" json:"type,omitempty"` + + // stdio 模式配置 Command string `yaml:"command,omitempty" json:"command,omitempty"` Args []string `yaml:"args,omitempty" json:"args,omitempty"` - Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式) + Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` - // HTTP模式配置 - Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "stdio" | "sse" | "http"(Streamable) | "simple_http"(自建/简单POST端点,如本机 http://127.0.0.1:8081/mcp) - URL string `yaml:"url,omitempty" json:"url,omitempty"` - Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` // HTTP/SSE 请求头(如 x-api-key) + // HTTP/SSE 模式配置 + URL string `yaml:"url,omitempty" json:"url,omitempty"` + Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // 官方标准字段 + Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 禁用服务器(官方字段) + AutoApprove []string `yaml:"autoApprove,omitempty" json:"autoApprove,omitempty"` // 自动批准的工具列表(官方字段) + + // SDK 高级配置(对应 MCP Go SDK 传输层参数) + MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // Streamable HTTP 断线重连次数(默认 5) + TerminateDuration int `yaml:"terminate_duration,omitempty" json:"terminate_duration,omitempty"` // stdio 进程优雅关闭等待秒数(默认 5) + KeepAlive int `yaml:"keep_alive,omitempty" json:"keep_alive,omitempty"` // 客户端心跳间隔秒数(0 = 禁用) // 通用配置 Description string `yaml:"description,omitempty" json:"description,omitempty"` - Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 超时时间(秒) - ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用外部MCP - ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态(工具名称 -> 是否启用) - - // 向后兼容字段(已废弃,保留用于读取旧配置) - Enabled bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` // 已废弃,使用 external_mcp_enable - Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 已废弃,使用 external_mcp_enable + Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 连接超时(秒) + ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用 + ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态 } + +// GetTransportType 返回实际传输类型。优先读 Type,否则根据 Command/URL 自动推断。 +func (c ExternalMCPServerConfig) GetTransportType() string { + if c.Type != "" { + return c.Type + } + if c.Command != "" { + return "stdio" + } + if c.URL != "" { + return "http" + } + return "" +} + type ToolConfig struct { Name string `yaml:"name"` Command string `yaml:"command"` @@ -369,23 +393,20 @@ func Load(path string) (*Config, error) { cfg.Security.Tools = tools } - // 迁移外部MCP配置:将旧的 enabled/disabled 字段迁移到 external_mcp_enable + // 外部 MCP:迁移 + 环境变量展开 if cfg.ExternalMCP.Servers != nil { for name, serverCfg := range cfg.ExternalMCP.Servers { - // 如果已经设置了 external_mcp_enable,跳过迁移 - // 否则从 enabled/disabled 字段迁移 - // 注意:由于 ExternalMCPEnable 是 bool 类型,零值为 false,所以需要检查是否真的设置了 - // 这里我们通过检查旧的 enabled/disabled 字段来判断是否需要迁移 + // 官方 disabled 字段 → ExternalMCPEnable if serverCfg.Disabled { - // 旧配置使用 disabled,迁移到 external_mcp_enable serverCfg.ExternalMCPEnable = false - } else if serverCfg.Enabled { - // 旧配置使用 enabled,迁移到 external_mcp_enable - serverCfg.ExternalMCPEnable = true - } else { - // 都没有设置,默认为启用 + } else if !serverCfg.ExternalMCPEnable { + // 默认启用 serverCfg.ExternalMCPEnable = true } + + // 展开所有 ${VAR} / ${VAR:-default} 环境变量引用 + ExpandConfigEnv(&serverCfg) + cfg.ExternalMCP.Servers[name] = serverCfg } } diff --git a/internal/config/envexpand.go b/internal/config/envexpand.go new file mode 100644 index 00000000..0ffc1784 --- /dev/null +++ b/internal/config/envexpand.go @@ -0,0 +1,66 @@ +package config + +import ( + "os" + "strings" +) + +// expandEnvVar 展开字符串中的 ${VAR} 和 ${VAR:-default} 环境变量引用。 +// 与官方 MCP 配置格式一致(Claude Desktop / Cursor / VS Code 均支持此语法)。 +func expandEnvVar(s string) string { + var b strings.Builder + i := 0 + for i < len(s) { + // 查找 ${ + idx := strings.Index(s[i:], "${") + if idx < 0 { + b.WriteString(s[i:]) + break + } + b.WriteString(s[i : i+idx]) + i += idx + 2 // skip ${ + + // 查找对应的 } + end := strings.IndexByte(s[i:], '}') + if end < 0 { + // 没有 },原样保留 + b.WriteString("${") + continue + } + expr := s[i : i+end] + i += end + 1 // skip } + + // 解析 VAR:-default + varName := expr + defaultVal := "" + hasDefault := false + if colonIdx := strings.Index(expr, ":-"); colonIdx >= 0 { + varName = expr[:colonIdx] + defaultVal = expr[colonIdx+2:] + hasDefault = true + } + + val := os.Getenv(varName) + if val == "" && hasDefault { + val = defaultVal + } + b.WriteString(val) + } + return b.String() +} + +// ExpandConfigEnv 展开 ExternalMCPServerConfig 中所有支持环境变量的字段。 +// 展开范围:Command、Args、Env values、URL、Headers values。 +func ExpandConfigEnv(cfg *ExternalMCPServerConfig) { + cfg.Command = expandEnvVar(cfg.Command) + for i, arg := range cfg.Args { + cfg.Args[i] = expandEnvVar(arg) + } + for k, v := range cfg.Env { + cfg.Env[k] = expandEnvVar(v) + } + cfg.URL = expandEnvVar(cfg.URL) + for k, v := range cfg.Headers { + cfg.Headers[k] = expandEnvVar(v) + } +} diff --git a/internal/config/envexpand_test.go b/internal/config/envexpand_test.go new file mode 100644 index 00000000..a17c4514 --- /dev/null +++ b/internal/config/envexpand_test.go @@ -0,0 +1,81 @@ +package config + +import ( + "os" + "testing" +) + +func TestExpandEnvVar(t *testing.T) { + os.Setenv("TEST_MCP_VAR", "hello") + os.Setenv("TEST_MCP_PATH", "/usr/local/bin") + defer os.Unsetenv("TEST_MCP_VAR") + defer os.Unsetenv("TEST_MCP_PATH") + + tests := []struct { + name string + input string + expect string + }{ + {"plain string", "no vars here", "no vars here"}, + {"empty string", "", ""}, + {"simple var", "${TEST_MCP_VAR}", "hello"}, + {"var in middle", "prefix-${TEST_MCP_VAR}-suffix", "prefix-hello-suffix"}, + {"multiple vars", "${TEST_MCP_PATH}/${TEST_MCP_VAR}", "/usr/local/bin/hello"}, + {"missing var empty", "${NONEXISTENT_MCP_VAR_XYZ}", ""}, + {"default value used", "${NONEXISTENT_MCP_VAR_XYZ:-fallback}", "fallback"}, + {"default not used", "${TEST_MCP_VAR:-unused}", "hello"}, + {"default with path", "${NONEXISTENT_MCP_VAR_XYZ:-/tmp/default}", "/tmp/default"}, + {"unclosed brace", "${UNCLOSED", "${UNCLOSED"}, + {"dollar without brace", "$PLAIN", "$PLAIN"}, + {"empty var name", "${}", ""}, + {"default empty var", "${:-default}", "default"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := expandEnvVar(tt.input) + if got != tt.expect { + t.Errorf("expandEnvVar(%q) = %q, want %q", tt.input, got, tt.expect) + } + }) + } +} + +func TestExpandConfigEnv(t *testing.T) { + os.Setenv("TEST_MCP_CMD", "python3") + os.Setenv("TEST_MCP_TOKEN", "secret123") + defer os.Unsetenv("TEST_MCP_CMD") + defer os.Unsetenv("TEST_MCP_TOKEN") + + cfg := &ExternalMCPServerConfig{ + Command: "${TEST_MCP_CMD}", + Args: []string{"--token", "${TEST_MCP_TOKEN}", "${MISSING:-default_arg}"}, + Env: map[string]string{"API_KEY": "${TEST_MCP_TOKEN}", "LEVEL": "${MISSING:-INFO}"}, + URL: "https://${MISSING:-example.com}/mcp", + Headers: map[string]string{"Authorization": "Bearer ${TEST_MCP_TOKEN}"}, + } + + ExpandConfigEnv(cfg) + + if cfg.Command != "python3" { + t.Errorf("Command = %q, want %q", cfg.Command, "python3") + } + if cfg.Args[1] != "secret123" { + t.Errorf("Args[1] = %q, want %q", cfg.Args[1], "secret123") + } + if cfg.Args[2] != "default_arg" { + t.Errorf("Args[2] = %q, want %q", cfg.Args[2], "default_arg") + } + if cfg.Env["API_KEY"] != "secret123" { + t.Errorf("Env[API_KEY] = %q, want %q", cfg.Env["API_KEY"], "secret123") + } + if cfg.Env["LEVEL"] != "INFO" { + t.Errorf("Env[LEVEL] = %q, want %q", cfg.Env["LEVEL"], "INFO") + } + if cfg.URL != "https://example.com/mcp" { + t.Errorf("URL = %q, want %q", cfg.URL, "https://example.com/mcp") + } + if cfg.Headers["Authorization"] != "Bearer secret123" { + t.Errorf("Headers[Authorization] = %q, want %q", cfg.Headers["Authorization"], "Bearer secret123") + } +} diff --git a/internal/database/database.go b/internal/database/database.go index 0e0ec524..61fc053f 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -4,11 +4,20 @@ import ( "database/sql" "fmt" "strings" + "time" _ "github.com/mattn/go-sqlite3" "go.uber.org/zap" ) +// configureDBPool 设置 SQLite 连接池参数,提升并发稳定性 +func configureDBPool(db *sql.DB) { + // SQLite 同一时间只允许一个写入者,限制连接数避免 "database is locked" 错误 + db.SetMaxOpenConns(25) + db.SetMaxIdleConns(5) + db.SetConnMaxLifetime(30 * time.Minute) +} + // DB 数据库连接 type DB struct { *sql.DB @@ -17,11 +26,13 @@ type DB struct { // NewDB 创建数据库连接 func NewDB(dbPath string, logger *zap.Logger) (*DB, error) { - db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1") + db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL") if err != nil { return nil, fmt.Errorf("打开数据库失败: %w", err) } + configureDBPool(db) + if err := db.Ping(); err != nil { return nil, fmt.Errorf("连接数据库失败: %w", err) } @@ -674,11 +685,13 @@ func (db *DB) migrateBatchTaskQueuesTable() error { // NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表) func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) { - sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1") + sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL") if err != nil { return nil, fmt.Errorf("打开知识库数据库失败: %w", err) } + configureDBPool(sqlDB) + if err := sqlDB.Ping(); err != nil { return nil, fmt.Errorf("连接知识库数据库失败: %w", err) } diff --git a/internal/mcp/client_sdk.go b/internal/mcp/client_sdk.go index 59b513b2..bfbbcb15 100644 --- a/internal/mcp/client_sdk.go +++ b/internal/mcp/client_sdk.go @@ -2,11 +2,9 @@ package mcp import ( - "bytes" "context" "encoding/json" "fmt" - "io" "net/http" "os" "os/exec" @@ -16,7 +14,6 @@ import ( "cyberstrike-ai/internal/config" - "github.com/google/uuid" "github.com/modelcontextprotocol/go-sdk/mcp" "go.uber.org/zap" ) @@ -268,172 +265,6 @@ func mustJSON(v interface{}) []byte { return b } -// simpleHTTPClient 简单 JSON-RPC over HTTP:每次请求一次 POST、响应在 body。实现 ExternalMCPClient。 -// 用于自建 MCP(如 http://127.0.0.1:8081/mcp)或其它仅支持简单 POST 的端点。 -type simpleHTTPClient struct { - url string - client *http.Client - logger *zap.Logger - mu sync.RWMutex - status string -} - -func newSimpleHTTPClient(ctx context.Context, url string, timeout time.Duration, headers map[string]string, logger *zap.Logger) (ExternalMCPClient, error) { - c := &simpleHTTPClient{ - url: url, - client: httpClientWithTimeoutAndHeaders(timeout, headers), - logger: logger, - status: "connecting", - } - if err := c.initialize(ctx); err != nil { - return nil, err - } - c.mu.Lock() - c.status = "connected" - c.mu.Unlock() - return c, nil -} - -func (c *simpleHTTPClient) setStatus(s string) { - c.mu.Lock() - defer c.mu.Unlock() - c.status = s -} - -func (c *simpleHTTPClient) GetStatus() string { - c.mu.RLock() - defer c.mu.RUnlock() - return c.status -} - -func (c *simpleHTTPClient) IsConnected() bool { - return c.GetStatus() == "connected" -} - -func (c *simpleHTTPClient) Initialize(context.Context) error { - return nil // 已在 newSimpleHTTPClient 中完成 -} - -func (c *simpleHTTPClient) initialize(ctx context.Context) error { - params := InitializeRequest{ - ProtocolVersion: ProtocolVersion, - Capabilities: make(map[string]interface{}), - ClientInfo: ClientInfo{Name: clientName, Version: clientVersion}, - } - paramsJSON, _ := json.Marshal(params) - req := &Message{ - ID: MessageID{value: "1"}, - Method: "initialize", - Version: "2.0", - Params: paramsJSON, - } - resp, err := c.sendRequest(ctx, req) - if err != nil { - return fmt.Errorf("initialize: %w", err) - } - if resp.Error != nil { - return fmt.Errorf("initialize: %s (code %d)", resp.Error.Message, resp.Error.Code) - } - // 发送 notifications/initialized(协议要求) - notify := &Message{ - ID: MessageID{value: nil}, - Method: "notifications/initialized", - Version: "2.0", - Params: json.RawMessage("{}"), - } - _ = c.sendNotification(notify) - return nil -} - -func (c *simpleHTTPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) { - body, err := json.Marshal(msg) - if err != nil { - return nil, err - } - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - httpReq.Header.Set("Content-Type", "application/json") - resp, err := c.client.Do(httpReq) - if err != nil { - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - b, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(b)) - } - var out Message - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { - return nil, err - } - return &out, nil -} - -func (c *simpleHTTPClient) sendNotification(msg *Message) error { - body, _ := json.Marshal(msg) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body)) - httpReq.Header.Set("Content-Type", "application/json") - resp, err := c.client.Do(httpReq) - if err != nil { - return err - } - resp.Body.Close() - return nil -} - -func (c *simpleHTTPClient) ListTools(ctx context.Context) ([]Tool, error) { - req := &Message{ - ID: MessageID{value: uuid.New().String()}, - Method: "tools/list", - Version: "2.0", - Params: json.RawMessage("{}"), - } - resp, err := c.sendRequest(ctx, req) - if err != nil { - return nil, err - } - if resp.Error != nil { - return nil, fmt.Errorf("tools/list: %s (code %d)", resp.Error.Message, resp.Error.Code) - } - var listResp ListToolsResponse - if err := json.Unmarshal(resp.Result, &listResp); err != nil { - return nil, err - } - return listResp.Tools, nil -} - -func (c *simpleHTTPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { - params := CallToolRequest{Name: name, Arguments: args} - paramsJSON, _ := json.Marshal(params) - req := &Message{ - ID: MessageID{value: uuid.New().String()}, - Method: "tools/call", - Version: "2.0", - Params: paramsJSON, - } - resp, err := c.sendRequest(ctx, req) - if err != nil { - return nil, err - } - if resp.Error != nil { - return nil, fmt.Errorf("tools/call: %s (code %d)", resp.Error.Message, resp.Error.Code) - } - var callResp CallToolResponse - if err := json.Unmarshal(resp.Result, &callResp); err != nil { - return nil, err - } - return &ToolResult{Content: callResp.Content, IsError: callResp.IsError}, nil -} - -func (c *simpleHTTPClient) Close() error { - c.setStatus("disconnected") - return nil -} - // createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient // 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。 func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) { @@ -442,21 +273,23 @@ func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConf timeout = 30 * time.Second } - transport := serverCfg.Transport + transport := serverCfg.GetTransportType() if transport == "" { - if serverCfg.Command != "" { - transport = "stdio" - } else if serverCfg.URL != "" { - transport = "http" - } else { - return nil, fmt.Errorf("配置缺少 command 或 url") + return nil, fmt.Errorf("配置缺少 command 或 url,且未指定 type/transport") + } + + // 构造 ClientOptions:KeepAlive 心跳 + var clientOpts *mcp.ClientOptions + if serverCfg.KeepAlive > 0 { + clientOpts = &mcp.ClientOptions{ + KeepAlive: time.Duration(serverCfg.KeepAlive) * time.Second, } } client := mcp.NewClient(&mcp.Implementation{ Name: clientName, Version: clientVersion, - }, nil) + }, clientOpts) var t mcp.Transport switch transport { @@ -470,12 +303,18 @@ func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConf if len(serverCfg.Env) > 0 { cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...) } - t = &mcp.CommandTransport{Command: cmd} + ct := &mcp.CommandTransport{Command: cmd} + if serverCfg.TerminateDuration > 0 { + ct.TerminateDuration = time.Duration(serverCfg.TerminateDuration) * time.Second + } + t = ct case "sse": if serverCfg.URL == "" { return nil, fmt.Errorf("sse 模式需要配置 url") } - httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers) + // SSE 是长连接(GET 流持续打开),不能设置 http.Client.Timeout(会在超时后杀掉整个连接导致 EOF)。 + // 超时由每次 ListTools/CallTool 的 context 单独控制。 + httpClient := httpClientForLongLived(serverCfg.Headers) t = &mcp.SSEClientTransport{ Endpoint: serverCfg.URL, HTTPClient: httpClient, @@ -485,18 +324,16 @@ func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConf return nil, fmt.Errorf("http 模式需要配置 url") } httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers) - t = &mcp.StreamableClientTransport{ + st := &mcp.StreamableClientTransport{ Endpoint: serverCfg.URL, HTTPClient: httpClient, } - case "simple_http": - // 简单 JSON-RPC HTTP:每次请求一次 POST、响应在 body。用于自建 MCP 或兼容旧端点(如 http://127.0.0.1:8081/mcp) - if serverCfg.URL == "" { - return nil, fmt.Errorf("simple_http 模式需要配置 url") + if serverCfg.MaxRetries > 0 { + st.MaxRetries = serverCfg.MaxRetries } - return newSimpleHTTPClient(ctx, serverCfg.URL, timeout, serverCfg.Headers, logger) + t = st default: - return nil, fmt.Errorf("不支持的传输模式: %s", transport) + return nil, fmt.Errorf("不支持的传输模式: %s(支持: stdio, sse, http)", transport) } session, err := client.Connect(ctx, t, nil) @@ -538,6 +375,23 @@ func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]s } } +// httpClientForLongLived 创建不设超时的 HTTP 客户端,用于 SSE 等长连接传输。 +// SSE 的 GET 流会持续打开,http.Client.Timeout 会在超时后强制关闭连接导致 EOF。 +// 超时由调用方通过 context 控制。 +func httpClientForLongLived(headers map[string]string) *http.Client { + transport := http.DefaultTransport + if len(headers) > 0 { + transport = &headerRoundTripper{ + headers: headers, + base: http.DefaultTransport, + } + } + return &http.Client{ + Transport: transport, + // 不设 Timeout,SSE 长连接的超时由 per-request context 控制 + } +} + type headerRoundTripper struct { headers map[string]string base http.RoundTripper diff --git a/internal/mcp/external_manager.go b/internal/mcp/external_manager.go index 1d9c3164..3d3346d6 100644 --- a/internal/mcp/external_manager.go +++ b/internal/mcp/external_manager.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "time" "cyberstrike-ai/internal/config" @@ -29,6 +30,7 @@ type ExternalMCPManager struct { toolCacheMu sync.RWMutex // 工具列表缓存的锁 stopRefresh chan struct{} // 停止后台刷新的信号 refreshWg sync.WaitGroup // 等待后台刷新goroutine完成 + refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积 mu sync.RWMutex } @@ -721,7 +723,13 @@ func (m *ExternalMCPManager) GetToolCounts() map[string]int { } // refreshToolCounts 刷新工具数量缓存(后台异步执行) +// 使用 atomic flag 防止并发堆积:如果上一次刷新尚未完成,本次触发直接跳过。 func (m *ExternalMCPManager) refreshToolCounts() { + if !m.refreshing.CompareAndSwap(false, true) { + return // 上一次刷新尚未完成,跳过 + } + defer m.refreshing.Store(false) + m.mu.RLock() clients := make(map[string]ExternalMCPClient) for k, v := range m.clients { @@ -874,16 +882,7 @@ func (m *ExternalMCPManager) triggerToolCountRefresh() { // createClient 创建客户端(不连接)。统一使用官方 MCP Go SDK 的 lazy 客户端,连接在 Initialize 时完成。 func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient { - transport := serverCfg.Transport - if transport == "" { - if serverCfg.Command != "" { - transport = "stdio" - } else if serverCfg.URL != "" { - transport = "http" - } else { - return nil - } - } + transport := serverCfg.GetTransportType() switch transport { case "http": @@ -891,12 +890,6 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf return nil } return newLazySDKClient(serverCfg, m.logger) - case "simple_http": - // 简单 HTTP(一次 POST 一次响应),用于自建 MCP 等 - if serverCfg.URL == "" { - return nil - } - return newLazySDKClient(serverCfg, m.logger) case "stdio": if serverCfg.Command == "" { return nil @@ -908,7 +901,11 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf } return newLazySDKClient(serverCfg, m.logger) default: - return nil + if transport == "" { + return nil + } + // 未知传输类型也尝试使用 lazy client + return newLazySDKClient(serverCfg, m.logger) } } @@ -990,20 +987,7 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa // isEnabled 检查是否启用 func (m *ExternalMCPManager) isEnabled(cfg config.ExternalMCPServerConfig) bool { - // 优先使用 ExternalMCPEnable 字段 - // 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容) - if cfg.ExternalMCPEnable { - return true - } - // 向后兼容:检查旧字段 - if cfg.Disabled { - return false - } - if cfg.Enabled { - return true - } - // 都没有设置,默认为启用 - return true + return cfg.ExternalMCPEnable } // findSubstring 查找子字符串(简单实现) @@ -1044,15 +1028,7 @@ func (m *ExternalMCPManager) StartAllEnabled() { zap.Error(err), } - // 根据传输模式添加相应的信息 - transport := c.Transport - if transport == "" { - if c.Command != "" { - transport = "stdio" - } else if c.URL != "" { - transport = "http" - } - } + transport := c.GetTransportType() if transport == "http" && c.URL != "" { fields = append(fields, zap.String("url", c.URL)) diff --git a/internal/mcp/external_manager_test.go b/internal/mcp/external_manager_test.go index d4c49851..b7692c33 100644 --- a/internal/mcp/external_manager_test.go +++ b/internal/mcp/external_manager_test.go @@ -16,12 +16,11 @@ func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) { // 测试添加stdio配置 stdioCfg := config.ExternalMCPServerConfig{ - Command: "python3", - Args: []string{"/path/to/script.py"}, - Transport: "stdio", - Description: "Test stdio MCP", - Timeout: 30, - Enabled: true, + Command: "python3", + Args: []string{"/path/to/script.py"}, + Description: "Test stdio MCP", + Timeout: 30, + ExternalMCPEnable: true, } err := manager.AddOrUpdateConfig("test-stdio", stdioCfg) @@ -31,11 +30,11 @@ func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) { // 测试添加HTTP配置 httpCfg := config.ExternalMCPServerConfig{ - Transport: "http", - URL: "http://127.0.0.1:8081/mcp", - Description: "Test HTTP MCP", - Timeout: 30, - Enabled: false, + Type: "http", + URL: "http://127.0.0.1:8081/mcp", + Description: "Test HTTP MCP", + Timeout: 30, + ExternalMCPEnable: false, } err = manager.AddOrUpdateConfig("test-http", httpCfg) @@ -64,8 +63,7 @@ func TestExternalMCPManager_RemoveConfig(t *testing.T) { cfg := config.ExternalMCPServerConfig{ Command: "python3", - Transport: "stdio", - Enabled: false, + ExternalMCPEnable: false, } manager.AddOrUpdateConfig("test-remove", cfg) @@ -89,18 +87,17 @@ func TestExternalMCPManager_GetStats(t *testing.T) { // 添加多个配置 manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ Command: "python3", - Enabled: true, + ExternalMCPEnable: true, }) manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ URL: "http://127.0.0.1:8081/mcp", - Enabled: true, + ExternalMCPEnable: true, }) manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ Command: "python3", - Enabled: false, - Disabled: true, // 明确设置为禁用 + ExternalMCPEnable: false, }) stats := manager.GetStats() @@ -126,11 +123,11 @@ func TestExternalMCPManager_LoadConfigs(t *testing.T) { Servers: map[string]config.ExternalMCPServerConfig{ "loaded1": { Command: "python3", - Enabled: true, + ExternalMCPEnable: true, }, "loaded2": { URL: "http://127.0.0.1:8081/mcp", - Enabled: false, + ExternalMCPEnable: false, }, }, } @@ -156,7 +153,7 @@ func TestLazySDKClient_InitializeFails(t *testing.T) { logger := zap.NewNop() // 使用不存在的 HTTP 地址,Initialize 应失败 cfg := config.ExternalMCPServerConfig{ - Transport: "http", + Type: "http", URL: "http://127.0.0.1:19999/nonexistent", Timeout: 2, } @@ -180,8 +177,7 @@ func TestExternalMCPManager_StartStopClient(t *testing.T) { // 添加一个禁用的配置 cfg := config.ExternalMCPServerConfig{ Command: "python3", - Transport: "stdio", - Enabled: false, + ExternalMCPEnable: false, } manager.AddOrUpdateConfig("test-start-stop", cfg) @@ -200,7 +196,7 @@ func TestExternalMCPManager_StartStopClient(t *testing.T) { // 验证配置已更新为禁用 configs := manager.GetConfigs() - if configs["test-start-stop"].Enabled { + if configs["test-start-stop"].ExternalMCPEnable { t.Error("配置应该已被禁用") } } diff --git a/internal/multiagent/eino_adk_run_loop.go b/internal/multiagent/eino_adk_run_loop.go index 9aa497bd..9541a33b 100644 --- a/internal/multiagent/eino_adk_run_loop.go +++ b/internal/multiagent/eino_adk_run_loop.go @@ -230,54 +230,61 @@ attemptLoop: continue } if ev.Err != nil { - canRetry := attempt+1 < maxToolCallRecoveryAttempts - - if canRetry && isRecoverableToolCallArgumentsJSONError(ev.Err) { - if logger != nil { - logger.Warn("eino: recoverable tool-call JSON error from model/API", zap.Error(ev.Err), zap.Int("attempt", attempt)) - } - retryHints = append(retryHints, toolCallArgumentsJSONRetryHint()) - if progress != nil { - progress("eino_recovery", toolCallArgumentsJSONRecoveryTimelineMessage(attempt), map[string]interface{}{ - "conversationId": conversationID, - "source": "eino", - "einoRetry": attempt, - "runIndex": attempt + 1, - "maxRuns": maxToolCallRecoveryAttempts, - "reason": "invalid_tool_arguments_json", - }) - } - continue attemptLoop - } - - if canRetry && isRecoverableToolExecutionError(ev.Err) { - if logger != nil { - logger.Warn("eino: recoverable tool execution error, will retry with corrective hint", - zap.Error(ev.Err), zap.Int("attempt", attempt)) - } + // context.Canceled 是唯一应当直接终止编排的错误(用户关闭页面、主动停止等)。 + if errors.Is(ev.Err, context.Canceled) { flushAllPendingAsFailed(ev.Err) - retryHints = append(retryHints, toolExecutionRetryHint()) if progress != nil { - progress("eino_recovery", toolExecutionRecoveryTimelineMessage(attempt), map[string]interface{}{ + progress("error", ev.Err.Error(), map[string]interface{}{ "conversationId": conversationID, "source": "eino", - "einoRetry": attempt, - "runIndex": attempt + 1, - "maxRuns": maxToolCallRecoveryAttempts, - "reason": "tool_execution_error", }) } - continue attemptLoop + return nil, ev.Err } + canRetry := attempt+1 < maxToolCallRecoveryAttempts + if !canRetry { + // 重试次数已耗尽,终止。 + flushAllPendingAsFailed(ev.Err) + if progress != nil { + progress("error", ev.Err.Error(), map[string]interface{}{ + "conversationId": conversationID, + "source": "eino", + }) + } + return nil, ev.Err + } + + // 区分错误类型以选择最合适的纠错提示,但无论哪种都执行重试(default-soft)。 + var hint *schema.Message + var reason, timelineMsg string + if isRecoverableToolCallArgumentsJSONError(ev.Err) { + hint = toolCallArgumentsJSONRetryHint() + reason = "invalid_tool_arguments_json" + timelineMsg = toolCallArgumentsJSONRecoveryTimelineMessage(attempt) + } else { + hint = toolExecutionRetryHint() + reason = "tool_execution_error" + timelineMsg = toolExecutionRecoveryTimelineMessage(attempt) + } + + if logger != nil { + logger.Warn("eino: recoverable error, will retry with corrective hint", + zap.Error(ev.Err), zap.Int("attempt", attempt), zap.String("reason", reason)) + } flushAllPendingAsFailed(ev.Err) + retryHints = append(retryHints, hint) if progress != nil { - progress("error", ev.Err.Error(), map[string]interface{}{ + progress("eino_recovery", timelineMsg, map[string]interface{}{ "conversationId": conversationID, "source": "eino", + "einoRetry": attempt, + "runIndex": attempt + 1, + "maxRuns": maxToolCallRecoveryAttempts, + "reason": reason, }) } - return nil, ev.Err + continue attemptLoop } if ev.AgentName != "" && progress != nil { iterEinoAgent := orchestratorName diff --git a/internal/multiagent/tool_error_middleware.go b/internal/multiagent/tool_error_middleware.go index 147aaa29..15e523a9 100644 --- a/internal/multiagent/tool_error_middleware.go +++ b/internal/multiagent/tool_error_middleware.go @@ -41,62 +41,27 @@ func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware { // isSoftRecoverableToolError determines whether a tool execution error should be // silently converted to a tool-result message rather than crashing the graph. +// +// Design: default-soft (blacklist). Almost every tool execution error should be +// fed back to the LLM so it can self-correct or choose an alternative tool. +// Only a small set of "truly fatal" conditions (user cancellation) should +// propagate as hard errors that terminate the orchestration graph. +// This avoids the fragile whitelist approach where every new error pattern +// would need to be explicitly enumerated. func isSoftRecoverableToolError(err error) bool { if err == nil { return false } - // 用户取消 — 不应重试,让 hard error 传播以终止编排。 + // 用户主动取消 — 唯一应当终止编排的情况,不应重试。 if errors.Is(err, context.Canceled) { return false } - // 工具执行超时 — 转为 soft error 让 LLM 知晓并选择替代方案,而非全局重试。 - if errors.Is(err, context.DeadlineExceeded) { - return true - } - - s := strings.ToLower(err.Error()) - - // JSON unmarshal/parse failures — the model generated truncated or malformed arguments. - if isJSONRelatedError(s) { - return true - } - - // Sub-agent type not found (from deep/task_tool.go) - if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") { - return true - } - - // Tool not found in ToolsNode indexes - if strings.Contains(s, "tool") && strings.Contains(s, "not found") { - return true - } - - return false -} - -// isJSONRelatedError checks whether an error string indicates a JSON parsing problem. -func isJSONRelatedError(lower string) bool { - if !strings.Contains(lower, "json") { - return false - } - jsonIndicators := []string{ - "unexpected end of json", - "unmarshal", - "invalid character", - "cannot unmarshal", - "invalid tool arguments", - "failed to unmarshal", - "must be in json format", - "unexpected eof", - } - for _, ind := range jsonIndicators { - if strings.Contains(lower, ind) { - return true - } - } - return false + // 其他所有工具执行错误(超时、命令不存在、JSON 解析失败、工具未找到、 + // 权限不足、网络不可达……)一律转为 soft error,让 LLM 看到错误信息 + // 后自行决策:换工具、调整参数、或向用户说明。 + return true } // buildSoftRecoveryMessage creates a bilingual error message that the LLM can act on. diff --git a/internal/multiagent/tool_error_middleware_test.go b/internal/multiagent/tool_error_middleware_test.go index d87e417b..bf2e622e 100644 --- a/internal/multiagent/tool_error_middleware_test.go +++ b/internal/multiagent/tool_error_middleware_test.go @@ -53,7 +53,12 @@ func TestIsSoftRecoverableToolError(t *testing.T) { { name: "unrelated network error", err: errors.New("connection refused"), - expected: false, + expected: true, // default-soft: non-cancel errors are recoverable + }, + { + name: "tool binary not installed", + err: errors.New("[LocalFunc] failed to invoke tool, toolName=grep, err=ripgrep (rg) is not installed or not in PATH"), + expected: true, }, { name: "context cancelled", @@ -131,15 +136,16 @@ func TestSoftRecoveryToolCallMiddleware_PropagatesNonRecoverable(t *testing.T) { return nil, origErr } wrapped := mw(next) - _, err := wrapped(context.Background(), &compose.ToolInput{ + out, err := wrapped(context.Background(), &compose.ToolInput{ Name: "test_tool", Arguments: `{}`, }) - if err == nil { - t.Fatal("expected error to propagate for non-recoverable errors") + // Default-soft: non-cancel errors are converted to tool-result messages. + if err != nil { + t.Fatalf("expected nil error (soft recovery), got: %v", err) } - if err != origErr { - t.Fatalf("expected original error, got: %v", err) + if out == nil || out.Result == "" { + t.Fatal("expected non-empty recovery message") } } diff --git a/internal/multiagent/tool_execution_retry.go b/internal/multiagent/tool_execution_retry.go index c79f8a66..6c5dad37 100644 --- a/internal/multiagent/tool_execution_retry.go +++ b/internal/multiagent/tool_execution_retry.go @@ -2,74 +2,42 @@ package multiagent import ( "fmt" - "strings" "github.com/cloudwego/eino/schema" ) -// isRecoverableToolExecutionError detects tool-level execution errors that can be -// recovered by retrying with a corrective hint. These errors originate from eino -// framework internals (e.g. task_tool.go, tool_node.go) when the LLM produces -// invalid tool calls such as non-existent sub-agent types, malformed JSON arguments, -// or unregistered tool names. -func isRecoverableToolExecutionError(err error) bool { - if err == nil { - return false - } - s := strings.ToLower(err.Error()) - - // Sub-agent type not found (from deep/task_tool.go) - if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") { - return true - } - - // Tool not found in toolsNode indexes (from compose/tool_node.go, when UnknownToolsHandler is nil) - if strings.Contains(s, "tool") && strings.Contains(s, "not found") { - return true - } - - // Invalid tool arguments JSON (from einomcp/mcp_tools.go or eino internals) - if strings.Contains(s, "invalid tool arguments json") { - return true - } - - // Failed to unmarshal task tool input json (from deep/task_tool.go) - if strings.Contains(s, "failed to unmarshal") && strings.Contains(s, "json") { - return true - } - - // Generic tool call stream/invoke failure wrapping the above - if (strings.Contains(s, "failed to stream tool call") || strings.Contains(s, "failed to invoke tool")) && - (strings.Contains(s, "not found") || strings.Contains(s, "json") || strings.Contains(s, "unmarshal")) { - return true - } - - return false -} - // toolExecutionRetryHint returns a user message appended to the conversation to prompt -// the LLM to correct its tool call after a tool execution error. +// the LLM to adjust after a tool execution error (tool not found, binary missing, +// runtime failure, network error, etc.). func toolExecutionRetryHint() *schema.Message { - return schema.UserMessage(`[System] Your previous tool call failed because: -- The tool or sub-agent name you used does not exist, OR + return schema.UserMessage(`[System] Your previous tool call failed. Possible causes: +- The tool or sub-agent name does not exist (typo or unregistered name). - The tool call arguments were not valid JSON. +- The tool's underlying binary is not installed or not in PATH. +- The tool encountered a runtime error (timeout, network failure, permission denied, etc.). -Please carefully review the available tools and sub-agents listed in your context, use only exact registered names (case-sensitive), and ensure all arguments are well-formed JSON objects. Then retry your action. +Please review the error message above, check available tools, and either: +1. Retry with corrected arguments or a different tool, OR +2. Inform the user about the limitation and proceed with an alternative approach. [系统提示] 上一次工具调用失败,可能原因: -- 你使用的工具名或子代理名称不存在; -- 工具调用参数不是合法 JSON。 +- 工具名或子代理名称不存在(拼写错误或未注册); +- 工具调用参数不是合法 JSON; +- 工具依赖的底层二进制程序未安装或不在 PATH 中; +- 工具运行时遇到错误(超时、网络故障、权限不足等)。 -请仔细检查上下文中列出的可用工具和子代理名称(须完全匹配、区分大小写),确保所有参数均为合法的 JSON 对象,然后重新执行。`) +请根据上述错误信息检查可用工具,然后: +1. 修正参数或改用其他工具重试,或者 +2. 告知用户当前限制并采用替代方案继续。`) } // toolExecutionRecoveryTimelineMessage returns a message for the eino_recovery event // displayed in the UI timeline when a tool execution error triggers a retry. func toolExecutionRecoveryTimelineMessage(attempt int) string { return fmt.Sprintf( - "工具调用执行失败(工具/子代理名称不存在或参数 JSON 无效)。已向对话追加纠错提示并要求模型重新生成。"+ + "工具调用执行失败。已向对话追加纠错提示并要求模型调整策略。"+ "当前为第 %d/%d 轮完整运行。\n\n"+ - "Tool call execution failed (unknown tool/sub-agent name or invalid JSON arguments). "+ + "Tool call execution failed. "+ "A corrective hint was appended. This is full run %d of %d.", attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts, )