diff --git a/README.md b/README.md index b095dbd6..8a8ae5a4 100644 --- a/README.md +++ b/README.md @@ -845,6 +845,209 @@ curl -X POST http://localhost:8080/api/mcp \ }' ``` +## 🔗 External MCP Integration + +CyberStrikeAI supports integrating external MCP servers to extend tool capabilities. External MCP tools are automatically registered in the system, and AI can call them just like built-in tools. + +### Configuration Methods + +#### Method 1: Configure via Web Interface (Recommended) + +1. After starting the server, access the web interface +2. Click the "Settings" button in the top-right corner +3. Find the "External MCP" configuration section in the settings +4. Click the "Add External MCP" button +5. Fill in the configuration information: + - **Name**: Unique identifier for the external MCP server (e.g., `hexstrike-ai`) + - **Transport Mode**: Choose `stdio` or `http` + - **Description**: Optional, used to describe the MCP server's functionality + - **Timeout**: Tool execution timeout in seconds, default 300 seconds + - **Enable Status**: Whether to enable this external MCP immediately +6. Fill in the corresponding configuration based on transport mode: + - **stdio Mode**: + - **Command**: Startup command for the MCP server (e.g., `python3`) + - **Args**: Startup arguments array (e.g., `["/path/to/mcp_server.py", "--arg", "value"]`) + - **HTTP Mode**: + - **URL**: HTTP endpoint address of the MCP server (e.g., `http://127.0.0.1:8888`) +7. Click the "Save" button, configuration will be automatically saved to `config.yaml` +8. The system will automatically connect to the external MCP server and load its tools + +#### Method 2: Edit Configuration File Directly + +Add `external_mcp` configuration in `config.yaml`: + +```yaml +# External MCP Configuration +external_mcp: + servers: + # External MCP server name (unique identifier) + hexstrike-ai: + # stdio mode configuration + command: python3 + args: + - /path/to/hexstrike_mcp.py + - --server + - 'http://127.0.0.1:8888' + + # Or HTTP mode configuration (choose one) + # transport: http + # url: http://127.0.0.1:8888 + + # Common configuration + description: HexStrike AI v6.0 - Advanced Cybersecurity Automation Platform + timeout: 300 # Timeout in seconds + external_mcp_enable: true # Whether to enable + + # Tool-level control (optional) + tool_enabled: + nmap_scan: true + sqlmap_scan: true + # ... other tools +``` + +### Transport Mode Description + +#### stdio Mode + +Communicates with external MCP servers via standard input/output (stdio), suitable for locally running MCP servers. + +**Configuration Example:** +```yaml +external_mcp: + servers: + my-mcp-server: + command: python3 + args: + - /path/to/mcp_server.py + - --config + - /path/to/config.json + description: My Custom MCP Server + timeout: 300 + external_mcp_enable: true +``` + +**Features:** +- ✅ Local process communication, no network port required +- ✅ High security, data doesn't traverse network +- ✅ Suitable for local development and testing + +#### HTTP Mode + +Communicates with external MCP servers via HTTP requests, suitable for remote MCP servers or scenarios requiring cross-network access. + +**Configuration Example:** +```yaml +external_mcp: + servers: + remote-mcp-server: + transport: http + url: http://192.168.1.100:8888 + description: Remote MCP Server + timeout: 300 + external_mcp_enable: true +``` + +**Features:** +- ✅ Supports remote access +- ✅ Suitable for distributed deployment +- ✅ Easy to integrate with existing HTTP services + +### Tool Naming Convention + +External MCP tools in the system use the naming format: `{mcp-server-name}::{tool-name}` + +For example: +- External MCP server name: `hexstrike-ai` +- Tool name: `nmap_scan` +- Full tool name in system: `hexstrike-ai::nmap_scan` + +AI will automatically recognize and use the full tool name when calling. + +### Tool Enable/Disable Control + +#### Global Enable/Disable + +Control the enable status of the entire external MCP server via the `external_mcp_enable` field: +- `true`: Enabled, system will automatically connect and load tools +- `false`: Disabled, system will not connect to this server + +#### Tool-Level Control + +Precisely control the enable status of each tool via the `tool_enabled` field: + +```yaml +external_mcp: + servers: + hexstrike-ai: + # ... other configuration + tool_enabled: + nmap_scan: true # Enable this tool + sqlmap_scan: false # Disable this tool + nuclei_scan: true # Enable this tool +``` + +- If a tool is not listed in `tool_enabled`, it is enabled by default +- If a tool is set to `false`, it won't appear in the tool list and AI cannot call it + +### Managing External MCP + +#### Via Web Interface + +1. **View External MCP List**: View all configured external MCP servers in the settings interface +2. **Start/Stop**: Start or stop external MCP server connections at any time +3. **Edit Configuration**: Modify external MCP configuration information +4. **Delete Configuration**: Remove unnecessary external MCP servers + +#### Via API + +- **Get External MCP List**: `GET /api/external-mcp` +- **Add External MCP**: `POST /api/external-mcp` +- **Update External MCP**: `PUT /api/external-mcp/:name` +- **Delete External MCP**: `DELETE /api/external-mcp/:name` +- **Start External MCP**: `POST /api/external-mcp/:name/start` +- **Stop External MCP**: `POST /api/external-mcp/:name/stop` + +### Monitoring and Statistics + +External MCP tool execution records and statistics are automatically recorded in the system: + +- **Execution Records**: View execution history of all external MCP tools in the "Tool Monitoring" page +- **Execution Statistics**: The execution statistics panel displays call counts and success/failure statistics for external MCP tools +- **Real-time Monitoring**: Real-time viewing of external MCP tool execution status + +### Troubleshooting + +**Issue: External MCP Cannot Connect** + +- ✅ Check if `command` and `args` configuration are correct (stdio mode) +- ✅ Check if `url` configuration is correct and accessible (HTTP mode) +- ✅ Check if the external MCP server is running normally +- ✅ View server logs for detailed error information +- ✅ Check network connection (HTTP mode) +- ✅ Check firewall settings (HTTP mode) + +**Issue: External MCP Tools Not Displayed** + +- ✅ Confirm `external_mcp_enable: true` +- ✅ Check `tool_enabled` configuration to ensure tools are not disabled +- ✅ Confirm external MCP server has successfully connected +- ✅ View server logs to confirm if tools have been loaded + +**Issue: External MCP Tool Execution Failed** + +- ✅ Check if external MCP server is running normally +- ✅ View tool execution logs for detailed error information +- ✅ Check if timeout setting is reasonable +- ✅ Confirm external MCP server supports the tool call format + +### Best Practices + +1. **Naming Convention**: Use meaningful names to identify external MCP servers, avoid conflicts +2. **Timeout Settings**: Set timeout reasonably based on tool execution time +3. **Tool Control**: Use `tool_enabled` to precisely control needed tools, avoid loading too many unnecessary tools +4. **Security Considerations**: HTTP mode is recommended to use intranet addresses or configure appropriate access control +5. **Monitoring Management**: Regularly check external MCP connection status and execution statistics + ## 🛠️ Security Tool Support ### Tool Overview diff --git a/README_CN.md b/README_CN.md index 5bcb077f..3bac15fc 100644 --- a/README_CN.md +++ b/README_CN.md @@ -844,6 +844,209 @@ curl -X POST http://localhost:8080/api/mcp \ }' ``` +## 🔗 外部 MCP 接入 + +CyberStrikeAI 支持接入外部 MCP 服务器,扩展工具能力。外部 MCP 工具会自动注册到系统中,AI 可以像使用内置工具一样调用它们。 + +### 配置方式 + +#### 方式一:通过 Web 界面配置(推荐) + +1. 启动服务器后,访问 Web 界面 +2. 点击右上角的"设置"按钮 +3. 在设置界面中找到"外部 MCP"配置部分 +4. 点击"添加外部 MCP"按钮 +5. 填写配置信息: + - **名称**:外部 MCP 服务器的唯一标识(如:`hexstrike-ai`) + - **传输模式**:选择 `stdio` 或 `http` + - **描述**:可选,用于说明该 MCP 服务器的功能 + - **超时时间**:工具执行超时时间(秒),默认 300 秒 + - **启用状态**:是否立即启用该外部 MCP +6. 根据传输模式填写相应配置: + - **stdio 模式**: + - **命令**:MCP 服务器的启动命令(如:`python3`) + - **参数**:启动参数数组(如:`["/path/to/mcp_server.py", "--arg", "value"]`) + - **HTTP 模式**: + - **URL**:MCP 服务器的 HTTP 端点地址(如:`http://127.0.0.1:8888`) +7. 点击"保存"按钮,配置会自动保存到 `config.yaml` +8. 系统会自动连接外部 MCP 服务器并加载其工具 + +#### 方式二:直接编辑配置文件 + +在 `config.yaml` 中添加 `external_mcp` 配置: + +```yaml +# 外部MCP配置 +external_mcp: + servers: + # 外部MCP服务器名称(唯一标识) + hexstrike-ai: + # stdio 模式配置 + command: python3 + args: + - /path/to/hexstrike_mcp.py + - --server + - 'http://127.0.0.1:8888' + + # 或 HTTP 模式配置(二选一) + # transport: http + # url: http://127.0.0.1:8888 + + # 通用配置 + description: HexStrike AI v6.0 - Advanced Cybersecurity Automation Platform + timeout: 300 # 超时时间(秒) + external_mcp_enable: true # 是否启用 + + # 工具级别控制(可选) + tool_enabled: + nmap_scan: true + sqlmap_scan: true + # ... 其他工具 +``` + +### 传输模式说明 + +#### stdio 模式 + +通过标准输入输出(stdio)与外部 MCP 服务器通信,适合本地运行的 MCP 服务器。 + +**配置示例:** +```yaml +external_mcp: + servers: + my-mcp-server: + command: python3 + args: + - /path/to/mcp_server.py + - --config + - /path/to/config.json + description: My Custom MCP Server + timeout: 300 + external_mcp_enable: true +``` + +**特点:** +- ✅ 本地进程通信,无需网络端口 +- ✅ 安全性高,数据不经过网络 +- ✅ 适合本地开发和测试 + +#### HTTP 模式 + +通过 HTTP 请求与外部 MCP 服务器通信,适合远程 MCP 服务器或需要跨网络访问的场景。 + +**配置示例:** +```yaml +external_mcp: + servers: + remote-mcp-server: + transport: http + url: http://192.168.1.100:8888 + description: Remote MCP Server + timeout: 300 + external_mcp_enable: true +``` + +**特点:** +- ✅ 支持远程访问 +- ✅ 适合分布式部署 +- ✅ 易于集成现有 HTTP 服务 + +### 工具命名规则 + +外部 MCP 工具在系统中的名称格式为:`{mcp-server-name}::{tool-name}` + +例如: +- 外部 MCP 服务器名称:`hexstrike-ai` +- 工具名称:`nmap_scan` +- 系统内完整工具名:`hexstrike-ai::nmap_scan` + +AI 在调用时会自动识别并使用完整工具名。 + +### 工具启用控制 + +#### 全局启用/禁用 + +通过 `external_mcp_enable` 字段控制整个外部 MCP 服务器的启用状态: +- `true`:启用,系统会自动连接并加载工具 +- `false`:禁用,系统不会连接该服务器 + +#### 工具级别控制 + +通过 `tool_enabled` 字段精确控制每个工具的启用状态: + +```yaml +external_mcp: + servers: + hexstrike-ai: + # ... 其他配置 + tool_enabled: + nmap_scan: true # 启用此工具 + sqlmap_scan: false # 禁用此工具 + nuclei_scan: true # 启用此工具 +``` + +- 如果 `tool_enabled` 中未列出某个工具,默认启用 +- 如果某个工具设置为 `false`,该工具不会出现在工具列表中,AI 也无法调用 + +### 管理外部 MCP + +#### 通过 Web 界面管理 + +1. **查看外部 MCP 列表**:在设置界面中查看所有已配置的外部 MCP 服务器 +2. **启动/停止**:可以随时启动或停止外部 MCP 服务器连接 +3. **编辑配置**:修改外部 MCP 的配置信息 +4. **删除配置**:移除不需要的外部 MCP 服务器 + +#### 通过 API 管理 + +- **获取外部 MCP 列表**:`GET /api/external-mcp` +- **添加外部 MCP**:`POST /api/external-mcp` +- **更新外部 MCP**:`PUT /api/external-mcp/:name` +- **删除外部 MCP**:`DELETE /api/external-mcp/:name` +- **启动外部 MCP**:`POST /api/external-mcp/:name/start` +- **停止外部 MCP**:`POST /api/external-mcp/:name/stop` + +### 监控和统计 + +外部 MCP 工具的执行记录和统计信息会自动记录到系统中: + +- **执行记录**:在"工具监控"页面可以查看所有外部 MCP 工具的执行历史 +- **执行统计**:执行统计面板会显示外部 MCP 工具的调用次数、成功/失败统计 +- **实时监控**:可以实时查看外部 MCP 工具的执行状态 + +### 故障排除 + +**问题:外部 MCP 无法连接** + +- ✅ 检查 `command` 和 `args` 配置是否正确(stdio 模式) +- ✅ 检查 `url` 配置是否正确且可访问(HTTP 模式) +- ✅ 检查外部 MCP 服务器是否正常运行 +- ✅ 查看服务器日志获取详细错误信息 +- ✅ 检查网络连接(HTTP 模式) +- ✅ 检查防火墙设置(HTTP 模式) + +**问题:外部 MCP 工具未显示** + +- ✅ 确认 `external_mcp_enable: true` +- ✅ 检查 `tool_enabled` 配置,确保工具未被禁用 +- ✅ 确认外部 MCP 服务器已成功连接 +- ✅ 查看服务器日志确认工具是否已加载 + +**问题:外部 MCP 工具执行失败** + +- ✅ 检查外部 MCP 服务器是否正常运行 +- ✅ 查看工具执行日志获取详细错误信息 +- ✅ 检查超时时间设置是否合理 +- ✅ 确认外部 MCP 服务器支持的工具调用格式 + +### 最佳实践 + +1. **命名规范**:使用有意义的名称标识外部 MCP 服务器,避免冲突 +2. **超时设置**:根据工具执行时间合理设置超时时间 +3. **工具控制**:使用 `tool_enabled` 精确控制需要的工具,避免加载过多无用工具 +4. **安全考虑**:HTTP 模式建议使用内网地址或配置适当的访问控制 +5. **监控管理**:定期检查外部 MCP 的连接状态和执行统计 + ## 🛠️ 安全工具支持 ### 工具概览 diff --git a/cmd/test-config/main.go b/cmd/test-config/main.go index e8dfd150..ab69be78 100644 --- a/cmd/test-config/main.go +++ b/cmd/test-config/main.go @@ -1,45 +1,57 @@ package main import ( - "cyberstrike-ai/internal/config" "fmt" "os" + + "cyberstrike-ai/internal/config" ) func main() { - cfg, err := config.Load("config.yaml") - if err != nil { - fmt.Printf("❌ 加载配置失败: %v\n", err) + if len(os.Args) < 2 { + fmt.Println("Usage: go run cmd/test-config/main.go ") os.Exit(1) } - fmt.Printf("✅ 配置加载成功\n") - fmt.Printf(" 工具目录: %s\n", cfg.Security.ToolsDir) - fmt.Printf(" 工具数量: %d\n", len(cfg.Security.Tools)) + configPath := os.Args[1] + cfg, err := config.Load(configPath) + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + os.Exit(1) + } - if len(cfg.Security.Tools) > 0 { - fmt.Printf("\n 已加载的工具:\n") - for _, tool := range cfg.Security.Tools { - status := "❌ 禁用" - if tool.Enabled { - status = "✅ 启用" - } - shortDesc := tool.ShortDescription - if shortDesc == "" { - shortDesc = "(无简短描述,将自动提取)" - } - fmt.Printf(" %s %s\n", status, tool.Name) - fmt.Printf(" 简短描述: %s\n", shortDesc) - if len(tool.Description) > 100 { - fmt.Printf(" 详细描述: %s...\n", tool.Description[:100]) - } else { - fmt.Printf(" 详细描述: %s\n", tool.Description) - } - fmt.Printf(" 参数数量: %d\n", len(tool.Parameters)) - fmt.Println() + if cfg.ExternalMCP.Servers == nil { + fmt.Println("No external MCP servers configured") + os.Exit(0) + } + + fmt.Printf("Found %d external MCP server(s):\n\n", len(cfg.ExternalMCP.Servers)) + + for name, srv := range cfg.ExternalMCP.Servers { + fmt.Printf("Name: %s\n", name) + fmt.Printf(" Transport: %s\n", getTransport(srv)) + fmt.Printf(" Command: %s\n", srv.Command) + if len(srv.Args) > 0 { + fmt.Printf(" Args: %v\n", srv.Args) } - } else { - fmt.Printf(" ⚠️ 未加载任何工具\n") + fmt.Printf(" URL: %s\n", srv.URL) + fmt.Printf(" Description: %s\n", srv.Description) + fmt.Printf(" Timeout: %d seconds\n", srv.Timeout) + fmt.Printf(" Enabled: %v\n", srv.Enabled) + fmt.Printf(" Disabled: %v\n", srv.Disabled) + fmt.Println() } } +func getTransport(srv config.ExternalMCPServerConfig) string { + if srv.Transport != "" { + return srv.Transport + } + if srv.Command != "" { + return "stdio" + } + if srv.URL != "" { + return "http" + } + return "unknown" +} diff --git a/cmd/test-external-mcp/main.go b/cmd/test-external-mcp/main.go new file mode 100644 index 00000000..4b74bcbf --- /dev/null +++ b/cmd/test-external-mcp/main.go @@ -0,0 +1,145 @@ +package main + +import ( + "context" + "fmt" + "os" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/logger" + "cyberstrike-ai/internal/mcp" +) + +func main() { + if len(os.Args) < 2 { + fmt.Println("Usage: go run cmd/test-external-mcp/main.go ") + os.Exit(1) + } + + configPath := os.Args[1] + cfg, err := config.Load(configPath) + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + os.Exit(1) + } + + if cfg.ExternalMCP.Servers == nil || len(cfg.ExternalMCP.Servers) == 0 { + fmt.Println("No external MCP servers configured") + os.Exit(0) + } + + fmt.Printf("Found %d external MCP server(s)\n\n", len(cfg.ExternalMCP.Servers)) + + // 创建日志 + log := logger.New("info", "stdout") + + // 创建外部MCP管理器 + manager := mcp.NewExternalMCPManager(log.Logger) + manager.LoadConfigs(&cfg.ExternalMCP) + + // 显示配置 + fmt.Println("=== 配置信息 ===") + for name, srv := range cfg.ExternalMCP.Servers { + fmt.Printf("\n%s:\n", name) + fmt.Printf(" Transport: %s\n", getTransport(srv)) + if srv.Command != "" { + fmt.Printf(" Command: %s\n", srv.Command) + fmt.Printf(" Args: %v\n", srv.Args) + } + if srv.URL != "" { + fmt.Printf(" URL: %s\n", srv.URL) + } + fmt.Printf(" Description: %s\n", srv.Description) + fmt.Printf(" Timeout: %d seconds\n", srv.Timeout) + fmt.Printf(" Enabled: %v\n", srv.Enabled) + fmt.Printf(" Disabled: %v\n", srv.Disabled) + } + + // 获取统计信息 + fmt.Println("\n=== 统计信息 ===") + stats := manager.GetStats() + fmt.Printf("总数: %d\n", stats["total"]) + fmt.Printf("已启用: %d\n", stats["enabled"]) + fmt.Printf("已停用: %d\n", stats["disabled"]) + fmt.Printf("已连接: %d\n", stats["connected"]) + + // 测试启动(仅测试启用的) + fmt.Println("\n=== 测试启动 ===") + for name, srv := range cfg.ExternalMCP.Servers { + if srv.Enabled && !srv.Disabled { + fmt.Printf("\n尝试启动 %s...\n", name) + // 注意:实际启动可能会失败,因为需要真实的MCP服务器 + err := manager.StartClient(name) + if err != nil { + fmt.Printf(" 启动失败(这是正常的,如果没有真实的MCP服务器): %v\n", err) + } else { + fmt.Printf(" 启动成功\n") + // 获取客户端状态 + if client, exists := manager.GetClient(name); exists { + fmt.Printf(" 状态: %s\n", client.GetStatus()) + fmt.Printf(" 已连接: %v\n", client.IsConnected()) + } + } + } + } + + // 等待一下 + time.Sleep(2 * time.Second) + + // 测试获取工具列表 + fmt.Println("\n=== 测试获取工具列表 ===") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + tools, err := manager.GetAllTools(ctx) + if err != nil { + fmt.Printf("获取工具列表失败: %v\n", err) + } else { + fmt.Printf("获取到 %d 个工具\n", len(tools)) + for i, tool := range tools { + if i < 5 { // 只显示前5个 + fmt.Printf(" - %s: %s\n", tool.Name, tool.Description) + } + } + if len(tools) > 5 { + fmt.Printf(" ... 还有 %d 个工具\n", len(tools)-5) + } + } + + // 测试停止 + fmt.Println("\n=== 测试停止 ===") + for name := range cfg.ExternalMCP.Servers { + fmt.Printf("\n停止 %s...\n", name) + err := manager.StopClient(name) + if err != nil { + fmt.Printf(" 停止失败: %v\n", err) + } else { + fmt.Printf(" 停止成功\n") + } + } + + // 最终统计 + fmt.Println("\n=== 最终统计 ===") + stats = manager.GetStats() + fmt.Printf("总数: %d\n", stats["total"]) + fmt.Printf("已启用: %d\n", stats["enabled"]) + fmt.Printf("已停用: %d\n", stats["disabled"]) + fmt.Printf("已连接: %d\n", stats["connected"]) + + fmt.Println("\n=== 测试完成 ===") +} + +func getTransport(srv config.ExternalMCPServerConfig) string { + if srv.Transport != "" { + return srv.Transport + } + if srv.Command != "" { + return "stdio" + } + if srv.URL != "" { + return "http" + } + return "unknown" +} + diff --git a/config.yaml b/config.yaml index 1a003fd6..c9ec2226 100644 --- a/config.yaml +++ b/config.yaml @@ -20,7 +20,7 @@ log: # MCP 协议配置 # MCP (Model Context Protocol) 用于工具注册和调用 mcp: - enabled: false # 是否启用 MCP 服务器 + enabled: true # 是否启用 MCP 服务器 host: 0.0.0.0 # MCP 服务器监听地址 port: 8081 # MCP 服务器端口 # AI 模型配置(支持 OpenAI 兼容 API) @@ -46,3 +46,6 @@ security: tools_dir: tools # 工具配置文件目录(相对于配置文件所在目录) # 系统会从该目录加载所有 .yaml 格式的工具配置文件 # 推荐方式:在 tools/ 目录下为每个工具创建独立的配置文件 +# 外部MCP配置 +external_mcp: + \ No newline at end of file diff --git a/internal/agent/agent.go b/internal/agent/agent.go index bb7a7820..99f6678a 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -19,16 +19,18 @@ import ( // Agent AI代理 type Agent struct { - openAIClient *http.Client - config *config.OpenAIConfig - mcpServer *mcp.Server - logger *zap.Logger - maxIterations int - mu sync.RWMutex // 添加互斥锁以支持并发更新 + openAIClient *http.Client + config *config.OpenAIConfig + mcpServer *mcp.Server + externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 + logger *zap.Logger + maxIterations int + mu sync.RWMutex // 添加互斥锁以支持并发更新 + toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具) } // NewAgent 创建新的Agent -func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, logger *zap.Logger, maxIterations int) *Agent { +func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger, maxIterations int) *Agent { // 如果 maxIterations 为 0 或负数,使用默认值 30 if maxIterations <= 0 { maxIterations = 30 @@ -55,10 +57,12 @@ func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, logger *zap.Logge Timeout: 30 * time.Minute, // 从5分钟增加到30分钟 Transport: transport, }, - config: cfg, - mcpServer: mcpServer, - logger: logger, - maxIterations: maxIterations, + config: cfg, + mcpServer: mcpServer, + externalMCPMgr: externalMCPMgr, + logger: logger, + maxIterations: maxIterations, + toolNameMapping: make(map[string]string), // 初始化工具名称映射 } } @@ -578,7 +582,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // getAvailableTools 获取可用工具 // 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗 func (a *Agent) getAvailableTools() []Tool { - // 从MCP服务器获取所有已注册的工具 + // 从MCP服务器获取所有已注册的内部工具 mcpTools := a.mcpServer.GetAllTools() // 转换为OpenAI格式的工具定义 @@ -603,8 +607,91 @@ func (a *Agent) getAvailableTools() []Tool { }) } + // 获取外部MCP工具 + if a.externalMCPMgr != nil { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + externalTools, err := a.externalMCPMgr.GetAllTools(ctx) + if err != nil { + a.logger.Warn("获取外部MCP工具失败", zap.Error(err)) + } else { + // 获取外部MCP配置,用于检查工具启用状态 + externalMCPConfigs := a.externalMCPMgr.GetConfigs() + + // 清空并重建工具名称映射 + a.mu.Lock() + a.toolNameMapping = make(map[string]string) + a.mu.Unlock() + + // 将外部MCP工具添加到工具列表(只添加启用的工具) + for _, externalTool := range externalTools { + // 解析工具名称:mcpName::toolName + var mcpName, actualToolName string + if idx := strings.Index(externalTool.Name, "::"); idx > 0 { + mcpName = externalTool.Name[:idx] + actualToolName = externalTool.Name[idx+2:] + } else { + continue // 跳过格式不正确的工具 + } + + // 检查工具是否启用 + enabled := false + if cfg, exists := externalMCPConfigs[mcpName]; exists { + // 首先检查外部MCP是否启用 + if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) { + enabled = false // MCP未启用,所有工具都禁用 + } else { + // MCP已启用,检查单个工具的启用状态 + // 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容) + if cfg.ToolEnabled == nil { + enabled = true // 未设置工具状态,默认为启用 + } else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists { + enabled = toolEnabled // 使用配置的工具状态 + } else { + enabled = true // 工具未在配置中,默认为启用 + } + } + } + + // 只添加启用的工具 + if !enabled { + continue + } + + // 使用简短描述(如果存在),否则使用详细描述 + description := externalTool.ShortDescription + if description == "" { + description = externalTool.Description + } + + // 转换schema中的类型为OpenAI标准类型 + convertedSchema := a.convertSchemaTypes(externalTool.InputSchema) + + // 将工具名称中的 "::" 替换为 "__" 以符合OpenAI命名规范 + // OpenAI要求工具名称只能包含 [a-zA-Z0-9_-] + openAIName := strings.ReplaceAll(externalTool.Name, "::", "__") + + // 保存名称映射关系(OpenAI格式 -> 原始格式) + a.mu.Lock() + a.toolNameMapping[openAIName] = externalTool.Name + a.mu.Unlock() + + tools = append(tools, Tool{ + Type: "function", + Function: FunctionDefinition{ + Name: openAIName, // 使用符合OpenAI规范的名称 + Description: description, + Parameters: convertedSchema, + }, + }) + } + } + } + a.logger.Debug("获取可用工具列表", - zap.Int("count", len(tools)), + zap.Int("internalTools", len(mcpTools)), + zap.Int("totalTools", len(tools)), ) return tools @@ -898,8 +985,26 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map zap.Any("args", args), ) - // 通过MCP服务器调用工具 - result, executionID, err := a.mcpServer.CallTool(ctx, toolName, args) + var result *mcp.ToolResult + var executionID string + var err error + + // 检查是否是外部MCP工具(通过工具名称映射) + a.mu.RLock() + originalToolName, isExternalTool := a.toolNameMapping[toolName] + a.mu.RUnlock() + + if isExternalTool && a.externalMCPMgr != nil { + // 使用原始工具名称调用外部MCP工具 + a.logger.Debug("调用外部MCP工具", + zap.String("openAIName", toolName), + zap.String("originalName", originalToolName), + ) + result, executionID, err = a.externalMCPMgr.CallTool(ctx, originalToolName, args) + } else { + // 调用内部MCP工具 + result, executionID, err = a.mcpServer.CallTool(ctx, toolName, args) + } // 如果调用失败(如工具不存在),返回友好的错误信息而不是抛出异常 if err != nil { diff --git a/internal/app/app.go b/internal/app/app.go index d1444859..2fecb409 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -20,14 +20,15 @@ import ( // App 应用 type App struct { - config *config.Config - logger *logger.Logger - router *gin.Engine - mcpServer *mcp.Server - agent *agent.Agent - executor *security.Executor - db *database.DB - auth *security.AuthManager + config *config.Config + logger *logger.Logger + router *gin.Engine + mcpServer *mcp.Server + externalMCPMgr *mcp.ExternalMCPManager + agent *agent.Agent + executor *security.Executor + db *database.DB + auth *security.AuthManager } // New 创建新应用 @@ -76,12 +77,20 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { cfg.Auth.GeneratedPasswordPersistErr = "" } + // 创建外部MCP管理器(使用与内部MCP服务器相同的存储) + externalMCPMgr := mcp.NewExternalMCPManagerWithStorage(log.Logger, db) + if cfg.ExternalMCP.Servers != nil { + externalMCPMgr.LoadConfigs(&cfg.ExternalMCP) + // 启动所有启用的外部MCP客户端 + externalMCPMgr.StartAllEnabled() + } + // 创建Agent maxIterations := cfg.Agent.MaxIterations if maxIterations <= 0 { maxIterations = 30 // 默认值 } - agent := agent.NewAgent(&cfg.OpenAI, mcpServer, log.Logger, maxIterations) + agent := agent.NewAgent(&cfg.OpenAI, mcpServer, externalMCPMgr, log.Logger, maxIterations) // 获取配置文件路径 configPath := "config.yaml" @@ -92,9 +101,11 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { // 创建处理器 agentHandler := handler.NewAgentHandler(agent, db, log.Logger) monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger) + monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录 conversationHandler := handler.NewConversationHandler(db, log.Logger) authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger) - configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, log.Logger) + configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, externalMCPMgr, log.Logger) + externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) // 设置路由 setupRoutes( @@ -104,19 +115,21 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { monitorHandler, conversationHandler, configHandler, + externalMCPHandler, mcpServer, authManager, ) return &App{ - config: cfg, - logger: log, - router: router, - mcpServer: mcpServer, - agent: agent, - executor: executor, - db: db, - auth: authManager, + config: cfg, + logger: log, + router: router, + mcpServer: mcpServer, + externalMCPMgr: externalMCPMgr, + agent: agent, + executor: executor, + db: db, + auth: authManager, }, nil } @@ -144,6 +157,14 @@ func (a *App) Run() error { return a.router.Run(addr) } +// Shutdown 关闭应用 +func (a *App) Shutdown() { + // 停止所有外部MCP客户端 + if a.externalMCPMgr != nil { + a.externalMCPMgr.StopAll() + } +} + // setupRoutes 设置路由 func setupRoutes( router *gin.Engine, @@ -152,6 +173,7 @@ func setupRoutes( monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, configHandler *handler.ConfigHandler, + externalMCPHandler *handler.ExternalMCPHandler, mcpServer *mcp.Server, authManager *security.AuthManager, ) { @@ -195,6 +217,15 @@ func setupRoutes( protected.PUT("/config", configHandler.UpdateConfig) protected.POST("/config/apply", configHandler.ApplyConfig) + // 外部MCP管理 + protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs) + protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats) + protected.GET("/external-mcp/:name", externalMCPHandler.GetExternalMCP) + protected.PUT("/external-mcp/:name", externalMCPHandler.AddOrUpdateExternalMCP) + protected.DELETE("/external-mcp/:name", externalMCPHandler.DeleteExternalMCP) + protected.POST("/external-mcp/:name/start", externalMCPHandler.StartExternalMCP) + protected.POST("/external-mcp/:name/stop", externalMCPHandler.StopExternalMCP) + // MCP端点 protected.POST("/mcp", func(c *gin.Context) { mcpServer.HandleHTTP(c.Writer, c.Request) diff --git a/internal/config/config.go b/internal/config/config.go index d73b4c1b..e02cef2c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,14 +12,15 @@ import ( ) type Config struct { - Server ServerConfig `yaml:"server"` - Log LogConfig `yaml:"log"` - MCP MCPConfig `yaml:"mcp"` - OpenAI OpenAIConfig `yaml:"openai"` - Agent AgentConfig `yaml:"agent"` - Security SecurityConfig `yaml:"security"` - Database DatabaseConfig `yaml:"database"` - Auth AuthConfig `yaml:"auth"` + Server ServerConfig `yaml:"server"` + Log LogConfig `yaml:"log"` + MCP MCPConfig `yaml:"mcp"` + OpenAI OpenAIConfig `yaml:"openai"` + Agent AgentConfig `yaml:"agent"` + Security SecurityConfig `yaml:"security"` + Database DatabaseConfig `yaml:"database"` + Auth AuthConfig `yaml:"auth"` + ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` } type ServerConfig struct { @@ -64,6 +65,32 @@ type AuthConfig struct { GeneratedPasswordPersisted bool `yaml:"-" json:"-"` GeneratedPasswordPersistErr string `yaml:"-" json:"-"` } + +// ExternalMCPConfig 外部MCP配置 +type ExternalMCPConfig struct { + Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"` +} + +// ExternalMCPServerConfig 外部MCP服务器配置 +type ExternalMCPServerConfig struct { + // stdio模式配置 + Command string `yaml:"command,omitempty" json:"command,omitempty"` + Args []string `yaml:"args,omitempty" json:"args,omitempty"` + + // HTTP模式配置 + Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio" + URL string `yaml:"url,omitempty" json:"url,omitempty"` + + // 通用配置 + Description string `yaml:"description,omitempty" json:"description,omitempty"` + Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 超时时间(秒) + ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用外部MCP + ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态(工具名称 -> 是否启用) + + // 向后兼容字段(已废弃,保留用于读取旧配置) + Enabled bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` // 已废弃,使用 external_mcp_enable + Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 已废弃,使用 external_mcp_enable +} type ToolConfig struct { Name string `yaml:"name"` Command string `yaml:"command"` @@ -152,6 +179,27 @@ func Load(path string) (*Config, error) { cfg.Security.Tools = tools } + // 迁移外部MCP配置:将旧的 enabled/disabled 字段迁移到 external_mcp_enable + if cfg.ExternalMCP.Servers != nil { + for name, serverCfg := range cfg.ExternalMCP.Servers { + // 如果已经设置了 external_mcp_enable,跳过迁移 + // 否则从 enabled/disabled 字段迁移 + // 注意:由于 ExternalMCPEnable 是 bool 类型,零值为 false,所以需要检查是否真的设置了 + // 这里我们通过检查旧的 enabled/disabled 字段来判断是否需要迁移 + if serverCfg.Disabled { + // 旧配置使用 disabled,迁移到 external_mcp_enable + serverCfg.ExternalMCPEnable = false + } else if serverCfg.Enabled { + // 旧配置使用 enabled,迁移到 external_mcp_enable + serverCfg.ExternalMCPEnable = true + } else { + // 都没有设置,默认为启用 + serverCfg.ExternalMCPEnable = true + } + cfg.ExternalMCP.Servers[name] = serverCfg + } + } + return &cfg, nil } diff --git a/internal/handler/config.go b/internal/handler/config.go index 4abf8112..07f0a02a 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -2,6 +2,7 @@ package handler import ( "bytes" + "context" "fmt" "net/http" "os" @@ -9,6 +10,7 @@ import ( "strconv" "strings" "sync" + "time" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" @@ -20,13 +22,14 @@ import ( // ConfigHandler 配置处理器 type ConfigHandler struct { - configPath string - config *config.Config - mcpServer *mcp.Server - executor *security.Executor - agent AgentUpdater // Agent接口,用于更新Agent配置 - logger *zap.Logger - mu sync.RWMutex + configPath string + config *config.Config + mcpServer *mcp.Server + executor *security.Executor + agent AgentUpdater // Agent接口,用于更新Agent配置 + externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 + logger *zap.Logger + mu sync.RWMutex } // AgentUpdater Agent更新接口 @@ -36,14 +39,15 @@ type AgentUpdater interface { } // NewConfigHandler 创建新的配置处理器 -func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, logger *zap.Logger) *ConfigHandler { +func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler { return &ConfigHandler{ - configPath: configPath, - config: cfg, - mcpServer: mcpServer, - executor: executor, - agent: agent, - logger: logger, + configPath: configPath, + config: cfg, + mcpServer: mcpServer, + executor: executor, + agent: agent, + externalMCPMgr: externalMCPMgr, + logger: logger, } } @@ -60,6 +64,8 @@ type ToolConfigInfo struct { Name string `json:"name"` Description string `json:"description"` Enabled bool `json:"enabled"` + IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具 + ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具) } // GetConfig 获取当前配置 @@ -67,13 +73,14 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { h.mu.RLock() defer h.mu.RUnlock() - // 获取工具列表 + // 获取工具列表(包含内部和外部工具) tools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) for _, tool := range h.config.Security.Tools { tools = append(tools, ToolConfigInfo{ Name: tool.Name, Description: tool.ShortDescription, Enabled: tool.Enabled, + IsExternal: false, }) // 如果没有简短描述,使用详细描述的前100个字符 if tools[len(tools)-1].Description == "" { @@ -85,6 +92,65 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { } } + // 获取外部MCP工具 + if h.externalMCPMgr != nil { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + externalTools, err := h.externalMCPMgr.GetAllTools(ctx) + if err == nil { + externalMCPConfigs := h.externalMCPMgr.GetConfigs() + for _, externalTool := range externalTools { + var mcpName, actualToolName string + if idx := strings.Index(externalTool.Name, "::"); idx > 0 { + mcpName = externalTool.Name[:idx] + actualToolName = externalTool.Name[idx+2:] + } else { + continue + } + + enabled := false + if cfg, exists := externalMCPConfigs[mcpName]; exists { + // 首先检查外部MCP是否启用 + if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) { + enabled = false // MCP未启用,所有工具都禁用 + } else { + // MCP已启用,检查单个工具的启用状态 + // 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容) + if cfg.ToolEnabled == nil { + enabled = true // 未设置工具状态,默认为启用 + } else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists { + enabled = toolEnabled // 使用配置的工具状态 + } else { + enabled = true // 工具未在配置中,默认为启用 + } + } + } + + client, exists := h.externalMCPMgr.GetClient(mcpName) + if !exists || !client.IsConnected() { + enabled = false + } + + description := externalTool.ShortDescription + if description == "" { + description = externalTool.Description + } + if len(description) > 100 { + description = description[:100] + "..." + } + + tools = append(tools, ToolConfigInfo{ + Name: actualToolName, + Description: description, + Enabled: enabled, + IsExternal: true, + ExternalMCP: mcpName, + }) + } + } + } + c.JSON(http.StatusOK, GetConfigResponse{ OpenAI: h.config.OpenAI, MCP: h.config.MCP, @@ -128,13 +194,14 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { searchTermLower = strings.ToLower(searchTerm) } - // 获取所有工具并应用搜索过滤 + // 获取所有内部工具并应用搜索过滤 allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) for _, tool := range h.config.Security.Tools { toolInfo := ToolConfigInfo{ Name: tool.Name, Description: tool.ShortDescription, Enabled: tool.Enabled, + IsExternal: false, } // 如果没有简短描述,使用详细描述的前100个字符 if toolInfo.Description == "" { @@ -157,6 +224,81 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { allTools = append(allTools, toolInfo) } + // 获取外部MCP工具 + if h.externalMCPMgr != nil { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + externalTools, err := h.externalMCPMgr.GetAllTools(ctx) + if err != nil { + h.logger.Warn("获取外部MCP工具失败", zap.Error(err)) + } else { + // 获取外部MCP配置,用于判断启用状态 + externalMCPConfigs := h.externalMCPMgr.GetConfigs() + + for _, externalTool := range externalTools { + // 解析工具名称:mcpName::toolName + var mcpName, actualToolName string + if idx := strings.Index(externalTool.Name, "::"); idx > 0 { + mcpName = externalTool.Name[:idx] + actualToolName = externalTool.Name[idx+2:] + } else { + continue // 跳过格式不正确的工具 + } + + // 获取外部工具的启用状态 + enabled := false + if cfg, exists := externalMCPConfigs[mcpName]; exists { + // 首先检查外部MCP是否启用 + if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) { + enabled = false // MCP未启用,所有工具都禁用 + } else { + // MCP已启用,检查单个工具的启用状态 + // 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容) + if cfg.ToolEnabled == nil { + enabled = true // 未设置工具状态,默认为启用 + } else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists { + enabled = toolEnabled // 使用配置的工具状态 + } else { + enabled = true // 工具未在配置中,默认为启用 + } + } + } + + // 检查外部MCP是否已连接 + client, exists := h.externalMCPMgr.GetClient(mcpName) + if !exists || !client.IsConnected() { + enabled = false // 未连接时视为禁用 + } + + description := externalTool.ShortDescription + if description == "" { + description = externalTool.Description + } + if len(description) > 100 { + description = description[:100] + "..." + } + + // 如果有关键词,进行搜索过滤 + if searchTermLower != "" { + nameLower := strings.ToLower(actualToolName) + descLower := strings.ToLower(description) + if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) { + continue // 不匹配,跳过 + } + } + + allTools = append(allTools, ToolConfigInfo{ + Name: actualToolName, // 显示实际工具名称,不带前缀 + Description: description, + Enabled: enabled, + IsExternal: true, + ExternalMCP: mcpName, + }) + } + } + } + total := len(allTools) totalPages := (total + pageSize - 1) / pageSize if totalPages == 0 { @@ -196,8 +338,10 @@ type UpdateConfigRequest struct { // ToolEnableStatus 工具启用状态 type ToolEnableStatus struct { - Name string `json:"name"` - Enabled bool `json:"enabled"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具 + ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具) } // UpdateConfig 更新配置 @@ -240,14 +384,28 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { // 更新工具启用状态 if req.Tools != nil { - toolMap := make(map[string]bool) + // 分离内部工具和外部工具 + internalToolMap := make(map[string]bool) + // 外部工具状态:MCP名称 -> 工具名称 -> 启用状态 + externalMCPToolMap := make(map[string]map[string]bool) + for _, toolStatus := range req.Tools { - toolMap[toolStatus.Name] = toolStatus.Enabled + if toolStatus.IsExternal && toolStatus.ExternalMCP != "" { + // 外部工具:保存每个工具的独立状态 + mcpName := toolStatus.ExternalMCP + if externalMCPToolMap[mcpName] == nil { + externalMCPToolMap[mcpName] = make(map[string]bool) + } + externalMCPToolMap[mcpName][toolStatus.Name] = toolStatus.Enabled + } else { + // 内部工具 + internalToolMap[toolStatus.Name] = toolStatus.Enabled + } } - // 更新配置中的工具状态 + // 更新内部工具状态 for i := range h.config.Security.Tools { - if enabled, ok := toolMap[h.config.Security.Tools[i].Name]; ok { + if enabled, ok := internalToolMap[h.config.Security.Tools[i].Name]; ok { h.config.Security.Tools[i].Enabled = enabled h.logger.Info("更新工具启用状态", zap.String("tool", h.config.Security.Tools[i].Name), @@ -255,6 +413,80 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { ) } } + + // 更新外部MCP工具状态 + if h.externalMCPMgr != nil { + for mcpName, toolStates := range externalMCPToolMap { + // 更新配置中的工具启用状态 + if h.config.ExternalMCP.Servers == nil { + h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) + } + cfg, exists := h.config.ExternalMCP.Servers[mcpName] + if !exists { + h.logger.Warn("外部MCP配置不存在", zap.String("mcp", mcpName)) + continue + } + + // 初始化ToolEnabled map + if cfg.ToolEnabled == nil { + cfg.ToolEnabled = make(map[string]bool) + } + + // 更新每个工具的启用状态 + for toolName, enabled := range toolStates { + cfg.ToolEnabled[toolName] = enabled + h.logger.Info("更新外部工具启用状态", + zap.String("mcp", mcpName), + zap.String("tool", toolName), + zap.Bool("enabled", enabled), + ) + } + + // 检查是否有任何工具启用,如果有则启用MCP + hasEnabledTool := false + for _, enabled := range cfg.ToolEnabled { + if enabled { + hasEnabledTool = true + break + } + } + + // 如果MCP之前未启用,但现在有工具启用,则启用MCP + // 如果MCP之前已启用,保持启用状态(允许部分工具禁用) + if !cfg.ExternalMCPEnable && hasEnabledTool { + cfg.ExternalMCPEnable = true + h.logger.Info("自动启用外部MCP(因为有工具启用)", zap.String("mcp", mcpName)) + } + + h.config.ExternalMCP.Servers[mcpName] = cfg + } + + // 同步更新 externalMCPMgr 中的配置,确保 GetConfigs() 返回最新配置 + // 在循环外部统一更新,避免重复调用 + h.externalMCPMgr.LoadConfigs(&h.config.ExternalMCP) + + // 处理MCP连接状态 + for mcpName := range externalMCPToolMap { + cfg := h.config.ExternalMCP.Servers[mcpName] + // 如果MCP需要启用,确保客户端已启动 + if cfg.ExternalMCPEnable { + // 启动外部MCP(如果未启动) + client, exists := h.externalMCPMgr.GetClient(mcpName) + if !exists || !client.IsConnected() { + if err := h.externalMCPMgr.StartClient(mcpName); err != nil { + h.logger.Warn("启动外部MCP失败", + zap.String("mcp", mcpName), + zap.Error(err), + ) + } else { + h.logger.Info("启动外部MCP", + zap.String("mcp", mcpName), + ) + } + } + } + } + } } // 保存配置到文件 @@ -318,6 +550,33 @@ func (h *ConfigHandler) saveConfig() error { updateAgentConfig(root, h.config.Agent.MaxIterations) updateMCPConfig(root, h.config.MCP) updateOpenAIConfig(root, h.config.OpenAI) + // 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用) + // 读取原始配置以保持向后兼容 + originalConfigs := make(map[string]map[string]bool) + externalMCPNode := findMapValue(root, "external_mcp") + if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode { + serversNode := findMapValue(externalMCPNode, "servers") + if serversNode != nil && serversNode.Kind == yaml.MappingNode { + for i := 0; i < len(serversNode.Content); i += 2 { + if i+1 >= len(serversNode.Content) { + break + } + nameNode := serversNode.Content[i] + serverNode := serversNode.Content[i+1] + if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode { + serverName := nameNode.Value + originalConfigs[serverName] = make(map[string]bool) + if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil { + originalConfigs[serverName]["enabled"] = *enabledVal + } + if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil { + originalConfigs[serverName]["disabled"] = *disabledVal + } + } + } + } + } + updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs) if err := writeYAMLDocument(h.configPath, root); err != nil { return fmt.Errorf("保存配置文件失败: %w", err) @@ -504,6 +763,34 @@ func setIntInMap(mapNode *yaml.Node, key string, value int) { valueNode.Value = fmt.Sprintf("%d", value) } +func findBoolInMap(mapNode *yaml.Node, key string) *bool { + if mapNode == nil || mapNode.Kind != yaml.MappingNode { + return nil + } + + for i := 0; i < len(mapNode.Content); i += 2 { + if i+1 >= len(mapNode.Content) { + break + } + keyNode := mapNode.Content[i] + valueNode := mapNode.Content[i+1] + + if keyNode.Kind == yaml.ScalarNode && keyNode.Value == key { + if valueNode.Kind == yaml.ScalarNode { + if valueNode.Value == "true" { + result := true + return &result + } else if valueNode.Value == "false" { + result := false + return &result + } + } + return nil + } + } + return nil +} + func setBoolInMap(mapNode *yaml.Node, key string, value bool) { _, valueNode := ensureKeyValue(mapNode, key) valueNode.Kind = yaml.ScalarNode diff --git a/internal/handler/external_mcp.go b/internal/handler/external_mcp.go new file mode 100644 index 00000000..4e0add98 --- /dev/null +++ b/internal/handler/external_mcp.go @@ -0,0 +1,510 @@ +package handler + +import ( + "fmt" + "net/http" + "os" + "sync" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "github.com/gin-gonic/gin" + "go.uber.org/zap" + "gopkg.in/yaml.v3" +) + +// ExternalMCPHandler 外部MCP处理器 +type ExternalMCPHandler struct { + manager *mcp.ExternalMCPManager + config *config.Config + configPath string + logger *zap.Logger + mu sync.RWMutex +} + +// NewExternalMCPHandler 创建外部MCP处理器 +func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler { + return &ExternalMCPHandler{ + manager: manager, + config: cfg, + configPath: configPath, + logger: logger, + } +} + +// GetExternalMCPs 获取所有外部MCP配置 +func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) { + h.mu.RLock() + defer h.mu.RUnlock() + + configs := h.manager.GetConfigs() + + // 获取所有外部MCP的工具数量 + toolCounts := h.manager.GetToolCounts() + + // 转换为响应格式 + result := make(map[string]ExternalMCPResponse) + for name, cfg := range configs { + client, exists := h.manager.GetClient(name) + status := "disconnected" + if exists { + status = client.GetStatus() + } else if h.isEnabled(cfg) { + status = "disconnected" + } else { + status = "disabled" + } + + toolCount := toolCounts[name] + + result[name] = ExternalMCPResponse{ + Config: cfg, + Status: status, + ToolCount: toolCount, + } + } + + c.JSON(http.StatusOK, gin.H{ + "servers": result, + "stats": h.manager.GetStats(), + }) +} + +// GetExternalMCP 获取单个外部MCP配置 +func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) { + name := c.Param("name") + + h.mu.RLock() + defer h.mu.RUnlock() + + configs := h.manager.GetConfigs() + cfg, exists := configs[name] + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"}) + return + } + + client, clientExists := h.manager.GetClient(name) + status := "disconnected" + if clientExists { + status = client.GetStatus() + } else if h.isEnabled(cfg) { + status = "disconnected" + } else { + status = "disabled" + } + + // 获取工具数量 + toolCount := 0 + if clientExists && client.IsConnected() { + if count, err := h.manager.GetToolCount(name); err == nil { + toolCount = count + } + } + + c.JSON(http.StatusOK, ExternalMCPResponse{ + Config: cfg, + Status: status, + ToolCount: toolCount, + }) +} + +// AddOrUpdateExternalMCP 添加或更新外部MCP配置 +func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) { + var req AddOrUpdateExternalMCPRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + + name := c.Param("name") + if name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"}) + return + } + + // 验证配置 + if err := h.validateConfig(req.Config); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + h.mu.Lock() + defer h.mu.Unlock() + + // 添加或更新配置 + if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil { + h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()}) + return + } + + // 更新内存中的配置 + if h.config.ExternalMCP.Servers == nil { + h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) + } + + // 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容 + // 同时将值迁移到 external_mcp_enable + cfg := req.Config + + if req.Config.Disabled { + // 用户设置了 disabled: true + cfg.ExternalMCPEnable = false + cfg.Disabled = true + cfg.Enabled = false + } else if req.Config.Enabled { + // 用户设置了 enabled: true + cfg.ExternalMCPEnable = true + cfg.Enabled = true + cfg.Disabled = false + } else if !req.Config.ExternalMCPEnable { + // 用户没有设置任何字段,且 external_mcp_enable 为 false + // 检查现有配置是否有旧字段 + if existingCfg, exists := h.config.ExternalMCP.Servers[name]; exists { + // 保留现有的旧字段 + cfg.Enabled = existingCfg.Enabled + cfg.Disabled = existingCfg.Disabled + } + } else { + // 用户通过新字段启用了(external_mcp_enable: true),但没有设置旧字段 + // 为了向后兼容,我们设置 enabled: true + // 这样即使原始配置中有 disabled: false,也会被转换为 enabled: true + cfg.Enabled = true + cfg.Disabled = false + } + + h.config.ExternalMCP.Servers[name] = cfg + + // 保存到配置文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + h.logger.Info("外部MCP配置已更新", zap.String("name", name)) + c.JSON(http.StatusOK, gin.H{"message": "配置已更新"}) +} + +// DeleteExternalMCP 删除外部MCP配置 +func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) { + name := c.Param("name") + + h.mu.Lock() + defer h.mu.Unlock() + + // 移除配置 + if err := h.manager.RemoveConfig(name); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"}) + return + } + + // 从内存配置中删除 + if h.config.ExternalMCP.Servers != nil { + delete(h.config.ExternalMCP.Servers, name) + } + + // 保存到配置文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + h.logger.Info("外部MCP配置已删除", zap.String("name", name)) + c.JSON(http.StatusOK, gin.H{"message": "配置已删除"}) +} + +// StartExternalMCP 启动外部MCP +func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) { + name := c.Param("name") + + h.mu.Lock() + defer h.mu.Unlock() + + // 更新配置为启用 + if h.config.ExternalMCP.Servers == nil { + h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) + } + cfg := h.config.ExternalMCP.Servers[name] + cfg.ExternalMCPEnable = true + h.config.ExternalMCP.Servers[name] = cfg + + // 保存到配置文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + // 启动客户端(这可能会花费一些时间) + h.logger.Info("开始启动外部MCP", zap.String("name", name)) + if err := h.manager.StartClient(name); err != nil { + h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + "status": "error", + }) + return + } + + // 获取连接状态 + client, exists := h.manager.GetClient(name) + status := "disconnected" + if exists { + status = client.GetStatus() + } + + h.logger.Info("外部MCP启动完成", zap.String("name", name), zap.String("status", status)) + c.JSON(http.StatusOK, gin.H{ + "message": "外部MCP启动完成", + "status": status, + }) +} + +// StopExternalMCP 停止外部MCP +func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) { + name := c.Param("name") + + h.mu.Lock() + defer h.mu.Unlock() + + // 停止客户端 + if err := h.manager.StopClient(name); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 更新配置 + if h.config.ExternalMCP.Servers == nil { + h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig) + } + cfg := h.config.ExternalMCP.Servers[name] + cfg.ExternalMCPEnable = false + h.config.ExternalMCP.Servers[name] = cfg + + // 保存到配置文件 + if err := h.saveConfig(); err != nil { + h.logger.Error("保存配置失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()}) + return + } + + h.logger.Info("外部MCP已停止", zap.String("name", name)) + c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"}) +} + +// GetExternalMCPStats 获取统计信息 +func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) { + stats := h.manager.GetStats() + c.JSON(http.StatusOK, stats) +} + +// validateConfig 验证配置 +func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error { + transport := cfg.Transport + if transport == "" { + // 如果没有指定transport,根据是否有command或url判断 + if cfg.Command != "" { + transport = "stdio" + } else if cfg.URL != "" { + transport = "http" + } else { + return fmt.Errorf("需要指定command(stdio模式)或url(http模式)") + } + } + + switch transport { + case "http": + if cfg.URL == "" { + return fmt.Errorf("HTTP模式需要URL") + } + case "stdio": + if cfg.Command == "" { + return fmt.Errorf("stdio模式需要command") + } + default: + return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio", transport) + } + + return nil +} + +// isEnabled 检查是否启用 +func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool { + // 优先使用 ExternalMCPEnable 字段 + // 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容) + if cfg.ExternalMCPEnable { + return true + } + // 向后兼容:检查旧字段 + if cfg.Disabled { + return false + } + if cfg.Enabled { + return true + } + // 都没有设置,默认为启用 + return true +} + +// saveConfig 保存配置到文件 +func (h *ExternalMCPHandler) saveConfig() error { + // 读取现有配置文件并创建备份 + data, err := os.ReadFile(h.configPath) + if err != nil { + return fmt.Errorf("读取配置文件失败: %w", err) + } + + if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil { + h.logger.Warn("创建配置备份失败", zap.Error(err)) + } + + root, err := loadYAMLDocument(h.configPath) + if err != nil { + return fmt.Errorf("解析配置文件失败: %w", err) + } + + // 在更新前,读取原始配置中的 enabled/disabled 字段,以便保持向后兼容 + originalConfigs := make(map[string]map[string]bool) + externalMCPNode := findMapValue(root.Content[0], "external_mcp") + if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode { + serversNode := findMapValue(externalMCPNode, "servers") + if serversNode != nil && serversNode.Kind == yaml.MappingNode { + // 遍历现有的服务器配置,保存 enabled/disabled 字段 + for i := 0; i < len(serversNode.Content); i += 2 { + if i+1 >= len(serversNode.Content) { + break + } + nameNode := serversNode.Content[i] + serverNode := serversNode.Content[i+1] + if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode { + serverName := nameNode.Value + originalConfigs[serverName] = make(map[string]bool) + // 检查是否有 enabled 字段 + if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil { + originalConfigs[serverName]["enabled"] = *enabledVal + } + // 检查是否有 disabled 字段 + if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil { + originalConfigs[serverName]["disabled"] = *disabledVal + } + } + } + } + } + + // 更新外部MCP配置 + updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs) + + if err := writeYAMLDocument(h.configPath, root); err != nil { + return fmt.Errorf("保存配置文件失败: %w", err) + } + + h.logger.Info("配置已保存", zap.String("path", h.configPath)) + return nil +} + +// updateExternalMCPConfig 更新外部MCP配置 +func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, originalConfigs map[string]map[string]bool) { + root := doc.Content[0] + externalMCPNode := ensureMap(root, "external_mcp") + serversNode := ensureMap(externalMCPNode, "servers") + + // 清空现有服务器配置 + serversNode.Content = nil + + // 添加新的服务器配置 + for name, serverCfg := range cfg.Servers { + // 添加服务器名称键 + nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name} + serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + serversNode.Content = append(serversNode.Content, nameNode, serverNode) + + // 设置服务器配置字段 + if serverCfg.Command != "" { + setStringInMap(serverNode, "command", serverCfg.Command) + } + if len(serverCfg.Args) > 0 { + setStringArrayInMap(serverNode, "args", serverCfg.Args) + } + if serverCfg.Transport != "" { + setStringInMap(serverNode, "transport", serverCfg.Transport) + } + if serverCfg.URL != "" { + setStringInMap(serverNode, "url", serverCfg.URL) + } + if serverCfg.Description != "" { + setStringInMap(serverNode, "description", serverCfg.Description) + } + if serverCfg.Timeout > 0 { + setIntInMap(serverNode, "timeout", serverCfg.Timeout) + } + // 保存 external_mcp_enable 字段(新字段) + setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable) + // 保存 tool_enabled 字段(每个工具的启用状态) + if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 { + toolEnabledNode := ensureMap(serverNode, "tool_enabled") + for toolName, enabled := range serverCfg.ToolEnabled { + setBoolInMap(toolEnabledNode, toolName, enabled) + } + } + // 保留旧的 enabled/disabled 字段以保持向后兼容 + originalFields, hasOriginal := originalConfigs[name] + + // 如果原始配置中有 enabled 字段,保留它 + if hasOriginal { + if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled { + setBoolInMap(serverNode, "enabled", enabledVal) + } + // 如果原始配置中有 disabled 字段,保留它 + // 注意:由于 omitempty,disabled: false 不会被保存,但 disabled: true 会被保存 + if disabledVal, hasDisabled := originalFields["disabled"]; hasDisabled { + if disabledVal { + setBoolInMap(serverNode, "disabled", disabledVal) + } else { + // 如果原始配置中有 disabled: false,我们保存 enabled: true 来等效表示 + // 因为 disabled: false 等价于 enabled: true + setBoolInMap(serverNode, "enabled", true) + } + } + } + + // 如果用户在当前请求中明确设置了这些字段,也保存它们 + if serverCfg.Enabled { + setBoolInMap(serverNode, "enabled", serverCfg.Enabled) + } + if serverCfg.Disabled { + setBoolInMap(serverNode, "disabled", serverCfg.Disabled) + } else if !hasOriginal && serverCfg.ExternalMCPEnable { + // 如果用户通过新字段启用了,且原始配置中没有旧字段,保存 enabled: true 以保持向后兼容 + setBoolInMap(serverNode, "enabled", true) + } + } +} + +// setStringArrayInMap 设置字符串数组 +func setStringArrayInMap(mapNode *yaml.Node, key string, values []string) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.SequenceNode + valueNode.Tag = "!!seq" + valueNode.Content = nil + for _, v := range values { + itemNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: v} + valueNode.Content = append(valueNode.Content, itemNode) + } +} + +// AddOrUpdateExternalMCPRequest 添加或更新外部MCP请求 +type AddOrUpdateExternalMCPRequest struct { + Config config.ExternalMCPServerConfig `json:"config"` +} + +// ExternalMCPResponse 外部MCP响应 +type ExternalMCPResponse struct { + Config config.ExternalMCPServerConfig `json:"config"` + Status string `json:"status"` // "connected", "disconnected", "disabled", "error" + ToolCount int `json:"tool_count"` // 工具数量 +} + diff --git a/internal/handler/external_mcp_test.go b/internal/handler/external_mcp_test.go new file mode 100644 index 00000000..0ba0b1bb --- /dev/null +++ b/internal/handler/external_mcp_test.go @@ -0,0 +1,518 @@ +package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) { + gin.SetMode(gin.TestMode) + router := gin.New() + + // 创建临时配置文件 + tmpFile, err := os.CreateTemp("", "test-config-*.yaml") + if err != nil { + panic(err) + } + tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n") + tmpFile.Close() + configPath := tmpFile.Name() + + logger := zap.NewNop() + manager := mcp.NewExternalMCPManager(logger) + cfg := &config.Config{ + ExternalMCP: config.ExternalMCPConfig{ + Servers: make(map[string]config.ExternalMCPServerConfig), + }, + } + + handler := NewExternalMCPHandler(manager, cfg, configPath, logger) + + api := router.Group("/api") + api.GET("/external-mcp", handler.GetExternalMCPs) + api.GET("/external-mcp/stats", handler.GetExternalMCPStats) + api.GET("/external-mcp/:name", handler.GetExternalMCP) + api.PUT("/external-mcp/:name", handler.AddOrUpdateExternalMCP) + api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP) + api.POST("/external-mcp/:name/start", handler.StartExternalMCP) + api.POST("/external-mcp/:name/stop", handler.StopExternalMCP) + + return router, handler, configPath +} + +func cleanupTestConfig(configPath string) { + os.Remove(configPath) + os.Remove(configPath + ".backup") +} + +func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) { + router, _, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + // 测试添加stdio模式的配置 + configJSON := `{ + "command": "python3", + "args": ["/path/to/script.py", "--server", "http://example.com"], + "description": "Test stdio MCP", + "timeout": 300, + "enabled": true + }` + + var configObj config.ExternalMCPServerConfig + if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil { + t.Fatalf("解析配置JSON失败: %v", err) + } + + reqBody := AddOrUpdateExternalMCPRequest{ + Config: configObj, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + // 验证配置已添加 + req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) + } + + var response ExternalMCPResponse + if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + if response.Config.Command != "python3" { + t.Errorf("期望command为python3,实际%s", response.Config.Command) + } + if len(response.Config.Args) != 3 { + t.Errorf("期望args长度为3,实际%d", len(response.Config.Args)) + } + if response.Config.Description != "Test stdio MCP" { + t.Errorf("期望description为'Test stdio MCP',实际%s", response.Config.Description) + } + if response.Config.Timeout != 300 { + t.Errorf("期望timeout为300,实际%d", response.Config.Timeout) + } + if !response.Config.Enabled { + t.Error("期望enabled为true") + } +} + +func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) { + router, _, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + // 测试添加HTTP模式的配置 + configJSON := `{ + "transport": "http", + "url": "http://127.0.0.1:8081/mcp", + "enabled": true + }` + + var configObj config.ExternalMCPServerConfig + if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil { + t.Fatalf("解析配置JSON失败: %v", err) + } + + reqBody := AddOrUpdateExternalMCPRequest{ + Config: configObj, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + // 验证配置已添加 + req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) + } + + var response ExternalMCPResponse + if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + if response.Config.Transport != "http" { + t.Errorf("期望transport为http,实际%s", response.Config.Transport) + } + if response.Config.URL != "http://127.0.0.1:8081/mcp" { + t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL) + } + if !response.Config.Enabled { + t.Error("期望enabled为true") + } +} + +func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) { + router, _, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + testCases := []struct { + name string + configJSON string + expectedErr string + }{ + { + name: "缺少command和url", + configJSON: `{"enabled": true}`, + expectedErr: "需要指定command(stdio模式)或url(http模式)", + }, + { + name: "stdio模式缺少command", + configJSON: `{"args": ["test"], "enabled": true}`, + expectedErr: "stdio模式需要command", + }, + { + name: "http模式缺少url", + configJSON: `{"transport": "http", "enabled": true}`, + expectedErr: "HTTP模式需要URL", + }, + { + name: "无效的transport", + configJSON: `{"transport": "invalid", "enabled": true}`, + expectedErr: "不支持的传输模式", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var configObj config.ExternalMCPServerConfig + if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil { + t.Fatalf("解析配置JSON失败: %v", err) + } + + reqBody := AddOrUpdateExternalMCPRequest{ + Config: configObj, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String()) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + errorMsg := response["error"].(string) + // 对于stdio模式缺少command的情况,错误信息可能略有不同 + if tc.name == "stdio模式缺少command" { + if !strings.Contains(errorMsg, "stdio") && !strings.Contains(errorMsg, "command") { + t.Errorf("期望错误信息包含'stdio'或'command',实际'%s'", errorMsg) + } + } else if !strings.Contains(errorMsg, tc.expectedErr) { + t.Errorf("期望错误信息包含'%s',实际'%s'", tc.expectedErr, errorMsg) + } + }) + } +} + +func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) { + router, handler, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + // 先添加一个配置 + configObj := config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: true, + } + handler.manager.AddOrUpdateConfig("test-delete", configObj) + + // 删除配置 + req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + // 验证配置已删除 + req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusNotFound { + t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String()) + } +} + +func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { + router, handler, _ := setupTestRouter() + + // 添加多个配置 + handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: true, + }) + handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{ + URL: "http://127.0.0.1:8081/mcp", + Enabled: false, + }) + + req := httptest.NewRequest("GET", "/api/external-mcp", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + servers := response["servers"].(map[string]interface{}) + if len(servers) != 2 { + t.Errorf("期望2个服务器,实际%d", len(servers)) + } + if _, ok := servers["test1"]; !ok { + t.Error("期望包含test1") + } + if _, ok := servers["test2"]; !ok { + t.Error("期望包含test2") + } + + stats := response["stats"].(map[string]interface{}) + if int(stats["total"].(float64)) != 2 { + t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64))) + } +} + +func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) { + router, handler, _ := setupTestRouter() + + // 添加配置 + handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: true, + }) + handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ + URL: "http://127.0.0.1:8081/mcp", + Enabled: true, + }) + handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: false, + Disabled: true, + }) + + req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + var stats map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + if int(stats["total"].(float64)) != 3 { + t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64))) + } + if int(stats["enabled"].(float64)) != 2 { + t.Errorf("期望启用数为2,实际%d", int(stats["enabled"].(float64))) + } + if int(stats["disabled"].(float64)) != 1 { + t.Errorf("期望停用数为1,实际%d", int(stats["disabled"].(float64))) + } +} + +func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) { + router, handler, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + // 添加一个禁用的配置 + handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: false, + Disabled: true, + }) + + // 测试启动(可能会失败,因为没有真实的服务器) + req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + // 启动可能会失败,但应该返回合理的状态码 + if w.Code != http.StatusOK { + // 如果启动失败,应该是400或500 + if w.Code != http.StatusBadRequest && w.Code != http.StatusInternalServerError { + t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String()) + } + } + + // 测试停止 + req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) + } +} + +func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) { + router, _, _ := setupTestRouter() + + req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String()) + } +} + +func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) { + router, _, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + // 删除不存在的配置可能返回200(幂等操作)或404,都是合理的 + if w.Code != http.StatusNotFound && w.Code != http.StatusOK { + t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String()) + } +} + +func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) { + router, _, _ := setupTestRouter() + + configObj := config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: true, + } + + reqBody := AddOrUpdateExternalMCPRequest{ + Config: configObj, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + // 空名称应该返回404或400 + if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest { + t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String()) + } +} + +func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) { + router, _, _ := setupTestRouter() + + // 发送无效的JSON + body := []byte(`{"config": invalid json}`) + req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String()) + } +} + +func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) { + router, handler, configPath := setupTestRouter() + defer cleanupTestConfig(configPath) + + // 先添加配置 + config1 := config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: true, + } + handler.manager.AddOrUpdateConfig("test-update", config1) + + // 更新配置 + config2 := config.ExternalMCPServerConfig{ + URL: "http://127.0.0.1:8081/mcp", + Enabled: true, + } + + reqBody := AddOrUpdateExternalMCPRequest{ + Config: config2, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String()) + } + + // 验证配置已更新 + req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String()) + } + + var response ExternalMCPResponse + if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + if response.Config.URL != "http://127.0.0.1:8081/mcp" { + t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL) + } + if response.Config.Command != "" { + t.Errorf("期望command为空,实际%s", response.Config.Command) + } +} + diff --git a/internal/handler/monitor.go b/internal/handler/monitor.go index 8a47268f..a7155c83 100644 --- a/internal/handler/monitor.go +++ b/internal/handler/monitor.go @@ -14,22 +14,29 @@ import ( // MonitorHandler 监控处理器 type MonitorHandler struct { - mcpServer *mcp.Server - executor *security.Executor - db *database.DB - logger *zap.Logger + mcpServer *mcp.Server + externalMCPMgr *mcp.ExternalMCPManager + executor *security.Executor + db *database.DB + logger *zap.Logger } // NewMonitorHandler 创建新的监控处理器 func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, db *database.DB, logger *zap.Logger) *MonitorHandler { return &MonitorHandler{ - mcpServer: mcpServer, - executor: executor, - db: db, - logger: logger, + mcpServer: mcpServer, + externalMCPMgr: nil, // 将在创建后设置 + executor: executor, + db: db, + logger: logger, } } +// SetExternalMCPManager 设置外部MCP管理器 +func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) { + h.externalMCPMgr = mgr +} + // MonitorResponse 监控响应 type MonitorResponse struct { Executions []*mcp.ToolExecution `json:"executions"` @@ -128,15 +135,49 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int) ([]*mc } func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats { + // 合并内部MCP服务器和外部MCP管理器的统计信息 + stats := make(map[string]*mcp.ToolStats) + + // 加载内部MCP服务器的统计信息 if h.db == nil { - return h.mcpServer.GetStats() + internalStats := h.mcpServer.GetStats() + for k, v := range internalStats { + stats[k] = v + } + } else { + dbStats, err := h.db.LoadToolStats() + if err != nil { + h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err)) + internalStats := h.mcpServer.GetStats() + for k, v := range internalStats { + stats[k] = v + } + } else { + for k, v := range dbStats { + stats[k] = v + } + } } - stats, err := h.db.LoadToolStats() - if err != nil { - h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err)) - return h.mcpServer.GetStats() + // 合并外部MCP管理器的统计信息 + if h.externalMCPMgr != nil { + externalStats := h.externalMCPMgr.GetToolStats() + for k, v := range externalStats { + // 如果已存在,合并统计信息 + if existing, exists := stats[k]; exists { + existing.TotalCalls += v.TotalCalls + existing.SuccessCalls += v.SuccessCalls + existing.FailedCalls += v.FailedCalls + // 使用最新的调用时间 + if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) { + existing.LastCallTime = v.LastCallTime + } + } else { + stats[k] = v + } + } } + return stats } @@ -145,13 +186,32 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats { func (h *MonitorHandler) GetExecution(c *gin.Context) { id := c.Param("id") + // 先从内部MCP服务器查找 exec, exists := h.mcpServer.GetExecution(id) - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"}) + if exists { + c.JSON(http.StatusOK, exec) return } - c.JSON(http.StatusOK, exec) + // 如果找不到,尝试从外部MCP管理器查找 + if h.externalMCPMgr != nil { + exec, exists = h.externalMCPMgr.GetExecution(id) + if exists { + c.JSON(http.StatusOK, exec) + return + } + } + + // 如果都找不到,尝试从数据库查找(如果使用数据库存储) + if h.db != nil { + exec, err := h.db.GetToolExecution(id) + if err == nil && exec != nil { + c.JSON(http.StatusOK, exec) + return + } + } + + c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"}) } // GetStats 获取统计信息 @@ -160,3 +220,4 @@ func (h *MonitorHandler) GetStats(c *gin.Context) { c.JSON(http.StatusOK, stats) } + diff --git a/internal/mcp/client.go b/internal/mcp/client.go new file mode 100644 index 00000000..752d1588 --- /dev/null +++ b/internal/mcp/client.go @@ -0,0 +1,474 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os/exec" + "sync" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// ExternalMCPClient 外部MCP客户端接口 +type ExternalMCPClient interface { + // Initialize 初始化连接 + Initialize(ctx context.Context) error + // ListTools 列出工具 + ListTools(ctx context.Context) ([]Tool, error) + // CallTool 调用工具 + CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) + // Close 关闭连接 + Close() error + // IsConnected 检查是否已连接 + IsConnected() bool + // GetStatus 获取状态 + GetStatus() string +} + +// HTTPMCPClient HTTP模式的MCP客户端 +type HTTPMCPClient struct { + url string + timeout time.Duration + client *http.Client + logger *zap.Logger + mu sync.RWMutex + status string // "disconnected", "connecting", "connected", "error" +} + +// NewHTTPMCPClient 创建HTTP模式的MCP客户端 +func NewHTTPMCPClient(url string, timeout time.Duration, logger *zap.Logger) *HTTPMCPClient { + if timeout <= 0 { + timeout = 30 * time.Second + } + return &HTTPMCPClient{ + url: url, + timeout: timeout, + client: &http.Client{ + Timeout: timeout, + }, + logger: logger, + status: "disconnected", + } +} + +func (c *HTTPMCPClient) setStatus(status string) { + c.mu.Lock() + defer c.mu.Unlock() + c.status = status +} + +func (c *HTTPMCPClient) GetStatus() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.status +} + +func (c *HTTPMCPClient) IsConnected() bool { + return c.GetStatus() == "connected" +} + +func (c *HTTPMCPClient) Initialize(ctx context.Context) error { + c.setStatus("connecting") + + req := Message{ + ID: MessageID{value: "1"}, + Method: "initialize", + Version: "2.0", + } + + params := InitializeRequest{ + ProtocolVersion: ProtocolVersion, + Capabilities: make(map[string]interface{}), + ClientInfo: ClientInfo{ + Name: "CyberStrikeAI", + Version: "1.0.0", + }, + } + + paramsJSON, _ := json.Marshal(params) + req.Params = paramsJSON + + _, err := c.sendRequest(ctx, &req) + if err != nil { + c.setStatus("error") + return fmt.Errorf("初始化失败: %w", err) + } + + c.setStatus("connected") + return nil +} + +func (c *HTTPMCPClient) ListTools(ctx context.Context) ([]Tool, error) { + req := Message{ + ID: MessageID{value: uuid.New().String()}, + Method: "tools/list", + Version: "2.0", + } + + req.Params = json.RawMessage("{}") + + resp, err := c.sendRequest(ctx, &req) + if err != nil { + return nil, fmt.Errorf("获取工具列表失败: %w", err) + } + + var listResp ListToolsResponse + if err := json.Unmarshal(resp.Result, &listResp); err != nil { + return nil, fmt.Errorf("解析工具列表失败: %w", err) + } + + return listResp.Tools, nil +} + +func (c *HTTPMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { + req := Message{ + ID: MessageID{value: uuid.New().String()}, + Method: "tools/call", + Version: "2.0", + } + + callReq := CallToolRequest{ + Name: name, + Arguments: args, + } + + paramsJSON, _ := json.Marshal(callReq) + req.Params = paramsJSON + + resp, err := c.sendRequest(ctx, &req) + if err != nil { + return nil, fmt.Errorf("调用工具失败: %w", err) + } + + var callResp CallToolResponse + if err := json.Unmarshal(resp.Result, &callResp); err != nil { + return nil, fmt.Errorf("解析工具调用结果失败: %w", err) + } + + return &ToolResult{ + Content: callResp.Content, + IsError: callResp.IsError, + }, nil +} + +func (c *HTTPMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) { + body, err := json.Marshal(msg) + if err != nil { + return nil, fmt.Errorf("序列化请求失败: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("创建HTTP请求失败: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("HTTP请求失败: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP错误 %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var mcpResp Message + if err := json.NewDecoder(resp.Body).Decode(&mcpResp); err != nil { + return nil, fmt.Errorf("解析响应失败: %w", err) + } + + if mcpResp.Error != nil { + return nil, fmt.Errorf("MCP错误: %s (code: %d)", mcpResp.Error.Message, mcpResp.Error.Code) + } + + return &mcpResp, nil +} + +func (c *HTTPMCPClient) Close() error { + c.setStatus("disconnected") + return nil +} + +// StdioMCPClient stdio模式的MCP客户端 +type StdioMCPClient struct { + command string + args []string + timeout time.Duration + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser + decoder *json.Decoder + encoder *json.Encoder + logger *zap.Logger + mu sync.RWMutex + status string + requestID int64 + responses map[string]chan *Message + responsesMu sync.Mutex + ctx context.Context + cancel context.CancelFunc +} + +// NewStdioMCPClient 创建stdio模式的MCP客户端 +func NewStdioMCPClient(command string, args []string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient { + if timeout <= 0 { + timeout = 30 * time.Second + } + ctx, cancel := context.WithCancel(context.Background()) + return &StdioMCPClient{ + command: command, + args: args, + timeout: timeout, + logger: logger, + status: "disconnected", + responses: make(map[string]chan *Message), + ctx: ctx, + cancel: cancel, + } +} + +func (c *StdioMCPClient) setStatus(status string) { + c.mu.Lock() + defer c.mu.Unlock() + c.status = status +} + +func (c *StdioMCPClient) GetStatus() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.status +} + +func (c *StdioMCPClient) IsConnected() bool { + return c.GetStatus() == "connected" +} + +func (c *StdioMCPClient) Initialize(ctx context.Context) error { + c.setStatus("connecting") + + if err := c.startProcess(); err != nil { + c.setStatus("error") + return fmt.Errorf("启动进程失败: %w", err) + } + + // 启动响应读取goroutine + go c.readResponses() + + // 发送初始化请求 + req := Message{ + ID: MessageID{value: "1"}, + Method: "initialize", + Version: "2.0", + } + + params := InitializeRequest{ + ProtocolVersion: ProtocolVersion, + Capabilities: make(map[string]interface{}), + ClientInfo: ClientInfo{ + Name: "CyberStrikeAI", + Version: "1.0.0", + }, + } + + paramsJSON, _ := json.Marshal(params) + req.Params = paramsJSON + + _, err := c.sendRequest(ctx, &req) + if err != nil { + c.setStatus("error") + c.Close() + return fmt.Errorf("初始化失败: %w", err) + } + + c.setStatus("connected") + return nil +} + +func (c *StdioMCPClient) startProcess() error { + cmd := exec.CommandContext(c.ctx, c.command, c.args...) + + stdin, err := cmd.StdinPipe() + if err != nil { + return err + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + stdin.Close() + return err + } + + if err := cmd.Start(); err != nil { + stdin.Close() + stdout.Close() + return err + } + + c.cmd = cmd + c.stdin = stdin + c.stdout = stdout + c.decoder = json.NewDecoder(stdout) + c.encoder = json.NewEncoder(stdin) + + return nil +} + +func (c *StdioMCPClient) readResponses() { + defer func() { + if r := recover(); r != nil { + c.logger.Error("读取响应时发生panic", zap.Any("error", r)) + } + }() + + for { + var msg Message + if err := c.decoder.Decode(&msg); err != nil { + if err == io.EOF { + c.setStatus("disconnected") + break + } + c.logger.Error("读取响应失败", zap.Error(err)) + break + } + + // 处理响应 + id := msg.ID.String() + c.responsesMu.Lock() + if ch, ok := c.responses[id]; ok { + select { + case ch <- &msg: + default: + } + delete(c.responses, id) + } + c.responsesMu.Unlock() + } +} + +func (c *StdioMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) { + if c.encoder == nil { + return nil, fmt.Errorf("进程未启动") + } + + id := msg.ID.String() + if id == "" { + c.mu.Lock() + c.requestID++ + id = fmt.Sprintf("%d", c.requestID) + msg.ID = MessageID{value: id} + c.mu.Unlock() + } + + // 创建响应通道 + responseCh := make(chan *Message, 1) + c.responsesMu.Lock() + c.responses[id] = responseCh + c.responsesMu.Unlock() + + // 发送请求 + if err := c.encoder.Encode(msg); err != nil { + c.responsesMu.Lock() + delete(c.responses, id) + c.responsesMu.Unlock() + return nil, fmt.Errorf("发送请求失败: %w", err) + } + + // 等待响应 + select { + case resp := <-responseCh: + if resp.Error != nil { + return nil, fmt.Errorf("MCP错误: %s (code: %d)", resp.Error.Message, resp.Error.Code) + } + return resp, nil + case <-ctx.Done(): + c.responsesMu.Lock() + delete(c.responses, id) + c.responsesMu.Unlock() + return nil, ctx.Err() + case <-time.After(c.timeout): + c.responsesMu.Lock() + delete(c.responses, id) + c.responsesMu.Unlock() + return nil, fmt.Errorf("请求超时") + } +} + +func (c *StdioMCPClient) ListTools(ctx context.Context) ([]Tool, error) { + req := Message{ + ID: MessageID{value: uuid.New().String()}, + Method: "tools/list", + Version: "2.0", + } + + req.Params = json.RawMessage("{}") + + resp, err := c.sendRequest(ctx, &req) + if err != nil { + return nil, fmt.Errorf("获取工具列表失败: %w", err) + } + + var listResp ListToolsResponse + if err := json.Unmarshal(resp.Result, &listResp); err != nil { + return nil, fmt.Errorf("解析工具列表失败: %w", err) + } + + return listResp.Tools, nil +} + +func (c *StdioMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { + req := Message{ + ID: MessageID{value: uuid.New().String()}, + Method: "tools/call", + Version: "2.0", + } + + callReq := CallToolRequest{ + Name: name, + Arguments: args, + } + + paramsJSON, _ := json.Marshal(callReq) + req.Params = paramsJSON + + resp, err := c.sendRequest(ctx, &req) + if err != nil { + return nil, fmt.Errorf("调用工具失败: %w", err) + } + + var callResp CallToolResponse + if err := json.Unmarshal(resp.Result, &callResp); err != nil { + return nil, fmt.Errorf("解析工具调用结果失败: %w", err) + } + + return &ToolResult{ + Content: callResp.Content, + IsError: callResp.IsError, + }, nil +} + +func (c *StdioMCPClient) Close() error { + c.cancel() + + if c.stdin != nil { + c.stdin.Close() + } + if c.stdout != nil { + c.stdout.Close() + } + if c.cmd != nil { + c.cmd.Process.Kill() + c.cmd.Wait() + } + + c.setStatus("disconnected") + return nil +} diff --git a/internal/mcp/external_manager.go b/internal/mcp/external_manager.go new file mode 100644 index 00000000..1dac3dc2 --- /dev/null +++ b/internal/mcp/external_manager.go @@ -0,0 +1,660 @@ +package mcp + +import ( + "context" + "fmt" + "sync" + "time" + + "cyberstrike-ai/internal/config" + "github.com/google/uuid" + + "go.uber.org/zap" +) + +// ExternalMCPManager 外部MCP管理器 +type ExternalMCPManager struct { + clients map[string]ExternalMCPClient + configs map[string]config.ExternalMCPServerConfig + logger *zap.Logger + storage MonitorStorage // 可选的持久化存储 + executions map[string]*ToolExecution // 执行记录 + stats map[string]*ToolStats // 工具统计信息 + mu sync.RWMutex +} + +// NewExternalMCPManager 创建外部MCP管理器 +func NewExternalMCPManager(logger *zap.Logger) *ExternalMCPManager { + return NewExternalMCPManagerWithStorage(logger, nil) +} + +// NewExternalMCPManagerWithStorage 创建外部MCP管理器(带持久化存储) +func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage) *ExternalMCPManager { + return &ExternalMCPManager{ + clients: make(map[string]ExternalMCPClient), + configs: make(map[string]config.ExternalMCPServerConfig), + logger: logger, + storage: storage, + executions: make(map[string]*ToolExecution), + stats: make(map[string]*ToolStats), + } +} + +// LoadConfigs 加载配置 +func (m *ExternalMCPManager) LoadConfigs(cfg *config.ExternalMCPConfig) { + m.mu.Lock() + defer m.mu.Unlock() + + if cfg == nil || cfg.Servers == nil { + return + } + + m.configs = make(map[string]config.ExternalMCPServerConfig) + for name, serverCfg := range cfg.Servers { + m.configs[name] = serverCfg + } +} + +// GetConfigs 获取所有配置 +func (m *ExternalMCPManager) GetConfigs() map[string]config.ExternalMCPServerConfig { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string]config.ExternalMCPServerConfig) + for k, v := range m.configs { + result[k] = v + } + return result +} + +// AddOrUpdateConfig 添加或更新配置 +func (m *ExternalMCPManager) AddOrUpdateConfig(name string, serverCfg config.ExternalMCPServerConfig) error { + m.mu.Lock() + defer m.mu.Unlock() + + // 如果已存在客户端,先关闭 + if client, exists := m.clients[name]; exists { + client.Close() + delete(m.clients, name) + } + + m.configs[name] = serverCfg + + // 如果启用,自动连接 + if m.isEnabled(serverCfg) { + go m.connectClient(name, serverCfg) + } + + return nil +} + +// RemoveConfig 移除配置 +func (m *ExternalMCPManager) RemoveConfig(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + + // 关闭客户端 + if client, exists := m.clients[name]; exists { + client.Close() + delete(m.clients, name) + } + + delete(m.configs, name) + return nil +} + +// StartClient 启动客户端 +func (m *ExternalMCPManager) StartClient(name string) error { + m.mu.Lock() + serverCfg, exists := m.configs[name] + m.mu.Unlock() + + if !exists { + return fmt.Errorf("配置不存在: %s", name) + } + + // 检查是否已经有连接的客户端 + m.mu.RLock() + _, hasClient := m.clients[name] + m.mu.RUnlock() + + if hasClient { + // 检查客户端是否已连接 + if client, ok := m.GetClient(name); ok && client.IsConnected() { + return fmt.Errorf("客户端已连接") + } + // 如果有客户端但未连接,先关闭 + if client, ok := m.GetClient(name); ok { + client.Close() + m.mu.Lock() + delete(m.clients, name) + m.mu.Unlock() + } + } + + // 更新配置为启用 + m.mu.Lock() + serverCfg.ExternalMCPEnable = true + m.configs[name] = serverCfg + m.mu.Unlock() + + // 连接客户端 + return m.connectClient(name, serverCfg) +} + +// StopClient 停止客户端 +func (m *ExternalMCPManager) StopClient(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + + serverCfg, exists := m.configs[name] + if !exists { + return fmt.Errorf("配置不存在: %s", name) + } + + // 关闭客户端 + if client, exists := m.clients[name]; exists { + client.Close() + delete(m.clients, name) + } + + // 更新配置为禁用 + serverCfg.ExternalMCPEnable = false + m.configs[name] = serverCfg + + return nil +} + +// GetClient 获取客户端 +func (m *ExternalMCPManager) GetClient(name string) (ExternalMCPClient, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + client, exists := m.clients[name] + return client, exists +} + +// GetAllTools 获取所有外部MCP的工具 +func (m *ExternalMCPManager) GetAllTools(ctx context.Context) ([]Tool, error) { + m.mu.RLock() + clients := make(map[string]ExternalMCPClient) + for k, v := range m.clients { + clients[k] = v + } + m.mu.RUnlock() + + var allTools []Tool + for name, client := range clients { + if !client.IsConnected() { + continue + } + + tools, err := client.ListTools(ctx) + if err != nil { + m.logger.Warn("获取外部MCP工具列表失败", + zap.String("name", name), + zap.Error(err), + ) + continue + } + + // 为工具添加前缀,避免冲突 + for _, tool := range tools { + tool.Name = fmt.Sprintf("%s::%s", name, tool.Name) + allTools = append(allTools, tool) + } + } + + return allTools, nil +} + +// CallTool 调用外部MCP工具(返回执行ID) +func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) { + // 解析工具名称:name::toolName + var mcpName, actualToolName string + if idx := findSubstring(toolName, "::"); idx > 0 { + mcpName = toolName[:idx] + actualToolName = toolName[idx+2:] + } else { + return nil, "", fmt.Errorf("无效的工具名称格式: %s", toolName) + } + + client, exists := m.GetClient(mcpName) + if !exists { + return nil, "", fmt.Errorf("外部MCP客户端不存在: %s", mcpName) + } + + if !client.IsConnected() { + return nil, "", fmt.Errorf("外部MCP客户端未连接: %s", mcpName) + } + + // 创建执行记录 + executionID := uuid.New().String() + execution := &ToolExecution{ + ID: executionID, + ToolName: toolName, // 使用完整工具名称(包含MCP名称) + Arguments: args, + Status: "running", + StartTime: time.Now(), + } + + m.mu.Lock() + m.executions[executionID] = execution + // 如果内存中的执行记录超过限制,清理最旧的记录 + m.cleanupOldExecutions() + m.mu.Unlock() + + if m.storage != nil { + if err := m.storage.SaveToolExecution(execution); err != nil { + m.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) + } + } + + // 调用工具 + result, err := client.CallTool(ctx, actualToolName, args) + + // 更新执行记录 + m.mu.Lock() + now := time.Now() + execution.EndTime = &now + execution.Duration = now.Sub(execution.StartTime) + + if err != nil { + execution.Status = "failed" + execution.Error = err.Error() + } else if result != nil && result.IsError { + execution.Status = "failed" + if len(result.Content) > 0 { + execution.Error = result.Content[0].Text + } else { + execution.Error = "工具执行返回错误结果" + } + execution.Result = result + } else { + execution.Status = "completed" + if result == nil { + result = &ToolResult{ + Content: []Content{ + {Type: "text", Text: "工具执行完成,但未返回结果"}, + }, + } + } + execution.Result = result + } + m.mu.Unlock() + + if m.storage != nil { + if err := m.storage.SaveToolExecution(execution); err != nil { + m.logger.Warn("保存执行记录到数据库失败", zap.Error(err)) + } + } + + // 更新统计信息 + failed := err != nil || (result != nil && result.IsError) + m.updateStats(toolName, failed) + + // 如果使用存储,从内存中删除(已持久化) + if m.storage != nil { + m.mu.Lock() + delete(m.executions, executionID) + m.mu.Unlock() + } + + if err != nil { + return nil, executionID, err + } + + return result, executionID, nil +} + +// cleanupOldExecutions 清理旧的执行记录(保持内存中的记录数量在限制内) +func (m *ExternalMCPManager) cleanupOldExecutions() { + const maxExecutionsInMemory = 1000 + if len(m.executions) <= maxExecutionsInMemory { + return + } + + // 按开始时间排序,删除最旧的记录 + type execTime struct { + id string + startTime time.Time + } + var execs []execTime + for id, exec := range m.executions { + execs = append(execs, execTime{id: id, startTime: exec.StartTime}) + } + + // 按时间排序 + for i := 0; i < len(execs)-1; i++ { + for j := i + 1; j < len(execs); j++ { + if execs[i].startTime.After(execs[j].startTime) { + execs[i], execs[j] = execs[j], execs[i] + } + } + } + + // 删除最旧的记录 + toDelete := len(m.executions) - maxExecutionsInMemory + for i := 0; i < toDelete && i < len(execs); i++ { + delete(m.executions, execs[i].id) + } +} + +// GetExecution 获取执行记录(先从内存查找,再从数据库查找) +func (m *ExternalMCPManager) GetExecution(id string) (*ToolExecution, bool) { + m.mu.RLock() + exec, exists := m.executions[id] + m.mu.RUnlock() + + if exists { + return exec, true + } + + if m.storage != nil { + exec, err := m.storage.GetToolExecution(id) + if err == nil { + return exec, true + } + } + + return nil, false +} + +// updateStats 更新统计信息 +func (m *ExternalMCPManager) updateStats(toolName string, failed bool) { + now := time.Now() + if m.storage != nil { + totalCalls := 1 + successCalls := 0 + failedCalls := 0 + if failed { + failedCalls = 1 + } else { + successCalls = 1 + } + if err := m.storage.UpdateToolStats(toolName, totalCalls, successCalls, failedCalls, &now); err != nil { + m.logger.Warn("保存统计信息到数据库失败", zap.Error(err)) + } + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.stats[toolName] == nil { + m.stats[toolName] = &ToolStats{ + ToolName: toolName, + } + } + + stats := m.stats[toolName] + stats.TotalCalls++ + stats.LastCallTime = &now + + if failed { + stats.FailedCalls++ + } else { + stats.SuccessCalls++ + } +} + +// GetStats 获取MCP服务器统计信息 +func (m *ExternalMCPManager) GetStats() map[string]interface{} { + m.mu.RLock() + defer m.mu.RUnlock() + + total := len(m.configs) + enabled := 0 + disabled := 0 + connected := 0 + + for name, cfg := range m.configs { + if m.isEnabled(cfg) { + enabled++ + if client, exists := m.clients[name]; exists && client.IsConnected() { + connected++ + } + } else { + disabled++ + } + } + + return map[string]interface{}{ + "total": total, + "enabled": enabled, + "disabled": disabled, + "connected": connected, + } +} + +// GetToolStats 获取工具统计信息(合并内存和数据库) +// 只返回外部MCP工具的统计信息(工具名称包含 "::") +func (m *ExternalMCPManager) GetToolStats() map[string]*ToolStats { + result := make(map[string]*ToolStats) + + // 从数据库加载统计信息(如果使用数据库存储) + if m.storage != nil { + dbStats, err := m.storage.LoadToolStats() + if err == nil { + // 只保留外部MCP工具的统计信息(工具名称包含 "::") + for k, v := range dbStats { + if findSubstring(k, "::") > 0 { + result[k] = v + } + } + } else { + m.logger.Warn("从数据库加载统计信息失败", zap.Error(err)) + } + } + + // 合并内存中的统计信息 + m.mu.RLock() + for k, v := range m.stats { + // 如果数据库中已有该工具的统计信息,合并它们 + if existing, exists := result[k]; exists { + // 创建新的统计信息对象,避免修改共享对象 + merged := &ToolStats{ + ToolName: k, + TotalCalls: existing.TotalCalls + v.TotalCalls, + SuccessCalls: existing.SuccessCalls + v.SuccessCalls, + FailedCalls: existing.FailedCalls + v.FailedCalls, + } + // 使用最新的调用时间 + if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) { + merged.LastCallTime = v.LastCallTime + } else if existing.LastCallTime != nil { + timeCopy := *existing.LastCallTime + merged.LastCallTime = &timeCopy + } + result[k] = merged + } else { + // 如果数据库中没有,直接使用内存中的统计信息 + statCopy := *v + result[k] = &statCopy + } + } + m.mu.RUnlock() + + return result +} + +// GetToolCount 获取指定外部MCP的工具数量 +func (m *ExternalMCPManager) GetToolCount(name string) (int, error) { + client, exists := m.GetClient(name) + if !exists { + return 0, fmt.Errorf("客户端不存在: %s", name) + } + + if !client.IsConnected() { + return 0, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + tools, err := client.ListTools(ctx) + if err != nil { + return 0, fmt.Errorf("获取工具列表失败: %w", err) + } + + return len(tools), nil +} + +// GetToolCounts 获取所有外部MCP的工具数量 +func (m *ExternalMCPManager) GetToolCounts() map[string]int { + m.mu.RLock() + clients := make(map[string]ExternalMCPClient) + for k, v := range m.clients { + clients[k] = v + } + m.mu.RUnlock() + + result := make(map[string]int) + for name, client := range clients { + if !client.IsConnected() { + result[name] = 0 + continue + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + tools, err := client.ListTools(ctx) + cancel() + + if err != nil { + m.logger.Warn("获取外部MCP工具数量失败", + zap.String("name", name), + zap.Error(err), + ) + result[name] = 0 + continue + } + + result[name] = len(tools) + } + + return result +} + +// connectClient 连接客户端(异步) +func (m *ExternalMCPManager) connectClient(name string, serverCfg config.ExternalMCPServerConfig) error { + var client ExternalMCPClient + + timeout := time.Duration(serverCfg.Timeout) * time.Second + if timeout <= 0 { + timeout = 30 * time.Second + } + + // 根据传输模式创建客户端 + transport := serverCfg.Transport + if transport == "" { + // 如果没有指定transport,根据是否有command或url判断 + if serverCfg.Command != "" { + transport = "stdio" + } else if serverCfg.URL != "" { + transport = "http" + } else { + return fmt.Errorf("无法确定传输模式: 需要指定command或url") + } + } + + switch transport { + case "http": + if serverCfg.URL == "" { + return fmt.Errorf("HTTP模式需要URL") + } + client = NewHTTPMCPClient(serverCfg.URL, timeout, m.logger) + case "stdio": + if serverCfg.Command == "" { + return fmt.Errorf("stdio模式需要command") + } + client = NewStdioMCPClient(serverCfg.Command, serverCfg.Args, timeout, m.logger) + default: + return fmt.Errorf("不支持的传输模式: %s", transport) + } + + // 初始化连接 + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if err := client.Initialize(ctx); err != nil { + m.logger.Error("初始化外部MCP客户端失败", + zap.String("name", name), + zap.Error(err), + ) + return err + } + + // 保存客户端 + m.mu.Lock() + m.clients[name] = client + m.mu.Unlock() + + m.logger.Info("外部MCP客户端已连接", + zap.String("name", name), + zap.String("transport", transport), + ) + + return nil +} + +// isEnabled 检查是否启用 +func (m *ExternalMCPManager) isEnabled(cfg config.ExternalMCPServerConfig) bool { + // 优先使用 ExternalMCPEnable 字段 + // 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容) + if cfg.ExternalMCPEnable { + return true + } + // 向后兼容:检查旧字段 + if cfg.Disabled { + return false + } + if cfg.Enabled { + return true + } + // 都没有设置,默认为启用 + return true +} + +// findSubstring 查找子字符串(简单实现) +func findSubstring(s, substr string) int { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} + +// StartAllEnabled 启动所有启用的客户端 +func (m *ExternalMCPManager) StartAllEnabled() { + m.mu.RLock() + configs := make(map[string]config.ExternalMCPServerConfig) + for k, v := range m.configs { + configs[k] = v + } + m.mu.RUnlock() + + for name, cfg := range configs { + if m.isEnabled(cfg) { + go func(n string, c config.ExternalMCPServerConfig) { + if err := m.connectClient(n, c); err != nil { + m.logger.Error("启动外部MCP客户端失败", + zap.String("name", n), + zap.Error(err), + ) + } + }(name, cfg) + } + } +} + +// StopAll 停止所有客户端 +func (m *ExternalMCPManager) StopAll() { + m.mu.Lock() + defer m.mu.Unlock() + + for name, client := range m.clients { + client.Close() + delete(m.clients, name) + } +} diff --git a/internal/mcp/external_manager_test.go b/internal/mcp/external_manager_test.go new file mode 100644 index 00000000..90542c1c --- /dev/null +++ b/internal/mcp/external_manager_test.go @@ -0,0 +1,261 @@ +package mcp + +import ( + "context" + "testing" + "time" + + "cyberstrike-ai/internal/config" + + "go.uber.org/zap" +) + +func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + // 测试添加stdio配置 + stdioCfg := config.ExternalMCPServerConfig{ + Command: "python3", + Args: []string{"/path/to/script.py"}, + Transport: "stdio", + Description: "Test stdio MCP", + Timeout: 30, + Enabled: true, + } + + err := manager.AddOrUpdateConfig("test-stdio", stdioCfg) + if err != nil { + t.Fatalf("添加stdio配置失败: %v", err) + } + + // 测试添加HTTP配置 + httpCfg := config.ExternalMCPServerConfig{ + Transport: "http", + URL: "http://127.0.0.1:8081/mcp", + Description: "Test HTTP MCP", + Timeout: 30, + Enabled: false, + } + + err = manager.AddOrUpdateConfig("test-http", httpCfg) + if err != nil { + t.Fatalf("添加HTTP配置失败: %v", err) + } + + // 验证配置已保存 + configs := manager.GetConfigs() + if len(configs) != 2 { + t.Fatalf("期望2个配置,实际%d个", len(configs)) + } + + if configs["test-stdio"].Command != stdioCfg.Command { + t.Errorf("stdio配置命令不匹配") + } + + if configs["test-http"].URL != httpCfg.URL { + t.Errorf("HTTP配置URL不匹配") + } +} + +func TestExternalMCPManager_RemoveConfig(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + cfg := config.ExternalMCPServerConfig{ + Command: "python3", + Transport: "stdio", + Enabled: false, + } + + manager.AddOrUpdateConfig("test-remove", cfg) + + // 移除配置 + err := manager.RemoveConfig("test-remove") + if err != nil { + t.Fatalf("移除配置失败: %v", err) + } + + configs := manager.GetConfigs() + if _, exists := configs["test-remove"]; exists { + t.Error("配置应该已被移除") + } +} + +func TestExternalMCPManager_GetStats(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + // 添加多个配置 + manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: true, + }) + + manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ + URL: "http://127.0.0.1:8081/mcp", + Enabled: true, + }) + + manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ + Command: "python3", + Enabled: false, + Disabled: true, // 明确设置为禁用 + }) + + stats := manager.GetStats() + + if stats["total"].(int) != 3 { + t.Errorf("期望总数3,实际%d", stats["total"]) + } + + if stats["enabled"].(int) != 2 { + t.Errorf("期望启用数2,实际%d", stats["enabled"]) + } + + if stats["disabled"].(int) != 1 { + t.Errorf("期望停用数1,实际%d", stats["disabled"]) + } +} + +func TestExternalMCPManager_LoadConfigs(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + externalMCPConfig := config.ExternalMCPConfig{ + Servers: map[string]config.ExternalMCPServerConfig{ + "loaded1": { + Command: "python3", + Enabled: true, + }, + "loaded2": { + URL: "http://127.0.0.1:8081/mcp", + Enabled: false, + }, + }, + } + + manager.LoadConfigs(&externalMCPConfig) + + configs := manager.GetConfigs() + if len(configs) != 2 { + t.Fatalf("期望2个配置,实际%d个", len(configs)) + } + + if configs["loaded1"].Command != "python3" { + t.Error("配置1加载失败") + } + + if configs["loaded2"].URL != "http://127.0.0.1:8081/mcp" { + t.Error("配置2加载失败") + } +} + +func TestHTTPMCPClient_Initialize(t *testing.T) { + // 注意:这个测试需要一个真实的HTTP MCP服务器 + // 如果没有服务器,这个测试会失败 + // 在实际测试中,可以使用mock服务器 + logger := zap.NewNop() + client := NewHTTPMCPClient("http://127.0.0.1:8081/mcp", 5*time.Second, logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 这个测试可能会失败,如果没有真实的服务器 + // 在实际环境中,应该使用mock服务器 + err := client.Initialize(ctx) + if err != nil { + t.Logf("初始化失败(可能是没有服务器): %v", err) + } + + status := client.GetStatus() + if status == "" { + t.Error("状态不应该为空") + } + + client.Close() +} + +func TestStdioMCPClient_Initialize(t *testing.T) { + // 注意:这个测试需要一个真实的stdio MCP服务器 + // 如果没有服务器,这个测试会失败 + logger := zap.NewNop() + client := NewStdioMCPClient("echo", []string{"test"}, 5*time.Second, logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 这个测试可能会失败,因为echo不是MCP服务器 + // 在实际环境中,应该使用真实的MCP服务器或mock + err := client.Initialize(ctx) + if err != nil { + t.Logf("初始化失败(echo不是MCP服务器): %v", err) + } + + client.Close() +} + +func TestExternalMCPManager_StartStopClient(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + // 添加一个禁用的配置 + cfg := config.ExternalMCPServerConfig{ + Command: "python3", + Transport: "stdio", + Enabled: false, + } + + manager.AddOrUpdateConfig("test-start-stop", cfg) + + // 尝试启动(可能会失败,因为没有真实的服务器) + err := manager.StartClient("test-start-stop") + if err != nil { + t.Logf("启动失败(可能是没有服务器): %v", err) + } + + // 停止 + err = manager.StopClient("test-start-stop") + if err != nil { + t.Fatalf("停止失败: %v", err) + } + + // 验证配置已更新为禁用 + configs := manager.GetConfigs() + if configs["test-start-stop"].Enabled { + t.Error("配置应该已被禁用") + } +} + +func TestExternalMCPManager_CallTool(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + // 测试调用不存在的工具 + _, err := manager.CallTool(context.Background(), "nonexistent::tool", map[string]interface{}{}) + if err == nil { + t.Error("应该返回错误") + } + + // 测试无效的工具名称格式 + _, err = manager.CallTool(context.Background(), "invalid-tool-name", map[string]interface{}{}) + if err == nil { + t.Error("应该返回错误(无效格式)") + } +} + +func TestExternalMCPManager_GetAllTools(t *testing.T) { + logger := zap.NewNop() + manager := NewExternalMCPManager(logger) + + ctx := context.Background() + tools, err := manager.GetAllTools(ctx) + if err != nil { + t.Fatalf("获取工具列表失败: %v", err) + } + + // 如果没有连接的客户端,应该返回空列表 + if len(tools) != 0 { + t.Logf("获取到%d个工具", len(tools)) + } +} diff --git a/web/static/css/style.css b/web/static/css/style.css index a845744b..ed5601a7 100644 --- a/web/static/css/style.css +++ b/web/static/css/style.css @@ -1525,6 +1525,8 @@ header { display: flex; gap: 8px; align-items: center; + flex-wrap: nowrap; + width: 100%; } .tools-actions button { @@ -1557,6 +1559,8 @@ header { display: flex; gap: 4px; flex: 1; + min-width: 0; + max-width: 300px; align-items: center; } @@ -1639,6 +1643,23 @@ header { color: var(--text-primary); font-size: 0.9375rem; margin-bottom: 4px; + display: flex; + align-items: center; + gap: 8px; + flex-wrap: wrap; +} + +.external-tool-badge { + display: inline-flex; + align-items: center; + padding: 2px 8px; + background: rgba(255, 152, 0, 0.12); + border: 1px solid rgba(255, 152, 0, 0.3); + border-radius: 12px; + font-size: 0.75rem; + font-weight: 600; + color: #ff9800; + white-space: nowrap; } .tool-item-desc { @@ -2007,3 +2028,392 @@ header { background: rgba(220, 53, 69, 0.08); border-radius: 12px; } + +/* 外部MCP配置样式 */ +.external-mcp-controls { + display: flex; + flex-direction: column; + gap: 16px; +} + +.external-mcp-actions { + display: flex; + align-items: center; + gap: 12px; + flex-wrap: wrap; +} + +.external-mcp-stats { + display: flex; + align-items: center; + gap: 20px; + margin-left: auto; + padding: 10px 20px; + background: linear-gradient(135deg, var(--bg-secondary) 0%, var(--bg-tertiary) 100%); + border: 1px solid var(--border-color); + border-radius: 10px; + font-size: 0.875rem; + color: var(--text-secondary); + box-shadow: var(--shadow-sm); +} + +.external-mcp-stats span { + display: inline-flex; + align-items: center; + gap: 8px; + padding: 4px 0; + white-space: nowrap; +} + +.external-mcp-stats span strong { + color: var(--text-primary); + font-weight: 600; + margin-left: 4px; +} + +.external-mcp-stats span:not(:last-child)::after { + content: ''; +} + +.tools-stats { + display: flex; + align-items: center; + gap: 15px; + margin-left: auto; + padding: 8px 16px; + background: linear-gradient(135deg, var(--bg-secondary) 0%, var(--bg-tertiary) 100%); + border: 1px solid var(--border-color); + border-radius: 10px; + font-size: 0.8125rem; + color: var(--text-secondary); + box-shadow: var(--shadow-sm); + flex-wrap: nowrap; + flex-shrink: 0; + white-space: nowrap; +} + +.tools-stats span { + display: inline-flex; + align-items: center; + gap: 8px; + padding: 4px 0; + white-space: nowrap; + flex-shrink: 0; +} + +.tools-stats span strong { + color: var(--text-primary); + font-weight: 600; + margin-left: 4px; +} + +.tools-stats span:not(:last-child)::after { + content: ''; + width: 1px; + height: 20px; + background: var(--border-color); + margin-left: 10px; + display: inline-block; +} + +.external-mcp-list { + display: flex; + flex-direction: column; + gap: 12px; +} + +.external-mcp-items { + display: flex; + flex-direction: column; + gap: 12px; +} + +.external-mcp-item { + background: var(--bg-primary); + border: 1px solid var(--border-color); + border-radius: 12px; + padding: 20px; + box-shadow: var(--shadow-sm); + transition: all 0.2s ease; + display: flex; + flex-direction: column; + gap: 16px; +} + +.external-mcp-item:hover { + box-shadow: var(--shadow-md); + border-color: var(--accent-color); + transform: translateY(-2px); +} + +.external-mcp-item-header { + display: flex; + align-items: center; + justify-content: space-between; + gap: 16px; + flex-wrap: wrap; +} + +.external-mcp-item-info { + display: flex; + align-items: center; + gap: 12px; + flex: 1; + min-width: 0; +} + +.external-mcp-item-info h4 { + margin: 0; + font-size: 1.125rem; + font-weight: 600; + color: var(--text-primary); + display: flex; + align-items: center; + gap: 8px; + flex-wrap: wrap; +} + +.tool-count-badge { + display: inline-flex; + align-items: center; + gap: 4px; + padding: 2px 8px; + background: rgba(0, 102, 255, 0.1); + border: 1px solid rgba(0, 102, 255, 0.3); + border-radius: 12px; + font-size: 0.75rem; + font-weight: 600; + color: var(--accent-color); + margin-left: 8px; + white-space: nowrap; +} + +.external-mcp-status { + display: inline-flex; + align-items: center; + gap: 6px; + padding: 4px 12px; + border-radius: 999px; + font-size: 0.8125rem; + font-weight: 600; + white-space: nowrap; +} + +.external-mcp-status.status-connected { + background: rgba(40, 167, 69, 0.12); + color: var(--success-color); + border: 1px solid rgba(40, 167, 69, 0.3); +} + +.external-mcp-status.status-connected::before { + content: ''; + width: 8px; + height: 8px; + border-radius: 50%; + background: var(--success-color); + display: inline-block; + animation: pulse 2s infinite; +} + +@keyframes pulse { + 0%, 100% { + opacity: 1; + } + 50% { + opacity: 0.5; + } +} + +.external-mcp-status.status-disconnected { + background: rgba(108, 117, 125, 0.12); + color: var(--text-secondary); + border: 1px solid rgba(108, 117, 125, 0.3); +} + +.external-mcp-status.status-connecting { + background: rgba(0, 123, 255, 0.12); + color: var(--accent-color); + border: 1px solid rgba(0, 123, 255, 0.3); +} + +.external-mcp-status.status-connecting::before { + content: ''; + width: 8px; + height: 8px; + border-radius: 50%; + background: var(--accent-color); + display: inline-block; + animation: pulse 1.5s infinite; +} + +.external-mcp-status.status-disabled { + background: rgba(255, 193, 7, 0.12); + color: #b8860b; + border: 1px solid rgba(255, 193, 7, 0.3); +} + +.external-mcp-item-actions { + display: flex; + align-items: center; + gap: 8px; + flex-wrap: wrap; +} + +.btn-small { + padding: 6px 14px; + font-size: 0.8125rem; + border-radius: 6px; + border: 1px solid var(--border-color); + background: var(--bg-secondary); + color: var(--text-primary); + cursor: pointer; + transition: all 0.2s ease; + font-weight: 500; + white-space: nowrap; +} + +.btn-small:hover { + background: var(--bg-tertiary); + border-color: var(--accent-color); + color: var(--accent-color); + transform: translateY(-1px); + box-shadow: var(--shadow-sm); +} + +.btn-small.btn-danger { + background: rgba(220, 53, 69, 0.08); + border-color: rgba(220, 53, 69, 0.3); + color: var(--error-color); +} + +.btn-small.btn-danger:hover { + background: rgba(220, 53, 69, 0.15); + border-color: var(--error-color); + color: #c82333; +} + +.external-mcp-item-details { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + gap: 16px; + padding-top: 16px; + border-top: 1px solid var(--border-color); +} + +.external-mcp-item-details > div { + display: flex; + flex-direction: column; + gap: 6px; + padding: 12px; + background: var(--bg-secondary); + border-radius: 8px; + border: 1px solid var(--border-color); + transition: all 0.2s ease; +} + +.external-mcp-item-details > div:hover { + background: var(--bg-tertiary); + border-color: var(--accent-color); + transform: translateY(-1px); + box-shadow: var(--shadow-sm); +} + +.external-mcp-item-details strong { + font-size: 0.75rem; + font-weight: 600; + color: var(--text-secondary); + text-transform: uppercase; + letter-spacing: 0.5px; + margin-bottom: 2px; +} + +.external-mcp-item-details span { + font-size: 0.875rem; + color: var(--text-primary); + word-break: break-word; + line-height: 1.5; +} + +.external-mcp-list .empty { + text-align: center; + padding: 48px 24px; + color: var(--text-muted); + font-size: 0.9375rem; + background: var(--bg-secondary); + border: 2px dashed var(--border-color); + border-radius: 12px; +} + +.external-mcp-list .error { + text-align: center; + padding: 24px; + color: var(--error-color); + background: rgba(220, 53, 69, 0.08); + border: 1px solid rgba(220, 53, 69, 0.3); + border-radius: 12px; + font-size: 0.875rem; +} + +.error-message { + display: block; + padding: 8px 12px; + background: rgba(220, 53, 69, 0.1); + border: 1px solid rgba(220, 53, 69, 0.3); + border-radius: 6px; + color: var(--error-color); + font-size: 0.875rem; + line-height: 1.5; +} + +.form-group textarea { + padding: 12px; + border: 1px solid var(--border-color); + border-radius: 6px; + font-size: 0.9375rem; + background: var(--bg-primary); + color: var(--text-primary); + transition: border-color 0.2s; + width: 100%; + font-family: inherit; + resize: vertical; + min-height: 200px; +} + +.form-group textarea:focus { + outline: none; + border-color: var(--accent-color); + box-shadow: 0 0 0 3px rgba(0, 102, 255, 0.1); +} + +.form-group textarea.error { + border-color: var(--error-color); + box-shadow: 0 0 0 3px rgba(220, 53, 69, 0.1); +} + +/* 响应式优化 */ +@media (max-width: 768px) { + .external-mcp-actions { + flex-direction: column; + align-items: stretch; + } + + .external-mcp-stats { + margin-left: 0; + width: 100%; + justify-content: space-around; + } + + .external-mcp-item-header { + flex-direction: column; + align-items: flex-start; + } + + .external-mcp-item-actions { + width: 100%; + justify-content: flex-end; + } + + .external-mcp-item-details { + grid-template-columns: 1fr; + } +} diff --git a/web/static/js/app.js b/web/static/js/app.js index 3eba9f9a..4b50b2f6 100644 --- a/web/static/js/app.js +++ b/web/static/js/app.js @@ -1837,6 +1837,8 @@ function renderToolsList() { if (!toolsList.contains(listContainer)) { toolsList.appendChild(listContainer); } + // 更新统计 + updateToolsStats(); return; } @@ -1844,10 +1846,19 @@ function renderToolsList() { const toolItem = document.createElement('div'); toolItem.className = 'tool-item'; toolItem.dataset.toolName = tool.name; // 保存原始工具名称 + toolItem.dataset.isExternal = tool.is_external ? 'true' : 'false'; + toolItem.dataset.externalMcp = tool.external_mcp || ''; + + // 外部工具标签 + const externalBadge = tool.is_external ? '外部' : ''; + toolItem.innerHTML = ` - +
-
${escapeHtml(tool.name)}
+
+ ${escapeHtml(tool.name)} + ${externalBadge} +
${escapeHtml(tool.description || '无描述')}
`; @@ -1857,6 +1868,9 @@ function renderToolsList() { if (!toolsList.contains(listContainer)) { toolsList.appendChild(listContainer); } + + // 更新统计 + updateToolsStats(); } // 渲染工具列表分页控件 @@ -1903,6 +1917,7 @@ function selectAllTools() { document.querySelectorAll('#tools-list input[type="checkbox"]').forEach(checkbox => { checkbox.checked = true; }); + updateToolsStats(); } // 全不选工具 @@ -1910,6 +1925,88 @@ function deselectAllTools() { document.querySelectorAll('#tools-list input[type="checkbox"]').forEach(checkbox => { checkbox.checked = false; }); + updateToolsStats(); +} + +// 更新工具统计信息 +async function updateToolsStats() { + const statsEl = document.getElementById('tools-stats'); + if (!statsEl) return; + + // 计算当前页的启用工具数 + const currentPageEnabled = Array.from(document.querySelectorAll('#tools-list input[type="checkbox"]:checked')).length; + const currentPageTotal = document.querySelectorAll('#tools-list input[type="checkbox"]').length; + + // 计算所有工具的启用数 + let totalEnabled = 0; + let totalTools = toolsPagination.total || 0; + + try { + // 如果有搜索关键词,只统计搜索结果 + if (toolsSearchKeyword) { + totalTools = allTools.length; + totalEnabled = allTools.filter(tool => { + const checkbox = document.getElementById(`tool-${tool.name}`); + return checkbox ? checkbox.checked : tool.enabled; + }).length; + } else { + // 没有搜索时,需要获取所有工具的状态 + // 先使用当前已知的工具状态 + const toolStateMap = new Map(); + + // 从当前页的checkbox获取状态 + allTools.forEach(tool => { + const checkbox = document.getElementById(`tool-${tool.name}`); + if (checkbox) { + toolStateMap.set(tool.name, checkbox.checked); + } else { + // 如果checkbox不存在(不在当前页),使用工具原始状态 + toolStateMap.set(tool.name, tool.enabled); + } + }); + + // 如果总工具数大于当前页,需要获取所有工具的状态 + if (totalTools > allTools.length) { + // 遍历所有页面获取完整状态 + let page = 1; + let hasMore = true; + const pageSize = 100; // 使用较大的页面大小以减少请求次数 + + while (hasMore && page <= 10) { // 限制最多10页,避免无限循环 + const url = `/api/config/tools?page=${page}&page_size=${pageSize}`; + const pageResponse = await apiFetch(url); + if (!pageResponse.ok) break; + + const pageResult = await pageResponse.json(); + pageResult.tools.forEach(tool => { + // 如果工具不在当前页,使用服务器返回的状态 + if (!toolStateMap.has(tool.name)) { + toolStateMap.set(tool.name, tool.enabled); + } + }); + + if (page >= pageResult.total_pages) { + hasMore = false; + } else { + page++; + } + } + } + + // 计算启用的工具数 + totalEnabled = Array.from(toolStateMap.values()).filter(enabled => enabled).length; + } + } catch (error) { + console.warn('获取工具统计失败,使用当前页数据', error); + // 如果获取失败,使用当前页的数据 + totalTools = totalTools || currentPageTotal; + totalEnabled = currentPageEnabled; + } + + statsEl.innerHTML = ` + ✅ 当前页已启用: ${currentPageEnabled} / ${currentPageTotal} + 📊 总计已启用: ${totalEnabled} / ${totalTools} + `; } // 过滤工具(已废弃,现在使用服务端搜索) @@ -1974,39 +2071,75 @@ async function applySettings() { document.querySelectorAll('#tools-list .tool-item').forEach(item => { const checkbox = item.querySelector('input[type="checkbox"]'); const toolName = item.dataset.toolName; + const isExternal = item.dataset.isExternal === 'true'; + const externalMcp = item.dataset.externalMcp || ''; if (toolName) { - currentPageTools.set(toolName, checkbox.checked); + currentPageTools.set(toolName, { + enabled: checkbox.checked, + is_external: isExternal, + external_mcp: externalMcp + }); } }); - // 获取所有工具列表以获取完整状态 + // 获取所有工具列表以获取完整状态(遍历所有页面) + // 注意:无论是否在搜索状态下,都要获取所有工具的状态,以确保完整保存 try { - const allToolsResponse = await apiFetch(`/api/config/tools?page=1&page_size=1000`); - if (allToolsResponse.ok) { - const allToolsResult = await allToolsResponse.json(); - // 使用所有工具,但用当前页的修改覆盖 - allToolsResult.tools.forEach(tool => { - config.tools.push({ + const allToolsMap = new Map(); + let page = 1; + let hasMore = true; + const pageSize = 100; // 使用合理的页面大小 + + // 遍历所有页面获取所有工具(不使用搜索关键词,获取全部工具) + while (hasMore) { + const url = `/api/config/tools?page=${page}&page_size=${pageSize}`; + + const pageResponse = await apiFetch(url); + if (!pageResponse.ok) { + throw new Error('获取工具列表失败'); + } + + const pageResult = await pageResponse.json(); + + // 将当前页的工具添加到映射中 + // 如果工具在当前显示的页面中(匹配搜索且在当前页),使用当前页的修改 + // 否则使用服务器返回的状态 + pageResult.tools.forEach(tool => { + const currentPageTool = currentPageTools.get(tool.name); + allToolsMap.set(tool.name, { name: tool.name, - enabled: currentPageTools.has(tool.name) ? currentPageTools.get(tool.name) : tool.enabled - }); - }); - } else { - // 如果获取失败,只使用当前页的工具 - currentPageTools.forEach((enabled, toolName) => { - config.tools.push({ - name: toolName, - enabled: enabled + enabled: currentPageTool ? currentPageTool.enabled : tool.enabled, + is_external: currentPageTool ? currentPageTool.is_external : (tool.is_external || false), + external_mcp: currentPageTool ? currentPageTool.external_mcp : (tool.external_mcp || '') }); }); + + // 检查是否还有更多页面 + if (page >= pageResult.total_pages) { + hasMore = false; + } else { + page++; + } } + + // 将所有工具添加到配置中 + allToolsMap.forEach(tool => { + config.tools.push({ + name: tool.name, + enabled: tool.enabled, + is_external: tool.is_external, + external_mcp: tool.external_mcp + }); + }); } catch (error) { console.warn('获取所有工具列表失败,仅使用当前页工具状态', error); // 如果获取失败,只使用当前页的工具 - currentPageTools.forEach((enabled, toolName) => { + currentPageTools.forEach((toolData, toolName) => { config.tools.push({ name: toolName, - enabled: enabled + enabled: toolData.enabled, + is_external: toolData.is_external, + external_mcp: toolData.external_mcp }); }); } @@ -2432,3 +2565,499 @@ function formatExecutionDuration(start, end) { const remainMinutes = minutes % 60; return remainMinutes > 0 ? `${hours} 小时 ${remainMinutes} 分` : `${hours} 小时`; } + +// ==================== 外部MCP管理 ==================== + +let currentEditingMCPName = null; + +// 加载外部MCP列表 +async function loadExternalMCPs() { + try { + const response = await apiFetch('/api/external-mcp'); + if (!response.ok) { + throw new Error('获取外部MCP列表失败'); + } + + const data = await response.json(); + renderExternalMCPList(data.servers || {}); + renderExternalMCPStats(data.stats || {}); + } catch (error) { + console.error('加载外部MCP列表失败:', error); + const list = document.getElementById('external-mcp-list'); + if (list) { + list.innerHTML = `
加载失败: ${escapeHtml(error.message)}
`; + } + } +} + +// 渲染外部MCP列表 +function renderExternalMCPList(servers) { + const list = document.getElementById('external-mcp-list'); + if (!list) return; + + if (Object.keys(servers).length === 0) { + list.innerHTML = '
📋 暂无外部MCP配置
点击"添加外部MCP"按钮开始配置
'; + return; + } + + let html = '
'; + for (const [name, server] of Object.entries(servers)) { + const status = server.status || 'disconnected'; + const statusClass = status === 'connected' ? 'status-connected' : + status === 'connecting' ? 'status-connecting' : + status === 'disabled' ? 'status-disabled' : 'status-disconnected'; + const statusText = status === 'connected' ? '已连接' : + status === 'connecting' ? '连接中...' : + status === 'disabled' ? '已禁用' : '未连接'; + const transport = server.config.transport || (server.config.command ? 'stdio' : 'http'); + const transportIcon = transport === 'stdio' ? '⚙️' : '🌐'; + + html += ` +
+
+
+

