Compare commits

...

25 Commits

Author SHA1 Message Date
公明 1336d6f9a6 Add files via upload 2026-01-11 02:17:07 +08:00
公明 5ce1fb7501 Add files via upload 2026-01-11 02:05:11 +08:00
公明 aa9819a2c8 Update config.yaml 2026-01-11 02:04:21 +08:00
公明 3aee7022c4 Add files via upload 2026-01-11 02:03:33 +08:00
公明 4ca1aa9aa8 Add files via upload 2026-01-09 23:00:32 +08:00
公明 3448c661b8 Add files via upload 2026-01-09 22:16:09 +08:00
公明 b524ce68ea Add files via upload 2026-01-09 21:16:38 +08:00
公明 2c973f8c3b Add files via upload 2026-01-09 19:44:59 +08:00
公明 c3a1d95a92 Add files via upload 2026-01-09 19:32:14 +08:00
公明 60e3795322 Add files via upload 2026-01-09 19:02:16 +08:00
公明 28ca7f1851 Add files via upload 2026-01-09 18:52:38 +08:00
公明 14e9b986b0 Add files via upload 2026-01-08 23:43:09 +08:00
公明 dccbb80fa4 Add files via upload 2026-01-08 22:54:36 +08:00
公明 3043232937 Add files via upload 2026-01-08 22:43:41 +08:00
公明 2aeb2705e9 Add files via upload 2026-01-07 19:41:35 +08:00
公明 6bd558cbd4 Add files via upload 2026-01-07 19:38:36 +08:00
公明 71abfb2384 Update README.md 2026-01-07 14:10:29 +08:00
公明 d3f6a87448 Update README.md 2026-01-07 14:07:23 +08:00
公明 2076266844 Update README.md 2026-01-07 14:05:25 +08:00
公明 42293a9f49 Update README.md 2026-01-06 00:58:55 +08:00
公明 92580bebd5 Update README_CN.md 2026-01-06 00:58:23 +08:00
公明 23fd79d50d Update README_CN.md 2026-01-06 00:57:58 +08:00
公明 5216cebb2f Update README.md 2026-01-06 00:52:43 +08:00
公明 e55dd0265e Update README.md 2026-01-06 00:52:07 +08:00
公明 d550853b56 Add files via upload 2026-01-02 04:11:46 +08:00
49 changed files with 6822 additions and 431 deletions
+169
View File
@@ -0,0 +1,169 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [1.2.0] - 2026-01-11
### Added
- Role-based testing feature: predefined security testing roles with custom system prompts and tool restrictions. Users can select roles (Penetration Testing, CTF, Web App Scanning, etc.) from the chat interface to customize AI behavior and available tools. Roles are defined as YAML files in the `roles/` directory with support for hot-reload.
## [1.1.0] - 2026-01-08
### Added
- SSE (Server-Sent Events) transport mode support for external MCP servers. External MCP federation now supports HTTP, stdio, and SSE modes. SSE mode enables real-time streaming communication for push-based scenarios.
## [1.0.0] - 2026-01-01
### Added
- Batch task management feature: create task queues with multiple tasks, add/edit/delete tasks before execution, and execute them sequentially. Each task runs as a separate conversation with status tracking (pending/running/completed/failed/cancelled). All queues and tasks are persisted in the database.
## [0.7.0] - 2025-12-25
### Added
- Vulnerability management feature: full CRUD operations for tracking vulnerabilities discovered during testing. Supports severity levels (critical/high/medium/low/info), status workflow (open/confirmed/fixed/false_positive), filtering by conversation/severity/status, and comprehensive statistics dashboard.
- Conversation grouping feature: organize conversations into groups, pin groups to top, rename/delete groups via context menu. All group data is persisted in the database.
## [0.6.1] - 2025-12-24
### Changed
- Refactored attack chain generation logic, achieving 2x faster generation speed. Redesigned attack chain frontend visualization for improved user experience.
## [0.6.0] - 2025-12-20
### Added
- Knowledge base feature with vector search, hybrid retrieval, and automatic indexing. AI agent can now search security knowledge during conversations.
## [0.5.1] - 2025-12-19
### Added
- ZoomEye network space search engine tool (zoomeye_search) with support for IPv4/IPv6/web assets, facets statistics, and flexible query parameters.
## [0.5.0] - 2025-12-18
### Changed
- Optimized web frontend with enhanced sidebar navigation and improved user experience.
## [0.4.1] - 2025-12-07
### Added
- FOFA network space search engine tool (fofa_search) with flexible query parameters and field configuration.
### Fixed
- Positional parameter handling bug: ensure correct parameter position when using default values.
## [0.4.0] - 2025-11-20
### Added
- Automatic compression/summarization for oversized tool logs and MCP transcripts.
## [0.3.0] - 2025-11-17
### Added
- AI-built attack-chain visualization with interactive graph and risk scoring.
## [0.2.0] - 2025-11-15
### Added
- Large-result pagination, advanced filtering, and external MCP federation.
## [0.1.1] - 2025-11-14
### Changed
- Optimized tool lookups to O(1) time complexity.
- Execution record cleanup and DB pagination improvements.
## [0.1.0] - 2025-11-13
### Added
- Web authentication, settings UI, and MCP stdio mode integration.
---
# 更新日志
本项目的重要变更将记录在此文件中。
格式基于 [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
并遵循 [语义化版本](https://semver.org/lang/zh-CN/)。
## [未发布]
## [1.2.0] - 2026-01-11
### 新增
- 角色化测试功能:预设安全测试角色,支持自定义系统提示词和工具限制。用户可在聊天界面选择角色(渗透测试、CTF、Web 应用扫描等),以自定义 AI 行为和可用工具。角色以 YAML 文件形式定义在 `roles/` 目录,支持热加载。
## [1.1.0] - 2026-01-08
### 新增
- SSEServer-Sent Events)传输模式支持,外部 MCP 联邦现支持 HTTP、stdio 和 SSE 三种模式。SSE 模式支持实时流式通信,适用于基于推送的场景。
## [1.0.0] - 2026-01-01
### 新增
- 批量任务管理功能:支持创建任务队列,批量添加多个任务,执行前可编辑或删除任务,然后依次顺序执行。每个任务作为独立对话运行,支持状态跟踪(待执行/执行中/已完成/失败/已取消),所有队列和任务数据持久化存储到数据库。
## [0.7.0] - 2025-12-25
### 新增
- 漏洞管理功能:完整的漏洞 CRUD 操作,支持跟踪测试过程中发现的漏洞。支持严重程度分级(严重/高/中/低/信息)、状态流转(待确认/已确认/已修复/误报)、按对话/严重程度/状态过滤,以及统计看板。
- 对话分组功能:支持创建分组、将对话移动到分组、分组置顶、重命名和删除等操作,所有分组数据持久化存储到数据库。
## [0.6.1] - 2025-12-24
### 变更
- 重构攻击链生成逻辑,生成速度提升一倍。重构攻击链前端页面展示,优化用户体验。
## [0.6.0] - 2025-12-20
### 新增
- 知识库功能:支持向量检索、混合搜索与自动索引,AI 智能体可在对话中自动搜索安全知识。
## [0.5.1] - 2025-12-19
### 新增
- 钟馗之眼(ZoomEye)网络空间搜索引擎工具(zoomeye_search),支持 IPv4/IPv6/Web 等资产搜索、统计项查询与灵活的查询参数配置。
## [0.5.0] - 2025-12-18
### 变更
- 优化 Web 前端界面,增加侧边栏导航,提升用户体验。
## [0.4.1] - 2025-12-07
### 新增
- FOFA 网络空间搜索引擎工具(fofa_search),支持灵活的查询参数与字段配置。
### 修复
- 修复位置参数处理 bug:当工具参数使用默认值时,确保后续参数位置正确传递。
## [0.4.0] - 2025-11-20
### 新增
- 支持超大日志/MCP 记录的自动压缩与摘要回写。
## [0.3.0] - 2025-11-17
### 新增
- 上线 AI 驱动的攻击链图谱与风险评分。
## [0.2.0] - 2025-11-15
### 新增
- 提供大结果分页检索与外部 MCP 挂载能力。
## [0.1.1] - 2025-11-14
### 变更
- 工具检索优化至 O(1) 时间复杂度。
- 执行记录清理、数据库分页优化。
## [0.1.0] - 2025-11-13
### 新增
- Web 鉴权、Settings 面板与 MCP stdio 模式发布。
+170 -56
View File
@@ -33,7 +33,7 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
## Highlights
- 🤖 AI decision engine with OpenAI-compatible models (GPT, Claude, DeepSeek, etc.)
- 🔌 Native MCP implementation with HTTP/stdio transports and external MCP federation
- 🔌 Native MCP implementation with HTTP/stdio/SSE transports and external MCP federation
- 🧰 100+ prebuilt tool recipes + YAML-based extension system
- 📄 Large-result pagination, compression, and searchable archives
- 🔗 Attack-chain graph, risk scoring, and step-by-step replay
@@ -42,6 +42,7 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
- 📁 Conversation grouping with pinning, rename, and batch management
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
## Tool Overview
@@ -65,35 +66,40 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
## Basic Usage
### Quick Start
1. **Clone & install**
```bash
git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git
cd CyberStrikeAI-main
go mod download
```
2. **Set up the Python tooling stack (required for the YAML tools directory)**
A large portion of `tools/*.yaml` recipes wrap Python utilities (`api-fuzzer`, `http-framework-test`, `install-python-package`, etc.). Create the project-local virtual environment once and install the shared dependencies:
```bash
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```
The helper tools automatically detect this `venv` (or any already active `$VIRTUAL_ENV`), so the default `env_name` works out of the box unless you intentionally supply another target.
3. **Configure OpenAI-compatible access**
Either open the in-app `Settings` panel after launch or edit `config.yaml`:
```yaml
openai:
api_key: "sk-your-key"
base_url: "https://api.openai.com/v1"
model: "gpt-4o"
auth:
password: "" # empty = auto-generate & log once
session_duration_hours: 12
security:
tools_dir: "tools"
```
4. **Install the tooling you need (optional)**
### Quick Start (One-Command Deployment)
**Prerequisites:**
- Go 1.21+ ([Install](https://go.dev/dl/))
- Python 3.10+ ([Install](https://www.python.org/downloads/))
**One-Command Deployment:**
```bash
git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git
cd CyberStrikeAI-main
chmod +x run.sh && ./run.sh
```
The `run.sh` script will automatically:
- ✅ Check and validate Go & Python environments
- ✅ Create Python virtual environment
- ✅ Install Python dependencies
- ✅ Download Go dependencies
- ✅ Build the project
- ✅ Start the server
**First-Time Configuration:**
1. **Configure OpenAI-compatible API** (required before first use)
- Open http://localhost:8080 after launch
- Go to `Settings` → Fill in your API credentials:
```yaml
openai:
api_key: "sk-your-key"
base_url: "https://api.openai.com/v1" # or https://api.deepseek.com/v1
model: "gpt-4o" # or deepseek-chat, claude-3-opus, etc.
```
- Or edit `config.yaml` directly before launching
2. **Login** - Use the auto-generated password shown in the console (or set `auth.password` in `config.yaml`)
3. **Install security tools (optional)** - Install tools as needed:
```bash
# macOS
brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass
@@ -101,18 +107,22 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster
```
AI automatically falls back to alternatives when a tool is missing.
5. **Launch**
```bash
chmod +x run.sh && ./run.sh
# or
go run cmd/server/main.go
# or
go build -o cyberstrike-ai cmd/server/main.go
```
6. **Open the console** at http://localhost:8080, log in with the generated password, and start chatting.
**Alternative Launch Methods:**
```bash
# Direct Go run (requires manual setup)
go run cmd/server/main.go
# Manual build
go build -o cyberstrike-ai cmd/server/main.go
./cyberstrike-ai
```
**Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment.
### Core Workflows
- **Conversation testing** Natural-language prompts trigger toolchains with streaming SSE output.
- **Role-based testing** Select from predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, etc.) to customize AI behavior and tool availability. Each role applies custom system prompts and can restrict available tools for focused testing scenarios.
- **Tool monitor** Inspect running jobs, execution logs, and large-result attachments.
- **History & audit** Every conversation and tool invocation is stored in SQLite with replay.
- **Conversation groups** Organize conversations into groups, pin important groups, rename or delete groups via context menu.
@@ -128,6 +138,28 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
## Advanced Usage
### Role-Based Testing
- **Predefined roles** System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory.
- **Custom prompts** Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas.
- **Tool restrictions** Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities).
- **Easy role creation** Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields.
- **Web UI integration** Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions.
**Creating a custom role (example):**
1. Create a YAML file in `roles/` (e.g., `roles/custom-role.yaml`):
```yaml
name: Custom Role
description: Specialized testing scenario
user_prompt: You are a specialized security tester focusing on API security...
icon: "\U0001F4E1"
tools:
- api-fuzzer
- arjun
- graphql-scanner
enabled: true
```
2. Restart the server or reload configuration; the role appears in the role selector dropdown.
### Tool Orchestration & Extensions
- **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata.
- **Directory hot-reload** pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments.
@@ -149,7 +181,7 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
### MCP Everywhere
- **Web mode** ships with HTTP MCP server automatically consumed by the UI.
- **MCP stdio mode** `go run cmd/mcp-stdio/main.go` exposes the agent to Cursor/CLI.
- **External MCP federation** register third-party MCP servers (HTTP or stdio) from the UI, toggle them per engagement, and monitor their health and call volume in real time.
- **External MCP federation** register third-party MCP servers (HTTP, stdio, or SSE) from the UI, toggle them per engagement, and monitor their health and call volume in real time.
#### MCP stdio quick start
1. **Build the binary** (run from the project root):
@@ -189,6 +221,62 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
}
```
#### External MCP federation (HTTP/stdio/SSE)
CyberStrikeAI supports connecting to external MCP servers via three transport modes:
- **HTTP mode** traditional request/response over HTTP POST
- **stdio mode** process-based communication via standard input/output
- **SSE mode** Server-Sent Events for real-time streaming communication
To add an external MCP server:
1. Open the Web UI and navigate to **Settings → External MCP**.
2. Click **Add External MCP** and provide the configuration in JSON format:
**HTTP mode example:**
```json
{
"my-http-mcp": {
"transport": "http",
"url": "http://127.0.0.1:8081/mcp",
"description": "HTTP MCP server",
"timeout": 30
}
}
```
**stdio mode example:**
```json
{
"my-stdio-mcp": {
"command": "python3",
"args": ["/path/to/mcp-server.py"],
"description": "stdio MCP server",
"timeout": 30
}
}
```
**SSE mode example:**
```json
{
"my-sse-mcp": {
"transport": "sse",
"url": "http://127.0.0.1:8082/sse",
"description": "SSE MCP server",
"timeout": 30
}
}
```
3. Click **Save** and then **Start** to connect to the server.
4. Monitor the connection status, tool count, and health in real time.
**SSE mode benefits:**
- Real-time bidirectional communication via Server-Sent Events
- Suitable for scenarios requiring continuous data streaming
- Lower latency for push-based notifications
A test SSE MCP server is available at `cmd/test-sse-mcp-server/` for validation purposes.
### Knowledge Base
- **Vector search** AI agent can automatically search the knowledge base for relevant security knowledge during conversations using the `search_knowledge_base` tool.
- **Hybrid retrieval** combines vector similarity search with keyword matching for better accuracy.
@@ -228,7 +316,8 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
### Automation Hooks
- **REST APIs** everything the UI uses (auth, conversations, tool runs, monitor, vulnerabilities) is available over JSON.
- **REST APIs** everything the UI uses (auth, conversations, tool runs, monitor, vulnerabilities, roles) is available over JSON.
- **Role APIs** manage security testing roles via `/api/roles` endpoints: `GET /api/roles` (list all roles), `GET /api/roles/:name` (get role), `POST /api/roles` (create role), `PUT /api/roles/:name` (update role), `DELETE /api/roles/:name` (delete role). Roles are stored as YAML files in the `roles/` directory and support hot-reload.
- **Vulnerability APIs** manage vulnerabilities via `/api/vulnerabilities` endpoints: `GET /api/vulnerabilities` (list with filters), `POST /api/vulnerabilities` (create), `GET /api/vulnerabilities/:id` (get), `PUT /api/vulnerabilities/:id` (update), `DELETE /api/vulnerabilities/:id` (delete), `GET /api/vulnerabilities/stats` (statistics).
- **Batch Task APIs** manage batch task queues via `/api/batch-tasks` endpoints: `POST /api/batch-tasks` (create queue), `GET /api/batch-tasks` (list queues), `GET /api/batch-tasks/:queueId` (get queue), `POST /api/batch-tasks/:queueId/start` (start execution), `POST /api/batch-tasks/:queueId/cancel` (cancel), `DELETE /api/batch-tasks/:queueId` (delete), `POST /api/batch-tasks/:queueId/tasks` (add task), `PUT /api/batch-tasks/:queueId/tasks/:taskId` (update task), `DELETE /api/batch-tasks/:queueId/tasks/:taskId` (delete task). Tasks execute sequentially, each creating a separate conversation with full status tracking.
- **Task control** pause/resume/stop long scans, re-run steps with new params, or stream transcripts.
@@ -271,6 +360,7 @@ knowledge:
top_k: 5 # Number of top results to return
similarity_threshold: 0.7 # Minimum similarity score (0-1)
hybrid_weight: 0.7 # Weight for vector search (1.0 = pure vector, 0.0 = pure keyword)
roles_dir: "roles" # Role configuration directory (relative to config file)
```
### Tool Definition Example (`tools/nmap.yaml`)
@@ -293,6 +383,26 @@ parameters:
description: "Range, e.g. 1-1000"
```
### Role Definition Example (`roles/penetration-testing.yaml`)
```yaml
name: Penetration Testing
description: Professional penetration testing expert for comprehensive security testing
user_prompt: You are a professional cybersecurity penetration testing expert. Please use professional penetration testing methods and tools to conduct comprehensive security testing on targets, including but not limited to SQL injection, XSS, CSRF, file inclusion, command execution and other common vulnerabilities.
icon: "\U0001F3AF"
tools:
- nmap
- sqlmap
- nuclei
- burpsuite
- metasploit
- httpx
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
```
## Project Layout
```
@@ -301,6 +411,7 @@ CyberStrikeAI/
├── internal/ # Agent, MCP core, handlers, security executor
├── web/ # Static SPA + templates
├── tools/ # YAML tool recipes (100+ examples provided)
├── roles/ # Role configurations (12+ predefined security testing roles)
├── img/ # Docs screenshots & diagrams
├── config.yaml # Runtime configuration
├── run.sh # Convenience launcher
@@ -326,22 +437,22 @@ Compress the 5 MB nuclei report, summarize critical CVEs, and attach the artifac
Build an attack chain for the latest engagement and export the node list with severity >= high.
```
## Changelog (Recent)
## Changelog
See [CHANGELOG.md](CHANGELOG.md) for detailed version history and all changes.
### Recent Highlights
- **2026-01-11** Role-based testing with predefined security testing roles
- **2026-01-08** SSE transport mode support for external MCP servers
- **2026-01-01** Batch task management with queue-based execution
- **2025-12-25** Vulnerability management and conversation grouping features
- **2025-12-20** Knowledge base with vector search and hybrid retrieval
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=Ed1s0nZ/CyberStrikeAI&type=date&legend=top-left)
- 2026-01-01 Added batch task management feature: create task queues with multiple tasks, add/edit/delete tasks before execution, and execute them sequentially. Each task runs as a separate conversation with status tracking (pending/running/completed/failed/cancelled). All queues and tasks are persisted in the database.
- 2025-12-25 Added vulnerability management feature: full CRUD operations for tracking vulnerabilities discovered during testing. Supports severity levels (critical/high/medium/low/info), status workflow (open/confirmed/fixed/false_positive), filtering by conversation/severity/status, and comprehensive statistics dashboard.
- 2025-12-25 Added conversation grouping feature: organize conversations into groups, pin groups to top, rename/delete groups via context menu. All group data is persisted in the database.
- 2025-12-24 Refactored attack chain generation logic, achieving 2x faster generation speed. Redesigned attack chain frontend visualization for improved user experience.
- 2025-12-20 Added knowledge base feature with vector search, hybrid retrieval, and automatic indexing. AI agent can now search security knowledge during conversations.
- 2025-12-19 Added ZoomEye network space search engine tool (zoomeye_search) with support for IPv4/IPv6/web assets, facets statistics, and flexible query parameters.
- 2025-12-18 Optimized web frontend with enhanced sidebar navigation and improved user experience.
- 2025-12-07 Added FOFA network space search engine tool (fofa_search) with flexible query parameters and field configuration.
- 2025-12-07 Fixed positional parameter handling bug: ensure correct parameter position when using default values.
- 2025-11-20 Added automatic compression/summarization for oversized tool logs and MCP transcripts.
- 2025-11-17 Introduced AI-built attack-chain visualization with interactive graph and risk scoring.
- 2025-11-15 Delivered large-result pagination, advanced filtering, and external MCP federation.
- 2025-11-14 Optimized tool lookups (O(1)), execution record cleanup, and DB pagination.
- 2025-11-13 Added web authentication, settings UI, and MCP stdio mode integration.
## 404Starlink
@@ -357,6 +468,9 @@ CyberStrikeAI has joined [404Starlink](https://github.com/knownsec/404StarLink)
</div>
---
Need help or want to contribute? Open an issue or PR—community tooling additions are welcome!
+167 -56
View File
@@ -32,7 +32,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
## 特性速览
- 🤖 兼容 OpenAI/DeepSeek/Claude 等模型的智能决策引擎
- 🔌 原生 MCP 协议,支持 HTTP / stdio 以及外部 MCP 接入
- 🔌 原生 MCP 协议,支持 HTTP / stdio / SSE 传输模式以及外部 MCP 接入
- 🧰 100+ 现成工具模版 + YAML 扩展能力
- 📄 大结果分页、压缩与全文检索
- 🔗 攻击链可视化、风险打分与步骤回放
@@ -41,6 +41,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
## 工具概览
@@ -64,35 +65,40 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
## 基础使用
### 快速上手
1. **获取代码并安装依赖**
```bash
git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git
cd CyberStrikeAI-main
go mod download
```
2. **初始化 Python 虚拟环境(tools 目录所需)**
`tools/*.yaml` 中大量工具(如 `api-fuzzer`、`http-framework-test`、`install-python-package` 等)依赖 Python 生态。首次进入项目根目录时请创建本地虚拟环境并安装依赖:
```bash
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```
两个 Python 专用工具(`install-python-package` 与 `execute-python-script`)会自动检测该 `venv`(或已经激活的 `$VIRTUAL_ENV`),因此默认 `env_name` 即可满足大多数场景。
3. **配置模型与鉴权**
启动后在 Web 端 `Settings` 填写,或直接编辑 `config.yaml`
```yaml
openai:
api_key: "sk-your-key"
base_url: "https://api.openai.com/v1"
model: "gpt-4o"
auth:
password: "" # 为空则首次启动自动生成强口令
session_duration_hours: 12
security:
tools_dir: "tools"
```
4. **按需安装安全工具(可选)**
### 快速上手(一条命令部署)
**环境要求:**
- Go 1.21+ ([下载安装](https://go.dev/dl/))
- Python 3.10+ ([下载安装](https://www.python.org/downloads/))
**一条命令部署:**
```bash
git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git
cd CyberStrikeAI-main
chmod +x run.sh && ./run.sh
```
`run.sh` 脚本会自动完成:
- ✅ 检查并验证 Go 和 Python 环境
- ✅ 创建 Python 虚拟环境
- ✅ 安装 Python 依赖包
- ✅ 下载 Go 依赖模块
- ✅ 编译构建项目
- ✅ 启动服务器
**首次配置:**
1. **配置 AI 模型 API**(首次使用前必填)
- 启动后访问 http://localhost:8080
- 进入 `设置` → 填写 API 配置信息:
```yaml
openai:
api_key: "sk-your-key"
base_url: "https://api.openai.com/v1" # 或 https://api.deepseek.com/v1
model: "gpt-4o" # 或 deepseek-chat, claude-3-opus 等
```
- 或启动前直接编辑 `config.yaml` 文件
2. **登录系统** - 使用控制台显示的自动生成密码(或在 `config.yaml` 中设置 `auth.password`
3. **安装安全工具(可选)** - 按需安装所需工具:
```bash
# macOS
brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass
@@ -100,18 +106,22 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster
```
未安装的工具会自动跳过或改用替代方案。
5. **启动服务**
```bash
chmod +x run.sh && ./run.sh
# 或
go run cmd/server/main.go
# 或
go build -o cyberstrike-ai cmd/server/main.go
```
6. **浏览器访问** http://localhost:8080 ,使用日志中提示的密码登录并开始对话。
**其他启动方式:**
```bash
# 直接运行(需手动配置环境)
go run cmd/server/main.go
# 手动编译
go build -o cyberstrike-ai cmd/server/main.go
./cyberstrike-ai
```
**说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。
### 常用流程
- **对话测试**:自然语言触发多步工具编排,SSE 实时输出。
- **角色化测试**:从预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试等)中选择,自定义 AI 行为和可用工具。每个角色可应用自定义系统提示词,并可限制可用工具列表,实现聚焦的测试场景。
- **工具监控**:查看任务队列、执行日志、大文件附件。
- **会话历史**:所有对话与工具调用保存在 SQLite,可随时重放。
- **对话分组**:将对话按项目或主题组织到不同分组,支持置顶、重命名、删除等操作,所有数据持久化存储。
@@ -127,6 +137,28 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
## 进阶使用
### 角色化测试
- **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。
- **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。
- **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段。
- **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。
**创建自定义角色示例:**
1. 在 `roles/` 目录创建 YAML 文件(如 `roles/custom-role.yaml`):
```yaml
name: 自定义角色
description: 专用测试场景
user_prompt: 你是一个专注于 API 安全的专业安全测试人员...
icon: "\U0001F4E1"
tools:
- api-fuzzer
- arjun
- graphql-scanner
enabled: true
```
2. 重启服务或重新加载配置,角色会出现在角色选择下拉菜单中。
### 工具编排与扩展
- `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。
- `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。
@@ -147,7 +179,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
### MCP 全场景
- **Web 模式**:自带 HTTP MCP 服务供前端调用。
- **MCP stdio 模式**`go run cmd/mcp-stdio/main.go` 可接入 Cursor/命令行。
- **外部 MCP 联邦**:在设置中注册第三方 MCP(HTTP/stdio),按需启停并实时查看调用统计与健康度。
- **外部 MCP 联邦**:在设置中注册第三方 MCPHTTP/stdio/SSE),按需启停并实时查看调用统计与健康度。
#### MCP stdio 快速集成
1. **编译可执行文件**(在项目根目录执行):
@@ -187,6 +219,62 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
}
```
#### 外部 MCP 联邦(HTTP/stdio/SSE
CyberStrikeAI 支持通过三种传输模式连接外部 MCP 服务器:
- **HTTP 模式** 通过 HTTP POST 进行传统的请求/响应通信
- **stdio 模式** – 通过标准输入/输出进行进程间通信
- **SSE 模式** 通过 Server-Sent Events 实现实时流式通信
添加外部 MCP 服务器:
1. 打开 Web 界面,进入 **设置 → 外部MCP**。
2. 点击 **添加外部MCP**,以 JSON 格式提供配置:
**HTTP 模式示例:**
```json
{
"my-http-mcp": {
"transport": "http",
"url": "http://127.0.0.1:8081/mcp",
"description": "HTTP MCP 服务器",
"timeout": 30
}
}
```
**stdio 模式示例:**
```json
{
"my-stdio-mcp": {
"command": "python3",
"args": ["/path/to/mcp-server.py"],
"description": "stdio MCP 服务器",
"timeout": 30
}
}
```
**SSE 模式示例:**
```json
{
"my-sse-mcp": {
"transport": "sse",
"url": "http://127.0.0.1:8082/sse",
"description": "SSE MCP 服务器",
"timeout": 30
}
}
```
3. 点击 **保存**,然后点击 **启动** 连接服务器。
4. 实时监控连接状态、工具数量和健康度。
**SSE 模式优势:**
- 通过 Server-Sent Events 实现实时双向通信
- 适用于需要持续数据流的场景
- 对于基于推送的通知,延迟更低
可在 `cmd/test-sse-mcp-server/` 目录找到用于验证的测试 SSE MCP 服务器。
### 知识库功能
- **向量检索**:AI 智能体在对话过程中可自动调用 `search_knowledge_base` 工具搜索知识库中的安全知识。
@@ -227,7 +315,8 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
### 自动化与安全
- **REST API**:认证、会话、任务、监控、漏洞管理等接口全部开放,可与 CI/CD 集成。
- **REST API**:认证、会话、任务、监控、漏洞管理、角色管理等接口全部开放,可与 CI/CD 集成。
- **角色管理 API**:通过 `/api/roles` 端点管理安全测试角色:`GET /api/roles`(列表)、`GET /api/roles/:name`(获取角色)、`POST /api/roles`(创建角色)、`PUT /api/roles/:name`(更新角色)、`DELETE /api/roles/:name`(删除角色)。角色以 YAML 文件形式存储在 `roles/` 目录,支持热加载。
- **漏洞管理 API**:通过 `/api/vulnerabilities` 端点管理漏洞:`GET /api/vulnerabilities`(列表,支持过滤)、`POST /api/vulnerabilities`(创建)、`GET /api/vulnerabilities/:id`(获取)、`PUT /api/vulnerabilities/:id`(更新)、`DELETE /api/vulnerabilities/:id`(删除)、`GET /api/vulnerabilities/stats`(统计)。
- **批量任务 API**:通过 `/api/batch-tasks` 端点管理批量任务队列:`POST /api/batch-tasks`(创建队列)、`GET /api/batch-tasks`(列表)、`GET /api/batch-tasks/:queueId`(获取队列)、`POST /api/batch-tasks/:queueId/start`(开始执行)、`POST /api/batch-tasks/:queueId/cancel`(取消)、`DELETE /api/batch-tasks/:queueId`(删除队列)、`POST /api/batch-tasks/:queueId/tasks`(添加任务)、`PUT /api/batch-tasks/:queueId/tasks/:taskId`(更新任务)、`DELETE /api/batch-tasks/:queueId/tasks/:taskId`(删除任务)。任务依次顺序执行,每个任务创建独立对话,支持完整状态跟踪。
- **任务控制**:支持暂停/终止长任务、修改参数后重跑、流式获取日志。
@@ -270,6 +359,7 @@ knowledge:
top_k: 5 # 检索返回的 Top-K 结果数量
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0 表示纯向量检索,0.0 表示纯关键词检索
roles_dir: "roles" # 角色配置文件目录(相对于配置文件所在目录)
```
### 工具模版示例(`tools/nmap.yaml`
@@ -292,6 +382,26 @@ parameters:
description: "端口范围,如 1-1000"
```
### 角色配置示例(`roles/渗透测试.yaml`
```yaml
name: 渗透测试
description: 专业渗透测试专家,全面深入的漏洞检测
user_prompt: 你是一个专业的网络安全渗透测试专家。请使用专业的渗透测试方法和工具,对目标进行全面的安全测试,包括但不限于SQL注入、XSS、CSRF、文件包含、命令执行等常见漏洞。
icon: "\U0001F3AF"
tools:
- nmap
- sqlmap
- nuclei
- burpsuite
- metasploit
- httpx
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
```
## 项目结构
```
@@ -300,6 +410,7 @@ CyberStrikeAI/
├── internal/ # Agent、MCP 核心、路由与执行器
├── web/ # 前端静态资源与模板
├── tools/ # YAML 工具目录(含 100+ 示例)
├── roles/ # 角色配置文件目录(含 12+ 预设安全测试角色)
├── img/ # 文档配图
├── config.yaml # 运行配置
├── run.sh # 启动脚本
@@ -325,21 +436,21 @@ CyberStrikeAI/
构建最新一次测试的攻击链,只导出风险 >= 高的节点列表。
```
## Changelog(近期)
- 2026-01-01 —— 新增批量任务管理功能:支持创建任务队列,批量添加多个任务,执行前可编辑或删除任务,然后依次顺序执行。每个任务作为独立对话运行,支持状态跟踪(待执行/执行中/已完成/失败/已取消),所有队列和任务数据持久化存储到数据库。
- 2025-12-25 —— 新增漏洞管理功能:完整的漏洞 CRUD 操作,支持跟踪测试过程中发现的漏洞。支持严重程度分级(严重/高/中/低/信息)、状态流转(待确认/已确认/已修复/误报)、按对话/严重程度/状态过滤,以及统计看板
- 2025-12-25 —— 新增对话分组功能:支持创建分组、将对话移动到分组、分组置顶、重命名和删除等操作,所有分组数据持久化存储到数据库。
- 2025-12-24 —— 重构攻击链生成逻辑,生成速度提升一倍。重构攻击链前端页面展示,优化用户体验。
- 2025-12-20 —— 新增知识库功能:支持向量检索、混合搜索与自动索引,AI 智能体可在对话中自动搜索安全知识。
- 2025-12-19 —— 新增钟馗之眼(ZoomEye)网络空间搜索引擎工具(zoomeye_search),支持 IPv4/IPv6/Web 等资产搜索、统计项查询与灵活的查询参数配置。
- 2025-12-18 —— 优化 Web 前端界面,增加侧边栏导航,提升用户体验。
- 2025-12-07 —— 新增 FOFA 网络空间搜索引擎工具(fofa_search),支持灵活的查询参数与字段配置。
- 2025-12-07 —— 修复位置参数处理 bug:当工具参数使用默认值时,确保后续参数位置正确传递。
- 2025-11-20 —— 支持超大日志/MCP 记录的自动压缩与摘要回写。
- 2025-11-17 —— 上线 AI 驱动的攻击链图谱与风险评分。
- 2025-11-15 —— 提供大结果分页检索与外部 MCP 挂载能力。
- 2025-11-14 —— 工具检索 O(1)、执行记录清理、数据库分页优化。
- 2025-11-13 —— Web 鉴权、Settings 面板与 MCP stdio 模式发布。
## 更新日志
详细版本历史和所有变更请查看 [CHANGELOG.md](CHANGELOG.md)
### 近期亮点
- **2026-01-11** – 新增角色化测试功能,支持预设安全测试角色
- **2026-01-08** 新增 SSE 传输模式支持,外部 MCP 联邦支持三种模式
- **2026-01-01** – 新增批量任务管理功能,支持队列式任务执行
- **2025-12-25** – 新增漏洞管理和对话分组功能
- **2025-12-20** – 新增知识库功能,支持向量检索和混合搜索
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=Ed1s0nZ/CyberStrikeAI&type=date&legend=top-left)
## 404星链计划
<img src="./img/404StarLinkLogo.png" width="30%">
+56
View File
@@ -0,0 +1,56 @@
# SSE MCP 测试服务器
这是一个用于验证SSE模式外部MCP功能的测试服务器。
## 使用方法
### 1. 启动测试服务器
```bash
cd cmd/test-sse-mcp-server
go run main.go
```
服务器将在 `http://127.0.0.1:8082` 启动,提供以下端点:
- `GET /sse` - SSE事件流端点
- `POST /message` - 消息接收端点
### 2. 在CyberStrikeAI中添加配置
在Web界面中添加外部MCP配置,使用以下JSON:
```json
{
"test-sse-mcp": {
"transport": "sse",
"url": "http://127.0.0.1:8082/sse",
"description": "SSE MCP测试服务器",
"timeout": 30
}
}
```
### 3. 测试功能
测试服务器提供两个测试工具:
1. **test_echo** - 回显输入的文本
- 参数:`text` (string) - 要回显的文本
2. **test_add** - 计算两个数字的和
- 参数:`a` (number) - 第一个数字
- 参数:`b` (number) - 第二个数字
## 工作原理
1. 客户端通过 `GET /sse` 建立SSE连接,接收服务器推送的事件
2. 客户端通过 `POST /message` 发送MCP协议消息
3. 服务器处理消息后,通过SSE连接推送响应
## 日志
服务器会输出以下日志:
- SSE客户端连接/断开
- 收到的请求(方法名和ID
- 工具调用详情
+395
View File
@@ -0,0 +1,395 @@
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"sync"
"time"
"github.com/google/uuid"
)
const ProtocolVersion = "2024-11-05"
// Message MCP消息
type Message struct {
ID interface{} `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
Version string `json:"jsonrpc,omitempty"`
}
// Error MCP错误
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// InitializeRequest 初始化请求
type InitializeRequest struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]interface{} `json:"capabilities"`
ClientInfo ClientInfo `json:"clientInfo"`
}
// ClientInfo 客户端信息
type ClientInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// InitializeResponse 初始化响应
type InitializeResponse struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
}
// ServerCapabilities 服务器能力
type ServerCapabilities struct {
Tools map[string]interface{} `json:"tools,omitempty"`
}
// ServerInfo 服务器信息
type ServerInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// Tool 工具定义
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema map[string]interface{} `json:"inputSchema"`
}
// ListToolsResponse 列出工具响应
type ListToolsResponse struct {
Tools []Tool `json:"tools"`
}
// CallToolRequest 调用工具请求
type CallToolRequest struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// CallToolResponse 调用工具响应
type CallToolResponse struct {
Content []Content `json:"content"`
IsError bool `json:"isError,omitempty"`
}
// Content 内容
type Content struct {
Type string `json:"type"`
Text string `json:"text"`
}
// SSEServer SSE MCP服务器
type SSEServer struct {
sseClients map[string]chan []byte
mu sync.RWMutex
}
func NewSSEServer() *SSEServer {
return &SSEServer{
sseClients: make(map[string]chan []byte),
}
}
// handleSSE 处理SSE连接
func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
clientID := uuid.New().String()
clientChan := make(chan []byte, 10)
s.mu.Lock()
s.sseClients[clientID] = clientChan
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.sseClients, clientID)
close(clientChan)
s.mu.Unlock()
}()
// 发送初始ready事件
fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n")
flusher.Flush()
log.Printf("SSE客户端连接: %s", clientID)
// 心跳
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
log.Printf("SSE客户端断开: %s", clientID)
return
case msg, ok := <-clientChan:
if !ok {
return
}
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
flusher.Flush()
case <-ticker.C:
// 心跳
fmt.Fprintf(w, ": ping\n\n")
flusher.Flush()
}
}
}
// handleMessage 处理POST消息
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var msg Message
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
return
}
log.Printf("收到请求: method=%s, id=%v", msg.Method, msg.ID)
// 处理消息
response := s.processMessage(&msg)
// 如果有SSE客户端,通过SSE推送响应
if response != nil {
responseJSON, _ := json.Marshal(response)
s.mu.RLock()
// 发送给所有SSE客户端
for _, ch := range s.sseClients {
select {
case ch <- responseJSON:
default:
}
}
s.mu.RUnlock()
}
// 也直接返回响应(兼容非SSE模式)
if response != nil {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
} else {
w.WriteHeader(http.StatusOK)
}
}
// processMessage 处理MCP消息
func (s *SSEServer) processMessage(msg *Message) *Message {
switch msg.Method {
case "initialize":
return s.handleInitialize(msg)
case "tools/list":
return s.handleListTools(msg)
case "tools/call":
return s.handleCallTool(msg)
default:
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32601,
Message: "Method not found",
},
}
}
}
// handleInitialize 处理初始化
func (s *SSEServer) handleInitialize(msg *Message) *Message {
var req InitializeRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32602,
Message: "Invalid params",
},
}
}
log.Printf("初始化请求: client=%s, version=%s", req.ClientInfo.Name, req.ClientInfo.Version)
response := InitializeResponse{
ProtocolVersion: ProtocolVersion,
Capabilities: ServerCapabilities{
Tools: map[string]interface{}{
"listChanged": true,
},
},
ServerInfo: ServerInfo{
Name: "Test SSE MCP Server",
Version: "1.0.0",
},
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Version: "2.0",
Result: result,
}
}
// handleListTools 处理列出工具
func (s *SSEServer) handleListTools(msg *Message) *Message {
tools := []Tool{
{
Name: "test_echo",
Description: "回显输入的文本,用于测试SSE MCP服务器",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"text": map[string]interface{}{
"type": "string",
"description": "要回显的文本",
},
},
"required": []string{"text"},
},
},
{
Name: "test_add",
Description: "计算两个数字的和,用于测试SSE MCP服务器",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"a": map[string]interface{}{
"type": "number",
"description": "第一个数字",
},
"b": map[string]interface{}{
"type": "number",
"description": "第二个数字",
},
},
"required": []string{"a", "b"},
},
},
}
response := ListToolsResponse{Tools: tools}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Version: "2.0",
Result: result,
}
}
// handleCallTool 处理工具调用
func (s *SSEServer) handleCallTool(msg *Message) *Message {
var req CallToolRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32602,
Message: "Invalid params",
},
}
}
log.Printf("调用工具: name=%s, args=%v", req.Name, req.Arguments)
var content []Content
switch req.Name {
case "test_echo":
text, _ := req.Arguments["text"].(string)
content = []Content{
{
Type: "text",
Text: fmt.Sprintf("回显: %s", text),
},
}
case "test_add":
var a, b float64
if val, ok := req.Arguments["a"].(float64); ok {
a = val
}
if val, ok := req.Arguments["b"].(float64); ok {
b = val
}
sum := a + b
content = []Content{
{
Type: "text",
Text: fmt.Sprintf("%.2f + %.2f = %.2f", a, b, sum),
},
}
default:
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32601,
Message: "Tool not found",
},
}
}
response := CallToolResponse{
Content: content,
IsError: false,
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Version: "2.0",
Result: result,
}
}
func main() {
server := NewSSEServer()
http.HandleFunc("/sse", server.handleSSE)
http.HandleFunc("/message", server.handleMessage)
port := ":8082"
log.Printf("SSE MCP测试服务器启动在端口 %s", port)
log.Printf("SSE端点: http://localhost%s/sse", port)
log.Printf("消息端点: http://localhost%s/message", port)
log.Printf("配置示例:")
log.Printf(`{
"test-sse-mcp": {
"transport": "sse",
"url": "http://127.0.0.1:8082/sse"
}
}`)
if err := http.ListenAndServe(port, nil); err != nil {
log.Fatal(err)
}
}
+4
View File
@@ -66,3 +66,7 @@ knowledge:
top_k: 5 # 检索返回的Top-K结果数量
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0表示纯向量检索,0.0表示纯关键词检索
# 角色配置
roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录)
# 系统会从该目录加载所有 .yaml 格式的角色配置文件
# 每个角色应创建独立的配置文件,例如:roles/CTF.yaml, roles/默认.yaml 等
+37 -9
View File
@@ -12,6 +12,7 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/storage"
@@ -302,16 +303,16 @@ type ProgressCallback func(eventType, message string, data interface{})
// AgentLoop 执行Agent循环
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil)
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil)
}
// AgentLoopWithConversationID 执行Agent循环(带对话ID
func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil)
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil)
}
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback) (*AgentLoopResult, error) {
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) {
// 设置当前对话ID
a.mu.Lock()
a.currentConversationID = conversationID
@@ -401,8 +402,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
漏洞记录要求:
- 当你发现有效漏洞时,必须使用 record_vulnerability 工具记录漏洞详情
- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议
- 当你发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 工具记录漏洞详情
` + `- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议
- 严重程度评估标准:
* critical(严重):可导致系统完全被控制、数据泄露、服务中断等
* high(高):可导致敏感信息泄露、权限提升、重要功能被绕过等
@@ -512,7 +513,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}
// 获取可用工具
tools := a.getAvailableTools()
tools := a.getAvailableTools(roleTools)
// 记录当前上下文的Token用量,展示压缩器运行状态
if a.memoryCompressor != nil {
@@ -837,13 +838,29 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
// getAvailableTools 获取可用工具
// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗
func (a *Agent) getAvailableTools() []Tool {
// roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色)
func (a *Agent) getAvailableTools(roleTools []string) []Tool {
// 构建角色工具集合(用于快速查找)
roleToolSet := make(map[string]bool)
if len(roleTools) > 0 {
for _, toolKey := range roleTools {
roleToolSet[toolKey] = true
}
}
// 从MCP服务器获取所有已注册的内部工具
mcpTools := a.mcpServer.GetAllTools()
// 转换为OpenAI格式的工具定义
tools := make([]Tool, 0, len(mcpTools))
for _, mcpTool := range mcpTools {
// 如果指定了角色工具列表,只添加在列表中的工具
if len(roleToolSet) > 0 {
toolKey := mcpTool.Name // 内置工具使用工具名称作为key
if !roleToolSet[toolKey] {
continue // 不在角色工具列表中,跳过
}
}
// 使用简短描述(如果存在),否则使用详细描述
description := mcpTool.ShortDescription
if description == "" {
@@ -865,7 +882,8 @@ func (a *Agent) getAvailableTools() []Tool {
// 获取外部MCP工具
if a.externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
// 增加超时时间到30秒,因为通过代理连接远程服务器可能需要更长时间
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
externalTools, err := a.externalMCPMgr.GetAllTools(ctx)
@@ -882,6 +900,16 @@ func (a *Agent) getAvailableTools() []Tool {
// 将外部MCP工具添加到工具列表(只添加启用的工具)
for _, externalTool := range externalTools {
// 外部工具使用 "mcpName::toolName" 作为toolKey
externalToolKey := externalTool.Name
// 如果指定了角色工具列表,只添加在列表中的工具
if len(roleToolSet) > 0 {
if !roleToolSet[externalToolKey] {
continue // 不在角色工具列表中,跳过
}
}
// 解析工具名称:mcpName::toolName
var mcpName, actualToolName string
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
@@ -1135,7 +1163,7 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
)
// 如果是record_vulnerability工具,自动添加conversation_id
if toolName == "record_vulnerability" {
if toolName == builtin.ToolRecordVulnerability {
a.mu.RLock()
conversationID := a.currentConversationID
a.mu.RUnlock()
+89 -18
View File
@@ -16,6 +16,7 @@ import (
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/storage"
@@ -214,23 +215,53 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
return
}
if hasIndex {
// 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 {
log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
continue
if hasIndex {
// 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 {
log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
failedCount := 0
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 如果连续失败2次,立即停止增量索引
if consecutiveFailures >= 2 {
log.Logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
}
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
} else {
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
}
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
} else {
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
return
}
return
}
// 只有在没有索引时才自动重建
log.Logger.Info("未检测到知识库索引,开始自动构建索引")
@@ -248,7 +279,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
}
// 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
if knowledgeManager != nil {
agentHandler.SetKnowledgeManager(knowledgeManager)
@@ -262,6 +293,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
// 创建 App 实例(部分字段稍后填充)
app := &App{
@@ -338,6 +370,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
attackChainHandler,
app, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler,
roleHandler,
mcpServer,
authManager,
)
@@ -398,6 +431,7 @@ func setupRoutes(
attackChainHandler *handler.AttackChainHandler,
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler *handler.VulnerabilityHandler,
roleHandler *handler.RoleHandler,
mcpServer *mcp.Server,
authManager *security.AuthManager,
) {
@@ -623,6 +657,13 @@ func setupRoutes(
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
// 角色管理
protected.GET("/roles", roleHandler.GetRoles)
protected.GET("/roles/:name", roleHandler.GetRole)
protected.POST("/roles", roleHandler.CreateRole)
protected.PUT("/roles/:name", roleHandler.UpdateRole)
protected.DELETE("/roles/:name", roleHandler.DeleteRole)
// MCP端点
protected.POST("/mcp", func(c *gin.Context) {
mcpServer.HandleHTTP(c.Writer, c.Request)
@@ -642,7 +683,7 @@ func setupRoutes(
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{
Name: "record_vulnerability",
Name: builtin.ToolRecordVulnerability,
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
InputSchema: map[string]interface{}{
@@ -934,13 +975,43 @@ func initializeKnowledge(
if len(itemsToIndex) > 0 {
logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
failedCount := 0
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 如果连续失败2次,立即停止增量索引
if consecutiveFailures >= 2 {
logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
}
logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
} else {
logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
}
+136
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"gopkg.in/yaml.v3"
@@ -22,6 +23,8 @@ type Config struct {
Auth AuthConfig `yaml:"auth"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
}
type ServerConfig struct {
@@ -81,6 +84,7 @@ type ExternalMCPServerConfig struct {
// stdio模式配置
Command string `yaml:"command,omitempty" json:"command,omitempty"`
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式)
// HTTP模式配置
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio"
@@ -206,6 +210,29 @@ func Load(path string) (*Config, error) {
}
}
// 从角色目录加载角色配置
if cfg.RolesDir != "" {
configDir := filepath.Dir(path)
rolesDir := cfg.RolesDir
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
roles, err := LoadRolesFromDir(rolesDir)
if err != nil {
return nil, fmt.Errorf("从角色目录加载角色配置失败: %w", err)
}
cfg.Roles = roles
} else {
// 如果未配置 roles_dir,初始化为空 map
if cfg.Roles == nil {
cfg.Roles = make(map[string]RoleConfig)
}
}
return &cfg, nil
}
@@ -374,6 +401,98 @@ func LoadToolFromFile(path string) (*ToolConfig, error) {
return &tool, nil
}
// LoadRolesFromDir 从目录加载所有角色配置文件
func LoadRolesFromDir(dir string) (map[string]RoleConfig, error) {
roles := make(map[string]RoleConfig)
// 检查目录是否存在
if _, err := os.Stat(dir); os.IsNotExist(err) {
return roles, nil // 目录不存在时返回空map,不报错
}
// 读取目录中的所有 .yaml 和 .yml 文件
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("读取角色目录失败: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
continue
}
filePath := filepath.Join(dir, name)
role, err := LoadRoleFromFile(filePath)
if err != nil {
// 记录错误但继续加载其他文件
fmt.Printf("警告: 加载角色配置文件 %s 失败: %v\n", filePath, err)
continue
}
// 使用角色名称作为key
roleName := role.Name
if roleName == "" {
// 如果角色名称为空,使用文件名(去掉扩展名)作为名称
roleName = strings.TrimSuffix(strings.TrimSuffix(name, ".yaml"), ".yml")
role.Name = roleName
}
roles[roleName] = *role
}
return roles, nil
}
// LoadRoleFromFile 从单个文件加载角色配置
func LoadRoleFromFile(path string) (*RoleConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
var role RoleConfig
if err := yaml.Unmarshal(data, &role); err != nil {
return nil, fmt.Errorf("解析角色配置失败: %w", err)
}
// 处理 icon 字段:如果包含 Unicode 转义格式(\U0001F3C6),转换为实际的 Unicode 字符
// Go 的 yaml 库可能不会自动解析 \U 转义序列,需要手动转换
if role.Icon != "" {
icon := role.Icon
// 去除可能的引号
icon = strings.Trim(icon, `"`)
// 检查是否是 Unicode 转义格式 \U0001F3C68位十六进制)或 \uXXXX(4位十六进制)
if len(icon) >= 3 && icon[0] == '\\' {
if icon[1] == 'U' && len(icon) >= 10 {
// \U0001F3C6 格式(8位十六进制)
if codePoint, err := strconv.ParseInt(icon[2:10], 16, 32); err == nil {
role.Icon = string(rune(codePoint))
}
} else if icon[1] == 'u' && len(icon) >= 6 {
// \uXXXX 格式(4位十六进制)
if codePoint, err := strconv.ParseInt(icon[2:6], 16, 32); err == nil {
role.Icon = string(rune(codePoint))
}
}
}
}
// 验证必需字段
if role.Name == "" {
// 如果名称为空,尝试从文件名获取
baseName := filepath.Base(path)
role.Name = strings.TrimSuffix(strings.TrimSuffix(baseName, ".yaml"), ".yml")
}
return &role, nil
}
func Default() *Config {
return &Config{
Server: ServerConfig{
@@ -447,3 +566,20 @@ type RetrievalConfig struct {
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 相似度阈值
HybridWeight float64 `yaml:"hybrid_weight" json:"hybrid_weight"` // 向量检索权重(0-1
}
// RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代)
// 保留此类型以兼容旧代码,但建议直接使用 map[string]RoleConfig
type RolesConfig struct {
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"`
}
// RoleConfig 单个角色配置
type RoleConfig struct {
Name string `yaml:"name" json:"name"` // 角色名称
Description string `yaml:"description" json:"description"` // 角色描述
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName"
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
}
+16 -15
View File
@@ -11,6 +11,7 @@ import (
// BatchTaskQueueRow 批量任务队列数据库行
type BatchTaskQueueRow struct {
ID string
Title sql.NullString
Status string
CreatedAt time.Time
StartedAt sql.NullTime
@@ -32,7 +33,7 @@ type BatchTaskRow struct {
}
// CreateBatchQueue 创建批量任务队列
func (db *DB) CreateBatchQueue(queueID string, tasks []map[string]interface{}) error {
func (db *DB) CreateBatchQueue(queueID string, title string, tasks []map[string]interface{}) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
@@ -41,8 +42,8 @@ func (db *DB) CreateBatchQueue(queueID string, tasks []map[string]interface{}) e
now := time.Now()
_, err = tx.Exec(
"INSERT INTO batch_task_queues (id, status, created_at, current_index) VALUES (?, ?, ?, ?)",
queueID, "pending", now, 0,
"INSERT INTO batch_task_queues (id, title, status, created_at, current_index) VALUES (?, ?, ?, ?, ?)",
queueID, title, "pending", now, 0,
)
if err != nil {
return fmt.Errorf("创建批量任务队列失败: %w", err)
@@ -76,9 +77,9 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
var row BatchTaskQueueRow
var createdAt string
err := db.QueryRow(
"SELECT id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
"SELECT id, title, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
queueID,
).Scan(&row.ID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
).Scan(&row.ID, &row.Title, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
if err == sql.ErrNoRows {
return nil, nil
}
@@ -102,7 +103,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
// GetAllBatchQueues 获取所有批量任务队列
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
rows, err := db.Query(
"SELECT id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
"SELECT id, title, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
)
if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
@@ -113,7 +114,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
for rows.Next() {
var row BatchTaskQueueRow
var createdAt string
if err := rows.Scan(&row.ID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
if err := rows.Scan(&row.ID, &row.Title, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
@@ -133,7 +134,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
query := "SELECT id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
query := "SELECT id, title, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
args := []interface{}{}
// 状态筛选
@@ -142,10 +143,10 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
args = append(args, status)
}
// 关键字搜索(搜索队列ID
// 关键字搜索(搜索队列ID和标题
if keyword != "" {
query += " AND id LIKE ?"
args = append(args, "%"+keyword+"%")
query += " AND (id LIKE ? OR title LIKE ?)"
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
}
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
@@ -161,7 +162,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
for rows.Next() {
var row BatchTaskQueueRow
var createdAt string
if err := rows.Scan(&row.ID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
if err := rows.Scan(&row.ID, &row.Title, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
@@ -190,10 +191,10 @@ func (db *DB) CountBatchQueues(status, keyword string) (int, error) {
args = append(args, status)
}
// 关键字搜索
// 关键字搜索(搜索队列ID和标题)
if keyword != "" {
query += " AND id LIKE ?"
args = append(args, "%"+keyword+"%")
query += " AND (id LIKE ? OR title LIKE ?)"
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
}
var count int
+31
View File
@@ -193,6 +193,7 @@ func (db *DB) initTables() error {
createBatchTaskQueuesTable := `
CREATE TABLE IF NOT EXISTS batch_task_queues (
id TEXT PRIMARY KEY,
title TEXT,
status TEXT NOT NULL,
created_at DATETIME NOT NULL,
started_at DATETIME,
@@ -240,6 +241,7 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at);
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
`
if _, err := db.Exec(createConversationsTable); err != nil {
@@ -310,6 +312,11 @@ func (db *DB) initTables() error {
// 不返回错误,允许继续运行
}
if err := db.migrateBatchTaskQueuesTable(); err != nil {
db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if _, err := db.Exec(createIndexes); err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
@@ -426,6 +433,30 @@ func (db *DB) migrateConversationGroupMappingsTable() error {
return nil
}
// migrateBatchTaskQueuesTable 迁移batch_task_queues表,添加title字段
func (db *DB) migrateBatchTaskQueuesTable() error {
// 检查title字段是否存在
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='title'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); addErr != nil {
// 如果字段已存在,忽略错误
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加title字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); err != nil {
db.logger.Warn("添加title字段失败", zap.Error(err))
}
}
return nil
}
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
+73 -1
View File
@@ -205,7 +205,7 @@ func (db *DB) AddConversationToGroup(conversationID, groupID string) error {
if err != nil {
return fmt.Errorf("删除对话旧分组关联失败: %w", err)
}
// 然后插入新的分组关联
id := uuid.New().String()
_, err = db.Exec(
@@ -282,6 +282,78 @@ func (db *DB) GetConversationsByGroup(groupID string) ([]*Conversation, error) {
return conversations, nil
}
// SearchConversationsByGroup 搜索分组中的对话(按标题和消息内容模糊匹配)
func (db *DB) SearchConversationsByGroup(groupID string, searchQuery string) ([]*Conversation, error) {
// 构建SQL查询,支持按标题和消息内容搜索
// 使用 DISTINCT 避免因为一个对话有多条匹配消息而重复
query := `SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned
FROM conversations c
INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id
WHERE cgm.group_id = ?`
args := []interface{}{groupID}
// 如果有搜索关键词,添加标题和消息内容搜索条件
if searchQuery != "" {
searchPattern := "%" + searchQuery + "%"
// 搜索标题或消息内容
// 使用 LEFT JOIN 连接消息表,这样即使没有消息的对话也能被搜索到(通过标题)
query += ` AND (
LOWER(c.title) LIKE LOWER(?)
OR EXISTS (
SELECT 1 FROM messages m
WHERE m.conversation_id = c.id
AND LOWER(m.content) LIKE LOWER(?)
)
)`
args = append(args, searchPattern, searchPattern)
}
query += " ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC"
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("搜索分组对话失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
var pinned int
var groupPinned int
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
conv.Pinned = pinned != 0
conversations = append(conversations, &conv)
}
return conversations, nil
}
// GetGroupByConversation 获取对话所属的分组
func (db *DB) GetGroupByConversation(conversationID string) (string, error) {
var groupID string
+64 -12
View File
@@ -12,7 +12,9 @@ import (
"unicode/utf8"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp/builtin"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
@@ -66,13 +68,14 @@ type AgentHandler struct {
logger *zap.Logger
tasks *AgentTaskManager
batchTaskManager *BatchTaskManager
knowledgeManager interface { // 知识库管理器接口
config *config.Config // 配置引用,用于获取角色信息
knowledgeManager interface { // 知识库管理器接口
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
}
}
// NewAgentHandler 创建新的Agent处理器
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, logger *zap.Logger) *AgentHandler {
batchTaskManager := NewBatchTaskManager()
batchTaskManager.SetDB(db)
@@ -87,6 +90,7 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *A
logger: logger,
tasks: NewAgentTaskManager(),
batchTaskManager: batchTaskManager,
config: cfg,
}
}
@@ -101,6 +105,7 @@ func (h *AgentHandler) SetKnowledgeManager(manager interface {
type ChatRequest struct {
Message string `json:"message" binding:"required"`
ConversationID string `json:"conversationId,omitempty"`
Role string `json:"role,omitempty"` // 角色名称
}
// ChatResponse 聊天响应
@@ -161,14 +166,34 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 保存用户消息
// 应用角色用户提示词和工具配置
finalMessage := req.Message
var roleTools []string // 角色配置的工具列表
if req.Role != "" && req.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + req.Message
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
}
}
}
}
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
// 执行Agent Loop,传入历史消息和对话ID
result, err := h.agent.AgentLoopWithConversationID(c.Request.Context(), req.Message, agentHistoryMessages, conversationID)
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
@@ -231,7 +256,7 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
if eventType == "tool_call" {
if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string)
if toolName == "search_knowledge_base" {
if toolName == builtin.ToolSearchKnowledgeBase {
if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" {
if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
toolCallCache[toolCallId] = argumentsObj
@@ -245,7 +270,7 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
if eventType == "tool_result" && h.knowledgeManager != nil {
if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string)
if toolName == "search_knowledge_base" {
if toolName == builtin.ToolSearchKnowledgeBase {
// 提取检索信息
query := ""
riskType := ""
@@ -470,7 +495,32 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 保存用户消息
// 应用角色用户提示词和工具配置
finalMessage := req.Message
var roleTools []string // 角色配置的工具列表
if req.Role != "" && req.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + req.Message
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
} else if len(role.MCPs) > 0 {
// 向后兼容:如果只有mcps字段,暂时使用空列表(表示使用所有工具)
// 因为mcps是MCP服务器名称,不是工具列表
h.logger.Info("角色配置使用旧的mcps字段,将使用所有工具", zap.String("role", req.Role))
}
}
}
}
// 如果roleTools为空,表示使用所有工具(默认角色或未配置工具的角色)
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
@@ -547,9 +597,9 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
taskStatus := "completed"
defer h.tasks.FinishTask(conversationID, taskStatus)
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断(使用包含角色提示词的finalMessage和角色工具列表)
sendEvent("progress", "正在分析您的请求...", nil)
result, err := h.agent.AgentLoopWithProgress(taskCtx, req.Message, agentHistoryMessages, conversationID, progressCallback)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
cause := context.Cause(baseCtx)
@@ -759,6 +809,7 @@ func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
// BatchTaskRequest 批量任务请求
type BatchTaskRequest struct {
Title string `json:"title"` // 任务标题(可选)
Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务
}
@@ -788,7 +839,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
return
}
queue := h.batchTaskManager.CreateBatchQueue(validTasks)
queue := h.batchTaskManager.CreateBatchQueue(req.Title, validTasks)
c.JSON(http.StatusOK, gin.H{
"queueId": queue.ID,
"queue": queue,
@@ -1071,7 +1122,8 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
// 存储取消函数,以便在取消队列时能够取消当前任务
h.batchTaskManager.SetTaskCancel(queueID, cancel)
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback)
// 批量任务暂时不支持角色工具过滤,使用所有工具(传入nil)
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback, nil)
// 任务执行完成,清理取消函数
h.batchTaskManager.SetTaskCancel(queueID, nil)
cancel()
+13 -4
View File
@@ -28,6 +28,7 @@ type BatchTask struct {
// BatchTaskQueue 批量任务队列
type BatchTaskQueue struct {
ID string `json:"id"`
Title string `json:"title,omitempty"`
Tasks []*BatchTask `json:"tasks"`
Status string `json:"status"` // pending, running, paused, completed, cancelled
CreatedAt time.Time `json:"createdAt"`
@@ -61,13 +62,14 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
}
// CreateBatchQueue 创建批量任务队列
func (m *BatchTaskManager) CreateBatchQueue(tasks []string) *BatchTaskQueue {
func (m *BatchTaskManager) CreateBatchQueue(title string, tasks []string) *BatchTaskQueue {
m.mu.Lock()
defer m.mu.Unlock()
queueID := time.Now().Format("20060102150405") + "-" + generateShortID()
queue := &BatchTaskQueue{
ID: queueID,
Title: title,
Tasks: make([]*BatchTask, 0, len(tasks)),
Status: "pending",
CreatedAt: time.Now(),
@@ -96,7 +98,7 @@ func (m *BatchTaskManager) CreateBatchQueue(tasks []string) *BatchTaskQueue {
// 保存到数据库
if m.db != nil {
if err := m.db.CreateBatchQueue(queueID, dbTasks); err != nil {
if err := m.db.CreateBatchQueue(queueID, title, dbTasks); err != nil {
// 如果数据库保存失败,记录错误但继续(使用内存缓存)
// 这里可以添加日志记录
}
@@ -153,6 +155,9 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
Tasks: make([]*BatchTask, 0, len(taskRows)),
}
if queueRow.Title.Valid {
queue.Title = queueRow.Title.String
}
if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time
}
@@ -271,11 +276,12 @@ func (m *BatchTaskManager) ListQueues(limit, offset int, status, keyword string)
if status != "" && status != "all" && queue.Status != status {
continue
}
// 关键字搜索
// 关键字搜索(搜索队列ID和标题)
if keyword != "" {
keywordLower := strings.ToLower(keyword)
queueIDLower := strings.ToLower(queue.ID)
if !strings.Contains(queueIDLower, keywordLower) {
queueTitleLower := strings.ToLower(queue.Title)
if !strings.Contains(queueIDLower, keywordLower) && !strings.Contains(queueTitleLower, keywordLower) {
// 也可以搜索创建时间
createdAtStr := queue.CreatedAt.Format("2006-01-02 15:04:05")
if !strings.Contains(createdAtStr, keyword) {
@@ -342,6 +348,9 @@ func (m *BatchTaskManager) LoadFromDB() error {
Tasks: make([]*BatchTask, 0, len(taskRows)),
}
if queueRow.Title.Valid {
queue.Title = queueRow.Title.String
}
if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time
}
+208 -28
View File
@@ -47,16 +47,17 @@ type ConfigHandler struct {
config *config.Config
mcpServer *mcp.Server
executor *security.Executor
agent AgentUpdater // Agent接口,用于更新Agent配置
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
agent AgentUpdater // Agent接口,用于更新Agent配置
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器(可选)
logger *zap.Logger
mu sync.RWMutex
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
}
// AttackChainUpdater 攻击链处理器更新接口
@@ -72,15 +73,26 @@ type AgentUpdater interface {
// NewConfigHandler 创建新的配置处理器
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, attackChainHandler AttackChainUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler {
// 保存初始的嵌入模型配置(如果知识库已启用)
var lastEmbeddingConfig *config.EmbeddingConfig
if cfg.Knowledge.Enabled {
lastEmbeddingConfig = &config.EmbeddingConfig{
Provider: cfg.Knowledge.Embedding.Provider,
Model: cfg.Knowledge.Embedding.Model,
BaseURL: cfg.Knowledge.Embedding.BaseURL,
APIKey: cfg.Knowledge.Embedding.APIKey,
}
}
return &ConfigHandler{
configPath: configPath,
config: cfg,
mcpServer: mcpServer,
executor: executor,
agent: agent,
attackChainHandler: attackChainHandler,
externalMCPMgr: externalMCPMgr,
logger: logger,
configPath: configPath,
config: cfg,
mcpServer: mcpServer,
executor: executor,
agent: agent,
attackChainHandler: attackChainHandler,
externalMCPMgr: externalMCPMgr,
logger: logger,
lastEmbeddingConfig: lastEmbeddingConfig,
}
}
@@ -135,6 +147,7 @@ type ToolConfigInfo struct {
Enabled bool `json:"enabled"`
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
}
// GetConfig 获取当前配置
@@ -191,7 +204,8 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
// 获取外部MCP工具
if h.externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
// 增加超时时间到30秒,因为通过代理连接远程服务器可能需要更长时间
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
@@ -259,11 +273,12 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
// GetToolsResponse 获取工具列表响应(分页)
type GetToolsResponse struct {
Tools []ToolConfigInfo `json:"tools"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
Tools []ToolConfigInfo `json:"tools"`
Total int `json:"total"`
TotalEnabled int `json:"total_enabled"` // 已启用的工具总数
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}
// GetTools 获取工具列表(支持分页和搜索)
@@ -292,6 +307,23 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
searchTermLower = strings.ToLower(searchTerm)
}
// 解析角色参数,用于过滤工具并标注启用状态
roleName := c.Query("role")
var roleToolsSet map[string]bool // 角色配置的工具集合
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
if roleName != "" && roleName != "默认" && h.config.Roles != nil {
if role, exists := h.config.Roles[roleName]; exists && role.Enabled {
if len(role.Tools) > 0 {
// 角色配置了工具列表,只使用这些工具
roleToolsSet = make(map[string]bool)
for _, toolKey := range role.Tools {
roleToolsSet[toolKey] = true
}
roleUsesAllTools = false
}
}
}
// 获取所有内部工具并应用搜索过滤
configToolMap := make(map[string]bool)
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
@@ -312,6 +344,31 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
toolInfo.Description = desc
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,标注启用的工具为role_enabled=true
if tool.Enabled {
roleEnabled := true
toolInfo.RoleEnabled = &roleEnabled
} else {
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 内部工具使用工具名称作为key
if roleToolsSet[tool.Name] {
roleEnabled := tool.Enabled // 工具必须在角色列表中且本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
// 如果有关键词,进行搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name)
@@ -348,6 +405,26 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
IsExternal: false,
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,直接注册的工具默认启用
roleEnabled := true
toolInfo.RoleEnabled = &roleEnabled
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 内部工具使用工具名称作为key
if roleToolsSet[mcpTool.Name] {
roleEnabled := true // 在角色列表中且工具本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
// 如果有关键词,进行搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name)
@@ -363,7 +440,8 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
// 获取外部MCP工具
if h.externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
// 增加超时时间到30秒,因为通过代理连接远程服务器可能需要更长时间
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
@@ -425,18 +503,55 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
}
}
allTools = append(allTools, ToolConfigInfo{
toolInfo := ToolConfigInfo{
Name: actualToolName, // 显示实际工具名称,不带前缀
Description: description,
Enabled: enabled,
IsExternal: true,
ExternalMCP: mcpName,
})
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,标注启用的工具为role_enabled=true
toolInfo.RoleEnabled = &enabled
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 外部工具使用 "mcpName::toolName" 格式作为key
externalToolKey := externalTool.Name // 这是 "mcpName::toolName" 格式
if roleToolsSet[externalToolKey] {
roleEnabled := enabled // 工具必须在角色列表中且本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
allTools = append(allTools, toolInfo)
}
}
}
// 如果角色配置了工具列表,过滤工具(只保留列表中的工具,但保留其他工具并标记为禁用)
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
total := len(allTools)
// 统计已启用的工具数(在角色中的启用工具数)
totalEnabled := 0
for _, tool := range allTools {
if tool.RoleEnabled != nil && *tool.RoleEnabled {
totalEnabled++
} else if tool.RoleEnabled == nil && tool.Enabled {
// 如果未指定角色,统计所有启用的工具
totalEnabled++
}
}
totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
@@ -457,11 +572,12 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
}
c.JSON(http.StatusOK, GetToolsResponse{
Tools: tools,
Total: total,
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
Tools: tools,
Total: total,
TotalEnabled: totalEnabled,
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
})
}
@@ -522,6 +638,15 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
// 更新Knowledge配置
if req.Knowledge != nil {
// 保存旧的嵌入模型配置(用于检测变更)
if h.config.Knowledge.Enabled {
h.lastEmbeddingConfig = &config.EmbeddingConfig{
Provider: h.config.Knowledge.Embedding.Provider,
Model: h.config.Knowledge.Embedding.Model,
BaseURL: h.config.Knowledge.Embedding.BaseURL,
APIKey: h.config.Knowledge.Embedding.APIKey,
}
}
h.config.Knowledge = *req.Knowledge
h.logger.Info("更新Knowledge配置",
zap.Bool("enabled", h.config.Knowledge.Enabled),
@@ -676,10 +801,55 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("知识库动态初始化完成,工具已注册")
}
// 检查嵌入模型配置是否变更(需要在锁外执行,避免阻塞)
var needReinitKnowledge bool
var reinitKnowledgeInitializer KnowledgeInitializer
h.mu.RLock()
if h.config.Knowledge.Enabled && h.knowledgeInitializer != nil && h.lastEmbeddingConfig != nil {
// 检查嵌入模型配置是否变更
currentEmbedding := h.config.Knowledge.Embedding
if currentEmbedding.Provider != h.lastEmbeddingConfig.Provider ||
currentEmbedding.Model != h.lastEmbeddingConfig.Model ||
currentEmbedding.BaseURL != h.lastEmbeddingConfig.BaseURL ||
currentEmbedding.APIKey != h.lastEmbeddingConfig.APIKey {
needReinitKnowledge = true
reinitKnowledgeInitializer = h.knowledgeInitializer
h.logger.Info("检测到嵌入模型配置变更,需要重新初始化知识库组件",
zap.String("old_model", h.lastEmbeddingConfig.Model),
zap.String("new_model", currentEmbedding.Model),
zap.String("old_base_url", h.lastEmbeddingConfig.BaseURL),
zap.String("new_base_url", currentEmbedding.BaseURL),
)
}
}
h.mu.RUnlock()
// 如果需要重新初始化知识库(嵌入模型配置变更),在锁外执行
if needReinitKnowledge {
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
if _, err := reinitKnowledgeInitializer(); err != nil {
h.logger.Error("重新初始化知识库失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
return
}
h.logger.Info("知识库组件重新初始化完成")
}
// 现在获取写锁,执行快速的操作
h.mu.Lock()
defer h.mu.Unlock()
// 如果重新初始化了知识库,更新嵌入模型配置记录
if needReinitKnowledge && h.config.Knowledge.Enabled {
h.lastEmbeddingConfig = &config.EmbeddingConfig{
Provider: h.config.Knowledge.Embedding.Provider,
Model: h.config.Knowledge.Embedding.Model,
BaseURL: h.config.Knowledge.Embedding.BaseURL,
APIKey: h.config.Knowledge.Embedding.APIKey,
}
h.logger.Info("已更新嵌入模型配置记录")
}
// 重新注册工具(根据新的启用状态)
h.logger.Info("重新注册工具")
@@ -737,6 +907,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
)
}
// 更新嵌入模型配置记录(如果知识库启用)
if h.config.Knowledge.Enabled {
h.lastEmbeddingConfig = &config.EmbeddingConfig{
Provider: h.config.Knowledge.Embedding.Provider,
Model: h.config.Knowledge.Embedding.Model,
BaseURL: h.config.Knowledge.Embedding.BaseURL,
APIKey: h.config.Knowledge.Embedding.APIKey,
}
}
h.logger.Info("配置已应用",
zap.Int("tools_count", len(h.config.Security.Tools)),
)
+13 -2
View File
@@ -324,7 +324,7 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
} else if cfg.URL != "" {
transport = "http"
} else {
return fmt.Errorf("需要指定commandstdio模式)或urlhttp模式)")
return fmt.Errorf("需要指定commandstdio模式)或urlhttp/sse模式)")
}
}
@@ -337,8 +337,12 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
if cfg.Command == "" {
return fmt.Errorf("stdio模式需要command")
}
case "sse":
if cfg.URL == "" {
return fmt.Errorf("SSE模式需要URL")
}
default:
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio", transport)
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport)
}
return nil
@@ -442,6 +446,13 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
if len(serverCfg.Args) > 0 {
setStringArrayInMap(serverNode, "args", serverCfg.Args)
}
// 保存 env 字段(环境变量)
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
envNode := ensureMap(serverNode, "env")
for envKey, envValue := range serverCfg.Env {
setStringInMap(envNode, envKey, envValue)
}
}
if serverCfg.Transport != "" {
setStringInMap(serverNode, "transport", serverCfg.Transport)
}
+11 -1
View File
@@ -189,8 +189,18 @@ type GroupConversation struct {
// GetGroupConversations 获取分组中的所有对话
func (h *GroupHandler) GetGroupConversations(c *gin.Context) {
groupID := c.Param("id")
searchQuery := c.Query("search") // 获取搜索参数
var conversations []*database.Conversation
var err error
// 如果有搜索关键词,使用搜索方法;否则使用普通方法
if searchQuery != "" {
conversations, err = h.db.SearchConversationsByGroup(groupID, searchQuery)
} else {
conversations, err = h.db.GetConversationsByGroup(groupID)
}
conversations, err := h.db.GetConversationsByGroup(groupID)
if err != nil {
h.logger.Error("获取分组对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+56 -3
View File
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"time"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/knowledge"
@@ -336,14 +337,54 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
go func() {
ctx := context.Background()
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
failedCount := 0
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
for i, itemID := range itemsToIndex {
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
h.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
failedCount++
consecutiveFailures++
// 只在第一个失败时记录详细日志
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
h.logger.Warn("索引知识项失败",
zap.String("itemId", itemID),
zap.Int("totalItems", len(itemsToIndex)),
zap.Error(err),
)
}
// 如果连续失败2次,立即停止增量索引
if consecutiveFailures >= 2 {
h.logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.Int("processedItems", i+1),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)))
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
// 减少进度日志频率
if (i+1)%10 == 0 || i+1 == len(itemsToIndex) {
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount))
}
}
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
}()
c.JSON(http.StatusOK, gin.H{
@@ -396,6 +437,18 @@ func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) {
return
}
// 获取索引器的错误信息
if h.indexer != nil {
lastError, lastErrorTime := h.indexer.GetLastError()
if lastError != "" {
// 如果错误是最近发生的(5分钟内),则返回错误信息
if time.Since(lastErrorTime) < 5*time.Minute {
status["last_error"] = lastError
status["last_error_time"] = lastErrorTime.Format(time.RFC3339)
}
}
}
c.JSON(http.StatusOK, status)
}
+453
View File
@@ -0,0 +1,453 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/config"
"gopkg.in/yaml.v3"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// RoleHandler 角色处理器
type RoleHandler struct {
config *config.Config
configPath string
logger *zap.Logger
}
// NewRoleHandler 创建新的角色处理器
func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler {
return &RoleHandler{
config: cfg,
configPath: configPath,
logger: logger,
}
}
// GetRoles 获取所有角色
func (h *RoleHandler) GetRoles(c *gin.Context) {
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
roles := make([]config.RoleConfig, 0, len(h.config.Roles))
for key, role := range h.config.Roles {
// 确保角色的key与name一致
if role.Name == "" {
role.Name = key
}
roles = append(roles, role)
}
c.JSON(http.StatusOK, gin.H{
"roles": roles,
})
}
// GetRole 获取单个角色
func (h *RoleHandler) GetRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
role, exists := h.config.Roles[roleName]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 确保角色的name与key一致
if role.Name == "" {
role.Name = roleName
}
c.JSON(http.StatusOK, gin.H{
"role": role,
})
}
// UpdateRole 更新角色
func (h *RoleHandler) UpdateRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
// 确保角色名称与请求中的name一致
if req.Name == "" {
req.Name = roleName
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 删除所有与角色name相同但key不同的旧角色(避免重复)
// 使用角色name作为key,确保唯一性
finalKey := req.Name
keysToDelete := make([]string, 0)
for key := range h.config.Roles {
// 如果key与最终的key不同,但name相同,则标记为删除
if key != finalKey {
role := h.config.Roles[key]
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = key
}
if role.Name == req.Name {
keysToDelete = append(keysToDelete, key)
}
}
}
// 删除旧的角色
for _, key := range keysToDelete {
delete(h.config.Roles, key)
h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name))
}
// 如果当前更新的key与最终key不同,也需要删除旧的
if roleName != finalKey {
delete(h.config.Roles, roleName)
}
// 如果角色名称改变,需要删除旧文件
if roleName != finalKey {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 删除旧的角色文件
oldSafeFileName := sanitizeFileName(roleName)
oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml")
oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml")
if _, err := os.Stat(oldRoleFileYaml); err == nil {
if err := os.Remove(oldRoleFileYaml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err))
}
}
if _, err := os.Stat(oldRoleFileYml); err == nil {
if err := os.Remove(oldRoleFileYml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err))
}
}
}
// 使用角色name作为key来保存(确保唯一性)
h.config.Roles[finalKey] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "角色已更新",
"role": req,
})
}
// CreateRole 创建新角色
func (h *RoleHandler) CreateRole(c *gin.Context) {
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 检查角色是否已存在
if _, exists := h.config.Roles[req.Name]; exists {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"})
return
}
// 创建角色(默认启用)
if !req.Enabled {
req.Enabled = true
}
h.config.Roles[req.Name] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("创建角色", zap.String("roleName", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "角色已创建",
"role": req,
})
}
// DeleteRole 删除角色
func (h *RoleHandler) DeleteRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
if _, exists := h.config.Roles[roleName]; !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 不允许删除"默认"角色
if roleName == "默认" {
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"})
return
}
delete(h.config.Roles, roleName)
// 删除对应的角色文件
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 尝试删除角色文件(.yaml 和 .yml)
safeFileName := sanitizeFileName(roleName)
roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml")
roleFileYml := filepath.Join(rolesDir, safeFileName+".yml")
// 删除 .yaml 文件(如果存在)
if _, err := os.Stat(roleFileYaml); err == nil {
if err := os.Remove(roleFileYaml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml))
}
}
// 删除 .yml 文件(如果存在)
if _, err := os.Stat(roleFileYml); err == nil {
if err := os.Remove(roleFileYml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml))
}
}
h.logger.Info("删除角色", zap.String("roleName", roleName))
c.JSON(http.StatusOK, gin.H{
"message": "角色已删除",
})
}
// saveConfig 保存配置到目录中的文件
func (h *RoleHandler) saveConfig() error {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 确保目录存在
if err := os.MkdirAll(rolesDir, 0755); err != nil {
return fmt.Errorf("创建角色目录失败: %w", err)
}
// 保存每个角色到独立的文件
if h.config.Roles != nil {
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
safeFileName := sanitizeFileName(role.Name)
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
// 将角色配置序列化为YAML
roleData, err := yaml.Marshal(&role)
if err != nil {
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
continue
}
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
roleDataStr := string(roleData)
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
// 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
roleData = []byte(roleDataStr)
}
// 写入文件
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
continue
}
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
}
}
return nil
}
// sanitizeFileName 将角色名称转换为安全的文件名
func sanitizeFileName(name string) string {
// 替换可能不安全的字符
replacer := map[rune]string{
'/': "_",
'\\': "_",
':': "_",
'*': "_",
'?': "_",
'"': "_",
'<': "_",
'>': "_",
'|': "_",
' ': "_",
}
var result []rune
for _, r := range name {
if replacement, ok := replacer[r]; ok {
result = append(result, []rune(replacement)...)
} else {
result = append(result, r)
}
}
fileName := string(result)
// 如果文件名为空,使用默认名称
if fileName == "" {
fileName = "role"
}
return fileName
}
// updateRolesConfig 更新角色配置
func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) {
root := doc.Content[0]
rolesNode := ensureMap(root, "roles")
// 清空现有角色
if rolesNode.Kind == yaml.MappingNode {
rolesNode.Content = nil
}
// 添加新角色(使用name作为key,确保唯一性)
if cfg.Roles != nil {
// 先建立一个以name为key的map,去重(保留最后一个)
rolesByName := make(map[string]config.RoleConfig)
for roleKey, role := range cfg.Roles {
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = roleKey
}
// 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个
rolesByName[role.Name] = role
}
// 将去重后的角色写入YAML
for roleName, role := range rolesByName {
roleNode := ensureMap(rolesNode, roleName)
setStringInMap(roleNode, "name", role.Name)
setStringInMap(roleNode, "description", role.Description)
setStringInMap(roleNode, "user_prompt", role.UserPrompt)
if role.Icon != "" {
setStringInMap(roleNode, "icon", role.Icon)
}
setBoolInMap(roleNode, "enabled", role.Enabled)
// 添加工具列表(优先使用tools字段)
if len(role.Tools) > 0 {
toolsNode := ensureArray(roleNode, "tools")
toolsNode.Content = nil
for _, toolKey := range role.Tools {
toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey}
toolsNode.Content = append(toolsNode.Content, toolNode)
}
} else if len(role.MCPs) > 0 {
// 向后兼容:如果没有tools但有mcps,保存mcps
mcpsNode := ensureArray(roleNode, "mcps")
mcpsNode.Content = nil
for _, mcpName := range role.MCPs {
mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName}
mcpsNode.Content = append(mcpsNode.Content, mcpNode)
}
}
}
}
}
// ensureArray 确保数组中存在指定key的数组节点
func ensureArray(parent *yaml.Node, key string) *yaml.Node {
_, valueNode := ensureKeyValue(parent, key)
if valueNode.Kind != yaml.SequenceNode {
valueNode.Kind = yaml.SequenceNode
valueNode.Tag = "!!seq"
valueNode.Content = nil
}
return valueNode
}
+130 -15
View File
@@ -7,6 +7,8 @@ import (
"fmt"
"regexp"
"strings"
"sync"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
@@ -19,6 +21,12 @@ type Indexer struct {
logger *zap.Logger
chunkSize int // 每个块的最大token数(估算)
overlap int // 块之间的重叠token数
// 错误跟踪
mu sync.RWMutex
lastError string // 最近一次错误信息
lastErrorTime time.Time // 最近一次错误时间
errorCount int // 连续错误计数
}
// NewIndexer 创建新的索引器
@@ -267,13 +275,13 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
chunks := idx.ChunkText(content)
idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
// 跟踪该知识项的错误
itemErrorCount := 0
var firstError error
firstErrorChunkIndex := -1
// 向量化每个块(包含category和title信息,以便向量检索时能匹配到风险类型)
for i, chunk := range chunks {
chunkPreview := chunk
if len(chunkPreview) > 200 {
chunkPreview = chunkPreview[:200] + "..."
}
// 将category和title信息包含到向量化的文本中
// 格式:"[风险类型: {category}] [标题: {title}]\n{chunk内容}"
// 这样向量嵌入就会包含风险类型信息,即使SQL过滤失败,向量相似度也能帮助匹配
@@ -281,13 +289,43 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding)
if err != nil {
idx.logger.Warn("向量化失败",
zap.String("itemId", itemID),
zap.Int("chunkIndex", i),
zap.Int("chunkLength", len(chunk)),
zap.String("chunkPreview", chunkPreview),
zap.Error(err),
)
itemErrorCount++
if firstError == nil {
firstError = err
firstErrorChunkIndex = i
// 只在第一个块失败时记录详细日志
chunkPreview := chunk
if len(chunkPreview) > 200 {
chunkPreview = chunkPreview[:200] + "..."
}
idx.logger.Warn("向量化失败",
zap.String("itemId", itemID),
zap.Int("chunkIndex", i),
zap.Int("totalChunks", len(chunks)),
zap.String("chunkPreview", chunkPreview),
zap.Error(err),
)
// 更新全局错误跟踪
errorMsg := fmt.Sprintf("向量化失败 (知识项: %s): %v", itemID, err)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
}
// 如果连续失败2个块,立即停止处理该知识项(降低阈值,更快停止)
// 这样可以避免继续浪费API调用,同时也能更快地检测到配置问题
if itemErrorCount >= 2 {
idx.logger.Error("知识项连续向量化失败,停止处理",
zap.String("itemId", itemID),
zap.Int("totalChunks", len(chunks)),
zap.Int("failedChunks", itemErrorCount),
zap.Int("firstErrorChunkIndex", firstErrorChunkIndex),
zap.Error(firstError),
)
return fmt.Errorf("知识项连续向量化失败 (%d个块失败): %v", itemErrorCount, firstError)
}
continue
}
@@ -321,6 +359,13 @@ func (idx *Indexer) HasIndex() (bool, error) {
// RebuildIndex 重建所有索引
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
// 重置错误跟踪
idx.mu.Lock()
idx.lastError = ""
idx.lastErrorTime = time.Time{}
idx.errorCount = 0
idx.mu.Unlock()
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
if err != nil {
return fmt.Errorf("查询知识项失败: %w", err)
@@ -348,14 +393,84 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
idx.logger.Info("已清空旧索引,开始重建")
}
failedCount := 0
consecutiveFailures := 0
maxConsecutiveFailures := 2 // 连续失败2次后立即停止(降低阈值,更快停止)
firstFailureItemID := ""
var firstFailureError error
for i, itemID := range itemIDs {
if err := idx.IndexItem(ctx, itemID); err != nil {
idx.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
failedCount++
consecutiveFailures++
// 只在第一个失败时记录详细日志
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
idx.logger.Warn("索引知识项失败",
zap.String("itemId", itemID),
zap.Int("totalItems", len(itemIDs)),
zap.Error(err),
)
}
// 如果连续失败过多,可能是配置问题,立即停止索引
if consecutiveFailures >= maxConsecutiveFailures {
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API密钥无效、余额不足等)。第一个失败项: %s, 错误: %v", consecutiveFailures, firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("连续索引失败次数过多,立即停止索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemIDs)),
zap.Int("processedItems", i+1),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
return fmt.Errorf("连续索引失败次数过多: %v", firstFailureError)
}
// 如果失败的知识项过多,记录警告但继续处理(降低阈值到30%)
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项: %s, 错误: %v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("索引失败的知识项过多,可能存在配置问题",
zap.Int("failedCount", failedCount),
zap.Int("totalItems", len(itemIDs)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
}
continue
}
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)))
// 成功时重置连续失败计数和第一个失败信息
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
// 减少进度日志频率(每10个或每10%记录一次)
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
}
}
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)))
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount))
return nil
}
// GetLastError 获取最近一次错误信息
func (idx *Indexer) GetLastError() (string, time.Time) {
idx.mu.RLock()
defer idx.mu.RUnlock()
return idx.lastError, idx.lastErrorTime
}
+8 -8
View File
@@ -161,14 +161,14 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
// 查询所有向量(或按风险类型过滤)
// 使用精确匹配(=)以提高性能和准确性
// 由于系统提供了 list_knowledge_risk_types 工具,用户应该使用准确的category名称
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
var rows *sql.Rows
if req.RiskType != "" {
// 使用精确匹配(=),性能更好且更准确
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
// 建议用户先调用 list_knowledge_risk_types 获取准确的category名称
// 由于系统提供了内置工具来获取风险类型列表,用户应该使用准确的category名称
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
var rows *sql.Rows
if req.RiskType != "" {
// 使用精确匹配(=),性能更好且更准确
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
// 建议用户先调用相应的内置工具获取准确的category名称
rows, err = r.db.Query(`
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
FROM knowledge_embeddings e
+12 -11
View File
@@ -8,6 +8,7 @@ import (
"strings"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
@@ -21,7 +22,7 @@ func RegisterKnowledgeTool(
) {
// 注册第一个工具:获取所有可用的风险类型列表
listRiskTypesTool := mcp.Tool{
Name: "list_knowledge_risk_types",
Name: builtin.ToolListKnowledgeRiskTypes,
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
ShortDescription: "获取知识库中所有可用的风险类型列表",
InputSchema: map[string]interface{}{
@@ -62,7 +63,7 @@ func RegisterKnowledgeTool(
for i, category := range categories {
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
}
resultText.WriteString("\n提示:在调用 search_knowledge_base 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
return &mcp.ToolResult{
Content: []mcp.Content{
@@ -79,8 +80,8 @@ func RegisterKnowledgeTool(
// 注册第二个工具:搜索知识库(保持原有功能)
searchTool := mcp.Tool{
Name: "search_knowledge_base",
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 list_knowledge_risk_types 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
Name: builtin.ToolSearchKnowledgeBase,
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
InputSchema: map[string]interface{}{
"type": "object",
@@ -91,7 +92,7 @@ func RegisterKnowledgeTool(
},
"risk_type": map[string]interface{}{
"type": "string",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 list_knowledge_risk_types 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
},
},
"required": []string{"query"},
@@ -165,9 +166,9 @@ func RegisterKnowledgeTool(
// 按文档分组结果,以便更好地展示上下文
// 使用有序的slice来保持文档顺序(按最高混合分数)
type itemGroup struct {
itemID string
results []*RetrievalResult
maxScore float64 // 该文档的最高混合分数
itemID string
results []*RetrievalResult
maxScore float64 // 该文档的最高混合分数
}
itemGroups := make([]*itemGroup, 0)
itemMap := make(map[string]*itemGroup)
@@ -177,8 +178,8 @@ func RegisterKnowledgeTool(
group, exists := itemMap[itemID]
if !exists {
group = &itemGroup{
itemID: itemID,
results: make([]*RetrievalResult, 0),
itemID: itemID,
results: make([]*RetrievalResult, 0),
maxScore: result.Score,
}
itemMap[itemID] = group
@@ -219,7 +220,7 @@ func RegisterKnowledgeTool(
})
// 显示主结果(混合分数最高的,同时显示相似度和混合分数)
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
resultIndex, mainResult.Similarity*100, mainResult.Score*100))
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
+33
View File
@@ -0,0 +1,33 @@
package builtin
// 内置工具名称常量
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
const (
// 漏洞管理工具
ToolRecordVulnerability = "record_vulnerability"
// 知识库工具
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
ToolSearchKnowledgeBase = "search_knowledge_base"
)
// IsBuiltinTool 检查工具名称是否是内置工具
func IsBuiltinTool(toolName string) bool {
switch toolName {
case ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase:
return true
default:
return false
}
}
// GetAllBuiltinTools 返回所有内置工具名称列表
func GetAllBuiltinTools() []string {
return []string{
ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase,
}
}
+559 -1
View File
@@ -1,13 +1,16 @@
package mcp
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"time"
@@ -100,6 +103,20 @@ func (c *HTTPMCPClient) Initialize(ctx context.Context) error {
return fmt.Errorf("初始化失败: %w", err)
}
// 发送 initialized 通知(MCP 协议要求:收到 initialize 响应后必须发送此通知)
notifyReq := Message{
ID: MessageID{value: nil}, // 通知没有 ID
Method: "notifications/initialized",
Version: "2.0",
}
notifyReq.Params = json.RawMessage("{}")
// 发送通知(不需要等待响应)
if err := c.sendNotification(&notifyReq); err != nil {
c.logger.Warn("发送 initialized 通知失败", zap.Error(err))
// 通知失败不应该导致初始化失败,只记录警告
}
c.setStatus("connected")
return nil
}
@@ -193,6 +210,34 @@ func (c *HTTPMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message
return &mcpResp, nil
}
func (c *HTTPMCPClient) sendNotification(msg *Message) error {
// 通知没有 ID,不需要等待响应
body, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("序列化通知失败: %w", err)
}
// 使用较短的超时发送通知
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("创建HTTP请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
// 发送通知,不等待响应(通知不需要响应)
resp, err := c.client.Do(httpReq)
if err != nil {
return fmt.Errorf("发送通知失败: %w", err)
}
resp.Body.Close()
return nil
}
func (c *HTTPMCPClient) Close() error {
c.setStatus("disconnected")
return nil
@@ -202,6 +247,7 @@ func (c *HTTPMCPClient) Close() error {
type StdioMCPClient struct {
command string
args []string
env map[string]string
timeout time.Duration
cmd *exec.Cmd
stdin io.WriteCloser
@@ -219,7 +265,7 @@ type StdioMCPClient struct {
}
// NewStdioMCPClient 创建stdio模式的MCP客户端
func NewStdioMCPClient(command string, args []string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient {
func NewStdioMCPClient(command string, args []string, env map[string]string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient {
if timeout <= 0 {
timeout = 30 * time.Second
}
@@ -227,6 +273,7 @@ func NewStdioMCPClient(command string, args []string, timeout time.Duration, log
return &StdioMCPClient{
command: command,
args: args,
env: env,
timeout: timeout,
logger: logger,
status: "disconnected",
@@ -289,6 +336,20 @@ func (c *StdioMCPClient) Initialize(ctx context.Context) error {
return fmt.Errorf("初始化失败: %w", err)
}
// 发送 initialized 通知(MCP 协议要求:收到 initialize 响应后必须发送此通知)
notifyReq := Message{
ID: MessageID{value: nil}, // 通知没有 ID
Method: "notifications/initialized",
Version: "2.0",
}
notifyReq.Params = json.RawMessage("{}")
// 发送通知(不需要等待响应)
if err := c.sendNotification(&notifyReq); err != nil {
c.logger.Warn("发送 initialized 通知失败", zap.Error(err))
// 通知失败不应该导致初始化失败,只记录警告
}
c.setStatus("connected")
return nil
}
@@ -296,6 +357,27 @@ func (c *StdioMCPClient) Initialize(ctx context.Context) error {
func (c *StdioMCPClient) startProcess() error {
cmd := exec.CommandContext(c.ctx, c.command, c.args...)
// 设置环境变量
if c.env != nil && len(c.env) > 0 {
// 获取当前环境变量
cmd.Env = os.Environ()
// 添加或覆盖配置的环境变量
for key, value := range c.env {
// 检查是否已存在该环境变量
found := false
for i, envVar := range cmd.Env {
if strings.HasPrefix(envVar, key+"=") {
cmd.Env[i] = key + "=" + value
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, key+"="+value)
}
}
}
stdin, err := cmd.StdinPipe()
if err != nil {
return err
@@ -424,6 +506,20 @@ func (c *StdioMCPClient) ListTools(ctx context.Context) ([]Tool, error) {
return listResp.Tools, nil
}
func (c *StdioMCPClient) sendNotification(msg *Message) error {
// 通知没有 ID,不需要等待响应
if c.encoder == nil {
return fmt.Errorf("进程未启动")
}
// 直接发送通知,不等待响应
if err := c.encoder.Encode(msg); err != nil {
return fmt.Errorf("发送通知失败: %w", err)
}
return nil
}
func (c *StdioMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
req := Message{
ID: MessageID{value: uuid.New().String()},
@@ -472,3 +568,465 @@ func (c *StdioMCPClient) Close() error {
c.setStatus("disconnected")
return nil
}
// SSEMCPClient SSE模式的MCP客户端
type SSEMCPClient struct {
url string
timeout time.Duration
client *http.Client
logger *zap.Logger
mu sync.RWMutex
status string // "disconnected", "connecting", "connected", "error"
sseConn io.ReadCloser
sseCancel context.CancelFunc
requestID int64
responses map[string]chan *Message
responsesMu sync.Mutex
ctx context.Context
}
// NewSSEMCPClient 创建SSE模式的MCP客户端
func NewSSEMCPClient(url string, timeout time.Duration, logger *zap.Logger) *SSEMCPClient {
if timeout <= 0 {
timeout = 30 * time.Second
}
ctx, cancel := context.WithCancel(context.Background())
return &SSEMCPClient{
url: url,
timeout: timeout,
client: &http.Client{Timeout: timeout},
logger: logger,
status: "disconnected",
responses: make(map[string]chan *Message),
ctx: ctx,
sseCancel: cancel,
}
}
func (c *SSEMCPClient) setStatus(status string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = status
}
func (c *SSEMCPClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *SSEMCPClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *SSEMCPClient) Initialize(ctx context.Context) error {
c.setStatus("connecting")
// 建立SSE连接
if err := c.connectSSE(); err != nil {
c.setStatus("error")
return fmt.Errorf("建立SSE连接失败: %w", err)
}
// 启动响应读取goroutine
go c.readSSEResponses()
// 发送初始化请求
req := Message{
ID: MessageID{value: "1"},
Method: "initialize",
Version: "2.0",
}
params := InitializeRequest{
ProtocolVersion: ProtocolVersion,
Capabilities: make(map[string]interface{}),
ClientInfo: ClientInfo{
Name: "CyberStrikeAI",
Version: "1.0.0",
},
}
paramsJSON, _ := json.Marshal(params)
req.Params = paramsJSON
_, err := c.sendRequest(ctx, &req)
if err != nil {
c.setStatus("error")
c.Close()
return fmt.Errorf("初始化失败: %w", err)
}
// 发送 initialized 通知(MCP 协议要求:收到 initialize 响应后必须发送此通知)
notifyReq := Message{
ID: MessageID{value: nil}, // 通知没有 ID
Method: "notifications/initialized",
Version: "2.0",
}
notifyReq.Params = json.RawMessage("{}")
// 发送通知(不需要等待响应)
if err := c.sendNotification(&notifyReq); err != nil {
c.logger.Warn("发送 initialized 通知失败", zap.Error(err))
// 通知失败不应该导致初始化失败,只记录警告
}
c.setStatus("connected")
return nil
}
func (c *SSEMCPClient) connectSSE() error {
// 建立SSE连接(GET请求,Accept: text/event-stream
// SSE连接需要长连接,使用无超时的客户端
sseClient := &http.Client{
Timeout: 0, // 无超时,用于长连接
}
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil)
if err != nil {
return fmt.Errorf("创建SSE请求失败: %w", err)
}
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
resp, err := sseClient.Do(req)
if err != nil {
return fmt.Errorf("SSE连接失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return fmt.Errorf("SSE连接失败,状态码: %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "text/event-stream") {
resp.Body.Close()
return fmt.Errorf("服务器不支持SSEContent-Type: %s", contentType)
}
c.sseConn = resp.Body
return nil
}
func (c *SSEMCPClient) readSSEResponses() {
defer func() {
if r := recover(); r != nil {
c.logger.Error("读取SSE响应时发生panic", zap.Any("error", r))
}
if c.sseConn != nil {
c.sseConn.Close()
}
c.setStatus("disconnected")
}()
if c.sseConn == nil {
return
}
scanner := &sseScanner{reader: bufio.NewReader(c.sseConn)}
for {
select {
case <-c.ctx.Done():
return
default:
}
// 读取SSE事件
event, err := scanner.readEvent()
if err != nil {
if err == io.EOF {
c.setStatus("disconnected")
return
}
c.logger.Error("读取SSE数据失败", zap.Error(err))
return
}
if event == nil || len(event.Data) == 0 {
continue
}
// 解析JSON消息
var msg Message
if err := json.Unmarshal(event.Data, &msg); err != nil {
c.logger.Warn("解析SSE消息失败", zap.Error(err), zap.String("data", string(event.Data)))
continue
}
// 处理响应
id := msg.ID.String()
c.responsesMu.Lock()
if ch, ok := c.responses[id]; ok {
select {
case ch <- &msg:
default:
}
delete(c.responses, id)
}
c.responsesMu.Unlock()
}
}
// sseEvent SSE事件
type sseEvent struct {
Event string
Data []byte
ID string
Retry int
}
// sseScanner SSE扫描器
type sseScanner struct {
reader *bufio.Reader
}
func (s *sseScanner) readEvent() (*sseEvent, error) {
event := &sseEvent{}
for {
line, err := s.reader.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimRight(line, "\r\n")
// 空行表示事件结束
if len(line) == 0 {
if len(event.Data) > 0 {
return event, nil
}
continue
}
// 解析SSE行
if strings.HasPrefix(line, "event: ") {
event.Event = strings.TrimSpace(line[7:])
} else if strings.HasPrefix(line, "data: ") {
data := []byte(strings.TrimSpace(line[6:]))
if len(event.Data) > 0 {
event.Data = append(event.Data, '\n')
}
event.Data = append(event.Data, data...)
} else if strings.HasPrefix(line, "id: ") {
event.ID = strings.TrimSpace(line[4:])
} else if strings.HasPrefix(line, "retry: ") {
fmt.Sscanf(line[7:], "%d", &event.Retry)
}
}
}
func (c *SSEMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
if c.sseConn == nil {
return nil, fmt.Errorf("SSE连接未建立")
}
id := msg.ID.String()
if id == "" {
c.mu.Lock()
c.requestID++
id = fmt.Sprintf("%d", c.requestID)
msg.ID = MessageID{value: id}
c.mu.Unlock()
}
// 创建响应通道
responseCh := make(chan *Message, 1)
c.responsesMu.Lock()
c.responses[id] = responseCh
c.responsesMu.Unlock()
// 通过HTTP POST发送请求(SSE用于接收响应,请求通过POST发送)
body, err := json.Marshal(msg)
if err != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
// 使用POST请求发送消息(通常SSE服务器会提供两个端点:一个用于SSE,一个用于POST)
// 如果URL是SSE端点,尝试使用相同的URL但改为POST,或者使用URL + "/message"
postURL := c.url
if strings.HasSuffix(postURL, "/sse") {
postURL = strings.TrimSuffix(postURL, "/sse")
postURL += "/message"
} else if strings.HasSuffix(postURL, "/events") {
postURL = strings.TrimSuffix(postURL, "/events")
postURL += "/message"
} else if !strings.Contains(postURL, "/message") {
// 如果URL不包含/message,尝试添加
postURL = strings.TrimSuffix(postURL, "/")
postURL += "/message"
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, postURL, bytes.NewReader(body))
if err != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("创建POST请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("发送POST请求失败: %w", err)
}
defer resp.Body.Close()
// 如果POST请求直接返回响应(非SSE模式),直接解析
if resp.StatusCode == http.StatusOK && resp.Header.Get("Content-Type") == "application/json" {
var mcpResp Message
if err := json.NewDecoder(resp.Body).Decode(&mcpResp); err != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if mcpResp.Error != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("MCP错误: %s (code: %d)", mcpResp.Error.Message, mcpResp.Error.Code)
}
return &mcpResp, nil
}
// 否则等待SSE响应
select {
case resp := <-responseCh:
if resp.Error != nil {
return nil, fmt.Errorf("MCP错误: %s (code: %d)", resp.Error.Message, resp.Error.Code)
}
return resp, nil
case <-ctx.Done():
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, ctx.Err()
case <-time.After(c.timeout):
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("请求超时")
}
}
func (c *SSEMCPClient) ListTools(ctx context.Context) ([]Tool, error) {
req := Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/list",
Version: "2.0",
}
req.Params = json.RawMessage("{}")
resp, err := c.sendRequest(ctx, &req)
if err != nil {
return nil, fmt.Errorf("获取工具列表失败: %w", err)
}
var listResp ListToolsResponse
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
return nil, fmt.Errorf("解析工具列表失败: %w", err)
}
return listResp.Tools, nil
}
func (c *SSEMCPClient) sendNotification(msg *Message) error {
// 通知没有 ID,不需要等待响应
if c.sseConn == nil {
return fmt.Errorf("SSE连接未建立")
}
body, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("序列化通知失败: %w", err)
}
// 使用 POST 发送通知(与 sendRequest 类似的逻辑)
postURL := c.url
if strings.HasSuffix(postURL, "/sse") {
postURL = strings.TrimSuffix(postURL, "/sse")
postURL += "/message"
} else if strings.HasSuffix(postURL, "/events") {
postURL = strings.TrimSuffix(postURL, "/events")
postURL += "/message"
} else if !strings.Contains(postURL, "/message") {
postURL = strings.TrimSuffix(postURL, "/")
postURL += "/message"
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, postURL, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("创建POST请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
// 发送通知,不等待响应(通知不需要响应)
resp, err := c.client.Do(httpReq)
if err != nil {
return fmt.Errorf("发送通知失败: %w", err)
}
resp.Body.Close()
return nil
}
func (c *SSEMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
req := Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/call",
Version: "2.0",
}
callReq := CallToolRequest{
Name: name,
Arguments: args,
}
paramsJSON, _ := json.Marshal(callReq)
req.Params = paramsJSON
resp, err := c.sendRequest(ctx, &req)
if err != nil {
return nil, fmt.Errorf("调用工具失败: %w", err)
}
var callResp CallToolResponse
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
return nil, fmt.Errorf("解析工具调用结果失败: %w", err)
}
return &ToolResult{
Content: callResp.Content,
IsError: callResp.IsError,
}, nil
}
func (c *SSEMCPClient) Close() error {
c.sseCancel()
if c.sseConn != nil {
c.sseConn.Close()
c.sseConn = nil
}
c.setStatus("disconnected")
return nil
}
+189 -46
View File
@@ -16,14 +16,18 @@ import (
// ExternalMCPManager 外部MCP管理器
type ExternalMCPManager struct {
clients map[string]ExternalMCPClient
configs map[string]config.ExternalMCPServerConfig
logger *zap.Logger
storage MonitorStorage // 可选的持久化存储
executions map[string]*ToolExecution // 执行记录
stats map[string]*ToolStats // 工具统计信息
errors map[string]string // 错误信息
mu sync.RWMutex
clients map[string]ExternalMCPClient
configs map[string]config.ExternalMCPServerConfig
logger *zap.Logger
storage MonitorStorage // 可选的持久化存储
executions map[string]*ToolExecution // 执行记录
stats map[string]*ToolStats // 工具统计信息
errors map[string]string // 错误信息
toolCounts map[string]int // 工具数量缓存
toolCountsMu sync.RWMutex // 工具数量缓存的锁
stopRefresh chan struct{} // 停止后台刷新的信号
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
mu sync.RWMutex
}
// NewExternalMCPManager 创建外部MCP管理器
@@ -33,15 +37,20 @@ func NewExternalMCPManager(logger *zap.Logger) *ExternalMCPManager {
// NewExternalMCPManagerWithStorage 创建外部MCP管理器(带持久化存储)
func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage) *ExternalMCPManager {
return &ExternalMCPManager{
clients: make(map[string]ExternalMCPClient),
configs: make(map[string]config.ExternalMCPServerConfig),
logger: logger,
storage: storage,
executions: make(map[string]*ToolExecution),
stats: make(map[string]*ToolStats),
errors: make(map[string]string),
manager := &ExternalMCPManager{
clients: make(map[string]ExternalMCPClient),
configs: make(map[string]config.ExternalMCPServerConfig),
logger: logger,
storage: storage,
executions: make(map[string]*ToolExecution),
stats: make(map[string]*ToolStats),
errors: make(map[string]string),
toolCounts: make(map[string]int),
stopRefresh: make(chan struct{}),
}
// 启动后台刷新工具数量的goroutine
manager.startToolCountRefresh()
return manager
}
// LoadConfigs 加载配置
@@ -104,6 +113,12 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error {
}
delete(m.configs, name)
// 清理工具数量缓存
m.toolCountsMu.Lock()
delete(m.toolCounts, name)
m.toolCountsMu.Unlock()
return nil
}
@@ -174,11 +189,15 @@ func (m *ExternalMCPManager) StartClient(name string) error {
m.mu.Lock()
m.errors[name] = err.Error()
m.mu.Unlock()
// 触发工具数量刷新(连接失败,工具数量应为0)
m.triggerToolCountRefresh()
} else {
// 连接成功,清除错误信息
m.mu.Lock()
delete(m.errors, name)
m.mu.Unlock()
// 连接成功,立即刷新工具数量
m.triggerToolCountRefresh()
}
}()
@@ -204,6 +223,11 @@ func (m *ExternalMCPManager) StopClient(name string) error {
// 清除错误信息
delete(m.errors, name)
// 更新工具数量缓存(停止后工具数量为0)
m.toolCountsMu.Lock()
m.toolCounts[name] = 0
m.toolCountsMu.Unlock()
// 更新配置为禁用
serverCfg.ExternalMCPEnable = false
m.configs[name] = serverCfg
@@ -532,30 +556,50 @@ func (m *ExternalMCPManager) GetToolStats() map[string]*ToolStats {
return result
}
// GetToolCount 获取指定外部MCP的工具数量
// GetToolCount 获取指定外部MCP的工具数量(从缓存读取,不阻塞)
func (m *ExternalMCPManager) GetToolCount(name string) (int, error) {
// 先从缓存读取
m.toolCountsMu.RLock()
if count, exists := m.toolCounts[name]; exists {
m.toolCountsMu.RUnlock()
return count, nil
}
m.toolCountsMu.RUnlock()
// 如果缓存中没有,检查客户端状态
client, exists := m.GetClient(name)
if !exists {
return 0, fmt.Errorf("客户端不存在: %s", name)
}
if !client.IsConnected() {
// 未连接,缓存为0
m.toolCountsMu.Lock()
m.toolCounts[name] = 0
m.toolCountsMu.Unlock()
return 0, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
tools, err := client.ListTools(ctx)
if err != nil {
return 0, fmt.Errorf("获取工具列表失败: %w", err)
}
return len(tools), nil
// 如果已连接但缓存中没有,触发异步刷新并返回0(避免阻塞)
m.triggerToolCountRefresh()
return 0, nil
}
// GetToolCounts 获取所有外部MCP的工具数量
// GetToolCounts 获取所有外部MCP的工具数量(从缓存读取,不阻塞)
func (m *ExternalMCPManager) GetToolCounts() map[string]int {
m.toolCountsMu.RLock()
defer m.toolCountsMu.RUnlock()
// 返回缓存的副本,避免外部修改
result := make(map[string]int)
for k, v := range m.toolCounts {
result[k] = v
}
return result
}
// refreshToolCounts 刷新工具数量缓存(后台异步执行)
func (m *ExternalMCPManager) refreshToolCounts() {
m.mu.RLock()
clients := make(map[string]ExternalMCPClient)
for k, v := range m.clients {
@@ -563,30 +607,104 @@ func (m *ExternalMCPManager) GetToolCounts() map[string]int {
}
m.mu.RUnlock()
result := make(map[string]int)
newCounts := make(map[string]int)
// 使用goroutine并发获取每个客户端的工具数量,避免串行阻塞
type countResult struct {
name string
count int
}
resultChan := make(chan countResult, len(clients))
for name, client := range clients {
if !client.IsConnected() {
result[name] = 0
continue
}
go func(n string, c ExternalMCPClient) {
if !c.IsConnected() {
resultChan <- countResult{name: n, count: 0}
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
tools, err := client.ListTools(ctx)
cancel()
// 使用合理的超时时间(15秒),既能应对网络延迟,又不会过长阻塞
// 由于这是后台异步刷新,超时不会影响前端响应
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
tools, err := c.ListTools(ctx)
cancel()
if err != nil {
m.logger.Warn("获取外部MCP工具数量失败",
zap.String("name", name),
zap.Error(err),
)
result[name] = 0
continue
}
if err != nil {
m.logger.Debug("获取外部MCP工具数量失败",
zap.String("name", n),
zap.Error(err),
)
// 如果获取失败,保留旧值(在更新时处理)
resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值
return
}
result[name] = len(tools)
resultChan <- countResult{name: n, count: len(tools)}
}(name, client)
}
return result
// 收集结果
m.toolCountsMu.RLock()
oldCounts := make(map[string]int)
for k, v := range m.toolCounts {
oldCounts[k] = v
}
m.toolCountsMu.RUnlock()
for i := 0; i < len(clients); i++ {
result := <-resultChan
if result.count >= 0 {
newCounts[result.name] = result.count
} else {
// 获取失败,保留旧值
if oldCount, exists := oldCounts[result.name]; exists {
newCounts[result.name] = oldCount
} else {
newCounts[result.name] = 0
}
}
}
// 更新缓存
m.toolCountsMu.Lock()
// 更新所有获取到的值
for name, count := range newCounts {
m.toolCounts[name] = count
}
// 对于未连接的客户端,设置为0
for name, client := range clients {
if !client.IsConnected() {
m.toolCounts[name] = 0
}
}
m.toolCountsMu.Unlock()
}
// startToolCountRefresh 启动后台刷新工具数量的goroutine
func (m *ExternalMCPManager) startToolCountRefresh() {
m.refreshWg.Add(1)
go func() {
defer m.refreshWg.Done()
ticker := time.NewTicker(10 * time.Second) // 每10秒刷新一次
defer ticker.Stop()
// 立即执行一次刷新
m.refreshToolCounts()
for {
select {
case <-ticker.C:
m.refreshToolCounts()
case <-m.stopRefresh:
return
}
}
}()
}
// triggerToolCountRefresh 触发立即刷新工具数量(异步)
func (m *ExternalMCPManager) triggerToolCountRefresh() {
go m.refreshToolCounts()
}
// createClient 创建客户端(不连接)
@@ -603,6 +721,7 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
if serverCfg.Command != "" {
transport = "stdio"
} else if serverCfg.URL != "" {
// 默认使用http,但可以通过transport字段指定sse
transport = "http"
} else {
return nil
@@ -619,7 +738,12 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
if serverCfg.Command == "" {
return nil
}
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, timeout, m.logger)
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, serverCfg.Env, timeout, m.logger)
case "sse":
if serverCfg.URL == "" {
return nil
}
return NewSSEMCPClient(serverCfg.URL, timeout, m.logger)
default:
return nil
}
@@ -654,6 +778,8 @@ func (m *ExternalMCPManager) setClientStatus(client ExternalMCPClient, status st
c.setStatus(status)
case *StdioMCPClient:
c.setStatus(status)
case *SSEMCPClient:
c.setStatus(status)
}
}
@@ -693,6 +819,9 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa
zap.String("name", name),
)
// 连接成功,触发工具数量刷新
m.triggerToolCountRefresh()
return nil
}
@@ -791,4 +920,18 @@ func (m *ExternalMCPManager) StopAll() {
client.Close()
delete(m.clients, name)
}
// 清理所有工具数量缓存
m.toolCountsMu.Lock()
m.toolCounts = make(map[string]int)
m.toolCountsMu.Unlock()
// 停止后台刷新(使用 select 避免重复关闭 channel
select {
case <-m.stopRefresh:
// 已经关闭,不需要再次关闭
default:
close(m.stopRefresh)
m.refreshWg.Wait()
}
}
+1 -1
View File
@@ -180,7 +180,7 @@ func TestStdioMCPClient_Initialize(t *testing.T) {
// 注意:这个测试需要一个真实的stdio MCP服务器
// 如果没有服务器,这个测试会失败
logger := zap.NewNop()
client := NewStdioMCPClient("echo", []string{"test"}, 5*time.Second, logger)
client := NewStdioMCPClient("echo", []string{"test"}, nil, 5*time.Second, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
+20
View File
@@ -0,0 +1,20 @@
name: API安全测试
description: API安全测试专家,专注于API接口安全检测
user_prompt: 你是一个专业的API安全测试专家。请使用专业的API测试工具对目标API接口进行全面的安全检测,包括GraphQL安全、API参数fuzzing、JWT分析、API架构分析等工作。
icon: "\U0001F4E1"
tools:
- api-fuzzer
- api-schema-analyzer
- graphql-scanner
- arjun
- jwt-analyzer
- http-intruder
- http-framework-test
- burpsuite
- httpx
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+33
View File
@@ -0,0 +1,33 @@
name: CTF
description: CTF竞赛专家,擅长解题和漏洞利用
user_prompt: 你是一个CTF竞赛专家。请使用CTF解题思维和方法,快速定位和利用漏洞,解决各类CTF题目。
icon: "\U0001F3C6"
tools:
- amass
- anew
- angr
- api-fuzzer
- api-schema-analyzer
- arjun
- arp-scan
- autorecon
- binwalk
- bloodhound
- burpsuite
- cat
- checkov
- checksec
- cloudmapper
- create-file
- cyberchef
- dalfox
- delete-file
- httpx
- http-framework-test
- exec
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+25
View File
@@ -0,0 +1,25 @@
name: Web应用扫描
description: Web应用漏洞扫描专家,全面的Web安全检测
user_prompt: 你是一个专业的Web应用漏洞扫描专家。请使用各种Web扫描工具对目标Web应用进行全面的安全检测,包括目录枚举、文件扫描、漏洞识别等工作。
icon: "\U0001F310"
tools:
- dirsearch
- dirb
- gobuster
- feroxbuster
- ffuf
- wfuzz
- sqlmap
- dalfox
- xsser
- nikto
- nuclei
- wpscan
- httpx
- http-framework-test
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+19
View File
@@ -0,0 +1,19 @@
name: Web框架测试
description: Web框架安全测试专家,专注于Web应用框架漏洞检测
user_prompt: 你是一个专业的Web框架安全测试专家。请使用专业的工具对Web应用框架进行安全测试,识别框架相关的安全漏洞和配置问题。
icon: "\U0001F310"
tools:
- http-framework-test
- nikto
- nuclei
- wafw00f
- wpscan
- httpx
- burpsuite
- zap
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+31
View File
@@ -0,0 +1,31 @@
name: 二进制分析
description: 二进制分析与利用专家,擅长逆向工程和密码破解
user_prompt: 你是一个专业的二进制分析与利用专家。请使用逆向工程工具分析二进制文件,识别漏洞,进行利用开发。同时擅长密码破解、哈希分析等技术。
icon: "\U0001F52C"
tools:
- dirsearch
- docker-bench-security
- exec
- execute-python-script
- install-python-package
- ghidra
- graphql-scanner
- hakrawler
- hash-identifier
- hashcat
- hashpump
- http-framework-test
- httpx
- gdb
- radare2
- objdump
- strings
- binwalk
- ropper
- ropgadget
- john
- cyberchef
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+17
View File
@@ -0,0 +1,17 @@
name: 云安全审计
description: 云安全审计专家,多云环境安全检测
user_prompt: 你是一个专业的云安全审计专家。请使用专业的云安全工具对AWS、Azure、GCP等云环境进行全面的安全审计,包括配置检查、合规性评估、权限审计、安全最佳实践验证等工作。
icon: "\U00002601"
tools:
- prowler
- scout-suite
- cloudmapper
- pacu
- terrascan
- checkov
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+31
View File
@@ -0,0 +1,31 @@
name: 信息收集
description: 资产发现与信息搜集专家
user_prompt: 你是一个专业的信息收集专家。请使用各种信息收集技术和工具,对目标进行全面的资产发现、子域名枚举、端口扫描、服务识别等信息收集工作。
icon: "\U0001F50D"
tools:
- amass
- subfinder
- dnsenum
- fierce
- fofa_search
- zoomeye_search
- nmap
- masscan
- rustscan
- arp-scan
- nbtscan
- httpx
- http-framework-test
- katana
- hakrawler
- waybackurls
- paramspider
- gau
- uro
- qsreplace
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+23
View File
@@ -0,0 +1,23 @@
name: 后渗透测试
description: 后渗透测试专家,权限维持与横向移动
user_prompt: 你是一个专业的后渗透测试专家。请使用专业的后渗透工具在获得初始访问权限后进行权限提升、横向移动、权限维持、数据收集等后渗透测试工作。
icon: "\U0001F575"
tools:
- linpeas
- winpeas
- mimikatz
- bloodhound
- impacket
- responder
- netexec
- rpcclient
- smbmap
- enum4linux
- enum4linux-ng
- exec
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+18
View File
@@ -0,0 +1,18 @@
name: 容器安全
description: 容器与Kubernetes安全专家,容器环境安全检测
user_prompt: 你是一个专业的容器与Kubernetes安全专家。请使用专业的容器安全工具对Docker容器和Kubernetes集群进行全面的安全检测,包括镜像漏洞扫描、配置检查、运行时安全等工作。
icon: "\U0001F6E1"
tools:
- trivy
- clair
- docker-bench-security
- kube-bench
- kube-hunter
- falco
- exec
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+24
View File
@@ -0,0 +1,24 @@
name: 数字取证
description: 数字取证与隐写分析专家,文件与内存取证
user_prompt: 你是一个专业的数字取证与隐写分析专家。请使用专业的取证工具对文件、磁盘镜像、内存转储进行分析,提取证据信息。同时擅长隐写分析、数据恢复、元数据提取等技术。
icon: "\U0001F50E"
tools:
- volatility
- volatility3
- foremost
- steghide
- stegsolve
- zsteg
- exiftool
- binwalk
- strings
- xxd
- fcrackzip
- pdfcrack
- exec
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+33
View File
@@ -0,0 +1,33 @@
name: 渗透测试
description: 专业渗透测试专家,全面深入的漏洞检测
user_prompt: 你是一个专业的网络安全渗透测试专家。请使用专业的渗透测试方法和工具,对目标进行全面的安全测试,包括但不限于SQL注入、XSS、CSRF、文件包含、命令执行等常见漏洞。
icon: "\U0001F3AF"
tools:
- http-framework-test
- httpx
- amass
- anew
- angr
- api-fuzzer
- api-schema-analyzer
- arjun
- arp-scan
- autorecon
- binwalk
- bloodhound
- burpsuite
- cat
- checkov
- checksec
- cloudmapper
- create-file
- cyberchef
- dalfox
- delete-file
- exec
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+23
View File
@@ -0,0 +1,23 @@
name: 综合漏洞扫描
description: 综合漏洞扫描专家,多类型漏洞检测
user_prompt: 你是一个专业的综合漏洞扫描专家。请使用各种漏洞扫描工具对目标进行全面的安全检测,包括Web漏洞、网络服务漏洞、配置缺陷等多种类型的漏洞识别和分析。
icon: "\U000026A0"
tools:
- nuclei
- nikto
- sqlmap
- nmap
- masscan
- rustscan
- wafw00f
- dalfox
- xsser
- jaeles
- httpx
- http-framework-test
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+5
View File
@@ -0,0 +1,5 @@
name: 默认
description: 默认角色,不额外携带用户提示词,使用默认MCP
user_prompt: ""
icon: "\U0001F535"
enabled: true
+183 -37
View File
@@ -2,59 +2,205 @@
set -euo pipefail
# CyberStrikeAI 启动脚本
# CyberStrikeAI 一键部署启动脚本
ROOT_DIR="$(cd "$(dirname "$0")" && pwd)"
cd "$ROOT_DIR"
echo "🚀 启动 CyberStrikeAI..."
# 颜色定义
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# 打印带颜色的消息
info() { echo -e "${BLUE}$1${NC}"; }
success() { echo -e "${GREEN}$1${NC}"; }
warning() { echo -e "${YELLOW}⚠️ $1${NC}"; }
error() { echo -e "${RED}$1${NC}"; }
echo ""
echo "=========================================="
echo " CyberStrikeAI 一键部署启动脚本"
echo "=========================================="
echo ""
CONFIG_FILE="$ROOT_DIR/config.yaml"
VENV_DIR="$ROOT_DIR/venv"
REQUIREMENTS_FILE="$ROOT_DIR/requirements.txt"
BINARY_NAME="cyberstrike-ai"
# 检查配置文件
if [ ! -f "$CONFIG_FILE" ]; then
echo "配置文件 config.yaml 不存在"
error "配置文件 config.yaml 不存在"
info "请确保在项目根目录运行此脚本"
exit 1
fi
# 检查 Python 环境
if ! command -v python3 >/dev/null 2>&1; then
echo "❌ 未找到 python3,请先安装 Python 3.10+"
exit 1
fi
# 检查并安装 Python 环境
check_python() {
if ! command -v python3 >/dev/null 2>&1; then
error "未找到 python3"
echo ""
info "请先安装 Python 3.10 或更高版本:"
echo " macOS: brew install python3"
echo " Ubuntu: sudo apt-get install python3 python3-venv"
echo " CentOS: sudo yum install python3 python3-pip"
exit 1
fi
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
PYTHON_MAJOR=$(echo "$PYTHON_VERSION" | cut -d. -f1)
PYTHON_MINOR=$(echo "$PYTHON_VERSION" | cut -d. -f2)
if [ "$PYTHON_MAJOR" -lt 3 ] || ([ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -lt 10 ]); then
error "Python 版本过低: $PYTHON_VERSION (需要 3.10+)"
exit 1
fi
success "Python 环境检查通过: $PYTHON_VERSION"
}
# 创建并激活虚拟环境
if [ ! -d "$VENV_DIR" ]; then
echo "🐍 创建 Python 虚拟环境..."
python3 -m venv "$VENV_DIR"
fi
# 检查并安装 Go 环境
check_go() {
if ! command -v go >/dev/null 2>&1; then
error "未找到 Go"
echo ""
info "请先安装 Go 1.21 或更高版本:"
echo " macOS: brew install go"
echo " Ubuntu: sudo apt-get install golang-go"
echo " CentOS: sudo yum install golang"
echo " 或访问: https://go.dev/dl/"
exit 1
fi
GO_VERSION=$(go version | awk '{print $3}' | sed 's/go//')
GO_MAJOR=$(echo "$GO_VERSION" | cut -d. -f1)
GO_MINOR=$(echo "$GO_VERSION" | cut -d. -f2)
if [ "$GO_MAJOR" -lt 1 ] || ([ "$GO_MAJOR" -eq 1 ] && [ "$GO_MINOR" -lt 21 ]); then
error "Go 版本过低: $GO_VERSION (需要 1.21+)"
exit 1
fi
success "Go 环境检查通过: $(go version)"
}
echo "🐍 激活虚拟环境..."
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"
# 设置 Python 虚拟环境
setup_python_env() {
if [ ! -d "$VENV_DIR" ]; then
info "创建 Python 虚拟环境..."
python3 -m venv "$VENV_DIR"
success "虚拟环境创建完成"
else
info "Python 虚拟环境已存在"
fi
info "激活虚拟环境..."
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"
if [ -f "$REQUIREMENTS_FILE" ]; then
info "安装/更新 Python 依赖..."
pip install --quiet --upgrade pip >/dev/null 2>&1 || true
# 尝试安装依赖,捕获错误输出
PIP_LOG=$(mktemp)
if pip install -r "$REQUIREMENTS_FILE" >"$PIP_LOG" 2>&1; then
success "Python 依赖安装完成"
else
# 检查是否是 angr 安装失败(需要 Rust)
if grep -q "angr" "$PIP_LOG" && grep -q "Rust compiler\|can't find Rust" "$PIP_LOG"; then
warning "angr 安装失败(需要 Rust 编译器)"
echo ""
info "angr 是可选依赖,主要用于二进制分析工具"
info "如果需要使用 angr,请先安装 Rust"
echo " macOS: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
echo " Ubuntu: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
echo " 或访问: https://rustup.rs/"
echo ""
info "其他依赖已安装,可以继续使用(部分工具可能不可用)"
else
warning "部分 Python 依赖安装失败,但可以继续尝试运行"
warning "如果遇到问题,请检查错误信息并手动安装缺失的依赖"
# 显示最后几行错误信息
echo ""
info "错误详情(最后 10 行):"
tail -n 10 "$PIP_LOG" | sed 's/^/ /'
echo ""
fi
fi
rm -f "$PIP_LOG"
else
warning "未找到 requirements.txt,跳过 Python 依赖安装"
fi
}
if [ -f "$REQUIREMENTS_FILE" ]; then
echo "📦 安装/更新 Python 依赖..."
pip install -r "$REQUIREMENTS_FILE"
else
echo "⚠️ 未找到 requirements.txt,跳过 Python 依赖安装"
fi
# 构建 Go 项目
build_go_project() {
info "下载 Go 依赖..."
go mod download >/dev/null 2>&1 || {
error "Go 依赖下载失败"
exit 1
}
info "构建项目..."
if go build -o "$BINARY_NAME" cmd/server/main.go 2>&1; then
success "项目构建完成: $BINARY_NAME"
else
error "项目构建失败"
exit 1
fi
}
# 检查 Go 环境
if ! command -v go >/dev/null 2>&1; then
echo "❌ Go 未安装,请先安装 Go 1.21 或更高版本"
exit 1
fi
# 检查是否需要重新构建
need_rebuild() {
if [ ! -f "$BINARY_NAME" ]; then
return 0 # 需要构建
fi
# 检查源代码是否有更新
if [ "$BINARY_NAME" -ot cmd/server/main.go ] || \
[ "$BINARY_NAME" -ot go.mod ] || \
find internal cmd -name "*.go" -newer "$BINARY_NAME" 2>/dev/null | grep -q .; then
return 0 # 需要重新构建
fi
return 1 # 不需要构建
}
# 下载依赖
echo "📦 下载 Go 依赖..."
go mod download
# 主流程
main() {
# 环境检查
info "检查运行环境..."
check_python
check_go
echo ""
# 设置 Python 环境
info "设置 Python 环境..."
setup_python_env
echo ""
# 构建 Go 项目
if need_rebuild; then
info "准备构建项目..."
build_go_project
else
success "可执行文件已是最新,跳过构建"
fi
echo ""
# 启动服务器
success "所有准备工作完成!"
echo ""
info "启动 CyberStrikeAI 服务器..."
echo "=========================================="
echo ""
# 运行服务器
exec "./$BINARY_NAME"
}
# 构建项目
echo "🔨 构建项目..."
go build -o cyberstrike-ai cmd/server/main.go
# 运行服务器
echo "✅ 启动服务器..."
./cyberstrike-ai
# 执行主流程
main
+1243 -22
View File
File diff suppressed because it is too large Load Diff
+27
View File
@@ -0,0 +1,27 @@
/**
* 内置工具名称常量
* 所有前端代码中使用内置工具名称的地方都应该使用这些常量而不是硬编码字符串
*
* 注意这些常量必须与后端的 internal/mcp/builtin/constants.go 中的常量保持一致
*/
// 内置工具名称常量
const BuiltinTools = {
// 漏洞管理工具
RECORD_VULNERABILITY: 'record_vulnerability',
// 知识库工具
LIST_KNOWLEDGE_RISK_TYPES: 'list_knowledge_risk_types',
SEARCH_KNOWLEDGE_BASE: 'search_knowledge_base'
};
// 检查是否是内置工具
function isBuiltinTool(toolName) {
return Object.values(BuiltinTools).includes(toolName);
}
// 获取所有内置工具名称列表
function getAllBuiltinTools() {
return Object.values(BuiltinTools);
}
+299 -23
View File
@@ -14,6 +14,9 @@ const mentionState = {
selectedIndex: 0,
};
// IME输入法状态跟踪
let isComposing = false;
// 输入框草稿保存相关
const DRAFT_STORAGE_KEY = 'cyberstrike-chat-draft';
let draftSaveTimer = null;
@@ -85,13 +88,20 @@ function clearChatDraft() {
function adjustTextareaHeight(textarea) {
if (!textarea) return;
// 重置高度以获取准确的scrollHeight
textarea.style.height = '44px';
// 重置高度为auto,然后立即设置为固定值,确保能准确获取scrollHeight
textarea.style.height = 'auto';
// 强制浏览器重新计算布局
void textarea.offsetHeight;
// 计算新高度(最小44px,最大不超过300px)
const scrollHeight = textarea.scrollHeight;
const newHeight = Math.min(Math.max(scrollHeight, 44), 300);
textarea.style.height = newHeight + 'px';
// 如果内容为空或只有很少内容,立即重置到最小高度
if (!textarea.value || textarea.value.trim().length === 0) {
textarea.style.height = '44px';
}
}
// 发送消息
@@ -135,6 +145,9 @@ async function sendMessage() {
let mcpExecutionIds = [];
try {
// 获取当前选中的角色(从 roles.js 的函数获取)
const roleName = typeof getCurrentRole === 'function' ? getCurrentRole() : '';
const response = await apiFetch('/api/agent-loop/stream', {
method: 'POST',
headers: {
@@ -142,7 +155,8 @@ async function sendMessage() {
},
body: JSON.stringify({
message: message,
conversationId: currentConversationId
conversationId: currentConversationId,
role: roleName || undefined
}),
});
@@ -242,6 +256,13 @@ if (typeof window !== 'undefined') {
}
function ensureMentionToolsLoaded() {
// 检查角色是否改变,如果改变则强制重新加载
if (typeof window !== 'undefined' && window._mentionToolsRoleChanged) {
mentionToolsLoaded = false;
mentionTools = [];
delete window._mentionToolsRoleChanged;
}
if (mentionToolsLoaded) {
return Promise.resolve(mentionTools);
}
@@ -254,6 +275,16 @@ function ensureMentionToolsLoaded() {
return mentionToolsLoadingPromise;
}
// 生成工具的唯一标识符,用于区分同名但来源不同的工具
function getToolKeyForMention(tool) {
// 如果是外部工具,使用 external_mcp::tool.name 作为唯一标识
// 如果是内部工具,使用 tool.name 作为标识
if (tool.is_external && tool.external_mcp) {
return `${tool.external_mcp}::${tool.name}`;
}
return tool.name;
}
async function fetchMentionTools() {
const pageSize = 100;
let page = 1;
@@ -262,6 +293,9 @@ async function fetchMentionTools() {
const collected = [];
try {
// 获取当前选中的角色(从 roles.js 的函数获取)
const roleName = typeof getCurrentRole === 'function' ? getCurrentRole() : '';
// 同时获取外部MCP列表
try {
const mcpResponse = await apiFetch('/api/external-mcp');
@@ -280,23 +314,45 @@ async function fetchMentionTools() {
}
while (page <= totalPages && page <= 20) {
const response = await apiFetch(`/api/config/tools?page=${page}&page_size=${pageSize}`);
// 构建API URL,如果指定了角色,添加role查询参数
let url = `/api/config/tools?page=${page}&page_size=${pageSize}`;
if (roleName && roleName !== '默认') {
url += `&role=${encodeURIComponent(roleName)}`;
}
const response = await apiFetch(url);
if (!response.ok) {
break;
}
const result = await response.json();
const tools = Array.isArray(result.tools) ? result.tools : [];
tools.forEach(tool => {
if (!tool || !tool.name || seen.has(tool.name)) {
if (!tool || !tool.name) {
return;
}
seen.add(tool.name);
// 使用唯一标识符来去重,而不是只使用工具名称
const toolKey = getToolKeyForMention(tool);
if (seen.has(toolKey)) {
return;
}
seen.add(toolKey);
// 确定工具在当前角色中的启用状态
// 如果有 role_enabled 字段,使用它(表示指定了角色)
// 否则使用 enabled 字段(表示未指定角色或使用所有工具)
let roleEnabled = tool.enabled !== false;
if (tool.role_enabled !== undefined && tool.role_enabled !== null) {
roleEnabled = tool.role_enabled;
}
collected.push({
name: tool.name,
description: tool.description || '',
enabled: tool.enabled !== false,
enabled: tool.enabled !== false, // 工具本身的启用状态
roleEnabled: roleEnabled, // 在当前角色中的启用状态
isExternal: !!tool.is_external,
externalMcp: tool.external_mcp || '',
toolKey: toolKey, // 保存唯一标识符
});
});
totalPages = result.total_pages || 1;
@@ -317,7 +373,10 @@ function handleChatInputInput(event) {
const textarea = event.target;
updateMentionStateFromInput(textarea);
// 自动调整输入框高度
adjustTextareaHeight(textarea);
// 使用requestAnimationFrame确保在DOM更新后立即调整,特别是在删除内容时
requestAnimationFrame(() => {
adjustTextareaHeight(textarea);
});
// 保存输入内容到localStorage(防抖)
saveChatDraftDebounced(textarea.value);
}
@@ -327,6 +386,12 @@ function handleChatInputClick(event) {
}
function handleChatInputKeydown(event) {
// 如果正在使用输入法输入(IME),回车键应该用于确认候选词,而不是发送消息
// 使用 event.isComposing 或 isComposing 标志来判断
if (event.isComposing || isComposing) {
return;
}
if (mentionState.active && mentionSuggestionsEl && mentionSuggestionsEl.style.display !== 'none') {
if (event.key === 'ArrowDown') {
event.preventDefault();
@@ -453,6 +518,15 @@ function updateMentionCandidates() {
}
filtered = filtered.slice().sort((a, b) => {
// 如果指定了角色,优先显示在当前角色中启用的工具
if (a.roleEnabled !== undefined || b.roleEnabled !== undefined) {
const aRoleEnabled = a.roleEnabled !== undefined ? a.roleEnabled : a.enabled;
const bRoleEnabled = b.roleEnabled !== undefined ? b.roleEnabled : b.enabled;
if (aRoleEnabled !== bRoleEnabled) {
return aRoleEnabled ? -1 : 1; // 启用的工具排在前面
}
}
if (normalizedQuery) {
// 精确匹配MCP名称的工具优先显示
const aMcpExact = a.externalMcp && a.externalMcp.toLowerCase() === normalizedQuery;
@@ -467,8 +541,11 @@ function updateMentionCandidates() {
return aStarts ? -1 : 1;
}
}
if (a.enabled !== b.enabled) {
return a.enabled ? -1 : 1;
// 如果指定了角色,使用 roleEnabled;否则使用 enabled
const aEnabled = a.roleEnabled !== undefined ? a.roleEnabled : a.enabled;
const bEnabled = b.roleEnabled !== undefined ? b.roleEnabled : b.enabled;
if (aEnabled !== bEnabled) {
return aEnabled ? -1 : 1;
}
return a.name.localeCompare(b.name, 'zh-CN');
});
@@ -510,13 +587,16 @@ function renderMentionSuggestions({ showLoading = false } = {}) {
const itemsHtml = mentionFilteredTools.map((tool, index) => {
const activeClass = index === mentionState.selectedIndex ? 'active' : '';
const disabledClass = tool.enabled ? '' : 'disabled';
// 如果工具有 roleEnabled 字段(指定了角色),使用它;否则使用 enabled
const toolEnabled = tool.roleEnabled !== undefined ? tool.roleEnabled : tool.enabled;
const disabledClass = toolEnabled ? '' : 'disabled';
const badge = tool.isExternal ? '<span class="mention-item-badge">外部</span>' : '<span class="mention-item-badge internal">内置</span>';
const nameHtml = escapeHtml(tool.name);
const description = tool.description && tool.description.length > 0 ? escapeHtml(tool.description) : '暂无描述';
const descHtml = `<div class="mention-item-desc">${description}</div>`;
const statusLabel = tool.enabled ? '可用' : '已禁用';
const statusClass = tool.enabled ? 'enabled' : 'disabled';
// 根据工具在当前角色中的启用状态显示状态标签
const statusLabel = toolEnabled ? '可用' : (tool.roleEnabled !== undefined ? '已禁用(当前角色)' : '已禁用');
const statusClass = toolEnabled ? 'enabled' : 'disabled';
const originLabel = tool.isExternal
? (tool.externalMcp ? `来源:${escapeHtml(tool.externalMcp)}` : '来源:外部MCP')
: '来源:内置工具';
@@ -800,6 +880,24 @@ function addMessage(role, content, mcpExecutionIds = null, progressId = null, cr
bubble.innerHTML = formattedContent;
contentWrapper.appendChild(bubble);
// 保存原始内容到消息元素,用于复制功能
if (role === 'assistant') {
messageDiv.dataset.originalContent = content;
}
// 为助手消息添加复制按钮(复制整个回复内容)- 放在消息气泡右下角
if (role === 'assistant') {
const copyBtn = document.createElement('button');
copyBtn.className = 'message-copy-btn';
copyBtn.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><rect x="9" y="9" width="13" height="13" rx="2" ry="2" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" fill="none"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" fill="none"/></svg><span>复制</span>';
copyBtn.title = '复制消息内容';
copyBtn.onclick = function(e) {
e.stopPropagation();
copyMessageToClipboard(messageDiv, this);
};
bubble.appendChild(copyBtn);
}
// 添加时间戳
const timeDiv = document.createElement('div');
timeDiv.className = 'message-time';
@@ -869,6 +967,69 @@ function addMessage(role, content, mcpExecutionIds = null, progressId = null, cr
return id;
}
// 复制消息内容到剪贴板(使用原始Markdown格式)
function copyMessageToClipboard(messageDiv, button) {
try {
// 获取保存的原始Markdown内容
const originalContent = messageDiv.dataset.originalContent;
if (!originalContent) {
// 如果没有保存原始内容,尝试从渲染后的HTML提取(降级方案)
const bubble = messageDiv.querySelector('.message-bubble');
if (bubble) {
const tempDiv = document.createElement('div');
tempDiv.innerHTML = bubble.innerHTML;
// 移除复制按钮本身(避免复制按钮文本)
const copyBtnInTemp = tempDiv.querySelector('.message-copy-btn');
if (copyBtnInTemp) {
copyBtnInTemp.remove();
}
// 提取纯文本内容
let textContent = tempDiv.textContent || tempDiv.innerText || '';
textContent = textContent.replace(/\n{3,}/g, '\n\n').trim();
navigator.clipboard.writeText(textContent).then(() => {
showCopySuccess(button);
}).catch(err => {
console.error('复制失败:', err);
alert('复制失败,请手动选择内容复制');
});
}
return;
}
// 使用原始Markdown内容
navigator.clipboard.writeText(originalContent).then(() => {
showCopySuccess(button);
}).catch(err => {
console.error('复制失败:', err);
alert('复制失败,请手动选择内容复制');
});
} catch (error) {
console.error('复制消息时出错:', error);
alert('复制失败,请手动选择内容复制');
}
}
// 显示复制成功提示
function showCopySuccess(button) {
if (button) {
const originalText = button.innerHTML;
button.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M20 6L9 17l-5-5" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round" fill="none"/></svg><span>已复制</span>';
button.style.color = '#10b981';
button.style.background = 'rgba(16, 185, 129, 0.1)';
button.style.borderColor = 'rgba(16, 185, 129, 0.3)';
setTimeout(() => {
button.innerHTML = originalText;
button.style.color = '';
button.style.background = '';
button.style.borderColor = '';
}, 2000);
}
}
// 渲染过程详情
function renderProcessDetails(messageId, processDetails) {
const messageElement = document.getElementById(messageId);
@@ -993,7 +1154,7 @@ function renderProcessDetails(messageId, processDetails) {
itemTitle = `${statusIcon} 工具 ${escapeHtml(toolName)} 执行${success ? '完成' : '失败'}`;
// 如果是知识检索工具,添加特殊标记
if (toolName === 'search_knowledge_base' && success) {
if (toolName === BuiltinTools.SEARCH_KNOWLEDGE_BASE && success) {
itemTitle = `📚 ${itemTitle} - 知识检索`;
}
} else if (eventType === 'knowledge_retrieval') {
@@ -1041,6 +1202,13 @@ if (chatInput) {
chatInput.addEventListener('input', handleChatInputInput);
chatInput.addEventListener('click', handleChatInputClick);
chatInput.addEventListener('focus', handleChatInputClick);
// IME输入法事件监听,用于跟踪输入法状态
chatInput.addEventListener('compositionstart', () => {
isComposing = true;
});
chatInput.addEventListener('compositionend', () => {
isComposing = false;
});
chatInput.addEventListener('blur', () => {
setTimeout(() => {
if (!chatInput.matches(':focus')) {
@@ -5148,7 +5316,8 @@ async function enterGroupDetail(groupId) {
// 刷新分组列表,确保当前分组高亮显示
await loadGroups();
loadGroupConversations(groupId);
// 加载分组对话(如果有搜索查询则使用搜索查询)
loadGroupConversations(groupId, currentGroupSearchQuery);
} catch (error) {
console.error('加载分组失败:', error);
currentGroupId = null;
@@ -5158,6 +5327,14 @@ async function enterGroupDetail(groupId) {
// 退出分组详情
function exitGroupDetail() {
currentGroupId = null;
currentGroupSearchQuery = ''; // 清除搜索状态
// 隐藏搜索框并清除搜索内容
const searchContainer = document.getElementById('group-search-container');
const searchInput = document.getElementById('group-search-input');
if (searchContainer) searchContainer.style.display = 'none';
if (searchInput) searchInput.value = '';
const sidebar = document.querySelector('.conversation-sidebar');
const groupDetailPage = document.getElementById('group-detail-page');
const chatContainer = document.querySelector('.chat-container');
@@ -5172,7 +5349,7 @@ function exitGroupDetail() {
}
// 加载分组中的对话
async function loadGroupConversations(groupId) {
async function loadGroupConversations(groupId, searchQuery = '') {
try {
if (!groupId) {
console.error('loadGroupConversations: groupId is null or undefined');
@@ -5190,10 +5367,20 @@ async function loadGroupConversations(groupId) {
console.error('group-conversations-list element not found');
return;
}
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">加载中...</div>';
// 显示加载状态
if (searchQuery) {
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">搜索中...</div>';
} else {
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">加载中...</div>';
}
// 确保使用正确的 groupId
const url = `/api/groups/${groupId}/conversations`;
// 构建URL,如果有搜索关键词则添加search参数
let url = `/api/groups/${groupId}/conversations`;
if (searchQuery && searchQuery.trim()) {
url += '?search=' + encodeURIComponent(searchQuery.trim());
}
const response = await apiFetch(url);
if (!response.ok) {
console.error(`Failed to load conversations for group ${groupId}:`, response.statusText);
@@ -5235,7 +5422,11 @@ async function loadGroupConversations(groupId) {
list.innerHTML = '';
if (groupConvs.length === 0) {
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">该分组暂无对话</div>';
if (searchQuery && searchQuery.trim()) {
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">未找到匹配的对话</div>';
} else {
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">该分组暂无对话</div>';
}
return;
}
@@ -5631,9 +5822,94 @@ function closeGroupContextMenu() {
}
// 分组搜索(占位函数)
function searchInGroup() {
alert('搜索功能待实现');
// 分组搜索相关变量
let groupSearchTimer = null;
let currentGroupSearchQuery = '';
// 切换分组搜索框显示/隐藏
function toggleGroupSearch() {
const searchContainer = document.getElementById('group-search-container');
const searchInput = document.getElementById('group-search-input');
if (!searchContainer || !searchInput) return;
if (searchContainer.style.display === 'none') {
searchContainer.style.display = 'block';
searchInput.focus();
} else {
searchContainer.style.display = 'none';
clearGroupSearch();
}
}
// 处理分组搜索输入
function handleGroupSearchInput(event) {
// 支持回车键搜索
if (event.key === 'Enter') {
event.preventDefault();
performGroupSearch();
return;
}
// 支持ESC键关闭搜索
if (event.key === 'Escape') {
clearGroupSearch();
toggleGroupSearch();
return;
}
const searchInput = document.getElementById('group-search-input');
const clearBtn = document.getElementById('group-search-clear-btn');
if (!searchInput) return;
const query = searchInput.value.trim();
// 显示/隐藏清除按钮
if (clearBtn) {
clearBtn.style.display = query ? 'block' : 'none';
}
// 防抖搜索
if (groupSearchTimer) {
clearTimeout(groupSearchTimer);
}
groupSearchTimer = setTimeout(() => {
performGroupSearch();
}, 300); // 300ms 防抖
}
// 执行分组搜索
async function performGroupSearch() {
const searchInput = document.getElementById('group-search-input');
if (!searchInput || !currentGroupId) return;
const query = searchInput.value.trim();
currentGroupSearchQuery = query;
// 加载搜索结果
await loadGroupConversations(currentGroupId, query);
}
// 清除分组搜索
function clearGroupSearch() {
const searchInput = document.getElementById('group-search-input');
const clearBtn = document.getElementById('group-search-clear-btn');
if (searchInput) {
searchInput.value = '';
}
if (clearBtn) {
clearBtn.style.display = 'none';
}
currentGroupSearchQuery = '';
// 重新加载分组对话(不搜索)
if (currentGroupId) {
loadGroupConversations(currentGroupId, '');
}
}
// 初始化时加载分组
+93 -2
View File
@@ -457,6 +457,7 @@ async function updateIndexProgress() {
const indexedItems = status.indexed_items || 0;
const progressPercent = status.progress_percent || 0;
const isComplete = status.is_complete || false;
const lastError = status.last_error || '';
if (totalItems === 0) {
// 没有知识项,隐藏进度条
@@ -471,6 +472,58 @@ async function updateIndexProgress() {
// 显示进度条
progressContainer.style.display = 'block';
// 如果有错误信息,显示错误
if (lastError) {
progressContainer.innerHTML = `
<div class="knowledge-index-progress-error" style="
background: #fee;
border: 1px solid #fcc;
border-radius: 8px;
padding: 16px;
margin-bottom: 16px;
">
<div style="display: flex; align-items: center; margin-bottom: 8px;">
<span style="font-size: 20px; margin-right: 8px;"></span>
<span style="font-weight: bold; color: #c00;">索引构建失败</span>
</div>
<div style="color: #666; font-size: 14px; margin-bottom: 12px; line-height: 1.5;">
${escapeHtml(lastError)}
</div>
<div style="color: #999; font-size: 12px; margin-bottom: 12px;">
可能的原因嵌入模型配置错误API密钥无效余额不足等请检查配置后重试
</div>
<div style="display: flex; gap: 8px;">
<button onclick="rebuildKnowledgeIndex()" style="
background: #007bff;
color: white;
border: none;
padding: 6px 12px;
border-radius: 4px;
cursor: pointer;
font-size: 13px;
">重试</button>
<button onclick="stopIndexProgressPolling()" style="
background: #6c757d;
color: white;
border: none;
padding: 6px 12px;
border-radius: 4px;
cursor: pointer;
font-size: 13px;
">关闭</button>
</div>
</div>
`;
// 停止轮询
if (indexProgressInterval) {
clearInterval(indexProgressInterval);
indexProgressInterval = null;
}
// 显示错误通知
showNotification('索引构建失败: ' + lastError.substring(0, 100), 'error');
return;
}
if (isComplete) {
progressContainer.innerHTML = `
<div class="knowledge-index-progress-complete">
@@ -503,8 +556,46 @@ async function updateIndexProgress() {
}
}
} catch (error) {
// 静默失败
console.debug('获取索引状态失败:', error);
// 显示错误信息
console.error('获取索引状态失败:', error);
const progressContainer = document.getElementById('knowledge-index-progress');
if (progressContainer) {
progressContainer.style.display = 'block';
progressContainer.innerHTML = `
<div class="knowledge-index-progress-error" style="
background: #fee;
border: 1px solid #fcc;
border-radius: 8px;
padding: 16px;
margin-bottom: 16px;
">
<div style="display: flex; align-items: center; margin-bottom: 8px;">
<span style="font-size: 20px; margin-right: 8px;"></span>
<span style="font-weight: bold; color: #c00;">无法获取索引状态</span>
</div>
<div style="color: #666; font-size: 14px;">
无法连接到服务器获取索引状态请检查网络连接或刷新页面
</div>
</div>
`;
}
// 停止轮询
if (indexProgressInterval) {
clearInterval(indexProgressInterval);
indexProgressInterval = null;
}
}
}
// 停止索引进度轮询
function stopIndexProgressPolling() {
if (indexProgressInterval) {
clearInterval(indexProgressInterval);
indexProgressInterval = null;
}
const progressContainer = document.getElementById('knowledge-index-progress');
if (progressContainer) {
progressContainer.style.display = 'none';
}
}
File diff suppressed because it is too large Load Diff
+24 -1
View File
@@ -8,7 +8,7 @@ function initRouter() {
if (hash) {
const hashParts = hash.split('?');
const pageId = hashParts[0];
if (pageId && ['chat', 'vulnerabilities', 'mcp-monitor', 'mcp-management', 'knowledge-management', 'knowledge-retrieval-logs', 'settings', 'tasks'].includes(pageId)) {
if (pageId && ['chat', 'vulnerabilities', 'mcp-monitor', 'mcp-management', 'knowledge-management', 'knowledge-retrieval-logs', 'roles-management', 'settings', 'tasks'].includes(pageId)) {
switchPage(pageId);
// 如果是chat页面且带有conversation参数,加载对应对话
@@ -94,6 +94,19 @@ function updateNavState(pageId) {
knowledgeItem.classList.add('expanded');
}
const submenuItem = document.querySelector(`.nav-submenu-item[data-page="${pageId}"]`);
if (submenuItem) {
submenuItem.classList.add('active');
}
} else if (pageId === 'roles-management') {
// 角色子菜单项
const rolesItem = document.querySelector('.nav-item[data-page="roles"]');
if (rolesItem) {
rolesItem.classList.add('active');
// 展开角色子菜单
rolesItem.classList.add('expanded');
}
const submenuItem = document.querySelector(`.nav-submenu-item[data-page="${pageId}"]`);
if (submenuItem) {
submenuItem.classList.add('active');
@@ -239,6 +252,16 @@ function initPage(pageId) {
loadConfig(false);
}
break;
case 'roles-management':
// 初始化角色管理页面
if (typeof loadRoles === 'function') {
loadRoles().then(() => {
if (typeof renderRolesList === 'function') {
renderRolesList();
}
});
}
break;
}
// 清理其他页面的定时器
+91 -40
View File
@@ -2,8 +2,18 @@
let currentConfig = null;
let allTools = [];
// 全局工具状态映射,用于保存用户在所有页面的修改
// key: tool.name, value: { enabled: boolean, is_external: boolean, external_mcp: string }
// key: 唯一工具标识符(toolKey),value: { enabled: boolean, is_external: boolean, external_mcp: string }
let toolStateMap = new Map();
// 生成工具的唯一标识符,用于区分同名但来源不同的工具
function getToolKey(tool) {
// 如果是外部工具,使用 external_mcp::tool.name 作为唯一标识
// 如果是内部工具,使用 tool.name 作为标识
if (tool.is_external && tool.external_mcp) {
return `${tool.external_mcp}::${tool.name}`;
}
return tool.name;
}
// 从localStorage读取每页显示数量,默认为20
const getToolsPageSize = () => {
const saved = localStorage.getItem('toolsPageSize');
@@ -199,11 +209,13 @@ async function loadToolsList(page = 1, searchKeyword = '') {
// 初始化工具状态映射(如果工具不在映射中,使用服务器返回的状态)
allTools.forEach(tool => {
if (!toolStateMap.has(tool.name)) {
toolStateMap.set(tool.name, {
const toolKey = getToolKey(tool);
if (!toolStateMap.has(toolKey)) {
toolStateMap.set(toolKey, {
enabled: tool.enabled,
is_external: tool.is_external || false,
external_mcp: tool.external_mcp || ''
external_mcp: tool.external_mcp || '',
name: tool.name // 保存原始工具名称
});
}
});
@@ -223,14 +235,16 @@ async function loadToolsList(page = 1, searchKeyword = '') {
function saveCurrentPageToolStates() {
document.querySelectorAll('#tools-list .tool-item').forEach(item => {
const checkbox = item.querySelector('input[type="checkbox"]');
const toolKey = item.dataset.toolKey; // 使用唯一标识符
const toolName = item.dataset.toolName;
const isExternal = item.dataset.isExternal === 'true';
const externalMcp = item.dataset.externalMcp || '';
if (toolName && checkbox) {
toolStateMap.set(toolName, {
if (toolKey && checkbox) {
toolStateMap.set(toolKey, {
enabled: checkbox.checked,
is_external: isExternal,
external_mcp: externalMcp
external_mcp: externalMcp,
name: toolName // 保存原始工具名称
});
}
});
@@ -283,14 +297,16 @@ function renderToolsList() {
}
allTools.forEach(tool => {
const toolKey = getToolKey(tool); // 生成唯一标识符
const toolItem = document.createElement('div');
toolItem.className = 'tool-item';
toolItem.dataset.toolKey = toolKey; // 保存唯一标识符
toolItem.dataset.toolName = tool.name; // 保存原始工具名称
toolItem.dataset.isExternal = tool.is_external ? 'true' : 'false';
toolItem.dataset.externalMcp = tool.external_mcp || '';
// 从全局状态映射获取工具状态,如果不存在则使用服务器返回的状态
const toolState = toolStateMap.get(tool.name) || {
const toolState = toolStateMap.get(toolKey) || {
enabled: tool.enabled,
is_external: tool.is_external || false,
external_mcp: tool.external_mcp || ''
@@ -298,15 +314,18 @@ function renderToolsList() {
// 外部工具标签,显示来源信息
let externalBadge = '';
if (toolState.is_external) {
const externalMcpName = toolState.external_mcp || '';
if (toolState.is_external || tool.is_external) {
const externalMcpName = toolState.external_mcp || tool.external_mcp || '';
const badgeText = externalMcpName ? `外部 (${escapeHtml(externalMcpName)})` : '外部';
const badgeTitle = externalMcpName ? `外部MCP工具 - 来源:${escapeHtml(externalMcpName)}` : '外部MCP工具';
externalBadge = `<span class="external-tool-badge" title="${badgeTitle}">${badgeText}</span>`;
}
// 生成唯一的checkbox id,使用工具唯一标识符
const checkboxId = `tool-${escapeHtml(toolKey).replace(/::/g, '--')}`;
toolItem.innerHTML = `
<input type="checkbox" id="tool-${tool.name}" ${toolState.enabled ? 'checked' : ''} ${toolState.is_external ? 'data-external="true"' : ''} onchange="handleToolCheckboxChange('${tool.name}', this.checked)" />
<input type="checkbox" id="${checkboxId}" ${toolState.enabled ? 'checked' : ''} ${toolState.is_external || tool.is_external ? 'data-external="true"' : ''} onchange="handleToolCheckboxChange('${escapeHtml(toolKey)}', this.checked)" />
<div class="tool-item-info">
<div class="tool-item-name">
${escapeHtml(tool.name)}
@@ -376,16 +395,18 @@ function renderToolsPagination() {
}
// 处理工具checkbox状态变化
function handleToolCheckboxChange(toolName, enabled) {
function handleToolCheckboxChange(toolKey, enabled) {
// 更新全局状态映射
const toolItem = document.querySelector(`.tool-item[data-tool-name="${toolName}"]`);
const toolItem = document.querySelector(`.tool-item[data-tool-key="${toolKey}"]`);
if (toolItem) {
const toolName = toolItem.dataset.toolName;
const isExternal = toolItem.dataset.isExternal === 'true';
const externalMcp = toolItem.dataset.externalMcp || '';
toolStateMap.set(toolName, {
toolStateMap.set(toolKey, {
enabled: enabled,
is_external: isExternal,
external_mcp: externalMcp
external_mcp: externalMcp,
name: toolName // 保存原始工具名称
});
}
updateToolsStats();
@@ -398,14 +419,16 @@ function selectAllTools() {
// 更新全局状态映射
const toolItem = checkbox.closest('.tool-item');
if (toolItem) {
const toolKey = toolItem.dataset.toolKey;
const toolName = toolItem.dataset.toolName;
const isExternal = toolItem.dataset.isExternal === 'true';
const externalMcp = toolItem.dataset.externalMcp || '';
if (toolName) {
toolStateMap.set(toolName, {
if (toolKey) {
toolStateMap.set(toolKey, {
enabled: true,
is_external: isExternal,
external_mcp: externalMcp
external_mcp: externalMcp,
name: toolName // 保存原始工具名称
});
}
}
@@ -420,14 +443,16 @@ function deselectAllTools() {
// 更新全局状态映射
const toolItem = checkbox.closest('.tool-item');
if (toolItem) {
const toolKey = toolItem.dataset.toolKey;
const toolName = toolItem.dataset.toolName;
const isExternal = toolItem.dataset.isExternal === 'true';
const externalMcp = toolItem.dataset.externalMcp || '';
if (toolName) {
toolStateMap.set(toolName, {
if (toolKey) {
toolStateMap.set(toolKey, {
enabled: false,
is_external: isExternal,
external_mcp: externalMcp
external_mcp: externalMcp,
name: toolName // 保存原始工具名称
});
}
}
@@ -484,11 +509,13 @@ async function updateToolsStats() {
totalTools = allTools.length;
totalEnabled = allTools.filter(tool => {
// 优先使用全局状态映射,否则使用checkbox状态,最后使用服务器返回的状态
const savedState = toolStateMap.get(tool.name);
const toolKey = getToolKey(tool);
const savedState = toolStateMap.get(toolKey);
if (savedState !== undefined) {
return savedState.enabled;
}
const checkbox = document.getElementById(`tool-${tool.name}`);
const checkboxId = `tool-${toolKey.replace(/::/g, '--')}`;
const checkbox = document.getElementById(checkboxId);
return checkbox ? checkbox.checked : tool.enabled;
}).length;
} else {
@@ -498,16 +525,18 @@ async function updateToolsStats() {
// 从当前页的checkbox获取状态(如果全局映射中没有)
allTools.forEach(tool => {
const savedState = toolStateMap.get(tool.name);
const toolKey = getToolKey(tool);
const savedState = toolStateMap.get(toolKey);
if (savedState !== undefined) {
localStateMap.set(tool.name, savedState.enabled);
localStateMap.set(toolKey, savedState.enabled);
} else {
const checkbox = document.getElementById(`tool-${tool.name}`);
const checkboxId = `tool-${toolKey.replace(/::/g, '--')}`;
const checkbox = document.getElementById(checkboxId);
if (checkbox) {
localStateMap.set(tool.name, checkbox.checked);
localStateMap.set(toolKey, checkbox.checked);
} else {
// 如果checkbox不存在(不在当前页),使用工具原始状态
localStateMap.set(tool.name, tool.enabled);
localStateMap.set(toolKey, tool.enabled);
}
}
});
@@ -527,9 +556,10 @@ async function updateToolsStats() {
const pageResult = await pageResponse.json();
pageResult.tools.forEach(tool => {
// 优先使用全局状态映射,否则使用服务器返回的状态
if (!localStateMap.has(tool.name)) {
const savedState = toolStateMap.get(tool.name);
localStateMap.set(tool.name, savedState ? savedState.enabled : tool.enabled);
const toolKey = getToolKey(tool);
if (!localStateMap.has(toolKey)) {
const savedState = toolStateMap.get(toolKey);
localStateMap.set(toolKey, savedState ? savedState.enabled : tool.enabled);
}
});
@@ -665,8 +695,9 @@ async function applySettings() {
// 将工具添加到映射中
// 优先使用全局状态映射中的状态(用户修改过的),否则使用服务器返回的状态
pageResult.tools.forEach(tool => {
const savedState = toolStateMap.get(tool.name);
allToolsMap.set(tool.name, {
const toolKey = getToolKey(tool);
const savedState = toolStateMap.get(toolKey);
allToolsMap.set(toolKey, {
name: tool.name,
enabled: savedState ? savedState.enabled : tool.enabled,
is_external: savedState ? savedState.is_external : (tool.is_external || false),
@@ -683,7 +714,7 @@ async function applySettings() {
}
// 将所有工具添加到配置中
allToolsMap.forEach(tool => {
allToolsMap.forEach((tool, toolKey) => {
config.tools.push({
name: tool.name,
enabled: tool.enabled,
@@ -694,7 +725,9 @@ async function applySettings() {
} catch (error) {
console.warn('获取所有工具列表失败,仅使用全局状态映射', error);
// 如果获取失败,使用全局状态映射
toolStateMap.forEach((toolData, toolName) => {
toolStateMap.forEach((toolData, toolKey) => {
// toolData.name 保存了原始工具名称
const toolName = toolData.name || toolKey.split('::').pop();
config.tools.push({
name: toolName,
enabled: toolData.enabled,
@@ -777,8 +810,9 @@ async function saveToolsConfig() {
// 将工具添加到映射中
pageResult.tools.forEach(tool => {
const savedState = toolStateMap.get(tool.name);
allToolsMap.set(tool.name, {
const toolKey = getToolKey(tool);
const savedState = toolStateMap.get(toolKey);
allToolsMap.set(toolKey, {
name: tool.name,
enabled: savedState ? savedState.enabled : tool.enabled,
is_external: savedState ? savedState.is_external : (tool.is_external || false),
@@ -795,7 +829,7 @@ async function saveToolsConfig() {
}
// 将所有工具添加到配置中
allToolsMap.forEach(tool => {
allToolsMap.forEach((tool, toolKey) => {
config.tools.push({
name: tool.name,
enabled: tool.enabled,
@@ -806,7 +840,9 @@ async function saveToolsConfig() {
} catch (error) {
console.warn('获取所有工具列表失败,仅使用全局状态映射', error);
// 如果获取失败,使用全局状态映射
toolStateMap.forEach((toolData, toolName) => {
toolStateMap.forEach((toolData, toolKey) => {
// toolData.name 保存了原始工具名称
const toolName = toolData.name || toolKey.split('::').pop();
config.tools.push({
name: toolName,
enabled: toolData.enabled,
@@ -1158,6 +1194,14 @@ function loadExternalMCPExample() {
],
description: "示例描述",
timeout: 300
},
"cyberstrike-ai-http": {
transport: "http",
url: "http://127.0.0.1:8081/mcp"
},
"cyberstrike-ai-sse": {
transport: "sse",
url: "http://127.0.0.1:8081/mcp/sse"
}
};
@@ -1231,7 +1275,7 @@ async function saveExternalMCP() {
// 验证配置内容
const transport = config.transport || (config.command ? 'stdio' : config.url ? 'http' : '');
if (!transport) {
errorDiv.textContent = `配置错误: "${name}" 需要指定commandstdio模式)或urlhttp模式)`;
errorDiv.textContent = `配置错误: "${name}" 需要指定commandstdio模式)或urlhttp/sse模式)`;
errorDiv.style.display = 'block';
jsonTextarea.classList.add('error');
return;
@@ -1250,6 +1294,13 @@ async function saveExternalMCP() {
jsonTextarea.classList.add('error');
return;
}
if (transport === 'sse' && !config.url) {
errorDiv.textContent = `配置错误: "${name}" sse模式需要url字段`;
errorDiv.style.display = 'block';
jsonTextarea.classList.add('error');
return;
}
}
// 清除错误提示
+42 -8
View File
@@ -720,8 +720,12 @@ const batchQueuesState = {
function showBatchImportModal() {
const modal = document.getElementById('batch-import-modal');
const input = document.getElementById('batch-tasks-input');
const titleInput = document.getElementById('batch-queue-title');
if (modal && input) {
input.value = '';
if (titleInput) {
titleInput.value = '';
}
updateBatchImportStats('');
modal.style.display = 'block';
input.focus();
@@ -765,6 +769,7 @@ document.addEventListener('DOMContentLoaded', function() {
// 创建批量任务队列
async function createBatchQueue() {
const input = document.getElementById('batch-tasks-input');
const titleInput = document.getElementById('batch-queue-title');
if (!input) return;
const text = input.value.trim();
@@ -780,13 +785,16 @@ async function createBatchQueue() {
return;
}
// 获取标题(可选)
const title = titleInput ? titleInput.value.trim() : '';
try {
const response = await apiFetch('/api/batch-tasks', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ tasks }),
body: JSON.stringify({ title, tasks }),
});
if (!response.ok) {
@@ -885,6 +893,11 @@ function renderBatchQueues() {
return;
}
// 确保分页控件可见(重置之前可能设置的 display: none
if (pagination) {
pagination.style.display = '';
}
list.innerHTML = queues.map(queue => {
const statusMap = {
'pending': { text: '待执行', class: 'batch-queue-status-pending' },
@@ -918,10 +931,13 @@ function renderBatchQueues() {
// 允许删除待执行、已完成或已取消状态的队列
const canDelete = queue.status === 'pending' || queue.status === 'completed' || queue.status === 'cancelled';
const titleDisplay = queue.title ? `<span class="batch-queue-title" style="font-weight: 600; color: var(--text-primary); margin-right: 8px;">${escapeHtml(queue.title)}</span>` : '';
return `
<div class="batch-queue-item" data-queue-id="${queue.id}" onclick="showBatchQueueDetail('${queue.id}')">
<div class="batch-queue-header">
<div class="batch-queue-info" style="flex: 1;">
${titleDisplay}
<span class="batch-queue-status ${status.class}">${status.text}</span>
<span class="batch-queue-id">队列ID: ${escapeHtml(queue.id)}</span>
<span class="batch-queue-time">创建时间: ${new Date(queue.createdAt).toLocaleString('zh-CN')}</span>
@@ -962,9 +978,13 @@ function renderBatchQueuesPagination() {
// 如果没有数据,不显示分页控件
if (total === 0) {
paginationContainer.innerHTML = '';
paginationContainer.style.display = 'none';
return;
}
// 确保分页控件可见
paginationContainer.style.display = '';
// 即使只有一页,也显示分页信息(总数和每页条数选择器)
// 计算显示的页码范围
@@ -1100,7 +1120,7 @@ async function showBatchQueueDetail(queueId) {
batchQueuesState.currentQueueId = queueId;
if (title) {
title.textContent = '批量任务队列';
title.textContent = queue.title ? `批量任务队列 - ${escapeHtml(queue.title)}` : '批量任务队列';
}
// 更新按钮显示
@@ -1146,19 +1166,33 @@ async function showBatchQueueDetail(queueId) {
content.innerHTML = `
<div class="batch-queue-detail-info">
${queue.title ? `<div class="detail-item">
<span class="detail-label">任务标题</span>
<span class="detail-value">${escapeHtml(queue.title)}</span>
</div>` : ''}
<div class="detail-item">
<strong>队列ID:</strong> <code>${escapeHtml(queue.id)}</code>
<span class="detail-label">队列ID</span>
<span class="detail-value"><code>${escapeHtml(queue.id)}</code></span>
</div>
<div class="detail-item">
<strong>状态:</strong> <span class="batch-queue-status ${queueStatusMap[queue.status]?.class || ''}">${queueStatusMap[queue.status]?.text || queue.status}</span>
<span class="detail-label">状态</span>
<span class="detail-value"><span class="batch-queue-status ${queueStatusMap[queue.status]?.class || ''}">${queueStatusMap[queue.status]?.text || queue.status}</span></span>
</div>
<div class="detail-item">
<strong>创建时间:</strong> ${new Date(queue.createdAt).toLocaleString('zh-CN')}
<span class="detail-label">创建时间</span>
<span class="detail-value">${new Date(queue.createdAt).toLocaleString('zh-CN')}</span>
</div>
${queue.startedAt ? `<div class="detail-item"><strong>开始时间:</strong> ${new Date(queue.startedAt).toLocaleString('zh-CN')}</div>` : ''}
${queue.completedAt ? `<div class="detail-item"><strong>完成时间:</strong> ${new Date(queue.completedAt).toLocaleString('zh-CN')}</div>` : ''}
${queue.startedAt ? `<div class="detail-item">
<span class="detail-label">开始时间</span>
<span class="detail-value">${new Date(queue.startedAt).toLocaleString('zh-CN')}</span>
</div>` : ''}
${queue.completedAt ? `<div class="detail-item">
<span class="detail-label">完成时间</span>
<span class="detail-value">${new Date(queue.completedAt).toLocaleString('zh-CN')}</span>
</div>` : ''}
<div class="detail-item">
<strong>任务总数:</strong> ${queue.tasks.length}
<span class="detail-label">任务总数</span>
<span class="detail-value">${queue.tasks.length}</span>
</div>
</div>
<div class="batch-queue-tasks-list">
+195 -11
View File
@@ -136,6 +136,23 @@
</div>
</div>
</div>
<div class="nav-item nav-item-has-submenu" data-page="roles">
<div class="nav-item-content" data-title="角色" onclick="toggleSubmenu('roles')">
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M20 21v-2a4 4 0 0 0-4-4H8a4 4 0 0 0-4 4v2" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<circle cx="12" cy="7" r="4" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<span>角色</span>
<svg class="submenu-arrow" width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M9 18l6-6-6-6" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</div>
<div class="nav-submenu">
<div class="nav-submenu-item" data-page="roles-management" onclick="switchPage('roles-management')">
<span>角色管理</span>
</div>
</div>
</div>
<div class="nav-item" data-page="settings">
<div class="nav-item-content" data-title="系统设置" onclick="switchPage('settings')">
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
@@ -162,7 +179,7 @@
</div>
<div class="sidebar-content">
<!-- 全局搜索 -->
<div class="conversation-search-box" style="margin-bottom: 16px;">
<div class="conversation-search-box" style="margin-bottom: 16px; margin-top: 16px;">
<input type="text" id="conversation-search-input" placeholder="搜索历史记录..."
oninput="handleConversationSearch(this.value)"
onkeypress="if(event.key === 'Enter') handleConversationSearch(this.value)" />
@@ -218,7 +235,7 @@
</button>
<h2 id="group-detail-title" class="group-detail-title"></h2>
<div class="group-detail-actions">
<button class="group-action-btn" onclick="searchInGroup()" title="搜索">
<button class="group-action-btn" onclick="toggleGroupSearch()" title="搜索" id="group-search-toggle-btn">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="11" cy="11" r="8" stroke="currentColor" stroke-width="2"/>
<path d="m21 21-4.35-4.35" stroke="currentColor" stroke-width="2" stroke-linecap="round"/>
@@ -237,6 +254,17 @@
</button>
</div>
</div>
<div id="group-search-container" class="group-search-container" style="display: none;">
<div class="group-search-input-wrapper">
<input type="text" id="group-search-input" class="group-search-input" placeholder="搜索分组中的对话..." onkeyup="handleGroupSearchInput(event)" oninput="handleGroupSearchInput(event)">
<button class="group-search-clear-btn" onclick="clearGroupSearch()" title="清除搜索" id="group-search-clear-btn" style="display: none;">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="12" cy="12" r="10" stroke="currentColor" stroke-width="2"/>
<path d="m8 8 8 8M16 8l-8 8" stroke="currentColor" stroke-width="2" stroke-linecap="round"/>
</svg>
</button>
</div>
</div>
<div class="group-detail-content">
<div id="group-conversations-list" class="group-conversations-list"></div>
</div>
@@ -258,11 +286,37 @@
<div id="active-tasks-bar" class="active-tasks-bar"></div>
<div id="chat-messages" class="chat-messages"></div>
<div class="chat-input-container">
<div class="role-selector-wrapper">
<button id="role-selector-btn" class="role-selector-btn" onclick="toggleRoleSelectionPanel()" title="选择角色">
<span id="role-selector-icon" class="role-selector-icon">🔵</span>
<span id="role-selector-text" class="role-selector-text">默认</span>
<svg class="role-selector-arrow" width="10" height="10" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M6 9l6 6 6-6" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</button>
<!-- 角色选择下拉面板 -->
<div id="role-selection-panel" class="role-selection-panel" style="display: none;">
<div class="role-selection-panel-header">
<h3 class="role-selection-panel-title">选择角色</h3>
<button class="role-selection-panel-close" onclick="closeRoleSelectionPanel()" title="关闭">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M18 6L6 18M6 6l12 12" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</button>
</div>
<div id="role-selection-list" class="role-selection-list-main"></div>
</div>
</div>
<div class="chat-input-field">
<textarea id="chat-input" placeholder="输入测试目标或命令... (输入 @ 选择工具 | Shift+Enter 换行,Enter 发送)" rows="1"></textarea>
<div id="mention-suggestions" class="mention-suggestions" role="listbox" aria-label="工具提及候选"></div>
</div>
<button onclick="sendMessage()">发送</button>
<button class="send-btn" onclick="sendMessage()">
<span>发送</span>
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5 12h14M12 5l7 7-7 7" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</button>
</div>
</div>
</div>
@@ -568,9 +622,6 @@
<div class="page-content">
<!-- 批量任务队列列表 -->
<div class="batch-queues-section" id="batch-queues-section" style="display: none;">
<div class="batch-queues-header">
<h3>批量任务队列</h3>
</div>
<!-- 筛选控件 -->
<div class="batch-queues-filters tasks-filters">
<label>
@@ -585,7 +636,7 @@
</select>
</label>
<label style="flex: 1; max-width: 300px;">
<span>搜索队列ID或创建时间</span>
<span>搜索队列ID、标题或创建时间</span>
<input type="text" id="batch-queues-search" placeholder="输入关键字搜索..."
oninput="filterBatchQueues()">
</label>
@@ -597,6 +648,34 @@
</div>
</div>
<!-- 角色管理页面 -->
<div id="page-roles-management" class="page">
<div class="page-header">
<h2>角色管理</h2>
<div class="page-header-actions">
<button class="btn-secondary" onclick="refreshRoles()">刷新</button>
<button class="btn-primary" onclick="showAddRoleModal()">添加角色</button>
</div>
</div>
<div class="page-content">
<div class="roles-controls">
<div class="roles-stats-bar" id="roles-stats">
<div class="role-stat-item">
<span class="role-stat-label">总角色数</span>
<span class="role-stat-value">-</span>
</div>
<div class="role-stat-item">
<span class="role-stat-label">已启用</span>
<span class="role-stat-value">-</span>
</div>
</div>
</div>
<div id="roles-list" class="roles-list">
<div class="loading-spinner">加载中...</div>
</div>
</div>
</div>
<!-- 系统设置页面 -->
<div id="page-settings" class="page">
<div class="page-header">
@@ -679,10 +758,6 @@
<option value="openai">OpenAI</option>
</select>
</div>
<div class="form-group">
<label for="knowledge-embedding-model">模型名称</label>
<input type="text" id="knowledge-embedding-model" placeholder="text-embedding-v4" />
</div>
<div class="form-group">
<label for="knowledge-embedding-base-url">Base URL</label>
<input type="text" id="knowledge-embedding-base-url" placeholder="留空则使用OpenAI配置的base_url" />
@@ -693,6 +768,10 @@
<input type="password" id="knowledge-embedding-api-key" placeholder="留空则使用OpenAI配置的api_key" />
<small class="form-hint">留空则使用OpenAI配置的api_key</small>
</div>
<div class="form-group">
<label for="knowledge-embedding-model">模型名称</label>
<input type="text" id="knowledge-embedding-model" placeholder="text-embedding-v4" />
</div>
<div class="settings-subsection-header">
<h5>检索配置</h5>
@@ -857,6 +936,13 @@
"transport": "http",
"url": "http://127.0.0.1:8081/mcp"
}
}</code>
<strong>SSE模式:</strong><br>
<code style="display: block; margin: 8px 0; padding: 8px; background: var(--bg-secondary); border-radius: 4px; white-space: pre-wrap;">{
"cyberstrike-ai-sse": {
"transport": "sse",
"url": "http://127.0.0.1:8081/mcp/sse"
}
}</code>
</div>
<div id="external-mcp-json-error" class="error-message" style="display: none; margin-top: 8px; padding: 8px; background: rgba(220, 53, 69, 0.1); border: 1px solid rgba(220, 53, 69, 0.3); border-radius: 4px; color: var(--error-color); font-size: 0.875rem;"></div>
@@ -1160,6 +1246,13 @@
<span class="modal-close" onclick="closeBatchImportModal()">&times;</span>
</div>
<div class="modal-body">
<div class="form-group">
<label for="batch-queue-title">任务标题</label>
<input type="text" id="batch-queue-title" placeholder="请输入任务标题(可选,用于标识和筛选)" />
<div class="form-hint" style="margin-top: 4px;">
为批量任务队列设置一个标题,方便后续查找和管理。
</div>
</div>
<div class="form-group">
<label for="batch-tasks-input">任务列表(每行一个任务)<span style="color: red;">*</span></label>
<textarea id="batch-tasks-input" rows="15" placeholder="请输入任务列表,每行一个任务,例如:&#10;扫描 192.168.1.1 的开放端口&#10;检查 https://example.com 是否存在SQL注入&#10;枚举 example.com 的子域名" style="font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; font-size: 0.875rem; line-height: 1.5;"></textarea>
@@ -1307,6 +1400,96 @@
</div>
</div>
<!-- 角色选择弹窗 -->
<div id="role-select-modal" class="modal">
<div class="modal-content role-select-modal-content">
<div class="modal-header">
<h2>选择角色</h2>
<span class="modal-close" onclick="closeRoleSelectModal()">&times;</span>
</div>
<div class="modal-body role-select-body">
<div id="role-select-list" class="role-select-list"></div>
</div>
</div>
</div>
<!-- 角色编辑模态框 -->
<div id="role-modal" class="modal">
<div class="modal-content" style="max-width: 900px;">
<div class="modal-header">
<h2 id="role-modal-title">添加角色</h2>
<span class="modal-close" onclick="closeRoleModal()">&times;</span>
</div>
<div class="modal-body">
<div class="form-group">
<label for="role-name">角色名称 <span style="color: red;">*</span></label>
<input type="text" id="role-name" placeholder="输入角色名称" required />
</div>
<div class="form-group">
<label for="role-description">角色描述</label>
<input type="text" id="role-description" placeholder="输入角色描述" />
</div>
<div class="form-group">
<label for="role-icon">角色图标</label>
<input type="text" id="role-icon" placeholder="输入emoji图标,例如: 🏆" maxlength="10" />
<small class="form-hint">输入一个emoji作为角色的图标,将显示在角色选择器中。</small>
</div>
<div class="form-group">
<label for="role-user-prompt">用户提示词</label>
<textarea id="role-user-prompt" rows="10" placeholder="输入用户提示词,会在用户消息前追加此提示词..."></textarea>
<small class="form-hint">此提示词会追加到用户消息前,用于指导AI的行为。注意:这不会修改系统提示词。</small>
</div>
<div class="form-group" id="role-tools-section">
<label>关联的工具(可选)</label>
<div id="role-tools-default-hint" class="role-tools-default-hint" style="display: none;">
<div class="role-tools-default-info">
<span class="role-tools-default-icon"></span>
<div class="role-tools-default-content">
<div class="role-tools-default-title">默认角色使用所有工具</div>
<div class="role-tools-default-desc">默认角色会自动使用MCP管理中启用的所有工具,无需单独配置。</div>
</div>
</div>
</div>
<div class="role-tools-controls">
<div class="role-tools-actions">
<button type="button" class="btn-secondary" onclick="selectAllRoleTools()">全选</button>
<button type="button" class="btn-secondary" onclick="deselectAllRoleTools()">全不选</button>
<div class="role-tools-search-box">
<input type="text" id="role-tools-search" placeholder="搜索工具..."
oninput="searchRoleTools(this.value)"
onkeypress="if(event.key === 'Enter') searchRoleTools(this.value)" />
<button class="role-tools-search-clear" id="role-tools-search-clear"
onclick="clearRoleToolsSearch()" style="display: none;" title="清除搜索">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="12" cy="12" r="10" stroke="currentColor" stroke-width="2"/>
<path d="M15 9l-6 6M9 9l6 6" stroke="currentColor" stroke-width="2" stroke-linecap="round"/>
</svg>
</button>
</div>
</div>
<div id="role-tools-stats" class="role-tools-stats"></div>
</div>
<div id="role-tools-list" class="role-tools-list">
<div class="tools-loading">正在加载工具列表...</div>
</div>
<small class="form-hint">勾选要关联的工具,留空则使用MCP管理中的全部工具配置。</small>
</div>
<div class="form-group">
<label class="checkbox-label">
<input type="checkbox" id="role-enabled" class="modern-checkbox" checked />
<span class="checkbox-custom"></span>
<span class="checkbox-text">启用此角色</span>
</label>
</div>
</div>
<div class="modal-footer">
<button class="btn-secondary" onclick="closeRoleModal()">取消</button>
<button class="btn-primary" onclick="saveRole()">保存</button>
</div>
</div>
</div>
<script src="/static/js/builtin-tools.js"></script>
<script src="/static/js/auth.js"></script>
<script src="/static/js/router.js"></script>
<script src="/static/js/monitor.js"></script>
@@ -1315,6 +1498,7 @@
<script src="/static/js/knowledge.js"></script>
<script src="/static/js/vulnerability.js?v=4"></script>
<script src="/static/js/tasks.js"></script>
<script src="/static/js/roles.js"></script>
</body>
</html>