Delete mcp directory

This commit is contained in:
公明
2026-05-14 19:23:52 +08:00
committed by GitHub
parent da369c2edc
commit b6c864547e
7 changed files with 0 additions and 3811 deletions
-133
View File
@@ -1,133 +0,0 @@
package builtin
// 内置工具名称常量
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
const (
// 漏洞管理工具
ToolRecordVulnerability = "record_vulnerability"
// 知识库工具
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
ToolSearchKnowledgeBase = "search_knowledge_base"
// WebShell 助手工具(AI 在 WebShell 管理 - AI 助手 中使用)
ToolWebshellExec = "webshell_exec"
ToolWebshellFileList = "webshell_file_list"
ToolWebshellFileRead = "webshell_file_read"
ToolWebshellFileWrite = "webshell_file_write"
// WebShell 连接管理工具(用于通过 MCP 管理 webshell 连接)
ToolManageWebshellList = "manage_webshell_list"
ToolManageWebshellAdd = "manage_webshell_add"
ToolManageWebshellUpdate = "manage_webshell_update"
ToolManageWebshellDelete = "manage_webshell_delete"
ToolManageWebshellTest = "manage_webshell_test"
// 批量任务队列(与 Web 端批量任务一致,供模型创建/启停/查询队列)
ToolBatchTaskList = "batch_task_list"
ToolBatchTaskGet = "batch_task_get"
ToolBatchTaskCreate = "batch_task_create"
ToolBatchTaskStart = "batch_task_start"
ToolBatchTaskRerun = "batch_task_rerun"
ToolBatchTaskPause = "batch_task_pause"
ToolBatchTaskDelete = "batch_task_delete"
ToolBatchTaskUpdateMetadata = "batch_task_update_metadata"
ToolBatchTaskUpdateSchedule = "batch_task_update_schedule"
ToolBatchTaskScheduleEnabled = "batch_task_schedule_enabled"
ToolBatchTaskAdd = "batch_task_add_task"
ToolBatchTaskUpdate = "batch_task_update_task"
ToolBatchTaskRemove = "batch_task_remove_task"
// C2 工具集(合并同类项,8 个统一工具)
ToolC2Listener = "c2_listener" // 监听器管理(create/start/stop/list/get/update/delete
ToolC2Session = "c2_session" // 会话管理(list/get/set_sleep/kill/delete
ToolC2Task = "c2_task" // 任务下发(统一 task_type 参数)
ToolC2TaskManage = "c2_task_manage" // 任务管理(get_result/wait/list/cancel
ToolC2Payload = "c2_payload" // Payload 生成(oneliner/build
ToolC2Event = "c2_event" // 事件查询
ToolC2Profile = "c2_profile" // Malleable Profile 管理(list/get/create/update/delete
ToolC2File = "c2_file" // 文件管理(list/get_result
)
// IsBuiltinTool 检查工具名称是否是内置工具
func IsBuiltinTool(toolName string) bool {
switch toolName {
case ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase,
ToolWebshellExec,
ToolWebshellFileList,
ToolWebshellFileRead,
ToolWebshellFileWrite,
ToolManageWebshellList,
ToolManageWebshellAdd,
ToolManageWebshellUpdate,
ToolManageWebshellDelete,
ToolManageWebshellTest,
ToolBatchTaskList,
ToolBatchTaskGet,
ToolBatchTaskCreate,
ToolBatchTaskStart,
ToolBatchTaskRerun,
ToolBatchTaskPause,
ToolBatchTaskDelete,
ToolBatchTaskUpdateMetadata,
ToolBatchTaskUpdateSchedule,
ToolBatchTaskScheduleEnabled,
ToolBatchTaskAdd,
ToolBatchTaskUpdate,
ToolBatchTaskRemove,
// C2 工具
ToolC2Listener,
ToolC2Session,
ToolC2Task,
ToolC2TaskManage,
ToolC2Payload,
ToolC2Event,
ToolC2Profile,
ToolC2File:
return true
default:
return false
}
}
// GetAllBuiltinTools 返回所有内置工具名称列表
func GetAllBuiltinTools() []string {
return []string{
ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase,
ToolWebshellExec,
ToolWebshellFileList,
ToolWebshellFileRead,
ToolWebshellFileWrite,
ToolManageWebshellList,
ToolManageWebshellAdd,
ToolManageWebshellUpdate,
ToolManageWebshellDelete,
ToolManageWebshellTest,
ToolBatchTaskList,
ToolBatchTaskGet,
ToolBatchTaskCreate,
ToolBatchTaskStart,
ToolBatchTaskRerun,
ToolBatchTaskPause,
ToolBatchTaskDelete,
ToolBatchTaskUpdateMetadata,
ToolBatchTaskUpdateSchedule,
ToolBatchTaskScheduleEnabled,
ToolBatchTaskAdd,
ToolBatchTaskUpdate,
ToolBatchTaskRemove,
// C2 工具
ToolC2Listener,
ToolC2Session,
ToolC2Task,
ToolC2TaskManage,
ToolC2Payload,
ToolC2Event,
ToolC2Profile,
ToolC2File,
}
}
-405
View File
@@ -1,405 +0,0 @@
// Package mcp 外部 MCP 客户端 - 基于官方 go-sdk 实现,保证协议兼容性
package mcp
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"github.com/modelcontextprotocol/go-sdk/mcp"
"go.uber.org/zap"
)
const (
clientName = "CyberStrikeAI"
clientVersion = "1.0.0"
)
// sdkClient 基于官方 MCP Go SDK 的外部 MCP 客户端,实现 ExternalMCPClient 接口
type sdkClient struct {
session *mcp.ClientSession
client *mcp.Client
logger *zap.Logger
mu sync.RWMutex
status string // "disconnected", "connecting", "connected", "error"
}
// newSDKClientFromSession 用已连接成功的 session 构造(供 createSDKClient 内部使用)
func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, logger *zap.Logger) *sdkClient {
return &sdkClient{
session: session,
client: client,
logger: logger,
status: "connected",
}
}
// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient
type lazySDKClient struct {
serverCfg config.ExternalMCPServerConfig
logger *zap.Logger
inner ExternalMCPClient // 连接成功后为 *sdkClient
mu sync.RWMutex
status string
}
func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient {
return &lazySDKClient{
serverCfg: serverCfg,
logger: logger,
status: "connecting",
}
}
func (c *lazySDKClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *lazySDKClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
if c.inner != nil {
return c.inner.GetStatus()
}
return c.status
}
func (c *lazySDKClient) IsConnected() bool {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner != nil {
return inner.IsConnected()
}
return false
}
func (c *lazySDKClient) Initialize(ctx context.Context) error {
c.mu.Lock()
if c.inner != nil {
c.mu.Unlock()
return nil
}
c.mu.Unlock()
inner, err := createSDKClient(ctx, c.serverCfg, c.logger)
if err != nil {
c.setStatus("error")
return err
}
c.mu.Lock()
c.inner = inner
c.mu.Unlock()
c.setStatus("connected")
return nil
}
func (c *lazySDKClient) ListTools(ctx context.Context) ([]Tool, error) {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner == nil {
return nil, fmt.Errorf("未连接")
}
return inner.ListTools(ctx)
}
func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner == nil {
return nil, fmt.Errorf("未连接")
}
return inner.CallTool(ctx, name, args)
}
func (c *lazySDKClient) Close() error {
c.mu.Lock()
inner := c.inner
c.inner = nil
c.mu.Unlock()
c.setStatus("disconnected")
if inner != nil {
return inner.Close()
}
return nil
}
func (c *sdkClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *sdkClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *sdkClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *sdkClient) Initialize(ctx context.Context) error {
// sdkClient 由 createSDKClient 在 Connect 成功后才创建,因此 Initialize 时已经连接
// 此方法仅用于满足 ExternalMCPClient 接口,实际连接在 createSDKClient 中完成
return nil
}
func (c *sdkClient) ListTools(ctx context.Context) ([]Tool, error) {
if c.session == nil {
return nil, fmt.Errorf("未连接")
}
res, err := c.session.ListTools(ctx, nil)
if err != nil {
return nil, err
}
if res == nil {
return nil, nil
}
return sdkToolsToOur(res.Tools), nil
}
func (c *sdkClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
if c.session == nil {
return nil, fmt.Errorf("未连接")
}
params := &mcp.CallToolParams{
Name: name,
Arguments: args,
}
res, err := c.session.CallTool(ctx, params)
if err != nil {
return nil, err
}
return sdkCallToolResultToOurs(res), nil
}
func (c *sdkClient) Close() error {
c.setStatus("disconnected")
if c.session != nil {
err := c.session.Close()
c.session = nil
return err
}
return nil
}
// sdkToolsToOur 将 SDK 的 []*mcp.Tool 转为我们的 []Tool
func sdkToolsToOur(tools []*mcp.Tool) []Tool {
if len(tools) == 0 {
return nil
}
out := make([]Tool, 0, len(tools))
for _, t := range tools {
if t == nil {
continue
}
schema := make(map[string]interface{})
if t.InputSchema != nil {
// SDK InputSchema 可能为 *jsonschema.Schema 或 map,统一转为 map
if m, ok := t.InputSchema.(map[string]interface{}); ok {
schema = m
} else {
_ = json.Unmarshal(mustJSON(t.InputSchema), &schema)
}
}
desc := t.Description
shortDesc := desc
if t.Annotations != nil && t.Annotations.Title != "" {
shortDesc = t.Annotations.Title
}
out = append(out, Tool{
Name: t.Name,
Description: desc,
ShortDescription: shortDesc,
InputSchema: schema,
})
}
return out
}
// sdkCallToolResultToOurs 将 SDK 的 *mcp.CallToolResult 转为我们的 *ToolResult
func sdkCallToolResultToOurs(res *mcp.CallToolResult) *ToolResult {
if res == nil {
return &ToolResult{Content: []Content{}}
}
content := sdkContentToOurs(res.Content)
return &ToolResult{
Content: content,
IsError: res.IsError,
}
}
func sdkContentToOurs(list []mcp.Content) []Content {
if len(list) == 0 {
return nil
}
out := make([]Content, 0, len(list))
for _, c := range list {
switch v := c.(type) {
case *mcp.TextContent:
out = append(out, Content{Type: "text", Text: v.Text})
default:
out = append(out, Content{Type: "text", Text: fmt.Sprintf("%v", c)})
}
}
return out
}
func mustJSON(v interface{}) []byte {
b, _ := json.Marshal(v)
return b
}
// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient
// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。
func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) {
timeout := time.Duration(serverCfg.Timeout) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second
}
transport := serverCfg.GetTransportType()
if transport == "" {
return nil, fmt.Errorf("配置缺少 command 或 url,且未指定 type/transport")
}
// 构造 ClientOptionsKeepAlive 心跳
var clientOpts *mcp.ClientOptions
if serverCfg.KeepAlive > 0 {
clientOpts = &mcp.ClientOptions{
KeepAlive: time.Duration(serverCfg.KeepAlive) * time.Second,
}
}
client := mcp.NewClient(&mcp.Implementation{
Name: clientName,
Version: clientVersion,
}, clientOpts)
var t mcp.Transport
switch transport {
case "stdio":
if serverCfg.Command == "" {
return nil, fmt.Errorf("stdio 模式需要配置 command")
}
// 必须用 exec.Command 而非 CommandContextdoConnect 返回后 ctx 会被 cancel
// 若用 CommandContext(ctx) 会立刻杀掉子进程,导致 ListTools 等后续请求失败、显示 0 工具
cmd := exec.Command(serverCfg.Command, serverCfg.Args...)
if len(serverCfg.Env) > 0 {
cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...)
}
ct := &mcp.CommandTransport{Command: cmd}
if serverCfg.TerminateDuration > 0 {
ct.TerminateDuration = time.Duration(serverCfg.TerminateDuration) * time.Second
}
t = ct
case "sse":
if serverCfg.URL == "" {
return nil, fmt.Errorf("sse 模式需要配置 url")
}
// SSE 是长连接(GET 流持续打开),不能设置 http.Client.Timeout(会在超时后杀掉整个连接导致 EOF)。
// 超时由每次 ListTools/CallTool 的 context 单独控制。
httpClient := httpClientForLongLived(serverCfg.Headers)
t = &mcp.SSEClientTransport{
Endpoint: serverCfg.URL,
HTTPClient: httpClient,
}
case "http":
if serverCfg.URL == "" {
return nil, fmt.Errorf("http 模式需要配置 url")
}
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
st := &mcp.StreamableClientTransport{
Endpoint: serverCfg.URL,
HTTPClient: httpClient,
}
if serverCfg.MaxRetries > 0 {
st.MaxRetries = serverCfg.MaxRetries
}
t = st
default:
return nil, fmt.Errorf("不支持的传输模式: %s(支持: stdio, sse, http", transport)
}
session, err := client.Connect(ctx, t, nil)
if err != nil {
return nil, fmt.Errorf("连接失败: %w", err)
}
return newSDKClientFromSession(session, client, logger), nil
}
func envMapToSlice(env map[string]string) []string {
m := make(map[string]string)
for _, s := range os.Environ() {
if i := strings.IndexByte(s, '='); i > 0 {
m[s[:i]] = s[i+1:]
}
}
for k, v := range env {
m[k] = v
}
out := make([]string, 0, len(m))
for k, v := range m {
out = append(out, k+"="+v)
}
return out
}
func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]string) *http.Client {
transport := http.DefaultTransport
if len(headers) > 0 {
transport = &headerRoundTripper{
headers: headers,
base: http.DefaultTransport,
}
}
return &http.Client{
Timeout: timeout,
Transport: transport,
}
}
// httpClientForLongLived 创建不设超时的 HTTP 客户端,用于 SSE 等长连接传输。
// SSE 的 GET 流会持续打开,http.Client.Timeout 会在超时后强制关闭连接导致 EOF。
// 超时由调用方通过 context 控制。
func httpClientForLongLived(headers map[string]string) *http.Client {
transport := http.DefaultTransport
if len(headers) > 0 {
transport = &headerRoundTripper{
headers: headers,
base: http.DefaultTransport,
}
}
return &http.Client{
Transport: transport,
// 不设 TimeoutSSE 长连接的超时由 per-request context 控制
}
}
type headerRoundTripper struct {
headers map[string]string
base http.RoundTripper
}
func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range h.headers {
req.Header.Set(k, v)
}
return h.base.RoundTrip(req)
}
File diff suppressed because it is too large Load Diff
-235
View File
@@ -1,235 +0,0 @@
package mcp
import (
"context"
"testing"
"time"
"cyberstrike-ai/internal/config"
"go.uber.org/zap"
)
func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
// 测试添加stdio配置
stdioCfg := config.ExternalMCPServerConfig{
Command: "python3",
Args: []string{"/path/to/script.py"},
Description: "Test stdio MCP",
Timeout: 30,
ExternalMCPEnable: true,
}
err := manager.AddOrUpdateConfig("test-stdio", stdioCfg)
if err != nil {
t.Fatalf("添加stdio配置失败: %v", err)
}
// 测试添加HTTP配置
httpCfg := config.ExternalMCPServerConfig{
Type: "http",
URL: "http://127.0.0.1:8081/mcp",
Description: "Test HTTP MCP",
Timeout: 30,
ExternalMCPEnable: false,
}
err = manager.AddOrUpdateConfig("test-http", httpCfg)
if err != nil {
t.Fatalf("添加HTTP配置失败: %v", err)
}
// 验证配置已保存
configs := manager.GetConfigs()
if len(configs) != 2 {
t.Fatalf("期望2个配置,实际%d个", len(configs))
}
if configs["test-stdio"].Command != stdioCfg.Command {
t.Errorf("stdio配置命令不匹配")
}
if configs["test-http"].URL != httpCfg.URL {
t.Errorf("HTTP配置URL不匹配")
}
}
func TestExternalMCPManager_RemoveConfig(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
cfg := config.ExternalMCPServerConfig{
Command: "python3",
ExternalMCPEnable: false,
}
manager.AddOrUpdateConfig("test-remove", cfg)
// 移除配置
err := manager.RemoveConfig("test-remove")
if err != nil {
t.Fatalf("移除配置失败: %v", err)
}
configs := manager.GetConfigs()
if _, exists := configs["test-remove"]; exists {
t.Error("配置应该已被移除")
}
}
func TestExternalMCPManager_GetStats(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
// 添加多个配置
manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
ExternalMCPEnable: true,
})
manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true,
})
manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3",
ExternalMCPEnable: false,
})
stats := manager.GetStats()
if stats["total"].(int) != 3 {
t.Errorf("期望总数3,实际%d", stats["total"])
}
if stats["enabled"].(int) != 2 {
t.Errorf("期望启用数2,实际%d", stats["enabled"])
}
if stats["disabled"].(int) != 1 {
t.Errorf("期望停用数1,实际%d", stats["disabled"])
}
}
func TestExternalMCPManager_LoadConfigs(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
externalMCPConfig := config.ExternalMCPConfig{
Servers: map[string]config.ExternalMCPServerConfig{
"loaded1": {
Command: "python3",
ExternalMCPEnable: true,
},
"loaded2": {
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: false,
},
},
}
manager.LoadConfigs(&externalMCPConfig)
configs := manager.GetConfigs()
if len(configs) != 2 {
t.Fatalf("期望2个配置,实际%d个", len(configs))
}
if configs["loaded1"].Command != "python3" {
t.Error("配置1加载失败")
}
if configs["loaded2"].URL != "http://127.0.0.1:8081/mcp" {
t.Error("配置2加载失败")
}
}
// TestLazySDKClient_InitializeFails 验证无效配置时 SDK 客户端 Initialize 失败并设置 error 状态
func TestLazySDKClient_InitializeFails(t *testing.T) {
logger := zap.NewNop()
// 使用不存在的 HTTP 地址,Initialize 应失败
cfg := config.ExternalMCPServerConfig{
Type: "http",
URL: "http://127.0.0.1:19999/nonexistent",
Timeout: 2,
}
c := newLazySDKClient(cfg, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := c.Initialize(ctx)
if err == nil {
t.Fatal("expected error when connecting to invalid server")
}
if c.GetStatus() != "error" {
t.Errorf("expected status error, got %s", c.GetStatus())
}
c.Close()
}
func TestExternalMCPManager_StartStopClient(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
// 添加一个禁用的配置
cfg := config.ExternalMCPServerConfig{
Command: "python3",
ExternalMCPEnable: false,
}
manager.AddOrUpdateConfig("test-start-stop", cfg)
// 尝试启动(可能会失败,因为没有真实的服务器)
err := manager.StartClient("test-start-stop")
if err != nil {
t.Logf("启动失败(可能是没有服务器): %v", err)
}
// 停止
err = manager.StopClient("test-start-stop")
if err != nil {
t.Fatalf("停止失败: %v", err)
}
// 验证配置已更新为禁用
configs := manager.GetConfigs()
if configs["test-start-stop"].ExternalMCPEnable {
t.Error("配置应该已被禁用")
}
}
func TestExternalMCPManager_CallTool(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
// 测试调用不存在的工具
_, _, err := manager.CallTool(context.Background(), "nonexistent::tool", map[string]interface{}{})
if err == nil {
t.Error("应该返回错误")
}
// 测试无效的工具名称格式
_, _, err = manager.CallTool(context.Background(), "invalid-tool-name", map[string]interface{}{})
if err == nil {
t.Error("应该返回错误(无效格式)")
}
}
func TestExternalMCPManager_GetAllTools(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
ctx := context.Background()
tools, err := manager.GetAllTools(ctx)
if err != nil {
t.Fatalf("获取工具列表失败: %v", err)
}
// 如果没有连接的客户端,应该返回空列表
if len(tools) != 0 {
t.Logf("获取到%d个工具", len(tools))
}
}
-77
View File
@@ -1,77 +0,0 @@
package mcp
import (
"context"
"strings"
)
// ToolRunRegistry 在工具开始/结束时登记当前 executionId,供对话页「仅终止当前工具」与监控页共用取消逻辑。
type ToolRunRegistry interface {
RegisterRunningTool(conversationID, executionID string)
UnregisterRunningTool(conversationID, executionID string)
}
type toolRunRegistryCtxKey struct{}
type mcpConversationIDCtxKey struct{}
// WithToolRunRegistry 将登记器注入 ctxEino / 原生 Agent 任务 ctx)。
func WithToolRunRegistry(ctx context.Context, reg ToolRunRegistry) context.Context {
if ctx == nil || reg == nil {
return ctx
}
return context.WithValue(ctx, toolRunRegistryCtxKey{}, reg)
}
// ToolRunRegistryFromContext 取出登记器(无则 nil)。
func ToolRunRegistryFromContext(ctx context.Context) ToolRunRegistry {
if ctx == nil {
return nil
}
v, _ := ctx.Value(toolRunRegistryCtxKey{}).(ToolRunRegistry)
return v
}
// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。
func WithMCPConversationID(ctx context.Context, conversationID string) context.Context {
if ctx == nil {
return nil
}
id := strings.TrimSpace(conversationID)
if id == "" {
return ctx
}
return context.WithValue(ctx, mcpConversationIDCtxKey{}, id)
}
// MCPConversationIDFromContext 读取对话 ID。
func MCPConversationIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
v, _ := ctx.Value(mcpConversationIDCtxKey{}).(string)
return v
}
func notifyToolRunBegin(ctx context.Context, executionID string) {
reg := ToolRunRegistryFromContext(ctx)
if reg == nil {
return
}
conv := MCPConversationIDFromContext(ctx)
if conv == "" || strings.TrimSpace(executionID) == "" {
return
}
reg.RegisterRunningTool(conv, executionID)
}
func notifyToolRunEnd(ctx context.Context, executionID string) {
reg := ToolRunRegistryFromContext(ctx)
if reg == nil {
return
}
conv := MCPConversationIDFromContext(ctx)
if conv == "" || strings.TrimSpace(executionID) == "" {
return
}
reg.UnregisterRunningTool(conv, executionID)
}
-1450
View File
File diff suppressed because it is too large Load Diff
-329
View File
@@ -1,329 +0,0 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
)
// ExternalMCPClient 外部 MCP 客户端接口(由 client_sdk.go 基于官方 SDK 实现)
type ExternalMCPClient interface {
Initialize(ctx context.Context) error
ListTools(ctx context.Context) ([]Tool, error)
CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error)
Close() error
IsConnected() bool
GetStatus() string
}
// MCP消息类型
const (
MessageTypeRequest = "request"
MessageTypeResponse = "response"
MessageTypeError = "error"
MessageTypeNotify = "notify"
)
// MCP协议版本
const ProtocolVersion = "2024-11-05"
// MessageID 表示JSON-RPC 2.0的id字段,可以是字符串、数字或null
type MessageID struct {
value interface{}
}
// UnmarshalJSON 自定义反序列化,支持字符串、数字和null
func (m *MessageID) UnmarshalJSON(data []byte) error {
// 尝试解析为null
if string(data) == "null" {
m.value = nil
return nil
}
// 尝试解析为字符串
var str string
if err := json.Unmarshal(data, &str); err == nil {
m.value = str
return nil
}
// 尝试解析为数字
var num json.Number
if err := json.Unmarshal(data, &num); err == nil {
m.value = num
return nil
}
return fmt.Errorf("invalid id type")
}
// MarshalJSON 自定义序列化
func (m MessageID) MarshalJSON() ([]byte, error) {
if m.value == nil {
return []byte("null"), nil
}
return json.Marshal(m.value)
}
// String 返回字符串表示
func (m MessageID) String() string {
if m.value == nil {
return ""
}
return fmt.Sprintf("%v", m.value)
}
// Value 返回原始值
func (m MessageID) Value() interface{} {
return m.value
}
// Message 表示MCP消息(符合JSON-RPC 2.0规范)
type Message struct {
ID MessageID `json:"id,omitempty"`
Type string `json:"-"` // 内部使用,不序列化到JSON
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
Version string `json:"jsonrpc,omitempty"` // JSON-RPC 2.0 版本标识
}
// Error 表示MCP错误
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// Tool 表示MCP工具定义
type Tool struct {
Name string `json:"name"`
Description string `json:"description"` // 详细描述
ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗)
InputSchema map[string]interface{} `json:"inputSchema"`
}
// ToolCall 表示工具调用
type ToolCall struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// ToolResult 表示工具执行结果
type ToolResult struct {
Content []Content `json:"content"`
IsError bool `json:"isError,omitempty"`
}
// Content 表示内容
type Content struct {
Type string `json:"type"`
Text string `json:"text"`
}
// InitializeRequest 初始化请求
type InitializeRequest struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]interface{} `json:"capabilities"`
ClientInfo ClientInfo `json:"clientInfo"`
}
// ClientInfo 客户端信息
type ClientInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// InitializeResponse 初始化响应
type InitializeResponse struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
}
// ServerCapabilities 服务器能力
type ServerCapabilities struct {
Tools map[string]interface{} `json:"tools,omitempty"`
Prompts map[string]interface{} `json:"prompts,omitempty"`
Resources map[string]interface{} `json:"resources,omitempty"`
Sampling map[string]interface{} `json:"sampling,omitempty"`
}
// ServerInfo 服务器信息
type ServerInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// ListToolsRequest 列出工具请求
type ListToolsRequest struct{}
// ListToolsResponse 列出工具响应
type ListToolsResponse struct {
Tools []Tool `json:"tools"`
}
// ListPromptsResponse 列出提示词响应
type ListPromptsResponse struct {
Prompts []Prompt `json:"prompts"`
}
// ListResourcesResponse 列出资源响应
type ListResourcesResponse struct {
Resources []Resource `json:"resources"`
}
// CallToolRequest 调用工具请求
type CallToolRequest struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// CallToolResponse 调用工具响应
type CallToolResponse struct {
Content []Content `json:"content"`
IsError bool `json:"isError,omitempty"`
}
// ToolExecution 工具执行记录
type ToolExecution struct {
ID string `json:"id"`
ToolName string `json:"toolName"`
Arguments map[string]interface{} `json:"arguments"`
Status string `json:"status"` // pending, running, completed, failed, cancelled
Result *ToolResult `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartTime time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
}
// ToolStats 工具统计信息
type ToolStats struct {
ToolName string `json:"toolName"`
TotalCalls int `json:"totalCalls"`
SuccessCalls int `json:"successCalls"`
FailedCalls int `json:"failedCalls"`
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
}
// Prompt 提示词模板
type Prompt struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Arguments []PromptArgument `json:"arguments,omitempty"`
}
// PromptArgument 提示词参数
type PromptArgument struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Required bool `json:"required,omitempty"`
}
// GetPromptRequest 获取提示词请求
type GetPromptRequest struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
// GetPromptResponse 获取提示词响应
type GetPromptResponse struct {
Messages []PromptMessage `json:"messages"`
}
// PromptMessage 提示词消息
type PromptMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// Resource 资源
type Resource struct {
URI string `json:"uri"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
MimeType string `json:"mimeType,omitempty"`
}
// ReadResourceRequest 读取资源请求
type ReadResourceRequest struct {
URI string `json:"uri"`
}
// ReadResourceResponse 读取资源响应
type ReadResourceResponse struct {
Contents []ResourceContent `json:"contents"`
}
// ResourceContent 资源内容
type ResourceContent struct {
URI string `json:"uri"`
MimeType string `json:"mimeType,omitempty"`
Text string `json:"text,omitempty"`
Blob string `json:"blob,omitempty"`
}
// SamplingRequest 采样请求
type SamplingRequest struct {
Messages []SamplingMessage `json:"messages"`
Model string `json:"model,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
}
// SamplingMessage 采样消息
type SamplingMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// SamplingResponse 采样响应
type SamplingResponse struct {
Content []SamplingContent `json:"content"`
Model string `json:"model,omitempty"`
StopReason string `json:"stopReason,omitempty"`
}
// SamplingContent 采样内容
type SamplingContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
}
// ToolResultPlainText 拼接工具结果中的文本(手动终止时作为「工具原始输出」)。
func ToolResultPlainText(r *ToolResult) string {
if r == nil || len(r.Content) == 0 {
return ""
}
var b strings.Builder
for _, c := range r.Content {
b.WriteString(c.Text)
}
return strings.TrimSpace(b.String())
}
// AbortNoteBannerForModel 标出后续文本来自「用户手动终止工具时在弹窗中填写」,避免与 stdout/stderr 混淆。
const AbortNoteBannerForModel = "---\n" +
"【用户终止说明|USER INTERRUPT NOTE】\n" +
"(以下由操作者填写,用于指示模型如何继续;不是工具原始输出。)\n" +
"Written by the operator when stopping this tool; not raw tool output.\n" +
"---"
// MergePartialToolOutputAndAbortNote 格式:工具原始输出 + 醒目标题 + 用户终止说明(无说明则原样返回 partial)。
func MergePartialToolOutputAndAbortNote(partial, userNote string) string {
partial = strings.TrimSpace(partial)
userNote = strings.TrimSpace(userNote)
if userNote == "" {
return partial
}
section := AbortNoteBannerForModel + "\n" + userNote
if partial == "" {
return section
}
return partial + "\n\n" + section
}