Compare commits

...

52 Commits

Author SHA1 Message Date
公明 1336d6f9a6 Add files via upload 2026-01-11 02:17:07 +08:00
公明 5ce1fb7501 Add files via upload 2026-01-11 02:05:11 +08:00
公明 aa9819a2c8 Update config.yaml 2026-01-11 02:04:21 +08:00
公明 3aee7022c4 Add files via upload 2026-01-11 02:03:33 +08:00
公明 4ca1aa9aa8 Add files via upload 2026-01-09 23:00:32 +08:00
公明 3448c661b8 Add files via upload 2026-01-09 22:16:09 +08:00
公明 b524ce68ea Add files via upload 2026-01-09 21:16:38 +08:00
公明 2c973f8c3b Add files via upload 2026-01-09 19:44:59 +08:00
公明 c3a1d95a92 Add files via upload 2026-01-09 19:32:14 +08:00
公明 60e3795322 Add files via upload 2026-01-09 19:02:16 +08:00
公明 28ca7f1851 Add files via upload 2026-01-09 18:52:38 +08:00
公明 14e9b986b0 Add files via upload 2026-01-08 23:43:09 +08:00
公明 dccbb80fa4 Add files via upload 2026-01-08 22:54:36 +08:00
公明 3043232937 Add files via upload 2026-01-08 22:43:41 +08:00
公明 2aeb2705e9 Add files via upload 2026-01-07 19:41:35 +08:00
公明 6bd558cbd4 Add files via upload 2026-01-07 19:38:36 +08:00
公明 71abfb2384 Update README.md 2026-01-07 14:10:29 +08:00
公明 d3f6a87448 Update README.md 2026-01-07 14:07:23 +08:00
公明 2076266844 Update README.md 2026-01-07 14:05:25 +08:00
公明 42293a9f49 Update README.md 2026-01-06 00:58:55 +08:00
公明 92580bebd5 Update README_CN.md 2026-01-06 00:58:23 +08:00
公明 23fd79d50d Update README_CN.md 2026-01-06 00:57:58 +08:00
公明 5216cebb2f Update README.md 2026-01-06 00:52:43 +08:00
公明 e55dd0265e Update README.md 2026-01-06 00:52:07 +08:00
公明 d550853b56 Add files via upload 2026-01-02 04:11:46 +08:00
公明 87e8f07738 Add files via upload 2026-01-02 02:22:43 +08:00
公明 044480a427 Add files via upload 2026-01-02 02:16:44 +08:00
公明 88e710d7e9 Add files via upload 2026-01-02 01:29:02 +08:00
公明 74b2edad29 Add files via upload 2026-01-02 01:05:02 +08:00
公明 cfc59ed895 Add files via upload 2026-01-02 01:02:01 +08:00
公明 9c5a115814 Delete img/效果.png 2026-01-02 01:01:26 +08:00
公明 a173dce667 Add files via upload 2026-01-02 00:52:25 +08:00
公明 b5d3396159 Add files via upload 2026-01-02 00:45:28 +08:00
公明 dcca3f014d Update README_CN.md 2026-01-02 00:20:20 +08:00
公明 ca8fb8b60b Update README.md 2026-01-02 00:19:36 +08:00
公明 7b9dee7268 Add files via upload 2026-01-02 00:18:39 +08:00
公明 b90a29fdd7 Add files via upload
1、修复删除知识项后总分类数统计错误:将 updateKnowledgeStats 中的 || 改为 != null 检查,并移除会错误更新统计的 updateKnowledgeStatsAfterDelete 调用。
2、为 MCP 状态监控页面添加了批量删除功能(复选框、全选、批量删除按钮)和每页显示数量配置(选择器位于分页控件左侧,设置保存到 localStorage)。
2025-12-31 19:20:58 +08:00
公明 24aa12cf33 Add files via upload 2025-12-31 19:02:15 +08:00
公明 7b8a220123 Update README.md 2025-12-31 09:08:09 +08:00
公明 99552a1812 Update README_CN.md 2025-12-31 09:07:07 +08:00
公明 e971e1eee2 Update README.md 2025-12-31 09:06:32 +08:00
公明 4fb1c7b911 Add files via upload 2025-12-31 09:06:08 +08:00
公明 9ebf9c2252 Delete img/外部MCP接入.png 2025-12-31 09:05:16 +08:00
公明 7fcfbe60c5 Add files via upload 2025-12-31 09:02:04 +08:00
公明 0c4f934b24 Delete img/效果.png 2025-12-31 09:01:51 +08:00
公明 90bafc2f1c Add files via upload 2025-12-31 08:57:48 +08:00
公明 adfd45e11e Add files via upload 2025-12-31 08:57:09 +08:00
公明 63f2a6fc3a Create feature_request.md 2025-12-31 01:31:29 +08:00
公明 4fecdad152 Create bug_report.md 2025-12-31 01:30:34 +08:00
公明 a32ba40353 Add files via upload 2025-12-30 23:21:09 +08:00
公明 d48238f6a0 Add files via upload 2025-12-30 22:02:13 +08:00
公明 98713236b7 Add files via upload 2025-12-29 19:17:16 +08:00
61 changed files with 12359 additions and 640 deletions
+78
View File
@@ -0,0 +1,78 @@
---
name: 🐛 Bug / 异常问题反馈
about: 报告一个 Bug 或异常问题
title: '[BUG] '
labels: ['bug', '待确认']
assignees: ''
---
## 📋 问题描述
<!-- 请清晰、简洁地描述遇到的问题 -->
## 🔄 复现步骤
<!-- 请详细描述如何复现这个问题 -->
1.
2.
3.
4.
## ✅ 期望行为
<!-- 描述你期望的正确行为是什么 -->
## ❌ 实际行为
<!-- 描述实际发生了什么 -->
## 📸 截图/录屏
<!--
⚠️ 重要:请提供完整的截图或录屏,确保包含:
- 完整的错误信息
- 相关的界面元素
- 浏览器控制台错误(如有)
- 终端输出(如有)
如果截图不完整,issue 可能会被关闭。
-->
<!-- 请在此处拖拽或粘贴截图 -->
## 📝 报错日志(脱敏后)
<!--
⚠️ 重要:请提供完整的、脱敏后的报错日志。
脱敏要求:
- 移除所有敏感信息(API Key、密码、Token、真实IP地址、域名等)
- 使用占位符替换,如:`sk-xxx``password: ***``192.168.x.x``example.com`
- 保留完整的错误堆栈信息
- 保留时间戳和日志级别
请从以下位置收集日志:
1. MCP状态监控 页面
2. 服务器终端输出
3. 日志文件(如果配置了文件输出)
4. 浏览器控制台(F12 → Console
-->
```
请在此处粘贴脱敏后的完整报错日志
```
## ✅ 检查清单
<!-- 提交前请确认以下项目 -->
- [ ] 我已阅读并理解项目的 Issue 规范
- [ ] 我已提供完整的、脱敏后的报错日志
- [ ] 我已提供完整的截图(如适用)
- [ ] 我已提供详细的复现步骤
- [ ] 我已填写所有必要的环境信息
- [ ] 我已脱敏所有敏感信息(API Key、密码、IP 等)
- [ ] 我已确认这不是重复的 issue
---
**注意**:如果缺少必要的日志或截图,此 issue 可能会被标记为 `需要更多信息` 或直接关闭。请确保提供完整的信息以便我们能够快速定位和解决问题。
+68
View File
@@ -0,0 +1,68 @@
---
name: ✨ 功能优化建议
about: 提出新功能或优化建议
title: '[FEATURE] '
labels: ['enhancement', '待讨论']
assignees: ''
---
## 💡 功能描述
<!-- 请清晰、简洁地描述你希望添加或优化的功能 -->
## 🎯 使用场景
<!-- 描述这个功能的使用场景,解决什么问题 -->
<!-- 例如:在什么情况下会用到这个功能?它如何改善用户体验? -->
## 🔄 当前行为
<!-- 描述当前系统是如何处理相关需求的,或者为什么需要这个功能 -->
## ✨ 期望行为
<!-- 详细描述你期望的新功能或优化后的行为 -->
## 📸 参考示例(如有)
<!--
如果有其他项目的类似功能实现,可以在此提供截图或链接作为参考
⚠️ 请确保截图完整,包含所有相关界面元素
-->
<!-- 请在此处拖拽或粘贴参考截图 -->
## 🛠️ 实现建议(可选)
<!-- 如果你有具体的实现思路或技术建议,可以在此描述 -->
## 📊 优先级评估
<!-- 请选择你认为的优先级 -->
- [ ] 🔴 高优先级(严重影响使用体验或功能缺失)
- [ ] 🟡 中优先级(能显著改善体验)
- [ ] 🟢 低优先级(锦上添花的功能)
## 🔍 相关功能
<!-- 这个功能是否与现有功能相关? -->
<!-- 例如:是否与工具管理、攻击链分析、知识库等功能相关? -->
## 📝 额外信息
<!-- 任何其他有助于理解需求的信息 -->
- 是否已有替代方案?
- 这个功能是否会影响现有功能?
- 是否有相关的其他 issue 或讨论?
## ✅ 检查清单
<!-- 提交前请确认以下项目 -->
- [ ] 我已清晰描述了功能需求和使用场景
- [ ] 我已提供完整的参考截图(如有)
- [ ] 我已评估了功能的优先级
- [ ] 我已确认这不是重复的 issue
- [ ] 我已考虑了对现有功能的影响
---
**注意**:请提供尽可能详细的信息,包括使用场景、参考示例等,这将有助于我们更好地理解和实现你的需求。
+169
View File
@@ -0,0 +1,169 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [1.2.0] - 2026-01-11
### Added
- Role-based testing feature: predefined security testing roles with custom system prompts and tool restrictions. Users can select roles (Penetration Testing, CTF, Web App Scanning, etc.) from the chat interface to customize AI behavior and available tools. Roles are defined as YAML files in the `roles/` directory with support for hot-reload.
## [1.1.0] - 2026-01-08
### Added
- SSE (Server-Sent Events) transport mode support for external MCP servers. External MCP federation now supports HTTP, stdio, and SSE modes. SSE mode enables real-time streaming communication for push-based scenarios.
## [1.0.0] - 2026-01-01
### Added
- Batch task management feature: create task queues with multiple tasks, add/edit/delete tasks before execution, and execute them sequentially. Each task runs as a separate conversation with status tracking (pending/running/completed/failed/cancelled). All queues and tasks are persisted in the database.
## [0.7.0] - 2025-12-25
### Added
- Vulnerability management feature: full CRUD operations for tracking vulnerabilities discovered during testing. Supports severity levels (critical/high/medium/low/info), status workflow (open/confirmed/fixed/false_positive), filtering by conversation/severity/status, and comprehensive statistics dashboard.
- Conversation grouping feature: organize conversations into groups, pin groups to top, rename/delete groups via context menu. All group data is persisted in the database.
## [0.6.1] - 2025-12-24
### Changed
- Refactored attack chain generation logic, achieving 2x faster generation speed. Redesigned attack chain frontend visualization for improved user experience.
## [0.6.0] - 2025-12-20
### Added
- Knowledge base feature with vector search, hybrid retrieval, and automatic indexing. AI agent can now search security knowledge during conversations.
## [0.5.1] - 2025-12-19
### Added
- ZoomEye network space search engine tool (zoomeye_search) with support for IPv4/IPv6/web assets, facets statistics, and flexible query parameters.
## [0.5.0] - 2025-12-18
### Changed
- Optimized web frontend with enhanced sidebar navigation and improved user experience.
## [0.4.1] - 2025-12-07
### Added
- FOFA network space search engine tool (fofa_search) with flexible query parameters and field configuration.
### Fixed
- Positional parameter handling bug: ensure correct parameter position when using default values.
## [0.4.0] - 2025-11-20
### Added
- Automatic compression/summarization for oversized tool logs and MCP transcripts.
## [0.3.0] - 2025-11-17
### Added
- AI-built attack-chain visualization with interactive graph and risk scoring.
## [0.2.0] - 2025-11-15
### Added
- Large-result pagination, advanced filtering, and external MCP federation.
## [0.1.1] - 2025-11-14
### Changed
- Optimized tool lookups to O(1) time complexity.
- Execution record cleanup and DB pagination improvements.
## [0.1.0] - 2025-11-13
### Added
- Web authentication, settings UI, and MCP stdio mode integration.
---
# 更新日志
本项目的重要变更将记录在此文件中。
格式基于 [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
并遵循 [语义化版本](https://semver.org/lang/zh-CN/)。
## [未发布]
## [1.2.0] - 2026-01-11
### 新增
- 角色化测试功能:预设安全测试角色,支持自定义系统提示词和工具限制。用户可在聊天界面选择角色(渗透测试、CTF、Web 应用扫描等),以自定义 AI 行为和可用工具。角色以 YAML 文件形式定义在 `roles/` 目录,支持热加载。
## [1.1.0] - 2026-01-08
### 新增
- SSEServer-Sent Events)传输模式支持,外部 MCP 联邦现支持 HTTP、stdio 和 SSE 三种模式。SSE 模式支持实时流式通信,适用于基于推送的场景。
## [1.0.0] - 2026-01-01
### 新增
- 批量任务管理功能:支持创建任务队列,批量添加多个任务,执行前可编辑或删除任务,然后依次顺序执行。每个任务作为独立对话运行,支持状态跟踪(待执行/执行中/已完成/失败/已取消),所有队列和任务数据持久化存储到数据库。
## [0.7.0] - 2025-12-25
### 新增
- 漏洞管理功能:完整的漏洞 CRUD 操作,支持跟踪测试过程中发现的漏洞。支持严重程度分级(严重/高/中/低/信息)、状态流转(待确认/已确认/已修复/误报)、按对话/严重程度/状态过滤,以及统计看板。
- 对话分组功能:支持创建分组、将对话移动到分组、分组置顶、重命名和删除等操作,所有分组数据持久化存储到数据库。
## [0.6.1] - 2025-12-24
### 变更
- 重构攻击链生成逻辑,生成速度提升一倍。重构攻击链前端页面展示,优化用户体验。
## [0.6.0] - 2025-12-20
### 新增
- 知识库功能:支持向量检索、混合搜索与自动索引,AI 智能体可在对话中自动搜索安全知识。
## [0.5.1] - 2025-12-19
### 新增
- 钟馗之眼(ZoomEye)网络空间搜索引擎工具(zoomeye_search),支持 IPv4/IPv6/Web 等资产搜索、统计项查询与灵活的查询参数配置。
## [0.5.0] - 2025-12-18
### 变更
- 优化 Web 前端界面,增加侧边栏导航,提升用户体验。
## [0.4.1] - 2025-12-07
### 新增
- FOFA 网络空间搜索引擎工具(fofa_search),支持灵活的查询参数与字段配置。
### 修复
- 修复位置参数处理 bug:当工具参数使用默认值时,确保后续参数位置正确传递。
## [0.4.0] - 2025-11-20
### 新增
- 支持超大日志/MCP 记录的自动压缩与摘要回写。
## [0.3.0] - 2025-11-17
### 新增
- 上线 AI 驱动的攻击链图谱与风险评分。
## [0.2.0] - 2025-11-15
### 新增
- 提供大结果分页检索与外部 MCP 挂载能力。
## [0.1.1] - 2025-11-14
### 变更
- 工具检索优化至 O(1) 时间复杂度。
- 执行记录清理、数据库分页优化。
## [0.1.0] - 2025-11-13
### 新增
- Web 鉴权、Settings 面板与 MCP stdio 模式发布。
+191 -64
View File
@@ -9,22 +9,31 @@
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, and comprehensive lifecycle management capabilities. Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams. CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, and comprehensive lifecycle management capabilities. Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
> In security, what is truly scarce is not tools, but judgment; judgment is born from experience, and experience is often bound to individuals, difficult to inherit and difficult to reuse. CyberStrikeAI does not attempt to automate attacks, but instead focuses on a harder problem: in complex, dynamic, and uncertain environments, what should be done next, and why. It does not pursue more aggressive automation; rather, it seeks to transform the way security experts think—their decision paths and lessons learned from failure—into a system capability that is constrained, auditable, and evolvable. If experience can exist beyond individuals, then security can finally become something that can be inherited by systems.
## Interface & Integration Preview ## Interface & Integration Preview
- Web console
<img src="./img/效果.png" alt="Preview" width="560"> ### Web Console
- MCP stdio mode <img src="./img/效果.png" alt="Web Console" width="560">
<img src="./img/mcp-stdio2.png" alt="Preview" width="560">
- External MCP servers & attack-chain view ### MCP Integration
<img src="./img/外部MCP接入.png" alt="Preview" width="560"> - **MCP stdio mode**
<img src="./img/攻击链.png" alt="Preview" width="560"> <img src="./img/mcp-stdio2.png" alt="MCP stdio mode" width="560">
- **MCP management**
<img src="./img/MCP管理.png" alt="MCP management" width="560">
### Attack Chain Visualization
<img src="./img/攻击链.png" alt="Attack Chain" width="560">
### Vulnerability Management
<img src="./img/漏洞管理.png" alt="Vulnerability Management" width="560">
### Task Management
<img src="./img/任务.png" alt="Task Management" width="560">
## Highlights ## Highlights
- 🤖 AI decision engine with OpenAI-compatible models (GPT, Claude, DeepSeek, etc.) - 🤖 AI decision engine with OpenAI-compatible models (GPT, Claude, DeepSeek, etc.)
- 🔌 Native MCP implementation with HTTP/stdio transports and external MCP federation - 🔌 Native MCP implementation with HTTP/stdio/SSE transports and external MCP federation
- 🧰 100+ prebuilt tool recipes + YAML-based extension system - 🧰 100+ prebuilt tool recipes + YAML-based extension system
- 📄 Large-result pagination, compression, and searchable archives - 📄 Large-result pagination, compression, and searchable archives
- 🔗 Attack-chain graph, risk scoring, and step-by-step replay - 🔗 Attack-chain graph, risk scoring, and step-by-step replay
@@ -32,6 +41,8 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
- 📚 Knowledge base with vector search and hybrid retrieval for security expertise - 📚 Knowledge base with vector search and hybrid retrieval for security expertise
- 📁 Conversation grouping with pinning, rename, and batch management - 📁 Conversation grouping with pinning, rename, and batch management
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics - 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
## Tool Overview ## Tool Overview
@@ -55,35 +66,40 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
## Basic Usage ## Basic Usage
### Quick Start ### Quick Start (One-Command Deployment)
1. **Clone & install**
```bash **Prerequisites:**
git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git - Go 1.21+ ([Install](https://go.dev/dl/))
cd CyberStrikeAI-main - Python 3.10+ ([Install](https://www.python.org/downloads/))
go mod download
``` **One-Command Deployment:**
2. **Set up the Python tooling stack (required for the YAML tools directory)** ```bash
A large portion of `tools/*.yaml` recipes wrap Python utilities (`api-fuzzer`, `http-framework-test`, `install-python-package`, etc.). Create the project-local virtual environment once and install the shared dependencies: git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git
```bash cd CyberStrikeAI-main
python3 -m venv venv chmod +x run.sh && ./run.sh
source venv/bin/activate ```
pip install -r requirements.txt
``` The `run.sh` script will automatically:
The helper tools automatically detect this `venv` (or any already active `$VIRTUAL_ENV`), so the default `env_name` works out of the box unless you intentionally supply another target. - ✅ Check and validate Go & Python environments
3. **Configure OpenAI-compatible access** - ✅ Create Python virtual environment
Either open the in-app `Settings` panel after launch or edit `config.yaml`: - ✅ Install Python dependencies
```yaml - ✅ Download Go dependencies
openai: - ✅ Build the project
api_key: "sk-your-key" - ✅ Start the server
base_url: "https://api.openai.com/v1"
model: "gpt-4o" **First-Time Configuration:**
auth: 1. **Configure OpenAI-compatible API** (required before first use)
password: "" # empty = auto-generate & log once - Open http://localhost:8080 after launch
session_duration_hours: 12 - Go to `Settings` → Fill in your API credentials:
security: ```yaml
tools_dir: "tools" openai:
``` api_key: "sk-your-key"
4. **Install the tooling you need (optional)** base_url: "https://api.openai.com/v1" # or https://api.deepseek.com/v1
model: "gpt-4o" # or deepseek-chat, claude-3-opus, etc.
```
- Or edit `config.yaml` directly before launching
2. **Login** - Use the auto-generated password shown in the console (or set `auth.password` in `config.yaml`)
3. **Install security tools (optional)** - Install tools as needed:
```bash ```bash
# macOS # macOS
brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass
@@ -91,22 +107,27 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster
``` ```
AI automatically falls back to alternatives when a tool is missing. AI automatically falls back to alternatives when a tool is missing.
5. **Launch**
```bash **Alternative Launch Methods:**
chmod +x run.sh && ./run.sh ```bash
# or # Direct Go run (requires manual setup)
go run cmd/server/main.go go run cmd/server/main.go
# or
go build -o cyberstrike-ai cmd/server/main.go # Manual build
``` go build -o cyberstrike-ai cmd/server/main.go
6. **Open the console** at http://localhost:8080, log in with the generated password, and start chatting. ./cyberstrike-ai
```
**Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment.
### Core Workflows ### Core Workflows
- **Conversation testing** Natural-language prompts trigger toolchains with streaming SSE output. - **Conversation testing** Natural-language prompts trigger toolchains with streaming SSE output.
- **Role-based testing** Select from predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, etc.) to customize AI behavior and tool availability. Each role applies custom system prompts and can restrict available tools for focused testing scenarios.
- **Tool monitor** Inspect running jobs, execution logs, and large-result attachments. - **Tool monitor** Inspect running jobs, execution logs, and large-result attachments.
- **History & audit** Every conversation and tool invocation is stored in SQLite with replay. - **History & audit** Every conversation and tool invocation is stored in SQLite with replay.
- **Conversation groups** Organize conversations into groups, pin important groups, rename or delete groups via context menu. - **Conversation groups** Organize conversations into groups, pin important groups, rename or delete groups via context menu.
- **Vulnerability management** Create, update, and track vulnerabilities discovered during testing. Filter by severity (critical/high/medium/low/info), status (open/confirmed/fixed/false_positive), and conversation. View statistics and export findings. - **Vulnerability management** Create, update, and track vulnerabilities discovered during testing. Filter by severity (critical/high/medium/low/info), status (open/confirmed/fixed/false_positive), and conversation. View statistics and export findings.
- **Batch task management** Create task queues with multiple tasks, add or edit tasks before execution, and run them sequentially. Each task executes as a separate conversation, with status tracking (pending/running/completed/failed/cancelled) and full execution history.
- **Settings** Tweak provider keys, MCP enablement, tool toggles, and agent iteration limits. - **Settings** Tweak provider keys, MCP enablement, tool toggles, and agent iteration limits.
### Built-in Safeguards ### Built-in Safeguards
@@ -117,6 +138,28 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
## Advanced Usage ## Advanced Usage
### Role-Based Testing
- **Predefined roles** System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory.
- **Custom prompts** Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas.
- **Tool restrictions** Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities).
- **Easy role creation** Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields.
- **Web UI integration** Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions.
**Creating a custom role (example):**
1. Create a YAML file in `roles/` (e.g., `roles/custom-role.yaml`):
```yaml
name: Custom Role
description: Specialized testing scenario
user_prompt: You are a specialized security tester focusing on API security...
icon: "\U0001F4E1"
tools:
- api-fuzzer
- arjun
- graphql-scanner
enabled: true
```
2. Restart the server or reload configuration; the role appears in the role selector dropdown.
### Tool Orchestration & Extensions ### Tool Orchestration & Extensions
- **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata. - **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata.
- **Directory hot-reload** pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments. - **Directory hot-reload** pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments.
@@ -138,7 +181,7 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
### MCP Everywhere ### MCP Everywhere
- **Web mode** ships with HTTP MCP server automatically consumed by the UI. - **Web mode** ships with HTTP MCP server automatically consumed by the UI.
- **MCP stdio mode** `go run cmd/mcp-stdio/main.go` exposes the agent to Cursor/CLI. - **MCP stdio mode** `go run cmd/mcp-stdio/main.go` exposes the agent to Cursor/CLI.
- **External MCP federation** register third-party MCP servers (HTTP or stdio) from the UI, toggle them per engagement, and monitor their health and call volume in real time. - **External MCP federation** register third-party MCP servers (HTTP, stdio, or SSE) from the UI, toggle them per engagement, and monitor their health and call volume in real time.
#### MCP stdio quick start #### MCP stdio quick start
1. **Build the binary** (run from the project root): 1. **Build the binary** (run from the project root):
@@ -178,6 +221,62 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
} }
``` ```
#### External MCP federation (HTTP/stdio/SSE)
CyberStrikeAI supports connecting to external MCP servers via three transport modes:
- **HTTP mode** traditional request/response over HTTP POST
- **stdio mode** process-based communication via standard input/output
- **SSE mode** Server-Sent Events for real-time streaming communication
To add an external MCP server:
1. Open the Web UI and navigate to **Settings → External MCP**.
2. Click **Add External MCP** and provide the configuration in JSON format:
**HTTP mode example:**
```json
{
"my-http-mcp": {
"transport": "http",
"url": "http://127.0.0.1:8081/mcp",
"description": "HTTP MCP server",
"timeout": 30
}
}
```
**stdio mode example:**
```json
{
"my-stdio-mcp": {
"command": "python3",
"args": ["/path/to/mcp-server.py"],
"description": "stdio MCP server",
"timeout": 30
}
}
```
**SSE mode example:**
```json
{
"my-sse-mcp": {
"transport": "sse",
"url": "http://127.0.0.1:8082/sse",
"description": "SSE MCP server",
"timeout": 30
}
}
```
3. Click **Save** and then **Start** to connect to the server.
4. Monitor the connection status, tool count, and health in real time.
**SSE mode benefits:**
- Real-time bidirectional communication via Server-Sent Events
- Suitable for scenarios requiring continuous data streaming
- Lower latency for push-based notifications
A test SSE MCP server is available at `cmd/test-sse-mcp-server/` for validation purposes.
### Knowledge Base ### Knowledge Base
- **Vector search** AI agent can automatically search the knowledge base for relevant security knowledge during conversations using the `search_knowledge_base` tool. - **Vector search** AI agent can automatically search the knowledge base for relevant security knowledge during conversations using the `search_knowledge_base` tool.
- **Hybrid retrieval** combines vector similarity search with keyword matching for better accuracy. - **Hybrid retrieval** combines vector similarity search with keyword matching for better accuracy.
@@ -217,8 +316,10 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
### Automation Hooks ### Automation Hooks
- **REST APIs** everything the UI uses (auth, conversations, tool runs, monitor, vulnerabilities) is available over JSON. - **REST APIs** everything the UI uses (auth, conversations, tool runs, monitor, vulnerabilities, roles) is available over JSON.
- **Role APIs** manage security testing roles via `/api/roles` endpoints: `GET /api/roles` (list all roles), `GET /api/roles/:name` (get role), `POST /api/roles` (create role), `PUT /api/roles/:name` (update role), `DELETE /api/roles/:name` (delete role). Roles are stored as YAML files in the `roles/` directory and support hot-reload.
- **Vulnerability APIs** manage vulnerabilities via `/api/vulnerabilities` endpoints: `GET /api/vulnerabilities` (list with filters), `POST /api/vulnerabilities` (create), `GET /api/vulnerabilities/:id` (get), `PUT /api/vulnerabilities/:id` (update), `DELETE /api/vulnerabilities/:id` (delete), `GET /api/vulnerabilities/stats` (statistics). - **Vulnerability APIs** manage vulnerabilities via `/api/vulnerabilities` endpoints: `GET /api/vulnerabilities` (list with filters), `POST /api/vulnerabilities` (create), `GET /api/vulnerabilities/:id` (get), `PUT /api/vulnerabilities/:id` (update), `DELETE /api/vulnerabilities/:id` (delete), `GET /api/vulnerabilities/stats` (statistics).
- **Batch Task APIs** manage batch task queues via `/api/batch-tasks` endpoints: `POST /api/batch-tasks` (create queue), `GET /api/batch-tasks` (list queues), `GET /api/batch-tasks/:queueId` (get queue), `POST /api/batch-tasks/:queueId/start` (start execution), `POST /api/batch-tasks/:queueId/cancel` (cancel), `DELETE /api/batch-tasks/:queueId` (delete), `POST /api/batch-tasks/:queueId/tasks` (add task), `PUT /api/batch-tasks/:queueId/tasks/:taskId` (update task), `DELETE /api/batch-tasks/:queueId/tasks/:taskId` (delete task). Tasks execute sequentially, each creating a separate conversation with full status tracking.
- **Task control** pause/resume/stop long scans, re-run steps with new params, or stream transcripts. - **Task control** pause/resume/stop long scans, re-run steps with new params, or stream transcripts.
- **Audit & security** rotate passwords via `/api/auth/change-password`, enforce short-lived sessions, and restrict MCP ports at the network layer when exposing the service. - **Audit & security** rotate passwords via `/api/auth/change-password`, enforce short-lived sessions, and restrict MCP ports at the network layer when exposing the service.
@@ -259,6 +360,7 @@ knowledge:
top_k: 5 # Number of top results to return top_k: 5 # Number of top results to return
similarity_threshold: 0.7 # Minimum similarity score (0-1) similarity_threshold: 0.7 # Minimum similarity score (0-1)
hybrid_weight: 0.7 # Weight for vector search (1.0 = pure vector, 0.0 = pure keyword) hybrid_weight: 0.7 # Weight for vector search (1.0 = pure vector, 0.0 = pure keyword)
roles_dir: "roles" # Role configuration directory (relative to config file)
``` ```
### Tool Definition Example (`tools/nmap.yaml`) ### Tool Definition Example (`tools/nmap.yaml`)
@@ -281,6 +383,26 @@ parameters:
description: "Range, e.g. 1-1000" description: "Range, e.g. 1-1000"
``` ```
### Role Definition Example (`roles/penetration-testing.yaml`)
```yaml
name: Penetration Testing
description: Professional penetration testing expert for comprehensive security testing
user_prompt: You are a professional cybersecurity penetration testing expert. Please use professional penetration testing methods and tools to conduct comprehensive security testing on targets, including but not limited to SQL injection, XSS, CSRF, file inclusion, command execution and other common vulnerabilities.
icon: "\U0001F3AF"
tools:
- nmap
- sqlmap
- nuclei
- burpsuite
- metasploit
- httpx
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
```
## Project Layout ## Project Layout
``` ```
@@ -289,6 +411,7 @@ CyberStrikeAI/
├── internal/ # Agent, MCP core, handlers, security executor ├── internal/ # Agent, MCP core, handlers, security executor
├── web/ # Static SPA + templates ├── web/ # Static SPA + templates
├── tools/ # YAML tool recipes (100+ examples provided) ├── tools/ # YAML tool recipes (100+ examples provided)
├── roles/ # Role configurations (12+ predefined security testing roles)
├── img/ # Docs screenshots & diagrams ├── img/ # Docs screenshots & diagrams
├── config.yaml # Runtime configuration ├── config.yaml # Runtime configuration
├── run.sh # Convenience launcher ├── run.sh # Convenience launcher
@@ -314,21 +437,22 @@ Compress the 5 MB nuclei report, summarize critical CVEs, and attach the artifac
Build an attack chain for the latest engagement and export the node list with severity >= high. Build an attack chain for the latest engagement and export the node list with severity >= high.
``` ```
## Changelog (Recent) ## Changelog
See [CHANGELOG.md](CHANGELOG.md) for detailed version history and all changes.
### Recent Highlights
- **2026-01-11** Role-based testing with predefined security testing roles
- **2026-01-08** SSE transport mode support for external MCP servers
- **2026-01-01** Batch task management with queue-based execution
- **2025-12-25** Vulnerability management and conversation grouping features
- **2025-12-20** Knowledge base with vector search and hybrid retrieval
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=Ed1s0nZ/CyberStrikeAI&type=date&legend=top-left)
- 2025-12-25 Added vulnerability management feature: full CRUD operations for tracking vulnerabilities discovered during testing. Supports severity levels (critical/high/medium/low/info), status workflow (open/confirmed/fixed/false_positive), filtering by conversation/severity/status, and comprehensive statistics dashboard.
- 2025-12-25 Added conversation grouping feature: organize conversations into groups, pin groups to top, rename/delete groups via context menu. All group data is persisted in the database.
- 2025-12-24 Refactored attack chain generation logic, achieving 2x faster generation speed. Redesigned attack chain frontend visualization for improved user experience.
- 2025-12-20 Added knowledge base feature with vector search, hybrid retrieval, and automatic indexing. AI agent can now search security knowledge during conversations.
- 2025-12-19 Added ZoomEye network space search engine tool (zoomeye_search) with support for IPv4/IPv6/web assets, facets statistics, and flexible query parameters.
- 2025-12-18 Optimized web frontend with enhanced sidebar navigation and improved user experience.
- 2025-12-07 Added FOFA network space search engine tool (fofa_search) with flexible query parameters and field configuration.
- 2025-12-07 Fixed positional parameter handling bug: ensure correct parameter position when using default values.
- 2025-11-20 Added automatic compression/summarization for oversized tool logs and MCP transcripts.
- 2025-11-17 Introduced AI-built attack-chain visualization with interactive graph and risk scoring.
- 2025-11-15 Delivered large-result pagination, advanced filtering, and external MCP federation.
- 2025-11-14 Optimized tool lookups (O(1)), execution record cleanup, and DB pagination.
- 2025-11-13 Added web authentication, settings UI, and MCP stdio mode integration.
## 404Starlink ## 404Starlink
@@ -344,6 +468,9 @@ CyberStrikeAI has joined [404Starlink](https://github.com/knownsec/404StarLink)
</div> </div>
--- ---
Need help or want to contribute? Open an issue or PR—community tooling additions are welcome! Need help or want to contribute? Open an issue or PR—community tooling additions are welcome!
+188 -63
View File
@@ -8,21 +8,31 @@
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎与完整的测试生命周期管理能力。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。 CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎与完整的测试生命周期管理能力。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
> 在安全领域,真正稀缺的从来不是工具,而是判断。判断来自经验,而经验往往只能依附在个人身上,难以继承、难以复用。CyberStrikeAI 尝试解决的不是“如何自动化攻击”,而是:在复杂、多变、充满不确定性的环境中,下一步“应该做什么”,以及“为什么”。它不追求更激进的自动化,而是试图把安全专家的思考方式、决策路径与失败经验,转化为一个可约束、可复盘、可演进的系统能力。如果经验可以脱离个人而存在,那么安全,才真正具备了被系统性继承的可能。
## 界面与集成预览 ## 界面与集成预览
- Web 控制台
<img src="./img/效果.png" alt="Preview" width="560"> ### Web 控制台
- MCP stdio 模式 <img src="./img/效果.png" alt="Web 控制台" width="560">
<img src="./img/mcp-stdio2.png" alt="Preview" width="560">
- 外部 MCP 服务器 & 攻击链视图 ### MCP 集成
<img src="./img/外部MCP接入.png" alt="Preview" width="560"> - **MCP stdio 模式**
<img src="./img/攻击链.png" alt="Preview" width="560"> <img src="./img/mcp-stdio2.png" alt="MCP stdio 模式" width="560">
- **MCP 管理**
<img src="./img/MCP管理.png" alt="MCP 管理" width="560">
### 攻击链可视化
<img src="./img/攻击链.png" alt="攻击链" width="560">
### 漏洞管理
<img src="./img/漏洞管理.png" alt="漏洞管理" width="560">
### 任务管理
<img src="./img/任务.png" alt="任务管理" width="560">
## 特性速览 ## 特性速览
- 🤖 兼容 OpenAI/DeepSeek/Claude 等模型的智能决策引擎 - 🤖 兼容 OpenAI/DeepSeek/Claude 等模型的智能决策引擎
- 🔌 原生 MCP 协议,支持 HTTP / stdio 以及外部 MCP 接入 - 🔌 原生 MCP 协议,支持 HTTP / stdio / SSE 传输模式以及外部 MCP 接入
- 🧰 100+ 现成工具模版 + YAML 扩展能力 - 🧰 100+ 现成工具模版 + YAML 扩展能力
- 📄 大结果分页、压缩与全文检索 - 📄 大结果分页、压缩与全文检索
- 🔗 攻击链可视化、风险打分与步骤回放 - 🔗 攻击链可视化、风险打分与步骤回放
@@ -30,6 +40,8 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
- 📚 知识库功能:向量检索与混合搜索,为 AI 提供安全专业知识 - 📚 知识库功能:向量检索与混合搜索,为 AI 提供安全专业知识
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作 - 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板 - 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
## 工具概览 ## 工具概览
@@ -53,35 +65,40 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
## 基础使用 ## 基础使用
### 快速上手 ### 快速上手(一条命令部署)
1. **获取代码并安装依赖**
```bash **环境要求:**
git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git - Go 1.21+ ([下载安装](https://go.dev/dl/))
cd CyberStrikeAI-main - Python 3.10+ ([下载安装](https://www.python.org/downloads/))
go mod download
``` **一条命令部署:**
2. **初始化 Python 虚拟环境(tools 目录所需)** ```bash
`tools/*.yaml` 中大量工具(如 `api-fuzzer`、`http-framework-test`、`install-python-package` 等)依赖 Python 生态。首次进入项目根目录时请创建本地虚拟环境并安装依赖: git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git
```bash cd CyberStrikeAI-main
python3 -m venv venv chmod +x run.sh && ./run.sh
source venv/bin/activate ```
pip install -r requirements.txt
``` `run.sh` 脚本会自动完成:
两个 Python 专用工具(`install-python-package` 与 `execute-python-script`)会自动检测该 `venv`(或已经激活的 `$VIRTUAL_ENV`),因此默认 `env_name` 即可满足大多数场景。 - ✅ 检查并验证 Go 和 Python 环境
3. **配置模型与鉴权** - ✅ 创建 Python 虚拟环境
启动后在 Web 端 `Settings` 填写,或直接编辑 `config.yaml` - ✅ 安装 Python 依赖包
```yaml - ✅ 下载 Go 依赖模块
openai: - ✅ 编译构建项目
api_key: "sk-your-key" - ✅ 启动服务器
base_url: "https://api.openai.com/v1"
model: "gpt-4o" **首次配置:**
auth: 1. **配置 AI 模型 API**(首次使用前必填)
password: "" # 为空则首次启动自动生成强口令 - 启动后访问 http://localhost:8080
session_duration_hours: 12 - 进入 `设置` → 填写 API 配置信息:
security: ```yaml
tools_dir: "tools" openai:
``` api_key: "sk-your-key"
4. **按需安装安全工具(可选)** base_url: "https://api.openai.com/v1" # 或 https://api.deepseek.com/v1
model: "gpt-4o" # 或 deepseek-chat, claude-3-opus 等
```
- 或启动前直接编辑 `config.yaml` 文件
2. **登录系统** - 使用控制台显示的自动生成密码(或在 `config.yaml` 中设置 `auth.password`
3. **安装安全工具(可选)** - 按需安装所需工具:
```bash ```bash
# macOS # macOS
brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass
@@ -89,22 +106,27 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster
``` ```
未安装的工具会自动跳过或改用替代方案。 未安装的工具会自动跳过或改用替代方案。
5. **启动服务**
```bash **其他启动方式:**
chmod +x run.sh && ./run.sh ```bash
# 或 # 直接运行(需手动配置环境)
go run cmd/server/main.go go run cmd/server/main.go
# 或
go build -o cyberstrike-ai cmd/server/main.go # 手动编译
``` go build -o cyberstrike-ai cmd/server/main.go
6. **浏览器访问** http://localhost:8080 ,使用日志中提示的密码登录并开始对话。 ./cyberstrike-ai
```
**说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。
### 常用流程 ### 常用流程
- **对话测试**:自然语言触发多步工具编排,SSE 实时输出。 - **对话测试**:自然语言触发多步工具编排,SSE 实时输出。
- **角色化测试**:从预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试等)中选择,自定义 AI 行为和可用工具。每个角色可应用自定义系统提示词,并可限制可用工具列表,实现聚焦的测试场景。
- **工具监控**:查看任务队列、执行日志、大文件附件。 - **工具监控**:查看任务队列、执行日志、大文件附件。
- **会话历史**:所有对话与工具调用保存在 SQLite,可随时重放。 - **会话历史**:所有对话与工具调用保存在 SQLite,可随时重放。
- **对话分组**:将对话按项目或主题组织到不同分组,支持置顶、重命名、删除等操作,所有数据持久化存储。 - **对话分组**:将对话按项目或主题组织到不同分组,支持置顶、重命名、删除等操作,所有数据持久化存储。
- **漏洞管理**:在测试过程中创建、更新和跟踪发现的漏洞。支持按严重程度(严重/高/中/低/信息)、状态(待确认/已确认/已修复/误报)和对话进行过滤,查看统计信息并导出发现。 - **漏洞管理**:在测试过程中创建、更新和跟踪发现的漏洞。支持按严重程度(严重/高/中/低/信息)、状态(待确认/已确认/已修复/误报)和对话进行过滤,查看统计信息并导出发现。
- **批量任务管理**:创建任务队列,批量添加多个任务,执行前可编辑或删除任务,然后依次顺序执行。每个任务会作为独立对话执行,支持完整的状态跟踪(待执行/执行中/已完成/失败/已取消)和执行历史。
- **可视化配置**:在界面中切换模型、启停工具、设置迭代次数等。 - **可视化配置**:在界面中切换模型、启停工具、设置迭代次数等。
### 默认安全措施 ### 默认安全措施
@@ -115,6 +137,28 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
## 进阶使用 ## 进阶使用
### 角色化测试
- **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。
- **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。
- **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段。
- **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。
**创建自定义角色示例:**
1. 在 `roles/` 目录创建 YAML 文件(如 `roles/custom-role.yaml`):
```yaml
name: 自定义角色
description: 专用测试场景
user_prompt: 你是一个专注于 API 安全的专业安全测试人员...
icon: "\U0001F4E1"
tools:
- api-fuzzer
- arjun
- graphql-scanner
enabled: true
```
2. 重启服务或重新加载配置,角色会出现在角色选择下拉菜单中。
### 工具编排与扩展 ### 工具编排与扩展
- `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。 - `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。
- `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。 - `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。
@@ -135,7 +179,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
### MCP 全场景 ### MCP 全场景
- **Web 模式**:自带 HTTP MCP 服务供前端调用。 - **Web 模式**:自带 HTTP MCP 服务供前端调用。
- **MCP stdio 模式**`go run cmd/mcp-stdio/main.go` 可接入 Cursor/命令行。 - **MCP stdio 模式**`go run cmd/mcp-stdio/main.go` 可接入 Cursor/命令行。
- **外部 MCP 联邦**:在设置中注册第三方 MCP(HTTP/stdio),按需启停并实时查看调用统计与健康度。 - **外部 MCP 联邦**:在设置中注册第三方 MCPHTTP/stdio/SSE),按需启停并实时查看调用统计与健康度。
#### MCP stdio 快速集成 #### MCP stdio 快速集成
1. **编译可执行文件**(在项目根目录执行): 1. **编译可执行文件**(在项目根目录执行):
@@ -175,6 +219,62 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
} }
``` ```
#### 外部 MCP 联邦(HTTP/stdio/SSE
CyberStrikeAI 支持通过三种传输模式连接外部 MCP 服务器:
- **HTTP 模式** 通过 HTTP POST 进行传统的请求/响应通信
- **stdio 模式** – 通过标准输入/输出进行进程间通信
- **SSE 模式** 通过 Server-Sent Events 实现实时流式通信
添加外部 MCP 服务器:
1. 打开 Web 界面,进入 **设置 → 外部MCP**。
2. 点击 **添加外部MCP**,以 JSON 格式提供配置:
**HTTP 模式示例:**
```json
{
"my-http-mcp": {
"transport": "http",
"url": "http://127.0.0.1:8081/mcp",
"description": "HTTP MCP 服务器",
"timeout": 30
}
}
```
**stdio 模式示例:**
```json
{
"my-stdio-mcp": {
"command": "python3",
"args": ["/path/to/mcp-server.py"],
"description": "stdio MCP 服务器",
"timeout": 30
}
}
```
**SSE 模式示例:**
```json
{
"my-sse-mcp": {
"transport": "sse",
"url": "http://127.0.0.1:8082/sse",
"description": "SSE MCP 服务器",
"timeout": 30
}
}
```
3. 点击 **保存**,然后点击 **启动** 连接服务器。
4. 实时监控连接状态、工具数量和健康度。
**SSE 模式优势:**
- 通过 Server-Sent Events 实现实时双向通信
- 适用于需要持续数据流的场景
- 对于基于推送的通知,延迟更低
可在 `cmd/test-sse-mcp-server/` 目录找到用于验证的测试 SSE MCP 服务器。
### 知识库功能 ### 知识库功能
- **向量检索**:AI 智能体在对话过程中可自动调用 `search_knowledge_base` 工具搜索知识库中的安全知识。 - **向量检索**:AI 智能体在对话过程中可自动调用 `search_knowledge_base` 工具搜索知识库中的安全知识。
@@ -215,8 +315,10 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
### 自动化与安全 ### 自动化与安全
- **REST API**:认证、会话、任务、监控、漏洞管理等接口全部开放,可与 CI/CD 集成。 - **REST API**:认证、会话、任务、监控、漏洞管理、角色管理等接口全部开放,可与 CI/CD 集成。
- **角色管理 API**:通过 `/api/roles` 端点管理安全测试角色:`GET /api/roles`(列表)、`GET /api/roles/:name`(获取角色)、`POST /api/roles`(创建角色)、`PUT /api/roles/:name`(更新角色)、`DELETE /api/roles/:name`(删除角色)。角色以 YAML 文件形式存储在 `roles/` 目录,支持热加载。
- **漏洞管理 API**:通过 `/api/vulnerabilities` 端点管理漏洞:`GET /api/vulnerabilities`(列表,支持过滤)、`POST /api/vulnerabilities`(创建)、`GET /api/vulnerabilities/:id`(获取)、`PUT /api/vulnerabilities/:id`(更新)、`DELETE /api/vulnerabilities/:id`(删除)、`GET /api/vulnerabilities/stats`(统计)。 - **漏洞管理 API**:通过 `/api/vulnerabilities` 端点管理漏洞:`GET /api/vulnerabilities`(列表,支持过滤)、`POST /api/vulnerabilities`(创建)、`GET /api/vulnerabilities/:id`(获取)、`PUT /api/vulnerabilities/:id`(更新)、`DELETE /api/vulnerabilities/:id`(删除)、`GET /api/vulnerabilities/stats`(统计)。
- **批量任务 API**:通过 `/api/batch-tasks` 端点管理批量任务队列:`POST /api/batch-tasks`(创建队列)、`GET /api/batch-tasks`(列表)、`GET /api/batch-tasks/:queueId`(获取队列)、`POST /api/batch-tasks/:queueId/start`(开始执行)、`POST /api/batch-tasks/:queueId/cancel`(取消)、`DELETE /api/batch-tasks/:queueId`(删除队列)、`POST /api/batch-tasks/:queueId/tasks`(添加任务)、`PUT /api/batch-tasks/:queueId/tasks/:taskId`(更新任务)、`DELETE /api/batch-tasks/:queueId/tasks/:taskId`(删除任务)。任务依次顺序执行,每个任务创建独立对话,支持完整状态跟踪。
- **任务控制**:支持暂停/终止长任务、修改参数后重跑、流式获取日志。 - **任务控制**:支持暂停/终止长任务、修改参数后重跑、流式获取日志。
- **安全管理**`/api/auth/change-password` 可即时轮换口令;建议在暴露 MCP 端口时配合网络层 ACL。 - **安全管理**`/api/auth/change-password` 可即时轮换口令;建议在暴露 MCP 端口时配合网络层 ACL。
@@ -257,6 +359,7 @@ knowledge:
top_k: 5 # 检索返回的 Top-K 结果数量 top_k: 5 # 检索返回的 Top-K 结果数量
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤 similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0 表示纯向量检索,0.0 表示纯关键词检索 hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0 表示纯向量检索,0.0 表示纯关键词检索
roles_dir: "roles" # 角色配置文件目录(相对于配置文件所在目录)
``` ```
### 工具模版示例(`tools/nmap.yaml` ### 工具模版示例(`tools/nmap.yaml`
@@ -279,6 +382,26 @@ parameters:
description: "端口范围,如 1-1000" description: "端口范围,如 1-1000"
``` ```
### 角色配置示例(`roles/渗透测试.yaml`
```yaml
name: 渗透测试
description: 专业渗透测试专家,全面深入的漏洞检测
user_prompt: 你是一个专业的网络安全渗透测试专家。请使用专业的渗透测试方法和工具,对目标进行全面的安全测试,包括但不限于SQL注入、XSS、CSRF、文件包含、命令执行等常见漏洞。
icon: "\U0001F3AF"
tools:
- nmap
- sqlmap
- nuclei
- burpsuite
- metasploit
- httpx
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
```
## 项目结构 ## 项目结构
``` ```
@@ -287,6 +410,7 @@ CyberStrikeAI/
├── internal/ # Agent、MCP 核心、路由与执行器 ├── internal/ # Agent、MCP 核心、路由与执行器
├── web/ # 前端静态资源与模板 ├── web/ # 前端静态资源与模板
├── tools/ # YAML 工具目录(含 100+ 示例) ├── tools/ # YAML 工具目录(含 100+ 示例)
├── roles/ # 角色配置文件目录(含 12+ 预设安全测试角色)
├── img/ # 文档配图 ├── img/ # 文档配图
├── config.yaml # 运行配置 ├── config.yaml # 运行配置
├── run.sh # 启动脚本 ├── run.sh # 启动脚本
@@ -312,20 +436,21 @@ CyberStrikeAI/
构建最新一次测试的攻击链,只导出风险 >= 高的节点列表。 构建最新一次测试的攻击链,只导出风险 >= 高的节点列表。
``` ```
## Changelog(近期) ## 更新日志
- 2025-12-25 —— 新增漏洞管理功能:完整的漏洞 CRUD 操作,支持跟踪测试过程中发现的漏洞。支持严重程度分级(严重/高/中/低/信息)、状态流转(待确认/已确认/已修复/误报)、按对话/严重程度/状态过滤,以及统计看板。
- 2025-12-25 —— 新增对话分组功能:支持创建分组、将对话移动到分组、分组置顶、重命名和删除等操作,所有分组数据持久化存储到数据库 详细版本历史和所有变更请查看 [CHANGELOG.md](CHANGELOG.md)
- 2025-12-24 —— 重构攻击链生成逻辑,生成速度提升一倍。重构攻击链前端页面展示,优化用户体验。
- 2025-12-20 —— 新增知识库功能:支持向量检索、混合搜索与自动索引,AI 智能体可在对话中自动搜索安全知识。 ### 近期亮点
- 2025-12-19 —— 新增钟馗之眼(ZoomEye)网络空间搜索引擎工具(zoomeye_search),支持 IPv4/IPv6/Web 等资产搜索、统计项查询与灵活的查询参数配置。
- 2025-12-18 —— 优化 Web 前端界面,增加侧边栏导航,提升用户体验。 - **2026-01-11** – 新增角色化测试功能,支持预设安全测试角色
- 2025-12-07 —— 新增 FOFA 网络空间搜索引擎工具(fofa_search),支持灵活的查询参数与字段配置。 - **2026-01-08** 新增 SSE 传输模式支持,外部 MCP 联邦支持三种模式
- 2025-12-07 —— 修复位置参数处理 bug:当工具参数使用默认值时,确保后续参数位置正确传递。 - **2026-01-01** – 新增批量任务管理功能,支持队列式任务执行
- 2025-11-20 —— 支持超大日志/MCP 记录的自动压缩与摘要回写。 - **2025-12-25** 新增漏洞管理和对话分组功能
- 2025-11-17 —— 上线 AI 驱动的攻击链图谱与风险评分。 - **2025-12-20** – 新增知识库功能,支持向量检索和混合搜索
- 2025-11-15 —— 提供大结果分页检索与外部 MCP 挂载能力。
- 2025-11-14 —— 工具检索 O(1)、执行记录清理、数据库分页优化。 ## Star History
- 2025-11-13 —— Web 鉴权、Settings 面板与 MCP stdio 模式发布。
![Star History Chart](https://api.star-history.com/svg?repos=Ed1s0nZ/CyberStrikeAI&type=date&legend=top-left)
## 404星链计划 ## 404星链计划
<img src="./img/404StarLinkLogo.png" width="30%"> <img src="./img/404StarLinkLogo.png" width="30%">
+56
View File
@@ -0,0 +1,56 @@
# SSE MCP 测试服务器
这是一个用于验证SSE模式外部MCP功能的测试服务器。
## 使用方法
### 1. 启动测试服务器
```bash
cd cmd/test-sse-mcp-server
go run main.go
```
服务器将在 `http://127.0.0.1:8082` 启动,提供以下端点:
- `GET /sse` - SSE事件流端点
- `POST /message` - 消息接收端点
### 2. 在CyberStrikeAI中添加配置
在Web界面中添加外部MCP配置,使用以下JSON:
```json
{
"test-sse-mcp": {
"transport": "sse",
"url": "http://127.0.0.1:8082/sse",
"description": "SSE MCP测试服务器",
"timeout": 30
}
}
```
### 3. 测试功能
测试服务器提供两个测试工具:
1. **test_echo** - 回显输入的文本
- 参数:`text` (string) - 要回显的文本
2. **test_add** - 计算两个数字的和
- 参数:`a` (number) - 第一个数字
- 参数:`b` (number) - 第二个数字
## 工作原理
1. 客户端通过 `GET /sse` 建立SSE连接,接收服务器推送的事件
2. 客户端通过 `POST /message` 发送MCP协议消息
3. 服务器处理消息后,通过SSE连接推送响应
## 日志
服务器会输出以下日志:
- SSE客户端连接/断开
- 收到的请求(方法名和ID
- 工具调用详情
+395
View File
@@ -0,0 +1,395 @@
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"sync"
"time"
"github.com/google/uuid"
)
const ProtocolVersion = "2024-11-05"
// Message MCP消息
type Message struct {
ID interface{} `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
Version string `json:"jsonrpc,omitempty"`
}
// Error MCP错误
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// InitializeRequest 初始化请求
type InitializeRequest struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]interface{} `json:"capabilities"`
ClientInfo ClientInfo `json:"clientInfo"`
}
// ClientInfo 客户端信息
type ClientInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// InitializeResponse 初始化响应
type InitializeResponse struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
}
// ServerCapabilities 服务器能力
type ServerCapabilities struct {
Tools map[string]interface{} `json:"tools,omitempty"`
}
// ServerInfo 服务器信息
type ServerInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// Tool 工具定义
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema map[string]interface{} `json:"inputSchema"`
}
// ListToolsResponse 列出工具响应
type ListToolsResponse struct {
Tools []Tool `json:"tools"`
}
// CallToolRequest 调用工具请求
type CallToolRequest struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// CallToolResponse 调用工具响应
type CallToolResponse struct {
Content []Content `json:"content"`
IsError bool `json:"isError,omitempty"`
}
// Content 内容
type Content struct {
Type string `json:"type"`
Text string `json:"text"`
}
// SSEServer SSE MCP服务器
type SSEServer struct {
sseClients map[string]chan []byte
mu sync.RWMutex
}
func NewSSEServer() *SSEServer {
return &SSEServer{
sseClients: make(map[string]chan []byte),
}
}
// handleSSE 处理SSE连接
func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
clientID := uuid.New().String()
clientChan := make(chan []byte, 10)
s.mu.Lock()
s.sseClients[clientID] = clientChan
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.sseClients, clientID)
close(clientChan)
s.mu.Unlock()
}()
// 发送初始ready事件
fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n")
flusher.Flush()
log.Printf("SSE客户端连接: %s", clientID)
// 心跳
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
log.Printf("SSE客户端断开: %s", clientID)
return
case msg, ok := <-clientChan:
if !ok {
return
}
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
flusher.Flush()
case <-ticker.C:
// 心跳
fmt.Fprintf(w, ": ping\n\n")
flusher.Flush()
}
}
}
// handleMessage 处理POST消息
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var msg Message
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
return
}
log.Printf("收到请求: method=%s, id=%v", msg.Method, msg.ID)
// 处理消息
response := s.processMessage(&msg)
// 如果有SSE客户端,通过SSE推送响应
if response != nil {
responseJSON, _ := json.Marshal(response)
s.mu.RLock()
// 发送给所有SSE客户端
for _, ch := range s.sseClients {
select {
case ch <- responseJSON:
default:
}
}
s.mu.RUnlock()
}
// 也直接返回响应(兼容非SSE模式)
if response != nil {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
} else {
w.WriteHeader(http.StatusOK)
}
}
// processMessage 处理MCP消息
func (s *SSEServer) processMessage(msg *Message) *Message {
switch msg.Method {
case "initialize":
return s.handleInitialize(msg)
case "tools/list":
return s.handleListTools(msg)
case "tools/call":
return s.handleCallTool(msg)
default:
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32601,
Message: "Method not found",
},
}
}
}
// handleInitialize 处理初始化
func (s *SSEServer) handleInitialize(msg *Message) *Message {
var req InitializeRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32602,
Message: "Invalid params",
},
}
}
log.Printf("初始化请求: client=%s, version=%s", req.ClientInfo.Name, req.ClientInfo.Version)
response := InitializeResponse{
ProtocolVersion: ProtocolVersion,
Capabilities: ServerCapabilities{
Tools: map[string]interface{}{
"listChanged": true,
},
},
ServerInfo: ServerInfo{
Name: "Test SSE MCP Server",
Version: "1.0.0",
},
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Version: "2.0",
Result: result,
}
}
// handleListTools 处理列出工具
func (s *SSEServer) handleListTools(msg *Message) *Message {
tools := []Tool{
{
Name: "test_echo",
Description: "回显输入的文本,用于测试SSE MCP服务器",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"text": map[string]interface{}{
"type": "string",
"description": "要回显的文本",
},
},
"required": []string{"text"},
},
},
{
Name: "test_add",
Description: "计算两个数字的和,用于测试SSE MCP服务器",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"a": map[string]interface{}{
"type": "number",
"description": "第一个数字",
},
"b": map[string]interface{}{
"type": "number",
"description": "第二个数字",
},
},
"required": []string{"a", "b"},
},
},
}
response := ListToolsResponse{Tools: tools}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Version: "2.0",
Result: result,
}
}
// handleCallTool 处理工具调用
func (s *SSEServer) handleCallTool(msg *Message) *Message {
var req CallToolRequest
if err := json.Unmarshal(msg.Params, &req); err != nil {
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32602,
Message: "Invalid params",
},
}
}
log.Printf("调用工具: name=%s, args=%v", req.Name, req.Arguments)
var content []Content
switch req.Name {
case "test_echo":
text, _ := req.Arguments["text"].(string)
content = []Content{
{
Type: "text",
Text: fmt.Sprintf("回显: %s", text),
},
}
case "test_add":
var a, b float64
if val, ok := req.Arguments["a"].(float64); ok {
a = val
}
if val, ok := req.Arguments["b"].(float64); ok {
b = val
}
sum := a + b
content = []Content{
{
Type: "text",
Text: fmt.Sprintf("%.2f + %.2f = %.2f", a, b, sum),
},
}
default:
return &Message{
ID: msg.ID,
Version: "2.0",
Error: &Error{
Code: -32601,
Message: "Tool not found",
},
}
}
response := CallToolResponse{
Content: content,
IsError: false,
}
result, _ := json.Marshal(response)
return &Message{
ID: msg.ID,
Version: "2.0",
Result: result,
}
}
func main() {
server := NewSSEServer()
http.HandleFunc("/sse", server.handleSSE)
http.HandleFunc("/message", server.handleMessage)
port := ":8082"
log.Printf("SSE MCP测试服务器启动在端口 %s", port)
log.Printf("SSE端点: http://localhost%s/sse", port)
log.Printf("消息端点: http://localhost%s/message", port)
log.Printf("配置示例:")
log.Printf(`{
"test-sse-mcp": {
"transport": "sse",
"url": "http://127.0.0.1:8082/sse"
}
}`)
if err := http.ListenAndServe(port, nil); err != nil {
log.Fatal(err)
}
}
+4
View File
@@ -66,3 +66,7 @@ knowledge:
top_k: 5 # 检索返回的Top-K结果数量 top_k: 5 # 检索返回的Top-K结果数量
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤 similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0表示纯向量检索,0.0表示纯关键词检索 hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0表示纯向量检索,0.0表示纯关键词检索
# 角色配置
roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录)
# 系统会从该目录加载所有 .yaml 格式的角色配置文件
# 每个角色应创建独立的配置文件,例如:roles/CTF.yaml, roles/默认.yaml 等
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 280 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 273 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 88 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 414 KiB

After

Width:  |  Height:  |  Size: 331 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 382 KiB

+37 -9
View File
@@ -12,6 +12,7 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/storage" "cyberstrike-ai/internal/storage"
@@ -302,16 +303,16 @@ type ProgressCallback func(eventType, message string, data interface{})
// AgentLoop 执行Agent循环 // AgentLoop 执行Agent循环
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) { func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil) return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil)
} }
// AgentLoopWithConversationID 执行Agent循环(带对话ID // AgentLoopWithConversationID 执行Agent循环(带对话ID
func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) { func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil) return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil)
} }
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID) // AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback) (*AgentLoopResult, error) { func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) {
// 设置当前对话ID // 设置当前对话ID
a.mu.Lock() a.mu.Lock()
a.currentConversationID = conversationID a.currentConversationID = conversationID
@@ -401,8 +402,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。 当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
漏洞记录要求: 漏洞记录要求:
- 当你发现有效漏洞时,必须使用 record_vulnerability 工具记录漏洞详情 - 当你发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 工具记录漏洞详情
- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议 ` + `- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议
- 严重程度评估标准: - 严重程度评估标准:
* critical(严重):可导致系统完全被控制、数据泄露、服务中断等 * critical(严重):可导致系统完全被控制、数据泄露、服务中断等
* high(高):可导致敏感信息泄露、权限提升、重要功能被绕过等 * high(高):可导致敏感信息泄露、权限提升、重要功能被绕过等
@@ -512,7 +513,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
} }
// 获取可用工具 // 获取可用工具
tools := a.getAvailableTools() tools := a.getAvailableTools(roleTools)
// 记录当前上下文的Token用量,展示压缩器运行状态 // 记录当前上下文的Token用量,展示压缩器运行状态
if a.memoryCompressor != nil { if a.memoryCompressor != nil {
@@ -837,13 +838,29 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
// getAvailableTools 获取可用工具 // getAvailableTools 获取可用工具
// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗 // 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗
func (a *Agent) getAvailableTools() []Tool { // roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色)
func (a *Agent) getAvailableTools(roleTools []string) []Tool {
// 构建角色工具集合(用于快速查找)
roleToolSet := make(map[string]bool)
if len(roleTools) > 0 {
for _, toolKey := range roleTools {
roleToolSet[toolKey] = true
}
}
// 从MCP服务器获取所有已注册的内部工具 // 从MCP服务器获取所有已注册的内部工具
mcpTools := a.mcpServer.GetAllTools() mcpTools := a.mcpServer.GetAllTools()
// 转换为OpenAI格式的工具定义 // 转换为OpenAI格式的工具定义
tools := make([]Tool, 0, len(mcpTools)) tools := make([]Tool, 0, len(mcpTools))
for _, mcpTool := range mcpTools { for _, mcpTool := range mcpTools {
// 如果指定了角色工具列表,只添加在列表中的工具
if len(roleToolSet) > 0 {
toolKey := mcpTool.Name // 内置工具使用工具名称作为key
if !roleToolSet[toolKey] {
continue // 不在角色工具列表中,跳过
}
}
// 使用简短描述(如果存在),否则使用详细描述 // 使用简短描述(如果存在),否则使用详细描述
description := mcpTool.ShortDescription description := mcpTool.ShortDescription
if description == "" { if description == "" {
@@ -865,7 +882,8 @@ func (a *Agent) getAvailableTools() []Tool {
// 获取外部MCP工具 // 获取外部MCP工具
if a.externalMCPMgr != nil { if a.externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // 增加超时时间到30秒,因为通过代理连接远程服务器可能需要更长时间
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
externalTools, err := a.externalMCPMgr.GetAllTools(ctx) externalTools, err := a.externalMCPMgr.GetAllTools(ctx)
@@ -882,6 +900,16 @@ func (a *Agent) getAvailableTools() []Tool {
// 将外部MCP工具添加到工具列表(只添加启用的工具) // 将外部MCP工具添加到工具列表(只添加启用的工具)
for _, externalTool := range externalTools { for _, externalTool := range externalTools {
// 外部工具使用 "mcpName::toolName" 作为toolKey
externalToolKey := externalTool.Name
// 如果指定了角色工具列表,只添加在列表中的工具
if len(roleToolSet) > 0 {
if !roleToolSet[externalToolKey] {
continue // 不在角色工具列表中,跳过
}
}
// 解析工具名称:mcpName::toolName // 解析工具名称:mcpName::toolName
var mcpName, actualToolName string var mcpName, actualToolName string
if idx := strings.Index(externalTool.Name, "::"); idx > 0 { if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
@@ -1135,7 +1163,7 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
) )
// 如果是record_vulnerability工具,自动添加conversation_id // 如果是record_vulnerability工具,自动添加conversation_id
if toolName == "record_vulnerability" { if toolName == builtin.ToolRecordVulnerability {
a.mu.RLock() a.mu.RLock()
conversationID := a.currentConversationID conversationID := a.currentConversationID
a.mu.RUnlock() a.mu.RUnlock()
+102 -18
View File
@@ -16,6 +16,7 @@ import (
"cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/storage" "cyberstrike-ai/internal/storage"
@@ -214,23 +215,53 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
return return
} }
if hasIndex { if hasIndex {
// 如果已有索引,只索引新添加或更新的项 // 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 { if len(itemsToIndex) > 0 {
log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex))) log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background() ctx := context.Background()
for _, itemID := range itemsToIndex { consecutiveFailures := 0
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil { var firstFailureItemID string
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) var firstFailureError error
continue failedCount := 0
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 如果连续失败2次,立即停止增量索引
if consecutiveFailures >= 2 {
log.Logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
} }
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
} else {
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
} }
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex))) return
} else {
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
} }
return
}
// 只有在没有索引时才自动重建 // 只有在没有索引时才自动重建
log.Logger.Info("未检测到知识库索引,开始自动构建索引") log.Logger.Info("未检测到知识库索引,开始自动构建索引")
@@ -248,7 +279,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
} }
// 创建处理器 // 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, log.Logger) agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志 // 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
if knowledgeManager != nil { if knowledgeManager != nil {
agentHandler.SetKnowledgeManager(knowledgeManager) agentHandler.SetKnowledgeManager(knowledgeManager)
@@ -262,6 +293,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger) vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
// 创建 App 实例(部分字段稍后填充) // 创建 App 实例(部分字段稍后填充)
app := &App{ app := &App{
@@ -338,6 +370,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
attackChainHandler, attackChainHandler,
app, // 传递 App 实例以便动态获取 knowledgeHandler app, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler, vulnerabilityHandler,
roleHandler,
mcpServer, mcpServer,
authManager, authManager,
) )
@@ -398,6 +431,7 @@ func setupRoutes(
attackChainHandler *handler.AttackChainHandler, attackChainHandler *handler.AttackChainHandler,
app *App, // 传递 App 实例以便动态获取 knowledgeHandler app *App, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler *handler.VulnerabilityHandler, vulnerabilityHandler *handler.VulnerabilityHandler,
roleHandler *handler.RoleHandler,
mcpServer *mcp.Server, mcpServer *mcp.Server,
authManager *security.AuthManager, authManager *security.AuthManager,
) { ) {
@@ -423,6 +457,18 @@ func setupRoutes(
// Agent Loop 取消与任务列表 // Agent Loop 取消与任务列表
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop) protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks) protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
protected.GET("/agent-loop/tasks/completed", agentHandler.ListCompletedTasks)
// 批量任务管理
protected.POST("/batch-tasks", agentHandler.CreateBatchQueue)
protected.GET("/batch-tasks", agentHandler.ListBatchQueues)
protected.GET("/batch-tasks/:queueId", agentHandler.GetBatchQueue)
protected.POST("/batch-tasks/:queueId/start", agentHandler.StartBatchQueue)
protected.POST("/batch-tasks/:queueId/pause", agentHandler.PauseBatchQueue)
protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue)
protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask)
protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask)
protected.DELETE("/batch-tasks/:queueId/tasks/:taskId", agentHandler.DeleteBatchTask)
// 对话历史 // 对话历史
protected.POST("/conversations", conversationHandler.CreateConversation) protected.POST("/conversations", conversationHandler.CreateConversation)
@@ -448,6 +494,7 @@ func setupRoutes(
protected.GET("/monitor", monitorHandler.Monitor) protected.GET("/monitor", monitorHandler.Monitor)
protected.GET("/monitor/execution/:id", monitorHandler.GetExecution) protected.GET("/monitor/execution/:id", monitorHandler.GetExecution)
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution) protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
protected.GET("/monitor/stats", monitorHandler.GetStats) protected.GET("/monitor/stats", monitorHandler.GetStats)
// 配置管理 // 配置管理
@@ -610,6 +657,13 @@ func setupRoutes(
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability) protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability) protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
// 角色管理
protected.GET("/roles", roleHandler.GetRoles)
protected.GET("/roles/:name", roleHandler.GetRole)
protected.POST("/roles", roleHandler.CreateRole)
protected.PUT("/roles/:name", roleHandler.UpdateRole)
protected.DELETE("/roles/:name", roleHandler.DeleteRole)
// MCP端点 // MCP端点
protected.POST("/mcp", func(c *gin.Context) { protected.POST("/mcp", func(c *gin.Context) {
mcpServer.HandleHTTP(c.Writer, c.Request) mcpServer.HandleHTTP(c.Writer, c.Request)
@@ -629,7 +683,7 @@ func setupRoutes(
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器 // registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) { func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{ tool := mcp.Tool{
Name: "record_vulnerability", Name: builtin.ToolRecordVulnerability,
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。", Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
ShortDescription: "记录发现的漏洞详情到漏洞管理系统", ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
InputSchema: map[string]interface{}{ InputSchema: map[string]interface{}{
@@ -921,13 +975,43 @@ func initializeKnowledge(
if len(itemsToIndex) > 0 { if len(itemsToIndex) > 0 {
logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex))) logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background() ctx := context.Background()
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
failedCount := 0
for _, itemID := range itemsToIndex { for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil { if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 如果连续失败2次,立即停止增量索引
if consecutiveFailures >= 2 {
logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue continue
} }
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
} }
logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex))) logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
} else { } else {
logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项") logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
} }
+136
View File
@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@@ -22,6 +23,8 @@ type Config struct {
Auth AuthConfig `yaml:"auth"` Auth AuthConfig `yaml:"auth"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"` Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
} }
type ServerConfig struct { type ServerConfig struct {
@@ -81,6 +84,7 @@ type ExternalMCPServerConfig struct {
// stdio模式配置 // stdio模式配置
Command string `yaml:"command,omitempty" json:"command,omitempty"` Command string `yaml:"command,omitempty" json:"command,omitempty"`
Args []string `yaml:"args,omitempty" json:"args,omitempty"` Args []string `yaml:"args,omitempty" json:"args,omitempty"`
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式)
// HTTP模式配置 // HTTP模式配置
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio" Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio"
@@ -206,6 +210,29 @@ func Load(path string) (*Config, error) {
} }
} }
// 从角色目录加载角色配置
if cfg.RolesDir != "" {
configDir := filepath.Dir(path)
rolesDir := cfg.RolesDir
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
roles, err := LoadRolesFromDir(rolesDir)
if err != nil {
return nil, fmt.Errorf("从角色目录加载角色配置失败: %w", err)
}
cfg.Roles = roles
} else {
// 如果未配置 roles_dir,初始化为空 map
if cfg.Roles == nil {
cfg.Roles = make(map[string]RoleConfig)
}
}
return &cfg, nil return &cfg, nil
} }
@@ -374,6 +401,98 @@ func LoadToolFromFile(path string) (*ToolConfig, error) {
return &tool, nil return &tool, nil
} }
// LoadRolesFromDir 从目录加载所有角色配置文件
func LoadRolesFromDir(dir string) (map[string]RoleConfig, error) {
roles := make(map[string]RoleConfig)
// 检查目录是否存在
if _, err := os.Stat(dir); os.IsNotExist(err) {
return roles, nil // 目录不存在时返回空map,不报错
}
// 读取目录中的所有 .yaml 和 .yml 文件
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("读取角色目录失败: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
continue
}
filePath := filepath.Join(dir, name)
role, err := LoadRoleFromFile(filePath)
if err != nil {
// 记录错误但继续加载其他文件
fmt.Printf("警告: 加载角色配置文件 %s 失败: %v\n", filePath, err)
continue
}
// 使用角色名称作为key
roleName := role.Name
if roleName == "" {
// 如果角色名称为空,使用文件名(去掉扩展名)作为名称
roleName = strings.TrimSuffix(strings.TrimSuffix(name, ".yaml"), ".yml")
role.Name = roleName
}
roles[roleName] = *role
}
return roles, nil
}
// LoadRoleFromFile 从单个文件加载角色配置
func LoadRoleFromFile(path string) (*RoleConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
var role RoleConfig
if err := yaml.Unmarshal(data, &role); err != nil {
return nil, fmt.Errorf("解析角色配置失败: %w", err)
}
// 处理 icon 字段:如果包含 Unicode 转义格式(\U0001F3C6),转换为实际的 Unicode 字符
// Go 的 yaml 库可能不会自动解析 \U 转义序列,需要手动转换
if role.Icon != "" {
icon := role.Icon
// 去除可能的引号
icon = strings.Trim(icon, `"`)
// 检查是否是 Unicode 转义格式 \U0001F3C68位十六进制)或 \uXXXX(4位十六进制)
if len(icon) >= 3 && icon[0] == '\\' {
if icon[1] == 'U' && len(icon) >= 10 {
// \U0001F3C6 格式(8位十六进制)
if codePoint, err := strconv.ParseInt(icon[2:10], 16, 32); err == nil {
role.Icon = string(rune(codePoint))
}
} else if icon[1] == 'u' && len(icon) >= 6 {
// \uXXXX 格式(4位十六进制)
if codePoint, err := strconv.ParseInt(icon[2:6], 16, 32); err == nil {
role.Icon = string(rune(codePoint))
}
}
}
}
// 验证必需字段
if role.Name == "" {
// 如果名称为空,尝试从文件名获取
baseName := filepath.Base(path)
role.Name = strings.TrimSuffix(strings.TrimSuffix(baseName, ".yaml"), ".yml")
}
return &role, nil
}
func Default() *Config { func Default() *Config {
return &Config{ return &Config{
Server: ServerConfig{ Server: ServerConfig{
@@ -447,3 +566,20 @@ type RetrievalConfig struct {
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 相似度阈值 SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 相似度阈值
HybridWeight float64 `yaml:"hybrid_weight" json:"hybrid_weight"` // 向量检索权重(0-1 HybridWeight float64 `yaml:"hybrid_weight" json:"hybrid_weight"` // 向量检索权重(0-1
} }
// RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代)
// 保留此类型以兼容旧代码,但建议直接使用 map[string]RoleConfig
type RolesConfig struct {
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"`
}
// RoleConfig 单个角色配置
type RoleConfig struct {
Name string `yaml:"name" json:"name"` // 角色名称
Description string `yaml:"description" json:"description"` // 角色描述
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName"
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
}
+389
View File
@@ -0,0 +1,389 @@
package database
import (
"database/sql"
"fmt"
"time"
"go.uber.org/zap"
)
// BatchTaskQueueRow 批量任务队列数据库行
type BatchTaskQueueRow struct {
ID string
Title sql.NullString
Status string
CreatedAt time.Time
StartedAt sql.NullTime
CompletedAt sql.NullTime
CurrentIndex int
}
// BatchTaskRow 批量任务数据库行
type BatchTaskRow struct {
ID string
QueueID string
Message string
ConversationID sql.NullString
Status string
StartedAt sql.NullTime
CompletedAt sql.NullTime
Error sql.NullString
Result sql.NullString
}
// CreateBatchQueue 创建批量任务队列
func (db *DB) CreateBatchQueue(queueID string, title string, tasks []map[string]interface{}) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
}
defer tx.Rollback()
now := time.Now()
_, err = tx.Exec(
"INSERT INTO batch_task_queues (id, title, status, created_at, current_index) VALUES (?, ?, ?, ?, ?)",
queueID, title, "pending", now, 0,
)
if err != nil {
return fmt.Errorf("创建批量任务队列失败: %w", err)
}
// 插入任务
for _, task := range tasks {
taskID, ok := task["id"].(string)
if !ok {
continue
}
message, ok := task["message"].(string)
if !ok {
continue
}
_, err = tx.Exec(
"INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)",
taskID, queueID, message, "pending",
)
if err != nil {
return fmt.Errorf("创建批量任务失败: %w", err)
}
}
return tx.Commit()
}
// GetBatchQueue 获取批量任务队列
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
var row BatchTaskQueueRow
var createdAt string
err := db.QueryRow(
"SELECT id, title, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
queueID,
).Scan(&row.ID, &row.Title, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
if parseErr != nil {
// 尝试其他时间格式
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
if parseErr != nil {
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
parsedTime = time.Now()
}
}
row.CreatedAt = parsedTime
return &row, nil
}
// GetAllBatchQueues 获取所有批量任务队列
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
rows, err := db.Query(
"SELECT id, title, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
)
if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
}
defer rows.Close()
var queues []*BatchTaskQueueRow
for rows.Next() {
var row BatchTaskQueueRow
var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
if parseErr != nil {
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
if parseErr != nil {
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
parsedTime = time.Now()
}
}
row.CreatedAt = parsedTime
queues = append(queues, &row)
}
return queues, nil
}
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
query := "SELECT id, title, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
args := []interface{}{}
// 状态筛选
if status != "" && status != "all" {
query += " AND status = ?"
args = append(args, status)
}
// 关键字搜索(搜索队列ID和标题)
if keyword != "" {
query += " AND (id LIKE ? OR title LIKE ?)"
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
}
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
}
defer rows.Close()
var queues []*BatchTaskQueueRow
for rows.Next() {
var row BatchTaskQueueRow
var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
if parseErr != nil {
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
if parseErr != nil {
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
parsedTime = time.Now()
}
}
row.CreatedAt = parsedTime
queues = append(queues, &row)
}
return queues, nil
}
// CountBatchQueues 统计批量任务队列总数(支持筛选条件)
func (db *DB) CountBatchQueues(status, keyword string) (int, error) {
query := "SELECT COUNT(*) FROM batch_task_queues WHERE 1=1"
args := []interface{}{}
// 状态筛选
if status != "" && status != "all" {
query += " AND status = ?"
args = append(args, status)
}
// 关键字搜索(搜索队列ID和标题)
if keyword != "" {
query += " AND (id LIKE ? OR title LIKE ?)"
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
}
var count int
err := db.QueryRow(query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("统计批量任务队列总数失败: %w", err)
}
return count, nil
}
// GetBatchTasks 获取批量任务队列的所有任务
func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) {
rows, err := db.Query(
"SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY id",
queueID,
)
if err != nil {
return nil, fmt.Errorf("查询批量任务失败: %w", err)
}
defer rows.Close()
var tasks []*BatchTaskRow
for rows.Next() {
var task BatchTaskRow
if err := rows.Scan(
&task.ID, &task.QueueID, &task.Message, &task.ConversationID,
&task.Status, &task.StartedAt, &task.CompletedAt, &task.Error, &task.Result,
); err != nil {
return nil, fmt.Errorf("扫描批量任务失败: %w", err)
}
tasks = append(tasks, &task)
}
return tasks, nil
}
// UpdateBatchQueueStatus 更新批量任务队列状态
func (db *DB) UpdateBatchQueueStatus(queueID, status string) error {
var err error
now := time.Now()
if status == "running" {
_, err = db.Exec(
"UPDATE batch_task_queues SET status = ?, started_at = COALESCE(started_at, ?) WHERE id = ?",
status, now, queueID,
)
} else if status == "completed" || status == "cancelled" {
_, err = db.Exec(
"UPDATE batch_task_queues SET status = ?, completed_at = COALESCE(completed_at, ?) WHERE id = ?",
status, now, queueID,
)
} else {
_, err = db.Exec(
"UPDATE batch_task_queues SET status = ? WHERE id = ?",
status, queueID,
)
}
if err != nil {
return fmt.Errorf("更新批量任务队列状态失败: %w", err)
}
return nil
}
// UpdateBatchTaskStatus 更新批量任务状态
func (db *DB) UpdateBatchTaskStatus(queueID, taskID, status string, conversationID, result, errorMsg string) error {
var err error
now := time.Now()
// 构建更新语句
var updates []string
var args []interface{}
updates = append(updates, "status = ?")
args = append(args, status)
if conversationID != "" {
updates = append(updates, "conversation_id = ?")
args = append(args, conversationID)
}
if result != "" {
updates = append(updates, "result = ?")
args = append(args, result)
}
if errorMsg != "" {
updates = append(updates, "error = ?")
args = append(args, errorMsg)
}
if status == "running" {
updates = append(updates, "started_at = COALESCE(started_at, ?)")
args = append(args, now)
}
if status == "completed" || status == "failed" || status == "cancelled" {
updates = append(updates, "completed_at = COALESCE(completed_at, ?)")
args = append(args, now)
}
args = append(args, queueID, taskID)
// 构建SQL语句
sql := "UPDATE batch_tasks SET "
for i, update := range updates {
if i > 0 {
sql += ", "
}
sql += update
}
sql += " WHERE queue_id = ? AND id = ?"
_, err = db.Exec(sql, args...)
if err != nil {
return fmt.Errorf("更新批量任务状态失败: %w", err)
}
return nil
}
// UpdateBatchQueueCurrentIndex 更新批量任务队列的当前索引
func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) error {
_, err := db.Exec(
"UPDATE batch_task_queues SET current_index = ? WHERE id = ?",
currentIndex, queueID,
)
if err != nil {
return fmt.Errorf("更新批量任务队列当前索引失败: %w", err)
}
return nil
}
// UpdateBatchTaskMessage 更新批量任务消息
func (db *DB) UpdateBatchTaskMessage(queueID, taskID, message string) error {
_, err := db.Exec(
"UPDATE batch_tasks SET message = ? WHERE queue_id = ? AND id = ?",
message, queueID, taskID,
)
if err != nil {
return fmt.Errorf("更新批量任务消息失败: %w", err)
}
return nil
}
// AddBatchTask 添加任务到批量任务队列
func (db *DB) AddBatchTask(queueID, taskID, message string) error {
_, err := db.Exec(
"INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)",
taskID, queueID, message, "pending",
)
if err != nil {
return fmt.Errorf("添加批量任务失败: %w", err)
}
return nil
}
// DeleteBatchTask 删除批量任务
func (db *DB) DeleteBatchTask(queueID, taskID string) error {
_, err := db.Exec(
"DELETE FROM batch_tasks WHERE queue_id = ? AND id = ?",
queueID, taskID,
)
if err != nil {
return fmt.Errorf("删除批量任务失败: %w", err)
}
return nil
}
// DeleteBatchQueue 删除批量任务队列
func (db *DB) DeleteBatchQueue(queueID string) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
}
defer tx.Rollback()
// 删除任务(外键会自动级联删除)
_, err = tx.Exec("DELETE FROM batch_tasks WHERE queue_id = ?", queueID)
if err != nil {
return fmt.Errorf("删除批量任务失败: %w", err)
}
// 删除队列
_, err = tx.Exec("DELETE FROM batch_task_queues WHERE id = ?", queueID)
if err != nil {
return fmt.Errorf("删除批量任务队列失败: %w", err)
}
return tx.Commit()
}
+67
View File
@@ -189,6 +189,33 @@ func (db *DB) initTables() error {
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
);` );`
// 创建批量任务队列表
createBatchTaskQueuesTable := `
CREATE TABLE IF NOT EXISTS batch_task_queues (
id TEXT PRIMARY KEY,
title TEXT,
status TEXT NOT NULL,
created_at DATETIME NOT NULL,
started_at DATETIME,
completed_at DATETIME,
current_index INTEGER NOT NULL DEFAULT 0
);`
// 创建批量任务表
createBatchTasksTable := `
CREATE TABLE IF NOT EXISTS batch_tasks (
id TEXT PRIMARY KEY,
queue_id TEXT NOT NULL,
message TEXT NOT NULL,
conversation_id TEXT,
status TEXT NOT NULL,
started_at DATETIME,
completed_at DATETIME,
error TEXT,
result TEXT,
FOREIGN KEY (queue_id) REFERENCES batch_task_queues(id) ON DELETE CASCADE
);`
// 创建索引 // 创建索引
createIndexes := ` createIndexes := `
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id); CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
@@ -212,6 +239,9 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at);
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
` `
if _, err := db.Exec(createConversationsTable); err != nil { if _, err := db.Exec(createConversationsTable); err != nil {
@@ -258,6 +288,14 @@ func (db *DB) initTables() error {
return fmt.Errorf("创建vulnerabilities表失败: %w", err) return fmt.Errorf("创建vulnerabilities表失败: %w", err)
} }
if _, err := db.Exec(createBatchTaskQueuesTable); err != nil {
return fmt.Errorf("创建batch_task_queues表失败: %w", err)
}
if _, err := db.Exec(createBatchTasksTable); err != nil {
return fmt.Errorf("创建batch_tasks表失败: %w", err)
}
// 为已有表添加新字段(如果不存在)- 必须在创建索引之前 // 为已有表添加新字段(如果不存在)- 必须在创建索引之前
if err := db.migrateConversationsTable(); err != nil { if err := db.migrateConversationsTable(); err != nil {
db.logger.Warn("迁移conversations表失败", zap.Error(err)) db.logger.Warn("迁移conversations表失败", zap.Error(err))
@@ -274,6 +312,11 @@ func (db *DB) initTables() error {
// 不返回错误,允许继续运行 // 不返回错误,允许继续运行
} }
if err := db.migrateBatchTaskQueuesTable(); err != nil {
db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if _, err := db.Exec(createIndexes); err != nil { if _, err := db.Exec(createIndexes); err != nil {
return fmt.Errorf("创建索引失败: %w", err) return fmt.Errorf("创建索引失败: %w", err)
} }
@@ -390,6 +433,30 @@ func (db *DB) migrateConversationGroupMappingsTable() error {
return nil return nil
} }
// migrateBatchTaskQueuesTable 迁移batch_task_queues表,添加title字段
func (db *DB) migrateBatchTaskQueuesTable() error {
// 检查title字段是否存在
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='title'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); addErr != nil {
// 如果字段已存在,忽略错误
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加title字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); err != nil {
db.logger.Warn("添加title字段失败", zap.Error(err))
}
}
return nil
}
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表) // NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) { func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1") sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
+73 -1
View File
@@ -205,7 +205,7 @@ func (db *DB) AddConversationToGroup(conversationID, groupID string) error {
if err != nil { if err != nil {
return fmt.Errorf("删除对话旧分组关联失败: %w", err) return fmt.Errorf("删除对话旧分组关联失败: %w", err)
} }
// 然后插入新的分组关联 // 然后插入新的分组关联
id := uuid.New().String() id := uuid.New().String()
_, err = db.Exec( _, err = db.Exec(
@@ -282,6 +282,78 @@ func (db *DB) GetConversationsByGroup(groupID string) ([]*Conversation, error) {
return conversations, nil return conversations, nil
} }
// SearchConversationsByGroup 搜索分组中的对话(按标题和消息内容模糊匹配)
func (db *DB) SearchConversationsByGroup(groupID string, searchQuery string) ([]*Conversation, error) {
// 构建SQL查询,支持按标题和消息内容搜索
// 使用 DISTINCT 避免因为一个对话有多条匹配消息而重复
query := `SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned
FROM conversations c
INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id
WHERE cgm.group_id = ?`
args := []interface{}{groupID}
// 如果有搜索关键词,添加标题和消息内容搜索条件
if searchQuery != "" {
searchPattern := "%" + searchQuery + "%"
// 搜索标题或消息内容
// 使用 LEFT JOIN 连接消息表,这样即使没有消息的对话也能被搜索到(通过标题)
query += ` AND (
LOWER(c.title) LIKE LOWER(?)
OR EXISTS (
SELECT 1 FROM messages m
WHERE m.conversation_id = c.id
AND LOWER(m.content) LIKE LOWER(?)
)
)`
args = append(args, searchPattern, searchPattern)
}
query += " ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC"
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("搜索分组对话失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
var pinned int
var groupPinned int
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
conv.Pinned = pinned != 0
conversations = append(conversations, &conv)
}
return conversations, nil
}
// GetGroupByConversation 获取对话所属的分组 // GetGroupByConversation 获取对话所属的分组
func (db *DB) GetGroupByConversation(conversationID string) (string, error) { func (db *DB) GetGroupByConversation(conversationID string) (string, error) {
var groupID string var groupID string
+144 -6
View File
@@ -3,9 +3,11 @@ package database
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"strings"
"time" "time"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -70,13 +72,25 @@ func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
} }
// CountToolExecutions 统计工具执行记录总数 // CountToolExecutions 统计工具执行记录总数
func (db *DB) CountToolExecutions(status string) (int, error) { func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
query := `SELECT COUNT(*) FROM tool_executions` query := `SELECT COUNT(*) FROM tool_executions`
args := []interface{}{} args := []interface{}{}
conditions := []string{}
if status != "" { if status != "" {
query += ` WHERE status = ?` conditions = append(conditions, "status = ?")
args = append(args, status) args = append(args, status)
} }
if toolName != "" {
// 支持部分匹配(模糊搜索),不区分大小写
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
args = append(args, "%"+strings.ToLower(toolName)+"%")
}
if len(conditions) > 0 {
query += ` WHERE ` + conditions[0]
for i := 1; i < len(conditions); i++ {
query += ` AND ` + conditions[i]
}
}
var count int var count int
err := db.QueryRow(query, args...).Scan(&count) err := db.QueryRow(query, args...).Scan(&count)
if err != nil { if err != nil {
@@ -87,30 +101,43 @@ func (db *DB) CountToolExecutions(status string) (int, error) {
// LoadToolExecutions 加载所有工具执行记录(支持分页) // LoadToolExecutions 加载所有工具执行记录(支持分页)
func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) { func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) {
return db.LoadToolExecutionsWithPagination(0, 1000, "") return db.LoadToolExecutionsWithPagination(0, 1000, "", "")
} }
// LoadToolExecutionsWithPagination 分页加载工具执行记录 // LoadToolExecutionsWithPagination 分页加载工具执行记录
// limit: 最大返回记录数,0 表示使用默认值 1000 // limit: 最大返回记录数,0 表示使用默认值 1000
// offset: 跳过的记录数,用于分页 // offset: 跳过的记录数,用于分页
// status: 状态筛选,空字符串表示不过滤 // status: 状态筛选,空字符串表示不过滤
func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status string) ([]*mcp.ToolExecution, error) { // toolName: 工具名称筛选,空字符串表示不过滤
func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) {
if limit <= 0 { if limit <= 0 {
limit = 1000 // 默认限制 limit = 1000 // 默认限制
} }
if limit > 10000 { if limit > 10000 {
limit = 10000 // 最大限制,防止一次性加载过多数据 limit = 10000 // 最大限制,防止一次性加载过多数据
} }
query := ` query := `
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
FROM tool_executions FROM tool_executions
` `
args := []interface{}{} args := []interface{}{}
conditions := []string{}
if status != "" { if status != "" {
query += ` WHERE status = ?` conditions = append(conditions, "status = ?")
args = append(args, status) args = append(args, status)
} }
if toolName != "" {
// 支持部分匹配(模糊搜索),不区分大小写
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
args = append(args, "%"+strings.ToLower(toolName)+"%")
}
if len(conditions) > 0 {
query += ` WHERE ` + conditions[0]
for i := 1; i < len(conditions); i++ {
query += ` AND ` + conditions[i]
}
}
query += ` ORDER BY start_time DESC LIMIT ? OFFSET ?` query += ` ORDER BY start_time DESC LIMIT ? OFFSET ?`
args = append(args, limit, offset) args = append(args, limit, offset)
@@ -254,6 +281,117 @@ func (db *DB) DeleteToolExecution(id string) error {
return nil return nil
} }
// DeleteToolExecutions 批量删除工具执行记录
func (db *DB) DeleteToolExecutions(ids []string) error {
if len(ids) == 0 {
return nil
}
// 构建 IN 查询的占位符
placeholders := make([]string, len(ids))
args := make([]interface{}, len(ids))
for i, id := range ids {
placeholders[i] = "?"
args[i] = id
}
query := `DELETE FROM tool_executions WHERE id IN (` + strings.Join(placeholders, ",") + `)`
_, err := db.Exec(query, args...)
if err != nil {
db.logger.Error("批量删除工具执行记录失败", zap.Error(err), zap.Int("count", len(ids)))
return err
}
return nil
}
// GetToolExecutionsByIds 根据ID列表获取工具执行记录(用于批量删除前获取统计信息)
func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error) {
if len(ids) == 0 {
return []*mcp.ToolExecution{}, nil
}
// 构建 IN 查询的占位符
placeholders := make([]string, len(ids))
args := make([]interface{}, len(ids))
for i, id := range ids {
placeholders[i] = "?"
args[i] = id
}
query := `
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
FROM tool_executions
WHERE id IN (` + strings.Join(placeholders, ",") + `)
`
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var executions []*mcp.ToolExecution
for rows.Next() {
var exec mcp.ToolExecution
var argsJSON string
var resultJSON sql.NullString
var errorText sql.NullString
var endTime sql.NullTime
var durationMs sql.NullInt64
err := rows.Scan(
&exec.ID,
&exec.ToolName,
&argsJSON,
&exec.Status,
&resultJSON,
&errorText,
&exec.StartTime,
&endTime,
&durationMs,
)
if err != nil {
db.logger.Warn("加载执行记录失败", zap.Error(err))
continue
}
// 解析参数
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
db.logger.Warn("解析执行参数失败", zap.Error(err))
exec.Arguments = make(map[string]interface{})
}
// 解析结果
if resultJSON.Valid && resultJSON.String != "" {
var result mcp.ToolResult
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
db.logger.Warn("解析执行结果失败", zap.Error(err))
} else {
exec.Result = &result
}
}
// 设置错误
if errorText.Valid {
exec.Error = errorText.String
}
// 设置结束时间
if endTime.Valid {
exec.EndTime = &endTime.Time
}
// 设置持续时间
if durationMs.Valid {
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
}
executions = append(executions, &exec)
}
return executions, nil
}
// SaveToolStats 保存工具统计信息 // SaveToolStats 保存工具统计信息
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error { func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
var lastCallTime sql.NullTime var lastCallTime sql.NullTime
+675 -152
View File
@@ -6,12 +6,15 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"strings" "strings"
"unicode/utf8"
"time" "time"
"unicode/utf8"
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp/builtin"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
@@ -25,16 +28,16 @@ func safeTruncateString(s string, maxLen int) string {
if utf8.RuneCountInString(s) <= maxLen { if utf8.RuneCountInString(s) <= maxLen {
return s return s
} }
// 将字符串转换为 rune 切片以正确计算字符数 // 将字符串转换为 rune 切片以正确计算字符数
runes := []rune(s) runes := []rune(s)
if len(runes) <= maxLen { if len(runes) <= maxLen {
return s return s
} }
// 截断到最大长度 // 截断到最大长度
truncated := string(runes[:maxLen]) truncated := string(runes[:maxLen])
// 尝试在标点符号或空格处截断,使截断更自然 // 尝试在标点符号或空格处截断,使截断更自然
// 在截断点往前查找合适的断点(不超过20%的长度) // 在截断点往前查找合适的断点(不超过20%的长度)
searchRange := maxLen / 5 searchRange := maxLen / 5
@@ -43,7 +46,7 @@ func safeTruncateString(s string, maxLen int) string {
} }
breakChars := []rune(",。、 ,.;:!?!?/\\-_") breakChars := []rune(",。、 ,.;:!?!?/\\-_")
bestBreakPos := len(runes[:maxLen]) bestBreakPos := len(runes[:maxLen])
for i := bestBreakPos - 1; i >= bestBreakPos-searchRange && i >= 0; i-- { for i := bestBreakPos - 1; i >= bestBreakPos-searchRange && i >= 0; i-- {
for _, breakChar := range breakChars { for _, breakChar := range breakChars {
if runes[i] == breakChar { if runes[i] == breakChar {
@@ -52,7 +55,7 @@ func safeTruncateString(s string, maxLen int) string {
} }
} }
} }
found: found:
truncated = string(runes[:bestBreakPos]) truncated = string(runes[:bestBreakPos])
return truncated + "..." return truncated + "..."
@@ -64,18 +67,30 @@ type AgentHandler struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
tasks *AgentTaskManager tasks *AgentTaskManager
knowledgeManager interface { // 知识库管理器接口 batchTaskManager *BatchTaskManager
config *config.Config // 配置引用,用于获取角色信息
knowledgeManager interface { // 知识库管理器接口
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
} }
} }
// NewAgentHandler 创建新的Agent处理器 // NewAgentHandler 创建新的Agent处理器
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler { func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, logger *zap.Logger) *AgentHandler {
batchTaskManager := NewBatchTaskManager()
batchTaskManager.SetDB(db)
// 从数据库加载所有批量任务队列
if err := batchTaskManager.LoadFromDB(); err != nil {
logger.Warn("从数据库加载批量任务队列失败", zap.Error(err))
}
return &AgentHandler{ return &AgentHandler{
agent: agent, agent: agent,
db: db, db: db,
logger: logger, logger: logger,
tasks: NewAgentTaskManager(), tasks: NewAgentTaskManager(),
batchTaskManager: batchTaskManager,
config: cfg,
} }
} }
@@ -90,6 +105,7 @@ func (h *AgentHandler) SetKnowledgeManager(manager interface {
type ChatRequest struct { type ChatRequest struct {
Message string `json:"message" binding:"required"` Message string `json:"message" binding:"required"`
ConversationID string `json:"conversationId,omitempty"` ConversationID string `json:"conversationId,omitempty"`
Role string `json:"role,omitempty"` // 角色名称
} }
// ChatResponse 聊天响应 // ChatResponse 聊天响应
@@ -150,14 +166,34 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
} }
// 保存用户消息 // 应用角色用户提示词和工具配置
finalMessage := req.Message
var roleTools []string // 角色配置的工具列表
if req.Role != "" && req.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + req.Message
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
}
}
}
}
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil) _, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil { if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err)) h.logger.Error("保存用户消息失败", zap.Error(err))
} }
// 执行Agent Loop,传入历史消息和对话ID // 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
result, err := h.agent.AgentLoopWithConversationID(c.Request.Context(), req.Message, agentHistoryMessages, conversationID) result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools)
if err != nil { if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err)) h.logger.Error("Agent Loop执行失败", zap.Error(err))
@@ -204,147 +240,23 @@ type StreamEvent struct {
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
} }
// AgentLoopStream 处理Agent Loop流式请求 // createProgressCallback 创建进度回调函数,用于保存processDetails
func (h *AgentHandler) AgentLoopStream(c *gin.Context) { // sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件
var req ChatRequest func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
if err := c.ShouldBindJSON(&req); err != nil {
// 对于流式请求,也发送SSE格式的错误
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
event := StreamEvent{
Type: "error",
Message: "请求参数错误: " + err.Error(),
}
eventJSON, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
c.Writer.Flush()
return
}
h.logger.Info("收到Agent Loop流式请求",
zap.String("message", req.Message),
zap.String("conversationId", req.ConversationID),
)
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no") // 禁用nginx缓冲
// 发送初始事件
// 用于跟踪客户端是否已断开连接
clientDisconnected := false
sendEvent := func(eventType, message string, data interface{}) {
// 如果客户端已断开,不再发送事件
if clientDisconnected {
return
}
// 检查请求上下文是否被取消(客户端断开)
select {
case <-c.Request.Context().Done():
clientDisconnected = true
return
default:
}
event := StreamEvent{
Type: eventType,
Message: message,
Data: data,
}
eventJSON, _ := json.Marshal(event)
// 尝试写入事件,如果失败则标记客户端断开
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
clientDisconnected = true
h.logger.Debug("客户端断开连接,停止发送SSE事件", zap.Error(err))
return
}
// 刷新响应,如果失败则标记客户端断开
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
c.Writer.Flush()
}
}
// 如果没有对话ID,创建新对话
conversationID := req.ConversationID
if conversationID == "" {
title := safeTruncateString(req.Message, 50)
conv, err := h.db.CreateConversation(title)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
sendEvent("error", "创建对话失败: "+err.Error(), nil)
return
}
conversationID = conv.ID
}
sendEvent("conversation", "会话已创建", map[string]interface{}{
"conversationId": conversationID,
})
// 优先尝试从保存的ReAct数据恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
h.logger.Warn("获取历史消息失败", zap.Error(err))
agentHistoryMessages = []agent.ChatMessage{}
} else {
// 将数据库消息转换为Agent消息格式
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
for _, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
}
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
}
} else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 保存用户消息
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
// 预先创建助手消息,以便关联过程详情
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
if err != nil {
h.logger.Error("创建助手消息失败", zap.Error(err))
// 如果创建失败,继续执行但不保存过程详情
assistantMsg = nil
}
// 创建进度回调函数,同时保存到数据库
var assistantMessageID string
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
// 用于保存tool_call事件中的参数,以便在tool_result时使用 // 用于保存tool_call事件中的参数,以便在tool_result时使用
toolCallCache := make(map[string]map[string]interface{}) // toolCallId -> arguments toolCallCache := make(map[string]map[string]interface{}) // toolCallId -> arguments
progressCallback := func(eventType, message string, data interface{}) { return func(eventType, message string, data interface{}) {
sendEvent(eventType, message, data) // 如果提供了sendEventFunc,发送流式事件
if sendEventFunc != nil {
sendEventFunc(eventType, message, data)
}
// 保存tool_call事件中的参数 // 保存tool_call事件中的参数
if eventType == "tool_call" { if eventType == "tool_call" {
if dataMap, ok := data.(map[string]interface{}); ok { if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string) toolName, _ := dataMap["toolName"].(string)
if toolName == "search_knowledge_base" { if toolName == builtin.ToolSearchKnowledgeBase {
if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" { if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" {
if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok { if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
toolCallCache[toolCallId] = argumentsObj toolCallCache[toolCallId] = argumentsObj
@@ -358,7 +270,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
if eventType == "tool_result" && h.knowledgeManager != nil { if eventType == "tool_result" && h.knowledgeManager != nil {
if dataMap, ok := data.(map[string]interface{}); ok { if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string) toolName, _ := dataMap["toolName"].(string)
if toolName == "search_knowledge_base" { if toolName == builtin.ToolSearchKnowledgeBase {
// 提取检索信息 // 提取检索信息
query := "" query := ""
riskType := "" riskType := ""
@@ -471,6 +383,165 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
} }
} }
} }
}
// AgentLoopStream 处理Agent Loop流式请求
func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
// 对于流式请求,也发送SSE格式的错误
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
event := StreamEvent{
Type: "error",
Message: "请求参数错误: " + err.Error(),
}
eventJSON, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
c.Writer.Flush()
return
}
h.logger.Info("收到Agent Loop流式请求",
zap.String("message", req.Message),
zap.String("conversationId", req.ConversationID),
)
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no") // 禁用nginx缓冲
// 发送初始事件
// 用于跟踪客户端是否已断开连接
clientDisconnected := false
sendEvent := func(eventType, message string, data interface{}) {
// 如果客户端已断开,不再发送事件
if clientDisconnected {
return
}
// 检查请求上下文是否被取消(客户端断开)
select {
case <-c.Request.Context().Done():
clientDisconnected = true
return
default:
}
event := StreamEvent{
Type: eventType,
Message: message,
Data: data,
}
eventJSON, _ := json.Marshal(event)
// 尝试写入事件,如果失败则标记客户端断开
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
clientDisconnected = true
h.logger.Debug("客户端断开连接,停止发送SSE事件", zap.Error(err))
return
}
// 刷新响应,如果失败则标记客户端断开
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
c.Writer.Flush()
}
}
// 如果没有对话ID,创建新对话
conversationID := req.ConversationID
if conversationID == "" {
title := safeTruncateString(req.Message, 50)
conv, err := h.db.CreateConversation(title)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
sendEvent("error", "创建对话失败: "+err.Error(), nil)
return
}
conversationID = conv.ID
}
sendEvent("conversation", "会话已创建", map[string]interface{}{
"conversationId": conversationID,
})
// 优先尝试从保存的ReAct数据恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
h.logger.Warn("获取历史消息失败", zap.Error(err))
agentHistoryMessages = []agent.ChatMessage{}
} else {
// 将数据库消息转换为Agent消息格式
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
for _, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
}
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
}
} else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 应用角色用户提示词和工具配置
finalMessage := req.Message
var roleTools []string // 角色配置的工具列表
if req.Role != "" && req.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + req.Message
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
} else if len(role.MCPs) > 0 {
// 向后兼容:如果只有mcps字段,暂时使用空列表(表示使用所有工具)
// 因为mcps是MCP服务器名称,不是工具列表
h.logger.Info("角色配置使用旧的mcps字段,将使用所有工具", zap.String("role", req.Role))
}
}
}
}
// 如果roleTools为空,表示使用所有工具(默认角色或未配置工具的角色)
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
// 预先创建助手消息,以便关联过程详情
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
if err != nil {
h.logger.Error("创建助手消息失败", zap.Error(err))
// 如果创建失败,继续执行但不保存过程详情
assistantMsg = nil
}
// 创建进度回调函数,同时保存到数据库
var assistantMessageID string
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
// 创建进度回调函数,复用统一逻辑
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent)
// 创建一个独立的上下文用于任务执行,不随HTTP请求取消 // 创建一个独立的上下文用于任务执行,不随HTTP请求取消
// 这样即使客户端断开连接(如刷新页面),任务也能继续执行 // 这样即使客户端断开连接(如刷新页面),任务也能继续执行
@@ -526,15 +597,20 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
taskStatus := "completed" taskStatus := "completed"
defer h.tasks.FinishTask(conversationID, taskStatus) defer h.tasks.FinishTask(conversationID, taskStatus)
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断 // 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断(使用包含角色提示词的finalMessage和角色工具列表)
sendEvent("progress", "正在分析您的请求...", nil) sendEvent("progress", "正在分析您的请求...", nil)
result, err := h.agent.AgentLoopWithProgress(taskCtx, req.Message, agentHistoryMessages, conversationID, progressCallback) result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
if err != nil { if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err)) h.logger.Error("Agent Loop执行失败", zap.Error(err))
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
// 检查是否是用户取消:context的cause是ErrTaskCancelled
// 如果cause是ErrTaskCancelled,无论错误是什么类型(包括context.Canceled),都视为用户取消
// 这样可以正确处理在API调用过程中被取消的情况
isCancelled := errors.Is(cause, ErrTaskCancelled)
switch { switch {
case errors.Is(cause, ErrTaskCancelled): case isCancelled:
taskStatus = "cancelled" taskStatus = "cancelled"
cancelMsg := "任务已被用户取消,后续操作已停止。" cancelMsg := "任务已被用户取消,后续操作已停止。"
@@ -724,6 +800,453 @@ func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
}) })
} }
// ListCompletedTasks 列出最近完成的任务历史
func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"tasks": h.tasks.GetCompletedTasks(),
})
}
// BatchTaskRequest 批量任务请求
type BatchTaskRequest struct {
Title string `json:"title"` // 任务标题(可选)
Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务
}
// CreateBatchQueue 创建批量任务队列
func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
var req BatchTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(req.Tasks) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "任务列表不能为空"})
return
}
// 过滤空任务
validTasks := make([]string, 0, len(req.Tasks))
for _, task := range req.Tasks {
if task != "" {
validTasks = append(validTasks, task)
}
}
if len(validTasks) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "没有有效的任务"})
return
}
queue := h.batchTaskManager.CreateBatchQueue(req.Title, validTasks)
c.JSON(http.StatusOK, gin.H{
"queueId": queue.ID,
"queue": queue,
})
}
// GetBatchQueue 获取批量任务队列
func (h *AgentHandler) GetBatchQueue(c *gin.Context) {
queueID := c.Param("queueId")
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"queue": queue})
}
// ListBatchQueuesResponse 批量任务队列列表响应
type ListBatchQueuesResponse struct {
Queues []*BatchTaskQueue `json:"queues"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}
// ListBatchQueues 列出所有批量任务队列(支持筛选和分页)
func (h *AgentHandler) ListBatchQueues(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "10")
offsetStr := c.DefaultQuery("offset", "0")
pageStr := c.Query("page")
status := c.Query("status")
keyword := c.Query("keyword")
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
page := 1
// 如果提供了page参数,优先使用page计算offset
if pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
offset = (page - 1) * limit
}
}
// 限制pageSize范围
if limit <= 0 || limit > 100 {
limit = 10
}
if offset < 0 {
offset = 0
}
// 默认status为"all"
if status == "" {
status = "all"
}
// 获取队列列表和总数
queues, total, err := h.batchTaskManager.ListQueues(limit, offset, status, keyword)
if err != nil {
h.logger.Error("获取批量任务队列列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 计算总页数
totalPages := (total + limit - 1) / limit
if totalPages == 0 {
totalPages = 1
}
// 如果使用offset计算page,需要重新计算
if pageStr == "" {
page = (offset / limit) + 1
}
response := ListBatchQueuesResponse{
Queues: queues,
Total: total,
Page: page,
PageSize: limit,
TotalPages: totalPages,
}
c.JSON(http.StatusOK, response)
}
// StartBatchQueue 开始执行批量任务队列
func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
queueID := c.Param("queueId")
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
if queue.Status != "pending" && queue.Status != "paused" {
c.JSON(http.StatusBadRequest, gin.H{"error": "队列状态不允许启动"})
return
}
// 在后台执行批量任务
go h.executeBatchQueue(queueID)
h.batchTaskManager.UpdateQueueStatus(queueID, "running")
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
}
// PauseBatchQueue 暂停批量任务队列
func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
queueID := c.Param("queueId")
success := h.batchTaskManager.PauseQueue(queueID)
if !success {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"})
}
// DeleteBatchQueue 删除批量任务队列
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
queueID := c.Param("queueId")
success := h.batchTaskManager.DeleteQueue(queueID)
if !success {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"})
}
// UpdateBatchTask 更新批量任务消息
func (h *AgentHandler) UpdateBatchTask(c *gin.Context) {
queueID := c.Param("queueId")
taskID := c.Param("taskId")
var req struct {
Message string `json:"message" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Message == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "任务消息不能为空"})
return
}
err := h.batchTaskManager.UpdateTaskMessage(queueID, taskID, req.Message)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 返回更新后的队列信息
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "任务已更新", "queue": queue})
}
// AddBatchTask 添加任务到批量任务队列
func (h *AgentHandler) AddBatchTask(c *gin.Context) {
queueID := c.Param("queueId")
var req struct {
Message string `json:"message" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Message == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "任务消息不能为空"})
return
}
task, err := h.batchTaskManager.AddTaskToQueue(queueID, req.Message)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 返回更新后的队列信息
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue})
}
// DeleteBatchTask 删除批量任务
func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
queueID := c.Param("queueId")
taskID := c.Param("taskId")
err := h.batchTaskManager.DeleteTask(queueID, taskID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 返回更新后的队列信息
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
}
// executeBatchQueue 执行批量任务队列
func (h *AgentHandler) executeBatchQueue(queueID string) {
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
for {
// 检查队列状态
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" {
break
}
// 获取下一个任务
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
if !hasNext {
// 所有任务完成
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
break
}
// 更新任务状态为运行中
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "")
// 创建新对话
title := safeTruncateString(task.Message, 50)
conv, err := h.db.CreateConversation(title)
var conversationID string
if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
h.batchTaskManager.MoveToNextTask(queueID)
continue
}
conversationID = conv.ID
// 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话)
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID)
// 保存用户消息
_, err = h.db.AddMessage(conversationID, "user", task.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
// 预先创建助手消息,以便关联过程详情
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
if err != nil {
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
// 如果创建失败,继续执行但不保存过程详情
assistantMsg = nil
}
// 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil)
var assistantMessageID string
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, nil)
// 执行任务
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("conversationId", conversationID))
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
// 存储取消函数,以便在取消队列时能够取消当前任务
h.batchTaskManager.SetTaskCancel(queueID, cancel)
// 批量任务暂时不支持角色工具过滤,使用所有工具(传入nil)
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback, nil)
// 任务执行完成,清理取消函数
h.batchTaskManager.SetTaskCancel(queueID, nil)
cancel()
if err != nil {
// 检查是否是取消错误
// 1. 直接检查是否是 context.Canceled(包括包装后的错误)
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
// 3. 检查 result.Response 中是否包含取消相关的消息
errStr := err.Error()
isCancelled := errors.Is(err, context.Canceled) ||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
(result != nil && result.Response != "" && (strings.Contains(result.Response, "任务已被取消") || strings.Contains(result.Response, "任务执行中断")))
if isCancelled {
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
cancelMsg := "任务已被用户取消,后续操作已停止。"
// 如果result中有更具体的取消消息,使用它
if result != nil && result.Response != "" && (strings.Contains(result.Response, "任务已被取消") || strings.Contains(result.Response, "任务执行中断")) {
cancelMsg = result.Response
}
// 更新助手消息内容
if assistantMessageID != "" {
if _, updateErr := h.db.Exec(
"UPDATE messages SET content = ? WHERE id = ?",
cancelMsg,
assistantMessageID,
); updateErr != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
// 保存取消详情到数据库
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
} else {
// 如果没有预先创建的助手消息,创建一个新的
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
if errMsg != nil {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
}
}
// 保存ReAct数据(如果存在)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
} else {
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
errorMsg := "执行失败: " + err.Error()
// 更新助手消息内容
if assistantMessageID != "" {
if _, updateErr := h.db.Exec(
"UPDATE messages SET content = ? WHERE id = ?",
errorMsg,
assistantMessageID,
); updateErr != nil {
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
// 保存错误详情到数据库
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", err.Error())
}
} else {
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
// 更新助手消息内容
if assistantMessageID != "" {
mcpIDsJSON := ""
if len(result.MCPExecutionIDs) > 0 {
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
mcpIDsJSON = string(jsonData)
}
if _, updateErr := h.db.Exec(
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
result.Response,
mcpIDsJSON,
assistantMessageID,
); updateErr != nil {
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
// 如果更新失败,尝试创建新消息
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
} else {
// 如果没有预先创建的助手消息,创建一个新的
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
// 保存ReAct数据
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else {
h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
}
}
// 保存结果
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", result.Response, "", conversationID)
}
// 移动到下一个任务
h.batchTaskManager.MoveToNextTask(queueID)
// 检查是否被取消或暂停
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
if queue.Status == "cancelled" || queue.Status == "paused" {
break
}
}
}
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文 // loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退到消息表 // 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退到消息表
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) { func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
+755
View File
@@ -0,0 +1,755 @@
package handler
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"sort"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/database"
)
// BatchTask 批量任务项
type BatchTask struct {
ID string `json:"id"`
Message string `json:"message"`
ConversationID string `json:"conversationId,omitempty"`
Status string `json:"status"` // pending, running, completed, failed, cancelled
StartedAt *time.Time `json:"startedAt,omitempty"`
CompletedAt *time.Time `json:"completedAt,omitempty"`
Error string `json:"error,omitempty"`
Result string `json:"result,omitempty"`
}
// BatchTaskQueue 批量任务队列
type BatchTaskQueue struct {
ID string `json:"id"`
Title string `json:"title,omitempty"`
Tasks []*BatchTask `json:"tasks"`
Status string `json:"status"` // pending, running, paused, completed, cancelled
CreatedAt time.Time `json:"createdAt"`
StartedAt *time.Time `json:"startedAt,omitempty"`
CompletedAt *time.Time `json:"completedAt,omitempty"`
CurrentIndex int `json:"currentIndex"`
mu sync.RWMutex
}
// BatchTaskManager 批量任务管理器
type BatchTaskManager struct {
db *database.DB
queues map[string]*BatchTaskQueue
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
mu sync.RWMutex
}
// NewBatchTaskManager 创建批量任务管理器
func NewBatchTaskManager() *BatchTaskManager {
return &BatchTaskManager{
queues: make(map[string]*BatchTaskQueue),
taskCancels: make(map[string]context.CancelFunc),
}
}
// SetDB 设置数据库连接
func (m *BatchTaskManager) SetDB(db *database.DB) {
m.mu.Lock()
defer m.mu.Unlock()
m.db = db
}
// CreateBatchQueue 创建批量任务队列
func (m *BatchTaskManager) CreateBatchQueue(title string, tasks []string) *BatchTaskQueue {
m.mu.Lock()
defer m.mu.Unlock()
queueID := time.Now().Format("20060102150405") + "-" + generateShortID()
queue := &BatchTaskQueue{
ID: queueID,
Title: title,
Tasks: make([]*BatchTask, 0, len(tasks)),
Status: "pending",
CreatedAt: time.Now(),
CurrentIndex: 0,
}
// 准备数据库保存的任务数据
dbTasks := make([]map[string]interface{}, 0, len(tasks))
for _, message := range tasks {
if message == "" {
continue // 跳过空行
}
taskID := generateShortID()
task := &BatchTask{
ID: taskID,
Message: message,
Status: "pending",
}
queue.Tasks = append(queue.Tasks, task)
dbTasks = append(dbTasks, map[string]interface{}{
"id": taskID,
"message": message,
})
}
// 保存到数据库
if m.db != nil {
if err := m.db.CreateBatchQueue(queueID, title, dbTasks); err != nil {
// 如果数据库保存失败,记录错误但继续(使用内存缓存)
// 这里可以添加日志记录
}
}
m.queues[queueID] = queue
return queue
}
// GetBatchQueue 获取批量任务队列
func (m *BatchTaskManager) GetBatchQueue(queueID string) (*BatchTaskQueue, bool) {
m.mu.RLock()
queue, exists := m.queues[queueID]
m.mu.RUnlock()
if exists {
return queue, true
}
// 如果内存中不存在,尝试从数据库加载
if m.db != nil {
if queue := m.loadQueueFromDB(queueID); queue != nil {
m.mu.Lock()
m.queues[queueID] = queue
m.mu.Unlock()
return queue, true
}
}
return nil, false
}
// loadQueueFromDB 从数据库加载单个队列
func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
if m.db == nil {
return nil
}
queueRow, err := m.db.GetBatchQueue(queueID)
if err != nil || queueRow == nil {
return nil
}
taskRows, err := m.db.GetBatchTasks(queueID)
if err != nil {
return nil
}
queue := &BatchTaskQueue{
ID: queueRow.ID,
Status: queueRow.Status,
CreatedAt: queueRow.CreatedAt,
CurrentIndex: queueRow.CurrentIndex,
Tasks: make([]*BatchTask, 0, len(taskRows)),
}
if queueRow.Title.Valid {
queue.Title = queueRow.Title.String
}
if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time
}
if queueRow.CompletedAt.Valid {
queue.CompletedAt = &queueRow.CompletedAt.Time
}
for _, taskRow := range taskRows {
task := &BatchTask{
ID: taskRow.ID,
Message: taskRow.Message,
Status: taskRow.Status,
}
if taskRow.ConversationID.Valid {
task.ConversationID = taskRow.ConversationID.String
}
if taskRow.StartedAt.Valid {
task.StartedAt = &taskRow.StartedAt.Time
}
if taskRow.CompletedAt.Valid {
task.CompletedAt = &taskRow.CompletedAt.Time
}
if taskRow.Error.Valid {
task.Error = taskRow.Error.String
}
if taskRow.Result.Valid {
task.Result = taskRow.Result.String
}
queue.Tasks = append(queue.Tasks, task)
}
return queue
}
// GetAllQueues 获取所有队列
func (m *BatchTaskManager) GetAllQueues() []*BatchTaskQueue {
m.mu.RLock()
result := make([]*BatchTaskQueue, 0, len(m.queues))
for _, queue := range m.queues {
result = append(result, queue)
}
m.mu.RUnlock()
// 如果数据库可用,确保所有数据库中的队列都已加载到内存
if m.db != nil {
dbQueues, err := m.db.GetAllBatchQueues()
if err == nil {
m.mu.Lock()
for _, queueRow := range dbQueues {
if _, exists := m.queues[queueRow.ID]; !exists {
if queue := m.loadQueueFromDB(queueRow.ID); queue != nil {
m.queues[queueRow.ID] = queue
result = append(result, queue)
}
}
}
m.mu.Unlock()
}
}
return result
}
// ListQueues 列出队列(支持筛选和分页)
func (m *BatchTaskManager) ListQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueue, int, error) {
var queues []*BatchTaskQueue
var total int
// 如果数据库可用,从数据库查询
if m.db != nil {
// 获取总数
count, err := m.db.CountBatchQueues(status, keyword)
if err != nil {
return nil, 0, fmt.Errorf("统计队列总数失败: %w", err)
}
total = count
// 获取队列列表(只获取ID
queueRows, err := m.db.ListBatchQueues(limit, offset, status, keyword)
if err != nil {
return nil, 0, fmt.Errorf("查询队列列表失败: %w", err)
}
// 加载完整的队列信息(从内存或数据库)
m.mu.Lock()
for _, queueRow := range queueRows {
var queue *BatchTaskQueue
// 先从内存查找
if cached, exists := m.queues[queueRow.ID]; exists {
queue = cached
} else {
// 从数据库加载
queue = m.loadQueueFromDB(queueRow.ID)
if queue != nil {
m.queues[queueRow.ID] = queue
}
}
if queue != nil {
queues = append(queues, queue)
}
}
m.mu.Unlock()
} else {
// 没有数据库,从内存中筛选和分页
m.mu.RLock()
allQueues := make([]*BatchTaskQueue, 0, len(m.queues))
for _, queue := range m.queues {
allQueues = append(allQueues, queue)
}
m.mu.RUnlock()
// 筛选
filtered := make([]*BatchTaskQueue, 0)
for _, queue := range allQueues {
// 状态筛选
if status != "" && status != "all" && queue.Status != status {
continue
}
// 关键字搜索(搜索队列ID和标题)
if keyword != "" {
keywordLower := strings.ToLower(keyword)
queueIDLower := strings.ToLower(queue.ID)
queueTitleLower := strings.ToLower(queue.Title)
if !strings.Contains(queueIDLower, keywordLower) && !strings.Contains(queueTitleLower, keywordLower) {
// 也可以搜索创建时间
createdAtStr := queue.CreatedAt.Format("2006-01-02 15:04:05")
if !strings.Contains(createdAtStr, keyword) {
continue
}
}
}
filtered = append(filtered, queue)
}
// 按创建时间倒序排序
sort.Slice(filtered, func(i, j int) bool {
return filtered[i].CreatedAt.After(filtered[j].CreatedAt)
})
total = len(filtered)
// 分页
start := offset
if start > len(filtered) {
start = len(filtered)
}
end := start + limit
if end > len(filtered) {
end = len(filtered)
}
if start < len(filtered) {
queues = filtered[start:end]
}
}
return queues, total, nil
}
// LoadFromDB 从数据库加载所有队列
func (m *BatchTaskManager) LoadFromDB() error {
if m.db == nil {
return nil
}
queueRows, err := m.db.GetAllBatchQueues()
if err != nil {
return err
}
m.mu.Lock()
defer m.mu.Unlock()
for _, queueRow := range queueRows {
if _, exists := m.queues[queueRow.ID]; exists {
continue // 已存在,跳过
}
taskRows, err := m.db.GetBatchTasks(queueRow.ID)
if err != nil {
continue // 跳过加载失败的任务
}
queue := &BatchTaskQueue{
ID: queueRow.ID,
Status: queueRow.Status,
CreatedAt: queueRow.CreatedAt,
CurrentIndex: queueRow.CurrentIndex,
Tasks: make([]*BatchTask, 0, len(taskRows)),
}
if queueRow.Title.Valid {
queue.Title = queueRow.Title.String
}
if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time
}
if queueRow.CompletedAt.Valid {
queue.CompletedAt = &queueRow.CompletedAt.Time
}
for _, taskRow := range taskRows {
task := &BatchTask{
ID: taskRow.ID,
Message: taskRow.Message,
Status: taskRow.Status,
}
if taskRow.ConversationID.Valid {
task.ConversationID = taskRow.ConversationID.String
}
if taskRow.StartedAt.Valid {
task.StartedAt = &taskRow.StartedAt.Time
}
if taskRow.CompletedAt.Valid {
task.CompletedAt = &taskRow.CompletedAt.Time
}
if taskRow.Error.Valid {
task.Error = taskRow.Error.String
}
if taskRow.Result.Valid {
task.Result = taskRow.Result.String
}
queue.Tasks = append(queue.Tasks, task)
}
m.queues[queueRow.ID] = queue
}
return nil
}
// UpdateTaskStatus 更新任务状态
func (m *BatchTaskManager) UpdateTaskStatus(queueID, taskID, status string, result, errorMsg string) {
m.UpdateTaskStatusWithConversationID(queueID, taskID, status, result, errorMsg, "")
}
// UpdateTaskStatusWithConversationID 更新任务状态(包含conversationId
func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, status string, result, errorMsg, conversationID string) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return
}
for _, task := range queue.Tasks {
if task.ID == taskID {
task.Status = status
if result != "" {
task.Result = result
}
if errorMsg != "" {
task.Error = errorMsg
}
if conversationID != "" {
task.ConversationID = conversationID
}
now := time.Now()
if status == "running" && task.StartedAt == nil {
task.StartedAt = &now
}
if status == "completed" || status == "failed" || status == "cancelled" {
task.CompletedAt = &now
}
break
}
}
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
}
// UpdateQueueStatus 更新队列状态
func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return
}
queue.Status = status
now := time.Now()
if status == "running" && queue.StartedAt == nil {
queue.StartedAt = &now
}
if status == "completed" || status == "cancelled" {
queue.CompletedAt = &now
}
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
}
// UpdateTaskMessage 更新任务消息(仅限待执行状态)
func (m *BatchTaskManager) UpdateTaskMessage(queueID, taskID, message string) error {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return fmt.Errorf("队列不存在")
}
// 检查队列状态,只有待执行状态的队列才能编辑任务
if queue.Status != "pending" {
return fmt.Errorf("只有待执行状态的队列才能编辑任务")
}
// 查找并更新任务
for _, task := range queue.Tasks {
if task.ID == taskID {
// 只有待执行状态的任务才能编辑
if task.Status != "pending" {
return fmt.Errorf("只有待执行状态的任务才能编辑")
}
task.Message = message
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchTaskMessage(queueID, taskID, message); err != nil {
return fmt.Errorf("更新任务消息失败: %w", err)
}
}
return nil
}
}
return fmt.Errorf("任务不存在")
}
// AddTaskToQueue 添加任务到队列(仅限待执行状态)
func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask, error) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return nil, fmt.Errorf("队列不存在")
}
// 检查队列状态,只有待执行状态的队列才能添加任务
if queue.Status != "pending" {
return nil, fmt.Errorf("只有待执行状态的队列才能添加任务")
}
if message == "" {
return nil, fmt.Errorf("任务消息不能为空")
}
// 生成任务ID
taskID := generateShortID()
task := &BatchTask{
ID: taskID,
Message: message,
Status: "pending",
}
// 添加到内存队列
queue.Tasks = append(queue.Tasks, task)
// 同步到数据库
if m.db != nil {
if err := m.db.AddBatchTask(queueID, taskID, message); err != nil {
// 如果数据库保存失败,从内存中移除
queue.Tasks = queue.Tasks[:len(queue.Tasks)-1]
return nil, fmt.Errorf("添加任务失败: %w", err)
}
}
return task, nil
}
// DeleteTask 删除任务(仅限待执行状态)
func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return fmt.Errorf("队列不存在")
}
// 检查队列状态,只有待执行状态的队列才能删除任务
if queue.Status != "pending" {
return fmt.Errorf("只有待执行状态的队列才能删除任务")
}
// 查找并删除任务
taskIndex := -1
for i, task := range queue.Tasks {
if task.ID == taskID {
// 只有待执行状态的任务才能删除
if task.Status != "pending" {
return fmt.Errorf("只有待执行状态的任务才能删除")
}
taskIndex = i
break
}
}
if taskIndex == -1 {
return fmt.Errorf("任务不存在")
}
// 从内存队列中删除
queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...)
// 同步到数据库
if m.db != nil {
if err := m.db.DeleteBatchTask(queueID, taskID); err != nil {
// 如果数据库删除失败,恢复内存中的任务
// 这里需要重新插入,但为了简化,我们只记录错误
return fmt.Errorf("删除任务失败: %w", err)
}
}
return nil
}
// GetNextTask 获取下一个待执行的任务
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
queue, exists := m.queues[queueID]
if !exists {
return nil, false
}
for i := queue.CurrentIndex; i < len(queue.Tasks); i++ {
task := queue.Tasks[i]
if task.Status == "pending" {
queue.CurrentIndex = i
return task, true
}
}
return nil, false
}
// MoveToNextTask 移动到下一个任务
func (m *BatchTaskManager) MoveToNextTask(queueID string) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return
}
queue.CurrentIndex++
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchQueueCurrentIndex(queueID, queue.CurrentIndex); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
}
// SetTaskCancel 设置当前任务的取消函数
func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
if cancel != nil {
m.taskCancels[queueID] = cancel
} else {
delete(m.taskCancels, queueID)
}
}
// PauseQueue 暂停队列
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
m.mu.Lock()
queue, exists := m.queues[queueID]
if !exists {
m.mu.Unlock()
return false
}
if queue.Status != "running" {
m.mu.Unlock()
return false
}
queue.Status = "paused"
// 取消当前正在执行的任务(通过取消context)
if cancel, exists := m.taskCancels[queueID]; exists {
cancel()
delete(m.taskCancels, queueID)
}
m.mu.Unlock()
// 同步队列状态到数据库
if m.db != nil {
if err := m.db.UpdateBatchQueueStatus(queueID, "paused"); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
return true
}
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
m.mu.Lock()
queue, exists := m.queues[queueID]
if !exists {
m.mu.Unlock()
return false
}
if queue.Status == "completed" || queue.Status == "cancelled" {
m.mu.Unlock()
return false
}
queue.Status = "cancelled"
now := time.Now()
queue.CompletedAt = &now
// 取消所有待执行的任务
for _, task := range queue.Tasks {
if task.Status == "pending" {
task.Status = "cancelled"
task.CompletedAt = &now
// 同步到数据库
if m.db != nil {
m.db.UpdateBatchTaskStatus(queueID, task.ID, "cancelled", "", "", "")
}
}
}
// 取消当前正在执行的任务
if cancel, exists := m.taskCancels[queueID]; exists {
cancel()
delete(m.taskCancels, queueID)
}
m.mu.Unlock()
// 同步队列状态到数据库
if m.db != nil {
if err := m.db.UpdateBatchQueueStatus(queueID, "cancelled"); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
return true
}
// DeleteQueue 删除队列
func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
m.mu.Lock()
defer m.mu.Unlock()
_, exists := m.queues[queueID]
if !exists {
return false
}
// 清理取消函数
delete(m.taskCancels, queueID)
// 从数据库删除
if m.db != nil {
if err := m.db.DeleteBatchQueue(queueID); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
delete(m.queues, queueID)
return true
}
// generateShortID 生成短ID
func generateShortID() string {
b := make([]byte, 4)
rand.Read(b)
return time.Now().Format("150405") + "-" + hex.EncodeToString(b)
}
+208 -28
View File
@@ -47,16 +47,17 @@ type ConfigHandler struct {
config *config.Config config *config.Config
mcpServer *mcp.Server mcpServer *mcp.Server
executor *security.Executor executor *security.Executor
agent AgentUpdater // Agent接口,用于更新Agent配置 agent AgentUpdater // Agent接口,用于更新Agent配置
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选) knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选) vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选) retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选) knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器(可选) appUpdater AppUpdater // App更新器(可选)
logger *zap.Logger logger *zap.Logger
mu sync.RWMutex mu sync.RWMutex
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
} }
// AttackChainUpdater 攻击链处理器更新接口 // AttackChainUpdater 攻击链处理器更新接口
@@ -72,15 +73,26 @@ type AgentUpdater interface {
// NewConfigHandler 创建新的配置处理器 // NewConfigHandler 创建新的配置处理器
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, attackChainHandler AttackChainUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler { func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, attackChainHandler AttackChainUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler {
// 保存初始的嵌入模型配置(如果知识库已启用)
var lastEmbeddingConfig *config.EmbeddingConfig
if cfg.Knowledge.Enabled {
lastEmbeddingConfig = &config.EmbeddingConfig{
Provider: cfg.Knowledge.Embedding.Provider,
Model: cfg.Knowledge.Embedding.Model,
BaseURL: cfg.Knowledge.Embedding.BaseURL,
APIKey: cfg.Knowledge.Embedding.APIKey,
}
}
return &ConfigHandler{ return &ConfigHandler{
configPath: configPath, configPath: configPath,
config: cfg, config: cfg,
mcpServer: mcpServer, mcpServer: mcpServer,
executor: executor, executor: executor,
agent: agent, agent: agent,
attackChainHandler: attackChainHandler, attackChainHandler: attackChainHandler,
externalMCPMgr: externalMCPMgr, externalMCPMgr: externalMCPMgr,
logger: logger, logger: logger,
lastEmbeddingConfig: lastEmbeddingConfig,
} }
} }
@@ -135,6 +147,7 @@ type ToolConfigInfo struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具 IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具) ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
} }
// GetConfig 获取当前配置 // GetConfig 获取当前配置
@@ -191,7 +204,8 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
// 获取外部MCP工具 // 获取外部MCP工具
if h.externalMCPMgr != nil { if h.externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // 增加超时时间到30秒,因为通过代理连接远程服务器可能需要更长时间
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
externalTools, err := h.externalMCPMgr.GetAllTools(ctx) externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
@@ -259,11 +273,12 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
// GetToolsResponse 获取工具列表响应(分页) // GetToolsResponse 获取工具列表响应(分页)
type GetToolsResponse struct { type GetToolsResponse struct {
Tools []ToolConfigInfo `json:"tools"` Tools []ToolConfigInfo `json:"tools"`
Total int `json:"total"` Total int `json:"total"`
Page int `json:"page"` TotalEnabled int `json:"total_enabled"` // 已启用的工具总数
PageSize int `json:"page_size"` Page int `json:"page"`
TotalPages int `json:"total_pages"` PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
} }
// GetTools 获取工具列表(支持分页和搜索) // GetTools 获取工具列表(支持分页和搜索)
@@ -292,6 +307,23 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
searchTermLower = strings.ToLower(searchTerm) searchTermLower = strings.ToLower(searchTerm)
} }
// 解析角色参数,用于过滤工具并标注启用状态
roleName := c.Query("role")
var roleToolsSet map[string]bool // 角色配置的工具集合
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
if roleName != "" && roleName != "默认" && h.config.Roles != nil {
if role, exists := h.config.Roles[roleName]; exists && role.Enabled {
if len(role.Tools) > 0 {
// 角色配置了工具列表,只使用这些工具
roleToolsSet = make(map[string]bool)
for _, toolKey := range role.Tools {
roleToolsSet[toolKey] = true
}
roleUsesAllTools = false
}
}
}
// 获取所有内部工具并应用搜索过滤 // 获取所有内部工具并应用搜索过滤
configToolMap := make(map[string]bool) configToolMap := make(map[string]bool)
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
@@ -312,6 +344,31 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
toolInfo.Description = desc toolInfo.Description = desc
} }
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,标注启用的工具为role_enabled=true
if tool.Enabled {
roleEnabled := true
toolInfo.RoleEnabled = &roleEnabled
} else {
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 内部工具使用工具名称作为key
if roleToolsSet[tool.Name] {
roleEnabled := tool.Enabled // 工具必须在角色列表中且本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
// 如果有关键词,进行搜索过滤 // 如果有关键词,进行搜索过滤
if searchTermLower != "" { if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name) nameLower := strings.ToLower(toolInfo.Name)
@@ -348,6 +405,26 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
IsExternal: false, IsExternal: false,
} }
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,直接注册的工具默认启用
roleEnabled := true
toolInfo.RoleEnabled = &roleEnabled
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 内部工具使用工具名称作为key
if roleToolsSet[mcpTool.Name] {
roleEnabled := true // 在角色列表中且工具本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
// 如果有关键词,进行搜索过滤 // 如果有关键词,进行搜索过滤
if searchTermLower != "" { if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name) nameLower := strings.ToLower(toolInfo.Name)
@@ -363,7 +440,8 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
// 获取外部MCP工具 // 获取外部MCP工具
if h.externalMCPMgr != nil { if h.externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // 增加超时时间到30秒,因为通过代理连接远程服务器可能需要更长时间
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
externalTools, err := h.externalMCPMgr.GetAllTools(ctx) externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
@@ -425,18 +503,55 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
} }
} }
allTools = append(allTools, ToolConfigInfo{ toolInfo := ToolConfigInfo{
Name: actualToolName, // 显示实际工具名称,不带前缀 Name: actualToolName, // 显示实际工具名称,不带前缀
Description: description, Description: description,
Enabled: enabled, Enabled: enabled,
IsExternal: true, IsExternal: true,
ExternalMCP: mcpName, ExternalMCP: mcpName,
}) }
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,标注启用的工具为role_enabled=true
toolInfo.RoleEnabled = &enabled
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 外部工具使用 "mcpName::toolName" 格式作为key
externalToolKey := externalTool.Name // 这是 "mcpName::toolName" 格式
if roleToolsSet[externalToolKey] {
roleEnabled := enabled // 工具必须在角色列表中且本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
allTools = append(allTools, toolInfo)
} }
} }
} }
// 如果角色配置了工具列表,过滤工具(只保留列表中的工具,但保留其他工具并标记为禁用)
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
total := len(allTools) total := len(allTools)
// 统计已启用的工具数(在角色中的启用工具数)
totalEnabled := 0
for _, tool := range allTools {
if tool.RoleEnabled != nil && *tool.RoleEnabled {
totalEnabled++
} else if tool.RoleEnabled == nil && tool.Enabled {
// 如果未指定角色,统计所有启用的工具
totalEnabled++
}
}
totalPages := (total + pageSize - 1) / pageSize totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 { if totalPages == 0 {
totalPages = 1 totalPages = 1
@@ -457,11 +572,12 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
} }
c.JSON(http.StatusOK, GetToolsResponse{ c.JSON(http.StatusOK, GetToolsResponse{
Tools: tools, Tools: tools,
Total: total, Total: total,
Page: page, TotalEnabled: totalEnabled,
PageSize: pageSize, Page: page,
TotalPages: totalPages, PageSize: pageSize,
TotalPages: totalPages,
}) })
} }
@@ -522,6 +638,15 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
// 更新Knowledge配置 // 更新Knowledge配置
if req.Knowledge != nil { if req.Knowledge != nil {
// 保存旧的嵌入模型配置(用于检测变更)
if h.config.Knowledge.Enabled {
h.lastEmbeddingConfig = &config.EmbeddingConfig{
Provider: h.config.Knowledge.Embedding.Provider,
Model: h.config.Knowledge.Embedding.Model,
BaseURL: h.config.Knowledge.Embedding.BaseURL,
APIKey: h.config.Knowledge.Embedding.APIKey,
}
}
h.config.Knowledge = *req.Knowledge h.config.Knowledge = *req.Knowledge
h.logger.Info("更新Knowledge配置", h.logger.Info("更新Knowledge配置",
zap.Bool("enabled", h.config.Knowledge.Enabled), zap.Bool("enabled", h.config.Knowledge.Enabled),
@@ -676,10 +801,55 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("知识库动态初始化完成,工具已注册") h.logger.Info("知识库动态初始化完成,工具已注册")
} }
// 检查嵌入模型配置是否变更(需要在锁外执行,避免阻塞)
var needReinitKnowledge bool
var reinitKnowledgeInitializer KnowledgeInitializer
h.mu.RLock()
if h.config.Knowledge.Enabled && h.knowledgeInitializer != nil && h.lastEmbeddingConfig != nil {
// 检查嵌入模型配置是否变更
currentEmbedding := h.config.Knowledge.Embedding
if currentEmbedding.Provider != h.lastEmbeddingConfig.Provider ||
currentEmbedding.Model != h.lastEmbeddingConfig.Model ||
currentEmbedding.BaseURL != h.lastEmbeddingConfig.BaseURL ||
currentEmbedding.APIKey != h.lastEmbeddingConfig.APIKey {
needReinitKnowledge = true
reinitKnowledgeInitializer = h.knowledgeInitializer
h.logger.Info("检测到嵌入模型配置变更,需要重新初始化知识库组件",
zap.String("old_model", h.lastEmbeddingConfig.Model),
zap.String("new_model", currentEmbedding.Model),
zap.String("old_base_url", h.lastEmbeddingConfig.BaseURL),
zap.String("new_base_url", currentEmbedding.BaseURL),
)
}
}
h.mu.RUnlock()
// 如果需要重新初始化知识库(嵌入模型配置变更),在锁外执行
if needReinitKnowledge {
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
if _, err := reinitKnowledgeInitializer(); err != nil {
h.logger.Error("重新初始化知识库失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
return
}
h.logger.Info("知识库组件重新初始化完成")
}
// 现在获取写锁,执行快速的操作 // 现在获取写锁,执行快速的操作
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
// 如果重新初始化了知识库,更新嵌入模型配置记录
if needReinitKnowledge && h.config.Knowledge.Enabled {
h.lastEmbeddingConfig = &config.EmbeddingConfig{
Provider: h.config.Knowledge.Embedding.Provider,
Model: h.config.Knowledge.Embedding.Model,
BaseURL: h.config.Knowledge.Embedding.BaseURL,
APIKey: h.config.Knowledge.Embedding.APIKey,
}
h.logger.Info("已更新嵌入模型配置记录")
}
// 重新注册工具(根据新的启用状态) // 重新注册工具(根据新的启用状态)
h.logger.Info("重新注册工具") h.logger.Info("重新注册工具")
@@ -737,6 +907,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
) )
} }
// 更新嵌入模型配置记录(如果知识库启用)
if h.config.Knowledge.Enabled {
h.lastEmbeddingConfig = &config.EmbeddingConfig{
Provider: h.config.Knowledge.Embedding.Provider,
Model: h.config.Knowledge.Embedding.Model,
BaseURL: h.config.Knowledge.Embedding.BaseURL,
APIKey: h.config.Knowledge.Embedding.APIKey,
}
}
h.logger.Info("配置已应用", h.logger.Info("配置已应用",
zap.Int("tools_count", len(h.config.Security.Tools)), zap.Int("tools_count", len(h.config.Security.Tools)),
) )
+13 -2
View File
@@ -324,7 +324,7 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
} else if cfg.URL != "" { } else if cfg.URL != "" {
transport = "http" transport = "http"
} else { } else {
return fmt.Errorf("需要指定commandstdio模式)或urlhttp模式)") return fmt.Errorf("需要指定commandstdio模式)或urlhttp/sse模式)")
} }
} }
@@ -337,8 +337,12 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
if cfg.Command == "" { if cfg.Command == "" {
return fmt.Errorf("stdio模式需要command") return fmt.Errorf("stdio模式需要command")
} }
case "sse":
if cfg.URL == "" {
return fmt.Errorf("SSE模式需要URL")
}
default: default:
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio", transport) return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport)
} }
return nil return nil
@@ -442,6 +446,13 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
if len(serverCfg.Args) > 0 { if len(serverCfg.Args) > 0 {
setStringArrayInMap(serverNode, "args", serverCfg.Args) setStringArrayInMap(serverNode, "args", serverCfg.Args)
} }
// 保存 env 字段(环境变量)
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
envNode := ensureMap(serverNode, "env")
for envKey, envValue := range serverCfg.Env {
setStringInMap(envNode, envKey, envValue)
}
}
if serverCfg.Transport != "" { if serverCfg.Transport != "" {
setStringInMap(serverNode, "transport", serverCfg.Transport) setStringInMap(serverNode, "transport", serverCfg.Transport)
} }
+11 -1
View File
@@ -189,8 +189,18 @@ type GroupConversation struct {
// GetGroupConversations 获取分组中的所有对话 // GetGroupConversations 获取分组中的所有对话
func (h *GroupHandler) GetGroupConversations(c *gin.Context) { func (h *GroupHandler) GetGroupConversations(c *gin.Context) {
groupID := c.Param("id") groupID := c.Param("id")
searchQuery := c.Query("search") // 获取搜索参数
var conversations []*database.Conversation
var err error
// 如果有搜索关键词,使用搜索方法;否则使用普通方法
if searchQuery != "" {
conversations, err = h.db.SearchConversationsByGroup(groupID, searchQuery)
} else {
conversations, err = h.db.GetConversationsByGroup(groupID)
}
conversations, err := h.db.GetConversationsByGroup(groupID)
if err != nil { if err != nil {
h.logger.Error("获取分组对话失败", zap.Error(err)) h.logger.Error("获取分组对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+56 -3
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"time"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/knowledge"
@@ -336,14 +337,54 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
go func() { go func() {
ctx := context.Background() ctx := context.Background()
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex))) h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
failedCount := 0
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
for i, itemID := range itemsToIndex { for i, itemID := range itemsToIndex {
if err := h.indexer.IndexItem(ctx, itemID); err != nil { if err := h.indexer.IndexItem(ctx, itemID); err != nil {
h.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) failedCount++
consecutiveFailures++
// 只在第一个失败时记录详细日志
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
h.logger.Warn("索引知识项失败",
zap.String("itemId", itemID),
zap.Int("totalItems", len(itemsToIndex)),
zap.Error(err),
)
}
// 如果连续失败2次,立即停止增量索引
if consecutiveFailures >= 2 {
h.logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.Int("processedItems", i+1),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue continue
} }
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)))
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
// 减少进度日志频率
if (i+1)%10 == 0 || i+1 == len(itemsToIndex) {
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount))
}
} }
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex))) h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
}() }()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -396,6 +437,18 @@ func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) {
return return
} }
// 获取索引器的错误信息
if h.indexer != nil {
lastError, lastErrorTime := h.indexer.GetLastError()
if lastError != "" {
// 如果错误是最近发生的(5分钟内),则返回错误信息
if time.Since(lastErrorTime) < 5*time.Minute {
status["last_error"] = lastError
status["last_error_time"] = lastErrorTime.Format(time.RFC3339)
}
}
}
c.JSON(http.StatusOK, status) c.JSON(http.StatusOK, status)
} }
+96 -12
View File
@@ -3,6 +3,7 @@ package handler
import ( import (
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
@@ -66,8 +67,10 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
// 解析状态筛选参数 // 解析状态筛选参数
status := c.Query("status") status := c.Query("status")
// 解析工具筛选参数
toolName := c.Query("tool")
executions, total := h.loadExecutionsWithPagination(page, pageSize, status) executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
stats := h.loadStats() stats := h.loadStats()
totalPages := (total + pageSize - 1) / pageSize totalPages := (total + pageSize - 1) / pageSize
@@ -87,18 +90,21 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
} }
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution { func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
executions, _ := h.loadExecutionsWithPagination(1, 1000, "") executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
return executions return executions
} }
func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status string) ([]*mcp.ToolExecution, int) { func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
if h.db == nil { if h.db == nil {
allExecutions := h.mcpServer.GetAllExecutions() allExecutions := h.mcpServer.GetAllExecutions()
// 如果指定了状态筛选,先进行筛选 // 如果指定了状态筛选或工具筛选,先进行筛选
if status != "" { if status != "" || toolName != "" {
filtered := make([]*mcp.ToolExecution, 0) filtered := make([]*mcp.ToolExecution, 0)
for _, exec := range allExecutions { for _, exec := range allExecutions {
if exec.Status == status { matchStatus := status == "" || exec.Status == status
// 支持部分匹配(模糊搜索)
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName))
if matchStatus && matchTool {
filtered = append(filtered, exec) filtered = append(filtered, exec)
} }
} }
@@ -117,15 +123,18 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
} }
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
executions, err := h.db.LoadToolExecutionsWithPagination(offset, pageSize, status) executions, err := h.db.LoadToolExecutionsWithPagination(offset, pageSize, status, toolName)
if err != nil { if err != nil {
h.logger.Warn("从数据库加载执行记录失败,回退到内存数据", zap.Error(err)) h.logger.Warn("从数据库加载执行记录失败,回退到内存数据", zap.Error(err))
allExecutions := h.mcpServer.GetAllExecutions() allExecutions := h.mcpServer.GetAllExecutions()
// 如果指定了状态筛选,先进行筛选 // 如果指定了状态筛选或工具筛选,先进行筛选
if status != "" { if status != "" || toolName != "" {
filtered := make([]*mcp.ToolExecution, 0) filtered := make([]*mcp.ToolExecution, 0)
for _, exec := range allExecutions { for _, exec := range allExecutions {
if exec.Status == status { matchStatus := status == "" || exec.Status == status
// 支持部分匹配(模糊搜索)
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName))
if matchStatus && matchTool {
filtered = append(filtered, exec) filtered = append(filtered, exec)
} }
} }
@@ -143,8 +152,8 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
return allExecutions[offset:end], total return allExecutions[offset:end], total
} }
// 获取总数(考虑状态筛选) // 获取总数(考虑状态筛选和工具筛选
total, err := h.db.CountToolExecutions(status) total, err := h.db.CountToolExecutions(status, toolName)
if err != nil { if err != nil {
h.logger.Warn("获取执行记录总数失败", zap.Error(err)) h.logger.Warn("获取执行记录总数失败", zap.Error(err))
// 回退:使用已加载的记录数估算 // 回退:使用已加载的记录数估算
@@ -298,4 +307,79 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
} }
// DeleteExecutions 批量删除执行记录
func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
var request struct {
IDs []string `json:"ids"`
}
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
if len(request.IDs) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID列表不能为空"})
return
}
// 如果使用数据库,先获取执行记录信息,然后删除并更新统计
if h.db != nil {
// 先获取执行记录信息(用于更新统计)
executions, err := h.db.GetToolExecutionsByIds(request.IDs)
if err != nil {
h.logger.Error("获取执行记录失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取执行记录失败: " + err.Error()})
return
}
// 按工具名称分组统计需要减少的数量
toolStats := make(map[string]struct {
totalCalls int
successCalls int
failedCalls int
})
for _, exec := range executions {
if exec.ToolName == "" {
continue
}
stats := toolStats[exec.ToolName]
stats.totalCalls++
if exec.Status == "failed" {
stats.failedCalls++
} else if exec.Status == "completed" {
stats.successCalls++
}
toolStats[exec.ToolName] = stats
}
// 批量删除执行记录
err = h.db.DeleteToolExecutions(request.IDs)
if err != nil {
h.logger.Error("批量删除执行记录失败", zap.Error(err), zap.Int("count", len(request.IDs)))
c.JSON(http.StatusInternalServerError, gin.H{"error": "批量删除执行记录失败: " + err.Error()})
return
}
// 更新统计信息(减少相应的计数)
for toolName, stats := range toolStats {
if err := h.db.DecreaseToolStats(toolName, stats.totalCalls, stats.successCalls, stats.failedCalls); err != nil {
h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", toolName))
// 不返回错误,因为记录已经删除成功
}
}
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
return
}
// 如果不使用数据库,尝试从内存中删除(内部MCP服务器)
// 注意:内存中的记录可能已经被清理,所以这里只记录日志
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
}
+453
View File
@@ -0,0 +1,453 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/config"
"gopkg.in/yaml.v3"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// RoleHandler 角色处理器
type RoleHandler struct {
config *config.Config
configPath string
logger *zap.Logger
}
// NewRoleHandler 创建新的角色处理器
func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler {
return &RoleHandler{
config: cfg,
configPath: configPath,
logger: logger,
}
}
// GetRoles 获取所有角色
func (h *RoleHandler) GetRoles(c *gin.Context) {
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
roles := make([]config.RoleConfig, 0, len(h.config.Roles))
for key, role := range h.config.Roles {
// 确保角色的key与name一致
if role.Name == "" {
role.Name = key
}
roles = append(roles, role)
}
c.JSON(http.StatusOK, gin.H{
"roles": roles,
})
}
// GetRole 获取单个角色
func (h *RoleHandler) GetRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
role, exists := h.config.Roles[roleName]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 确保角色的name与key一致
if role.Name == "" {
role.Name = roleName
}
c.JSON(http.StatusOK, gin.H{
"role": role,
})
}
// UpdateRole 更新角色
func (h *RoleHandler) UpdateRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
// 确保角色名称与请求中的name一致
if req.Name == "" {
req.Name = roleName
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 删除所有与角色name相同但key不同的旧角色(避免重复)
// 使用角色name作为key,确保唯一性
finalKey := req.Name
keysToDelete := make([]string, 0)
for key := range h.config.Roles {
// 如果key与最终的key不同,但name相同,则标记为删除
if key != finalKey {
role := h.config.Roles[key]
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = key
}
if role.Name == req.Name {
keysToDelete = append(keysToDelete, key)
}
}
}
// 删除旧的角色
for _, key := range keysToDelete {
delete(h.config.Roles, key)
h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name))
}
// 如果当前更新的key与最终key不同,也需要删除旧的
if roleName != finalKey {
delete(h.config.Roles, roleName)
}
// 如果角色名称改变,需要删除旧文件
if roleName != finalKey {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 删除旧的角色文件
oldSafeFileName := sanitizeFileName(roleName)
oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml")
oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml")
if _, err := os.Stat(oldRoleFileYaml); err == nil {
if err := os.Remove(oldRoleFileYaml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err))
}
}
if _, err := os.Stat(oldRoleFileYml); err == nil {
if err := os.Remove(oldRoleFileYml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err))
}
}
}
// 使用角色name作为key来保存(确保唯一性)
h.config.Roles[finalKey] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "角色已更新",
"role": req,
})
}
// CreateRole 创建新角色
func (h *RoleHandler) CreateRole(c *gin.Context) {
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 检查角色是否已存在
if _, exists := h.config.Roles[req.Name]; exists {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"})
return
}
// 创建角色(默认启用)
if !req.Enabled {
req.Enabled = true
}
h.config.Roles[req.Name] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("创建角色", zap.String("roleName", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "角色已创建",
"role": req,
})
}
// DeleteRole 删除角色
func (h *RoleHandler) DeleteRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
if _, exists := h.config.Roles[roleName]; !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 不允许删除"默认"角色
if roleName == "默认" {
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"})
return
}
delete(h.config.Roles, roleName)
// 删除对应的角色文件
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 尝试删除角色文件(.yaml 和 .yml)
safeFileName := sanitizeFileName(roleName)
roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml")
roleFileYml := filepath.Join(rolesDir, safeFileName+".yml")
// 删除 .yaml 文件(如果存在)
if _, err := os.Stat(roleFileYaml); err == nil {
if err := os.Remove(roleFileYaml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml))
}
}
// 删除 .yml 文件(如果存在)
if _, err := os.Stat(roleFileYml); err == nil {
if err := os.Remove(roleFileYml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml))
}
}
h.logger.Info("删除角色", zap.String("roleName", roleName))
c.JSON(http.StatusOK, gin.H{
"message": "角色已删除",
})
}
// saveConfig 保存配置到目录中的文件
func (h *RoleHandler) saveConfig() error {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 确保目录存在
if err := os.MkdirAll(rolesDir, 0755); err != nil {
return fmt.Errorf("创建角色目录失败: %w", err)
}
// 保存每个角色到独立的文件
if h.config.Roles != nil {
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
safeFileName := sanitizeFileName(role.Name)
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
// 将角色配置序列化为YAML
roleData, err := yaml.Marshal(&role)
if err != nil {
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
continue
}
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
roleDataStr := string(roleData)
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
// 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
roleData = []byte(roleDataStr)
}
// 写入文件
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
continue
}
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
}
}
return nil
}
// sanitizeFileName 将角色名称转换为安全的文件名
func sanitizeFileName(name string) string {
// 替换可能不安全的字符
replacer := map[rune]string{
'/': "_",
'\\': "_",
':': "_",
'*': "_",
'?': "_",
'"': "_",
'<': "_",
'>': "_",
'|': "_",
' ': "_",
}
var result []rune
for _, r := range name {
if replacement, ok := replacer[r]; ok {
result = append(result, []rune(replacement)...)
} else {
result = append(result, r)
}
}
fileName := string(result)
// 如果文件名为空,使用默认名称
if fileName == "" {
fileName = "role"
}
return fileName
}
// updateRolesConfig 更新角色配置
func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) {
root := doc.Content[0]
rolesNode := ensureMap(root, "roles")
// 清空现有角色
if rolesNode.Kind == yaml.MappingNode {
rolesNode.Content = nil
}
// 添加新角色(使用name作为key,确保唯一性)
if cfg.Roles != nil {
// 先建立一个以name为key的map,去重(保留最后一个)
rolesByName := make(map[string]config.RoleConfig)
for roleKey, role := range cfg.Roles {
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = roleKey
}
// 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个
rolesByName[role.Name] = role
}
// 将去重后的角色写入YAML
for roleName, role := range rolesByName {
roleNode := ensureMap(rolesNode, roleName)
setStringInMap(roleNode, "name", role.Name)
setStringInMap(roleNode, "description", role.Description)
setStringInMap(roleNode, "user_prompt", role.UserPrompt)
if role.Icon != "" {
setStringInMap(roleNode, "icon", role.Icon)
}
setBoolInMap(roleNode, "enabled", role.Enabled)
// 添加工具列表(优先使用tools字段)
if len(role.Tools) > 0 {
toolsNode := ensureArray(roleNode, "tools")
toolsNode.Content = nil
for _, toolKey := range role.Tools {
toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey}
toolsNode.Content = append(toolsNode.Content, toolNode)
}
} else if len(role.MCPs) > 0 {
// 向后兼容:如果没有tools但有mcps,保存mcps
mcpsNode := ensureArray(roleNode, "mcps")
mcpsNode.Content = nil
for _, mcpName := range role.MCPs {
mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName}
mcpsNode.Content = append(mcpsNode.Content, mcpNode)
}
}
}
}
}
// ensureArray 确保数组中存在指定key的数组节点
func ensureArray(parent *yaml.Node, key string) *yaml.Node {
_, valueNode := ensureKeyValue(parent, key)
if valueNode.Kind != yaml.SequenceNode {
valueNode.Kind = yaml.SequenceNode
valueNode.Tag = "!!seq"
valueNode.Content = nil
}
return valueNode
}
+90 -3
View File
@@ -23,16 +23,31 @@ type AgentTask struct {
cancel func(error) cancel func(error)
} }
// CompletedTask 已完成的任务(用于历史记录)
type CompletedTask struct {
ConversationID string `json:"conversationId"`
Message string `json:"message,omitempty"`
StartedAt time.Time `json:"startedAt"`
CompletedAt time.Time `json:"completedAt"`
Status string `json:"status"`
}
// AgentTaskManager 管理正在运行的Agent任务 // AgentTaskManager 管理正在运行的Agent任务
type AgentTaskManager struct { type AgentTaskManager struct {
mu sync.RWMutex mu sync.RWMutex
tasks map[string]*AgentTask tasks map[string]*AgentTask
completedTasks []*CompletedTask // 最近完成的任务历史
maxHistorySize int // 最大历史记录数
historyRetention time.Duration // 历史记录保留时间
} }
// NewAgentTaskManager 创建任务管理器 // NewAgentTaskManager 创建任务管理器
func NewAgentTaskManager() *AgentTaskManager { func NewAgentTaskManager() *AgentTaskManager {
return &AgentTaskManager{ return &AgentTaskManager{
tasks: make(map[string]*AgentTask), tasks: make(map[string]*AgentTask),
completedTasks: make([]*CompletedTask, 0),
maxHistorySize: 50, // 最多保留50条历史记录
historyRetention: 24 * time.Hour, // 保留24小时
} }
} }
@@ -118,9 +133,49 @@ func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string)
task.Status = finalStatus task.Status = finalStatus
} }
// 保存到历史记录
completedTask := &CompletedTask{
ConversationID: task.ConversationID,
Message: task.Message,
StartedAt: task.StartedAt,
CompletedAt: time.Now(),
Status: finalStatus,
}
// 添加到历史记录
m.completedTasks = append(m.completedTasks, completedTask)
// 清理过期和过多的历史记录
m.cleanupHistory()
// 从运行任务中移除
delete(m.tasks, conversationID) delete(m.tasks, conversationID)
} }
// cleanupHistory 清理过期的历史记录
func (m *AgentTaskManager) cleanupHistory() {
now := time.Now()
cutoffTime := now.Add(-m.historyRetention)
// 过滤掉过期的记录
validTasks := make([]*CompletedTask, 0, len(m.completedTasks))
for _, task := range m.completedTasks {
if task.CompletedAt.After(cutoffTime) {
validTasks = append(validTasks, task)
}
}
// 如果仍然超过最大数量,只保留最新的
if len(validTasks) > m.maxHistorySize {
// 按完成时间排序,保留最新的
// 由于是追加的,最新的在最后,所以直接取最后N个
start := len(validTasks) - m.maxHistorySize
validTasks = validTasks[start:]
}
m.completedTasks = validTasks
}
// GetActiveTasks 返回所有正在运行的任务 // GetActiveTasks 返回所有正在运行的任务
func (m *AgentTaskManager) GetActiveTasks() []*AgentTask { func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
m.mu.RLock() m.mu.RLock()
@@ -137,3 +192,35 @@ func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
} }
return result return result
} }
// GetCompletedTasks 返回最近完成的任务历史
func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask {
m.mu.RLock()
defer m.mu.RUnlock()
// 清理过期记录(只读锁,不影响其他操作)
// 注意:这里不能直接调用cleanupHistory,因为需要写锁
// 所以返回时过滤过期记录
now := time.Now()
cutoffTime := now.Add(-m.historyRetention)
result := make([]*CompletedTask, 0, len(m.completedTasks))
for _, task := range m.completedTasks {
if task.CompletedAt.After(cutoffTime) {
result = append(result, task)
}
}
// 按完成时间倒序排序(最新的在前)
// 由于是追加的,最新的在最后,需要反转
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
result[i], result[j] = result[j], result[i]
}
// 限制返回数量
if len(result) > m.maxHistorySize {
result = result[:m.maxHistorySize]
}
return result
}
+130 -15
View File
@@ -7,6 +7,8 @@ import (
"fmt" "fmt"
"regexp" "regexp"
"strings" "strings"
"sync"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap" "go.uber.org/zap"
@@ -19,6 +21,12 @@ type Indexer struct {
logger *zap.Logger logger *zap.Logger
chunkSize int // 每个块的最大token数(估算) chunkSize int // 每个块的最大token数(估算)
overlap int // 块之间的重叠token数 overlap int // 块之间的重叠token数
// 错误跟踪
mu sync.RWMutex
lastError string // 最近一次错误信息
lastErrorTime time.Time // 最近一次错误时间
errorCount int // 连续错误计数
} }
// NewIndexer 创建新的索引器 // NewIndexer 创建新的索引器
@@ -267,13 +275,13 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
chunks := idx.ChunkText(content) chunks := idx.ChunkText(content)
idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks))) idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
// 跟踪该知识项的错误
itemErrorCount := 0
var firstError error
firstErrorChunkIndex := -1
// 向量化每个块(包含category和title信息,以便向量检索时能匹配到风险类型) // 向量化每个块(包含category和title信息,以便向量检索时能匹配到风险类型)
for i, chunk := range chunks { for i, chunk := range chunks {
chunkPreview := chunk
if len(chunkPreview) > 200 {
chunkPreview = chunkPreview[:200] + "..."
}
// 将category和title信息包含到向量化的文本中 // 将category和title信息包含到向量化的文本中
// 格式:"[风险类型: {category}] [标题: {title}]\n{chunk内容}" // 格式:"[风险类型: {category}] [标题: {title}]\n{chunk内容}"
// 这样向量嵌入就会包含风险类型信息,即使SQL过滤失败,向量相似度也能帮助匹配 // 这样向量嵌入就会包含风险类型信息,即使SQL过滤失败,向量相似度也能帮助匹配
@@ -281,13 +289,43 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding) embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding)
if err != nil { if err != nil {
idx.logger.Warn("向量化失败", itemErrorCount++
zap.String("itemId", itemID), if firstError == nil {
zap.Int("chunkIndex", i), firstError = err
zap.Int("chunkLength", len(chunk)), firstErrorChunkIndex = i
zap.String("chunkPreview", chunkPreview), // 只在第一个块失败时记录详细日志
zap.Error(err), chunkPreview := chunk
) if len(chunkPreview) > 200 {
chunkPreview = chunkPreview[:200] + "..."
}
idx.logger.Warn("向量化失败",
zap.String("itemId", itemID),
zap.Int("chunkIndex", i),
zap.Int("totalChunks", len(chunks)),
zap.String("chunkPreview", chunkPreview),
zap.Error(err),
)
// 更新全局错误跟踪
errorMsg := fmt.Sprintf("向量化失败 (知识项: %s): %v", itemID, err)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
}
// 如果连续失败2个块,立即停止处理该知识项(降低阈值,更快停止)
// 这样可以避免继续浪费API调用,同时也能更快地检测到配置问题
if itemErrorCount >= 2 {
idx.logger.Error("知识项连续向量化失败,停止处理",
zap.String("itemId", itemID),
zap.Int("totalChunks", len(chunks)),
zap.Int("failedChunks", itemErrorCount),
zap.Int("firstErrorChunkIndex", firstErrorChunkIndex),
zap.Error(firstError),
)
return fmt.Errorf("知识项连续向量化失败 (%d个块失败): %v", itemErrorCount, firstError)
}
continue continue
} }
@@ -321,6 +359,13 @@ func (idx *Indexer) HasIndex() (bool, error) {
// RebuildIndex 重建所有索引 // RebuildIndex 重建所有索引
func (idx *Indexer) RebuildIndex(ctx context.Context) error { func (idx *Indexer) RebuildIndex(ctx context.Context) error {
// 重置错误跟踪
idx.mu.Lock()
idx.lastError = ""
idx.lastErrorTime = time.Time{}
idx.errorCount = 0
idx.mu.Unlock()
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items") rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
if err != nil { if err != nil {
return fmt.Errorf("查询知识项失败: %w", err) return fmt.Errorf("查询知识项失败: %w", err)
@@ -348,14 +393,84 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
idx.logger.Info("已清空旧索引,开始重建") idx.logger.Info("已清空旧索引,开始重建")
} }
failedCount := 0
consecutiveFailures := 0
maxConsecutiveFailures := 2 // 连续失败2次后立即停止(降低阈值,更快停止)
firstFailureItemID := ""
var firstFailureError error
for i, itemID := range itemIDs { for i, itemID := range itemIDs {
if err := idx.IndexItem(ctx, itemID); err != nil { if err := idx.IndexItem(ctx, itemID); err != nil {
idx.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err)) failedCount++
consecutiveFailures++
// 只在第一个失败时记录详细日志
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
idx.logger.Warn("索引知识项失败",
zap.String("itemId", itemID),
zap.Int("totalItems", len(itemIDs)),
zap.Error(err),
)
}
// 如果连续失败过多,可能是配置问题,立即停止索引
if consecutiveFailures >= maxConsecutiveFailures {
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API密钥无效、余额不足等)。第一个失败项: %s, 错误: %v", consecutiveFailures, firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("连续索引失败次数过多,立即停止索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemIDs)),
zap.Int("processedItems", i+1),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
return fmt.Errorf("连续索引失败次数过多: %v", firstFailureError)
}
// 如果失败的知识项过多,记录警告但继续处理(降低阈值到30%)
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项: %s, 错误: %v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("索引失败的知识项过多,可能存在配置问题",
zap.Int("failedCount", failedCount),
zap.Int("totalItems", len(itemIDs)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
}
continue continue
} }
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)))
// 成功时重置连续失败计数和第一个失败信息
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
// 减少进度日志频率(每10个或每10%记录一次)
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
}
} }
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs))) idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount))
return nil return nil
} }
// GetLastError 获取最近一次错误信息
func (idx *Indexer) GetLastError() (string, time.Time) {
idx.mu.RLock()
defer idx.mu.RUnlock()
return idx.lastError, idx.lastErrorTime
}
+17 -1
View File
@@ -639,7 +639,12 @@ func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeIte
// 删除旧目录(如果为空) // 删除旧目录(如果为空)
oldDir := filepath.Dir(item.FilePath) oldDir := filepath.Dir(item.FilePath)
if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 { if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 {
os.Remove(oldDir) // 只有当目录不是知识库根目录时才删除(避免删除根目录)
if oldDir != m.basePath {
if err := os.Remove(oldDir); err != nil {
m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err))
}
}
} }
} }
@@ -686,6 +691,17 @@ func (m *Manager) DeleteItem(id string) error {
return fmt.Errorf("删除知识项失败: %w", err) return fmt.Errorf("删除知识项失败: %w", err)
} }
// 删除空目录(如果为空)
dir := filepath.Dir(filePath)
if entries, err := os.ReadDir(dir); err == nil && len(entries) == 0 {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if dir != m.basePath {
if err := os.Remove(dir); err != nil {
m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err))
}
}
}
return nil return nil
} }
+8 -8
View File
@@ -161,14 +161,14 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
// 查询所有向量(或按风险类型过滤) // 查询所有向量(或按风险类型过滤)
// 使用精确匹配(=)以提高性能和准确性 // 使用精确匹配(=)以提高性能和准确性
// 由于系统提供了 list_knowledge_risk_types 工具,用户应该使用准确的category名称 // 由于系统提供了内置工具来获取风险类型列表,用户应该使用准确的category名称
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配 // 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
var rows *sql.Rows var rows *sql.Rows
if req.RiskType != "" { if req.RiskType != "" {
// 使用精确匹配(=),性能更好且更准确 // 使用精确匹配(=),性能更好且更准确
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性 // 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到 // 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
// 建议用户先调用 list_knowledge_risk_types 获取准确的category名称 // 建议用户先调用相应的内置工具获取准确的category名称
rows, err = r.db.Query(` rows, err = r.db.Query(`
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
FROM knowledge_embeddings e FROM knowledge_embeddings e
+12 -11
View File
@@ -8,6 +8,7 @@ import (
"strings" "strings"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -21,7 +22,7 @@ func RegisterKnowledgeTool(
) { ) {
// 注册第一个工具:获取所有可用的风险类型列表 // 注册第一个工具:获取所有可用的风险类型列表
listRiskTypesTool := mcp.Tool{ listRiskTypesTool := mcp.Tool{
Name: "list_knowledge_risk_types", Name: builtin.ToolListKnowledgeRiskTypes,
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。", Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
ShortDescription: "获取知识库中所有可用的风险类型列表", ShortDescription: "获取知识库中所有可用的风险类型列表",
InputSchema: map[string]interface{}{ InputSchema: map[string]interface{}{
@@ -62,7 +63,7 @@ func RegisterKnowledgeTool(
for i, category := range categories { for i, category := range categories {
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category)) resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
} }
resultText.WriteString("\n提示:在调用 search_knowledge_base 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。") resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
return &mcp.ToolResult{ return &mcp.ToolResult{
Content: []mcp.Content{ Content: []mcp.Content{
@@ -79,8 +80,8 @@ func RegisterKnowledgeTool(
// 注册第二个工具:搜索知识库(保持原有功能) // 注册第二个工具:搜索知识库(保持原有功能)
searchTool := mcp.Tool{ searchTool := mcp.Tool{
Name: "search_knowledge_base", Name: builtin.ToolSearchKnowledgeBase,
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 list_knowledge_risk_types 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。", Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)", ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
InputSchema: map[string]interface{}{ InputSchema: map[string]interface{}{
"type": "object", "type": "object",
@@ -91,7 +92,7 @@ func RegisterKnowledgeTool(
}, },
"risk_type": map[string]interface{}{ "risk_type": map[string]interface{}{
"type": "string", "type": "string",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 list_knowledge_risk_types 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。", "description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
}, },
}, },
"required": []string{"query"}, "required": []string{"query"},
@@ -165,9 +166,9 @@ func RegisterKnowledgeTool(
// 按文档分组结果,以便更好地展示上下文 // 按文档分组结果,以便更好地展示上下文
// 使用有序的slice来保持文档顺序(按最高混合分数) // 使用有序的slice来保持文档顺序(按最高混合分数)
type itemGroup struct { type itemGroup struct {
itemID string itemID string
results []*RetrievalResult results []*RetrievalResult
maxScore float64 // 该文档的最高混合分数 maxScore float64 // 该文档的最高混合分数
} }
itemGroups := make([]*itemGroup, 0) itemGroups := make([]*itemGroup, 0)
itemMap := make(map[string]*itemGroup) itemMap := make(map[string]*itemGroup)
@@ -177,8 +178,8 @@ func RegisterKnowledgeTool(
group, exists := itemMap[itemID] group, exists := itemMap[itemID]
if !exists { if !exists {
group = &itemGroup{ group = &itemGroup{
itemID: itemID, itemID: itemID,
results: make([]*RetrievalResult, 0), results: make([]*RetrievalResult, 0),
maxScore: result.Score, maxScore: result.Score,
} }
itemMap[itemID] = group itemMap[itemID] = group
@@ -219,7 +220,7 @@ func RegisterKnowledgeTool(
}) })
// 显示主结果(混合分数最高的,同时显示相似度和混合分数) // 显示主结果(混合分数最高的,同时显示相似度和混合分数)
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n", resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
resultIndex, mainResult.Similarity*100, mainResult.Score*100)) resultIndex, mainResult.Similarity*100, mainResult.Score*100))
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID)) resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
+33
View File
@@ -0,0 +1,33 @@
package builtin
// 内置工具名称常量
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
const (
// 漏洞管理工具
ToolRecordVulnerability = "record_vulnerability"
// 知识库工具
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
ToolSearchKnowledgeBase = "search_knowledge_base"
)
// IsBuiltinTool 检查工具名称是否是内置工具
func IsBuiltinTool(toolName string) bool {
switch toolName {
case ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase:
return true
default:
return false
}
}
// GetAllBuiltinTools 返回所有内置工具名称列表
func GetAllBuiltinTools() []string {
return []string{
ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase,
}
}
+559 -1
View File
@@ -1,13 +1,16 @@
package mcp package mcp
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os"
"os/exec" "os/exec"
"strings"
"sync" "sync"
"time" "time"
@@ -100,6 +103,20 @@ func (c *HTTPMCPClient) Initialize(ctx context.Context) error {
return fmt.Errorf("初始化失败: %w", err) return fmt.Errorf("初始化失败: %w", err)
} }
// 发送 initialized 通知(MCP 协议要求:收到 initialize 响应后必须发送此通知)
notifyReq := Message{
ID: MessageID{value: nil}, // 通知没有 ID
Method: "notifications/initialized",
Version: "2.0",
}
notifyReq.Params = json.RawMessage("{}")
// 发送通知(不需要等待响应)
if err := c.sendNotification(&notifyReq); err != nil {
c.logger.Warn("发送 initialized 通知失败", zap.Error(err))
// 通知失败不应该导致初始化失败,只记录警告
}
c.setStatus("connected") c.setStatus("connected")
return nil return nil
} }
@@ -193,6 +210,34 @@ func (c *HTTPMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message
return &mcpResp, nil return &mcpResp, nil
} }
func (c *HTTPMCPClient) sendNotification(msg *Message) error {
// 通知没有 ID,不需要等待响应
body, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("序列化通知失败: %w", err)
}
// 使用较短的超时发送通知
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("创建HTTP请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
// 发送通知,不等待响应(通知不需要响应)
resp, err := c.client.Do(httpReq)
if err != nil {
return fmt.Errorf("发送通知失败: %w", err)
}
resp.Body.Close()
return nil
}
func (c *HTTPMCPClient) Close() error { func (c *HTTPMCPClient) Close() error {
c.setStatus("disconnected") c.setStatus("disconnected")
return nil return nil
@@ -202,6 +247,7 @@ func (c *HTTPMCPClient) Close() error {
type StdioMCPClient struct { type StdioMCPClient struct {
command string command string
args []string args []string
env map[string]string
timeout time.Duration timeout time.Duration
cmd *exec.Cmd cmd *exec.Cmd
stdin io.WriteCloser stdin io.WriteCloser
@@ -219,7 +265,7 @@ type StdioMCPClient struct {
} }
// NewStdioMCPClient 创建stdio模式的MCP客户端 // NewStdioMCPClient 创建stdio模式的MCP客户端
func NewStdioMCPClient(command string, args []string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient { func NewStdioMCPClient(command string, args []string, env map[string]string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient {
if timeout <= 0 { if timeout <= 0 {
timeout = 30 * time.Second timeout = 30 * time.Second
} }
@@ -227,6 +273,7 @@ func NewStdioMCPClient(command string, args []string, timeout time.Duration, log
return &StdioMCPClient{ return &StdioMCPClient{
command: command, command: command,
args: args, args: args,
env: env,
timeout: timeout, timeout: timeout,
logger: logger, logger: logger,
status: "disconnected", status: "disconnected",
@@ -289,6 +336,20 @@ func (c *StdioMCPClient) Initialize(ctx context.Context) error {
return fmt.Errorf("初始化失败: %w", err) return fmt.Errorf("初始化失败: %w", err)
} }
// 发送 initialized 通知(MCP 协议要求:收到 initialize 响应后必须发送此通知)
notifyReq := Message{
ID: MessageID{value: nil}, // 通知没有 ID
Method: "notifications/initialized",
Version: "2.0",
}
notifyReq.Params = json.RawMessage("{}")
// 发送通知(不需要等待响应)
if err := c.sendNotification(&notifyReq); err != nil {
c.logger.Warn("发送 initialized 通知失败", zap.Error(err))
// 通知失败不应该导致初始化失败,只记录警告
}
c.setStatus("connected") c.setStatus("connected")
return nil return nil
} }
@@ -296,6 +357,27 @@ func (c *StdioMCPClient) Initialize(ctx context.Context) error {
func (c *StdioMCPClient) startProcess() error { func (c *StdioMCPClient) startProcess() error {
cmd := exec.CommandContext(c.ctx, c.command, c.args...) cmd := exec.CommandContext(c.ctx, c.command, c.args...)
// 设置环境变量
if c.env != nil && len(c.env) > 0 {
// 获取当前环境变量
cmd.Env = os.Environ()
// 添加或覆盖配置的环境变量
for key, value := range c.env {
// 检查是否已存在该环境变量
found := false
for i, envVar := range cmd.Env {
if strings.HasPrefix(envVar, key+"=") {
cmd.Env[i] = key + "=" + value
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, key+"="+value)
}
}
}
stdin, err := cmd.StdinPipe() stdin, err := cmd.StdinPipe()
if err != nil { if err != nil {
return err return err
@@ -424,6 +506,20 @@ func (c *StdioMCPClient) ListTools(ctx context.Context) ([]Tool, error) {
return listResp.Tools, nil return listResp.Tools, nil
} }
func (c *StdioMCPClient) sendNotification(msg *Message) error {
// 通知没有 ID,不需要等待响应
if c.encoder == nil {
return fmt.Errorf("进程未启动")
}
// 直接发送通知,不等待响应
if err := c.encoder.Encode(msg); err != nil {
return fmt.Errorf("发送通知失败: %w", err)
}
return nil
}
func (c *StdioMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) { func (c *StdioMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
req := Message{ req := Message{
ID: MessageID{value: uuid.New().String()}, ID: MessageID{value: uuid.New().String()},
@@ -472,3 +568,465 @@ func (c *StdioMCPClient) Close() error {
c.setStatus("disconnected") c.setStatus("disconnected")
return nil return nil
} }
// SSEMCPClient SSE模式的MCP客户端
type SSEMCPClient struct {
url string
timeout time.Duration
client *http.Client
logger *zap.Logger
mu sync.RWMutex
status string // "disconnected", "connecting", "connected", "error"
sseConn io.ReadCloser
sseCancel context.CancelFunc
requestID int64
responses map[string]chan *Message
responsesMu sync.Mutex
ctx context.Context
}
// NewSSEMCPClient 创建SSE模式的MCP客户端
func NewSSEMCPClient(url string, timeout time.Duration, logger *zap.Logger) *SSEMCPClient {
if timeout <= 0 {
timeout = 30 * time.Second
}
ctx, cancel := context.WithCancel(context.Background())
return &SSEMCPClient{
url: url,
timeout: timeout,
client: &http.Client{Timeout: timeout},
logger: logger,
status: "disconnected",
responses: make(map[string]chan *Message),
ctx: ctx,
sseCancel: cancel,
}
}
func (c *SSEMCPClient) setStatus(status string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = status
}
func (c *SSEMCPClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *SSEMCPClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *SSEMCPClient) Initialize(ctx context.Context) error {
c.setStatus("connecting")
// 建立SSE连接
if err := c.connectSSE(); err != nil {
c.setStatus("error")
return fmt.Errorf("建立SSE连接失败: %w", err)
}
// 启动响应读取goroutine
go c.readSSEResponses()
// 发送初始化请求
req := Message{
ID: MessageID{value: "1"},
Method: "initialize",
Version: "2.0",
}
params := InitializeRequest{
ProtocolVersion: ProtocolVersion,
Capabilities: make(map[string]interface{}),
ClientInfo: ClientInfo{
Name: "CyberStrikeAI",
Version: "1.0.0",
},
}
paramsJSON, _ := json.Marshal(params)
req.Params = paramsJSON
_, err := c.sendRequest(ctx, &req)
if err != nil {
c.setStatus("error")
c.Close()
return fmt.Errorf("初始化失败: %w", err)
}
// 发送 initialized 通知(MCP 协议要求:收到 initialize 响应后必须发送此通知)
notifyReq := Message{
ID: MessageID{value: nil}, // 通知没有 ID
Method: "notifications/initialized",
Version: "2.0",
}
notifyReq.Params = json.RawMessage("{}")
// 发送通知(不需要等待响应)
if err := c.sendNotification(&notifyReq); err != nil {
c.logger.Warn("发送 initialized 通知失败", zap.Error(err))
// 通知失败不应该导致初始化失败,只记录警告
}
c.setStatus("connected")
return nil
}
func (c *SSEMCPClient) connectSSE() error {
// 建立SSE连接(GET请求,Accept: text/event-stream
// SSE连接需要长连接,使用无超时的客户端
sseClient := &http.Client{
Timeout: 0, // 无超时,用于长连接
}
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil)
if err != nil {
return fmt.Errorf("创建SSE请求失败: %w", err)
}
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
resp, err := sseClient.Do(req)
if err != nil {
return fmt.Errorf("SSE连接失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return fmt.Errorf("SSE连接失败,状态码: %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "text/event-stream") {
resp.Body.Close()
return fmt.Errorf("服务器不支持SSEContent-Type: %s", contentType)
}
c.sseConn = resp.Body
return nil
}
func (c *SSEMCPClient) readSSEResponses() {
defer func() {
if r := recover(); r != nil {
c.logger.Error("读取SSE响应时发生panic", zap.Any("error", r))
}
if c.sseConn != nil {
c.sseConn.Close()
}
c.setStatus("disconnected")
}()
if c.sseConn == nil {
return
}
scanner := &sseScanner{reader: bufio.NewReader(c.sseConn)}
for {
select {
case <-c.ctx.Done():
return
default:
}
// 读取SSE事件
event, err := scanner.readEvent()
if err != nil {
if err == io.EOF {
c.setStatus("disconnected")
return
}
c.logger.Error("读取SSE数据失败", zap.Error(err))
return
}
if event == nil || len(event.Data) == 0 {
continue
}
// 解析JSON消息
var msg Message
if err := json.Unmarshal(event.Data, &msg); err != nil {
c.logger.Warn("解析SSE消息失败", zap.Error(err), zap.String("data", string(event.Data)))
continue
}
// 处理响应
id := msg.ID.String()
c.responsesMu.Lock()
if ch, ok := c.responses[id]; ok {
select {
case ch <- &msg:
default:
}
delete(c.responses, id)
}
c.responsesMu.Unlock()
}
}
// sseEvent SSE事件
type sseEvent struct {
Event string
Data []byte
ID string
Retry int
}
// sseScanner SSE扫描器
type sseScanner struct {
reader *bufio.Reader
}
func (s *sseScanner) readEvent() (*sseEvent, error) {
event := &sseEvent{}
for {
line, err := s.reader.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimRight(line, "\r\n")
// 空行表示事件结束
if len(line) == 0 {
if len(event.Data) > 0 {
return event, nil
}
continue
}
// 解析SSE行
if strings.HasPrefix(line, "event: ") {
event.Event = strings.TrimSpace(line[7:])
} else if strings.HasPrefix(line, "data: ") {
data := []byte(strings.TrimSpace(line[6:]))
if len(event.Data) > 0 {
event.Data = append(event.Data, '\n')
}
event.Data = append(event.Data, data...)
} else if strings.HasPrefix(line, "id: ") {
event.ID = strings.TrimSpace(line[4:])
} else if strings.HasPrefix(line, "retry: ") {
fmt.Sscanf(line[7:], "%d", &event.Retry)
}
}
}
func (c *SSEMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
if c.sseConn == nil {
return nil, fmt.Errorf("SSE连接未建立")
}
id := msg.ID.String()
if id == "" {
c.mu.Lock()
c.requestID++
id = fmt.Sprintf("%d", c.requestID)
msg.ID = MessageID{value: id}
c.mu.Unlock()
}
// 创建响应通道
responseCh := make(chan *Message, 1)
c.responsesMu.Lock()
c.responses[id] = responseCh
c.responsesMu.Unlock()
// 通过HTTP POST发送请求(SSE用于接收响应,请求通过POST发送)
body, err := json.Marshal(msg)
if err != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
// 使用POST请求发送消息(通常SSE服务器会提供两个端点:一个用于SSE,一个用于POST)
// 如果URL是SSE端点,尝试使用相同的URL但改为POST,或者使用URL + "/message"
postURL := c.url
if strings.HasSuffix(postURL, "/sse") {
postURL = strings.TrimSuffix(postURL, "/sse")
postURL += "/message"
} else if strings.HasSuffix(postURL, "/events") {
postURL = strings.TrimSuffix(postURL, "/events")
postURL += "/message"
} else if !strings.Contains(postURL, "/message") {
// 如果URL不包含/message,尝试添加
postURL = strings.TrimSuffix(postURL, "/")
postURL += "/message"
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, postURL, bytes.NewReader(body))
if err != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("创建POST请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("发送POST请求失败: %w", err)
}
defer resp.Body.Close()
// 如果POST请求直接返回响应(非SSE模式),直接解析
if resp.StatusCode == http.StatusOK && resp.Header.Get("Content-Type") == "application/json" {
var mcpResp Message
if err := json.NewDecoder(resp.Body).Decode(&mcpResp); err != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if mcpResp.Error != nil {
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("MCP错误: %s (code: %d)", mcpResp.Error.Message, mcpResp.Error.Code)
}
return &mcpResp, nil
}
// 否则等待SSE响应
select {
case resp := <-responseCh:
if resp.Error != nil {
return nil, fmt.Errorf("MCP错误: %s (code: %d)", resp.Error.Message, resp.Error.Code)
}
return resp, nil
case <-ctx.Done():
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, ctx.Err()
case <-time.After(c.timeout):
c.responsesMu.Lock()
delete(c.responses, id)
c.responsesMu.Unlock()
return nil, fmt.Errorf("请求超时")
}
}
func (c *SSEMCPClient) ListTools(ctx context.Context) ([]Tool, error) {
req := Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/list",
Version: "2.0",
}
req.Params = json.RawMessage("{}")
resp, err := c.sendRequest(ctx, &req)
if err != nil {
return nil, fmt.Errorf("获取工具列表失败: %w", err)
}
var listResp ListToolsResponse
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
return nil, fmt.Errorf("解析工具列表失败: %w", err)
}
return listResp.Tools, nil
}
func (c *SSEMCPClient) sendNotification(msg *Message) error {
// 通知没有 ID,不需要等待响应
if c.sseConn == nil {
return fmt.Errorf("SSE连接未建立")
}
body, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("序列化通知失败: %w", err)
}
// 使用 POST 发送通知(与 sendRequest 类似的逻辑)
postURL := c.url
if strings.HasSuffix(postURL, "/sse") {
postURL = strings.TrimSuffix(postURL, "/sse")
postURL += "/message"
} else if strings.HasSuffix(postURL, "/events") {
postURL = strings.TrimSuffix(postURL, "/events")
postURL += "/message"
} else if !strings.Contains(postURL, "/message") {
postURL = strings.TrimSuffix(postURL, "/")
postURL += "/message"
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, postURL, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("创建POST请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
// 发送通知,不等待响应(通知不需要响应)
resp, err := c.client.Do(httpReq)
if err != nil {
return fmt.Errorf("发送通知失败: %w", err)
}
resp.Body.Close()
return nil
}
func (c *SSEMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
req := Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/call",
Version: "2.0",
}
callReq := CallToolRequest{
Name: name,
Arguments: args,
}
paramsJSON, _ := json.Marshal(callReq)
req.Params = paramsJSON
resp, err := c.sendRequest(ctx, &req)
if err != nil {
return nil, fmt.Errorf("调用工具失败: %w", err)
}
var callResp CallToolResponse
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
return nil, fmt.Errorf("解析工具调用结果失败: %w", err)
}
return &ToolResult{
Content: callResp.Content,
IsError: callResp.IsError,
}, nil
}
func (c *SSEMCPClient) Close() error {
c.sseCancel()
if c.sseConn != nil {
c.sseConn.Close()
c.sseConn = nil
}
c.setStatus("disconnected")
return nil
}
+189 -46
View File
@@ -16,14 +16,18 @@ import (
// ExternalMCPManager 外部MCP管理器 // ExternalMCPManager 外部MCP管理器
type ExternalMCPManager struct { type ExternalMCPManager struct {
clients map[string]ExternalMCPClient clients map[string]ExternalMCPClient
configs map[string]config.ExternalMCPServerConfig configs map[string]config.ExternalMCPServerConfig
logger *zap.Logger logger *zap.Logger
storage MonitorStorage // 可选的持久化存储 storage MonitorStorage // 可选的持久化存储
executions map[string]*ToolExecution // 执行记录 executions map[string]*ToolExecution // 执行记录
stats map[string]*ToolStats // 工具统计信息 stats map[string]*ToolStats // 工具统计信息
errors map[string]string // 错误信息 errors map[string]string // 错误信息
mu sync.RWMutex toolCounts map[string]int // 工具数量缓存
toolCountsMu sync.RWMutex // 工具数量缓存的锁
stopRefresh chan struct{} // 停止后台刷新的信号
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
mu sync.RWMutex
} }
// NewExternalMCPManager 创建外部MCP管理器 // NewExternalMCPManager 创建外部MCP管理器
@@ -33,15 +37,20 @@ func NewExternalMCPManager(logger *zap.Logger) *ExternalMCPManager {
// NewExternalMCPManagerWithStorage 创建外部MCP管理器(带持久化存储) // NewExternalMCPManagerWithStorage 创建外部MCP管理器(带持久化存储)
func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage) *ExternalMCPManager { func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage) *ExternalMCPManager {
return &ExternalMCPManager{ manager := &ExternalMCPManager{
clients: make(map[string]ExternalMCPClient), clients: make(map[string]ExternalMCPClient),
configs: make(map[string]config.ExternalMCPServerConfig), configs: make(map[string]config.ExternalMCPServerConfig),
logger: logger, logger: logger,
storage: storage, storage: storage,
executions: make(map[string]*ToolExecution), executions: make(map[string]*ToolExecution),
stats: make(map[string]*ToolStats), stats: make(map[string]*ToolStats),
errors: make(map[string]string), errors: make(map[string]string),
toolCounts: make(map[string]int),
stopRefresh: make(chan struct{}),
} }
// 启动后台刷新工具数量的goroutine
manager.startToolCountRefresh()
return manager
} }
// LoadConfigs 加载配置 // LoadConfigs 加载配置
@@ -104,6 +113,12 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error {
} }
delete(m.configs, name) delete(m.configs, name)
// 清理工具数量缓存
m.toolCountsMu.Lock()
delete(m.toolCounts, name)
m.toolCountsMu.Unlock()
return nil return nil
} }
@@ -174,11 +189,15 @@ func (m *ExternalMCPManager) StartClient(name string) error {
m.mu.Lock() m.mu.Lock()
m.errors[name] = err.Error() m.errors[name] = err.Error()
m.mu.Unlock() m.mu.Unlock()
// 触发工具数量刷新(连接失败,工具数量应为0)
m.triggerToolCountRefresh()
} else { } else {
// 连接成功,清除错误信息 // 连接成功,清除错误信息
m.mu.Lock() m.mu.Lock()
delete(m.errors, name) delete(m.errors, name)
m.mu.Unlock() m.mu.Unlock()
// 连接成功,立即刷新工具数量
m.triggerToolCountRefresh()
} }
}() }()
@@ -204,6 +223,11 @@ func (m *ExternalMCPManager) StopClient(name string) error {
// 清除错误信息 // 清除错误信息
delete(m.errors, name) delete(m.errors, name)
// 更新工具数量缓存(停止后工具数量为0)
m.toolCountsMu.Lock()
m.toolCounts[name] = 0
m.toolCountsMu.Unlock()
// 更新配置为禁用 // 更新配置为禁用
serverCfg.ExternalMCPEnable = false serverCfg.ExternalMCPEnable = false
m.configs[name] = serverCfg m.configs[name] = serverCfg
@@ -532,30 +556,50 @@ func (m *ExternalMCPManager) GetToolStats() map[string]*ToolStats {
return result return result
} }
// GetToolCount 获取指定外部MCP的工具数量 // GetToolCount 获取指定外部MCP的工具数量(从缓存读取,不阻塞)
func (m *ExternalMCPManager) GetToolCount(name string) (int, error) { func (m *ExternalMCPManager) GetToolCount(name string) (int, error) {
// 先从缓存读取
m.toolCountsMu.RLock()
if count, exists := m.toolCounts[name]; exists {
m.toolCountsMu.RUnlock()
return count, nil
}
m.toolCountsMu.RUnlock()
// 如果缓存中没有,检查客户端状态
client, exists := m.GetClient(name) client, exists := m.GetClient(name)
if !exists { if !exists {
return 0, fmt.Errorf("客户端不存在: %s", name) return 0, fmt.Errorf("客户端不存在: %s", name)
} }
if !client.IsConnected() { if !client.IsConnected() {
// 未连接,缓存为0
m.toolCountsMu.Lock()
m.toolCounts[name] = 0
m.toolCountsMu.Unlock()
return 0, nil return 0, nil
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // 如果已连接但缓存中没有,触发异步刷新并返回0(避免阻塞)
defer cancel() m.triggerToolCountRefresh()
return 0, nil
tools, err := client.ListTools(ctx)
if err != nil {
return 0, fmt.Errorf("获取工具列表失败: %w", err)
}
return len(tools), nil
} }
// GetToolCounts 获取所有外部MCP的工具数量 // GetToolCounts 获取所有外部MCP的工具数量(从缓存读取,不阻塞)
func (m *ExternalMCPManager) GetToolCounts() map[string]int { func (m *ExternalMCPManager) GetToolCounts() map[string]int {
m.toolCountsMu.RLock()
defer m.toolCountsMu.RUnlock()
// 返回缓存的副本,避免外部修改
result := make(map[string]int)
for k, v := range m.toolCounts {
result[k] = v
}
return result
}
// refreshToolCounts 刷新工具数量缓存(后台异步执行)
func (m *ExternalMCPManager) refreshToolCounts() {
m.mu.RLock() m.mu.RLock()
clients := make(map[string]ExternalMCPClient) clients := make(map[string]ExternalMCPClient)
for k, v := range m.clients { for k, v := range m.clients {
@@ -563,30 +607,104 @@ func (m *ExternalMCPManager) GetToolCounts() map[string]int {
} }
m.mu.RUnlock() m.mu.RUnlock()
result := make(map[string]int) newCounts := make(map[string]int)
// 使用goroutine并发获取每个客户端的工具数量,避免串行阻塞
type countResult struct {
name string
count int
}
resultChan := make(chan countResult, len(clients))
for name, client := range clients { for name, client := range clients {
if !client.IsConnected() { go func(n string, c ExternalMCPClient) {
result[name] = 0 if !c.IsConnected() {
continue resultChan <- countResult{name: n, count: 0}
} return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // 使用合理的超时时间(15秒),既能应对网络延迟,又不会过长阻塞
tools, err := client.ListTools(ctx) // 由于这是后台异步刷新,超时不会影响前端响应
cancel() ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
tools, err := c.ListTools(ctx)
cancel()
if err != nil { if err != nil {
m.logger.Warn("获取外部MCP工具数量失败", m.logger.Debug("获取外部MCP工具数量失败",
zap.String("name", name), zap.String("name", n),
zap.Error(err), zap.Error(err),
) )
result[name] = 0 // 如果获取失败,保留旧值(在更新时处理)
continue resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值
} return
}
result[name] = len(tools) resultChan <- countResult{name: n, count: len(tools)}
}(name, client)
} }
return result // 收集结果
m.toolCountsMu.RLock()
oldCounts := make(map[string]int)
for k, v := range m.toolCounts {
oldCounts[k] = v
}
m.toolCountsMu.RUnlock()
for i := 0; i < len(clients); i++ {
result := <-resultChan
if result.count >= 0 {
newCounts[result.name] = result.count
} else {
// 获取失败,保留旧值
if oldCount, exists := oldCounts[result.name]; exists {
newCounts[result.name] = oldCount
} else {
newCounts[result.name] = 0
}
}
}
// 更新缓存
m.toolCountsMu.Lock()
// 更新所有获取到的值
for name, count := range newCounts {
m.toolCounts[name] = count
}
// 对于未连接的客户端,设置为0
for name, client := range clients {
if !client.IsConnected() {
m.toolCounts[name] = 0
}
}
m.toolCountsMu.Unlock()
}
// startToolCountRefresh 启动后台刷新工具数量的goroutine
func (m *ExternalMCPManager) startToolCountRefresh() {
m.refreshWg.Add(1)
go func() {
defer m.refreshWg.Done()
ticker := time.NewTicker(10 * time.Second) // 每10秒刷新一次
defer ticker.Stop()
// 立即执行一次刷新
m.refreshToolCounts()
for {
select {
case <-ticker.C:
m.refreshToolCounts()
case <-m.stopRefresh:
return
}
}
}()
}
// triggerToolCountRefresh 触发立即刷新工具数量(异步)
func (m *ExternalMCPManager) triggerToolCountRefresh() {
go m.refreshToolCounts()
} }
// createClient 创建客户端(不连接) // createClient 创建客户端(不连接)
@@ -603,6 +721,7 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
if serverCfg.Command != "" { if serverCfg.Command != "" {
transport = "stdio" transport = "stdio"
} else if serverCfg.URL != "" { } else if serverCfg.URL != "" {
// 默认使用http,但可以通过transport字段指定sse
transport = "http" transport = "http"
} else { } else {
return nil return nil
@@ -619,7 +738,12 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
if serverCfg.Command == "" { if serverCfg.Command == "" {
return nil return nil
} }
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, timeout, m.logger) return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, serverCfg.Env, timeout, m.logger)
case "sse":
if serverCfg.URL == "" {
return nil
}
return NewSSEMCPClient(serverCfg.URL, timeout, m.logger)
default: default:
return nil return nil
} }
@@ -654,6 +778,8 @@ func (m *ExternalMCPManager) setClientStatus(client ExternalMCPClient, status st
c.setStatus(status) c.setStatus(status)
case *StdioMCPClient: case *StdioMCPClient:
c.setStatus(status) c.setStatus(status)
case *SSEMCPClient:
c.setStatus(status)
} }
} }
@@ -693,6 +819,9 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa
zap.String("name", name), zap.String("name", name),
) )
// 连接成功,触发工具数量刷新
m.triggerToolCountRefresh()
return nil return nil
} }
@@ -791,4 +920,18 @@ func (m *ExternalMCPManager) StopAll() {
client.Close() client.Close()
delete(m.clients, name) delete(m.clients, name)
} }
// 清理所有工具数量缓存
m.toolCountsMu.Lock()
m.toolCounts = make(map[string]int)
m.toolCountsMu.Unlock()
// 停止后台刷新(使用 select 避免重复关闭 channel
select {
case <-m.stopRefresh:
// 已经关闭,不需要再次关闭
default:
close(m.stopRefresh)
m.refreshWg.Wait()
}
} }
+1 -1
View File
@@ -180,7 +180,7 @@ func TestStdioMCPClient_Initialize(t *testing.T) {
// 注意:这个测试需要一个真实的stdio MCP服务器 // 注意:这个测试需要一个真实的stdio MCP服务器
// 如果没有服务器,这个测试会失败 // 如果没有服务器,这个测试会失败
logger := zap.NewNop() logger := zap.NewNop()
client := NewStdioMCPClient("echo", []string{"test"}, 5*time.Second, logger) client := NewStdioMCPClient("echo", []string{"test"}, nil, 5*time.Second, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
+20
View File
@@ -0,0 +1,20 @@
name: API安全测试
description: API安全测试专家,专注于API接口安全检测
user_prompt: 你是一个专业的API安全测试专家。请使用专业的API测试工具对目标API接口进行全面的安全检测,包括GraphQL安全、API参数fuzzing、JWT分析、API架构分析等工作。
icon: "\U0001F4E1"
tools:
- api-fuzzer
- api-schema-analyzer
- graphql-scanner
- arjun
- jwt-analyzer
- http-intruder
- http-framework-test
- burpsuite
- httpx
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+33
View File
@@ -0,0 +1,33 @@
name: CTF
description: CTF竞赛专家,擅长解题和漏洞利用
user_prompt: 你是一个CTF竞赛专家。请使用CTF解题思维和方法,快速定位和利用漏洞,解决各类CTF题目。
icon: "\U0001F3C6"
tools:
- amass
- anew
- angr
- api-fuzzer
- api-schema-analyzer
- arjun
- arp-scan
- autorecon
- binwalk
- bloodhound
- burpsuite
- cat
- checkov
- checksec
- cloudmapper
- create-file
- cyberchef
- dalfox
- delete-file
- httpx
- http-framework-test
- exec
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+25
View File
@@ -0,0 +1,25 @@
name: Web应用扫描
description: Web应用漏洞扫描专家,全面的Web安全检测
user_prompt: 你是一个专业的Web应用漏洞扫描专家。请使用各种Web扫描工具对目标Web应用进行全面的安全检测,包括目录枚举、文件扫描、漏洞识别等工作。
icon: "\U0001F310"
tools:
- dirsearch
- dirb
- gobuster
- feroxbuster
- ffuf
- wfuzz
- sqlmap
- dalfox
- xsser
- nikto
- nuclei
- wpscan
- httpx
- http-framework-test
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+19
View File
@@ -0,0 +1,19 @@
name: Web框架测试
description: Web框架安全测试专家,专注于Web应用框架漏洞检测
user_prompt: 你是一个专业的Web框架安全测试专家。请使用专业的工具对Web应用框架进行安全测试,识别框架相关的安全漏洞和配置问题。
icon: "\U0001F310"
tools:
- http-framework-test
- nikto
- nuclei
- wafw00f
- wpscan
- httpx
- burpsuite
- zap
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+31
View File
@@ -0,0 +1,31 @@
name: 二进制分析
description: 二进制分析与利用专家,擅长逆向工程和密码破解
user_prompt: 你是一个专业的二进制分析与利用专家。请使用逆向工程工具分析二进制文件,识别漏洞,进行利用开发。同时擅长密码破解、哈希分析等技术。
icon: "\U0001F52C"
tools:
- dirsearch
- docker-bench-security
- exec
- execute-python-script
- install-python-package
- ghidra
- graphql-scanner
- hakrawler
- hash-identifier
- hashcat
- hashpump
- http-framework-test
- httpx
- gdb
- radare2
- objdump
- strings
- binwalk
- ropper
- ropgadget
- john
- cyberchef
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+17
View File
@@ -0,0 +1,17 @@
name: 云安全审计
description: 云安全审计专家,多云环境安全检测
user_prompt: 你是一个专业的云安全审计专家。请使用专业的云安全工具对AWS、Azure、GCP等云环境进行全面的安全审计,包括配置检查、合规性评估、权限审计、安全最佳实践验证等工作。
icon: "\U00002601"
tools:
- prowler
- scout-suite
- cloudmapper
- pacu
- terrascan
- checkov
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+31
View File
@@ -0,0 +1,31 @@
name: 信息收集
description: 资产发现与信息搜集专家
user_prompt: 你是一个专业的信息收集专家。请使用各种信息收集技术和工具,对目标进行全面的资产发现、子域名枚举、端口扫描、服务识别等信息收集工作。
icon: "\U0001F50D"
tools:
- amass
- subfinder
- dnsenum
- fierce
- fofa_search
- zoomeye_search
- nmap
- masscan
- rustscan
- arp-scan
- nbtscan
- httpx
- http-framework-test
- katana
- hakrawler
- waybackurls
- paramspider
- gau
- uro
- qsreplace
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+23
View File
@@ -0,0 +1,23 @@
name: 后渗透测试
description: 后渗透测试专家,权限维持与横向移动
user_prompt: 你是一个专业的后渗透测试专家。请使用专业的后渗透工具在获得初始访问权限后进行权限提升、横向移动、权限维持、数据收集等后渗透测试工作。
icon: "\U0001F575"
tools:
- linpeas
- winpeas
- mimikatz
- bloodhound
- impacket
- responder
- netexec
- rpcclient
- smbmap
- enum4linux
- enum4linux-ng
- exec
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+18
View File
@@ -0,0 +1,18 @@
name: 容器安全
description: 容器与Kubernetes安全专家,容器环境安全检测
user_prompt: 你是一个专业的容器与Kubernetes安全专家。请使用专业的容器安全工具对Docker容器和Kubernetes集群进行全面的安全检测,包括镜像漏洞扫描、配置检查、运行时安全等工作。
icon: "\U0001F6E1"
tools:
- trivy
- clair
- docker-bench-security
- kube-bench
- kube-hunter
- falco
- exec
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+24
View File
@@ -0,0 +1,24 @@
name: 数字取证
description: 数字取证与隐写分析专家,文件与内存取证
user_prompt: 你是一个专业的数字取证与隐写分析专家。请使用专业的取证工具对文件、磁盘镜像、内存转储进行分析,提取证据信息。同时擅长隐写分析、数据恢复、元数据提取等技术。
icon: "\U0001F50E"
tools:
- volatility
- volatility3
- foremost
- steghide
- stegsolve
- zsteg
- exiftool
- binwalk
- strings
- xxd
- fcrackzip
- pdfcrack
- exec
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+33
View File
@@ -0,0 +1,33 @@
name: 渗透测试
description: 专业渗透测试专家,全面深入的漏洞检测
user_prompt: 你是一个专业的网络安全渗透测试专家。请使用专业的渗透测试方法和工具,对目标进行全面的安全测试,包括但不限于SQL注入、XSS、CSRF、文件包含、命令执行等常见漏洞。
icon: "\U0001F3AF"
tools:
- http-framework-test
- httpx
- amass
- anew
- angr
- api-fuzzer
- api-schema-analyzer
- arjun
- arp-scan
- autorecon
- binwalk
- bloodhound
- burpsuite
- cat
- checkov
- checksec
- cloudmapper
- create-file
- cyberchef
- dalfox
- delete-file
- exec
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+23
View File
@@ -0,0 +1,23 @@
name: 综合漏洞扫描
description: 综合漏洞扫描专家,多类型漏洞检测
user_prompt: 你是一个专业的综合漏洞扫描专家。请使用各种漏洞扫描工具对目标进行全面的安全检测,包括Web漏洞、网络服务漏洞、配置缺陷等多种类型的漏洞识别和分析。
icon: "\U000026A0"
tools:
- nuclei
- nikto
- sqlmap
- nmap
- masscan
- rustscan
- wafw00f
- dalfox
- xsser
- jaeles
- httpx
- http-framework-test
- execute-python-script
- install-python-package
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
enabled: true
+5
View File
@@ -0,0 +1,5 @@
name: 默认
description: 默认角色,不额外携带用户提示词,使用默认MCP
user_prompt: ""
icon: "\U0001F535"
enabled: true
+183 -37
View File
@@ -2,59 +2,205 @@
set -euo pipefail set -euo pipefail
# CyberStrikeAI 启动脚本 # CyberStrikeAI 一键部署启动脚本
ROOT_DIR="$(cd "$(dirname "$0")" && pwd)" ROOT_DIR="$(cd "$(dirname "$0")" && pwd)"
cd "$ROOT_DIR" cd "$ROOT_DIR"
echo "🚀 启动 CyberStrikeAI..." # 颜色定义
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# 打印带颜色的消息
info() { echo -e "${BLUE}$1${NC}"; }
success() { echo -e "${GREEN}$1${NC}"; }
warning() { echo -e "${YELLOW}⚠️ $1${NC}"; }
error() { echo -e "${RED}$1${NC}"; }
echo ""
echo "=========================================="
echo " CyberStrikeAI 一键部署启动脚本"
echo "=========================================="
echo ""
CONFIG_FILE="$ROOT_DIR/config.yaml" CONFIG_FILE="$ROOT_DIR/config.yaml"
VENV_DIR="$ROOT_DIR/venv" VENV_DIR="$ROOT_DIR/venv"
REQUIREMENTS_FILE="$ROOT_DIR/requirements.txt" REQUIREMENTS_FILE="$ROOT_DIR/requirements.txt"
BINARY_NAME="cyberstrike-ai"
# 检查配置文件 # 检查配置文件
if [ ! -f "$CONFIG_FILE" ]; then if [ ! -f "$CONFIG_FILE" ]; then
echo "配置文件 config.yaml 不存在" error "配置文件 config.yaml 不存在"
info "请确保在项目根目录运行此脚本"
exit 1 exit 1
fi fi
# 检查 Python 环境 # 检查并安装 Python 环境
if ! command -v python3 >/dev/null 2>&1; then check_python() {
echo "❌ 未找到 python3,请先安装 Python 3.10+" if ! command -v python3 >/dev/null 2>&1; then
exit 1 error "未找到 python3"
fi echo ""
info "请先安装 Python 3.10 或更高版本:"
echo " macOS: brew install python3"
echo " Ubuntu: sudo apt-get install python3 python3-venv"
echo " CentOS: sudo yum install python3 python3-pip"
exit 1
fi
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
PYTHON_MAJOR=$(echo "$PYTHON_VERSION" | cut -d. -f1)
PYTHON_MINOR=$(echo "$PYTHON_VERSION" | cut -d. -f2)
if [ "$PYTHON_MAJOR" -lt 3 ] || ([ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -lt 10 ]); then
error "Python 版本过低: $PYTHON_VERSION (需要 3.10+)"
exit 1
fi
success "Python 环境检查通过: $PYTHON_VERSION"
}
# 创建并激活虚拟环境 # 检查并安装 Go 环境
if [ ! -d "$VENV_DIR" ]; then check_go() {
echo "🐍 创建 Python 虚拟环境..." if ! command -v go >/dev/null 2>&1; then
python3 -m venv "$VENV_DIR" error "未找到 Go"
fi echo ""
info "请先安装 Go 1.21 或更高版本:"
echo " macOS: brew install go"
echo " Ubuntu: sudo apt-get install golang-go"
echo " CentOS: sudo yum install golang"
echo " 或访问: https://go.dev/dl/"
exit 1
fi
GO_VERSION=$(go version | awk '{print $3}' | sed 's/go//')
GO_MAJOR=$(echo "$GO_VERSION" | cut -d. -f1)
GO_MINOR=$(echo "$GO_VERSION" | cut -d. -f2)
if [ "$GO_MAJOR" -lt 1 ] || ([ "$GO_MAJOR" -eq 1 ] && [ "$GO_MINOR" -lt 21 ]); then
error "Go 版本过低: $GO_VERSION (需要 1.21+)"
exit 1
fi
success "Go 环境检查通过: $(go version)"
}
echo "🐍 激活虚拟环境..." # 设置 Python 虚拟环境
# shellcheck disable=SC1091 setup_python_env() {
source "$VENV_DIR/bin/activate" if [ ! -d "$VENV_DIR" ]; then
info "创建 Python 虚拟环境..."
python3 -m venv "$VENV_DIR"
success "虚拟环境创建完成"
else
info "Python 虚拟环境已存在"
fi
info "激活虚拟环境..."
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"
if [ -f "$REQUIREMENTS_FILE" ]; then
info "安装/更新 Python 依赖..."
pip install --quiet --upgrade pip >/dev/null 2>&1 || true
# 尝试安装依赖,捕获错误输出
PIP_LOG=$(mktemp)
if pip install -r "$REQUIREMENTS_FILE" >"$PIP_LOG" 2>&1; then
success "Python 依赖安装完成"
else
# 检查是否是 angr 安装失败(需要 Rust)
if grep -q "angr" "$PIP_LOG" && grep -q "Rust compiler\|can't find Rust" "$PIP_LOG"; then
warning "angr 安装失败(需要 Rust 编译器)"
echo ""
info "angr 是可选依赖,主要用于二进制分析工具"
info "如果需要使用 angr,请先安装 Rust"
echo " macOS: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
echo " Ubuntu: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
echo " 或访问: https://rustup.rs/"
echo ""
info "其他依赖已安装,可以继续使用(部分工具可能不可用)"
else
warning "部分 Python 依赖安装失败,但可以继续尝试运行"
warning "如果遇到问题,请检查错误信息并手动安装缺失的依赖"
# 显示最后几行错误信息
echo ""
info "错误详情(最后 10 行):"
tail -n 10 "$PIP_LOG" | sed 's/^/ /'
echo ""
fi
fi
rm -f "$PIP_LOG"
else
warning "未找到 requirements.txt,跳过 Python 依赖安装"
fi
}
if [ -f "$REQUIREMENTS_FILE" ]; then # 构建 Go 项目
echo "📦 安装/更新 Python 依赖..." build_go_project() {
pip install -r "$REQUIREMENTS_FILE" info "下载 Go 依赖..."
else go mod download >/dev/null 2>&1 || {
echo "⚠️ 未找到 requirements.txt,跳过 Python 依赖安装" error "Go 依赖下载失败"
fi exit 1
}
info "构建项目..."
if go build -o "$BINARY_NAME" cmd/server/main.go 2>&1; then
success "项目构建完成: $BINARY_NAME"
else
error "项目构建失败"
exit 1
fi
}
# 检查 Go 环境 # 检查是否需要重新构建
if ! command -v go >/dev/null 2>&1; then need_rebuild() {
echo "❌ Go 未安装,请先安装 Go 1.21 或更高版本" if [ ! -f "$BINARY_NAME" ]; then
exit 1 return 0 # 需要构建
fi fi
# 检查源代码是否有更新
if [ "$BINARY_NAME" -ot cmd/server/main.go ] || \
[ "$BINARY_NAME" -ot go.mod ] || \
find internal cmd -name "*.go" -newer "$BINARY_NAME" 2>/dev/null | grep -q .; then
return 0 # 需要重新构建
fi
return 1 # 不需要构建
}
# 下载依赖 # 主流程
echo "📦 下载 Go 依赖..." main() {
go mod download # 环境检查
info "检查运行环境..."
check_python
check_go
echo ""
# 设置 Python 环境
info "设置 Python 环境..."
setup_python_env
echo ""
# 构建 Go 项目
if need_rebuild; then
info "准备构建项目..."
build_go_project
else
success "可执行文件已是最新,跳过构建"
fi
echo ""
# 启动服务器
success "所有准备工作完成!"
echo ""
info "启动 CyberStrikeAI 服务器..."
echo "=========================================="
echo ""
# 运行服务器
exec "./$BINARY_NAME"
}
# 构建项目 # 执行主流程
echo "🔨 构建项目..." main
go build -o cyberstrike-ai cmd/server/main.go
# 运行服务器
echo "✅ 启动服务器..."
./cyberstrike-ai
+2166 -16
View File
File diff suppressed because it is too large Load Diff
+27
View File
@@ -0,0 +1,27 @@
/**
* 内置工具名称常量
* 所有前端代码中使用内置工具名称的地方都应该使用这些常量而不是硬编码字符串
*
* 注意这些常量必须与后端的 internal/mcp/builtin/constants.go 中的常量保持一致
*/
// 内置工具名称常量
const BuiltinTools = {
// 漏洞管理工具
RECORD_VULNERABILITY: 'record_vulnerability',
// 知识库工具
LIST_KNOWLEDGE_RISK_TYPES: 'list_knowledge_risk_types',
SEARCH_KNOWLEDGE_BASE: 'search_knowledge_base'
};
// 检查是否是内置工具
function isBuiltinTool(toolName) {
return Object.values(BuiltinTools).includes(toolName);
}
// 获取所有内置工具名称列表
function getAllBuiltinTools() {
return Object.values(BuiltinTools);
}
+393 -66
View File
@@ -14,6 +14,9 @@ const mentionState = {
selectedIndex: 0, selectedIndex: 0,
}; };
// IME输入法状态跟踪
let isComposing = false;
// 输入框草稿保存相关 // 输入框草稿保存相关
const DRAFT_STORAGE_KEY = 'cyberstrike-chat-draft'; const DRAFT_STORAGE_KEY = 'cyberstrike-chat-draft';
let draftSaveTimer = null; let draftSaveTimer = null;
@@ -85,13 +88,20 @@ function clearChatDraft() {
function adjustTextareaHeight(textarea) { function adjustTextareaHeight(textarea) {
if (!textarea) return; if (!textarea) return;
// 重置高度以获取准确的scrollHeight // 重置高度为auto,然后立即设置为固定值,确保能准确获取scrollHeight
textarea.style.height = '44px'; textarea.style.height = 'auto';
// 强制浏览器重新计算布局
void textarea.offsetHeight;
// 计算新高度(最小44px,最大不超过300px) // 计算新高度(最小44px,最大不超过300px)
const scrollHeight = textarea.scrollHeight; const scrollHeight = textarea.scrollHeight;
const newHeight = Math.min(Math.max(scrollHeight, 44), 300); const newHeight = Math.min(Math.max(scrollHeight, 44), 300);
textarea.style.height = newHeight + 'px'; textarea.style.height = newHeight + 'px';
// 如果内容为空或只有很少内容,立即重置到最小高度
if (!textarea.value || textarea.value.trim().length === 0) {
textarea.style.height = '44px';
}
} }
// 发送消息 // 发送消息
@@ -135,6 +145,9 @@ async function sendMessage() {
let mcpExecutionIds = []; let mcpExecutionIds = [];
try { try {
// 获取当前选中的角色(从 roles.js 的函数获取)
const roleName = typeof getCurrentRole === 'function' ? getCurrentRole() : '';
const response = await apiFetch('/api/agent-loop/stream', { const response = await apiFetch('/api/agent-loop/stream', {
method: 'POST', method: 'POST',
headers: { headers: {
@@ -142,7 +155,8 @@ async function sendMessage() {
}, },
body: JSON.stringify({ body: JSON.stringify({
message: message, message: message,
conversationId: currentConversationId conversationId: currentConversationId,
role: roleName || undefined
}), }),
}); });
@@ -242,6 +256,13 @@ if (typeof window !== 'undefined') {
} }
function ensureMentionToolsLoaded() { function ensureMentionToolsLoaded() {
// 检查角色是否改变,如果改变则强制重新加载
if (typeof window !== 'undefined' && window._mentionToolsRoleChanged) {
mentionToolsLoaded = false;
mentionTools = [];
delete window._mentionToolsRoleChanged;
}
if (mentionToolsLoaded) { if (mentionToolsLoaded) {
return Promise.resolve(mentionTools); return Promise.resolve(mentionTools);
} }
@@ -254,6 +275,16 @@ function ensureMentionToolsLoaded() {
return mentionToolsLoadingPromise; return mentionToolsLoadingPromise;
} }
// 生成工具的唯一标识符,用于区分同名但来源不同的工具
function getToolKeyForMention(tool) {
// 如果是外部工具,使用 external_mcp::tool.name 作为唯一标识
// 如果是内部工具,使用 tool.name 作为标识
if (tool.is_external && tool.external_mcp) {
return `${tool.external_mcp}::${tool.name}`;
}
return tool.name;
}
async function fetchMentionTools() { async function fetchMentionTools() {
const pageSize = 100; const pageSize = 100;
let page = 1; let page = 1;
@@ -262,6 +293,9 @@ async function fetchMentionTools() {
const collected = []; const collected = [];
try { try {
// 获取当前选中的角色(从 roles.js 的函数获取)
const roleName = typeof getCurrentRole === 'function' ? getCurrentRole() : '';
// 同时获取外部MCP列表 // 同时获取外部MCP列表
try { try {
const mcpResponse = await apiFetch('/api/external-mcp'); const mcpResponse = await apiFetch('/api/external-mcp');
@@ -280,23 +314,45 @@ async function fetchMentionTools() {
} }
while (page <= totalPages && page <= 20) { while (page <= totalPages && page <= 20) {
const response = await apiFetch(`/api/config/tools?page=${page}&page_size=${pageSize}`); // 构建API URL,如果指定了角色,添加role查询参数
let url = `/api/config/tools?page=${page}&page_size=${pageSize}`;
if (roleName && roleName !== '默认') {
url += `&role=${encodeURIComponent(roleName)}`;
}
const response = await apiFetch(url);
if (!response.ok) { if (!response.ok) {
break; break;
} }
const result = await response.json(); const result = await response.json();
const tools = Array.isArray(result.tools) ? result.tools : []; const tools = Array.isArray(result.tools) ? result.tools : [];
tools.forEach(tool => { tools.forEach(tool => {
if (!tool || !tool.name || seen.has(tool.name)) { if (!tool || !tool.name) {
return; return;
} }
seen.add(tool.name); // 使用唯一标识符来去重,而不是只使用工具名称
const toolKey = getToolKeyForMention(tool);
if (seen.has(toolKey)) {
return;
}
seen.add(toolKey);
// 确定工具在当前角色中的启用状态
// 如果有 role_enabled 字段,使用它(表示指定了角色)
// 否则使用 enabled 字段(表示未指定角色或使用所有工具)
let roleEnabled = tool.enabled !== false;
if (tool.role_enabled !== undefined && tool.role_enabled !== null) {
roleEnabled = tool.role_enabled;
}
collected.push({ collected.push({
name: tool.name, name: tool.name,
description: tool.description || '', description: tool.description || '',
enabled: tool.enabled !== false, enabled: tool.enabled !== false, // 工具本身的启用状态
roleEnabled: roleEnabled, // 在当前角色中的启用状态
isExternal: !!tool.is_external, isExternal: !!tool.is_external,
externalMcp: tool.external_mcp || '', externalMcp: tool.external_mcp || '',
toolKey: toolKey, // 保存唯一标识符
}); });
}); });
totalPages = result.total_pages || 1; totalPages = result.total_pages || 1;
@@ -317,7 +373,10 @@ function handleChatInputInput(event) {
const textarea = event.target; const textarea = event.target;
updateMentionStateFromInput(textarea); updateMentionStateFromInput(textarea);
// 自动调整输入框高度 // 自动调整输入框高度
adjustTextareaHeight(textarea); // 使用requestAnimationFrame确保在DOM更新后立即调整,特别是在删除内容时
requestAnimationFrame(() => {
adjustTextareaHeight(textarea);
});
// 保存输入内容到localStorage(防抖) // 保存输入内容到localStorage(防抖)
saveChatDraftDebounced(textarea.value); saveChatDraftDebounced(textarea.value);
} }
@@ -327,6 +386,12 @@ function handleChatInputClick(event) {
} }
function handleChatInputKeydown(event) { function handleChatInputKeydown(event) {
// 如果正在使用输入法输入(IME),回车键应该用于确认候选词,而不是发送消息
// 使用 event.isComposing 或 isComposing 标志来判断
if (event.isComposing || isComposing) {
return;
}
if (mentionState.active && mentionSuggestionsEl && mentionSuggestionsEl.style.display !== 'none') { if (mentionState.active && mentionSuggestionsEl && mentionSuggestionsEl.style.display !== 'none') {
if (event.key === 'ArrowDown') { if (event.key === 'ArrowDown') {
event.preventDefault(); event.preventDefault();
@@ -453,6 +518,15 @@ function updateMentionCandidates() {
} }
filtered = filtered.slice().sort((a, b) => { filtered = filtered.slice().sort((a, b) => {
// 如果指定了角色,优先显示在当前角色中启用的工具
if (a.roleEnabled !== undefined || b.roleEnabled !== undefined) {
const aRoleEnabled = a.roleEnabled !== undefined ? a.roleEnabled : a.enabled;
const bRoleEnabled = b.roleEnabled !== undefined ? b.roleEnabled : b.enabled;
if (aRoleEnabled !== bRoleEnabled) {
return aRoleEnabled ? -1 : 1; // 启用的工具排在前面
}
}
if (normalizedQuery) { if (normalizedQuery) {
// 精确匹配MCP名称的工具优先显示 // 精确匹配MCP名称的工具优先显示
const aMcpExact = a.externalMcp && a.externalMcp.toLowerCase() === normalizedQuery; const aMcpExact = a.externalMcp && a.externalMcp.toLowerCase() === normalizedQuery;
@@ -467,8 +541,11 @@ function updateMentionCandidates() {
return aStarts ? -1 : 1; return aStarts ? -1 : 1;
} }
} }
if (a.enabled !== b.enabled) { // 如果指定了角色,使用 roleEnabled;否则使用 enabled
return a.enabled ? -1 : 1; const aEnabled = a.roleEnabled !== undefined ? a.roleEnabled : a.enabled;
const bEnabled = b.roleEnabled !== undefined ? b.roleEnabled : b.enabled;
if (aEnabled !== bEnabled) {
return aEnabled ? -1 : 1;
} }
return a.name.localeCompare(b.name, 'zh-CN'); return a.name.localeCompare(b.name, 'zh-CN');
}); });
@@ -510,13 +587,16 @@ function renderMentionSuggestions({ showLoading = false } = {}) {
const itemsHtml = mentionFilteredTools.map((tool, index) => { const itemsHtml = mentionFilteredTools.map((tool, index) => {
const activeClass = index === mentionState.selectedIndex ? 'active' : ''; const activeClass = index === mentionState.selectedIndex ? 'active' : '';
const disabledClass = tool.enabled ? '' : 'disabled'; // 如果工具有 roleEnabled 字段(指定了角色),使用它;否则使用 enabled
const toolEnabled = tool.roleEnabled !== undefined ? tool.roleEnabled : tool.enabled;
const disabledClass = toolEnabled ? '' : 'disabled';
const badge = tool.isExternal ? '<span class="mention-item-badge">外部</span>' : '<span class="mention-item-badge internal">内置</span>'; const badge = tool.isExternal ? '<span class="mention-item-badge">外部</span>' : '<span class="mention-item-badge internal">内置</span>';
const nameHtml = escapeHtml(tool.name); const nameHtml = escapeHtml(tool.name);
const description = tool.description && tool.description.length > 0 ? escapeHtml(tool.description) : '暂无描述'; const description = tool.description && tool.description.length > 0 ? escapeHtml(tool.description) : '暂无描述';
const descHtml = `<div class="mention-item-desc">${description}</div>`; const descHtml = `<div class="mention-item-desc">${description}</div>`;
const statusLabel = tool.enabled ? '可用' : '已禁用'; // 根据工具在当前角色中的启用状态显示状态标签
const statusClass = tool.enabled ? 'enabled' : 'disabled'; const statusLabel = toolEnabled ? '可用' : (tool.roleEnabled !== undefined ? '已禁用(当前角色)' : '已禁用');
const statusClass = toolEnabled ? 'enabled' : 'disabled';
const originLabel = tool.isExternal const originLabel = tool.isExternal
? (tool.externalMcp ? `来源:${escapeHtml(tool.externalMcp)}` : '来源:外部MCP') ? (tool.externalMcp ? `来源:${escapeHtml(tool.externalMcp)}` : '来源:外部MCP')
: '来源:内置工具'; : '来源:内置工具';
@@ -800,6 +880,24 @@ function addMessage(role, content, mcpExecutionIds = null, progressId = null, cr
bubble.innerHTML = formattedContent; bubble.innerHTML = formattedContent;
contentWrapper.appendChild(bubble); contentWrapper.appendChild(bubble);
// 保存原始内容到消息元素,用于复制功能
if (role === 'assistant') {
messageDiv.dataset.originalContent = content;
}
// 为助手消息添加复制按钮(复制整个回复内容)- 放在消息气泡右下角
if (role === 'assistant') {
const copyBtn = document.createElement('button');
copyBtn.className = 'message-copy-btn';
copyBtn.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><rect x="9" y="9" width="13" height="13" rx="2" ry="2" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" fill="none"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" fill="none"/></svg><span>复制</span>';
copyBtn.title = '复制消息内容';
copyBtn.onclick = function(e) {
e.stopPropagation();
copyMessageToClipboard(messageDiv, this);
};
bubble.appendChild(copyBtn);
}
// 添加时间戳 // 添加时间戳
const timeDiv = document.createElement('div'); const timeDiv = document.createElement('div');
timeDiv.className = 'message-time'; timeDiv.className = 'message-time';
@@ -869,12 +967,71 @@ function addMessage(role, content, mcpExecutionIds = null, progressId = null, cr
return id; return id;
} }
// 复制消息内容到剪贴板(使用原始Markdown格式)
function copyMessageToClipboard(messageDiv, button) {
try {
// 获取保存的原始Markdown内容
const originalContent = messageDiv.dataset.originalContent;
if (!originalContent) {
// 如果没有保存原始内容,尝试从渲染后的HTML提取(降级方案)
const bubble = messageDiv.querySelector('.message-bubble');
if (bubble) {
const tempDiv = document.createElement('div');
tempDiv.innerHTML = bubble.innerHTML;
// 移除复制按钮本身(避免复制按钮文本)
const copyBtnInTemp = tempDiv.querySelector('.message-copy-btn');
if (copyBtnInTemp) {
copyBtnInTemp.remove();
}
// 提取纯文本内容
let textContent = tempDiv.textContent || tempDiv.innerText || '';
textContent = textContent.replace(/\n{3,}/g, '\n\n').trim();
navigator.clipboard.writeText(textContent).then(() => {
showCopySuccess(button);
}).catch(err => {
console.error('复制失败:', err);
alert('复制失败,请手动选择内容复制');
});
}
return;
}
// 使用原始Markdown内容
navigator.clipboard.writeText(originalContent).then(() => {
showCopySuccess(button);
}).catch(err => {
console.error('复制失败:', err);
alert('复制失败,请手动选择内容复制');
});
} catch (error) {
console.error('复制消息时出错:', error);
alert('复制失败,请手动选择内容复制');
}
}
// 显示复制成功提示
function showCopySuccess(button) {
if (button) {
const originalText = button.innerHTML;
button.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M20 6L9 17l-5-5" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round" fill="none"/></svg><span>已复制</span>';
button.style.color = '#10b981';
button.style.background = 'rgba(16, 185, 129, 0.1)';
button.style.borderColor = 'rgba(16, 185, 129, 0.3)';
setTimeout(() => {
button.innerHTML = originalText;
button.style.color = '';
button.style.background = '';
button.style.borderColor = '';
}, 2000);
}
}
// 渲染过程详情 // 渲染过程详情
function renderProcessDetails(messageId, processDetails) { function renderProcessDetails(messageId, processDetails) {
if (!processDetails || processDetails.length === 0) {
return;
}
const messageElement = document.getElementById(messageId); const messageElement = document.getElementById(messageId);
if (!messageElement) { if (!messageElement) {
return; return;
@@ -942,7 +1099,7 @@ function renderProcessDetails(messageId, processDetails) {
} }
} }
// 创建时间线 // 创建时间线(即使没有processDetails也要创建,以便展开详情按钮能正常工作)
const timelineId = detailsId + '-timeline'; const timelineId = detailsId + '-timeline';
let timeline = document.getElementById(timelineId); let timeline = document.getElementById(timelineId);
@@ -958,9 +1115,19 @@ function renderProcessDetails(messageId, processDetails) {
detailsContainer.appendChild(contentDiv); detailsContainer.appendChild(contentDiv);
} }
// 如果没有processDetails或为空,显示空状态
if (!processDetails || processDetails.length === 0) {
// 显示空状态提示
timeline.innerHTML = '<div class="progress-timeline-empty">暂无过程详情(可能执行过快或未触发详细事件)</div>';
// 默认折叠
timeline.classList.remove('expanded');
return;
}
// 清空时间线并重新渲染 // 清空时间线并重新渲染
timeline.innerHTML = ''; timeline.innerHTML = '';
// 渲染每个过程详情事件 // 渲染每个过程详情事件
processDetails.forEach(detail => { processDetails.forEach(detail => {
const eventType = detail.eventType || ''; const eventType = detail.eventType || '';
@@ -987,7 +1154,7 @@ function renderProcessDetails(messageId, processDetails) {
itemTitle = `${statusIcon} 工具 ${escapeHtml(toolName)} 执行${success ? '完成' : '失败'}`; itemTitle = `${statusIcon} 工具 ${escapeHtml(toolName)} 执行${success ? '完成' : '失败'}`;
// 如果是知识检索工具,添加特殊标记 // 如果是知识检索工具,添加特殊标记
if (toolName === 'search_knowledge_base' && success) { if (toolName === BuiltinTools.SEARCH_KNOWLEDGE_BASE && success) {
itemTitle = `📚 ${itemTitle} - 知识检索`; itemTitle = `📚 ${itemTitle} - 知识检索`;
} }
} else if (eventType === 'knowledge_retrieval') { } else if (eventType === 'knowledge_retrieval') {
@@ -1035,6 +1202,13 @@ if (chatInput) {
chatInput.addEventListener('input', handleChatInputInput); chatInput.addEventListener('input', handleChatInputInput);
chatInput.addEventListener('click', handleChatInputClick); chatInput.addEventListener('click', handleChatInputClick);
chatInput.addEventListener('focus', handleChatInputClick); chatInput.addEventListener('focus', handleChatInputClick);
// IME输入法事件监听,用于跟踪输入法状态
chatInput.addEventListener('compositionstart', () => {
isComposing = true;
});
chatInput.addEventListener('compositionend', () => {
isComposing = false;
});
chatInput.addEventListener('blur', () => { chatInput.addEventListener('blur', () => {
setTimeout(() => { setTimeout(() => {
if (!chatInput.matches(':focus')) { if (!chatInput.matches(':focus')) {
@@ -1449,7 +1623,9 @@ function formatConversationTimestamp(dateObj, todayStart, yesterdayStart) {
if (!(dateObj instanceof Date) || isNaN(dateObj.getTime())) { if (!(dateObj instanceof Date) || isNaN(dateObj.getTime())) {
return ''; return '';
} }
const referenceToday = todayStart || new Date(dateObj.getFullYear(), dateObj.getMonth(), dateObj.getDate()); // 如果没有传入 todayStart,使用当前日期作为参考
const now = new Date();
const referenceToday = todayStart || new Date(now.getFullYear(), now.getMonth(), now.getDate());
const referenceYesterday = yesterdayStart || new Date(referenceToday.getTime() - 24 * 60 * 60 * 1000); const referenceYesterday = yesterdayStart || new Date(referenceToday.getTime() - 24 * 60 * 60 * 1000);
const messageDate = new Date(dateObj.getFullYear(), dateObj.getMonth(), dateObj.getDate()); const messageDate = new Date(dateObj.getFullYear(), dateObj.getMonth(), dateObj.getDate());
@@ -1604,17 +1780,19 @@ async function loadConversation(conversationId) {
// 传递消息的创建时间 // 传递消息的创建时间
const messageId = addMessage(msg.role, displayContent, msg.mcpExecutionIds || [], null, msg.createdAt); const messageId = addMessage(msg.role, displayContent, msg.mcpExecutionIds || [], null, msg.createdAt);
// 如果有过程详情,显示它们 // 对于助手消息,总是渲染过程详情(即使没有processDetails也要显示展开详情按钮)
if (msg.processDetails && msg.processDetails.length > 0 && msg.role === 'assistant') { if (msg.role === 'assistant') {
// 延迟一下,确保消息已经渲染 // 延迟一下,确保消息已经渲染
setTimeout(() => { setTimeout(() => {
renderProcessDetails(messageId, msg.processDetails); renderProcessDetails(messageId, msg.processDetails || []);
// 检查是否有错误或取消事件,如果有,确保详情默认折叠 // 如果有过程详情,检查是否有错误或取消事件,如果有,确保详情默认折叠
const hasErrorOrCancelled = msg.processDetails.some(d => if (msg.processDetails && msg.processDetails.length > 0) {
d.eventType === 'error' || d.eventType === 'cancelled' const hasErrorOrCancelled = msg.processDetails.some(d =>
); d.eventType === 'error' || d.eventType === 'cancelled'
if (hasErrorOrCancelled) { );
collapseAllProgressDetails(messageId, null); if (hasErrorOrCancelled) {
collapseAllProgressDetails(messageId, null);
}
} }
}, 100); }, 100);
} }
@@ -3770,12 +3948,24 @@ let contextMenuConversationId = null;
let contextMenuGroupId = null; let contextMenuGroupId = null;
let groupsCache = []; let groupsCache = [];
let conversationGroupMappingCache = {}; let conversationGroupMappingCache = {};
let pendingGroupMappings = {}; // 待保留的分组映射(用于处理后端API延迟的情况)
// 加载分组列表 // 加载分组列表
async function loadGroups() { async function loadGroups() {
try { try {
const response = await apiFetch('/api/groups'); const response = await apiFetch('/api/groups');
groupsCache = await response.json(); if (!response.ok) {
groupsCache = [];
return;
}
const data = await response.json();
// 确保groupsCache是有效数组
if (Array.isArray(data)) {
groupsCache = data;
} else {
// 如果返回的不是数组,使用空数组(不打印警告,因为可能后端返回了错误格式但我们要优雅处理)
groupsCache = [];
}
const groupsList = document.getElementById('conversation-groups-list'); const groupsList = document.getElementById('conversation-groups-list');
if (!groupsList) return; if (!groupsList) return;
@@ -3899,13 +4089,10 @@ async function loadConversationsWithGroups(searchQuery = '') {
} }
// 如果没有搜索关键词,使用原有逻辑 // 如果没有搜索关键词,使用原有逻辑
// 如果对话在某个分组中,且当前不在分组详情页,则跳过 // "最近对话"列表应该只显示不在任何分组中的对话
if (currentGroupId === null && conversationGroupMappingCache[conv.id]) { // 无论是否在分组详情页,都不应该在"最近对话"中显示分组中的对话
return; if (conversationGroupMappingCache[conv.id]) {
} // 对话在某个分组中,不应该显示在"最近对话"列表中
// 如果当前在分组详情页,只显示该分组的对话
if (currentGroupId !== null && conversationGroupMappingCache[conv.id] !== currentGroupId) {
return; return;
} }
@@ -4048,8 +4235,12 @@ async function showConversationContextMenu(event) {
if (convId) { if (convId) {
try { try {
let isPinned = false; let isPinned = false;
if (currentGroupId) { // 检查对话是否真的在当前分组中
// 如果在分组详情页面,获取分组内置顶状态 const conversationGroupId = conversationGroupMappingCache[convId];
const isInCurrentGroup = currentGroupId && conversationGroupId === currentGroupId;
if (isInCurrentGroup) {
// 对话在当前分组中,获取分组内置顶状态
const response = await apiFetch(`/api/groups/${currentGroupId}/conversations`); const response = await apiFetch(`/api/groups/${currentGroupId}/conversations`);
if (response.ok) { if (response.ok) {
const groupConvs = await response.json(); const groupConvs = await response.json();
@@ -4059,7 +4250,7 @@ async function showConversationContextMenu(event) {
} }
} }
} else { } else {
// 不在分组详情页面,获取全局置顶状态 // 不在分组详情页面,或者对话不在当前分组中,获取全局置顶状态
const response = await apiFetch(`/api/conversations/${convId}`); const response = await apiFetch(`/api/conversations/${convId}`);
if (response.ok) { if (response.ok) {
const conv = await response.json(); const conv = await response.json();
@@ -4314,8 +4505,14 @@ async function pinConversation() {
if (!convId) return; if (!convId) return;
try { try {
// 如果当前分组详情页面,使用分组内置顶 // 检查对话是否真的在当前分组
if (currentGroupId) { // 如果对话已经从分组移出,conversationGroupMappingCache 中不会有该对话的映射
// 或者映射的分组ID不等于当前分组ID
const conversationGroupId = conversationGroupMappingCache[convId];
const isInCurrentGroup = currentGroupId && conversationGroupId === currentGroupId;
// 如果当前在分组详情页面,且对话确实在当前分组中,使用分组内置顶
if (isInCurrentGroup) {
// 获取当前对话在分组中的置顶状态 // 获取当前对话在分组中的置顶状态
const response = await apiFetch(`/api/groups/${currentGroupId}/conversations`); const response = await apiFetch(`/api/groups/${currentGroupId}/conversations`);
const groupConvs = await response.json(); const groupConvs = await response.json();
@@ -4337,7 +4534,7 @@ async function pinConversation() {
// 重新加载分组对话 // 重新加载分组对话
loadGroupConversations(currentGroupId); loadGroupConversations(currentGroupId);
} else { } else {
// 不在分组详情页面,使用全局置顶 // 不在分组详情页面,或者对话不在当前分组中,使用全局置顶
const response = await apiFetch(`/api/conversations/${convId}`); const response = await apiFetch(`/api/conversations/${convId}`);
const conv = await response.json(); const conv = await response.json();
const newPinned = !conv.pinned; const newPinned = !conv.pinned;
@@ -4627,27 +4824,30 @@ async function moveConversationToGroup(convId, groupId) {
const oldGroupId = conversationGroupMappingCache[convId]; const oldGroupId = conversationGroupMappingCache[convId];
conversationGroupMappingCache[convId] = groupId; conversationGroupMappingCache[convId] = groupId;
// 将新移动的对话添加到待保留映射中,防止后端API延迟导致映射丢失
pendingGroupMappings[convId] = groupId;
// 如果移动的是当前对话,更新 currentConversationGroupId // 如果移动的是当前对话,更新 currentConversationGroupId
if (currentConversationId === convId) { if (currentConversationId === convId) {
currentConversationGroupId = groupId; currentConversationGroupId = groupId;
} }
// 重新加载分组映射缓存,确保数据同步
await loadConversationGroupMapping();
// 如果当前在分组详情页面,重新加载分组对话 // 如果当前在分组详情页面,重新加载分组对话
if (currentGroupId) { if (currentGroupId) {
// 如果从当前分组移出,或者移动到当前分组,都需要重新加载 // 如果从当前分组移出,或者移动到当前分组,都需要重新加载
if (currentGroupId === oldGroupId || currentGroupId === groupId) { if (currentGroupId === oldGroupId || currentGroupId === groupId) {
await loadGroupConversations(currentGroupId); await loadGroupConversations(currentGroupId);
} }
} else {
// 如果不在分组详情页面,刷新最近对话列表
loadConversationsWithGroups();
} }
// 如果旧分组和新分组不同,且用户正在查看旧分组,也需要刷新旧分组 // 无论是否在分组详情页面,都需要刷新最近对话列表
// 但上面的逻辑已经处理了这种情况(currentGroupId === oldGroupId // 因为最近对话列表会根据分组映射缓存来过滤显示,需要立即更新
// loadConversationsWithGroups 内部会调用 loadConversationGroupMapping
// loadConversationGroupMapping 会保留 pendingGroupMappings 中的映射
await loadConversationsWithGroups();
// 注意:pendingGroupMappings 中的映射会在下次 loadConversationGroupMapping
// 成功从后端加载时自动清理(在 loadConversationGroupMapping 中处理)
// 刷新分组列表,更新高亮状态 // 刷新分组列表,更新高亮状态
await loadGroups(); await loadGroups();
@@ -4668,6 +4868,8 @@ async function removeConversationFromGroup(convId, groupId) {
// 更新缓存 - 立即删除,确保后续加载时能正确识别 // 更新缓存 - 立即删除,确保后续加载时能正确识别
delete conversationGroupMappingCache[convId]; delete conversationGroupMappingCache[convId];
// 同时从待保留映射中移除
delete pendingGroupMappings[convId];
// 如果移除的是当前对话,清除 currentConversationGroupId // 如果移除的是当前对话,清除 currentConversationGroupId
if (currentConversationId === convId) { if (currentConversationId === convId) {
@@ -4708,14 +4910,24 @@ async function loadConversationGroupMapping() {
groups = groupsCache; groups = groupsCache;
} else { } else {
const response = await apiFetch('/api/groups'); const response = await apiFetch('/api/groups');
groups = await response.json(); if (!response.ok) {
// 如果API请求失败,使用空数组,不打印警告(这是正常错误处理)
groups = [];
} else {
groups = await response.json();
// 确保groups是有效数组,只在真正异常时才打印警告
if (!Array.isArray(groups)) {
// 只在返回的不是数组且不是null/undefined时才打印警告(可能是后端返回了错误格式)
if (groups !== null && groups !== undefined) {
console.warn('loadConversationGroupMapping: groups不是有效数组,使用空数组', groups);
}
groups = [];
}
}
} }
// 确保groups是有效数组 // 保存待保留的映射
if (!Array.isArray(groups)) { const preservedMappings = { ...pendingGroupMappings };
console.warn('loadConversationGroupMapping: groups不是有效数组,使用空数组');
groups = [];
}
conversationGroupMappingCache = {}; conversationGroupMappingCache = {};
@@ -4726,9 +4938,16 @@ async function loadConversationGroupMapping() {
if (Array.isArray(conversations)) { if (Array.isArray(conversations)) {
conversations.forEach(conv => { conversations.forEach(conv => {
conversationGroupMappingCache[conv.id] = group.id; conversationGroupMappingCache[conv.id] = group.id;
// 如果这个对话在待保留映射中,从待保留映射中移除(因为已经从后端加载了)
if (preservedMappings[conv.id] === group.id) {
delete pendingGroupMappings[conv.id];
}
}); });
} }
} }
// 恢复待保留的映射(这些是后端API尚未同步的映射)
Object.assign(conversationGroupMappingCache, preservedMappings);
} catch (error) { } catch (error) {
console.error('加载对话分组映射失败:', error); console.error('加载对话分组映射失败:', error);
} }
@@ -5097,7 +5316,8 @@ async function enterGroupDetail(groupId) {
// 刷新分组列表,确保当前分组高亮显示 // 刷新分组列表,确保当前分组高亮显示
await loadGroups(); await loadGroups();
loadGroupConversations(groupId); // 加载分组对话(如果有搜索查询则使用搜索查询)
loadGroupConversations(groupId, currentGroupSearchQuery);
} catch (error) { } catch (error) {
console.error('加载分组失败:', error); console.error('加载分组失败:', error);
currentGroupId = null; currentGroupId = null;
@@ -5107,6 +5327,14 @@ async function enterGroupDetail(groupId) {
// 退出分组详情 // 退出分组详情
function exitGroupDetail() { function exitGroupDetail() {
currentGroupId = null; currentGroupId = null;
currentGroupSearchQuery = ''; // 清除搜索状态
// 隐藏搜索框并清除搜索内容
const searchContainer = document.getElementById('group-search-container');
const searchInput = document.getElementById('group-search-input');
if (searchContainer) searchContainer.style.display = 'none';
if (searchInput) searchInput.value = '';
const sidebar = document.querySelector('.conversation-sidebar'); const sidebar = document.querySelector('.conversation-sidebar');
const groupDetailPage = document.getElementById('group-detail-page'); const groupDetailPage = document.getElementById('group-detail-page');
const chatContainer = document.querySelector('.chat-container'); const chatContainer = document.querySelector('.chat-container');
@@ -5121,7 +5349,7 @@ function exitGroupDetail() {
} }
// 加载分组中的对话 // 加载分组中的对话
async function loadGroupConversations(groupId) { async function loadGroupConversations(groupId, searchQuery = '') {
try { try {
if (!groupId) { if (!groupId) {
console.error('loadGroupConversations: groupId is null or undefined'); console.error('loadGroupConversations: groupId is null or undefined');
@@ -5139,10 +5367,20 @@ async function loadGroupConversations(groupId) {
console.error('group-conversations-list element not found'); console.error('group-conversations-list element not found');
return; return;
} }
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">加载中...</div>';
// 显示加载状态
if (searchQuery) {
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">搜索中...</div>';
} else {
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">加载中...</div>';
}
// 确保使用正确的 groupId // 构建URL,如果有搜索关键词则添加search参数
const url = `/api/groups/${groupId}/conversations`; let url = `/api/groups/${groupId}/conversations`;
if (searchQuery && searchQuery.trim()) {
url += '?search=' + encodeURIComponent(searchQuery.trim());
}
const response = await apiFetch(url); const response = await apiFetch(url);
if (!response.ok) { if (!response.ok) {
console.error(`Failed to load conversations for group ${groupId}:`, response.statusText); console.error(`Failed to load conversations for group ${groupId}:`, response.statusText);
@@ -5184,7 +5422,11 @@ async function loadGroupConversations(groupId) {
list.innerHTML = ''; list.innerHTML = '';
if (groupConvs.length === 0) { if (groupConvs.length === 0) {
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">该分组暂无对话</div>'; if (searchQuery && searchQuery.trim()) {
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">未找到匹配的对话</div>';
} else {
list.innerHTML = '<div style="padding: 40px; text-align: center; color: var(--text-muted);">该分组暂无对话</div>';
}
return; return;
} }
@@ -5580,9 +5822,94 @@ function closeGroupContextMenu() {
} }
// 分组搜索(占位函数) // 分组搜索相关变量
function searchInGroup() { let groupSearchTimer = null;
alert('搜索功能待实现'); let currentGroupSearchQuery = '';
// 切换分组搜索框显示/隐藏
function toggleGroupSearch() {
const searchContainer = document.getElementById('group-search-container');
const searchInput = document.getElementById('group-search-input');
if (!searchContainer || !searchInput) return;
if (searchContainer.style.display === 'none') {
searchContainer.style.display = 'block';
searchInput.focus();
} else {
searchContainer.style.display = 'none';
clearGroupSearch();
}
}
// 处理分组搜索输入
function handleGroupSearchInput(event) {
// 支持回车键搜索
if (event.key === 'Enter') {
event.preventDefault();
performGroupSearch();
return;
}
// 支持ESC键关闭搜索
if (event.key === 'Escape') {
clearGroupSearch();
toggleGroupSearch();
return;
}
const searchInput = document.getElementById('group-search-input');
const clearBtn = document.getElementById('group-search-clear-btn');
if (!searchInput) return;
const query = searchInput.value.trim();
// 显示/隐藏清除按钮
if (clearBtn) {
clearBtn.style.display = query ? 'block' : 'none';
}
// 防抖搜索
if (groupSearchTimer) {
clearTimeout(groupSearchTimer);
}
groupSearchTimer = setTimeout(() => {
performGroupSearch();
}, 300); // 300ms 防抖
}
// 执行分组搜索
async function performGroupSearch() {
const searchInput = document.getElementById('group-search-input');
if (!searchInput || !currentGroupId) return;
const query = searchInput.value.trim();
currentGroupSearchQuery = query;
// 加载搜索结果
await loadGroupConversations(currentGroupId, query);
}
// 清除分组搜索
function clearGroupSearch() {
const searchInput = document.getElementById('group-search-input');
const clearBtn = document.getElementById('group-search-clear-btn');
if (searchInput) {
searchInput.value = '';
}
if (clearBtn) {
clearBtn.style.display = 'none';
}
currentGroupSearchQuery = '';
// 重新加载分组对话(不搜索)
if (currentGroupId) {
loadGroupConversations(currentGroupId, '');
}
} }
// 初始化时加载分组 // 初始化时加载分组
+96 -6
View File
@@ -398,8 +398,8 @@ function updateKnowledgeStats(data, categoryCount) {
} }
} }
// 总分类数(来自分页信息) // 总分类数(来自分页信息,只有在未定义时才使用当前页分类数作为后备值
const totalCategories = knowledgePagination.total || categoryCount; const totalCategories = (knowledgePagination.total != null) ? knowledgePagination.total : categoryCount;
statsContainer.innerHTML = ` statsContainer.innerHTML = `
<div class="knowledge-stat-item"> <div class="knowledge-stat-item">
@@ -457,6 +457,7 @@ async function updateIndexProgress() {
const indexedItems = status.indexed_items || 0; const indexedItems = status.indexed_items || 0;
const progressPercent = status.progress_percent || 0; const progressPercent = status.progress_percent || 0;
const isComplete = status.is_complete || false; const isComplete = status.is_complete || false;
const lastError = status.last_error || '';
if (totalItems === 0) { if (totalItems === 0) {
// 没有知识项,隐藏进度条 // 没有知识项,隐藏进度条
@@ -471,6 +472,58 @@ async function updateIndexProgress() {
// 显示进度条 // 显示进度条
progressContainer.style.display = 'block'; progressContainer.style.display = 'block';
// 如果有错误信息,显示错误
if (lastError) {
progressContainer.innerHTML = `
<div class="knowledge-index-progress-error" style="
background: #fee;
border: 1px solid #fcc;
border-radius: 8px;
padding: 16px;
margin-bottom: 16px;
">
<div style="display: flex; align-items: center; margin-bottom: 8px;">
<span style="font-size: 20px; margin-right: 8px;"></span>
<span style="font-weight: bold; color: #c00;">索引构建失败</span>
</div>
<div style="color: #666; font-size: 14px; margin-bottom: 12px; line-height: 1.5;">
${escapeHtml(lastError)}
</div>
<div style="color: #999; font-size: 12px; margin-bottom: 12px;">
可能的原因嵌入模型配置错误API密钥无效余额不足等请检查配置后重试
</div>
<div style="display: flex; gap: 8px;">
<button onclick="rebuildKnowledgeIndex()" style="
background: #007bff;
color: white;
border: none;
padding: 6px 12px;
border-radius: 4px;
cursor: pointer;
font-size: 13px;
">重试</button>
<button onclick="stopIndexProgressPolling()" style="
background: #6c757d;
color: white;
border: none;
padding: 6px 12px;
border-radius: 4px;
cursor: pointer;
font-size: 13px;
">关闭</button>
</div>
</div>
`;
// 停止轮询
if (indexProgressInterval) {
clearInterval(indexProgressInterval);
indexProgressInterval = null;
}
// 显示错误通知
showNotification('索引构建失败: ' + lastError.substring(0, 100), 'error');
return;
}
if (isComplete) { if (isComplete) {
progressContainer.innerHTML = ` progressContainer.innerHTML = `
<div class="knowledge-index-progress-complete"> <div class="knowledge-index-progress-complete">
@@ -503,8 +556,46 @@ async function updateIndexProgress() {
} }
} }
} catch (error) { } catch (error) {
// 静默失败 // 显示错误信息
console.debug('获取索引状态失败:', error); console.error('获取索引状态失败:', error);
const progressContainer = document.getElementById('knowledge-index-progress');
if (progressContainer) {
progressContainer.style.display = 'block';
progressContainer.innerHTML = `
<div class="knowledge-index-progress-error" style="
background: #fee;
border: 1px solid #fcc;
border-radius: 8px;
padding: 16px;
margin-bottom: 16px;
">
<div style="display: flex; align-items: center; margin-bottom: 8px;">
<span style="font-size: 20px; margin-right: 8px;"></span>
<span style="font-weight: bold; color: #c00;">无法获取索引状态</span>
</div>
<div style="color: #666; font-size: 14px;">
无法连接到服务器获取索引状态请检查网络连接或刷新页面
</div>
</div>
`;
}
// 停止轮询
if (indexProgressInterval) {
clearInterval(indexProgressInterval);
indexProgressInterval = null;
}
}
}
// 停止索引进度轮询
function stopIndexProgressPolling() {
if (indexProgressInterval) {
clearInterval(indexProgressInterval);
indexProgressInterval = null;
}
const progressContainer = document.getElementById('knowledge-index-progress');
if (progressContainer) {
progressContainer.style.display = 'none';
} }
} }
@@ -1035,8 +1126,7 @@ async function deleteKnowledgeItem(id) {
} }
} }
// 更新统计信息(临时更新,稍后会重新加载) // 不在这里更新统计信息,等待重新加载数据后由正确的逻辑更新
updateKnowledgeStatsAfterDelete();
} }
}, 300); }, 300);
} }
+288 -14
View File
@@ -3,6 +3,9 @@ let activeTaskInterval = null;
const ACTIVE_TASK_REFRESH_INTERVAL = 10000; // 10秒检查一次 const ACTIVE_TASK_REFRESH_INTERVAL = 10000; // 10秒检查一次
const TASK_FINAL_STATUSES = new Set(['failed', 'timeout', 'cancelled', 'completed']); const TASK_FINAL_STATUSES = new Set(['failed', 'timeout', 'cancelled', 'completed']);
// 存储工具调用ID到DOM元素的映射,用于更新执行状态
const toolCallStatusMap = new Map();
const conversationExecutionTracker = { const conversationExecutionTracker = {
activeConversations: new Set(), activeConversations: new Set(),
update(tasks = []) { update(tasks = []) {
@@ -493,12 +496,26 @@ function handleStreamEvent(event, progressElement, progressId,
const toolName = toolInfo.toolName || '未知工具'; const toolName = toolInfo.toolName || '未知工具';
const index = toolInfo.index || 0; const index = toolInfo.index || 0;
const total = toolInfo.total || 0; const total = toolInfo.total || 0;
addTimelineItem(timeline, 'tool_call', { const toolCallId = toolInfo.toolCallId || null;
// 添加工具调用项,并标记为执行中
const toolCallItemId = addTimelineItem(timeline, 'tool_call', {
title: `🔧 调用工具: ${escapeHtml(toolName)} (${index}/${total})`, title: `🔧 调用工具: ${escapeHtml(toolName)} (${index}/${total})`,
message: event.message, message: event.message,
data: toolInfo, data: toolInfo,
expanded: false expanded: false
}); });
// 如果有toolCallId,存储映射关系以便后续更新状态
if (toolCallId && toolCallItemId) {
toolCallStatusMap.set(toolCallId, {
itemId: toolCallItemId,
timeline: timeline
});
// 添加执行中状态指示器
updateToolCallStatus(toolCallId, 'running');
}
break; break;
case 'tool_result': case 'tool_result':
@@ -507,6 +524,15 @@ function handleStreamEvent(event, progressElement, progressId,
const resultToolName = resultInfo.toolName || '未知工具'; const resultToolName = resultInfo.toolName || '未知工具';
const success = resultInfo.success !== false; const success = resultInfo.success !== false;
const statusIcon = success ? '✅' : '❌'; const statusIcon = success ? '✅' : '❌';
const resultToolCallId = resultInfo.toolCallId || null;
// 如果有关联的toolCallId,更新工具调用项的状态
if (resultToolCallId && toolCallStatusMap.has(resultToolCallId)) {
updateToolCallStatus(resultToolCallId, success ? 'completed' : 'failed');
// 从映射中移除(已完成)
toolCallStatusMap.delete(resultToolCallId);
}
addTimelineItem(timeline, 'tool_result', { addTimelineItem(timeline, 'tool_result', {
title: `${statusIcon} 工具 ${escapeHtml(resultToolName)} 执行${success ? '完成' : '失败'}`, title: `${statusIcon} 工具 ${escapeHtml(resultToolName)} 执行${success ? '完成' : '失败'}`,
message: event.message, message: event.message,
@@ -767,9 +793,46 @@ function handleStreamEvent(event, progressElement, progressId,
messagesDiv.scrollTop = messagesDiv.scrollHeight; messagesDiv.scrollTop = messagesDiv.scrollHeight;
} }
// 更新工具调用状态
function updateToolCallStatus(toolCallId, status) {
const mapping = toolCallStatusMap.get(toolCallId);
if (!mapping) return;
const item = document.getElementById(mapping.itemId);
if (!item) return;
const titleElement = item.querySelector('.timeline-item-title');
if (!titleElement) return;
// 移除之前的状态类
item.classList.remove('tool-call-running', 'tool-call-completed', 'tool-call-failed');
// 根据状态更新样式和文本
let statusText = '';
if (status === 'running') {
item.classList.add('tool-call-running');
statusText = ' <span class="tool-status-badge tool-status-running">执行中...</span>';
} else if (status === 'completed') {
item.classList.add('tool-call-completed');
statusText = ' <span class="tool-status-badge tool-status-completed">✅ 已完成</span>';
} else if (status === 'failed') {
item.classList.add('tool-call-failed');
statusText = ' <span class="tool-status-badge tool-status-failed">❌ 执行失败</span>';
}
// 更新标题(保留原有文本,追加状态)
const originalText = titleElement.innerHTML;
// 移除之前可能存在的状态标记
const cleanText = originalText.replace(/\s*<span class="tool-status-badge[^>]*>.*?<\/span>/g, '');
titleElement.innerHTML = cleanText + statusText;
}
// 添加时间线项目 // 添加时间线项目
function addTimelineItem(timeline, type, options) { function addTimelineItem(timeline, type, options) {
const item = document.createElement('div'); const item = document.createElement('div');
// 生成唯一ID
const itemId = 'timeline-item-' + Date.now() + '-' + Math.random().toString(36).substr(2, 9);
item.id = itemId;
item.className = `timeline-item timeline-item-${type}`; item.className = `timeline-item timeline-item-${type}`;
const time = new Date().toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit', second: '2-digit' }); const time = new Date().toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit', second: '2-digit' });
@@ -828,6 +891,9 @@ function addTimelineItem(timeline, type, options) {
if (!expanded && (type === 'tool_call' || type === 'tool_result')) { if (!expanded && (type === 'tool_call' || type === 'tool_result')) {
// 对于工具调用和结果,默认显示摘要 // 对于工具调用和结果,默认显示摘要
} }
// 返回item ID以便后续更新
return itemId;
} }
// 加载活跃任务列表 // 加载活跃任务列表
@@ -942,7 +1008,11 @@ const monitorState = {
lastFetchedAt: null, lastFetchedAt: null,
pagination: { pagination: {
page: 1, page: 1,
pageSize: 20, pageSize: (() => {
// 从 localStorage 读取保存的每页显示数量,默认为 20
const saved = localStorage.getItem('monitorPageSize');
return saved ? parseInt(saved, 10) : 20;
})(),
total: 0, total: 0,
totalPages: 0 totalPages: 0
} }
@@ -953,6 +1023,39 @@ function openMonitorPanel() {
if (typeof switchPage === 'function') { if (typeof switchPage === 'function') {
switchPage('mcp-monitor'); switchPage('mcp-monitor');
} }
// 初始化每页显示数量选择器
initializeMonitorPageSize();
}
// 初始化每页显示数量选择器
function initializeMonitorPageSize() {
const pageSizeSelect = document.getElementById('monitor-page-size');
if (pageSizeSelect) {
pageSizeSelect.value = monitorState.pagination.pageSize;
}
}
// 改变每页显示数量
function changeMonitorPageSize() {
const pageSizeSelect = document.getElementById('monitor-page-size');
if (!pageSizeSelect) {
return;
}
const newPageSize = parseInt(pageSizeSelect.value, 10);
if (isNaN(newPageSize) || newPageSize <= 0) {
return;
}
// 保存到 localStorage
localStorage.setItem('monitorPageSize', newPageSize.toString());
// 更新状态
monitorState.pagination.pageSize = newPageSize;
monitorState.pagination.page = 1; // 重置到第一页
// 刷新数据
refreshMonitorPanel(1);
} }
function closeMonitorPanel() { function closeMonitorPanel() {
@@ -974,12 +1077,17 @@ async function refreshMonitorPanel(page = null) {
// 获取当前的筛选条件 // 获取当前的筛选条件
const statusFilter = document.getElementById('monitor-status-filter'); const statusFilter = document.getElementById('monitor-status-filter');
const currentFilter = statusFilter ? statusFilter.value : 'all'; const toolFilter = document.getElementById('monitor-tool-filter');
const currentStatusFilter = statusFilter ? statusFilter.value : 'all';
const currentToolFilter = toolFilter ? (toolFilter.value.trim() || 'all') : 'all';
// 构建请求 URL // 构建请求 URL
let url = `/api/monitor?page=${currentPage}&page_size=${pageSize}`; let url = `/api/monitor?page=${currentPage}&page_size=${pageSize}`;
if (currentFilter && currentFilter !== 'all') { if (currentStatusFilter && currentStatusFilter !== 'all') {
url += `&status=${encodeURIComponent(currentFilter)}`; url += `&status=${encodeURIComponent(currentStatusFilter)}`;
}
if (currentToolFilter && currentToolFilter !== 'all') {
url += `&tool=${encodeURIComponent(currentToolFilter)}`;
} }
const response = await apiFetch(url, { method: 'GET' }); const response = await apiFetch(url, { method: 'GET' });
@@ -1003,8 +1111,11 @@ async function refreshMonitorPanel(page = null) {
} }
renderMonitorStats(monitorState.stats, monitorState.lastFetchedAt); renderMonitorStats(monitorState.stats, monitorState.lastFetchedAt);
renderMonitorExecutions(monitorState.executions, currentFilter); renderMonitorExecutions(monitorState.executions, currentStatusFilter);
renderMonitorPagination(); renderMonitorPagination();
// 初始化每页显示数量选择器
initializeMonitorPageSize();
} catch (error) { } catch (error) {
console.error('刷新监控面板失败:', error); console.error('刷新监控面板失败:', error);
if (statsContainer) { if (statsContainer) {
@@ -1016,14 +1127,30 @@ async function refreshMonitorPanel(page = null) {
} }
} }
async function applyMonitorFilters() { // 处理工具搜索输入(防抖)
const statusFilter = document.getElementById('monitor-status-filter'); let toolFilterDebounceTimer = null;
const status = statusFilter ? statusFilter.value : 'all'; function handleToolFilterInput() {
// 当筛选条件改变时,从后端重新获取数据 // 清除之前的定时器
await refreshMonitorPanelWithFilter(status); if (toolFilterDebounceTimer) {
clearTimeout(toolFilterDebounceTimer);
}
// 设置新的定时器,500ms后执行筛选
toolFilterDebounceTimer = setTimeout(() => {
applyMonitorFilters();
}, 500);
} }
async function refreshMonitorPanelWithFilter(statusFilter = 'all') { async function applyMonitorFilters() {
const statusFilter = document.getElementById('monitor-status-filter');
const toolFilter = document.getElementById('monitor-tool-filter');
const status = statusFilter ? statusFilter.value : 'all';
const tool = toolFilter ? (toolFilter.value.trim() || 'all') : 'all';
// 当筛选条件改变时,从后端重新获取数据
await refreshMonitorPanelWithFilter(status, tool);
}
async function refreshMonitorPanelWithFilter(statusFilter = 'all', toolFilter = 'all') {
const statsContainer = document.getElementById('monitor-stats'); const statsContainer = document.getElementById('monitor-stats');
const execContainer = document.getElementById('monitor-executions'); const execContainer = document.getElementById('monitor-executions');
@@ -1036,6 +1163,9 @@ async function refreshMonitorPanelWithFilter(statusFilter = 'all') {
if (statusFilter && statusFilter !== 'all') { if (statusFilter && statusFilter !== 'all') {
url += `&status=${encodeURIComponent(statusFilter)}`; url += `&status=${encodeURIComponent(statusFilter)}`;
} }
if (toolFilter && toolFilter !== 'all') {
url += `&tool=${encodeURIComponent(toolFilter)}`;
}
const response = await apiFetch(url, { method: 'GET' }); const response = await apiFetch(url, { method: 'GET' });
const result = await response.json().catch(() => ({})); const result = await response.json().catch(() => ({}));
@@ -1060,6 +1190,9 @@ async function refreshMonitorPanelWithFilter(statusFilter = 'all') {
renderMonitorStats(monitorState.stats, monitorState.lastFetchedAt); renderMonitorStats(monitorState.stats, monitorState.lastFetchedAt);
renderMonitorExecutions(monitorState.executions, statusFilter); renderMonitorExecutions(monitorState.executions, statusFilter);
renderMonitorPagination(); renderMonitorPagination();
// 初始化每页显示数量选择器
initializeMonitorPageSize();
} catch (error) { } catch (error) {
console.error('刷新监控面板失败:', error); console.error('刷新监控面板失败:', error);
if (statsContainer) { if (statsContainer) {
@@ -1071,6 +1204,7 @@ async function refreshMonitorPanelWithFilter(statusFilter = 'all') {
} }
} }
function renderMonitorStats(statsMap = {}, lastFetchedAt = null) { function renderMonitorStats(statsMap = {}, lastFetchedAt = null) {
const container = document.getElementById('monitor-stats'); const container = document.getElementById('monitor-stats');
if (!container) { if (!container) {
@@ -1151,11 +1285,19 @@ function renderMonitorExecutions(executions = [], statusFilter = 'all') {
if (!Array.isArray(executions) || executions.length === 0) { if (!Array.isArray(executions) || executions.length === 0) {
// 根据是否有筛选条件显示不同的提示 // 根据是否有筛选条件显示不同的提示
if (statusFilter && statusFilter !== 'all') { const toolFilter = document.getElementById('monitor-tool-filter');
const currentToolFilter = toolFilter ? toolFilter.value : 'all';
const hasFilter = (statusFilter && statusFilter !== 'all') || (currentToolFilter && currentToolFilter !== 'all');
if (hasFilter) {
container.innerHTML = '<div class="monitor-empty">当前筛选条件下暂无记录</div>'; container.innerHTML = '<div class="monitor-empty">当前筛选条件下暂无记录</div>';
} else { } else {
container.innerHTML = '<div class="monitor-empty">暂无执行记录</div>'; container.innerHTML = '<div class="monitor-empty">暂无执行记录</div>';
} }
// 隐藏批量操作栏
const batchActions = document.getElementById('monitor-batch-actions');
if (batchActions) {
batchActions.style.display = 'none';
}
return; return;
} }
@@ -1172,6 +1314,9 @@ function renderMonitorExecutions(executions = [], statusFilter = 'all') {
const executionId = escapeHtml(exec.id || ''); const executionId = escapeHtml(exec.id || '');
return ` return `
<tr> <tr>
<td>
<input type="checkbox" class="monitor-execution-checkbox" value="${executionId}" onchange="updateBatchActionsState()" />
</td>
<td>${toolName}</td> <td>${toolName}</td>
<td><span class="${statusClass}">${statusLabel}</span></td> <td><span class="${statusClass}">${statusLabel}</span></td>
<td>${startTime}</td> <td>${startTime}</td>
@@ -1205,6 +1350,9 @@ function renderMonitorExecutions(executions = [], statusFilter = 'all') {
<table class="monitor-table"> <table class="monitor-table">
<thead> <thead>
<tr> <tr>
<th style="width: 40px;">
<input type="checkbox" id="monitor-select-all" onchange="toggleSelectAll(this)" />
</th>
<th>工具</th> <th>工具</th>
<th>状态</th> <th>状态</th>
<th>开始时间</th> <th>开始时间</th>
@@ -1223,6 +1371,9 @@ function renderMonitorExecutions(executions = [], statusFilter = 'all') {
} else { } else {
container.appendChild(tableContainer); container.appendChild(tableContainer);
} }
// 更新批量操作状态
updateBatchActionsState();
} }
// 渲染监控面板分页控件 // 渲染监控面板分页控件
@@ -1248,7 +1399,16 @@ function renderMonitorPagination() {
pagination.innerHTML = ` pagination.innerHTML = `
<div class="pagination-info"> <div class="pagination-info">
显示 ${startItem}-${endItem} / ${total} 条记录 <span>显示 ${startItem}-${endItem} / ${total} 条记录</span>
<label class="pagination-page-size">
每页显示
<select id="monitor-page-size" onchange="changeMonitorPageSize()">
<option value="10" ${pageSize === 10 ? 'selected' : ''}>10</option>
<option value="20" ${pageSize === 20 ? 'selected' : ''}>20</option>
<option value="50" ${pageSize === 50 ? 'selected' : ''}>50</option>
<option value="100" ${pageSize === 100 ? 'selected' : ''}>100</option>
</select>
</label>
</div> </div>
<div class="pagination-controls"> <div class="pagination-controls">
<button class="btn-secondary" onclick="refreshMonitorPanel(1)" ${page === 1 || total === 0 ? 'disabled' : ''}>首页</button> <button class="btn-secondary" onclick="refreshMonitorPanel(1)" ${page === 1 || total === 0 ? 'disabled' : ''}>首页</button>
@@ -1260,6 +1420,9 @@ function renderMonitorPagination() {
`; `;
container.appendChild(pagination); container.appendChild(pagination);
// 初始化每页显示数量选择器
initializeMonitorPageSize();
} }
// 删除执行记录 // 删除执行记录
@@ -1294,6 +1457,117 @@ async function deleteExecution(executionId) {
} }
} }
// 更新批量操作状态
function updateBatchActionsState() {
const checkboxes = document.querySelectorAll('.monitor-execution-checkbox:checked');
const selectedCount = checkboxes.length;
const batchActions = document.getElementById('monitor-batch-actions');
const selectedCountSpan = document.getElementById('monitor-selected-count');
if (selectedCount > 0) {
if (batchActions) {
batchActions.style.display = 'flex';
}
if (selectedCountSpan) {
selectedCountSpan.textContent = `已选择 ${selectedCount}`;
}
} else {
if (batchActions) {
batchActions.style.display = 'none';
}
}
// 更新全选复选框状态
const selectAllCheckbox = document.getElementById('monitor-select-all');
if (selectAllCheckbox) {
const allCheckboxes = document.querySelectorAll('.monitor-execution-checkbox');
const allChecked = allCheckboxes.length > 0 && Array.from(allCheckboxes).every(cb => cb.checked);
selectAllCheckbox.checked = allChecked;
selectAllCheckbox.indeterminate = selectedCount > 0 && selectedCount < allCheckboxes.length;
}
}
// 切换全选
function toggleSelectAll(checkbox) {
const checkboxes = document.querySelectorAll('.monitor-execution-checkbox');
checkboxes.forEach(cb => {
cb.checked = checkbox.checked;
});
updateBatchActionsState();
}
// 全选
function selectAllExecutions() {
const checkboxes = document.querySelectorAll('.monitor-execution-checkbox');
checkboxes.forEach(cb => {
cb.checked = true;
});
const selectAllCheckbox = document.getElementById('monitor-select-all');
if (selectAllCheckbox) {
selectAllCheckbox.checked = true;
selectAllCheckbox.indeterminate = false;
}
updateBatchActionsState();
}
// 取消全选
function deselectAllExecutions() {
const checkboxes = document.querySelectorAll('.monitor-execution-checkbox');
checkboxes.forEach(cb => {
cb.checked = false;
});
const selectAllCheckbox = document.getElementById('monitor-select-all');
if (selectAllCheckbox) {
selectAllCheckbox.checked = false;
selectAllCheckbox.indeterminate = false;
}
updateBatchActionsState();
}
// 批量删除执行记录
async function batchDeleteExecutions() {
const checkboxes = document.querySelectorAll('.monitor-execution-checkbox:checked');
if (checkboxes.length === 0) {
alert('请先选择要删除的执行记录');
return;
}
const ids = Array.from(checkboxes).map(cb => cb.value);
const count = ids.length;
// 确认删除
if (!confirm(`确定要删除选中的 ${count} 条执行记录吗?此操作不可恢复。`)) {
return;
}
try {
const response = await apiFetch('/api/monitor/executions', {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ ids: ids })
});
if (!response.ok) {
const error = await response.json().catch(() => ({}));
throw new Error(error.error || '批量删除执行记录失败');
}
const result = await response.json().catch(() => ({}));
const deletedCount = result.deleted || count;
// 删除成功后刷新当前页面
const currentPage = monitorState.pagination.page;
await refreshMonitorPanel(currentPage);
alert(`成功删除 ${deletedCount} 条执行记录`);
} catch (error) {
console.error('批量删除执行记录失败:', error);
alert('批量删除执行记录失败: ' + error.message);
}
}
function formatExecutionDuration(start, end) { function formatExecutionDuration(start, end) {
if (!start) { if (!start) {
return '未知'; return '未知';
File diff suppressed because it is too large Load Diff
+102 -7
View File
@@ -3,14 +3,37 @@ let currentPage = 'chat';
// 初始化路由 // 初始化路由
function initRouter() { function initRouter() {
// 默认显示对话页面
switchPage('chat');
// 从URL hash读取页面(如果有) // 从URL hash读取页面(如果有)
const hash = window.location.hash.slice(1); const hash = window.location.hash.slice(1);
if (hash && ['chat', 'vulnerabilities', 'mcp-monitor', 'mcp-management', 'knowledge-management', 'knowledge-retrieval-logs', 'settings'].includes(hash)) { if (hash) {
switchPage(hash); const hashParts = hash.split('?');
const pageId = hashParts[0];
if (pageId && ['chat', 'vulnerabilities', 'mcp-monitor', 'mcp-management', 'knowledge-management', 'knowledge-retrieval-logs', 'roles-management', 'settings', 'tasks'].includes(pageId)) {
switchPage(pageId);
// 如果是chat页面且带有conversation参数,加载对应对话
if (pageId === 'chat' && hashParts.length > 1) {
const params = new URLSearchParams(hashParts[1]);
const conversationId = params.get('conversation');
if (conversationId) {
setTimeout(() => {
// 尝试多种方式调用loadConversation
if (typeof loadConversation === 'function') {
loadConversation(conversationId);
} else if (typeof window.loadConversation === 'function') {
window.loadConversation(conversationId);
} else {
console.warn('loadConversation function not found');
}
}, 500);
}
}
return;
}
} }
// 默认显示对话页面
switchPage('chat');
} }
// 切换页面 // 切换页面
@@ -71,6 +94,19 @@ function updateNavState(pageId) {
knowledgeItem.classList.add('expanded'); knowledgeItem.classList.add('expanded');
} }
const submenuItem = document.querySelector(`.nav-submenu-item[data-page="${pageId}"]`);
if (submenuItem) {
submenuItem.classList.add('active');
}
} else if (pageId === 'roles-management') {
// 角色子菜单项
const rolesItem = document.querySelector('.nav-item[data-page="roles"]');
if (rolesItem) {
rolesItem.classList.add('active');
// 展开角色子菜单
rolesItem.classList.add('expanded');
}
const submenuItem = document.querySelector(`.nav-submenu-item[data-page="${pageId}"]`); const submenuItem = document.querySelector(`.nav-submenu-item[data-page="${pageId}"]`);
if (submenuItem) { if (submenuItem) {
submenuItem.classList.add('active'); submenuItem.classList.add('active');
@@ -178,6 +214,12 @@ function initPage(pageId) {
case 'chat': case 'chat':
// 对话页面已由chat.js初始化 // 对话页面已由chat.js初始化
break; break;
case 'tasks':
// 初始化任务管理页面
if (typeof initTasksPage === 'function') {
initTasksPage();
}
break;
case 'mcp-monitor': case 'mcp-monitor':
// 初始化监控面板 // 初始化监控面板
if (typeof refreshMonitorPanel === 'function') { if (typeof refreshMonitorPanel === 'function') {
@@ -210,6 +252,21 @@ function initPage(pageId) {
loadConfig(false); loadConfig(false);
} }
break; break;
case 'roles-management':
// 初始化角色管理页面
if (typeof loadRoles === 'function') {
loadRoles().then(() => {
if (typeof renderRolesList === 'function') {
renderRolesList();
}
});
}
break;
}
// 清理其他页面的定时器
if (pageId !== 'tasks' && typeof cleanupTasksPage === 'function') {
cleanupTasksPage();
} }
} }
@@ -221,10 +278,48 @@ document.addEventListener('DOMContentLoaded', function() {
// 监听hash变化 // 监听hash变化
window.addEventListener('hashchange', function() { window.addEventListener('hashchange', function() {
const hash = window.location.hash.slice(1); const hash = window.location.hash.slice(1);
if (hash && ['chat', 'vulnerabilities', 'mcp-monitor', 'mcp-management', 'knowledge-management', 'knowledge-retrieval-logs', 'settings'].includes(hash)) { // 处理带参数的hash(如 chat?conversation=xxx
switchPage(hash); const hashParts = hash.split('?');
const pageId = hashParts[0];
if (pageId && ['chat', 'tasks', 'vulnerabilities', 'mcp-monitor', 'mcp-management', 'knowledge-management', 'knowledge-retrieval-logs', 'settings'].includes(pageId)) {
switchPage(pageId);
// 如果是chat页面且带有conversation参数,加载对应对话
if (pageId === 'chat' && hashParts.length > 1) {
const params = new URLSearchParams(hashParts[1]);
const conversationId = params.get('conversation');
if (conversationId) {
setTimeout(() => {
// 尝试多种方式调用loadConversation
if (typeof loadConversation === 'function') {
loadConversation(conversationId);
} else if (typeof window.loadConversation === 'function') {
window.loadConversation(conversationId);
} else {
console.warn('loadConversation function not found');
}
}, 200);
}
}
} }
}); });
// 页面加载时也检查hash参数
const hash = window.location.hash.slice(1);
if (hash) {
const hashParts = hash.split('?');
const pageId = hashParts[0];
if (pageId === 'chat' && hashParts.length > 1) {
const params = new URLSearchParams(hashParts[1]);
const conversationId = params.get('conversation');
if (conversationId && typeof loadConversation === 'function') {
setTimeout(() => {
loadConversation(conversationId);
}, 500);
}
}
}
}); });
// 切换侧边栏折叠/展开 // 切换侧边栏折叠/展开
+91 -40
View File
@@ -2,8 +2,18 @@
let currentConfig = null; let currentConfig = null;
let allTools = []; let allTools = [];
// 全局工具状态映射,用于保存用户在所有页面的修改 // 全局工具状态映射,用于保存用户在所有页面的修改
// key: tool.name, value: { enabled: boolean, is_external: boolean, external_mcp: string } // key: 唯一工具标识符(toolKey),value: { enabled: boolean, is_external: boolean, external_mcp: string }
let toolStateMap = new Map(); let toolStateMap = new Map();
// 生成工具的唯一标识符,用于区分同名但来源不同的工具
function getToolKey(tool) {
// 如果是外部工具,使用 external_mcp::tool.name 作为唯一标识
// 如果是内部工具,使用 tool.name 作为标识
if (tool.is_external && tool.external_mcp) {
return `${tool.external_mcp}::${tool.name}`;
}
return tool.name;
}
// 从localStorage读取每页显示数量,默认为20 // 从localStorage读取每页显示数量,默认为20
const getToolsPageSize = () => { const getToolsPageSize = () => {
const saved = localStorage.getItem('toolsPageSize'); const saved = localStorage.getItem('toolsPageSize');
@@ -199,11 +209,13 @@ async function loadToolsList(page = 1, searchKeyword = '') {
// 初始化工具状态映射(如果工具不在映射中,使用服务器返回的状态) // 初始化工具状态映射(如果工具不在映射中,使用服务器返回的状态)
allTools.forEach(tool => { allTools.forEach(tool => {
if (!toolStateMap.has(tool.name)) { const toolKey = getToolKey(tool);
toolStateMap.set(tool.name, { if (!toolStateMap.has(toolKey)) {
toolStateMap.set(toolKey, {
enabled: tool.enabled, enabled: tool.enabled,
is_external: tool.is_external || false, is_external: tool.is_external || false,
external_mcp: tool.external_mcp || '' external_mcp: tool.external_mcp || '',
name: tool.name // 保存原始工具名称
}); });
} }
}); });
@@ -223,14 +235,16 @@ async function loadToolsList(page = 1, searchKeyword = '') {
function saveCurrentPageToolStates() { function saveCurrentPageToolStates() {
document.querySelectorAll('#tools-list .tool-item').forEach(item => { document.querySelectorAll('#tools-list .tool-item').forEach(item => {
const checkbox = item.querySelector('input[type="checkbox"]'); const checkbox = item.querySelector('input[type="checkbox"]');
const toolKey = item.dataset.toolKey; // 使用唯一标识符
const toolName = item.dataset.toolName; const toolName = item.dataset.toolName;
const isExternal = item.dataset.isExternal === 'true'; const isExternal = item.dataset.isExternal === 'true';
const externalMcp = item.dataset.externalMcp || ''; const externalMcp = item.dataset.externalMcp || '';
if (toolName && checkbox) { if (toolKey && checkbox) {
toolStateMap.set(toolName, { toolStateMap.set(toolKey, {
enabled: checkbox.checked, enabled: checkbox.checked,
is_external: isExternal, is_external: isExternal,
external_mcp: externalMcp external_mcp: externalMcp,
name: toolName // 保存原始工具名称
}); });
} }
}); });
@@ -283,14 +297,16 @@ function renderToolsList() {
} }
allTools.forEach(tool => { allTools.forEach(tool => {
const toolKey = getToolKey(tool); // 生成唯一标识符
const toolItem = document.createElement('div'); const toolItem = document.createElement('div');
toolItem.className = 'tool-item'; toolItem.className = 'tool-item';
toolItem.dataset.toolKey = toolKey; // 保存唯一标识符
toolItem.dataset.toolName = tool.name; // 保存原始工具名称 toolItem.dataset.toolName = tool.name; // 保存原始工具名称
toolItem.dataset.isExternal = tool.is_external ? 'true' : 'false'; toolItem.dataset.isExternal = tool.is_external ? 'true' : 'false';
toolItem.dataset.externalMcp = tool.external_mcp || ''; toolItem.dataset.externalMcp = tool.external_mcp || '';
// 从全局状态映射获取工具状态,如果不存在则使用服务器返回的状态 // 从全局状态映射获取工具状态,如果不存在则使用服务器返回的状态
const toolState = toolStateMap.get(tool.name) || { const toolState = toolStateMap.get(toolKey) || {
enabled: tool.enabled, enabled: tool.enabled,
is_external: tool.is_external || false, is_external: tool.is_external || false,
external_mcp: tool.external_mcp || '' external_mcp: tool.external_mcp || ''
@@ -298,15 +314,18 @@ function renderToolsList() {
// 外部工具标签,显示来源信息 // 外部工具标签,显示来源信息
let externalBadge = ''; let externalBadge = '';
if (toolState.is_external) { if (toolState.is_external || tool.is_external) {
const externalMcpName = toolState.external_mcp || ''; const externalMcpName = toolState.external_mcp || tool.external_mcp || '';
const badgeText = externalMcpName ? `外部 (${escapeHtml(externalMcpName)})` : '外部'; const badgeText = externalMcpName ? `外部 (${escapeHtml(externalMcpName)})` : '外部';
const badgeTitle = externalMcpName ? `外部MCP工具 - 来源:${escapeHtml(externalMcpName)}` : '外部MCP工具'; const badgeTitle = externalMcpName ? `外部MCP工具 - 来源:${escapeHtml(externalMcpName)}` : '外部MCP工具';
externalBadge = `<span class="external-tool-badge" title="${badgeTitle}">${badgeText}</span>`; externalBadge = `<span class="external-tool-badge" title="${badgeTitle}">${badgeText}</span>`;
} }
// 生成唯一的checkbox id,使用工具唯一标识符
const checkboxId = `tool-${escapeHtml(toolKey).replace(/::/g, '--')}`;
toolItem.innerHTML = ` toolItem.innerHTML = `
<input type="checkbox" id="tool-${tool.name}" ${toolState.enabled ? 'checked' : ''} ${toolState.is_external ? 'data-external="true"' : ''} onchange="handleToolCheckboxChange('${tool.name}', this.checked)" /> <input type="checkbox" id="${checkboxId}" ${toolState.enabled ? 'checked' : ''} ${toolState.is_external || tool.is_external ? 'data-external="true"' : ''} onchange="handleToolCheckboxChange('${escapeHtml(toolKey)}', this.checked)" />
<div class="tool-item-info"> <div class="tool-item-info">
<div class="tool-item-name"> <div class="tool-item-name">
${escapeHtml(tool.name)} ${escapeHtml(tool.name)}
@@ -376,16 +395,18 @@ function renderToolsPagination() {
} }
// 处理工具checkbox状态变化 // 处理工具checkbox状态变化
function handleToolCheckboxChange(toolName, enabled) { function handleToolCheckboxChange(toolKey, enabled) {
// 更新全局状态映射 // 更新全局状态映射
const toolItem = document.querySelector(`.tool-item[data-tool-name="${toolName}"]`); const toolItem = document.querySelector(`.tool-item[data-tool-key="${toolKey}"]`);
if (toolItem) { if (toolItem) {
const toolName = toolItem.dataset.toolName;
const isExternal = toolItem.dataset.isExternal === 'true'; const isExternal = toolItem.dataset.isExternal === 'true';
const externalMcp = toolItem.dataset.externalMcp || ''; const externalMcp = toolItem.dataset.externalMcp || '';
toolStateMap.set(toolName, { toolStateMap.set(toolKey, {
enabled: enabled, enabled: enabled,
is_external: isExternal, is_external: isExternal,
external_mcp: externalMcp external_mcp: externalMcp,
name: toolName // 保存原始工具名称
}); });
} }
updateToolsStats(); updateToolsStats();
@@ -398,14 +419,16 @@ function selectAllTools() {
// 更新全局状态映射 // 更新全局状态映射
const toolItem = checkbox.closest('.tool-item'); const toolItem = checkbox.closest('.tool-item');
if (toolItem) { if (toolItem) {
const toolKey = toolItem.dataset.toolKey;
const toolName = toolItem.dataset.toolName; const toolName = toolItem.dataset.toolName;
const isExternal = toolItem.dataset.isExternal === 'true'; const isExternal = toolItem.dataset.isExternal === 'true';
const externalMcp = toolItem.dataset.externalMcp || ''; const externalMcp = toolItem.dataset.externalMcp || '';
if (toolName) { if (toolKey) {
toolStateMap.set(toolName, { toolStateMap.set(toolKey, {
enabled: true, enabled: true,
is_external: isExternal, is_external: isExternal,
external_mcp: externalMcp external_mcp: externalMcp,
name: toolName // 保存原始工具名称
}); });
} }
} }
@@ -420,14 +443,16 @@ function deselectAllTools() {
// 更新全局状态映射 // 更新全局状态映射
const toolItem = checkbox.closest('.tool-item'); const toolItem = checkbox.closest('.tool-item');
if (toolItem) { if (toolItem) {
const toolKey = toolItem.dataset.toolKey;
const toolName = toolItem.dataset.toolName; const toolName = toolItem.dataset.toolName;
const isExternal = toolItem.dataset.isExternal === 'true'; const isExternal = toolItem.dataset.isExternal === 'true';
const externalMcp = toolItem.dataset.externalMcp || ''; const externalMcp = toolItem.dataset.externalMcp || '';
if (toolName) { if (toolKey) {
toolStateMap.set(toolName, { toolStateMap.set(toolKey, {
enabled: false, enabled: false,
is_external: isExternal, is_external: isExternal,
external_mcp: externalMcp external_mcp: externalMcp,
name: toolName // 保存原始工具名称
}); });
} }
} }
@@ -484,11 +509,13 @@ async function updateToolsStats() {
totalTools = allTools.length; totalTools = allTools.length;
totalEnabled = allTools.filter(tool => { totalEnabled = allTools.filter(tool => {
// 优先使用全局状态映射,否则使用checkbox状态,最后使用服务器返回的状态 // 优先使用全局状态映射,否则使用checkbox状态,最后使用服务器返回的状态
const savedState = toolStateMap.get(tool.name); const toolKey = getToolKey(tool);
const savedState = toolStateMap.get(toolKey);
if (savedState !== undefined) { if (savedState !== undefined) {
return savedState.enabled; return savedState.enabled;
} }
const checkbox = document.getElementById(`tool-${tool.name}`); const checkboxId = `tool-${toolKey.replace(/::/g, '--')}`;
const checkbox = document.getElementById(checkboxId);
return checkbox ? checkbox.checked : tool.enabled; return checkbox ? checkbox.checked : tool.enabled;
}).length; }).length;
} else { } else {
@@ -498,16 +525,18 @@ async function updateToolsStats() {
// 从当前页的checkbox获取状态(如果全局映射中没有) // 从当前页的checkbox获取状态(如果全局映射中没有)
allTools.forEach(tool => { allTools.forEach(tool => {
const savedState = toolStateMap.get(tool.name); const toolKey = getToolKey(tool);
const savedState = toolStateMap.get(toolKey);
if (savedState !== undefined) { if (savedState !== undefined) {
localStateMap.set(tool.name, savedState.enabled); localStateMap.set(toolKey, savedState.enabled);
} else { } else {
const checkbox = document.getElementById(`tool-${tool.name}`); const checkboxId = `tool-${toolKey.replace(/::/g, '--')}`;
const checkbox = document.getElementById(checkboxId);
if (checkbox) { if (checkbox) {
localStateMap.set(tool.name, checkbox.checked); localStateMap.set(toolKey, checkbox.checked);
} else { } else {
// 如果checkbox不存在(不在当前页),使用工具原始状态 // 如果checkbox不存在(不在当前页),使用工具原始状态
localStateMap.set(tool.name, tool.enabled); localStateMap.set(toolKey, tool.enabled);
} }
} }
}); });
@@ -527,9 +556,10 @@ async function updateToolsStats() {
const pageResult = await pageResponse.json(); const pageResult = await pageResponse.json();
pageResult.tools.forEach(tool => { pageResult.tools.forEach(tool => {
// 优先使用全局状态映射,否则使用服务器返回的状态 // 优先使用全局状态映射,否则使用服务器返回的状态
if (!localStateMap.has(tool.name)) { const toolKey = getToolKey(tool);
const savedState = toolStateMap.get(tool.name); if (!localStateMap.has(toolKey)) {
localStateMap.set(tool.name, savedState ? savedState.enabled : tool.enabled); const savedState = toolStateMap.get(toolKey);
localStateMap.set(toolKey, savedState ? savedState.enabled : tool.enabled);
} }
}); });
@@ -665,8 +695,9 @@ async function applySettings() {
// 将工具添加到映射中 // 将工具添加到映射中
// 优先使用全局状态映射中的状态(用户修改过的),否则使用服务器返回的状态 // 优先使用全局状态映射中的状态(用户修改过的),否则使用服务器返回的状态
pageResult.tools.forEach(tool => { pageResult.tools.forEach(tool => {
const savedState = toolStateMap.get(tool.name); const toolKey = getToolKey(tool);
allToolsMap.set(tool.name, { const savedState = toolStateMap.get(toolKey);
allToolsMap.set(toolKey, {
name: tool.name, name: tool.name,
enabled: savedState ? savedState.enabled : tool.enabled, enabled: savedState ? savedState.enabled : tool.enabled,
is_external: savedState ? savedState.is_external : (tool.is_external || false), is_external: savedState ? savedState.is_external : (tool.is_external || false),
@@ -683,7 +714,7 @@ async function applySettings() {
} }
// 将所有工具添加到配置中 // 将所有工具添加到配置中
allToolsMap.forEach(tool => { allToolsMap.forEach((tool, toolKey) => {
config.tools.push({ config.tools.push({
name: tool.name, name: tool.name,
enabled: tool.enabled, enabled: tool.enabled,
@@ -694,7 +725,9 @@ async function applySettings() {
} catch (error) { } catch (error) {
console.warn('获取所有工具列表失败,仅使用全局状态映射', error); console.warn('获取所有工具列表失败,仅使用全局状态映射', error);
// 如果获取失败,使用全局状态映射 // 如果获取失败,使用全局状态映射
toolStateMap.forEach((toolData, toolName) => { toolStateMap.forEach((toolData, toolKey) => {
// toolData.name 保存了原始工具名称
const toolName = toolData.name || toolKey.split('::').pop();
config.tools.push({ config.tools.push({
name: toolName, name: toolName,
enabled: toolData.enabled, enabled: toolData.enabled,
@@ -777,8 +810,9 @@ async function saveToolsConfig() {
// 将工具添加到映射中 // 将工具添加到映射中
pageResult.tools.forEach(tool => { pageResult.tools.forEach(tool => {
const savedState = toolStateMap.get(tool.name); const toolKey = getToolKey(tool);
allToolsMap.set(tool.name, { const savedState = toolStateMap.get(toolKey);
allToolsMap.set(toolKey, {
name: tool.name, name: tool.name,
enabled: savedState ? savedState.enabled : tool.enabled, enabled: savedState ? savedState.enabled : tool.enabled,
is_external: savedState ? savedState.is_external : (tool.is_external || false), is_external: savedState ? savedState.is_external : (tool.is_external || false),
@@ -795,7 +829,7 @@ async function saveToolsConfig() {
} }
// 将所有工具添加到配置中 // 将所有工具添加到配置中
allToolsMap.forEach(tool => { allToolsMap.forEach((tool, toolKey) => {
config.tools.push({ config.tools.push({
name: tool.name, name: tool.name,
enabled: tool.enabled, enabled: tool.enabled,
@@ -806,7 +840,9 @@ async function saveToolsConfig() {
} catch (error) { } catch (error) {
console.warn('获取所有工具列表失败,仅使用全局状态映射', error); console.warn('获取所有工具列表失败,仅使用全局状态映射', error);
// 如果获取失败,使用全局状态映射 // 如果获取失败,使用全局状态映射
toolStateMap.forEach((toolData, toolName) => { toolStateMap.forEach((toolData, toolKey) => {
// toolData.name 保存了原始工具名称
const toolName = toolData.name || toolKey.split('::').pop();
config.tools.push({ config.tools.push({
name: toolName, name: toolName,
enabled: toolData.enabled, enabled: toolData.enabled,
@@ -1158,6 +1194,14 @@ function loadExternalMCPExample() {
], ],
description: "示例描述", description: "示例描述",
timeout: 300 timeout: 300
},
"cyberstrike-ai-http": {
transport: "http",
url: "http://127.0.0.1:8081/mcp"
},
"cyberstrike-ai-sse": {
transport: "sse",
url: "http://127.0.0.1:8081/mcp/sse"
} }
}; };
@@ -1231,7 +1275,7 @@ async function saveExternalMCP() {
// 验证配置内容 // 验证配置内容
const transport = config.transport || (config.command ? 'stdio' : config.url ? 'http' : ''); const transport = config.transport || (config.command ? 'stdio' : config.url ? 'http' : '');
if (!transport) { if (!transport) {
errorDiv.textContent = `配置错误: "${name}" 需要指定commandstdio模式)或urlhttp模式)`; errorDiv.textContent = `配置错误: "${name}" 需要指定commandstdio模式)或urlhttp/sse模式)`;
errorDiv.style.display = 'block'; errorDiv.style.display = 'block';
jsonTextarea.classList.add('error'); jsonTextarea.classList.add('error');
return; return;
@@ -1250,6 +1294,13 @@ async function saveExternalMCP() {
jsonTextarea.classList.add('error'); jsonTextarea.classList.add('error');
return; return;
} }
if (transport === 'sse' && !config.url) {
errorDiv.textContent = `配置错误: "${name}" sse模式需要url字段`;
errorDiv.style.display = 'block';
jsonTextarea.classList.add('error');
return;
}
} }
// 清除错误提示 // 清除错误提示
File diff suppressed because it is too large Load Diff
+351 -9
View File
@@ -76,6 +76,17 @@
<span>对话</span> <span>对话</span>
</div> </div>
</div> </div>
<div class="nav-item" data-page="tasks">
<div class="nav-item-content" data-title="任务管理" onclick="switchPage('tasks')">
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M13 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V9z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<polyline points="13 2 13 9 20 9" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<line x1="9" y1="13" x2="15" y2="13" stroke="currentColor" stroke-width="2" stroke-linecap="round"/>
<line x1="9" y1="17" x2="15" y2="17" stroke="currentColor" stroke-width="2" stroke-linecap="round"/>
</svg>
<span>任务管理</span>
</div>
</div>
<div class="nav-item" data-page="vulnerabilities"> <div class="nav-item" data-page="vulnerabilities">
<div class="nav-item-content" data-title="漏洞管理" onclick="switchPage('vulnerabilities')"> <div class="nav-item-content" data-title="漏洞管理" onclick="switchPage('vulnerabilities')">
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
@@ -125,6 +136,23 @@
</div> </div>
</div> </div>
</div> </div>
<div class="nav-item nav-item-has-submenu" data-page="roles">
<div class="nav-item-content" data-title="角色" onclick="toggleSubmenu('roles')">
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M20 21v-2a4 4 0 0 0-4-4H8a4 4 0 0 0-4 4v2" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<circle cx="12" cy="7" r="4" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<span>角色</span>
<svg class="submenu-arrow" width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M9 18l6-6-6-6" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</div>
<div class="nav-submenu">
<div class="nav-submenu-item" data-page="roles-management" onclick="switchPage('roles-management')">
<span>角色管理</span>
</div>
</div>
</div>
<div class="nav-item" data-page="settings"> <div class="nav-item" data-page="settings">
<div class="nav-item-content" data-title="系统设置" onclick="switchPage('settings')"> <div class="nav-item-content" data-title="系统设置" onclick="switchPage('settings')">
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
@@ -151,7 +179,7 @@
</div> </div>
<div class="sidebar-content"> <div class="sidebar-content">
<!-- 全局搜索 --> <!-- 全局搜索 -->
<div class="conversation-search-box" style="margin-bottom: 16px;"> <div class="conversation-search-box" style="margin-bottom: 16px; margin-top: 16px;">
<input type="text" id="conversation-search-input" placeholder="搜索历史记录..." <input type="text" id="conversation-search-input" placeholder="搜索历史记录..."
oninput="handleConversationSearch(this.value)" oninput="handleConversationSearch(this.value)"
onkeypress="if(event.key === 'Enter') handleConversationSearch(this.value)" /> onkeypress="if(event.key === 'Enter') handleConversationSearch(this.value)" />
@@ -207,7 +235,7 @@
</button> </button>
<h2 id="group-detail-title" class="group-detail-title"></h2> <h2 id="group-detail-title" class="group-detail-title"></h2>
<div class="group-detail-actions"> <div class="group-detail-actions">
<button class="group-action-btn" onclick="searchInGroup()" title="搜索"> <button class="group-action-btn" onclick="toggleGroupSearch()" title="搜索" id="group-search-toggle-btn">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="11" cy="11" r="8" stroke="currentColor" stroke-width="2"/> <circle cx="11" cy="11" r="8" stroke="currentColor" stroke-width="2"/>
<path d="m21 21-4.35-4.35" stroke="currentColor" stroke-width="2" stroke-linecap="round"/> <path d="m21 21-4.35-4.35" stroke="currentColor" stroke-width="2" stroke-linecap="round"/>
@@ -226,6 +254,17 @@
</button> </button>
</div> </div>
</div> </div>
<div id="group-search-container" class="group-search-container" style="display: none;">
<div class="group-search-input-wrapper">
<input type="text" id="group-search-input" class="group-search-input" placeholder="搜索分组中的对话..." onkeyup="handleGroupSearchInput(event)" oninput="handleGroupSearchInput(event)">
<button class="group-search-clear-btn" onclick="clearGroupSearch()" title="清除搜索" id="group-search-clear-btn" style="display: none;">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="12" cy="12" r="10" stroke="currentColor" stroke-width="2"/>
<path d="m8 8 8 8M16 8l-8 8" stroke="currentColor" stroke-width="2" stroke-linecap="round"/>
</svg>
</button>
</div>
</div>
<div class="group-detail-content"> <div class="group-detail-content">
<div id="group-conversations-list" class="group-conversations-list"></div> <div id="group-conversations-list" class="group-conversations-list"></div>
</div> </div>
@@ -247,11 +286,37 @@
<div id="active-tasks-bar" class="active-tasks-bar"></div> <div id="active-tasks-bar" class="active-tasks-bar"></div>
<div id="chat-messages" class="chat-messages"></div> <div id="chat-messages" class="chat-messages"></div>
<div class="chat-input-container"> <div class="chat-input-container">
<div class="role-selector-wrapper">
<button id="role-selector-btn" class="role-selector-btn" onclick="toggleRoleSelectionPanel()" title="选择角色">
<span id="role-selector-icon" class="role-selector-icon">🔵</span>
<span id="role-selector-text" class="role-selector-text">默认</span>
<svg class="role-selector-arrow" width="10" height="10" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M6 9l6 6 6-6" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</button>
<!-- 角色选择下拉面板 -->
<div id="role-selection-panel" class="role-selection-panel" style="display: none;">
<div class="role-selection-panel-header">
<h3 class="role-selection-panel-title">选择角色</h3>
<button class="role-selection-panel-close" onclick="closeRoleSelectionPanel()" title="关闭">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M18 6L6 18M6 6l12 12" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</button>
</div>
<div id="role-selection-list" class="role-selection-list-main"></div>
</div>
</div>
<div class="chat-input-field"> <div class="chat-input-field">
<textarea id="chat-input" placeholder="输入测试目标或命令... (Shift+Enter 换行,Enter 发送)" rows="1"></textarea> <textarea id="chat-input" placeholder="输入测试目标或命令... (输入 @ 选择工具 | Shift+Enter 换行,Enter 发送)" rows="1"></textarea>
<div id="mention-suggestions" class="mention-suggestions" role="listbox" aria-label="工具提及候选"></div> <div id="mention-suggestions" class="mention-suggestions" role="listbox" aria-label="工具提及候选"></div>
</div> </div>
<button onclick="sendMessage()">发送</button> <button class="send-btn" onclick="sendMessage()">
<span>发送</span>
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5 12h14M12 5l7 7-7 7" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</button>
</div> </div>
</div> </div>
</div> </div>
@@ -277,6 +342,10 @@
<div class="section-header"> <div class="section-header">
<h3>最新执行记录</h3> <h3>最新执行记录</h3>
<div class="section-actions"> <div class="section-actions">
<label>
工具搜索
<input type="text" id="monitor-tool-filter" placeholder="输入工具名称..." oninput="handleToolFilterInput()" onkeydown="if(event.key==='Enter') applyMonitorFilters()" />
</label>
<label> <label>
状态筛选 状态筛选
<select id="monitor-status-filter" onchange="applyMonitorFilters()"> <select id="monitor-status-filter" onchange="applyMonitorFilters()">
@@ -288,6 +357,16 @@
</label> </label>
</div> </div>
</div> </div>
<div id="monitor-batch-actions" class="monitor-batch-actions" style="display: none;">
<div class="batch-actions-info">
<span id="monitor-selected-count">已选择 0 项</span>
</div>
<div class="batch-actions-buttons">
<button class="btn-secondary" onclick="selectAllExecutions()">全选</button>
<button class="btn-secondary" onclick="deselectAllExecutions()">取消全选</button>
<button class="btn-secondary btn-delete" onclick="batchDeleteExecutions()">批量删除</button>
</div>
</div>
<div id="monitor-executions" class="monitor-table-container"> <div id="monitor-executions" class="monitor-table-container">
<div class="monitor-empty">加载中...</div> <div class="monitor-empty">加载中...</div>
</div> </div>
@@ -527,6 +606,76 @@
</div> </div>
</div> </div>
<!-- 任务管理页面 -->
<div id="page-tasks" class="page">
<div class="page-header">
<h2>任务管理</h2>
<div class="page-header-actions">
<button class="btn-primary" onclick="showBatchImportModal()">批量导入任务</button>
<label class="auto-refresh-toggle">
<input type="checkbox" id="tasks-auto-refresh" checked onchange="toggleTasksAutoRefresh(this.checked)">
<span>自动刷新</span>
</label>
<button class="btn-secondary" onclick="refreshBatchQueues()">刷新</button>
</div>
</div>
<div class="page-content">
<!-- 批量任务队列列表 -->
<div class="batch-queues-section" id="batch-queues-section" style="display: none;">
<!-- 筛选控件 -->
<div class="batch-queues-filters tasks-filters">
<label>
<span>状态筛选</span>
<select id="batch-queues-status-filter" onchange="filterBatchQueues()">
<option value="all">全部</option>
<option value="pending">待执行</option>
<option value="running">执行中</option>
<option value="paused">已暂停</option>
<option value="completed">已完成</option>
<option value="cancelled">已取消</option>
</select>
</label>
<label style="flex: 1; max-width: 300px;">
<span>搜索队列ID、标题或创建时间</span>
<input type="text" id="batch-queues-search" placeholder="输入关键字搜索..."
oninput="filterBatchQueues()">
</label>
</div>
<div id="batch-queues-list" class="batch-queues-list"></div>
<!-- 分页控件 -->
<div id="batch-queues-pagination"></div>
</div>
</div>
</div>
<!-- 角色管理页面 -->
<div id="page-roles-management" class="page">
<div class="page-header">
<h2>角色管理</h2>
<div class="page-header-actions">
<button class="btn-secondary" onclick="refreshRoles()">刷新</button>
<button class="btn-primary" onclick="showAddRoleModal()">添加角色</button>
</div>
</div>
<div class="page-content">
<div class="roles-controls">
<div class="roles-stats-bar" id="roles-stats">
<div class="role-stat-item">
<span class="role-stat-label">总角色数</span>
<span class="role-stat-value">-</span>
</div>
<div class="role-stat-item">
<span class="role-stat-label">已启用</span>
<span class="role-stat-value">-</span>
</div>
</div>
</div>
<div id="roles-list" class="roles-list">
<div class="loading-spinner">加载中...</div>
</div>
</div>
</div>
<!-- 系统设置页面 --> <!-- 系统设置页面 -->
<div id="page-settings" class="page"> <div id="page-settings" class="page">
<div class="page-header"> <div class="page-header">
@@ -609,10 +758,6 @@
<option value="openai">OpenAI</option> <option value="openai">OpenAI</option>
</select> </select>
</div> </div>
<div class="form-group">
<label for="knowledge-embedding-model">模型名称</label>
<input type="text" id="knowledge-embedding-model" placeholder="text-embedding-v4" />
</div>
<div class="form-group"> <div class="form-group">
<label for="knowledge-embedding-base-url">Base URL</label> <label for="knowledge-embedding-base-url">Base URL</label>
<input type="text" id="knowledge-embedding-base-url" placeholder="留空则使用OpenAI配置的base_url" /> <input type="text" id="knowledge-embedding-base-url" placeholder="留空则使用OpenAI配置的base_url" />
@@ -623,6 +768,10 @@
<input type="password" id="knowledge-embedding-api-key" placeholder="留空则使用OpenAI配置的api_key" /> <input type="password" id="knowledge-embedding-api-key" placeholder="留空则使用OpenAI配置的api_key" />
<small class="form-hint">留空则使用OpenAI配置的api_key</small> <small class="form-hint">留空则使用OpenAI配置的api_key</small>
</div> </div>
<div class="form-group">
<label for="knowledge-embedding-model">模型名称</label>
<input type="text" id="knowledge-embedding-model" placeholder="text-embedding-v4" />
</div>
<div class="settings-subsection-header"> <div class="settings-subsection-header">
<h5>检索配置</h5> <h5>检索配置</h5>
@@ -787,6 +936,13 @@
"transport": "http", "transport": "http",
"url": "http://127.0.0.1:8081/mcp" "url": "http://127.0.0.1:8081/mcp"
} }
}</code>
<strong>SSE模式:</strong><br>
<code style="display: block; margin: 8px 0; padding: 8px; background: var(--bg-secondary); border-radius: 4px; white-space: pre-wrap;">{
"cyberstrike-ai-sse": {
"transport": "sse",
"url": "http://127.0.0.1:8081/mcp/sse"
}
}</code> }</code>
</div> </div>
<div id="external-mcp-json-error" class="error-message" style="display: none; margin-top: 8px; padding: 8px; background: rgba(220, 53, 69, 0.1); border: 1px solid rgba(220, 53, 69, 0.3); border-radius: 4px; color: var(--error-color); font-size: 0.875rem;"></div> <div id="external-mcp-json-error" class="error-message" style="display: none; margin-top: 8px; padding: 8px; background: rgba(220, 53, 69, 0.1); border: 1px solid rgba(220, 53, 69, 0.3); border-radius: 4px; color: var(--error-color); font-size: 0.875rem;"></div>
@@ -1001,7 +1157,7 @@
<span class="modal-close" onclick="closeCreateGroupModal()">&times;</span> <span class="modal-close" onclick="closeCreateGroupModal()">&times;</span>
</div> </div>
<div class="modal-body create-group-body"> <div class="modal-body create-group-body">
<p class="create-group-description">分组功能可将对话集中归类管理,并支持自定义指令,让对话更加井然有序。</p> <p class="create-group-description">分组功能可将对话集中归类管理,让对话更加井然有序。</p>
<div class="create-group-input-wrapper"> <div class="create-group-input-wrapper">
<span class="group-icon-input">😊</span> <span class="group-icon-input">😊</span>
<input type="text" id="create-group-name-input" placeholder="请输入分组名称" /> <input type="text" id="create-group-name-input" placeholder="请输入分组名称" />
@@ -1082,6 +1238,100 @@
</div> </div>
</div> </div>
<!-- 批量导入任务模态框 -->
<div id="batch-import-modal" class="modal">
<div class="modal-content" style="max-width: 800px;">
<div class="modal-header">
<h2>批量导入任务</h2>
<span class="modal-close" onclick="closeBatchImportModal()">&times;</span>
</div>
<div class="modal-body">
<div class="form-group">
<label for="batch-queue-title">任务标题</label>
<input type="text" id="batch-queue-title" placeholder="请输入任务标题(可选,用于标识和筛选)" />
<div class="form-hint" style="margin-top: 4px;">
为批量任务队列设置一个标题,方便后续查找和管理。
</div>
</div>
<div class="form-group">
<label for="batch-tasks-input">任务列表(每行一个任务)<span style="color: red;">*</span></label>
<textarea id="batch-tasks-input" rows="15" placeholder="请输入任务列表,每行一个任务,例如:&#10;扫描 192.168.1.1 的开放端口&#10;检查 https://example.com 是否存在SQL注入&#10;枚举 example.com 的子域名" style="font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; font-size: 0.875rem; line-height: 1.5;"></textarea>
<div class="form-hint" style="margin-top: 8px;">
<strong>提示:</strong>每行输入一个任务指令,系统将依次执行这些任务。空行会被自动忽略。
</div>
</div>
<div class="form-group">
<div id="batch-import-stats" class="batch-import-stats"></div>
</div>
</div>
<div class="modal-footer">
<button class="btn-secondary" onclick="closeBatchImportModal()">取消</button>
<button class="btn-primary" onclick="createBatchQueue()">创建队列</button>
</div>
</div>
</div>
<!-- 批量任务队列详情模态框 -->
<div id="batch-queue-detail-modal" class="modal">
<div class="modal-content" style="max-width: 900px;">
<div class="modal-header">
<h2 id="batch-queue-detail-title">批量任务队列详情</h2>
<div style="display: flex; align-items: center; gap: 12px;">
<div class="modal-header-actions">
<button class="btn-secondary" id="batch-queue-add-task-btn" onclick="showAddBatchTaskModal()" style="display: none;">添加任务</button>
<button class="btn-primary" id="batch-queue-start-btn" onclick="startBatchQueue()" style="display: none;">开始执行</button>
<button class="btn-secondary" id="batch-queue-pause-btn" onclick="pauseBatchQueue()" style="display: none;">暂停队列</button>
<button class="btn-secondary btn-danger" id="batch-queue-delete-btn" onclick="deleteBatchQueue()" style="display: none;">删除队列</button>
</div>
<span class="modal-close" onclick="closeBatchQueueDetailModal()">&times;</span>
</div>
</div>
<div class="modal-body">
<div id="batch-queue-detail-content"></div>
</div>
</div>
</div>
<!-- 编辑批量任务模态框 -->
<div id="edit-batch-task-modal" class="modal">
<div class="modal-content" style="max-width: 600px;">
<div class="modal-header">
<h2>编辑任务</h2>
<span class="modal-close" onclick="closeEditBatchTaskModal()">&times;</span>
</div>
<div class="modal-body">
<div class="form-group">
<label for="edit-task-message">任务消息</label>
<textarea id="edit-task-message" class="form-control" rows="5" placeholder="请输入任务消息"></textarea>
</div>
<div class="form-actions">
<button class="btn-primary" onclick="saveBatchTask()">保存</button>
<button class="btn-secondary" onclick="closeEditBatchTaskModal()">取消</button>
</div>
</div>
</div>
</div>
<!-- 添加批量任务模态框 -->
<div id="add-batch-task-modal" class="modal">
<div class="modal-content" style="max-width: 600px;">
<div class="modal-header">
<h2>添加任务</h2>
<span class="modal-close" onclick="closeAddBatchTaskModal()">&times;</span>
</div>
<div class="modal-body">
<div class="form-group">
<label for="add-task-message">任务消息</label>
<textarea id="add-task-message" class="form-control" rows="5" placeholder="请输入任务消息"></textarea>
</div>
<div class="form-actions">
<button class="btn-primary" onclick="saveAddBatchTask()">添加</button>
<button class="btn-secondary" onclick="closeAddBatchTaskModal()">取消</button>
</div>
</div>
</div>
</div>
<!-- 漏洞编辑模态框 --> <!-- 漏洞编辑模态框 -->
<div id="vulnerability-modal" class="modal"> <div id="vulnerability-modal" class="modal">
<div class="modal-content" style="max-width: 900px;"> <div class="modal-content" style="max-width: 900px;">
@@ -1150,6 +1400,96 @@
</div> </div>
</div> </div>
<!-- 角色选择弹窗 -->
<div id="role-select-modal" class="modal">
<div class="modal-content role-select-modal-content">
<div class="modal-header">
<h2>选择角色</h2>
<span class="modal-close" onclick="closeRoleSelectModal()">&times;</span>
</div>
<div class="modal-body role-select-body">
<div id="role-select-list" class="role-select-list"></div>
</div>
</div>
</div>
<!-- 角色编辑模态框 -->
<div id="role-modal" class="modal">
<div class="modal-content" style="max-width: 900px;">
<div class="modal-header">
<h2 id="role-modal-title">添加角色</h2>
<span class="modal-close" onclick="closeRoleModal()">&times;</span>
</div>
<div class="modal-body">
<div class="form-group">
<label for="role-name">角色名称 <span style="color: red;">*</span></label>
<input type="text" id="role-name" placeholder="输入角色名称" required />
</div>
<div class="form-group">
<label for="role-description">角色描述</label>
<input type="text" id="role-description" placeholder="输入角色描述" />
</div>
<div class="form-group">
<label for="role-icon">角色图标</label>
<input type="text" id="role-icon" placeholder="输入emoji图标,例如: 🏆" maxlength="10" />
<small class="form-hint">输入一个emoji作为角色的图标,将显示在角色选择器中。</small>
</div>
<div class="form-group">
<label for="role-user-prompt">用户提示词</label>
<textarea id="role-user-prompt" rows="10" placeholder="输入用户提示词,会在用户消息前追加此提示词..."></textarea>
<small class="form-hint">此提示词会追加到用户消息前,用于指导AI的行为。注意:这不会修改系统提示词。</small>
</div>
<div class="form-group" id="role-tools-section">
<label>关联的工具(可选)</label>
<div id="role-tools-default-hint" class="role-tools-default-hint" style="display: none;">
<div class="role-tools-default-info">
<span class="role-tools-default-icon"></span>
<div class="role-tools-default-content">
<div class="role-tools-default-title">默认角色使用所有工具</div>
<div class="role-tools-default-desc">默认角色会自动使用MCP管理中启用的所有工具,无需单独配置。</div>
</div>
</div>
</div>
<div class="role-tools-controls">
<div class="role-tools-actions">
<button type="button" class="btn-secondary" onclick="selectAllRoleTools()">全选</button>
<button type="button" class="btn-secondary" onclick="deselectAllRoleTools()">全不选</button>
<div class="role-tools-search-box">
<input type="text" id="role-tools-search" placeholder="搜索工具..."
oninput="searchRoleTools(this.value)"
onkeypress="if(event.key === 'Enter') searchRoleTools(this.value)" />
<button class="role-tools-search-clear" id="role-tools-search-clear"
onclick="clearRoleToolsSearch()" style="display: none;" title="清除搜索">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="12" cy="12" r="10" stroke="currentColor" stroke-width="2"/>
<path d="M15 9l-6 6M9 9l6 6" stroke="currentColor" stroke-width="2" stroke-linecap="round"/>
</svg>
</button>
</div>
</div>
<div id="role-tools-stats" class="role-tools-stats"></div>
</div>
<div id="role-tools-list" class="role-tools-list">
<div class="tools-loading">正在加载工具列表...</div>
</div>
<small class="form-hint">勾选要关联的工具,留空则使用MCP管理中的全部工具配置。</small>
</div>
<div class="form-group">
<label class="checkbox-label">
<input type="checkbox" id="role-enabled" class="modern-checkbox" checked />
<span class="checkbox-custom"></span>
<span class="checkbox-text">启用此角色</span>
</label>
</div>
</div>
<div class="modal-footer">
<button class="btn-secondary" onclick="closeRoleModal()">取消</button>
<button class="btn-primary" onclick="saveRole()">保存</button>
</div>
</div>
</div>
<script src="/static/js/builtin-tools.js"></script>
<script src="/static/js/auth.js"></script> <script src="/static/js/auth.js"></script>
<script src="/static/js/router.js"></script> <script src="/static/js/router.js"></script>
<script src="/static/js/monitor.js"></script> <script src="/static/js/monitor.js"></script>
@@ -1157,6 +1497,8 @@
<script src="/static/js/settings.js"></script> <script src="/static/js/settings.js"></script>
<script src="/static/js/knowledge.js"></script> <script src="/static/js/knowledge.js"></script>
<script src="/static/js/vulnerability.js?v=4"></script> <script src="/static/js/vulnerability.js?v=4"></script>
<script src="/static/js/tasks.js"></script>
<script src="/static/js/roles.js"></script>
</body> </body>
</html> </html>