${transportIcon} ${escapeHtml(name)}${server.tool_count !== undefined && server.tool_count > 0 ? `🔧 ${server.tool_count}` : ''}

+ ${statusText} +
+
+ ${status === 'connected' || status === 'disconnected' ? + `` : + status === 'connecting' ? + `` : ''} + + +
+
+
+
+ 传输模式 + ${transportIcon} ${escapeHtml(transport.toUpperCase())} +
+ ${server.tool_count !== undefined && server.tool_count > 0 ? ` +
+ 工具数量 + 🔧 ${server.tool_count} 个工具 +
` : server.tool_count === 0 && status === 'connected' ? ` +
+ 工具数量 + 暂无工具 +
` : ''} + ${server.config.description ? ` +
+ 描述 + ${escapeHtml(server.config.description)} +
` : ''} + ${server.config.timeout ? ` +
+ 超时时间 + ${server.config.timeout} 秒 +
` : ''} + ${transport === 'stdio' && server.config.command ? ` +
+ 命令 + ${escapeHtml(server.config.command)} +
` : ''} + ${transport === 'http' && server.config.url ? ` +
+ URL + ${escapeHtml(server.config.url)} +
` : ''} +
+
+ `; + } + html += '
'; + list.innerHTML = html; +} + +// 渲染外部MCP统计信息 +function renderExternalMCPStats(stats) { + const statsEl = document.getElementById('external-mcp-stats'); + if (!statsEl) return; + + const total = stats.total || 0; + const enabled = stats.enabled || 0; + const disabled = stats.disabled || 0; + const connected = stats.connected || 0; + + statsEl.innerHTML = ` + 📊 总数: ${total} + ✅ 已启用: ${enabled} + ⏸ 已停用: ${disabled} + 🔗 已连接: ${connected} + `; +} + +// 显示添加外部MCP模态框 +function showAddExternalMCPModal() { + currentEditingMCPName = null; + document.getElementById('external-mcp-modal-title').textContent = '添加外部MCP'; + document.getElementById('external-mcp-json').value = ''; + document.getElementById('external-mcp-json-error').style.display = 'none'; + document.getElementById('external-mcp-json-error').textContent = ''; + document.getElementById('external-mcp-json').classList.remove('error'); + document.getElementById('external-mcp-modal').style.display = 'block'; +} + +// 关闭外部MCP模态框 +function closeExternalMCPModal() { + document.getElementById('external-mcp-modal').style.display = 'none'; + currentEditingMCPName = null; +} + +// 编辑外部MCP +async function editExternalMCP(name) { + try { + const response = await apiFetch(`/api/external-mcp/${encodeURIComponent(name)}`); + if (!response.ok) { + throw new Error('获取外部MCP配置失败'); + } + + const server = await response.json(); + currentEditingMCPName = name; + + document.getElementById('external-mcp-modal-title').textContent = '编辑外部MCP'; + + // 将配置转换为对象格式(key为名称) + const config = { ...server.config }; + // 移除tool_count、external_mcp_enable等前端字段,但保留enabled/disabled用于向后兼容 + delete config.tool_count; + delete config.external_mcp_enable; + + // 包装成对象格式:{ "name": { config } } + const configObj = {}; + configObj[name] = config; + + // 格式化JSON + const jsonStr = JSON.stringify(configObj, null, 2); + document.getElementById('external-mcp-json').value = jsonStr; + document.getElementById('external-mcp-json-error').style.display = 'none'; + document.getElementById('external-mcp-json-error').textContent = ''; + document.getElementById('external-mcp-json').classList.remove('error'); + + document.getElementById('external-mcp-modal').style.display = 'block'; + } catch (error) { + console.error('编辑外部MCP失败:', error); + alert('编辑失败: ' + error.message); + } +} + +// 格式化JSON +function formatExternalMCPJSON() { + const jsonTextarea = document.getElementById('external-mcp-json'); + const errorDiv = document.getElementById('external-mcp-json-error'); + + try { + const jsonStr = jsonTextarea.value.trim(); + if (!jsonStr) { + errorDiv.textContent = 'JSON不能为空'; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + return; + } + + const parsed = JSON.parse(jsonStr); + const formatted = JSON.stringify(parsed, null, 2); + jsonTextarea.value = formatted; + errorDiv.style.display = 'none'; + jsonTextarea.classList.remove('error'); + } catch (error) { + errorDiv.textContent = 'JSON格式错误: ' + error.message; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + } +} + +// 加载示例 +function loadExternalMCPExample() { + const example = { + "hexstrike-ai": { + command: "python3", + args: [ + "/path/to/script.py", + "--server", + "http://example.com" + ], + description: "示例描述", + timeout: 300 + } + }; + + document.getElementById('external-mcp-json').value = JSON.stringify(example, null, 2); + document.getElementById('external-mcp-json-error').style.display = 'none'; + document.getElementById('external-mcp-json').classList.remove('error'); +} + +// 保存外部MCP +async function saveExternalMCP() { + const jsonTextarea = document.getElementById('external-mcp-json'); + const jsonStr = jsonTextarea.value.trim(); + const errorDiv = document.getElementById('external-mcp-json-error'); + + if (!jsonStr) { + errorDiv.textContent = 'JSON配置不能为空'; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + jsonTextarea.focus(); + return; + } + + let configObj; + try { + configObj = JSON.parse(jsonStr); + } catch (error) { + errorDiv.textContent = 'JSON格式错误: ' + error.message; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + jsonTextarea.focus(); + return; + } + + // 验证必须是对象格式 + if (typeof configObj !== 'object' || Array.isArray(configObj) || configObj === null) { + errorDiv.textContent = '配置错误: 必须是JSON对象格式,key为配置名称,value为配置内容'; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + return; + } + + // 获取所有配置名称 + const names = Object.keys(configObj); + if (names.length === 0) { + errorDiv.textContent = '配置错误: 至少需要一个配置项'; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + return; + } + + // 验证每个配置 + for (const name of names) { + if (!name || name.trim() === '') { + errorDiv.textContent = '配置错误: 配置名称不能为空'; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + return; + } + + const config = configObj[name]; + if (typeof config !== 'object' || Array.isArray(config) || config === null) { + errorDiv.textContent = `配置错误: "${name}" 的配置必须是对象`; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + return; + } + + // 移除 external_mcp_enable 字段(由按钮控制,但保留 enabled/disabled 用于向后兼容) + delete config.external_mcp_enable; + + // 验证配置内容 + const transport = config.transport || (config.command ? 'stdio' : config.url ? 'http' : ''); + if (!transport) { + errorDiv.textContent = `配置错误: "${name}" 需要指定command(stdio模式)或url(http模式)`; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + return; + } + + if (transport === 'stdio' && !config.command) { + errorDiv.textContent = `配置错误: "${name}" stdio模式需要command字段`; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + return; + } + + if (transport === 'http' && !config.url) { + errorDiv.textContent = `配置错误: "${name}" http模式需要url字段`; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + return; + } + } + + // 清除错误提示 + errorDiv.style.display = 'none'; + jsonTextarea.classList.remove('error'); + + try { + // 如果是编辑模式,只更新当前编辑的配置 + if (currentEditingMCPName) { + if (!configObj[currentEditingMCPName]) { + errorDiv.textContent = `配置错误: 编辑模式下,JSON必须包含配置名称 "${currentEditingMCPName}"`; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + return; + } + + const response = await apiFetch(`/api/external-mcp/${encodeURIComponent(currentEditingMCPName)}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ config: configObj[currentEditingMCPName] }), + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.error || '保存失败'); + } + } else { + // 添加模式:保存所有配置 + for (const name of names) { + const config = configObj[name]; + const response = await apiFetch(`/api/external-mcp/${encodeURIComponent(name)}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ config }), + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(`保存 "${name}" 失败: ${error.error || '未知错误'}`); + } + } + } + + closeExternalMCPModal(); + await loadExternalMCPs(); + alert('保存成功'); + } catch (error) { + console.error('保存外部MCP失败:', error); + errorDiv.textContent = '保存失败: ' + error.message; + errorDiv.style.display = 'block'; + jsonTextarea.classList.add('error'); + } +} + +// 删除外部MCP +async function deleteExternalMCP(name) { + if (!confirm(`确定要删除外部MCP "${name}" 吗?`)) { + return; + } + + try { + const response = await apiFetch(`/api/external-mcp/${encodeURIComponent(name)}`, { + method: 'DELETE', + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.error || '删除失败'); + } + + await loadExternalMCPs(); + alert('删除成功'); + } catch (error) { + console.error('删除外部MCP失败:', error); + alert('删除失败: ' + error.message); + } +} + +// 切换外部MCP启停 +async function toggleExternalMCP(name, currentStatus) { + const action = currentStatus === 'connected' ? 'stop' : 'start'; + const buttonId = `btn-toggle-${name}`; + const button = document.getElementById(buttonId); + + // 如果是启动操作,显示加载状态 + if (action === 'start' && button) { + button.disabled = true; + button.style.opacity = '0.6'; + button.style.cursor = 'not-allowed'; + button.innerHTML = '⏳ 连接中...'; + } + + try { + const response = await apiFetch(`/api/external-mcp/${encodeURIComponent(name)}/${action}`, { + method: 'POST', + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.error || '操作失败'); + } + + const result = await response.json(); + + // 如果是启动操作,轮询状态直到连接成功或失败 + if (action === 'start') { + await pollExternalMCPStatus(name, 30); // 最多轮询30次(约30秒) + } else { + // 停止操作,直接刷新 + await loadExternalMCPs(); + } + } catch (error) { + console.error('切换外部MCP状态失败:', error); + alert('操作失败: ' + error.message); + + // 恢复按钮状态 + if (button) { + button.disabled = false; + button.style.opacity = '1'; + button.style.cursor = 'pointer'; + button.innerHTML = '▶ 启动'; + } + + // 刷新状态 + await loadExternalMCPs(); + } +} + +// 轮询外部MCP状态 +async function pollExternalMCPStatus(name, maxAttempts = 30) { + let attempts = 0; + const pollInterval = 1000; // 1秒轮询一次 + + while (attempts < maxAttempts) { + await new Promise(resolve => setTimeout(resolve, pollInterval)); + + try { + const response = await apiFetch(`/api/external-mcp/${encodeURIComponent(name)}`); + if (response.ok) { + const data = await response.json(); + const status = data.status || 'disconnected'; + + // 更新按钮状态 + const buttonId = `btn-toggle-${name}`; + const button = document.getElementById(buttonId); + + if (status === 'connected') { + // 连接成功,刷新列表 + await loadExternalMCPs(); + return; + } else if (status === 'error' || status === 'disconnected') { + // 连接失败,刷新列表并显示错误 + await loadExternalMCPs(); + if (status === 'error') { + alert('连接失败,请检查配置和网络连接'); + } + return; + } else if (status === 'connecting') { + // 仍在连接中,继续轮询 + attempts++; + continue; + } + } + } catch (error) { + console.error('轮询状态失败:', error); + } + + attempts++; + } + + // 超时,刷新列表 + await loadExternalMCPs(); + alert('连接超时,请检查配置和网络连接'); +} + +// 在打开设置时加载外部MCP列表 +const originalOpenSettings = openSettings; +openSettings = async function() { + await originalOpenSettings(); + await loadExternalMCPs(); +}; diff --git a/web/templates/index.html b/web/templates/index.html index b1f1cb16..dbf50f8a 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -119,11 +119,25 @@ +
+ +
+

外部 MCP 配置

+
+
+ + +
+
+
+
+
+

Agent 配置

@@ -247,6 +261,51 @@
+ + +