Add files via upload

This commit is contained in:
公明
2026-04-21 19:17:46 +08:00
committed by GitHub
parent 26116b0822
commit 964c520215
11 changed files with 366 additions and 413 deletions
+46 -25
View File
@@ -257,28 +257,52 @@ type ExternalMCPConfig struct {
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
}
// ExternalMCPServerConfig 外部MCP服务器配置
// ExternalMCPServerConfig 外部MCP服务器配置(遵循官方 MCP 配置格式,兼容 Claude Desktop / Cursor / VS Code)。
// 所有字符串字段均支持 ${VAR} 和 ${VAR:-default} 环境变量展开语法。
type ExternalMCPServerConfig struct {
// stdio模式配置
// 传输类型: "stdio" | "sse" | "http"Streamable HTTP)。
// stdio 模式可省略,有 command 字段时自动推断。
Type string `yaml:"type,omitempty" json:"type,omitempty"`
// stdio 模式配置
Command string `yaml:"command,omitempty" json:"command,omitempty"`
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式)
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"`
// HTTP模式配置
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "stdio" | "sse" | "http"(Streamable) | "simple_http"(自建/简单POST端点,如本机 http://127.0.0.1:8081/mcp)
URL string `yaml:"url,omitempty" json:"url,omitempty"`
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` // HTTP/SSE 请求头(如 x-api-key
// HTTP/SSE 模式配置
URL string `yaml:"url,omitempty" json:"url,omitempty"`
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
// 官方标准字段
Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 禁用服务器(官方字段)
AutoApprove []string `yaml:"autoApprove,omitempty" json:"autoApprove,omitempty"` // 自动批准的工具列表(官方字段)
// SDK 高级配置(对应 MCP Go SDK 传输层参数)
MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // Streamable HTTP 断线重连次数(默认 5)
TerminateDuration int `yaml:"terminate_duration,omitempty" json:"terminate_duration,omitempty"` // stdio 进程优雅关闭等待秒数(默认 5)
KeepAlive int `yaml:"keep_alive,omitempty" json:"keep_alive,omitempty"` // 客户端心跳间隔秒数(0 = 禁用)
// 通用配置
Description string `yaml:"description,omitempty" json:"description,omitempty"`
Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 超时时间(秒)
ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用外部MCP
ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态(工具名称 -> 是否启用)
// 向后兼容字段(已废弃,保留用于读取旧配置)
Enabled bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` // 已废弃,使用 external_mcp_enable
Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 已废弃,使用 external_mcp_enable
Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 连接超时(秒)
ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用
ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态
}
// GetTransportType 返回实际传输类型。优先读 Type,否则根据 Command/URL 自动推断。
func (c ExternalMCPServerConfig) GetTransportType() string {
if c.Type != "" {
return c.Type
}
if c.Command != "" {
return "stdio"
}
if c.URL != "" {
return "http"
}
return ""
}
type ToolConfig struct {
Name string `yaml:"name"`
Command string `yaml:"command"`
@@ -369,23 +393,20 @@ func Load(path string) (*Config, error) {
cfg.Security.Tools = tools
}
// 迁移外部MCP配置:将旧的 enabled/disabled 字段迁移到 external_mcp_enable
// 外部 MCP:迁移 + 环境变量展开
if cfg.ExternalMCP.Servers != nil {
for name, serverCfg := range cfg.ExternalMCP.Servers {
// 如果已经设置了 external_mcp_enable,跳过迁移
// 否则从 enabled/disabled 字段迁移
// 注意:由于 ExternalMCPEnable 是 bool 类型,零值为 false,所以需要检查是否真的设置了
// 这里我们通过检查旧的 enabled/disabled 字段来判断是否需要迁移
// 官方 disabled 字段 → ExternalMCPEnable
if serverCfg.Disabled {
// 旧配置使用 disabled,迁移到 external_mcp_enable
serverCfg.ExternalMCPEnable = false
} else if serverCfg.Enabled {
// 旧配置使用 enabled,迁移到 external_mcp_enable
serverCfg.ExternalMCPEnable = true
} else {
// 都没有设置,默认为启用
} else if !serverCfg.ExternalMCPEnable {
// 默认启用
serverCfg.ExternalMCPEnable = true
}
// 展开所有 ${VAR} / ${VAR:-default} 环境变量引用
ExpandConfigEnv(&serverCfg)
cfg.ExternalMCP.Servers[name] = serverCfg
}
}
+66
View File
@@ -0,0 +1,66 @@
package config
import (
"os"
"strings"
)
// expandEnvVar 展开字符串中的 ${VAR} 和 ${VAR:-default} 环境变量引用。
// 与官方 MCP 配置格式一致(Claude Desktop / Cursor / VS Code 均支持此语法)。
func expandEnvVar(s string) string {
var b strings.Builder
i := 0
for i < len(s) {
// 查找 ${
idx := strings.Index(s[i:], "${")
if idx < 0 {
b.WriteString(s[i:])
break
}
b.WriteString(s[i : i+idx])
i += idx + 2 // skip ${
// 查找对应的 }
end := strings.IndexByte(s[i:], '}')
if end < 0 {
// 没有 },原样保留
b.WriteString("${")
continue
}
expr := s[i : i+end]
i += end + 1 // skip }
// 解析 VAR:-default
varName := expr
defaultVal := ""
hasDefault := false
if colonIdx := strings.Index(expr, ":-"); colonIdx >= 0 {
varName = expr[:colonIdx]
defaultVal = expr[colonIdx+2:]
hasDefault = true
}
val := os.Getenv(varName)
if val == "" && hasDefault {
val = defaultVal
}
b.WriteString(val)
}
return b.String()
}
// ExpandConfigEnv 展开 ExternalMCPServerConfig 中所有支持环境变量的字段。
// 展开范围:Command、Args、Env values、URL、Headers values。
func ExpandConfigEnv(cfg *ExternalMCPServerConfig) {
cfg.Command = expandEnvVar(cfg.Command)
for i, arg := range cfg.Args {
cfg.Args[i] = expandEnvVar(arg)
}
for k, v := range cfg.Env {
cfg.Env[k] = expandEnvVar(v)
}
cfg.URL = expandEnvVar(cfg.URL)
for k, v := range cfg.Headers {
cfg.Headers[k] = expandEnvVar(v)
}
}
+81
View File
@@ -0,0 +1,81 @@
package config
import (
"os"
"testing"
)
func TestExpandEnvVar(t *testing.T) {
os.Setenv("TEST_MCP_VAR", "hello")
os.Setenv("TEST_MCP_PATH", "/usr/local/bin")
defer os.Unsetenv("TEST_MCP_VAR")
defer os.Unsetenv("TEST_MCP_PATH")
tests := []struct {
name string
input string
expect string
}{
{"plain string", "no vars here", "no vars here"},
{"empty string", "", ""},
{"simple var", "${TEST_MCP_VAR}", "hello"},
{"var in middle", "prefix-${TEST_MCP_VAR}-suffix", "prefix-hello-suffix"},
{"multiple vars", "${TEST_MCP_PATH}/${TEST_MCP_VAR}", "/usr/local/bin/hello"},
{"missing var empty", "${NONEXISTENT_MCP_VAR_XYZ}", ""},
{"default value used", "${NONEXISTENT_MCP_VAR_XYZ:-fallback}", "fallback"},
{"default not used", "${TEST_MCP_VAR:-unused}", "hello"},
{"default with path", "${NONEXISTENT_MCP_VAR_XYZ:-/tmp/default}", "/tmp/default"},
{"unclosed brace", "${UNCLOSED", "${UNCLOSED"},
{"dollar without brace", "$PLAIN", "$PLAIN"},
{"empty var name", "${}", ""},
{"default empty var", "${:-default}", "default"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := expandEnvVar(tt.input)
if got != tt.expect {
t.Errorf("expandEnvVar(%q) = %q, want %q", tt.input, got, tt.expect)
}
})
}
}
func TestExpandConfigEnv(t *testing.T) {
os.Setenv("TEST_MCP_CMD", "python3")
os.Setenv("TEST_MCP_TOKEN", "secret123")
defer os.Unsetenv("TEST_MCP_CMD")
defer os.Unsetenv("TEST_MCP_TOKEN")
cfg := &ExternalMCPServerConfig{
Command: "${TEST_MCP_CMD}",
Args: []string{"--token", "${TEST_MCP_TOKEN}", "${MISSING:-default_arg}"},
Env: map[string]string{"API_KEY": "${TEST_MCP_TOKEN}", "LEVEL": "${MISSING:-INFO}"},
URL: "https://${MISSING:-example.com}/mcp",
Headers: map[string]string{"Authorization": "Bearer ${TEST_MCP_TOKEN}"},
}
ExpandConfigEnv(cfg)
if cfg.Command != "python3" {
t.Errorf("Command = %q, want %q", cfg.Command, "python3")
}
if cfg.Args[1] != "secret123" {
t.Errorf("Args[1] = %q, want %q", cfg.Args[1], "secret123")
}
if cfg.Args[2] != "default_arg" {
t.Errorf("Args[2] = %q, want %q", cfg.Args[2], "default_arg")
}
if cfg.Env["API_KEY"] != "secret123" {
t.Errorf("Env[API_KEY] = %q, want %q", cfg.Env["API_KEY"], "secret123")
}
if cfg.Env["LEVEL"] != "INFO" {
t.Errorf("Env[LEVEL] = %q, want %q", cfg.Env["LEVEL"], "INFO")
}
if cfg.URL != "https://example.com/mcp" {
t.Errorf("URL = %q, want %q", cfg.URL, "https://example.com/mcp")
}
if cfg.Headers["Authorization"] != "Bearer secret123" {
t.Errorf("Headers[Authorization] = %q, want %q", cfg.Headers["Authorization"], "Bearer secret123")
}
}
+15 -2
View File
@@ -4,11 +4,20 @@ import (
"database/sql"
"fmt"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
"go.uber.org/zap"
)
// configureDBPool 设置 SQLite 连接池参数,提升并发稳定性
func configureDBPool(db *sql.DB) {
// SQLite 同一时间只允许一个写入者,限制连接数避免 "database is locked" 错误
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(30 * time.Minute)
}
// DB 数据库连接
type DB struct {
*sql.DB
@@ -17,11 +26,13 @@ type DB struct {
// NewDB 创建数据库连接
func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL")
if err != nil {
return nil, fmt.Errorf("打开数据库失败: %w", err)
}
configureDBPool(db)
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
@@ -674,11 +685,13 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL")
if err != nil {
return nil, fmt.Errorf("打开知识库数据库失败: %w", err)
}
configureDBPool(sqlDB)
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("连接知识库数据库失败: %w", err)
}
+40 -186
View File
@@ -2,11 +2,9 @@
package mcp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
@@ -16,7 +14,6 @@ import (
"cyberstrike-ai/internal/config"
"github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/mcp"
"go.uber.org/zap"
)
@@ -268,172 +265,6 @@ func mustJSON(v interface{}) []byte {
return b
}
// simpleHTTPClient 简单 JSON-RPC over HTTP:每次请求一次 POST、响应在 body。实现 ExternalMCPClient。
// 用于自建 MCP(如 http://127.0.0.1:8081/mcp)或其它仅支持简单 POST 的端点。
type simpleHTTPClient struct {
url string
client *http.Client
logger *zap.Logger
mu sync.RWMutex
status string
}
func newSimpleHTTPClient(ctx context.Context, url string, timeout time.Duration, headers map[string]string, logger *zap.Logger) (ExternalMCPClient, error) {
c := &simpleHTTPClient{
url: url,
client: httpClientWithTimeoutAndHeaders(timeout, headers),
logger: logger,
status: "connecting",
}
if err := c.initialize(ctx); err != nil {
return nil, err
}
c.mu.Lock()
c.status = "connected"
c.mu.Unlock()
return c, nil
}
func (c *simpleHTTPClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *simpleHTTPClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *simpleHTTPClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *simpleHTTPClient) Initialize(context.Context) error {
return nil // 已在 newSimpleHTTPClient 中完成
}
func (c *simpleHTTPClient) initialize(ctx context.Context) error {
params := InitializeRequest{
ProtocolVersion: ProtocolVersion,
Capabilities: make(map[string]interface{}),
ClientInfo: ClientInfo{Name: clientName, Version: clientVersion},
}
paramsJSON, _ := json.Marshal(params)
req := &Message{
ID: MessageID{value: "1"},
Method: "initialize",
Version: "2.0",
Params: paramsJSON,
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return fmt.Errorf("initialize: %w", err)
}
if resp.Error != nil {
return fmt.Errorf("initialize: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
// 发送 notifications/initialized(协议要求)
notify := &Message{
ID: MessageID{value: nil},
Method: "notifications/initialized",
Version: "2.0",
Params: json.RawMessage("{}"),
}
_ = c.sendNotification(notify)
return nil
}
func (c *simpleHTTPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
body, err := json.Marshal(msg)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(b))
}
var out Message
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, err
}
return &out, nil
}
func (c *simpleHTTPClient) sendNotification(msg *Message) error {
body, _ := json.Marshal(msg)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return err
}
resp.Body.Close()
return nil
}
func (c *simpleHTTPClient) ListTools(ctx context.Context) ([]Tool, error) {
req := &Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/list",
Version: "2.0",
Params: json.RawMessage("{}"),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("tools/list: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
var listResp ListToolsResponse
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
return nil, err
}
return listResp.Tools, nil
}
func (c *simpleHTTPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
params := CallToolRequest{Name: name, Arguments: args}
paramsJSON, _ := json.Marshal(params)
req := &Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/call",
Version: "2.0",
Params: paramsJSON,
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("tools/call: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
var callResp CallToolResponse
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
return nil, err
}
return &ToolResult{Content: callResp.Content, IsError: callResp.IsError}, nil
}
func (c *simpleHTTPClient) Close() error {
c.setStatus("disconnected")
return nil
}
// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient
// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。
func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) {
@@ -442,21 +273,23 @@ func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConf
timeout = 30 * time.Second
}
transport := serverCfg.Transport
transport := serverCfg.GetTransportType()
if transport == "" {
if serverCfg.Command != "" {
transport = "stdio"
} else if serverCfg.URL != "" {
transport = "http"
} else {
return nil, fmt.Errorf("配置缺少 command 或 url")
return nil, fmt.Errorf("配置缺少 command 或 url,且未指定 type/transport")
}
// 构造 ClientOptionsKeepAlive 心跳
var clientOpts *mcp.ClientOptions
if serverCfg.KeepAlive > 0 {
clientOpts = &mcp.ClientOptions{
KeepAlive: time.Duration(serverCfg.KeepAlive) * time.Second,
}
}
client := mcp.NewClient(&mcp.Implementation{
Name: clientName,
Version: clientVersion,
}, nil)
}, clientOpts)
var t mcp.Transport
switch transport {
@@ -470,12 +303,18 @@ func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConf
if len(serverCfg.Env) > 0 {
cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...)
}
t = &mcp.CommandTransport{Command: cmd}
ct := &mcp.CommandTransport{Command: cmd}
if serverCfg.TerminateDuration > 0 {
ct.TerminateDuration = time.Duration(serverCfg.TerminateDuration) * time.Second
}
t = ct
case "sse":
if serverCfg.URL == "" {
return nil, fmt.Errorf("sse 模式需要配置 url")
}
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
// SSE 是长连接(GET 流持续打开),不能设置 http.Client.Timeout(会在超时后杀掉整个连接导致 EOF)。
// 超时由每次 ListTools/CallTool 的 context 单独控制。
httpClient := httpClientForLongLived(serverCfg.Headers)
t = &mcp.SSEClientTransport{
Endpoint: serverCfg.URL,
HTTPClient: httpClient,
@@ -485,18 +324,16 @@ func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConf
return nil, fmt.Errorf("http 模式需要配置 url")
}
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
t = &mcp.StreamableClientTransport{
st := &mcp.StreamableClientTransport{
Endpoint: serverCfg.URL,
HTTPClient: httpClient,
}
case "simple_http":
// 简单 JSON-RPC HTTP:每次请求一次 POST、响应在 body。用于自建 MCP 或兼容旧端点(如 http://127.0.0.1:8081/mcp
if serverCfg.URL == "" {
return nil, fmt.Errorf("simple_http 模式需要配置 url")
if serverCfg.MaxRetries > 0 {
st.MaxRetries = serverCfg.MaxRetries
}
return newSimpleHTTPClient(ctx, serverCfg.URL, timeout, serverCfg.Headers, logger)
t = st
default:
return nil, fmt.Errorf("不支持的传输模式: %s", transport)
return nil, fmt.Errorf("不支持的传输模式: %s(支持: stdio, sse, http", transport)
}
session, err := client.Connect(ctx, t, nil)
@@ -538,6 +375,23 @@ func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]s
}
}
// httpClientForLongLived 创建不设超时的 HTTP 客户端,用于 SSE 等长连接传输。
// SSE 的 GET 流会持续打开,http.Client.Timeout 会在超时后强制关闭连接导致 EOF。
// 超时由调用方通过 context 控制。
func httpClientForLongLived(headers map[string]string) *http.Client {
transport := http.DefaultTransport
if len(headers) > 0 {
transport = &headerRoundTripper{
headers: headers,
base: http.DefaultTransport,
}
}
return &http.Client{
Transport: transport,
// 不设 TimeoutSSE 长连接的超时由 per-request context 控制
}
}
type headerRoundTripper struct {
headers map[string]string
base http.RoundTripper
+16 -40
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"cyberstrike-ai/internal/config"
@@ -29,6 +30,7 @@ type ExternalMCPManager struct {
toolCacheMu sync.RWMutex // 工具列表缓存的锁
stopRefresh chan struct{} // 停止后台刷新的信号
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积
mu sync.RWMutex
}
@@ -721,7 +723,13 @@ func (m *ExternalMCPManager) GetToolCounts() map[string]int {
}
// refreshToolCounts 刷新工具数量缓存(后台异步执行)
// 使用 atomic flag 防止并发堆积:如果上一次刷新尚未完成,本次触发直接跳过。
func (m *ExternalMCPManager) refreshToolCounts() {
if !m.refreshing.CompareAndSwap(false, true) {
return // 上一次刷新尚未完成,跳过
}
defer m.refreshing.Store(false)
m.mu.RLock()
clients := make(map[string]ExternalMCPClient)
for k, v := range m.clients {
@@ -874,16 +882,7 @@ func (m *ExternalMCPManager) triggerToolCountRefresh() {
// createClient 创建客户端(不连接)。统一使用官方 MCP Go SDK 的 lazy 客户端,连接在 Initialize 时完成。
func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient {
transport := serverCfg.Transport
if transport == "" {
if serverCfg.Command != "" {
transport = "stdio"
} else if serverCfg.URL != "" {
transport = "http"
} else {
return nil
}
}
transport := serverCfg.GetTransportType()
switch transport {
case "http":
@@ -891,12 +890,6 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
return nil
}
return newLazySDKClient(serverCfg, m.logger)
case "simple_http":
// 简单 HTTP(一次 POST 一次响应),用于自建 MCP 等
if serverCfg.URL == "" {
return nil
}
return newLazySDKClient(serverCfg, m.logger)
case "stdio":
if serverCfg.Command == "" {
return nil
@@ -908,7 +901,11 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
}
return newLazySDKClient(serverCfg, m.logger)
default:
return nil
if transport == "" {
return nil
}
// 未知传输类型也尝试使用 lazy client
return newLazySDKClient(serverCfg, m.logger)
}
}
@@ -990,20 +987,7 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa
// isEnabled 检查是否启用
func (m *ExternalMCPManager) isEnabled(cfg config.ExternalMCPServerConfig) bool {
// 优先使用 ExternalMCPEnable 字段
// 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容)
if cfg.ExternalMCPEnable {
return true
}
// 向后兼容:检查旧字段
if cfg.Disabled {
return false
}
if cfg.Enabled {
return true
}
// 都没有设置,默认为启用
return true
return cfg.ExternalMCPEnable
}
// findSubstring 查找子字符串(简单实现)
@@ -1044,15 +1028,7 @@ func (m *ExternalMCPManager) StartAllEnabled() {
zap.Error(err),
}
// 根据传输模式添加相应的信息
transport := c.Transport
if transport == "" {
if c.Command != "" {
transport = "stdio"
} else if c.URL != "" {
transport = "http"
}
}
transport := c.GetTransportType()
if transport == "http" && c.URL != "" {
fields = append(fields, zap.String("url", c.URL))
+19 -23
View File
@@ -16,12 +16,11 @@ func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) {
// 测试添加stdio配置
stdioCfg := config.ExternalMCPServerConfig{
Command: "python3",
Args: []string{"/path/to/script.py"},
Transport: "stdio",
Description: "Test stdio MCP",
Timeout: 30,
Enabled: true,
Command: "python3",
Args: []string{"/path/to/script.py"},
Description: "Test stdio MCP",
Timeout: 30,
ExternalMCPEnable: true,
}
err := manager.AddOrUpdateConfig("test-stdio", stdioCfg)
@@ -31,11 +30,11 @@ func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) {
// 测试添加HTTP配置
httpCfg := config.ExternalMCPServerConfig{
Transport: "http",
URL: "http://127.0.0.1:8081/mcp",
Description: "Test HTTP MCP",
Timeout: 30,
Enabled: false,
Type: "http",
URL: "http://127.0.0.1:8081/mcp",
Description: "Test HTTP MCP",
Timeout: 30,
ExternalMCPEnable: false,
}
err = manager.AddOrUpdateConfig("test-http", httpCfg)
@@ -64,8 +63,7 @@ func TestExternalMCPManager_RemoveConfig(t *testing.T) {
cfg := config.ExternalMCPServerConfig{
Command: "python3",
Transport: "stdio",
Enabled: false,
ExternalMCPEnable: false,
}
manager.AddOrUpdateConfig("test-remove", cfg)
@@ -89,18 +87,17 @@ func TestExternalMCPManager_GetStats(t *testing.T) {
// 添加多个配置
manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
ExternalMCPEnable: true,
})
manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: true,
ExternalMCPEnable: true,
})
manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: false,
Disabled: true, // 明确设置为禁用
ExternalMCPEnable: false,
})
stats := manager.GetStats()
@@ -126,11 +123,11 @@ func TestExternalMCPManager_LoadConfigs(t *testing.T) {
Servers: map[string]config.ExternalMCPServerConfig{
"loaded1": {
Command: "python3",
Enabled: true,
ExternalMCPEnable: true,
},
"loaded2": {
URL: "http://127.0.0.1:8081/mcp",
Enabled: false,
ExternalMCPEnable: false,
},
},
}
@@ -156,7 +153,7 @@ func TestLazySDKClient_InitializeFails(t *testing.T) {
logger := zap.NewNop()
// 使用不存在的 HTTP 地址,Initialize 应失败
cfg := config.ExternalMCPServerConfig{
Transport: "http",
Type: "http",
URL: "http://127.0.0.1:19999/nonexistent",
Timeout: 2,
}
@@ -180,8 +177,7 @@ func TestExternalMCPManager_StartStopClient(t *testing.T) {
// 添加一个禁用的配置
cfg := config.ExternalMCPServerConfig{
Command: "python3",
Transport: "stdio",
Enabled: false,
ExternalMCPEnable: false,
}
manager.AddOrUpdateConfig("test-start-stop", cfg)
@@ -200,7 +196,7 @@ func TestExternalMCPManager_StartStopClient(t *testing.T) {
// 验证配置已更新为禁用
configs := manager.GetConfigs()
if configs["test-start-stop"].Enabled {
if configs["test-start-stop"].ExternalMCPEnable {
t.Error("配置应该已被禁用")
}
}
+41 -34
View File
@@ -230,54 +230,61 @@ attemptLoop:
continue
}
if ev.Err != nil {
canRetry := attempt+1 < maxToolCallRecoveryAttempts
if canRetry && isRecoverableToolCallArgumentsJSONError(ev.Err) {
if logger != nil {
logger.Warn("eino: recoverable tool-call JSON error from model/API", zap.Error(ev.Err), zap.Int("attempt", attempt))
}
retryHints = append(retryHints, toolCallArgumentsJSONRetryHint())
if progress != nil {
progress("eino_recovery", toolCallArgumentsJSONRecoveryTimelineMessage(attempt), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"einoRetry": attempt,
"runIndex": attempt + 1,
"maxRuns": maxToolCallRecoveryAttempts,
"reason": "invalid_tool_arguments_json",
})
}
continue attemptLoop
}
if canRetry && isRecoverableToolExecutionError(ev.Err) {
if logger != nil {
logger.Warn("eino: recoverable tool execution error, will retry with corrective hint",
zap.Error(ev.Err), zap.Int("attempt", attempt))
}
// context.Canceled 是唯一应当直接终止编排的错误(用户关闭页面、主动停止等)。
if errors.Is(ev.Err, context.Canceled) {
flushAllPendingAsFailed(ev.Err)
retryHints = append(retryHints, toolExecutionRetryHint())
if progress != nil {
progress("eino_recovery", toolExecutionRecoveryTimelineMessage(attempt), map[string]interface{}{
progress("error", ev.Err.Error(), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"einoRetry": attempt,
"runIndex": attempt + 1,
"maxRuns": maxToolCallRecoveryAttempts,
"reason": "tool_execution_error",
})
}
continue attemptLoop
return nil, ev.Err
}
canRetry := attempt+1 < maxToolCallRecoveryAttempts
if !canRetry {
// 重试次数已耗尽,终止。
flushAllPendingAsFailed(ev.Err)
if progress != nil {
progress("error", ev.Err.Error(), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
})
}
return nil, ev.Err
}
// 区分错误类型以选择最合适的纠错提示,但无论哪种都执行重试(default-soft)。
var hint *schema.Message
var reason, timelineMsg string
if isRecoverableToolCallArgumentsJSONError(ev.Err) {
hint = toolCallArgumentsJSONRetryHint()
reason = "invalid_tool_arguments_json"
timelineMsg = toolCallArgumentsJSONRecoveryTimelineMessage(attempt)
} else {
hint = toolExecutionRetryHint()
reason = "tool_execution_error"
timelineMsg = toolExecutionRecoveryTimelineMessage(attempt)
}
if logger != nil {
logger.Warn("eino: recoverable error, will retry with corrective hint",
zap.Error(ev.Err), zap.Int("attempt", attempt), zap.String("reason", reason))
}
flushAllPendingAsFailed(ev.Err)
retryHints = append(retryHints, hint)
if progress != nil {
progress("error", ev.Err.Error(), map[string]interface{}{
progress("eino_recovery", timelineMsg, map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"einoRetry": attempt,
"runIndex": attempt + 1,
"maxRuns": maxToolCallRecoveryAttempts,
"reason": reason,
})
}
return nil, ev.Err
continue attemptLoop
}
if ev.AgentName != "" && progress != nil {
iterEinoAgent := orchestratorName
+12 -47
View File
@@ -41,62 +41,27 @@ func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware {
// isSoftRecoverableToolError determines whether a tool execution error should be
// silently converted to a tool-result message rather than crashing the graph.
//
// Design: default-soft (blacklist). Almost every tool execution error should be
// fed back to the LLM so it can self-correct or choose an alternative tool.
// Only a small set of "truly fatal" conditions (user cancellation) should
// propagate as hard errors that terminate the orchestration graph.
// This avoids the fragile whitelist approach where every new error pattern
// would need to be explicitly enumerated.
func isSoftRecoverableToolError(err error) bool {
if err == nil {
return false
}
// 用户取消 — 不应重试,让 hard error 传播以终止编排
// 用户主动取消 — 唯一应当终止编排的情况,不应重试。
if errors.Is(err, context.Canceled) {
return false
}
// 工具执行超时 — 转为 soft error 让 LLM 知晓并选择替代方案,而非全局重试。
if errors.Is(err, context.DeadlineExceeded) {
return true
}
s := strings.ToLower(err.Error())
// JSON unmarshal/parse failures — the model generated truncated or malformed arguments.
if isJSONRelatedError(s) {
return true
}
// Sub-agent type not found (from deep/task_tool.go)
if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") {
return true
}
// Tool not found in ToolsNode indexes
if strings.Contains(s, "tool") && strings.Contains(s, "not found") {
return true
}
return false
}
// isJSONRelatedError checks whether an error string indicates a JSON parsing problem.
func isJSONRelatedError(lower string) bool {
if !strings.Contains(lower, "json") {
return false
}
jsonIndicators := []string{
"unexpected end of json",
"unmarshal",
"invalid character",
"cannot unmarshal",
"invalid tool arguments",
"failed to unmarshal",
"must be in json format",
"unexpected eof",
}
for _, ind := range jsonIndicators {
if strings.Contains(lower, ind) {
return true
}
}
return false
// 其他所有工具执行错误(超时、命令不存在、JSON 解析失败、工具未找到、
// 权限不足、网络不可达……)一律转为 soft error,让 LLM 看到错误信息
// 后自行决策:换工具、调整参数、或向用户说明。
return true
}
// buildSoftRecoveryMessage creates a bilingual error message that the LLM can act on.
@@ -53,7 +53,12 @@ func TestIsSoftRecoverableToolError(t *testing.T) {
{
name: "unrelated network error",
err: errors.New("connection refused"),
expected: false,
expected: true, // default-soft: non-cancel errors are recoverable
},
{
name: "tool binary not installed",
err: errors.New("[LocalFunc] failed to invoke tool, toolName=grep, err=ripgrep (rg) is not installed or not in PATH"),
expected: true,
},
{
name: "context cancelled",
@@ -131,15 +136,16 @@ func TestSoftRecoveryToolCallMiddleware_PropagatesNonRecoverable(t *testing.T) {
return nil, origErr
}
wrapped := mw(next)
_, err := wrapped(context.Background(), &compose.ToolInput{
out, err := wrapped(context.Background(), &compose.ToolInput{
Name: "test_tool",
Arguments: `{}`,
})
if err == nil {
t.Fatal("expected error to propagate for non-recoverable errors")
// Default-soft: non-cancel errors are converted to tool-result messages.
if err != nil {
t.Fatalf("expected nil error (soft recovery), got: %v", err)
}
if err != origErr {
t.Fatalf("expected original error, got: %v", err)
if out == nil || out.Result == "" {
t.Fatal("expected non-empty recovery message")
}
}
+18 -50
View File
@@ -2,74 +2,42 @@ package multiagent
import (
"fmt"
"strings"
"github.com/cloudwego/eino/schema"
)
// isRecoverableToolExecutionError detects tool-level execution errors that can be
// recovered by retrying with a corrective hint. These errors originate from eino
// framework internals (e.g. task_tool.go, tool_node.go) when the LLM produces
// invalid tool calls such as non-existent sub-agent types, malformed JSON arguments,
// or unregistered tool names.
func isRecoverableToolExecutionError(err error) bool {
if err == nil {
return false
}
s := strings.ToLower(err.Error())
// Sub-agent type not found (from deep/task_tool.go)
if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") {
return true
}
// Tool not found in toolsNode indexes (from compose/tool_node.go, when UnknownToolsHandler is nil)
if strings.Contains(s, "tool") && strings.Contains(s, "not found") {
return true
}
// Invalid tool arguments JSON (from einomcp/mcp_tools.go or eino internals)
if strings.Contains(s, "invalid tool arguments json") {
return true
}
// Failed to unmarshal task tool input json (from deep/task_tool.go)
if strings.Contains(s, "failed to unmarshal") && strings.Contains(s, "json") {
return true
}
// Generic tool call stream/invoke failure wrapping the above
if (strings.Contains(s, "failed to stream tool call") || strings.Contains(s, "failed to invoke tool")) &&
(strings.Contains(s, "not found") || strings.Contains(s, "json") || strings.Contains(s, "unmarshal")) {
return true
}
return false
}
// toolExecutionRetryHint returns a user message appended to the conversation to prompt
// the LLM to correct its tool call after a tool execution error.
// the LLM to adjust after a tool execution error (tool not found, binary missing,
// runtime failure, network error, etc.).
func toolExecutionRetryHint() *schema.Message {
return schema.UserMessage(`[System] Your previous tool call failed because:
- The tool or sub-agent name you used does not exist, OR
return schema.UserMessage(`[System] Your previous tool call failed. Possible causes:
- The tool or sub-agent name does not exist (typo or unregistered name).
- The tool call arguments were not valid JSON.
- The tool's underlying binary is not installed or not in PATH.
- The tool encountered a runtime error (timeout, network failure, permission denied, etc.).
Please carefully review the available tools and sub-agents listed in your context, use only exact registered names (case-sensitive), and ensure all arguments are well-formed JSON objects. Then retry your action.
Please review the error message above, check available tools, and either:
1. Retry with corrected arguments or a different tool, OR
2. Inform the user about the limitation and proceed with an alternative approach.
[系统提示] 上一次工具调用失败,可能原因:
- 你使用的工具名或子代理名称不存在;
- 工具调用参数不是合法 JSON
- 工具名或子代理名称不存在(拼写错误或未注册)
- 工具调用参数不是合法 JSON
- 工具依赖的底层二进制程序未安装或不在 PATH 中;
- 工具运行时遇到错误(超时、网络故障、权限不足等)。
仔细检查上下文中列出的可用工具和子代理名称(须完全匹配、区分大小写),确保所有参数均为合法的 JSON 对象,然后重新执行。`)
根据上述错误信息检查可用工具,然后:
1. 修正参数或改用其他工具重试,或者
2. 告知用户当前限制并采用替代方案继续。`)
}
// toolExecutionRecoveryTimelineMessage returns a message for the eino_recovery event
// displayed in the UI timeline when a tool execution error triggers a retry.
func toolExecutionRecoveryTimelineMessage(attempt int) string {
return fmt.Sprintf(
"工具调用执行失败(工具/子代理名称不存在或参数 JSON 无效)。已向对话追加纠错提示并要求模型重新生成。"+
"工具调用执行失败。已向对话追加纠错提示并要求模型调整策略。"+
"当前为第 %d/%d 轮完整运行。\n\n"+
"Tool call execution failed (unknown tool/sub-agent name or invalid JSON arguments). "+
"Tool call execution failed. "+
"A corrective hint was appended. This is full run %d of %d.",
attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts,
)