mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-05 22:06:41 +02:00
Add files via upload
This commit is contained in:
+46
-25
@@ -257,28 +257,52 @@ type ExternalMCPConfig struct {
|
||||
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
// ExternalMCPServerConfig 外部MCP服务器配置
|
||||
// ExternalMCPServerConfig 外部MCP服务器配置(遵循官方 MCP 配置格式,兼容 Claude Desktop / Cursor / VS Code)。
|
||||
// 所有字符串字段均支持 ${VAR} 和 ${VAR:-default} 环境变量展开语法。
|
||||
type ExternalMCPServerConfig struct {
|
||||
// stdio模式配置
|
||||
// 传输类型: "stdio" | "sse" | "http"(Streamable HTTP)。
|
||||
// stdio 模式可省略,有 command 字段时自动推断。
|
||||
Type string `yaml:"type,omitempty" json:"type,omitempty"`
|
||||
|
||||
// stdio 模式配置
|
||||
Command string `yaml:"command,omitempty" json:"command,omitempty"`
|
||||
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
|
||||
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式)
|
||||
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"`
|
||||
|
||||
// HTTP模式配置
|
||||
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "stdio" | "sse" | "http"(Streamable) | "simple_http"(自建/简单POST端点,如本机 http://127.0.0.1:8081/mcp)
|
||||
URL string `yaml:"url,omitempty" json:"url,omitempty"`
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` // HTTP/SSE 请求头(如 x-api-key)
|
||||
// HTTP/SSE 模式配置
|
||||
URL string `yaml:"url,omitempty" json:"url,omitempty"`
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
// 官方标准字段
|
||||
Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 禁用服务器(官方字段)
|
||||
AutoApprove []string `yaml:"autoApprove,omitempty" json:"autoApprove,omitempty"` // 自动批准的工具列表(官方字段)
|
||||
|
||||
// SDK 高级配置(对应 MCP Go SDK 传输层参数)
|
||||
MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // Streamable HTTP 断线重连次数(默认 5)
|
||||
TerminateDuration int `yaml:"terminate_duration,omitempty" json:"terminate_duration,omitempty"` // stdio 进程优雅关闭等待秒数(默认 5)
|
||||
KeepAlive int `yaml:"keep_alive,omitempty" json:"keep_alive,omitempty"` // 客户端心跳间隔秒数(0 = 禁用)
|
||||
|
||||
// 通用配置
|
||||
Description string `yaml:"description,omitempty" json:"description,omitempty"`
|
||||
Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 超时时间(秒)
|
||||
ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用外部MCP
|
||||
ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态(工具名称 -> 是否启用)
|
||||
|
||||
// 向后兼容字段(已废弃,保留用于读取旧配置)
|
||||
Enabled bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` // 已废弃,使用 external_mcp_enable
|
||||
Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 已废弃,使用 external_mcp_enable
|
||||
Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 连接超时(秒)
|
||||
ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用
|
||||
ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态
|
||||
}
|
||||
|
||||
// GetTransportType 返回实际传输类型。优先读 Type,否则根据 Command/URL 自动推断。
|
||||
func (c ExternalMCPServerConfig) GetTransportType() string {
|
||||
if c.Type != "" {
|
||||
return c.Type
|
||||
}
|
||||
if c.Command != "" {
|
||||
return "stdio"
|
||||
}
|
||||
if c.URL != "" {
|
||||
return "http"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Command string `yaml:"command"`
|
||||
@@ -369,23 +393,20 @@ func Load(path string) (*Config, error) {
|
||||
cfg.Security.Tools = tools
|
||||
}
|
||||
|
||||
// 迁移外部MCP配置:将旧的 enabled/disabled 字段迁移到 external_mcp_enable
|
||||
// 外部 MCP:迁移 + 环境变量展开
|
||||
if cfg.ExternalMCP.Servers != nil {
|
||||
for name, serverCfg := range cfg.ExternalMCP.Servers {
|
||||
// 如果已经设置了 external_mcp_enable,跳过迁移
|
||||
// 否则从 enabled/disabled 字段迁移
|
||||
// 注意:由于 ExternalMCPEnable 是 bool 类型,零值为 false,所以需要检查是否真的设置了
|
||||
// 这里我们通过检查旧的 enabled/disabled 字段来判断是否需要迁移
|
||||
// 官方 disabled 字段 → ExternalMCPEnable
|
||||
if serverCfg.Disabled {
|
||||
// 旧配置使用 disabled,迁移到 external_mcp_enable
|
||||
serverCfg.ExternalMCPEnable = false
|
||||
} else if serverCfg.Enabled {
|
||||
// 旧配置使用 enabled,迁移到 external_mcp_enable
|
||||
serverCfg.ExternalMCPEnable = true
|
||||
} else {
|
||||
// 都没有设置,默认为启用
|
||||
} else if !serverCfg.ExternalMCPEnable {
|
||||
// 默认启用
|
||||
serverCfg.ExternalMCPEnable = true
|
||||
}
|
||||
|
||||
// 展开所有 ${VAR} / ${VAR:-default} 环境变量引用
|
||||
ExpandConfigEnv(&serverCfg)
|
||||
|
||||
cfg.ExternalMCP.Servers[name] = serverCfg
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// expandEnvVar 展开字符串中的 ${VAR} 和 ${VAR:-default} 环境变量引用。
|
||||
// 与官方 MCP 配置格式一致(Claude Desktop / Cursor / VS Code 均支持此语法)。
|
||||
func expandEnvVar(s string) string {
|
||||
var b strings.Builder
|
||||
i := 0
|
||||
for i < len(s) {
|
||||
// 查找 ${
|
||||
idx := strings.Index(s[i:], "${")
|
||||
if idx < 0 {
|
||||
b.WriteString(s[i:])
|
||||
break
|
||||
}
|
||||
b.WriteString(s[i : i+idx])
|
||||
i += idx + 2 // skip ${
|
||||
|
||||
// 查找对应的 }
|
||||
end := strings.IndexByte(s[i:], '}')
|
||||
if end < 0 {
|
||||
// 没有 },原样保留
|
||||
b.WriteString("${")
|
||||
continue
|
||||
}
|
||||
expr := s[i : i+end]
|
||||
i += end + 1 // skip }
|
||||
|
||||
// 解析 VAR:-default
|
||||
varName := expr
|
||||
defaultVal := ""
|
||||
hasDefault := false
|
||||
if colonIdx := strings.Index(expr, ":-"); colonIdx >= 0 {
|
||||
varName = expr[:colonIdx]
|
||||
defaultVal = expr[colonIdx+2:]
|
||||
hasDefault = true
|
||||
}
|
||||
|
||||
val := os.Getenv(varName)
|
||||
if val == "" && hasDefault {
|
||||
val = defaultVal
|
||||
}
|
||||
b.WriteString(val)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// ExpandConfigEnv 展开 ExternalMCPServerConfig 中所有支持环境变量的字段。
|
||||
// 展开范围:Command、Args、Env values、URL、Headers values。
|
||||
func ExpandConfigEnv(cfg *ExternalMCPServerConfig) {
|
||||
cfg.Command = expandEnvVar(cfg.Command)
|
||||
for i, arg := range cfg.Args {
|
||||
cfg.Args[i] = expandEnvVar(arg)
|
||||
}
|
||||
for k, v := range cfg.Env {
|
||||
cfg.Env[k] = expandEnvVar(v)
|
||||
}
|
||||
cfg.URL = expandEnvVar(cfg.URL)
|
||||
for k, v := range cfg.Headers {
|
||||
cfg.Headers[k] = expandEnvVar(v)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExpandEnvVar(t *testing.T) {
|
||||
os.Setenv("TEST_MCP_VAR", "hello")
|
||||
os.Setenv("TEST_MCP_PATH", "/usr/local/bin")
|
||||
defer os.Unsetenv("TEST_MCP_VAR")
|
||||
defer os.Unsetenv("TEST_MCP_PATH")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{"plain string", "no vars here", "no vars here"},
|
||||
{"empty string", "", ""},
|
||||
{"simple var", "${TEST_MCP_VAR}", "hello"},
|
||||
{"var in middle", "prefix-${TEST_MCP_VAR}-suffix", "prefix-hello-suffix"},
|
||||
{"multiple vars", "${TEST_MCP_PATH}/${TEST_MCP_VAR}", "/usr/local/bin/hello"},
|
||||
{"missing var empty", "${NONEXISTENT_MCP_VAR_XYZ}", ""},
|
||||
{"default value used", "${NONEXISTENT_MCP_VAR_XYZ:-fallback}", "fallback"},
|
||||
{"default not used", "${TEST_MCP_VAR:-unused}", "hello"},
|
||||
{"default with path", "${NONEXISTENT_MCP_VAR_XYZ:-/tmp/default}", "/tmp/default"},
|
||||
{"unclosed brace", "${UNCLOSED", "${UNCLOSED"},
|
||||
{"dollar without brace", "$PLAIN", "$PLAIN"},
|
||||
{"empty var name", "${}", ""},
|
||||
{"default empty var", "${:-default}", "default"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := expandEnvVar(tt.input)
|
||||
if got != tt.expect {
|
||||
t.Errorf("expandEnvVar(%q) = %q, want %q", tt.input, got, tt.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandConfigEnv(t *testing.T) {
|
||||
os.Setenv("TEST_MCP_CMD", "python3")
|
||||
os.Setenv("TEST_MCP_TOKEN", "secret123")
|
||||
defer os.Unsetenv("TEST_MCP_CMD")
|
||||
defer os.Unsetenv("TEST_MCP_TOKEN")
|
||||
|
||||
cfg := &ExternalMCPServerConfig{
|
||||
Command: "${TEST_MCP_CMD}",
|
||||
Args: []string{"--token", "${TEST_MCP_TOKEN}", "${MISSING:-default_arg}"},
|
||||
Env: map[string]string{"API_KEY": "${TEST_MCP_TOKEN}", "LEVEL": "${MISSING:-INFO}"},
|
||||
URL: "https://${MISSING:-example.com}/mcp",
|
||||
Headers: map[string]string{"Authorization": "Bearer ${TEST_MCP_TOKEN}"},
|
||||
}
|
||||
|
||||
ExpandConfigEnv(cfg)
|
||||
|
||||
if cfg.Command != "python3" {
|
||||
t.Errorf("Command = %q, want %q", cfg.Command, "python3")
|
||||
}
|
||||
if cfg.Args[1] != "secret123" {
|
||||
t.Errorf("Args[1] = %q, want %q", cfg.Args[1], "secret123")
|
||||
}
|
||||
if cfg.Args[2] != "default_arg" {
|
||||
t.Errorf("Args[2] = %q, want %q", cfg.Args[2], "default_arg")
|
||||
}
|
||||
if cfg.Env["API_KEY"] != "secret123" {
|
||||
t.Errorf("Env[API_KEY] = %q, want %q", cfg.Env["API_KEY"], "secret123")
|
||||
}
|
||||
if cfg.Env["LEVEL"] != "INFO" {
|
||||
t.Errorf("Env[LEVEL] = %q, want %q", cfg.Env["LEVEL"], "INFO")
|
||||
}
|
||||
if cfg.URL != "https://example.com/mcp" {
|
||||
t.Errorf("URL = %q, want %q", cfg.URL, "https://example.com/mcp")
|
||||
}
|
||||
if cfg.Headers["Authorization"] != "Bearer secret123" {
|
||||
t.Errorf("Headers[Authorization] = %q, want %q", cfg.Headers["Authorization"], "Bearer secret123")
|
||||
}
|
||||
}
|
||||
@@ -4,11 +4,20 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// configureDBPool 设置 SQLite 连接池参数,提升并发稳定性
|
||||
func configureDBPool(db *sql.DB) {
|
||||
// SQLite 同一时间只允许一个写入者,限制连接数避免 "database is locked" 错误
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(30 * time.Minute)
|
||||
}
|
||||
|
||||
// DB 数据库连接
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
@@ -17,11 +26,13 @@ type DB struct {
|
||||
|
||||
// NewDB 创建数据库连接
|
||||
func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
|
||||
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库失败: %w", err)
|
||||
}
|
||||
|
||||
configureDBPool(db)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
@@ -674,11 +685,13 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
|
||||
|
||||
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
|
||||
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
|
||||
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开知识库数据库失败: %w", err)
|
||||
}
|
||||
|
||||
configureDBPool(sqlDB)
|
||||
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("连接知识库数据库失败: %w", err)
|
||||
}
|
||||
|
||||
+40
-186
@@ -2,11 +2,9 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -16,7 +14,6 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -268,172 +265,6 @@ func mustJSON(v interface{}) []byte {
|
||||
return b
|
||||
}
|
||||
|
||||
// simpleHTTPClient 简单 JSON-RPC over HTTP:每次请求一次 POST、响应在 body。实现 ExternalMCPClient。
|
||||
// 用于自建 MCP(如 http://127.0.0.1:8081/mcp)或其它仅支持简单 POST 的端点。
|
||||
type simpleHTTPClient struct {
|
||||
url string
|
||||
client *http.Client
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
status string
|
||||
}
|
||||
|
||||
func newSimpleHTTPClient(ctx context.Context, url string, timeout time.Duration, headers map[string]string, logger *zap.Logger) (ExternalMCPClient, error) {
|
||||
c := &simpleHTTPClient{
|
||||
url: url,
|
||||
client: httpClientWithTimeoutAndHeaders(timeout, headers),
|
||||
logger: logger,
|
||||
status: "connecting",
|
||||
}
|
||||
if err := c.initialize(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.status = "connected"
|
||||
c.mu.Unlock()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) setStatus(s string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = s
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) IsConnected() bool {
|
||||
return c.GetStatus() == "connected"
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) Initialize(context.Context) error {
|
||||
return nil // 已在 newSimpleHTTPClient 中完成
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) initialize(ctx context.Context) error {
|
||||
params := InitializeRequest{
|
||||
ProtocolVersion: ProtocolVersion,
|
||||
Capabilities: make(map[string]interface{}),
|
||||
ClientInfo: ClientInfo{Name: clientName, Version: clientVersion},
|
||||
}
|
||||
paramsJSON, _ := json.Marshal(params)
|
||||
req := &Message{
|
||||
ID: MessageID{value: "1"},
|
||||
Method: "initialize",
|
||||
Version: "2.0",
|
||||
Params: paramsJSON,
|
||||
}
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize: %w", err)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("initialize: %s (code %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
// 发送 notifications/initialized(协议要求)
|
||||
notify := &Message{
|
||||
ID: MessageID{value: nil},
|
||||
Method: "notifications/initialized",
|
||||
Version: "2.0",
|
||||
Params: json.RawMessage("{}"),
|
||||
}
|
||||
_ = c.sendNotification(notify)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(b))
|
||||
}
|
||||
var out Message
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) sendNotification(msg *Message) error {
|
||||
body, _ := json.Marshal(msg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
req := &Message{
|
||||
ID: MessageID{value: uuid.New().String()},
|
||||
Method: "tools/list",
|
||||
Version: "2.0",
|
||||
Params: json.RawMessage("{}"),
|
||||
}
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("tools/list: %s (code %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
var listResp ListToolsResponse
|
||||
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return listResp.Tools, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
params := CallToolRequest{Name: name, Arguments: args}
|
||||
paramsJSON, _ := json.Marshal(params)
|
||||
req := &Message{
|
||||
ID: MessageID{value: uuid.New().String()},
|
||||
Method: "tools/call",
|
||||
Version: "2.0",
|
||||
Params: paramsJSON,
|
||||
}
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("tools/call: %s (code %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
var callResp CallToolResponse
|
||||
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ToolResult{Content: callResp.Content, IsError: callResp.IsError}, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) Close() error {
|
||||
c.setStatus("disconnected")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient
|
||||
// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。
|
||||
func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) {
|
||||
@@ -442,21 +273,23 @@ func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConf
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
transport := serverCfg.Transport
|
||||
transport := serverCfg.GetTransportType()
|
||||
if transport == "" {
|
||||
if serverCfg.Command != "" {
|
||||
transport = "stdio"
|
||||
} else if serverCfg.URL != "" {
|
||||
transport = "http"
|
||||
} else {
|
||||
return nil, fmt.Errorf("配置缺少 command 或 url")
|
||||
return nil, fmt.Errorf("配置缺少 command 或 url,且未指定 type/transport")
|
||||
}
|
||||
|
||||
// 构造 ClientOptions:KeepAlive 心跳
|
||||
var clientOpts *mcp.ClientOptions
|
||||
if serverCfg.KeepAlive > 0 {
|
||||
clientOpts = &mcp.ClientOptions{
|
||||
KeepAlive: time.Duration(serverCfg.KeepAlive) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
client := mcp.NewClient(&mcp.Implementation{
|
||||
Name: clientName,
|
||||
Version: clientVersion,
|
||||
}, nil)
|
||||
}, clientOpts)
|
||||
|
||||
var t mcp.Transport
|
||||
switch transport {
|
||||
@@ -470,12 +303,18 @@ func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConf
|
||||
if len(serverCfg.Env) > 0 {
|
||||
cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...)
|
||||
}
|
||||
t = &mcp.CommandTransport{Command: cmd}
|
||||
ct := &mcp.CommandTransport{Command: cmd}
|
||||
if serverCfg.TerminateDuration > 0 {
|
||||
ct.TerminateDuration = time.Duration(serverCfg.TerminateDuration) * time.Second
|
||||
}
|
||||
t = ct
|
||||
case "sse":
|
||||
if serverCfg.URL == "" {
|
||||
return nil, fmt.Errorf("sse 模式需要配置 url")
|
||||
}
|
||||
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
|
||||
// SSE 是长连接(GET 流持续打开),不能设置 http.Client.Timeout(会在超时后杀掉整个连接导致 EOF)。
|
||||
// 超时由每次 ListTools/CallTool 的 context 单独控制。
|
||||
httpClient := httpClientForLongLived(serverCfg.Headers)
|
||||
t = &mcp.SSEClientTransport{
|
||||
Endpoint: serverCfg.URL,
|
||||
HTTPClient: httpClient,
|
||||
@@ -485,18 +324,16 @@ func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConf
|
||||
return nil, fmt.Errorf("http 模式需要配置 url")
|
||||
}
|
||||
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
|
||||
t = &mcp.StreamableClientTransport{
|
||||
st := &mcp.StreamableClientTransport{
|
||||
Endpoint: serverCfg.URL,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
case "simple_http":
|
||||
// 简单 JSON-RPC HTTP:每次请求一次 POST、响应在 body。用于自建 MCP 或兼容旧端点(如 http://127.0.0.1:8081/mcp)
|
||||
if serverCfg.URL == "" {
|
||||
return nil, fmt.Errorf("simple_http 模式需要配置 url")
|
||||
if serverCfg.MaxRetries > 0 {
|
||||
st.MaxRetries = serverCfg.MaxRetries
|
||||
}
|
||||
return newSimpleHTTPClient(ctx, serverCfg.URL, timeout, serverCfg.Headers, logger)
|
||||
t = st
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的传输模式: %s", transport)
|
||||
return nil, fmt.Errorf("不支持的传输模式: %s(支持: stdio, sse, http)", transport)
|
||||
}
|
||||
|
||||
session, err := client.Connect(ctx, t, nil)
|
||||
@@ -538,6 +375,23 @@ func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]s
|
||||
}
|
||||
}
|
||||
|
||||
// httpClientForLongLived 创建不设超时的 HTTP 客户端,用于 SSE 等长连接传输。
|
||||
// SSE 的 GET 流会持续打开,http.Client.Timeout 会在超时后强制关闭连接导致 EOF。
|
||||
// 超时由调用方通过 context 控制。
|
||||
func httpClientForLongLived(headers map[string]string) *http.Client {
|
||||
transport := http.DefaultTransport
|
||||
if len(headers) > 0 {
|
||||
transport = &headerRoundTripper{
|
||||
headers: headers,
|
||||
base: http.DefaultTransport,
|
||||
}
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
// 不设 Timeout,SSE 长连接的超时由 per-request context 控制
|
||||
}
|
||||
}
|
||||
|
||||
type headerRoundTripper struct {
|
||||
headers map[string]string
|
||||
base http.RoundTripper
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
@@ -29,6 +30,7 @@ type ExternalMCPManager struct {
|
||||
toolCacheMu sync.RWMutex // 工具列表缓存的锁
|
||||
stopRefresh chan struct{} // 停止后台刷新的信号
|
||||
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
|
||||
refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -721,7 +723,13 @@ func (m *ExternalMCPManager) GetToolCounts() map[string]int {
|
||||
}
|
||||
|
||||
// refreshToolCounts 刷新工具数量缓存(后台异步执行)
|
||||
// 使用 atomic flag 防止并发堆积:如果上一次刷新尚未完成,本次触发直接跳过。
|
||||
func (m *ExternalMCPManager) refreshToolCounts() {
|
||||
if !m.refreshing.CompareAndSwap(false, true) {
|
||||
return // 上一次刷新尚未完成,跳过
|
||||
}
|
||||
defer m.refreshing.Store(false)
|
||||
|
||||
m.mu.RLock()
|
||||
clients := make(map[string]ExternalMCPClient)
|
||||
for k, v := range m.clients {
|
||||
@@ -874,16 +882,7 @@ func (m *ExternalMCPManager) triggerToolCountRefresh() {
|
||||
|
||||
// createClient 创建客户端(不连接)。统一使用官方 MCP Go SDK 的 lazy 客户端,连接在 Initialize 时完成。
|
||||
func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient {
|
||||
transport := serverCfg.Transport
|
||||
if transport == "" {
|
||||
if serverCfg.Command != "" {
|
||||
transport = "stdio"
|
||||
} else if serverCfg.URL != "" {
|
||||
transport = "http"
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
transport := serverCfg.GetTransportType()
|
||||
|
||||
switch transport {
|
||||
case "http":
|
||||
@@ -891,12 +890,6 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
|
||||
return nil
|
||||
}
|
||||
return newLazySDKClient(serverCfg, m.logger)
|
||||
case "simple_http":
|
||||
// 简单 HTTP(一次 POST 一次响应),用于自建 MCP 等
|
||||
if serverCfg.URL == "" {
|
||||
return nil
|
||||
}
|
||||
return newLazySDKClient(serverCfg, m.logger)
|
||||
case "stdio":
|
||||
if serverCfg.Command == "" {
|
||||
return nil
|
||||
@@ -908,7 +901,11 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
|
||||
}
|
||||
return newLazySDKClient(serverCfg, m.logger)
|
||||
default:
|
||||
return nil
|
||||
if transport == "" {
|
||||
return nil
|
||||
}
|
||||
// 未知传输类型也尝试使用 lazy client
|
||||
return newLazySDKClient(serverCfg, m.logger)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -990,20 +987,7 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa
|
||||
|
||||
// isEnabled 检查是否启用
|
||||
func (m *ExternalMCPManager) isEnabled(cfg config.ExternalMCPServerConfig) bool {
|
||||
// 优先使用 ExternalMCPEnable 字段
|
||||
// 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容)
|
||||
if cfg.ExternalMCPEnable {
|
||||
return true
|
||||
}
|
||||
// 向后兼容:检查旧字段
|
||||
if cfg.Disabled {
|
||||
return false
|
||||
}
|
||||
if cfg.Enabled {
|
||||
return true
|
||||
}
|
||||
// 都没有设置,默认为启用
|
||||
return true
|
||||
return cfg.ExternalMCPEnable
|
||||
}
|
||||
|
||||
// findSubstring 查找子字符串(简单实现)
|
||||
@@ -1044,15 +1028,7 @@ func (m *ExternalMCPManager) StartAllEnabled() {
|
||||
zap.Error(err),
|
||||
}
|
||||
|
||||
// 根据传输模式添加相应的信息
|
||||
transport := c.Transport
|
||||
if transport == "" {
|
||||
if c.Command != "" {
|
||||
transport = "stdio"
|
||||
} else if c.URL != "" {
|
||||
transport = "http"
|
||||
}
|
||||
}
|
||||
transport := c.GetTransportType()
|
||||
|
||||
if transport == "http" && c.URL != "" {
|
||||
fields = append(fields, zap.String("url", c.URL))
|
||||
|
||||
@@ -16,12 +16,11 @@ func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) {
|
||||
|
||||
// 测试添加stdio配置
|
||||
stdioCfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Args: []string{"/path/to/script.py"},
|
||||
Transport: "stdio",
|
||||
Description: "Test stdio MCP",
|
||||
Timeout: 30,
|
||||
Enabled: true,
|
||||
Command: "python3",
|
||||
Args: []string{"/path/to/script.py"},
|
||||
Description: "Test stdio MCP",
|
||||
Timeout: 30,
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
|
||||
err := manager.AddOrUpdateConfig("test-stdio", stdioCfg)
|
||||
@@ -31,11 +30,11 @@ func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) {
|
||||
|
||||
// 测试添加HTTP配置
|
||||
httpCfg := config.ExternalMCPServerConfig{
|
||||
Transport: "http",
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Description: "Test HTTP MCP",
|
||||
Timeout: 30,
|
||||
Enabled: false,
|
||||
Type: "http",
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Description: "Test HTTP MCP",
|
||||
Timeout: 30,
|
||||
ExternalMCPEnable: false,
|
||||
}
|
||||
|
||||
err = manager.AddOrUpdateConfig("test-http", httpCfg)
|
||||
@@ -64,8 +63,7 @@ func TestExternalMCPManager_RemoveConfig(t *testing.T) {
|
||||
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Transport: "stdio",
|
||||
Enabled: false,
|
||||
ExternalMCPEnable: false,
|
||||
}
|
||||
|
||||
manager.AddOrUpdateConfig("test-remove", cfg)
|
||||
@@ -89,18 +87,17 @@ func TestExternalMCPManager_GetStats(t *testing.T) {
|
||||
// 添加多个配置
|
||||
manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
|
||||
manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: true,
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
|
||||
manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: false,
|
||||
Disabled: true, // 明确设置为禁用
|
||||
ExternalMCPEnable: false,
|
||||
})
|
||||
|
||||
stats := manager.GetStats()
|
||||
@@ -126,11 +123,11 @@ func TestExternalMCPManager_LoadConfigs(t *testing.T) {
|
||||
Servers: map[string]config.ExternalMCPServerConfig{
|
||||
"loaded1": {
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
ExternalMCPEnable: true,
|
||||
},
|
||||
"loaded2": {
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: false,
|
||||
ExternalMCPEnable: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -156,7 +153,7 @@ func TestLazySDKClient_InitializeFails(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
// 使用不存在的 HTTP 地址,Initialize 应失败
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Transport: "http",
|
||||
Type: "http",
|
||||
URL: "http://127.0.0.1:19999/nonexistent",
|
||||
Timeout: 2,
|
||||
}
|
||||
@@ -180,8 +177,7 @@ func TestExternalMCPManager_StartStopClient(t *testing.T) {
|
||||
// 添加一个禁用的配置
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Transport: "stdio",
|
||||
Enabled: false,
|
||||
ExternalMCPEnable: false,
|
||||
}
|
||||
|
||||
manager.AddOrUpdateConfig("test-start-stop", cfg)
|
||||
@@ -200,7 +196,7 @@ func TestExternalMCPManager_StartStopClient(t *testing.T) {
|
||||
|
||||
// 验证配置已更新为禁用
|
||||
configs := manager.GetConfigs()
|
||||
if configs["test-start-stop"].Enabled {
|
||||
if configs["test-start-stop"].ExternalMCPEnable {
|
||||
t.Error("配置应该已被禁用")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,54 +230,61 @@ attemptLoop:
|
||||
continue
|
||||
}
|
||||
if ev.Err != nil {
|
||||
canRetry := attempt+1 < maxToolCallRecoveryAttempts
|
||||
|
||||
if canRetry && isRecoverableToolCallArgumentsJSONError(ev.Err) {
|
||||
if logger != nil {
|
||||
logger.Warn("eino: recoverable tool-call JSON error from model/API", zap.Error(ev.Err), zap.Int("attempt", attempt))
|
||||
}
|
||||
retryHints = append(retryHints, toolCallArgumentsJSONRetryHint())
|
||||
if progress != nil {
|
||||
progress("eino_recovery", toolCallArgumentsJSONRecoveryTimelineMessage(attempt), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"einoRetry": attempt,
|
||||
"runIndex": attempt + 1,
|
||||
"maxRuns": maxToolCallRecoveryAttempts,
|
||||
"reason": "invalid_tool_arguments_json",
|
||||
})
|
||||
}
|
||||
continue attemptLoop
|
||||
}
|
||||
|
||||
if canRetry && isRecoverableToolExecutionError(ev.Err) {
|
||||
if logger != nil {
|
||||
logger.Warn("eino: recoverable tool execution error, will retry with corrective hint",
|
||||
zap.Error(ev.Err), zap.Int("attempt", attempt))
|
||||
}
|
||||
// context.Canceled 是唯一应当直接终止编排的错误(用户关闭页面、主动停止等)。
|
||||
if errors.Is(ev.Err, context.Canceled) {
|
||||
flushAllPendingAsFailed(ev.Err)
|
||||
retryHints = append(retryHints, toolExecutionRetryHint())
|
||||
if progress != nil {
|
||||
progress("eino_recovery", toolExecutionRecoveryTimelineMessage(attempt), map[string]interface{}{
|
||||
progress("error", ev.Err.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"einoRetry": attempt,
|
||||
"runIndex": attempt + 1,
|
||||
"maxRuns": maxToolCallRecoveryAttempts,
|
||||
"reason": "tool_execution_error",
|
||||
})
|
||||
}
|
||||
continue attemptLoop
|
||||
return nil, ev.Err
|
||||
}
|
||||
|
||||
canRetry := attempt+1 < maxToolCallRecoveryAttempts
|
||||
if !canRetry {
|
||||
// 重试次数已耗尽,终止。
|
||||
flushAllPendingAsFailed(ev.Err)
|
||||
if progress != nil {
|
||||
progress("error", ev.Err.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
return nil, ev.Err
|
||||
}
|
||||
|
||||
// 区分错误类型以选择最合适的纠错提示,但无论哪种都执行重试(default-soft)。
|
||||
var hint *schema.Message
|
||||
var reason, timelineMsg string
|
||||
if isRecoverableToolCallArgumentsJSONError(ev.Err) {
|
||||
hint = toolCallArgumentsJSONRetryHint()
|
||||
reason = "invalid_tool_arguments_json"
|
||||
timelineMsg = toolCallArgumentsJSONRecoveryTimelineMessage(attempt)
|
||||
} else {
|
||||
hint = toolExecutionRetryHint()
|
||||
reason = "tool_execution_error"
|
||||
timelineMsg = toolExecutionRecoveryTimelineMessage(attempt)
|
||||
}
|
||||
|
||||
if logger != nil {
|
||||
logger.Warn("eino: recoverable error, will retry with corrective hint",
|
||||
zap.Error(ev.Err), zap.Int("attempt", attempt), zap.String("reason", reason))
|
||||
}
|
||||
flushAllPendingAsFailed(ev.Err)
|
||||
retryHints = append(retryHints, hint)
|
||||
if progress != nil {
|
||||
progress("error", ev.Err.Error(), map[string]interface{}{
|
||||
progress("eino_recovery", timelineMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"einoRetry": attempt,
|
||||
"runIndex": attempt + 1,
|
||||
"maxRuns": maxToolCallRecoveryAttempts,
|
||||
"reason": reason,
|
||||
})
|
||||
}
|
||||
return nil, ev.Err
|
||||
continue attemptLoop
|
||||
}
|
||||
if ev.AgentName != "" && progress != nil {
|
||||
iterEinoAgent := orchestratorName
|
||||
|
||||
@@ -41,62 +41,27 @@ func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
|
||||
// isSoftRecoverableToolError determines whether a tool execution error should be
|
||||
// silently converted to a tool-result message rather than crashing the graph.
|
||||
//
|
||||
// Design: default-soft (blacklist). Almost every tool execution error should be
|
||||
// fed back to the LLM so it can self-correct or choose an alternative tool.
|
||||
// Only a small set of "truly fatal" conditions (user cancellation) should
|
||||
// propagate as hard errors that terminate the orchestration graph.
|
||||
// This avoids the fragile whitelist approach where every new error pattern
|
||||
// would need to be explicitly enumerated.
|
||||
func isSoftRecoverableToolError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 用户取消 — 不应重试,让 hard error 传播以终止编排。
|
||||
// 用户主动取消 — 唯一应当终止编排的情况,不应重试。
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 工具执行超时 — 转为 soft error 让 LLM 知晓并选择替代方案,而非全局重试。
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
|
||||
s := strings.ToLower(err.Error())
|
||||
|
||||
// JSON unmarshal/parse failures — the model generated truncated or malformed arguments.
|
||||
if isJSONRelatedError(s) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Sub-agent type not found (from deep/task_tool.go)
|
||||
if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Tool not found in ToolsNode indexes
|
||||
if strings.Contains(s, "tool") && strings.Contains(s, "not found") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isJSONRelatedError checks whether an error string indicates a JSON parsing problem.
|
||||
func isJSONRelatedError(lower string) bool {
|
||||
if !strings.Contains(lower, "json") {
|
||||
return false
|
||||
}
|
||||
jsonIndicators := []string{
|
||||
"unexpected end of json",
|
||||
"unmarshal",
|
||||
"invalid character",
|
||||
"cannot unmarshal",
|
||||
"invalid tool arguments",
|
||||
"failed to unmarshal",
|
||||
"must be in json format",
|
||||
"unexpected eof",
|
||||
}
|
||||
for _, ind := range jsonIndicators {
|
||||
if strings.Contains(lower, ind) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
// 其他所有工具执行错误(超时、命令不存在、JSON 解析失败、工具未找到、
|
||||
// 权限不足、网络不可达……)一律转为 soft error,让 LLM 看到错误信息
|
||||
// 后自行决策:换工具、调整参数、或向用户说明。
|
||||
return true
|
||||
}
|
||||
|
||||
// buildSoftRecoveryMessage creates a bilingual error message that the LLM can act on.
|
||||
|
||||
@@ -53,7 +53,12 @@ func TestIsSoftRecoverableToolError(t *testing.T) {
|
||||
{
|
||||
name: "unrelated network error",
|
||||
err: errors.New("connection refused"),
|
||||
expected: false,
|
||||
expected: true, // default-soft: non-cancel errors are recoverable
|
||||
},
|
||||
{
|
||||
name: "tool binary not installed",
|
||||
err: errors.New("[LocalFunc] failed to invoke tool, toolName=grep, err=ripgrep (rg) is not installed or not in PATH"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "context cancelled",
|
||||
@@ -131,15 +136,16 @@ func TestSoftRecoveryToolCallMiddleware_PropagatesNonRecoverable(t *testing.T) {
|
||||
return nil, origErr
|
||||
}
|
||||
wrapped := mw(next)
|
||||
_, err := wrapped(context.Background(), &compose.ToolInput{
|
||||
out, err := wrapped(context.Background(), &compose.ToolInput{
|
||||
Name: "test_tool",
|
||||
Arguments: `{}`,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error to propagate for non-recoverable errors")
|
||||
// Default-soft: non-cancel errors are converted to tool-result messages.
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error (soft recovery), got: %v", err)
|
||||
}
|
||||
if err != origErr {
|
||||
t.Fatalf("expected original error, got: %v", err)
|
||||
if out == nil || out.Result == "" {
|
||||
t.Fatal("expected non-empty recovery message")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,74 +2,42 @@ package multiagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// isRecoverableToolExecutionError detects tool-level execution errors that can be
|
||||
// recovered by retrying with a corrective hint. These errors originate from eino
|
||||
// framework internals (e.g. task_tool.go, tool_node.go) when the LLM produces
|
||||
// invalid tool calls such as non-existent sub-agent types, malformed JSON arguments,
|
||||
// or unregistered tool names.
|
||||
func isRecoverableToolExecutionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.ToLower(err.Error())
|
||||
|
||||
// Sub-agent type not found (from deep/task_tool.go)
|
||||
if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Tool not found in toolsNode indexes (from compose/tool_node.go, when UnknownToolsHandler is nil)
|
||||
if strings.Contains(s, "tool") && strings.Contains(s, "not found") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Invalid tool arguments JSON (from einomcp/mcp_tools.go or eino internals)
|
||||
if strings.Contains(s, "invalid tool arguments json") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Failed to unmarshal task tool input json (from deep/task_tool.go)
|
||||
if strings.Contains(s, "failed to unmarshal") && strings.Contains(s, "json") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Generic tool call stream/invoke failure wrapping the above
|
||||
if (strings.Contains(s, "failed to stream tool call") || strings.Contains(s, "failed to invoke tool")) &&
|
||||
(strings.Contains(s, "not found") || strings.Contains(s, "json") || strings.Contains(s, "unmarshal")) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// toolExecutionRetryHint returns a user message appended to the conversation to prompt
|
||||
// the LLM to correct its tool call after a tool execution error.
|
||||
// the LLM to adjust after a tool execution error (tool not found, binary missing,
|
||||
// runtime failure, network error, etc.).
|
||||
func toolExecutionRetryHint() *schema.Message {
|
||||
return schema.UserMessage(`[System] Your previous tool call failed because:
|
||||
- The tool or sub-agent name you used does not exist, OR
|
||||
return schema.UserMessage(`[System] Your previous tool call failed. Possible causes:
|
||||
- The tool or sub-agent name does not exist (typo or unregistered name).
|
||||
- The tool call arguments were not valid JSON.
|
||||
- The tool's underlying binary is not installed or not in PATH.
|
||||
- The tool encountered a runtime error (timeout, network failure, permission denied, etc.).
|
||||
|
||||
Please carefully review the available tools and sub-agents listed in your context, use only exact registered names (case-sensitive), and ensure all arguments are well-formed JSON objects. Then retry your action.
|
||||
Please review the error message above, check available tools, and either:
|
||||
1. Retry with corrected arguments or a different tool, OR
|
||||
2. Inform the user about the limitation and proceed with an alternative approach.
|
||||
|
||||
[系统提示] 上一次工具调用失败,可能原因:
|
||||
- 你使用的工具名或子代理名称不存在;
|
||||
- 工具调用参数不是合法 JSON。
|
||||
- 工具名或子代理名称不存在(拼写错误或未注册);
|
||||
- 工具调用参数不是合法 JSON;
|
||||
- 工具依赖的底层二进制程序未安装或不在 PATH 中;
|
||||
- 工具运行时遇到错误(超时、网络故障、权限不足等)。
|
||||
|
||||
请仔细检查上下文中列出的可用工具和子代理名称(须完全匹配、区分大小写),确保所有参数均为合法的 JSON 对象,然后重新执行。`)
|
||||
请根据上述错误信息检查可用工具,然后:
|
||||
1. 修正参数或改用其他工具重试,或者
|
||||
2. 告知用户当前限制并采用替代方案继续。`)
|
||||
}
|
||||
|
||||
// toolExecutionRecoveryTimelineMessage returns a message for the eino_recovery event
|
||||
// displayed in the UI timeline when a tool execution error triggers a retry.
|
||||
func toolExecutionRecoveryTimelineMessage(attempt int) string {
|
||||
return fmt.Sprintf(
|
||||
"工具调用执行失败(工具/子代理名称不存在或参数 JSON 无效)。已向对话追加纠错提示并要求模型重新生成。"+
|
||||
"工具调用执行失败。已向对话追加纠错提示并要求模型调整策略。"+
|
||||
"当前为第 %d/%d 轮完整运行。\n\n"+
|
||||
"Tool call execution failed (unknown tool/sub-agent name or invalid JSON arguments). "+
|
||||
"Tool call execution failed. "+
|
||||
"A corrective hint was appended. This is full run %d of %d.",
|
||||
attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user