mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-15 04:51:01 +02:00
Delete mcp directory
This commit is contained in:
@@ -1,133 +0,0 @@
|
||||
package builtin
|
||||
|
||||
// 内置工具名称常量
|
||||
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
|
||||
const (
|
||||
// 漏洞管理工具
|
||||
ToolRecordVulnerability = "record_vulnerability"
|
||||
|
||||
// 知识库工具
|
||||
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
|
||||
ToolSearchKnowledgeBase = "search_knowledge_base"
|
||||
|
||||
// WebShell 助手工具(AI 在 WebShell 管理 - AI 助手 中使用)
|
||||
ToolWebshellExec = "webshell_exec"
|
||||
ToolWebshellFileList = "webshell_file_list"
|
||||
ToolWebshellFileRead = "webshell_file_read"
|
||||
ToolWebshellFileWrite = "webshell_file_write"
|
||||
|
||||
// WebShell 连接管理工具(用于通过 MCP 管理 webshell 连接)
|
||||
ToolManageWebshellList = "manage_webshell_list"
|
||||
ToolManageWebshellAdd = "manage_webshell_add"
|
||||
ToolManageWebshellUpdate = "manage_webshell_update"
|
||||
ToolManageWebshellDelete = "manage_webshell_delete"
|
||||
ToolManageWebshellTest = "manage_webshell_test"
|
||||
|
||||
// 批量任务队列(与 Web 端批量任务一致,供模型创建/启停/查询队列)
|
||||
ToolBatchTaskList = "batch_task_list"
|
||||
ToolBatchTaskGet = "batch_task_get"
|
||||
ToolBatchTaskCreate = "batch_task_create"
|
||||
ToolBatchTaskStart = "batch_task_start"
|
||||
ToolBatchTaskRerun = "batch_task_rerun"
|
||||
ToolBatchTaskPause = "batch_task_pause"
|
||||
ToolBatchTaskDelete = "batch_task_delete"
|
||||
ToolBatchTaskUpdateMetadata = "batch_task_update_metadata"
|
||||
ToolBatchTaskUpdateSchedule = "batch_task_update_schedule"
|
||||
ToolBatchTaskScheduleEnabled = "batch_task_schedule_enabled"
|
||||
ToolBatchTaskAdd = "batch_task_add_task"
|
||||
ToolBatchTaskUpdate = "batch_task_update_task"
|
||||
ToolBatchTaskRemove = "batch_task_remove_task"
|
||||
|
||||
// C2 工具集(合并同类项,8 个统一工具)
|
||||
ToolC2Listener = "c2_listener" // 监听器管理(create/start/stop/list/get/update/delete)
|
||||
ToolC2Session = "c2_session" // 会话管理(list/get/set_sleep/kill/delete)
|
||||
ToolC2Task = "c2_task" // 任务下发(统一 task_type 参数)
|
||||
ToolC2TaskManage = "c2_task_manage" // 任务管理(get_result/wait/list/cancel)
|
||||
ToolC2Payload = "c2_payload" // Payload 生成(oneliner/build)
|
||||
ToolC2Event = "c2_event" // 事件查询
|
||||
ToolC2Profile = "c2_profile" // Malleable Profile 管理(list/get/create/update/delete)
|
||||
ToolC2File = "c2_file" // 文件管理(list/get_result)
|
||||
)
|
||||
|
||||
// IsBuiltinTool 检查工具名称是否是内置工具
|
||||
func IsBuiltinTool(toolName string) bool {
|
||||
switch toolName {
|
||||
case ToolRecordVulnerability,
|
||||
ToolListKnowledgeRiskTypes,
|
||||
ToolSearchKnowledgeBase,
|
||||
ToolWebshellExec,
|
||||
ToolWebshellFileList,
|
||||
ToolWebshellFileRead,
|
||||
ToolWebshellFileWrite,
|
||||
ToolManageWebshellList,
|
||||
ToolManageWebshellAdd,
|
||||
ToolManageWebshellUpdate,
|
||||
ToolManageWebshellDelete,
|
||||
ToolManageWebshellTest,
|
||||
ToolBatchTaskList,
|
||||
ToolBatchTaskGet,
|
||||
ToolBatchTaskCreate,
|
||||
ToolBatchTaskStart,
|
||||
ToolBatchTaskRerun,
|
||||
ToolBatchTaskPause,
|
||||
ToolBatchTaskDelete,
|
||||
ToolBatchTaskUpdateMetadata,
|
||||
ToolBatchTaskUpdateSchedule,
|
||||
ToolBatchTaskScheduleEnabled,
|
||||
ToolBatchTaskAdd,
|
||||
ToolBatchTaskUpdate,
|
||||
ToolBatchTaskRemove,
|
||||
// C2 工具
|
||||
ToolC2Listener,
|
||||
ToolC2Session,
|
||||
ToolC2Task,
|
||||
ToolC2TaskManage,
|
||||
ToolC2Payload,
|
||||
ToolC2Event,
|
||||
ToolC2Profile,
|
||||
ToolC2File:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllBuiltinTools 返回所有内置工具名称列表
|
||||
func GetAllBuiltinTools() []string {
|
||||
return []string{
|
||||
ToolRecordVulnerability,
|
||||
ToolListKnowledgeRiskTypes,
|
||||
ToolSearchKnowledgeBase,
|
||||
ToolWebshellExec,
|
||||
ToolWebshellFileList,
|
||||
ToolWebshellFileRead,
|
||||
ToolWebshellFileWrite,
|
||||
ToolManageWebshellList,
|
||||
ToolManageWebshellAdd,
|
||||
ToolManageWebshellUpdate,
|
||||
ToolManageWebshellDelete,
|
||||
ToolManageWebshellTest,
|
||||
ToolBatchTaskList,
|
||||
ToolBatchTaskGet,
|
||||
ToolBatchTaskCreate,
|
||||
ToolBatchTaskStart,
|
||||
ToolBatchTaskRerun,
|
||||
ToolBatchTaskPause,
|
||||
ToolBatchTaskDelete,
|
||||
ToolBatchTaskUpdateMetadata,
|
||||
ToolBatchTaskUpdateSchedule,
|
||||
ToolBatchTaskScheduleEnabled,
|
||||
ToolBatchTaskAdd,
|
||||
ToolBatchTaskUpdate,
|
||||
ToolBatchTaskRemove,
|
||||
// C2 工具
|
||||
ToolC2Listener,
|
||||
ToolC2Session,
|
||||
ToolC2Task,
|
||||
ToolC2TaskManage,
|
||||
ToolC2Payload,
|
||||
ToolC2Event,
|
||||
ToolC2Profile,
|
||||
ToolC2File,
|
||||
}
|
||||
}
|
||||
@@ -1,405 +0,0 @@
|
||||
// Package mcp 外部 MCP 客户端 - 基于官方 go-sdk 实现,保证协议兼容性
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
clientName = "CyberStrikeAI"
|
||||
clientVersion = "1.0.0"
|
||||
)
|
||||
|
||||
// sdkClient 基于官方 MCP Go SDK 的外部 MCP 客户端,实现 ExternalMCPClient 接口
|
||||
type sdkClient struct {
|
||||
session *mcp.ClientSession
|
||||
client *mcp.Client
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
status string // "disconnected", "connecting", "connected", "error"
|
||||
}
|
||||
|
||||
// newSDKClientFromSession 用已连接成功的 session 构造(供 createSDKClient 内部使用)
|
||||
func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, logger *zap.Logger) *sdkClient {
|
||||
return &sdkClient{
|
||||
session: session,
|
||||
client: client,
|
||||
logger: logger,
|
||||
status: "connected",
|
||||
}
|
||||
}
|
||||
|
||||
// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient
|
||||
type lazySDKClient struct {
|
||||
serverCfg config.ExternalMCPServerConfig
|
||||
logger *zap.Logger
|
||||
inner ExternalMCPClient // 连接成功后为 *sdkClient
|
||||
mu sync.RWMutex
|
||||
status string
|
||||
}
|
||||
|
||||
func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient {
|
||||
return &lazySDKClient{
|
||||
serverCfg: serverCfg,
|
||||
logger: logger,
|
||||
status: "connecting",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) setStatus(s string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = s
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.inner != nil {
|
||||
return c.inner.GetStatus()
|
||||
}
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) IsConnected() bool {
|
||||
c.mu.RLock()
|
||||
inner := c.inner
|
||||
c.mu.RUnlock()
|
||||
if inner != nil {
|
||||
return inner.IsConnected()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) Initialize(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
if c.inner != nil {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
inner, err := createSDKClient(ctx, c.serverCfg, c.logger)
|
||||
if err != nil {
|
||||
c.setStatus("error")
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.inner = inner
|
||||
c.mu.Unlock()
|
||||
c.setStatus("connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
c.mu.RLock()
|
||||
inner := c.inner
|
||||
c.mu.RUnlock()
|
||||
if inner == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
return inner.ListTools(ctx)
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
c.mu.RLock()
|
||||
inner := c.inner
|
||||
c.mu.RUnlock()
|
||||
if inner == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
return inner.CallTool(ctx, name, args)
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) Close() error {
|
||||
c.mu.Lock()
|
||||
inner := c.inner
|
||||
c.inner = nil
|
||||
c.mu.Unlock()
|
||||
c.setStatus("disconnected")
|
||||
if inner != nil {
|
||||
return inner.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) setStatus(s string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = s
|
||||
}
|
||||
|
||||
func (c *sdkClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *sdkClient) IsConnected() bool {
|
||||
return c.GetStatus() == "connected"
|
||||
}
|
||||
|
||||
func (c *sdkClient) Initialize(ctx context.Context) error {
|
||||
// sdkClient 由 createSDKClient 在 Connect 成功后才创建,因此 Initialize 时已经连接
|
||||
// 此方法仅用于满足 ExternalMCPClient 接口,实际连接在 createSDKClient 中完成
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
if c.session == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
res, err := c.session.ListTools(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return sdkToolsToOur(res.Tools), nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
if c.session == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
params := &mcp.CallToolParams{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
res, err := c.session.CallTool(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sdkCallToolResultToOurs(res), nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) Close() error {
|
||||
c.setStatus("disconnected")
|
||||
if c.session != nil {
|
||||
err := c.session.Close()
|
||||
c.session = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sdkToolsToOur 将 SDK 的 []*mcp.Tool 转为我们的 []Tool
|
||||
func sdkToolsToOur(tools []*mcp.Tool) []Tool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]Tool, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
schema := make(map[string]interface{})
|
||||
if t.InputSchema != nil {
|
||||
// SDK InputSchema 可能为 *jsonschema.Schema 或 map,统一转为 map
|
||||
if m, ok := t.InputSchema.(map[string]interface{}); ok {
|
||||
schema = m
|
||||
} else {
|
||||
_ = json.Unmarshal(mustJSON(t.InputSchema), &schema)
|
||||
}
|
||||
}
|
||||
desc := t.Description
|
||||
shortDesc := desc
|
||||
if t.Annotations != nil && t.Annotations.Title != "" {
|
||||
shortDesc = t.Annotations.Title
|
||||
}
|
||||
out = append(out, Tool{
|
||||
Name: t.Name,
|
||||
Description: desc,
|
||||
ShortDescription: shortDesc,
|
||||
InputSchema: schema,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// sdkCallToolResultToOurs 将 SDK 的 *mcp.CallToolResult 转为我们的 *ToolResult
|
||||
func sdkCallToolResultToOurs(res *mcp.CallToolResult) *ToolResult {
|
||||
if res == nil {
|
||||
return &ToolResult{Content: []Content{}}
|
||||
}
|
||||
content := sdkContentToOurs(res.Content)
|
||||
return &ToolResult{
|
||||
Content: content,
|
||||
IsError: res.IsError,
|
||||
}
|
||||
}
|
||||
|
||||
func sdkContentToOurs(list []mcp.Content) []Content {
|
||||
if len(list) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]Content, 0, len(list))
|
||||
for _, c := range list {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
out = append(out, Content{Type: "text", Text: v.Text})
|
||||
default:
|
||||
out = append(out, Content{Type: "text", Text: fmt.Sprintf("%v", c)})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mustJSON(v interface{}) []byte {
|
||||
b, _ := json.Marshal(v)
|
||||
return b
|
||||
}
|
||||
|
||||
// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient
|
||||
// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。
|
||||
func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) {
|
||||
timeout := time.Duration(serverCfg.Timeout) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
transport := serverCfg.GetTransportType()
|
||||
if transport == "" {
|
||||
return nil, fmt.Errorf("配置缺少 command 或 url,且未指定 type/transport")
|
||||
}
|
||||
|
||||
// 构造 ClientOptions:KeepAlive 心跳
|
||||
var clientOpts *mcp.ClientOptions
|
||||
if serverCfg.KeepAlive > 0 {
|
||||
clientOpts = &mcp.ClientOptions{
|
||||
KeepAlive: time.Duration(serverCfg.KeepAlive) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
client := mcp.NewClient(&mcp.Implementation{
|
||||
Name: clientName,
|
||||
Version: clientVersion,
|
||||
}, clientOpts)
|
||||
|
||||
var t mcp.Transport
|
||||
switch transport {
|
||||
case "stdio":
|
||||
if serverCfg.Command == "" {
|
||||
return nil, fmt.Errorf("stdio 模式需要配置 command")
|
||||
}
|
||||
// 必须用 exec.Command 而非 CommandContext:doConnect 返回后 ctx 会被 cancel,
|
||||
// 若用 CommandContext(ctx) 会立刻杀掉子进程,导致 ListTools 等后续请求失败、显示 0 工具
|
||||
cmd := exec.Command(serverCfg.Command, serverCfg.Args...)
|
||||
if len(serverCfg.Env) > 0 {
|
||||
cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...)
|
||||
}
|
||||
ct := &mcp.CommandTransport{Command: cmd}
|
||||
if serverCfg.TerminateDuration > 0 {
|
||||
ct.TerminateDuration = time.Duration(serverCfg.TerminateDuration) * time.Second
|
||||
}
|
||||
t = ct
|
||||
case "sse":
|
||||
if serverCfg.URL == "" {
|
||||
return nil, fmt.Errorf("sse 模式需要配置 url")
|
||||
}
|
||||
// SSE 是长连接(GET 流持续打开),不能设置 http.Client.Timeout(会在超时后杀掉整个连接导致 EOF)。
|
||||
// 超时由每次 ListTools/CallTool 的 context 单独控制。
|
||||
httpClient := httpClientForLongLived(serverCfg.Headers)
|
||||
t = &mcp.SSEClientTransport{
|
||||
Endpoint: serverCfg.URL,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
case "http":
|
||||
if serverCfg.URL == "" {
|
||||
return nil, fmt.Errorf("http 模式需要配置 url")
|
||||
}
|
||||
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
|
||||
st := &mcp.StreamableClientTransport{
|
||||
Endpoint: serverCfg.URL,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
if serverCfg.MaxRetries > 0 {
|
||||
st.MaxRetries = serverCfg.MaxRetries
|
||||
}
|
||||
t = st
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的传输模式: %s(支持: stdio, sse, http)", transport)
|
||||
}
|
||||
|
||||
session, err := client.Connect(ctx, t, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接失败: %w", err)
|
||||
}
|
||||
|
||||
return newSDKClientFromSession(session, client, logger), nil
|
||||
}
|
||||
|
||||
func envMapToSlice(env map[string]string) []string {
|
||||
m := make(map[string]string)
|
||||
for _, s := range os.Environ() {
|
||||
if i := strings.IndexByte(s, '='); i > 0 {
|
||||
m[s[:i]] = s[i+1:]
|
||||
}
|
||||
}
|
||||
for k, v := range env {
|
||||
m[k] = v
|
||||
}
|
||||
out := make([]string, 0, len(m))
|
||||
for k, v := range m {
|
||||
out = append(out, k+"="+v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]string) *http.Client {
|
||||
transport := http.DefaultTransport
|
||||
if len(headers) > 0 {
|
||||
transport = &headerRoundTripper{
|
||||
headers: headers,
|
||||
base: http.DefaultTransport,
|
||||
}
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
// httpClientForLongLived 创建不设超时的 HTTP 客户端,用于 SSE 等长连接传输。
|
||||
// SSE 的 GET 流会持续打开,http.Client.Timeout 会在超时后强制关闭连接导致 EOF。
|
||||
// 超时由调用方通过 context 控制。
|
||||
func httpClientForLongLived(headers map[string]string) *http.Client {
|
||||
transport := http.DefaultTransport
|
||||
if len(headers) > 0 {
|
||||
transport = &headerRoundTripper{
|
||||
headers: headers,
|
||||
base: http.DefaultTransport,
|
||||
}
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
// 不设 Timeout,SSE 长连接的超时由 per-request context 控制
|
||||
}
|
||||
}
|
||||
|
||||
type headerRoundTripper struct {
|
||||
headers map[string]string
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for k, v := range h.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
return h.base.RoundTrip(req)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,235 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
// 测试添加stdio配置
|
||||
stdioCfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Args: []string{"/path/to/script.py"},
|
||||
Description: "Test stdio MCP",
|
||||
Timeout: 30,
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
|
||||
err := manager.AddOrUpdateConfig("test-stdio", stdioCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("添加stdio配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试添加HTTP配置
|
||||
httpCfg := config.ExternalMCPServerConfig{
|
||||
Type: "http",
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Description: "Test HTTP MCP",
|
||||
Timeout: 30,
|
||||
ExternalMCPEnable: false,
|
||||
}
|
||||
|
||||
err = manager.AddOrUpdateConfig("test-http", httpCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("添加HTTP配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证配置已保存
|
||||
configs := manager.GetConfigs()
|
||||
if len(configs) != 2 {
|
||||
t.Fatalf("期望2个配置,实际%d个", len(configs))
|
||||
}
|
||||
|
||||
if configs["test-stdio"].Command != stdioCfg.Command {
|
||||
t.Errorf("stdio配置命令不匹配")
|
||||
}
|
||||
|
||||
if configs["test-http"].URL != httpCfg.URL {
|
||||
t.Errorf("HTTP配置URL不匹配")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_RemoveConfig(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: false,
|
||||
}
|
||||
|
||||
manager.AddOrUpdateConfig("test-remove", cfg)
|
||||
|
||||
// 移除配置
|
||||
err := manager.RemoveConfig("test-remove")
|
||||
if err != nil {
|
||||
t.Fatalf("移除配置失败: %v", err)
|
||||
}
|
||||
|
||||
configs := manager.GetConfigs()
|
||||
if _, exists := configs["test-remove"]; exists {
|
||||
t.Error("配置应该已被移除")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_GetStats(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
// 添加多个配置
|
||||
manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
|
||||
manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
|
||||
manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: false,
|
||||
})
|
||||
|
||||
stats := manager.GetStats()
|
||||
|
||||
if stats["total"].(int) != 3 {
|
||||
t.Errorf("期望总数3,实际%d", stats["total"])
|
||||
}
|
||||
|
||||
if stats["enabled"].(int) != 2 {
|
||||
t.Errorf("期望启用数2,实际%d", stats["enabled"])
|
||||
}
|
||||
|
||||
if stats["disabled"].(int) != 1 {
|
||||
t.Errorf("期望停用数1,实际%d", stats["disabled"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_LoadConfigs(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
externalMCPConfig := config.ExternalMCPConfig{
|
||||
Servers: map[string]config.ExternalMCPServerConfig{
|
||||
"loaded1": {
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
},
|
||||
"loaded2": {
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
ExternalMCPEnable: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager.LoadConfigs(&externalMCPConfig)
|
||||
|
||||
configs := manager.GetConfigs()
|
||||
if len(configs) != 2 {
|
||||
t.Fatalf("期望2个配置,实际%d个", len(configs))
|
||||
}
|
||||
|
||||
if configs["loaded1"].Command != "python3" {
|
||||
t.Error("配置1加载失败")
|
||||
}
|
||||
|
||||
if configs["loaded2"].URL != "http://127.0.0.1:8081/mcp" {
|
||||
t.Error("配置2加载失败")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLazySDKClient_InitializeFails 验证无效配置时 SDK 客户端 Initialize 失败并设置 error 状态
|
||||
func TestLazySDKClient_InitializeFails(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
// 使用不存在的 HTTP 地址,Initialize 应失败
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Type: "http",
|
||||
URL: "http://127.0.0.1:19999/nonexistent",
|
||||
Timeout: 2,
|
||||
}
|
||||
c := newLazySDKClient(cfg, logger)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := c.Initialize(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when connecting to invalid server")
|
||||
}
|
||||
if c.GetStatus() != "error" {
|
||||
t.Errorf("expected status error, got %s", c.GetStatus())
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_StartStopClient(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
// 添加一个禁用的配置
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: false,
|
||||
}
|
||||
|
||||
manager.AddOrUpdateConfig("test-start-stop", cfg)
|
||||
|
||||
// 尝试启动(可能会失败,因为没有真实的服务器)
|
||||
err := manager.StartClient("test-start-stop")
|
||||
if err != nil {
|
||||
t.Logf("启动失败(可能是没有服务器): %v", err)
|
||||
}
|
||||
|
||||
// 停止
|
||||
err = manager.StopClient("test-start-stop")
|
||||
if err != nil {
|
||||
t.Fatalf("停止失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证配置已更新为禁用
|
||||
configs := manager.GetConfigs()
|
||||
if configs["test-start-stop"].ExternalMCPEnable {
|
||||
t.Error("配置应该已被禁用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_CallTool(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
// 测试调用不存在的工具
|
||||
_, _, err := manager.CallTool(context.Background(), "nonexistent::tool", map[string]interface{}{})
|
||||
if err == nil {
|
||||
t.Error("应该返回错误")
|
||||
}
|
||||
|
||||
// 测试无效的工具名称格式
|
||||
_, _, err = manager.CallTool(context.Background(), "invalid-tool-name", map[string]interface{}{})
|
||||
if err == nil {
|
||||
t.Error("应该返回错误(无效格式)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_GetAllTools(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
ctx := context.Background()
|
||||
tools, err := manager.GetAllTools(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("获取工具列表失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果没有连接的客户端,应该返回空列表
|
||||
if len(tools) != 0 {
|
||||
t.Logf("获取到%d个工具", len(tools))
|
||||
}
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ToolRunRegistry 在工具开始/结束时登记当前 executionId,供对话页「仅终止当前工具」与监控页共用取消逻辑。
|
||||
type ToolRunRegistry interface {
|
||||
RegisterRunningTool(conversationID, executionID string)
|
||||
UnregisterRunningTool(conversationID, executionID string)
|
||||
}
|
||||
|
||||
type toolRunRegistryCtxKey struct{}
|
||||
type mcpConversationIDCtxKey struct{}
|
||||
|
||||
// WithToolRunRegistry 将登记器注入 ctx(Eino / 原生 Agent 任务 ctx)。
|
||||
func WithToolRunRegistry(ctx context.Context, reg ToolRunRegistry) context.Context {
|
||||
if ctx == nil || reg == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, toolRunRegistryCtxKey{}, reg)
|
||||
}
|
||||
|
||||
// ToolRunRegistryFromContext 取出登记器(无则 nil)。
|
||||
func ToolRunRegistryFromContext(ctx context.Context) ToolRunRegistry {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
v, _ := ctx.Value(toolRunRegistryCtxKey{}).(ToolRunRegistry)
|
||||
return v
|
||||
}
|
||||
|
||||
// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。
|
||||
func WithMCPConversationID(ctx context.Context, conversationID string) context.Context {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
id := strings.TrimSpace(conversationID)
|
||||
if id == "" {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, mcpConversationIDCtxKey{}, id)
|
||||
}
|
||||
|
||||
// MCPConversationIDFromContext 读取对话 ID。
|
||||
func MCPConversationIDFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
v, _ := ctx.Value(mcpConversationIDCtxKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
func notifyToolRunBegin(ctx context.Context, executionID string) {
|
||||
reg := ToolRunRegistryFromContext(ctx)
|
||||
if reg == nil {
|
||||
return
|
||||
}
|
||||
conv := MCPConversationIDFromContext(ctx)
|
||||
if conv == "" || strings.TrimSpace(executionID) == "" {
|
||||
return
|
||||
}
|
||||
reg.RegisterRunningTool(conv, executionID)
|
||||
}
|
||||
|
||||
func notifyToolRunEnd(ctx context.Context, executionID string) {
|
||||
reg := ToolRunRegistryFromContext(ctx)
|
||||
if reg == nil {
|
||||
return
|
||||
}
|
||||
conv := MCPConversationIDFromContext(ctx)
|
||||
if conv == "" || strings.TrimSpace(executionID) == "" {
|
||||
return
|
||||
}
|
||||
reg.UnregisterRunningTool(conv, executionID)
|
||||
}
|
||||
-1450
File diff suppressed because it is too large
Load Diff
-329
@@ -1,329 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExternalMCPClient 外部 MCP 客户端接口(由 client_sdk.go 基于官方 SDK 实现)
|
||||
type ExternalMCPClient interface {
|
||||
Initialize(ctx context.Context) error
|
||||
ListTools(ctx context.Context) ([]Tool, error)
|
||||
CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error)
|
||||
Close() error
|
||||
IsConnected() bool
|
||||
GetStatus() string
|
||||
}
|
||||
|
||||
// MCP消息类型
|
||||
const (
|
||||
MessageTypeRequest = "request"
|
||||
MessageTypeResponse = "response"
|
||||
MessageTypeError = "error"
|
||||
MessageTypeNotify = "notify"
|
||||
)
|
||||
|
||||
// MCP协议版本
|
||||
const ProtocolVersion = "2024-11-05"
|
||||
|
||||
// MessageID 表示JSON-RPC 2.0的id字段,可以是字符串、数字或null
|
||||
type MessageID struct {
|
||||
value interface{}
|
||||
}
|
||||
|
||||
// UnmarshalJSON 自定义反序列化,支持字符串、数字和null
|
||||
func (m *MessageID) UnmarshalJSON(data []byte) error {
|
||||
// 尝试解析为null
|
||||
if string(data) == "null" {
|
||||
m.value = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// 尝试解析为字符串
|
||||
var str string
|
||||
if err := json.Unmarshal(data, &str); err == nil {
|
||||
m.value = str
|
||||
return nil
|
||||
}
|
||||
|
||||
// 尝试解析为数字
|
||||
var num json.Number
|
||||
if err := json.Unmarshal(data, &num); err == nil {
|
||||
m.value = num
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid id type")
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义序列化
|
||||
func (m MessageID) MarshalJSON() ([]byte, error) {
|
||||
if m.value == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(m.value)
|
||||
}
|
||||
|
||||
// String 返回字符串表示
|
||||
func (m MessageID) String() string {
|
||||
if m.value == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", m.value)
|
||||
}
|
||||
|
||||
// Value 返回原始值
|
||||
func (m MessageID) Value() interface{} {
|
||||
return m.value
|
||||
}
|
||||
|
||||
// Message 表示MCP消息(符合JSON-RPC 2.0规范)
|
||||
type Message struct {
|
||||
ID MessageID `json:"id,omitempty"`
|
||||
Type string `json:"-"` // 内部使用,不序列化到JSON
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
Version string `json:"jsonrpc,omitempty"` // JSON-RPC 2.0 版本标识
|
||||
}
|
||||
|
||||
// Error 表示MCP错误
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Tool 表示MCP工具定义
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"` // 详细描述
|
||||
ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗)
|
||||
InputSchema map[string]interface{} `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// ToolCall 表示工具调用
|
||||
type ToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// ToolResult 表示工具执行结果
|
||||
type ToolResult struct {
|
||||
Content []Content `json:"content"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
// Content 表示内容
|
||||
type Content struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// InitializeRequest 初始化请求
|
||||
type InitializeRequest struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]interface{} `json:"capabilities"`
|
||||
ClientInfo ClientInfo `json:"clientInfo"`
|
||||
}
|
||||
|
||||
// ClientInfo 客户端信息
|
||||
type ClientInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// InitializeResponse 初始化响应
|
||||
type InitializeResponse struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ServerCapabilities `json:"capabilities"`
|
||||
ServerInfo ServerInfo `json:"serverInfo"`
|
||||
}
|
||||
|
||||
// ServerCapabilities 服务器能力
|
||||
type ServerCapabilities struct {
|
||||
Tools map[string]interface{} `json:"tools,omitempty"`
|
||||
Prompts map[string]interface{} `json:"prompts,omitempty"`
|
||||
Resources map[string]interface{} `json:"resources,omitempty"`
|
||||
Sampling map[string]interface{} `json:"sampling,omitempty"`
|
||||
}
|
||||
|
||||
// ServerInfo 服务器信息
|
||||
type ServerInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// ListToolsRequest 列出工具请求
|
||||
type ListToolsRequest struct{}
|
||||
|
||||
// ListToolsResponse 列出工具响应
|
||||
type ListToolsResponse struct {
|
||||
Tools []Tool `json:"tools"`
|
||||
}
|
||||
|
||||
// ListPromptsResponse 列出提示词响应
|
||||
type ListPromptsResponse struct {
|
||||
Prompts []Prompt `json:"prompts"`
|
||||
}
|
||||
|
||||
// ListResourcesResponse 列出资源响应
|
||||
type ListResourcesResponse struct {
|
||||
Resources []Resource `json:"resources"`
|
||||
}
|
||||
|
||||
// CallToolRequest 调用工具请求
|
||||
type CallToolRequest struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// CallToolResponse 调用工具响应
|
||||
type CallToolResponse struct {
|
||||
Content []Content `json:"content"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
// ToolExecution 工具执行记录
|
||||
type ToolExecution struct {
|
||||
ID string `json:"id"`
|
||||
ToolName string `json:"toolName"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
Status string `json:"status"` // pending, running, completed, failed, cancelled
|
||||
Result *ToolResult `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StartTime time.Time `json:"startTime"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Duration time.Duration `json:"duration,omitempty"`
|
||||
}
|
||||
|
||||
// ToolStats 工具统计信息
|
||||
type ToolStats struct {
|
||||
ToolName string `json:"toolName"`
|
||||
TotalCalls int `json:"totalCalls"`
|
||||
SuccessCalls int `json:"successCalls"`
|
||||
FailedCalls int `json:"failedCalls"`
|
||||
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
|
||||
}
|
||||
|
||||
// Prompt 提示词模板
|
||||
type Prompt struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Arguments []PromptArgument `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// PromptArgument 提示词参数
|
||||
type PromptArgument struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Required bool `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
// GetPromptRequest 获取提示词请求
|
||||
type GetPromptRequest struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// GetPromptResponse 获取提示词响应
|
||||
type GetPromptResponse struct {
|
||||
Messages []PromptMessage `json:"messages"`
|
||||
}
|
||||
|
||||
// PromptMessage 提示词消息
|
||||
type PromptMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// Resource 资源
|
||||
type Resource struct {
|
||||
URI string `json:"uri"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
}
|
||||
|
||||
// ReadResourceRequest 读取资源请求
|
||||
type ReadResourceRequest struct {
|
||||
URI string `json:"uri"`
|
||||
}
|
||||
|
||||
// ReadResourceResponse 读取资源响应
|
||||
type ReadResourceResponse struct {
|
||||
Contents []ResourceContent `json:"contents"`
|
||||
}
|
||||
|
||||
// ResourceContent 资源内容
|
||||
type ResourceContent struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingRequest 采样请求
|
||||
type SamplingRequest struct {
|
||||
Messages []SamplingMessage `json:"messages"`
|
||||
Model string `json:"model,omitempty"`
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingMessage 采样消息
|
||||
type SamplingMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// SamplingResponse 采样响应
|
||||
type SamplingResponse struct {
|
||||
Content []SamplingContent `json:"content"`
|
||||
Model string `json:"model,omitempty"`
|
||||
StopReason string `json:"stopReason,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingContent 采样内容
|
||||
type SamplingContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// ToolResultPlainText 拼接工具结果中的文本(手动终止时作为「工具原始输出」)。
|
||||
func ToolResultPlainText(r *ToolResult) string {
|
||||
if r == nil || len(r.Content) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, c := range r.Content {
|
||||
b.WriteString(c.Text)
|
||||
}
|
||||
return strings.TrimSpace(b.String())
|
||||
}
|
||||
|
||||
// AbortNoteBannerForModel 标出后续文本来自「用户手动终止工具时在弹窗中填写」,避免与 stdout/stderr 混淆。
|
||||
const AbortNoteBannerForModel = "---\n" +
|
||||
"【用户终止说明|USER INTERRUPT NOTE】\n" +
|
||||
"(以下由操作者填写,用于指示模型如何继续;不是工具原始输出。)\n" +
|
||||
"(Written by the operator when stopping this tool; not raw tool output.)\n" +
|
||||
"---"
|
||||
|
||||
// MergePartialToolOutputAndAbortNote 格式:工具原始输出 + 醒目标题 + 用户终止说明(无说明则原样返回 partial)。
|
||||
func MergePartialToolOutputAndAbortNote(partial, userNote string) string {
|
||||
partial = strings.TrimSpace(partial)
|
||||
userNote = strings.TrimSpace(userNote)
|
||||
if userNote == "" {
|
||||
return partial
|
||||
}
|
||||
section := AbortNoteBannerForModel + "\n" + userNote
|
||||
if partial == "" {
|
||||
return section
|
||||
}
|
||||
return partial + "\n\n" + section
|
||||
}
|
||||
Reference in New Issue
Block a user