mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-06 06:13:58 +02:00
Compare commits
273 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b8ebf023a0 | |||
| 604ce34d5e | |||
| b29b36bfd5 | |||
| 11bab83fc5 | |||
| dc750e3680 | |||
| 0236d1c155 | |||
| be59ddcab6 | |||
| 25464a68e6 | |||
| eabfed09c9 | |||
| cbcbd414cd | |||
| 0933f9365b | |||
| e792891ff3 | |||
| e14e5f15d3 | |||
| 4d5e0c5f21 | |||
| b3238304ce | |||
| 665e2ec73a | |||
| d63d9c25b8 | |||
| d1c63d0ba7 | |||
| 55d6d449cd | |||
| d4bc9646d9 | |||
| b941f5a8d9 | |||
| 97e2c0fd43 | |||
| bd3e48c2d0 | |||
| 8b0b91fddc | |||
| 2b38595b42 | |||
| 5c795439ee | |||
| df531910cf | |||
| 8a089a826c | |||
| 60b32ffc69 | |||
| 21c36fcce8 | |||
| 4d048f6da0 | |||
| 03a2707b83 | |||
| 9941f51b3e | |||
| 1553e896c5 | |||
| ea2184773e | |||
| 764d8110ec | |||
| e037f383f5 | |||
| e40f7cb468 | |||
| 72aca69204 | |||
| 133da1c640 | |||
| af78b47517 | |||
| f5fabc05a4 | |||
| 5cc53b1076 | |||
| f1be2064db | |||
| 0c9c2ec606 | |||
| cf09dd36d8 | |||
| c6e2701b30 | |||
| 42b5901d99 | |||
| 117bed6839 | |||
| bad323cd0e | |||
| 8138f8b576 | |||
| 74627d214b | |||
| f622efe245 | |||
| 3924b5285b | |||
| 21f641bbd7 | |||
| d913695303 | |||
| 6bb3a73f73 | |||
| f0a80a8e58 | |||
| 3f9dbb4214 | |||
| c0f0861b31 | |||
| 704137aa34 | |||
| c56bf36df0 | |||
| 5560f34c6c | |||
| 70e9a73fc0 | |||
| 12bc9d8ab6 | |||
| f8db82a065 | |||
| 8ce30d9072 | |||
| e6506d00e8 | |||
| b2308617b8 | |||
| cd17fdca33 | |||
| 1acaccd09f | |||
| 983fe650c1 | |||
| 52d03dc849 | |||
| 9de72d9ad5 | |||
| d95275ffae | |||
| 6cef93dbb7 | |||
| dd3b1ae219 | |||
| f42209682a | |||
| 1b1aed1699 | |||
| 44ced98863 | |||
| 97834c162e | |||
| 9276f2f144 | |||
| a454cada6a | |||
| 99b53d4fbc | |||
| a43a9deaea | |||
| ce88da84c9 | |||
| 15855c7073 | |||
| 43eb3e546b | |||
| 2d52c9b6ac | |||
| d5401b8b4c | |||
| 5fd4393a2e | |||
| a049f6b5c2 | |||
| acba8e5a39 | |||
| f826b91362 | |||
| 98c2de2a60 | |||
| 1c4d4b305b | |||
| f210ac9a03 | |||
| 6685076dfb | |||
| 7f322653f6 | |||
| 66ac2f1357 | |||
| c446e22d0c | |||
| 0358d3a67d | |||
| 9b82f265fd | |||
| 3d9cae58e4 | |||
| 1f1eadee5e | |||
| 0569255189 | |||
| 8ccf90d067 | |||
| b3be89f47d | |||
| b9bf8f62d4 | |||
| 05ca0c1480 | |||
| 47a4f3fc5b | |||
| a3b378ae9e | |||
| a904d26e78 | |||
| 7ba7476c4f | |||
| ae25a243ac | |||
| 23bd6288ff | |||
| fef21d3a24 | |||
| 933bba4517 | |||
| e1d65437cc | |||
| 9325aed1eb | |||
| dee2b3ab42 | |||
| a69bc93fa1 | |||
| b1a620bfce | |||
| 61b164eec2 | |||
| ba77e1837e | |||
| eacad60fd6 | |||
| 70bf5c93bf | |||
| 08bd278d8c | |||
| 22746d64a3 | |||
| 199392a5d5 | |||
| aafb4cb584 | |||
| 96e3dd397c | |||
| ec0f17145b | |||
| ed53da0999 | |||
| dc440fc511 | |||
| 009ae59033 | |||
| f348b3245a | |||
| 0018c5219c | |||
| 01a3e3677a | |||
| a12ecdb46f | |||
| 9f59230d74 | |||
| 085c6a1c72 | |||
| 7b3860971f | |||
| f6f7b7b237 | |||
| d5cf4b3b16 | |||
| 3e58d8355b | |||
| eb01ade63b | |||
| d1dc15fa44 | |||
| 73a39ef868 | |||
| a022baef03 | |||
| 59312d428e | |||
| 951d14ef14 | |||
| 0eb22da6e9 | |||
| 5fd9ef0514 | |||
| 9a4f3c7d35 | |||
| ead2ce3ecc | |||
| 8733f3a2d2 | |||
| 8642f3ba31 | |||
| 6a262a7367 | |||
| eb9192ddb3 | |||
| 5587e75628 | |||
| 74bbb453e2 | |||
| 66842f6206 | |||
| dc1779275d | |||
| 10dff937b1 | |||
| d4e1fe3bbe | |||
| 179976ae57 | |||
| 1c758bb98c | |||
| 17c4f38ee3 | |||
| cd7e57d121 | |||
| 0f2c3f65cc | |||
| 7779666e27 | |||
| c74bd4403b | |||
| 04d23ddb43 | |||
| 0874e84393 | |||
| 57f57f30b1 | |||
| f37d613a0c | |||
| 87d0ff9154 | |||
| b3418f39b8 | |||
| f9e1ca0e2d | |||
| 2c45879669 | |||
| 1cdcfa2c2d | |||
| eab5b73846 | |||
| d961ba1ec7 | |||
| 1ba5e57ec6 | |||
| 1216d25f96 | |||
| fde693408e | |||
| 352a81a869 | |||
| b2562b1010 | |||
| 0d8ba51087 | |||
| 0b847fcea3 | |||
| bf2f49fe62 | |||
| 75e64b1a86 | |||
| 2167735022 | |||
| 4ee292cc1f | |||
| 961205940f | |||
| ffe797bd06 | |||
| b6c864547e | |||
| da369c2edc | |||
| 54dc31a616 | |||
| 9e0b985221 | |||
| eb47077082 | |||
| f9a482857d | |||
| 679a68b12f | |||
| 840a26c7ef | |||
| 030e69c02d | |||
| d9683cdb44 | |||
| 60a063dd7d | |||
| 5f0c1805a7 | |||
| cb7e66001b | |||
| 4ea838f1d7 | |||
| 573648fc4b | |||
| f0e090abea | |||
| 549dcf518c | |||
| c74e20c54a | |||
| c94a9fd9e9 | |||
| ce9749a8ef | |||
| 145da12017 | |||
| 5111f4c311 | |||
| 8f6384a083 | |||
| 762f778e1e | |||
| 4a11ba8f14 | |||
| 86090af4df | |||
| 2dea6e36bd | |||
| 38ce695708 | |||
| 41fe90faa3 | |||
| 9f54bdb1bf | |||
| 08e727aa41 | |||
| 176c17d630 | |||
| 62710f6619 | |||
| e4dbb96b3e | |||
| 832532213a | |||
| eb04ac0c3a | |||
| 1946508325 | |||
| 89d1c5124f | |||
| 1e7a3299a5 | |||
| cae3a77331 | |||
| 2e1e57ce27 | |||
| 45b6ed2847 | |||
| 88eadf13a4 | |||
| dca5666b18 | |||
| e5d52cdf85 | |||
| 65e48826ff | |||
| 0cff507272 | |||
| 30afd71c05 | |||
| d2b6a154de | |||
| 278d5aa25c | |||
| 215f5a4a93 | |||
| 44185d748d | |||
| fe47f1f058 | |||
| 99ce183f41 | |||
| 2ed1947f36 | |||
| 97f3e8c179 | |||
| 38b0c31b87 | |||
| cb839da4d1 | |||
| 5ed730f17c | |||
| 30b1e5f820 | |||
| 8e5c70703e | |||
| 3cc3b25a7b | |||
| 44cf63fa52 | |||
| 12057c065b | |||
| c4e0b9735c | |||
| 218e9b9880 | |||
| 82d840966e | |||
| c62ff3bde9 | |||
| df2506b651 | |||
| efe9172f85 | |||
| b788bc6dab | |||
| 9134f2bbcb | |||
| d76cf2a162 | |||
| 2f96feb98f | |||
| a374c3950c | |||
| a93e3455fa |
@@ -113,6 +113,7 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
|
||||
- 🔒 Password-protected web UI, audit logs, and SQLite persistence
|
||||
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
|
||||
- 📁 Conversation grouping with pinning, rename, and batch management
|
||||
- 📂 **Project management**: group conversations and vulnerabilities by project; **shared facts** (project blackboard) persist cross-session context (targets, env, auth notes) with auto-injection for agents and MCP tools (`upsert_project_fact`, `get_project_fact`, …)
|
||||
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
||||
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
||||
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
||||
@@ -174,9 +175,11 @@ The `run.sh` script will automatically:
|
||||
- ✅ Build the project
|
||||
- ✅ Start the server
|
||||
|
||||
**Networking defaults:** `run.sh` starts the server with **`--https`** and the repo **`config.yaml`** (local self-signed TLS; better for many concurrent streams). Use **`./run.sh --http`** for plain HTTP. In production, set **`server.tls_cert_path`** / **`server.tls_key_path`** in **`config.yaml`** (see comments there). For manual runs, add **`--https`** or **`CYBERSTRIKE_HTTPS=1`**; if **`-config`** is wrong, the binary prints a short usage hint on stderr.
|
||||
|
||||
**First-Time Configuration:**
|
||||
1. **Configure OpenAI-compatible API** (required before first use)
|
||||
- Open http://localhost:8080 after launch
|
||||
- After launch, open **`https://127.0.0.1:8080/`** (or **`https://localhost:8080/`**; replace **8080** with `server.port` in `config.yaml`) and accept the self-signed certificate warning once. If you used `./run.sh --http`, use **`http://`** instead.
|
||||
- Go to `Settings` → Fill in your API credentials:
|
||||
```yaml
|
||||
openai:
|
||||
@@ -197,21 +200,23 @@ The `run.sh` script will automatically:
|
||||
|
||||
**Alternative Launch Methods:**
|
||||
```bash
|
||||
# Direct Go run (requires manual setup)
|
||||
go run cmd/server/main.go
|
||||
# Direct Go run (set up env yourself); add --https to match run.sh defaults
|
||||
go run cmd/server/main.go --https
|
||||
|
||||
# Manual build
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
./cyberstrike-ai
|
||||
./cyberstrike-ai --https
|
||||
```
|
||||
|
||||
If server logs show `client sent an HTTP request to an HTTPS server`, a client is still using **`http://`** on a TLS-only port—switch the URL to **`https://`**.
|
||||
|
||||
**Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment.
|
||||
|
||||
### Version Update (No Breaking Changes)
|
||||
|
||||
**CyberStrikeAI one-click upgrade (recommended):**
|
||||
1. (First time) enable the script: `chmod +x upgrade.sh`
|
||||
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--preserve-custom`, `--yes`)
|
||||
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--yes`). Local `tools/`, `roles/`, and `skills/` are always preserved.
|
||||
3. The script will back up your `config.yaml` and `data/`, upgrade the code from GitHub Release, update `config.yaml`'s `version`, then restart the server.
|
||||
|
||||
Recommended one-liner:
|
||||
@@ -281,7 +286,7 @@ Requirements / tips:
|
||||
- **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent.
|
||||
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
|
||||
- **Management** – Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
|
||||
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `default_mode`, `robot_use_multi_agent`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
||||
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
||||
- **Details** – **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
|
||||
|
||||
### Skills System (Agent Skills + Eino)
|
||||
@@ -532,7 +537,7 @@ agents_dir: "agents" # Multi-agent Markdown definitions (orchestrator + sub-age
|
||||
multi_agent:
|
||||
enabled: false
|
||||
default_mode: "single" # single | multi (UI default when multi-agent is enabled)
|
||||
robot_use_multi_agent: false
|
||||
robot_default_agent_mode: react
|
||||
batch_use_multi_agent: false
|
||||
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
|
||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
||||
|
||||
+12
-7
@@ -112,6 +112,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
- 🔒 Web 登录保护、审计日志、SQLite 持久化
|
||||
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
|
||||
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
||||
- 📂 **项目管理**:按项目归类对话与漏洞;**共享事实**(项目黑板)在多会话间沉淀目标/环境/认证等认知,自动注入 Agent 上下文,支持 MCP 工具读写(`upsert_project_fact`、`get_project_fact` 等)
|
||||
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
||||
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
||||
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
||||
@@ -173,9 +174,11 @@ chmod +x run.sh && ./run.sh
|
||||
- ✅ 编译构建项目
|
||||
- ✅ 启动服务器
|
||||
|
||||
**网络默认:** `run.sh` 会以 **`--https`** 并传入项目根 **`config.yaml`** 启动(本机自签证书,多路流式场景更稳)。只要明文 HTTP 用 **`./run.sh --http`**。生产环境在 **`config.yaml`** 的 **`server.tls_cert_path` / `server.tls_key_path`** 配正式证书(见文件内注释)。手动启动可加 **`--https`** 或环境变量 **`CYBERSTRIKE_HTTPS=1`**;`-config` 写错时程序会在终端提示正确写法。
|
||||
|
||||
**首次配置:**
|
||||
1. **配置 AI 模型 API**(首次使用前必填)
|
||||
- 启动后访问 http://localhost:8080
|
||||
- 启动后在浏览器打开 **`https://127.0.0.1:8080/`**(或 **`https://localhost:8080/`**;端口以 `config.yaml` 中 **`server.port`** 为准,默认 8080),并按提示信任自签证书。若使用 **`./run.sh --http`**,则改用 **`http://`** 访问。
|
||||
- 进入 `设置` → 填写 API 配置信息:
|
||||
```yaml
|
||||
openai:
|
||||
@@ -196,20 +199,22 @@ chmod +x run.sh && ./run.sh
|
||||
|
||||
**其他启动方式:**
|
||||
```bash
|
||||
# 直接运行(需手动配置环境)
|
||||
go run cmd/server/main.go
|
||||
# 直接运行(需自行配环境);与 run.sh 默认一致可加 --https
|
||||
go run cmd/server/main.go --https
|
||||
|
||||
# 手动编译
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
./cyberstrike-ai
|
||||
./cyberstrike-ai --https
|
||||
```
|
||||
|
||||
若日志出现 `client sent an HTTP request to an HTTPS server`,说明仍有客户端用 **`http://`** 访问只提供 HTTPS 的端口,请改为 **`https://`**。
|
||||
|
||||
**说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。
|
||||
|
||||
### CyberStrikeAI 版本更新(无兼容性问题)
|
||||
|
||||
1. (首次使用)启用脚本:`chmod +x upgrade.sh`
|
||||
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--preserve-custom`、`--yes`)
|
||||
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--yes`)。本地的 `tools/`、`roles/`、`skills/` 会始终保留不被覆盖。
|
||||
3. 脚本会备份你的 `config.yaml` 和 `data/`,从 GitHub Release 升级代码,更新 `config.yaml` 的 `version` 字段后重启服务。
|
||||
|
||||
推荐的一键指令:
|
||||
@@ -279,7 +284,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
||||
- **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。
|
||||
- **子代理**(**deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
|
||||
- **界面管理**:**Agents → Agent 管理**;API `/api/multi-agent/markdown-agents`。
|
||||
- **配置项**:`multi_agent`:`enabled`、`default_mode`、`robot_use_multi_agent`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
||||
- **配置项**:`multi_agent`:`enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
||||
- **更多细节**:[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
|
||||
|
||||
### Skills 技能系统(Agent Skills + Eino)
|
||||
@@ -530,7 +535,7 @@ agents_dir: "agents" # 多代理 Markdown(主代理 orchestrator.md + 子代
|
||||
multi_agent:
|
||||
enabled: false
|
||||
default_mode: "single" # single | multi(开启多代理时的界面默认模式)
|
||||
robot_use_multi_agent: false
|
||||
robot_default_agent_mode: react
|
||||
batch_use_multi_agent: false
|
||||
orchestrator_instruction: "" # Deep;orchestrator.md 正文为空时使用
|
||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
||||
|
||||
@@ -61,4 +61,8 @@ max_iterations: 0
|
||||
5) Follow-up Verification Plan(后续验证建议)
|
||||
- 对每个优先条目:建议由哪个阶段子代理接手、需要补测的最小证据集
|
||||
|
||||
输出后直接结束。遇到证据不足的条目标注为“需要补证据”。
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
输出后直接结束。遇到证据不足的条目标注为“需要补证据”。
|
||||
|
||||
@@ -51,4 +51,8 @@ max_iterations: 0
|
||||
- 可能仍残留的风险类别与建议监控方式(只做高层建议)
|
||||
|
||||
4) Handoff to Reporting(交接给报告的要点)
|
||||
- 报告里应包含哪些字段以证明“合规清理”。
|
||||
- 报告里应包含哪些字段以证明“合规清理”。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -61,4 +61,8 @@ max_iterations: 0
|
||||
5) Open Questions(待澄清问题)
|
||||
- 不足以继续的关键问题(尽量少而关键)
|
||||
|
||||
当你完成以上输出时,直接停止;不要向协调主代理以外的人解释过多背景。将所有不确定性标注为“需要补证据/需要澄清”。
|
||||
当你完成以上输出时,直接停止;不要向协调主代理以外的人解释过多背景。将所有不确定性标注为“需要补证据/需要澄清”。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -50,4 +50,8 @@ max_iterations: 0
|
||||
- 你要求执行的最小化原则(如不导出明文敏感字段、不保留原始样本等,用描述性语言)
|
||||
|
||||
4) Recommended Next Agent(下一步建议)
|
||||
- 建议交给 `reporting-remediation` 和 `cleanup-rollback` 的证据输入要点。
|
||||
- 建议交给 `reporting-remediation` 和 `cleanup-rollback` 的证据输入要点。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -32,3 +32,7 @@ max_iterations: 0
|
||||
- 优先用工具拿可验证事实,标注信息来源与置信度;避免无依据推测。
|
||||
- 输出结构化(目标、发现项、证据摘要、建议后续动作),便于协调者合并进总报告。
|
||||
- 不执行未授权的入侵或社工骚扰;双用途技术仅用于甲方书面授权场景。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -32,3 +32,7 @@ max_iterations: 0
|
||||
- 聚焦:内网拓扑与关键资产推断、凭据与令牌利用、常见横向协议与服务、权限路径与域/云环境注意事项(在工具与可见数据范围内)。
|
||||
- 每一步说明假设前提与证据;禁止对未授权网段、生产无关系统或真实用户数据进行操作。
|
||||
- 输出结构化:当前据点能力、发现的主机/服务、建议的下一步(可交给其他子代理或主代理编排)、风险与回滚注意点。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -51,4 +51,8 @@ max_iterations: 0
|
||||
- 建议记录哪些证据字段(时间戳、目标、请求摘要、响应摘要、变更清单、回滚确认)
|
||||
|
||||
4) Stop & Rollback Criteria(停止与回滚标准)
|
||||
- 触发阈值/不可控情况(用描述性语言即可)
|
||||
- 触发阈值/不可控情况(用描述性语言即可)
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -102,10 +102,34 @@ description: plan_execute 模式下的规划/重规划侧主代理:拆解目
|
||||
|
||||
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
|
||||
|
||||
## 证据与漏洞
|
||||
## 证据、黑板与漏洞
|
||||
|
||||
- 要求结论有证据支撑(请求/响应、命令输出、可复现步骤);禁止无依据的确定断言。
|
||||
- 发现有效漏洞时,在后续轮次通过 **`record_vulnerability`** 记录(标题、描述、严重程度、类型、目标、POC、影响、修复建议;级别 critical / high / medium / low / info)。
|
||||
|
||||
## 项目黑板(事实)与漏洞记录(分离)
|
||||
|
||||
当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。
|
||||
|
||||
- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。
|
||||
- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。
|
||||
- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。
|
||||
- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。
|
||||
- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。
|
||||
- **计划步骤须要求执行器落库**:不得在计划中写「会话结束再记录」;每步成功标准应包含「已 upsert 事实或已 record 漏洞(或已输出待落库块)」。
|
||||
|
||||
### 事实写入规范(审计复现 / 知识沉淀)
|
||||
|
||||
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||
- **body**:完整可复现上下文,写入 `upsert_project_fact` 的 body 字段;索引不含 body,后续会话须靠 `get_project_fact` 取回。
|
||||
- **category / fact_key 建议**:
|
||||
- 环境认知:`target/`、`auth/`、`infra/`、`business/`(body 用环境模板即可)
|
||||
- 发现与利用:`finding/`、`chain/`、`exploit/`、`poc/`(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||
- **与漏洞记录分工**:`record_vulnerability` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||
- 更新同一发现时保持相同 `fact_key` 覆盖写入,勿散落多个 key 导致上下文丢失。
|
||||
|
||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。
|
||||
|
||||
## 执行器对用户输出(重要)
|
||||
|
||||
|
||||
@@ -117,9 +117,29 @@ description: supervisor 模式下的协调者:通过 transfer 委派专家子
|
||||
3. 期望交付物是否可验收(例如:可复现命令、截图要点、结论段落)?
|
||||
4. 是否已明确写出 URL/IP:Port/域名路径与 in-scope 边界(而非“按上文继续”)?
|
||||
|
||||
## 漏洞
|
||||
## 项目黑板(事实)与漏洞记录(分离)
|
||||
|
||||
有效漏洞应通过 **`record_vulnerability`** 记录(含 POC 与严重性)。
|
||||
当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。
|
||||
|
||||
- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。
|
||||
- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。
|
||||
- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。
|
||||
- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。
|
||||
- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。
|
||||
|
||||
### 事实写入规范(审计复现 / 知识沉淀)
|
||||
|
||||
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||
- **body**:完整可复现上下文,写入 `upsert_project_fact` 的 body 字段;索引不含 body,后续会话须靠 `get_project_fact` 取回。
|
||||
- **category / fact_key 建议**:
|
||||
- 环境认知:`target/`、`auth/`、`infra/`、`business/`(body 用环境模板即可)
|
||||
- 发现与利用:`finding/`、`chain/`、`exploit/`、`poc/`(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||
- **与漏洞记录分工**:`record_vulnerability` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||
- 更新同一发现时保持相同 `fact_key` 覆盖写入,勿散落多个 key 导致上下文丢失。
|
||||
|
||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。
|
||||
|
||||
## 表达
|
||||
|
||||
|
||||
+23
-1
@@ -127,7 +127,29 @@ description: 多代理模式下的 Deep 编排者:在已授权安全场景中
|
||||
## 工具与 MCP
|
||||
|
||||
- **工具调用失败时**:1) 仔细分析错误信息,理解失败的具体原因;2) 如果工具不存在或未启用,尝试使用其他替代工具完成相同目标;3) 如果参数错误,根据错误提示修正参数后重试;4) 如果工具执行失败但输出了有用信息,可以基于这些信息继续分析;5) 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作;6) 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务。工具返回的错误信息会包含在工具响应中,请仔细阅读并做出合理决策。
|
||||
- **漏洞记录**:发现**有效漏洞**时,必须使用 **`record_vulnerability`** 记录(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度使用 critical / high / medium / low / info。记录后可在授权范围内继续测试。
|
||||
## 项目黑板(事实)与漏洞记录(分离)
|
||||
|
||||
当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。
|
||||
|
||||
- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。
|
||||
- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。
|
||||
- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。
|
||||
- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。
|
||||
- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。
|
||||
|
||||
### 事实写入规范(审计复现 / 知识沉淀)
|
||||
|
||||
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||
- **body**:完整可复现上下文,写入 `upsert_project_fact` 的 body 字段;索引不含 body,后续会话须靠 `get_project_fact` 取回。
|
||||
- **category / fact_key 建议**:
|
||||
- 环境认知:`target/`、`auth/`、`infra/`、`business/`(body 用环境模板即可)
|
||||
- 发现与利用:`finding/`、`chain/`、`exploit/`、`poc/`(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||
- **与漏洞记录分工**:`record_vulnerability` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||
- 更新同一发现时保持相同 `fact_key` 覆盖写入,勿散落多个 key 导致上下文丢失。
|
||||
|
||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。
|
||||
- **编排进度(待办)**:当你的任务包含 3 个或以上步骤,或你准备委派多个子目标并行/串行推进时,优先使用 `write_todos` 来向用户展示“当前在做什么/接下来做什么”。维护约束:同一时刻最多一个条目处于 `in_progress`;完成后立刻标记 `completed`;遇到阻塞就保留为 `in_progress` 并继续推进。
|
||||
- **强触发建议(提升多 agent 使用率)**:如果你将要进行任何“证据收集/枚举/扫描/验证/复现/整理报告”这类实质执行动作,且不只是单步查询,请优先在第一个工具调用前就用 `write_todos` 建立计划;随后用 `task` 委派至少一个子代理获取结构化证据,而不是自己把全部步骤做完。
|
||||
- **技能库(Skills)与知识库**:技能包位于服务器 `skills/` 目录(各子目录 `SKILL.md`,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。多代理本会话通过内置 **`skill`** 工具渐进加载;子代理同样挂载 skill + 可选本机文件工具时,可在委派说明中提示按需加载。若当前无 skill 工具,需要完整 Skill 工作流时请使用多代理模式或切换为 Eino 编排会话。
|
||||
|
||||
@@ -31,5 +31,9 @@ max_iterations: 0
|
||||
- 禁止自行猜测目标、替换为历史目标或擅自发起全量探索。
|
||||
|
||||
- 以证据为中心:请求/响应、Payload、命令输出、截图说明等,便于审计与复现。
|
||||
- 先确认边界与禁止项(如拒绝 DoS、数据破坏);发现有效漏洞时按协调者要求使用 `record_vulnerability` 等流程(若你的工具集中包含)。
|
||||
- 先确认边界与禁止项(如拒绝 DoS、数据破坏)。
|
||||
- 输出包含:攻击路径摘要、关键步骤、影响评估、修复与缓解建议;语言简洁,便于主代理汇总。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -51,4 +51,8 @@ max_iterations: 0
|
||||
- 列出需要清理/验证的痕迹类型(配置、会话、日志、服务变更等层级描述即可)
|
||||
|
||||
4) Recommended Next Steps(下一步建议)
|
||||
- 建议由哪个阶段子代理接手,以及需要哪些证据输入。
|
||||
- 建议由哪个阶段子代理接手,以及需要哪些证据输入。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -53,4 +53,8 @@ max_iterations: 0
|
||||
4) Recommended Next Agent(下一步建议)
|
||||
- 明确建议由哪个子代理接手(例如 `lateral-movement` / `persistence-maintenance` / `impact-exfiltration` / `reporting-remediation`)
|
||||
|
||||
输出后直接结束。
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
输出后直接结束。
|
||||
|
||||
@@ -34,3 +34,7 @@ max_iterations: 0
|
||||
|
||||
- 若 **`description` / 用户消息 / 上文交接包** 中已给出资产列表、枚举结论或明确写「跳过全量枚举 / 仅做增量 / 从端口扫描或验证开始」,则**不得**为走完整流程而重新执行等价的广域子域爆破或相同参数集的枚举;仅在交接包声明的**缺口**上补充侦察。
|
||||
- 若子目标实为**漏洞验证、协议利用、权限提升**等而非攻击面扩展,应**极短说明**「当前角色为侦察;建议协调者改派专项代理」并仅提供与侦察相关的最小补充信息,避免擅自把任务扩写成新一轮全盘资产收集。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -55,4 +55,8 @@ max_iterations: 0
|
||||
5) Appendix(附录)
|
||||
- 术语、假设、证据清单索引(按证据类型列出即可)
|
||||
|
||||
输出后直接结束。
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
输出后直接结束。
|
||||
|
||||
@@ -57,4 +57,8 @@ max_iterations: 0
|
||||
4) Uncertainties & Missing Evidence(不确定性与缺口)
|
||||
- 列出最关键的缺口(尽量少,但要关键)
|
||||
|
||||
输出后直接结束。
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
输出后直接结束。
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -32,6 +33,23 @@ func main() {
|
||||
// 创建安全工具执行器
|
||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||
|
||||
// 初始化结果存储(与 internal/app/app.go 同样的逻辑)。
|
||||
// stdio 模式下原本不初始化,导致 'exec' 等查询型工具报"结果存储未初始化"。
|
||||
resultStorageDir := "tmp"
|
||||
if cfg.Agent.ResultStorageDir != "" {
|
||||
resultStorageDir = cfg.Agent.ResultStorageDir
|
||||
}
|
||||
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "创建结果存储目录失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "初始化结果存储失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
executor.SetResultStorage(resultStorage)
|
||||
|
||||
// 注册工具
|
||||
executor.RegisterTools(mcpServer)
|
||||
|
||||
|
||||
+43
-3
@@ -9,22 +9,62 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var configPath = flag.String("config", "config.yaml", "配置文件路径")
|
||||
var httpsBootstrap = flag.Bool("https", false, "启用主站 HTTPS:未配置 tls_cert_path/tls_key_path 时使用内存自签证书(本地测试);与 run.sh 默认行为一致")
|
||||
flag.Parse()
|
||||
|
||||
// 环境变量兼容(便于 systemd/docker 等不传参场景)
|
||||
if !*httpsBootstrap {
|
||||
v := strings.TrimSpace(os.Getenv("CYBERSTRIKE_HTTPS"))
|
||||
if v == "1" || strings.EqualFold(v, "true") || strings.EqualFold(v, "yes") {
|
||||
*httpsBootstrap = true
|
||||
}
|
||||
}
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.Load(*configPath)
|
||||
cp := strings.TrimSpace(*configPath)
|
||||
if cp == "" {
|
||||
cp = "config.yaml"
|
||||
}
|
||||
if strings.HasPrefix(cp, "-") {
|
||||
fmt.Fprintf(os.Stderr, "无效的 -config 路径 %q。\n若同时需要 HTTPS,请写成: ./cyberstrike-ai --https -config config.yaml(-config 后必须是 yaml 文件路径)。\n", cp)
|
||||
os.Exit(2)
|
||||
}
|
||||
cfg, err := config.Load(cp)
|
||||
if err != nil {
|
||||
fmt.Printf("加载配置失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if *httpsBootstrap {
|
||||
config.ApplyDevHTTPSBootstrap(cfg)
|
||||
}
|
||||
|
||||
port := cfg.Server.Port
|
||||
if port <= 0 {
|
||||
port = 8080
|
||||
}
|
||||
scheme := "http"
|
||||
if config.MainWebUIUsesHTTPS(&cfg.Server) {
|
||||
scheme = "https"
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Printf("→ Web 界面: %s://127.0.0.1:%d/\n", scheme, port)
|
||||
if scheme == "https" && cfg.Server.TLSAutoSelfSign {
|
||||
fmt.Println(" (内存自签证书:浏览器首次需确认「继续访问」)")
|
||||
}
|
||||
if scheme == "https" && config.ServerHTTPRedirectEnabled(&cfg.Server) {
|
||||
fmt.Printf(" (http://127.0.0.1:%d/ 将自动跳转到 HTTPS)\n", port)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置
|
||||
if err := config.EnsureMCPAuth(*configPath, cfg); err != nil {
|
||||
if err := config.EnsureMCPAuth(cp, cfg); err != nil {
|
||||
fmt.Printf("MCP 鉴权配置失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
@@ -44,7 +84,7 @@ func main() {
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// 创建应用
|
||||
application, err := app.New(cfg, log)
|
||||
application, err := app.New(cfg, log, cp)
|
||||
if err != nil {
|
||||
log.Fatal("应用初始化失败", "error", err)
|
||||
}
|
||||
|
||||
+71
-7
@@ -10,11 +10,22 @@
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.6.5"
|
||||
version: "v1.6.29"
|
||||
# 服务器配置
|
||||
server:
|
||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||
port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080
|
||||
port: 8080 # 服务端口;未启用 TLS 时为 http://localhost:8080
|
||||
# --- 可选:HTTPS + HTTP/2(缓解浏览器对同源 HTTP/1.1 的并发连接数限制,多路 Deep 流式更稳)---
|
||||
# 启用 TLS 的条件(满足其一即可):tls_enabled: true,或 tls_auto_self_sign: true,或同时配置了 tls_cert_path + tls_key_path。
|
||||
# 启用后请用 https://127.0.0.1:<本端口>/ 访问;若仍用 http:// 访问同端口,将自动 308 跳转到 HTTPS(可用 tls_http_redirect: false 关闭)。
|
||||
tls_enabled: true
|
||||
# 启用 HTTPS 时,明文 HTTP 是否自动跳转到 HTTPS(默认 true;同端口嗅探 TLS/HTTP 后分流)
|
||||
# tls_http_redirect: true
|
||||
# 方式 A(推荐生产):PEM 证书与私钥路径
|
||||
# tls_cert_path: /path/to/fullchain.pem
|
||||
# tls_key_path: /path/to/privkey.pem
|
||||
# 方式 B(仅本地/测试):无证书文件时内存自签(浏览器会提示不受信任;SAN 含 localhost / 127.0.0.1)
|
||||
tls_auto_self_sign: true
|
||||
# 认证配置
|
||||
auth:
|
||||
password: # Web 登录密码,请修改为强密码
|
||||
@@ -23,6 +34,12 @@ auth:
|
||||
log:
|
||||
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
|
||||
output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径
|
||||
# 平台操作审计(系统设置 -> 日志审计;不记录对话正文与每次工具调用)
|
||||
audit:
|
||||
enabled: true
|
||||
retention_days: 15 # 0 表示不自动清理
|
||||
max_detail_bytes: 8192
|
||||
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
|
||||
# ============================================
|
||||
# 对话相关配置
|
||||
# ============================================
|
||||
@@ -41,6 +58,13 @@ openai:
|
||||
api_key: sk-xxxxxxx # API 密钥(必填)
|
||||
model: qwen3-max # 模型名称(必填)
|
||||
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
||||
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
||||
reasoning:
|
||||
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
||||
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
|
||||
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
|
||||
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
|
||||
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
|
||||
# ============================================
|
||||
# 信息收集(FOFA)配置(可选)
|
||||
# ============================================
|
||||
@@ -53,21 +77,23 @@ fofa:
|
||||
# Agent 配置
|
||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||
agent:
|
||||
max_iterations: 120 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||
max_iterations: 12000 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
||||
tool_timeout_minutes: 30 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||
|
||||
system_prompt_path: ""
|
||||
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
||||
hitl:
|
||||
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
|
||||
tool_whitelist: [read_file, list_dir, glob, grep]
|
||||
# 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存)
|
||||
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
|
||||
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/stream;Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人/批量无请求体时固定按 deep
|
||||
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/stream;Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人按 robot_default_agent_mode
|
||||
multi_agent:
|
||||
enabled: true
|
||||
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高)
|
||||
robot_default_agent_mode: eino_single # 企微/钉钉/飞书机器人默认对话模式:react | eino_single | deep | plan_execute | supervisor
|
||||
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
||||
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
|
||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
|
||||
@@ -90,7 +116,7 @@ multi_agent:
|
||||
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
||||
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
||||
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
||||
tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||
tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过
|
||||
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放
|
||||
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
||||
@@ -107,9 +133,26 @@ multi_agent:
|
||||
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
||||
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
||||
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
|
||||
run_retry_max_attempts: 0 # >0:429/5xx/网络抖动时 ADK 运行循环指数退避续跑次数;0=默认 10
|
||||
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
||||
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
||||
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client)
|
||||
eino_callbacks:
|
||||
enabled: true
|
||||
# log_only=仅 Zap+OTel(推荐默认)| sse/full=才启用流式回调副本关闭等(full 含 stream hooks)
|
||||
mode: log_only
|
||||
sse_trace_to_client: false # true:且 mode 为 sse/full 时,向前端时间线推送 eino_trace_*(排障/内网演示用)
|
||||
max_input_summary_runes: 400
|
||||
max_output_summary_runes: 400
|
||||
zap_verbose: false # true:Debug 附带 input/output 摘要
|
||||
otel:
|
||||
enabled: true
|
||||
service_name: cyberstrike-ai
|
||||
exporter: stdout # none | stdout(开发/本机)| otlphttp(生产接 Collector)
|
||||
otlp_endpoint: localhost:4318 # otlphttp 时使用,host:port,路径固定 /v1/traces
|
||||
sample_ratio: 1.0 # 0~1,ParentBased+TraceIDRatio
|
||||
# 数据库配置
|
||||
database:
|
||||
path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息
|
||||
@@ -202,6 +245,14 @@ knowledge:
|
||||
# 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用
|
||||
# 在系统设置 -> 机器人设置 中可配置
|
||||
robots:
|
||||
wechat: # 微信 iLink(个人微信 ClawBot,扫码绑定)
|
||||
enabled: false
|
||||
bot_token: ""
|
||||
ilink_bot_id: ""
|
||||
ilink_user_id: ""
|
||||
base_url: https://ilinkai.weixin.qq.com
|
||||
bot_type: "3"
|
||||
bot_agent: CyberStrikeAI/1.0
|
||||
wecom: # 企业微信
|
||||
enabled: false
|
||||
token: ""
|
||||
@@ -213,11 +264,13 @@ robots:
|
||||
enabled: false
|
||||
client_id: ""
|
||||
client_secret: ""
|
||||
allow_conversation_id_fallback: false
|
||||
lark: # 飞书
|
||||
enabled: false
|
||||
app_id: ""
|
||||
app_secret: ""
|
||||
verify_token: ""
|
||||
allow_chat_id_fallback: false
|
||||
# ============================================
|
||||
# Skills 相关配置
|
||||
# ============================================
|
||||
@@ -239,3 +292,14 @@ agents_dir: agents
|
||||
# 系统会从该目录加载所有 .yaml 格式的角色配置文件
|
||||
# 每个角色应创建独立的配置文件,例如:roles/CTF.yaml, roles/默认.yaml 等
|
||||
roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录)
|
||||
|
||||
# ============================================
|
||||
# 项目管理与事实黑板
|
||||
# ============================================
|
||||
project:
|
||||
enabled: true
|
||||
# default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID
|
||||
fact_index_max_runes: 3500
|
||||
fact_summary_max_runes: 240
|
||||
default_inject_deprecated: false
|
||||
|
||||
|
||||
@@ -27,7 +27,14 @@ require (
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/pkoukk/tiktoken-go v0.1.8
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
go.opentelemetry.io/otel v1.34.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
|
||||
go.opentelemetry.io/otel/sdk v1.34.0
|
||||
go.opentelemetry.io/otel/trace v1.34.0
|
||||
go.uber.org/zap v1.26.0
|
||||
golang.org/x/net v0.35.0
|
||||
golang.org/x/text v0.26.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -39,6 +46,7 @@ require (
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
@@ -46,6 +54,8 @@ require (
|
||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
@@ -53,6 +63,7 @@ require (
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
@@ -71,14 +82,20 @@ require (
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/yargevad/filepathx v1.0.0 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.34.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.5.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/arch v0.15.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||
golang.org/x/net v0.24.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||
google.golang.org/grpc v1.69.4 // indirect
|
||||
google.golang.org/protobuf v1.36.3 // indirect
|
||||
)
|
||||
|
||||
// 修复钉钉 Stream SDK 在长连接断开(熄屏/网络中断)后 "panic: send on closed channel" 问题
|
||||
|
||||
@@ -17,6 +17,8 @@ github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uS
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
@@ -59,6 +61,11 @@ github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
@@ -75,8 +82,8 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
@@ -90,6 +97,8 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d
|
||||
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3ArSgIyScOAyMRqBxRg=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
@@ -154,6 +163,8 @@ github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtIS
|
||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI=
|
||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg=
|
||||
github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY=
|
||||
@@ -191,6 +202,26 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
|
||||
go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 h1:OeNbIYk/2C15ckl7glBlOBp5+WlYsOElzTNmiPW/x60=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0/go.mod h1:7Bept48yIeqxP2OZ9/AqIpYS94h2or0aB4FypJTc8ZM=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0 h1:BEj3SPM81McUZHYjRS5pEgNgnmzGJ5tRpU5krWnV8Bs=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0/go.mod h1:9cKLGBDzI/F3NoHLQGm4ZrYdIHsvGt6ej6hUowxY0J4=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0 h1:jBpDk4HAUsrnVO1FsfCfCOTEc/MkInJmvfCHYLFiT80=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0/go.mod h1:H9LUIM1daaeZaz91vZcfeM0fejXPmgCYE8ZhzqfJuiU=
|
||||
go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ=
|
||||
go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE=
|
||||
go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A=
|
||||
go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8=
|
||||
go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
|
||||
go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
|
||||
go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4=
|
||||
go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4=
|
||||
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
|
||||
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
@@ -216,8 +247,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -251,9 +282,14 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f h1:gap6+3Gk41EItBuyi4XX/bp4oqJ3UwuIMl25yGinuAA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:Ic02D47M+zbarjYYUlK57y316f2MoN0gjAwI3f2S95o=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50=
|
||||
google.golang.org/grpc v1.69.4 h1:MF5TftSMkd8GLw/m0KM6V8CMOCY6NZ1NQDPGFgbTt4A=
|
||||
google.golang.org/grpc v1.69.4/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4=
|
||||
google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU=
|
||||
google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 182 KiB After Width: | Height: | Size: 178 KiB |
+59
-13
@@ -17,6 +17,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/project"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
@@ -193,6 +194,10 @@ type ChatMessage struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
// ToolName 仅 tool 角色:从 Eino/轨迹 JSON 的 name 或 tool_name 恢复,供续跑构造 ToolMessage。
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
// ReasoningContent 对应 OpenAI/DeepSeek 的 reasoning_content;思考模式 + 工具调用后续跑须回传(见 DeepSeek 文档)。
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串
|
||||
@@ -206,11 +211,17 @@ func (cm ChatMessage) MarshalJSON() ([]byte, error) {
|
||||
if cm.Content != "" {
|
||||
aux["content"] = cm.Content
|
||||
}
|
||||
if cm.ReasoningContent != "" {
|
||||
aux["reasoning_content"] = cm.ReasoningContent
|
||||
}
|
||||
|
||||
// 添加tool_call_id(如果存在)
|
||||
if cm.ToolCallID != "" {
|
||||
aux["tool_call_id"] = cm.ToolCallID
|
||||
}
|
||||
if cm.ToolName != "" {
|
||||
aux["tool_name"] = cm.ToolName
|
||||
}
|
||||
|
||||
// 转换tool_calls,将arguments转换为JSON字符串
|
||||
if len(cm.ToolCalls) > 0 {
|
||||
@@ -355,12 +366,12 @@ type ProgressCallback func(eventType, message string, data interface{})
|
||||
|
||||
// AgentLoop 执行Agent循环
|
||||
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil)
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil, "")
|
||||
}
|
||||
|
||||
// AgentLoopWithConversationID 执行Agent循环(带对话ID)
|
||||
func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) {
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil)
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil, "")
|
||||
}
|
||||
|
||||
// EinoSingleAgentSystemInstruction 供 Eino adk.ChatModelAgent.Instruction 使用,与 AgentLoopWithProgress 首条 system 对齐(含 system_prompt_path)。
|
||||
@@ -386,7 +397,7 @@ func (a *Agent) EinoSingleAgentSystemInstruction() string {
|
||||
}
|
||||
|
||||
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
|
||||
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) {
|
||||
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string, systemPromptExtra string) (*AgentLoopResult, error) {
|
||||
ctx = withAgentConversationID(ctx, conversationID)
|
||||
// 设置当前对话ID(兼容未走 context 的旧路径;并发会话应以 context 为准)
|
||||
a.mu.Lock()
|
||||
@@ -416,6 +427,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
}
|
||||
}
|
||||
}
|
||||
systemPrompt = project.AppendSystemPromptBlock(systemPrompt, systemPromptExtra)
|
||||
|
||||
messages := []ChatMessage{
|
||||
{
|
||||
@@ -438,6 +450,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Content: msg.Content,
|
||||
ToolCalls: msg.ToolCalls,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ToolName: msg.ToolName,
|
||||
})
|
||||
addedCount++
|
||||
contentPreview := msg.Content
|
||||
@@ -587,11 +600,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
thinkingStreamSeq++
|
||||
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
|
||||
thinkingStreamStarted := false
|
||||
var thinkingWire string
|
||||
|
||||
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
var deltaOut string
|
||||
thinkingWire, deltaOut = openai.NormalizeStreamingDelta(thinkingWire, delta)
|
||||
if deltaOut == "" {
|
||||
return nil
|
||||
}
|
||||
if !thinkingStreamStarted {
|
||||
thinkingStreamStarted = true
|
||||
sendProgress("thinking_stream_start", " ", map[string]interface{}{
|
||||
@@ -600,10 +619,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
"toolStream": false,
|
||||
})
|
||||
}
|
||||
sendProgress("thinking_stream_delta", delta, map[string]interface{}{
|
||||
sendProgress("thinking_stream_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"streamId": thinkingStreamId,
|
||||
"iteration": i + 1,
|
||||
})
|
||||
}, thinkingWire))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -657,8 +676,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// 检查是否有工具调用
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// 思考内容:如果本轮启用了思考流式增量(thinking_stream_*),前端会去重;
|
||||
// 同时也需要在该“思考阶段结束”时补一条可落库的 thinking(用于刷新后持久化展示)。
|
||||
// ReAct 助手正文流式增量(thinking_stream_*)在 UI 上归为「思考」;若与 streamId 重复则前端会去重。
|
||||
// 该条 thinking 用于刷新后持久化展示(与流式聚合一致)。
|
||||
if choice.Message.Content != "" {
|
||||
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
|
||||
"iteration": i + 1,
|
||||
@@ -816,10 +835,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "summary",
|
||||
})
|
||||
var summaryWire string
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
var deltaOut string
|
||||
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||
if deltaOut == "" {
|
||||
return nil
|
||||
}
|
||||
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
}, summaryWire))
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
@@ -863,10 +888,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "summary",
|
||||
})
|
||||
var summaryWire string
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
var deltaOut string
|
||||
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||
if deltaOut == "" {
|
||||
return nil
|
||||
}
|
||||
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
}, summaryWire))
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
@@ -910,10 +941,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "max_iter_summary",
|
||||
})
|
||||
var summaryWire string
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
var deltaOut string
|
||||
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||
if deltaOut == "" {
|
||||
return nil
|
||||
}
|
||||
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
}, summaryWire))
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
@@ -1909,6 +1946,15 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
|
||||
return a.executeToolViaMCP(ctx, toolName, args)
|
||||
}
|
||||
|
||||
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
||||
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
||||
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
if a == nil || a.mcpServer == nil {
|
||||
return ""
|
||||
}
|
||||
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
|
||||
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
|
||||
executionID = strings.TrimSpace(executionID)
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseTraceMessages 解析落库的 last_react_input(OpenAI 风格 messages JSON 数组)。
|
||||
func ParseTraceMessages(traceInputJSON string) ([]ChatMessage, error) {
|
||||
traceInputJSON = strings.TrimSpace(traceInputJSON)
|
||||
if traceInputJSON == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var raw []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(traceInputJSON), &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]ChatMessage, 0, len(raw))
|
||||
for _, msgMap := range raw {
|
||||
msg := ChatMessage{}
|
||||
role, _ := msgMap["role"].(string)
|
||||
if role == "" {
|
||||
continue
|
||||
}
|
||||
msg.Role = role
|
||||
if content, ok := msgMap["content"].(string); ok {
|
||||
msg.Content = content
|
||||
}
|
||||
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
|
||||
msg.ReasoningContent = rc
|
||||
}
|
||||
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
||||
if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok {
|
||||
for _, tcRaw := range toolCallsArray {
|
||||
tcMap, ok := tcRaw.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
toolCall := ToolCall{}
|
||||
if id, ok := tcMap["id"].(string); ok {
|
||||
toolCall.ID = id
|
||||
}
|
||||
if toolType, ok := tcMap["type"].(string); ok {
|
||||
toolCall.Type = toolType
|
||||
}
|
||||
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
||||
toolCall.Function = FunctionCall{}
|
||||
if name, ok := funcMap["name"].(string); ok {
|
||||
toolCall.Function.Name = name
|
||||
}
|
||||
if argsRaw, ok := funcMap["arguments"]; ok {
|
||||
if argsStr, ok := argsRaw.(string); ok {
|
||||
var argsMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
|
||||
toolCall.Function.Arguments = argsMap
|
||||
}
|
||||
} else if argsMap, ok := argsRaw.(map[string]interface{}); ok {
|
||||
toolCall.Function.Arguments = argsMap
|
||||
}
|
||||
}
|
||||
}
|
||||
if toolCall.ID != "" {
|
||||
msg.ToolCalls = append(msg.ToolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
||||
msg.ToolCallID = toolCallID
|
||||
}
|
||||
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
|
||||
msg.ToolName = strings.TrimSpace(tn)
|
||||
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
|
||||
msg.ToolName = strings.TrimSpace(tn)
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ExtractLastUserTurnMessages 仅保留最后一次 user 提问起的消息(不含更早的用户轮次;跳过 system)。
|
||||
// 与「继续对话」续跑所用轨迹范围一致:当前任务轮次,而非整段多轮对话历史。
|
||||
func ExtractLastUserTurnMessages(msgs []ChatMessage) []ChatMessage {
|
||||
if len(msgs) == 0 {
|
||||
return msgs
|
||||
}
|
||||
lastUser := -1
|
||||
for i, m := range msgs {
|
||||
if strings.EqualFold(m.Role, "user") {
|
||||
lastUser = i
|
||||
}
|
||||
}
|
||||
if lastUser < 0 {
|
||||
return msgs
|
||||
}
|
||||
trimmed := msgs[lastUser:]
|
||||
out := make([]ChatMessage, 0, len(trimmed))
|
||||
for _, m := range trimmed {
|
||||
if strings.EqualFold(m.Role, "system") {
|
||||
continue
|
||||
}
|
||||
out = append(out, m)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ExtractLastUserTurnTraceJSON 在 JSON 轨迹上裁剪为最后一次 user 起的片段(供落库格式直接处理)。
|
||||
func ExtractLastUserTurnTraceJSON(traceInputJSON string) string {
|
||||
traceInputJSON = strings.TrimSpace(traceInputJSON)
|
||||
if traceInputJSON == "" {
|
||||
return traceInputJSON
|
||||
}
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(traceInputJSON), &arr); err != nil {
|
||||
return traceInputJSON
|
||||
}
|
||||
lastUser := -1
|
||||
for i, m := range arr {
|
||||
if r, _ := m["role"].(string); strings.EqualFold(r, "user") {
|
||||
lastUser = i
|
||||
}
|
||||
}
|
||||
if lastUser <= 0 {
|
||||
return traceInputJSON
|
||||
}
|
||||
trimmed := arr[lastUser:]
|
||||
b, err := json.Marshal(trimmed)
|
||||
if err != nil {
|
||||
return traceInputJSON
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// MergeAssistantTraceOutput 将 last_react_output 合并进轨迹最后一条 assistant(与 loadHistoryFromAgentTrace 一致)。
|
||||
func MergeAssistantTraceOutput(msgs []ChatMessage, assistantOut string) []ChatMessage {
|
||||
assistantOut = strings.TrimSpace(assistantOut)
|
||||
if assistantOut == "" || len(msgs) == 0 {
|
||||
return msgs
|
||||
}
|
||||
out := append([]ChatMessage(nil), msgs...)
|
||||
last := &out[len(out)-1]
|
||||
if strings.EqualFold(last.Role, "assistant") && len(last.ToolCalls) == 0 {
|
||||
last.Content = assistantOut
|
||||
return out
|
||||
}
|
||||
out = append(out, ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: assistantOut,
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
// MessagesToTraceJSON 将消息带序列化为 JSON(跳过 system)。
|
||||
func MessagesToTraceJSON(msgs []ChatMessage) (string, error) {
|
||||
filtered := make([]ChatMessage, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
if strings.EqualFold(m.Role, "system") {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
b, err := json.Marshal(filtered)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractLastUserTurnTraceJSON(t *testing.T) {
|
||||
raw := []map[string]interface{}{
|
||||
{"role": "user", "content": "old question"},
|
||||
{"role": "assistant", "content": "old answer"},
|
||||
{"role": "user", "content": "new target 1.1.1.1"},
|
||||
{"role": "assistant", "tool_calls": []interface{}{map[string]interface{}{
|
||||
"id": "c1", "type": "function",
|
||||
"function": map[string]interface{}{"name": "nmap", "arguments": "{}"},
|
||||
}}},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "open ports"},
|
||||
}
|
||||
b, _ := json.Marshal(raw)
|
||||
out := ExtractLastUserTurnTraceJSON(string(b))
|
||||
var trimmed []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(out), &trimmed); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(trimmed) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d", len(trimmed))
|
||||
}
|
||||
if trimmed[0]["content"] != "new target 1.1.1.1" {
|
||||
t.Fatalf("unexpected first message: %v", trimmed[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractLastUserTurnMessagesSkipsSystem(t *testing.T) {
|
||||
msgs := []ChatMessage{
|
||||
{Role: "system", Content: "sys"},
|
||||
{Role: "user", Content: "q"},
|
||||
{Role: "assistant", Content: "a"},
|
||||
}
|
||||
out := ExtractLastUserTurnMessages(msgs)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(out))
|
||||
}
|
||||
if out[0].Role != "user" {
|
||||
t.Fatal("expected user first")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAssistantTraceOutput(t *testing.T) {
|
||||
msgs := []ChatMessage{
|
||||
{Role: "user", Content: "q"},
|
||||
{Role: "assistant", Content: "draft"},
|
||||
}
|
||||
out := MergeAssistantTraceOutput(msgs, "final summary")
|
||||
if out[len(out)-1].Content != "final summary" {
|
||||
t.Fatalf("expected merged output, got %q", out[len(out)-1].Content)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package agent
|
||||
|
||||
import "cyberstrike-ai/internal/mcp/builtin"
|
||||
import (
|
||||
"cyberstrike-ai/internal/project"
|
||||
)
|
||||
|
||||
// DefaultSingleAgentSystemPrompt 单代理(ReAct / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。
|
||||
func DefaultSingleAgentSystemPrompt() string {
|
||||
@@ -105,11 +107,7 @@ func DefaultSingleAgentSystemPrompt() string {
|
||||
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
|
||||
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
|
||||
|
||||
## 漏洞记录
|
||||
|
||||
发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。
|
||||
|
||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。记录后可在授权范围内继续测试。
|
||||
` + project.FactRecordingBlackboardSection(false) + `
|
||||
|
||||
## 技能库(Skills)与知识库
|
||||
|
||||
|
||||
+160
-202
@@ -3,8 +3,10 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -13,9 +15,11 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/einoobserve"
|
||||
"cyberstrike-ai/internal/handler"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/logger"
|
||||
@@ -29,6 +33,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// App 应用
|
||||
@@ -52,14 +57,16 @@ type App struct {
|
||||
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
|
||||
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
|
||||
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
|
||||
wechatCancel context.CancelFunc // 微信 iLink 长轮询取消函数
|
||||
c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil)
|
||||
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
|
||||
c2WatchdogCancel context.CancelFunc // 看门狗取消函数
|
||||
c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步)
|
||||
auditSvc *audit.Service
|
||||
}
|
||||
|
||||
// New 创建新应用
|
||||
func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error) {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
router := gin.Default()
|
||||
|
||||
@@ -88,8 +95,14 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||
}
|
||||
|
||||
auditSvc := audit.NewService(db, cfg, log.Logger)
|
||||
audit.RegisterConversationCreateHook(auditSvc)
|
||||
auditSvc.PurgeExpired()
|
||||
audit.StartRetentionLoop(auditSvc, log.Logger)
|
||||
|
||||
// 创建MCP服务器(带数据库持久化)
|
||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||
|
||||
// 创建安全工具执行器
|
||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||
@@ -98,7 +111,8 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
executor.RegisterTools(mcpServer)
|
||||
|
||||
// 注册漏洞记录工具
|
||||
registerVulnerabilityTool(mcpServer, db, log.Logger)
|
||||
registerVulnerabilityTools(mcpServer, db, log.Logger)
|
||||
registerProjectFactTools(mcpServer, db, cfg, log.Logger)
|
||||
|
||||
if cfg.Auth.GeneratedPassword != "" {
|
||||
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
|
||||
@@ -216,6 +230,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
|
||||
// 创建知识库API处理器
|
||||
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
|
||||
knowledgeHandler.SetAudit(auditSvc)
|
||||
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
||||
|
||||
// 扫描知识库并建立索引(异步)
|
||||
@@ -290,10 +305,10 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
}()
|
||||
}
|
||||
|
||||
// 获取配置文件路径
|
||||
configPath := "config.yaml"
|
||||
if len(os.Args) > 1 {
|
||||
configPath = os.Args[1]
|
||||
// 配置文件路径必须由入口传入(与 flag -config 一致)。勿再用 os.Args[1],否则 ./cyberstrike-ai --https 会把 --https 当成路径。
|
||||
configPath = strings.TrimSpace(configPath)
|
||||
if configPath == "" {
|
||||
configPath = "config.yaml"
|
||||
}
|
||||
|
||||
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
|
||||
@@ -312,31 +327,43 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
log.Logger.Warn("创建 agents 目录失败", zap.String("path", agentsDir), zap.Error(err))
|
||||
}
|
||||
markdownAgentsHandler := handler.NewMarkdownAgentsHandler(agentsDir)
|
||||
markdownAgentsHandler.SetAudit(auditSvc)
|
||||
log.Logger.Info("多代理 Markdown 子 Agent 目录", zap.String("agentsDir", agentsDir))
|
||||
|
||||
// 创建处理器
|
||||
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
|
||||
agentHandler.SetAudit(auditSvc)
|
||||
agentHandler.SetAgentsMarkdownDir(agentsDir)
|
||||
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
|
||||
if knowledgeManager != nil {
|
||||
agentHandler.SetKnowledgeManager(knowledgeManager)
|
||||
}
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
||||
monitorHandler.SetAudit(auditSvc)
|
||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
||||
authHandler.SetAudit(auditSvc)
|
||||
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
||||
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
||||
projectHandler := handler.NewProjectHandler(db, log.Logger)
|
||||
vulnerabilityHandler.SetAudit(auditSvc)
|
||||
webshellHandler := handler.NewWebShellHandler(log.Logger, db)
|
||||
webshellHandler.SetAudit(auditSvc)
|
||||
chatUploadsHandler := handler.NewChatUploadsHandler(log.Logger)
|
||||
chatUploadsHandler.SetAudit(auditSvc)
|
||||
registerWebshellTools(mcpServer, db, webshellHandler, log.Logger)
|
||||
registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger)
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||
configHandler.SetAudit(auditSvc)
|
||||
agentHandler.SetHitlToolWhitelistSaver(configHandler)
|
||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||
externalMCPHandler.SetAudit(auditSvc)
|
||||
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
||||
roleHandler.SetAudit(auditSvc)
|
||||
skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger)
|
||||
skillsHandler.SetAudit(auditSvc)
|
||||
fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
|
||||
terminalHandler := handler.NewTerminalHandler(log.Logger)
|
||||
if db != nil {
|
||||
@@ -351,9 +378,12 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port)
|
||||
}
|
||||
c2Handler := handler.NewC2Handler(c2Manager, log.Logger)
|
||||
c2Handler.SetAudit(auditSvc)
|
||||
|
||||
// 创建OpenAPI处理器
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
conversationHandler.SetAudit(auditSvc)
|
||||
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
||||
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
||||
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
|
||||
|
||||
@@ -379,13 +409,15 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
c2Watchdog: c2Watchdog,
|
||||
c2WatchdogCancel: watchdogCancel,
|
||||
c2Handler: c2Handler,
|
||||
auditSvc: auditSvc,
|
||||
}
|
||||
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
|
||||
app.startRobotConnections()
|
||||
|
||||
// 设置漏洞工具注册器(内置工具,必须设置)
|
||||
vulnerabilityRegistrar := func() error {
|
||||
registerVulnerabilityTool(mcpServer, db, log.Logger)
|
||||
registerVulnerabilityTools(mcpServer, db, log.Logger)
|
||||
registerProjectFactTools(mcpServer, db, cfg, log.Logger)
|
||||
return nil
|
||||
}
|
||||
configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar)
|
||||
@@ -444,9 +476,11 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
configHandler.SetRetrieverUpdater(knowledgeRetriever)
|
||||
}
|
||||
|
||||
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效
|
||||
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书/微信新配置生效
|
||||
configHandler.SetRobotRestarter(app)
|
||||
|
||||
wechatRobotHandler := handler.NewWechatRobotHandler(cfg, configHandler, log.Logger)
|
||||
|
||||
configHandler.SetC2Runtime(app)
|
||||
configHandler.SetC2ToolRegistrar(func() error {
|
||||
if app.config.C2.EnabledEffective() && app.c2Manager != nil {
|
||||
@@ -464,12 +498,14 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
notificationHandler,
|
||||
conversationHandler,
|
||||
robotHandler,
|
||||
wechatRobotHandler,
|
||||
groupHandler,
|
||||
configHandler,
|
||||
externalMCPHandler,
|
||||
attackChainHandler,
|
||||
app, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler,
|
||||
projectHandler,
|
||||
webshellHandler,
|
||||
chatUploadsHandler,
|
||||
roleHandler,
|
||||
@@ -478,6 +514,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
fofaHandler,
|
||||
terminalHandler,
|
||||
app.c2Handler,
|
||||
auditHandler,
|
||||
mcpServer,
|
||||
authManager,
|
||||
openAPIHandler,
|
||||
@@ -528,18 +565,49 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
||||
}()
|
||||
}
|
||||
|
||||
// 启动主服务器
|
||||
// 启动主服务器(可选 HTTPS + HTTP/2,见 config server.tls_*)
|
||||
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
|
||||
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
|
||||
tlsMode, tlsConf, certFile, keyFile, tlsErr := prepareMainServerTLS(&a.config.Server)
|
||||
if tlsErr != nil {
|
||||
return tlsErr
|
||||
}
|
||||
|
||||
srv := &http.Server{Addr: addr, Handler: a.router}
|
||||
var mainMux *mainServerMux
|
||||
httpRedirect := config.ServerHTTPRedirectEnabled(&a.config.Server)
|
||||
if tlsMode != mainTLSOff {
|
||||
srv.TLSConfig = tlsConf
|
||||
if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil {
|
||||
return fmt.Errorf("主服务 HTTP/2 配置失败: %w", err)
|
||||
}
|
||||
switch tlsMode {
|
||||
case mainTLSFromFiles:
|
||||
a.logger.Info("启动 HTTPS 主服务(已启用 HTTP/2 协商)",
|
||||
zap.String("address", addr),
|
||||
zap.String("cert", certFile),
|
||||
)
|
||||
case mainTLSInMemorySelfSigned:
|
||||
a.logger.Info("启动 HTTPS 主服务(内存自签证书,仅测试;已启用 HTTP/2 协商)",
|
||||
zap.String("address", addr),
|
||||
)
|
||||
}
|
||||
if httpRedirect {
|
||||
a.logger.Info("已启用 HTTP→HTTPS 自动跳转(同端口嗅探分流)", zap.String("address", addr))
|
||||
}
|
||||
} else {
|
||||
a.logger.Info("启动 HTTP 主服务", zap.String("address", addr))
|
||||
}
|
||||
|
||||
// 监听 context 取消,优雅关闭 HTTP 服务器
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
if mainMux != nil {
|
||||
if err := mainMux.Shutdown(shutdownCtx); err != nil {
|
||||
a.logger.Error("HTTP/HTTPS 分流服务器关闭失败", zap.Error(err))
|
||||
}
|
||||
} else if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
a.logger.Error("HTTP服务器关闭失败", zap.Error(err))
|
||||
}
|
||||
if mcpServer != nil {
|
||||
@@ -549,7 +617,36 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
var err error
|
||||
switch {
|
||||
case tlsMode != mainTLSOff && httpRedirect:
|
||||
var tlsConfReady *tls.Config
|
||||
tlsConfReady, err = ensureMainTLSConfigCerts(tlsMode, tlsConf, certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加载 TLS 证书: %w", err)
|
||||
}
|
||||
srv.TLSConfig = tlsConfReady
|
||||
var ln net.Listener
|
||||
ln, err = net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mainMux = newMainServerMux(ln, srv, portFromListenAddr(addr), a.logger.Logger)
|
||||
err = mainMux.Serve()
|
||||
case tlsMode == mainTLSOff:
|
||||
err = srv.ListenAndServe()
|
||||
case tlsMode == mainTLSFromFiles:
|
||||
err = srv.ListenAndServeTLS(certFile, keyFile)
|
||||
case tlsMode == mainTLSInMemorySelfSigned:
|
||||
var ln net.Listener
|
||||
ln, err = tls.Listen("tcp", addr, srv.TLSConfig)
|
||||
if err == nil {
|
||||
err = srv.Serve(ln)
|
||||
}
|
||||
default:
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -557,6 +654,10 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
||||
|
||||
// Shutdown 关闭应用
|
||||
func (a *App) Shutdown() {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = einoobserve.ShutdownOtel(shutdownCtx)
|
||||
shutdownCancel()
|
||||
|
||||
// 停止钉钉/飞书长连接
|
||||
a.robotMu.Lock()
|
||||
if a.dingCancel != nil {
|
||||
@@ -606,9 +707,14 @@ func (a *App) startRobotConnections() {
|
||||
a.dingCancel = cancel
|
||||
go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
|
||||
}
|
||||
if cfg.Robots.Wechat.Enabled && cfg.Robots.Wechat.BotToken != "" {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
a.wechatCancel = cancel
|
||||
go robot.StartWechat(ctx, cfg.Robots, a.robotHandler, cfg.Version, a.logger.Logger)
|
||||
}
|
||||
}
|
||||
|
||||
// RestartRobotConnections 重启钉钉/飞书长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter)
|
||||
// RestartRobotConnections 重启钉钉/飞书/微信长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter)
|
||||
func (a *App) RestartRobotConnections() {
|
||||
a.robotMu.Lock()
|
||||
if a.dingCancel != nil {
|
||||
@@ -619,6 +725,10 @@ func (a *App) RestartRobotConnections() {
|
||||
a.larkCancel()
|
||||
a.larkCancel = nil
|
||||
}
|
||||
if a.wechatCancel != nil {
|
||||
a.wechatCancel()
|
||||
a.wechatCancel = nil
|
||||
}
|
||||
a.robotMu.Unlock()
|
||||
// 给旧 goroutine 一点时间退出
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
@@ -634,12 +744,14 @@ func setupRoutes(
|
||||
notificationHandler *handler.NotificationHandler,
|
||||
conversationHandler *handler.ConversationHandler,
|
||||
robotHandler *handler.RobotHandler,
|
||||
wechatRobotHandler *handler.WechatRobotHandler,
|
||||
groupHandler *handler.GroupHandler,
|
||||
configHandler *handler.ConfigHandler,
|
||||
externalMCPHandler *handler.ExternalMCPHandler,
|
||||
attackChainHandler *handler.AttackChainHandler,
|
||||
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler *handler.VulnerabilityHandler,
|
||||
projectHandler *handler.ProjectHandler,
|
||||
webshellHandler *handler.WebShellHandler,
|
||||
chatUploadsHandler *handler.ChatUploadsHandler,
|
||||
roleHandler *handler.RoleHandler,
|
||||
@@ -648,6 +760,7 @@ func setupRoutes(
|
||||
fofaHandler *handler.FofaHandler,
|
||||
terminalHandler *handler.TerminalHandler,
|
||||
c2Handler *handler.C2Handler,
|
||||
auditHandler *handler.AuditHandler,
|
||||
mcpServer *mcp.Server,
|
||||
authManager *security.AuthManager,
|
||||
openAPIHandler *handler.OpenAPIHandler,
|
||||
@@ -682,6 +795,12 @@ func setupRoutes(
|
||||
// 机器人测试(需登录):POST /api/robot/test,body: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
|
||||
protected.POST("/robot/test", robotHandler.HandleRobotTest)
|
||||
|
||||
// 微信 iLink 扫码绑定(需登录)
|
||||
protected.POST("/robot/wechat/qrcode", wechatRobotHandler.HandleWechatQRCode)
|
||||
protected.GET("/robot/wechat/qrcode/status", wechatRobotHandler.HandleWechatQRCodeStatus)
|
||||
protected.POST("/robot/wechat/qrcode/verify", wechatRobotHandler.HandleWechatVerifyCode)
|
||||
protected.GET("/robot/wechat/status", wechatRobotHandler.HandleWechatStatus)
|
||||
|
||||
// Agent Loop
|
||||
protected.POST("/agent-loop", agentHandler.AgentLoop)
|
||||
// Agent Loop 流式输出
|
||||
@@ -737,6 +856,7 @@ func setupRoutes(
|
||||
protected.GET("/conversations/:id", conversationHandler.GetConversation)
|
||||
protected.GET("/messages/:id/process-details", conversationHandler.GetMessageProcessDetails)
|
||||
protected.PUT("/conversations/:id", conversationHandler.UpdateConversation)
|
||||
protected.PUT("/conversations/:id/project", conversationHandler.SetConversationProject)
|
||||
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
|
||||
protected.POST("/conversations/:id/delete-turn", conversationHandler.DeleteConversationTurn)
|
||||
protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned)
|
||||
@@ -778,6 +898,13 @@ func setupRoutes(
|
||||
protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream)
|
||||
protected.GET("/terminal/ws", terminalHandler.RunCommandWS)
|
||||
|
||||
// 平台审计日志
|
||||
protected.GET("/audit/meta", auditHandler.Meta)
|
||||
protected.GET("/audit/summary", auditHandler.Summary)
|
||||
protected.GET("/audit/logs", auditHandler.ListLogs)
|
||||
protected.GET("/audit/logs/export", auditHandler.ExportLogs)
|
||||
protected.GET("/audit/logs/:id", auditHandler.GetLog)
|
||||
|
||||
// 外部MCP管理
|
||||
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
|
||||
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
|
||||
@@ -946,6 +1073,23 @@ func setupRoutes(
|
||||
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
|
||||
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
|
||||
|
||||
// 项目管理与事实黑板
|
||||
protected.GET("/projects", projectHandler.ListProjects)
|
||||
protected.POST("/projects", projectHandler.CreateProject)
|
||||
protected.GET("/projects/:id/stats", projectHandler.GetProjectStats)
|
||||
protected.GET("/projects/:id/conversations", projectHandler.ListProjectConversations)
|
||||
protected.GET("/projects/:id", projectHandler.GetProject)
|
||||
protected.PUT("/projects/:id", projectHandler.UpdateProject)
|
||||
protected.DELETE("/projects/:id", projectHandler.DeleteProject)
|
||||
protected.GET("/projects/:id/facts", projectHandler.ListFacts)
|
||||
protected.GET("/projects/:id/facts/:factId/previous-version", projectHandler.GetFactPreviousVersion)
|
||||
protected.GET("/projects/:id/facts/:factId/versions", projectHandler.ListFactVersions)
|
||||
protected.POST("/projects/:id/facts", projectHandler.CreateFact)
|
||||
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
|
||||
protected.DELETE("/projects/:id/facts/:factId", projectHandler.DeleteFact)
|
||||
protected.POST("/projects/:id/facts/deprecate", projectHandler.DeprecateFact)
|
||||
protected.POST("/projects/:id/facts/restore", projectHandler.RestoreFact)
|
||||
|
||||
// WebShell 管理(代理执行 + 连接配置存 SQLite)
|
||||
protected.GET("/webshell/connections", webshellHandler.ListConnections)
|
||||
protected.POST("/webshell/connections", webshellHandler.CreateConnection)
|
||||
@@ -1066,195 +1210,6 @@ func setupRoutes(
|
||||
})
|
||||
}
|
||||
|
||||
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
|
||||
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: builtin.ToolRecordVulnerability,
|
||||
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
|
||||
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"title": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞标题(必需)",
|
||||
},
|
||||
"description": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞详细描述",
|
||||
},
|
||||
"severity": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
|
||||
"enum": []string{"critical", "high", "medium", "low", "info"},
|
||||
},
|
||||
"vulnerability_type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
|
||||
},
|
||||
"target": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "受影响的目标(URL、IP地址、服务等)",
|
||||
},
|
||||
"proof": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞证明(POC、截图、请求/响应等)",
|
||||
},
|
||||
"impact": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞影响说明",
|
||||
},
|
||||
"recommendation": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "修复建议",
|
||||
},
|
||||
},
|
||||
"required": []string{"title", "severity"},
|
||||
},
|
||||
}
|
||||
|
||||
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
// 从参数中获取conversation_id(由Agent自动添加)
|
||||
conversationID, _ := args["conversation_id"].(string)
|
||||
if conversationID == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: conversation_id 未设置。这是系统错误,请重试。",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
title, ok := args["title"].(string)
|
||||
if !ok || title == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: title 参数必需且不能为空",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
severity, ok := args["severity"].(string)
|
||||
if !ok || severity == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: severity 参数必需且不能为空",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 验证严重程度
|
||||
validSeverities := map[string]bool{
|
||||
"critical": true,
|
||||
"high": true,
|
||||
"medium": true,
|
||||
"low": true,
|
||||
"info": true,
|
||||
}
|
||||
if !validSeverities[severity] {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 获取可选参数
|
||||
description := ""
|
||||
if d, ok := args["description"].(string); ok {
|
||||
description = d
|
||||
}
|
||||
|
||||
vulnType := ""
|
||||
if t, ok := args["vulnerability_type"].(string); ok {
|
||||
vulnType = t
|
||||
}
|
||||
|
||||
target := ""
|
||||
if t, ok := args["target"].(string); ok {
|
||||
target = t
|
||||
}
|
||||
|
||||
proof := ""
|
||||
if p, ok := args["proof"].(string); ok {
|
||||
proof = p
|
||||
}
|
||||
|
||||
impact := ""
|
||||
if i, ok := args["impact"].(string); ok {
|
||||
impact = i
|
||||
}
|
||||
|
||||
recommendation := ""
|
||||
if r, ok := args["recommendation"].(string); ok {
|
||||
recommendation = r
|
||||
}
|
||||
|
||||
// 创建漏洞记录
|
||||
vuln := &database.Vulnerability{
|
||||
ConversationID: conversationID,
|
||||
Title: title,
|
||||
Description: description,
|
||||
Severity: severity,
|
||||
Status: "open",
|
||||
Type: vulnType,
|
||||
Target: target,
|
||||
Proof: proof,
|
||||
Impact: impact,
|
||||
Recommendation: recommendation,
|
||||
}
|
||||
|
||||
created, err := db.CreateVulnerability(vuln)
|
||||
if err != nil {
|
||||
logger.Error("记录漏洞失败", zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("记录漏洞失败: %v", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
logger.Info("漏洞记录成功",
|
||||
zap.String("id", created.ID),
|
||||
zap.String("title", created.Title),
|
||||
zap.String("severity", created.Severity),
|
||||
zap.String("conversation_id", conversationID),
|
||||
)
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n你可以在漏洞管理页面查看和管理此漏洞。", created.ID, created.Title, created.Severity, created.Status),
|
||||
},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, handler)
|
||||
logger.Info("漏洞记录工具注册成功")
|
||||
}
|
||||
|
||||
// registerWebshellTools 注册 WebShell 相关 MCP 工具,供 AI 助手在指定连接上执行命令与文件操作
|
||||
func registerWebshellTools(mcpServer *mcp.Server, db *database.DB, webshellHandler *handler.WebShellHandler, logger *zap.Logger) {
|
||||
if db == nil || webshellHandler == nil {
|
||||
@@ -1839,6 +1794,9 @@ func initializeKnowledge(
|
||||
|
||||
// 创建知识库API处理器
|
||||
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
|
||||
if app != nil && app.auditSvc != nil {
|
||||
knowledgeHandler.SetAudit(app.auditSvc)
|
||||
}
|
||||
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
||||
|
||||
// 设置知识库管理器到AgentHandler以便记录检索日志
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// peekedConn 在已预读首字节后仍将连接交给 net/http 或 crypto/tls。
|
||||
type peekedConn struct {
|
||||
net.Conn
|
||||
r *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *peekedConn) Read(p []byte) (int, error) {
|
||||
return c.r.Read(p)
|
||||
}
|
||||
|
||||
// oneConnListener 供 http.Server.Serve 处理单条 TCP 连接(含 keep-alive)。
|
||||
type oneConnListener struct {
|
||||
conn net.Conn
|
||||
addr net.Addr
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (l *oneConnListener) Accept() (net.Conn, error) {
|
||||
var c net.Conn
|
||||
l.once.Do(func() {
|
||||
c = l.conn
|
||||
l.conn = nil
|
||||
})
|
||||
if c == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (l *oneConnListener) Close() error { return nil }
|
||||
func (l *oneConnListener) Addr() net.Addr { return l.addr }
|
||||
|
||||
func isTLSHandshakeRecord(b byte) bool {
|
||||
return b == 0x16
|
||||
}
|
||||
|
||||
func newHTTPToHTTPSRedirectHandler(httpsPort int) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.Host
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
var target string
|
||||
if httpsPort == 443 {
|
||||
target = fmt.Sprintf("https://%s%s", host, r.URL.RequestURI())
|
||||
} else {
|
||||
target = fmt.Sprintf("https://%s:%d%s", host, httpsPort, r.URL.RequestURI())
|
||||
}
|
||||
http.Redirect(w, r, target, http.StatusPermanentRedirect)
|
||||
})
|
||||
}
|
||||
|
||||
func portFromListenAddr(addr string) int {
|
||||
_, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return 443
|
||||
}
|
||||
p, err := strconv.Atoi(portStr)
|
||||
if err != nil || p <= 0 {
|
||||
return 443
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func ensureMainTLSConfigCerts(mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string) (*tls.Config, error) {
|
||||
if mode != mainTLSFromFiles {
|
||||
return tlsConf, nil
|
||||
}
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
if len(tlsConf.Certificates) > 0 {
|
||||
return tlsConf, nil
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConf.Certificates = []tls.Certificate{cert}
|
||||
return tlsConf, nil
|
||||
}
|
||||
|
||||
type mainServerMux struct {
|
||||
ln net.Listener
|
||||
httpsSrv *http.Server
|
||||
redirectSrv *http.Server
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func newMainServerMux(ln net.Listener, httpsSrv *http.Server, httpsPort int, logger *zap.Logger) *mainServerMux {
|
||||
return &mainServerMux{
|
||||
ln: ln,
|
||||
httpsSrv: httpsSrv,
|
||||
redirectSrv: &http.Server{Handler: newHTTPToHTTPSRedirectHandler(httpsPort), ReadHeaderTimeout: 10 * time.Second},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mainServerMux) Serve() error {
|
||||
for {
|
||||
conn, err := m.ln.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
go m.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mainServerMux) handleConn(raw net.Conn) {
|
||||
if err := raw.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
_ = raw.Close()
|
||||
return
|
||||
}
|
||||
br := bufio.NewReader(raw)
|
||||
b, err := br.Peek(1)
|
||||
if err != nil {
|
||||
_ = raw.Close()
|
||||
return
|
||||
}
|
||||
_ = raw.SetReadDeadline(time.Time{})
|
||||
|
||||
pc := &peekedConn{Conn: raw, r: br}
|
||||
ocl := &oneConnListener{conn: pc, addr: raw.LocalAddr()}
|
||||
|
||||
if isTLSHandshakeRecord(b[0]) {
|
||||
m.serveHTTPS(pc, raw.LocalAddr())
|
||||
return
|
||||
}
|
||||
if err := m.redirectSrv.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||
m.logger.Debug("HTTP 重定向连接处理结束", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// serveHTTPS 在已嗅探为 TLS 的连接上完成握手,再按 ALPN 走 HTTP/2 或 HTTP/1.1。
|
||||
// 不能对同一 http.Server 并发调用 Serve(TLSConfig!=nil),否则握手/ALPN 会异常(浏览器 ERR_SSL_PROTOCOL_ERROR)。
|
||||
func (m *mainServerMux) serveHTTPS(pc *peekedConn, localAddr net.Addr) {
|
||||
tlsConn := tls.Server(pc, m.httpsSrv.TLSConfig)
|
||||
handCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
if err := tlsConn.HandshakeContext(handCtx); err != nil {
|
||||
m.logger.Debug("TLS 握手失败", zap.Error(err))
|
||||
_ = pc.Close()
|
||||
return
|
||||
}
|
||||
|
||||
srv := m.httpsSrv
|
||||
if srv.TLSNextProto != nil {
|
||||
proto := tlsConn.ConnectionState().NegotiatedProtocol
|
||||
if fn := srv.TLSNextProto[proto]; fn != nil {
|
||||
fn(srv, tlsConn, srv.Handler)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
plain := *srv
|
||||
plain.TLSConfig = nil
|
||||
ocl := &oneConnListener{conn: tlsConn, addr: localAddr}
|
||||
if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||
m.logger.Debug("HTTPS 连接处理结束", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mainServerMux) Shutdown(ctx context.Context) error {
|
||||
_ = m.ln.Close()
|
||||
var err1, err2 error
|
||||
if m.httpsSrv != nil {
|
||||
err1 = m.httpsSrv.Shutdown(ctx)
|
||||
}
|
||||
if m.redirectSrv != nil {
|
||||
err2 = m.redirectSrv.Shutdown(ctx)
|
||||
}
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
func TestNewHTTPToHTTPSRedirectHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
httpsPort int
|
||||
host string
|
||||
uri string
|
||||
wantTarget string
|
||||
}{
|
||||
{
|
||||
name: "non standard port",
|
||||
httpsPort: 8080,
|
||||
host: "127.0.0.1:8080",
|
||||
uri: "/login?next=/",
|
||||
wantTarget: "https://127.0.0.1:8080/login?next=/",
|
||||
},
|
||||
{
|
||||
name: "standard port",
|
||||
httpsPort: 443,
|
||||
host: "example.com:80",
|
||||
uri: "/",
|
||||
wantTarget: "https://example.com/",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newHTTPToHTTPSRedirectHandler(tt.httpsPort)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://"+tt.host+tt.uri, nil)
|
||||
req.Host = tt.host
|
||||
rec := httptest.NewRecorder()
|
||||
h.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusPermanentRedirect {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusPermanentRedirect)
|
||||
}
|
||||
if got := rec.Header().Get("Location"); got != tt.wantTarget {
|
||||
t.Fatalf("Location = %q, want %q", got, tt.wantTarget)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTLSHandshakeRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !isTLSHandshakeRecord(0x16) {
|
||||
t.Fatal("expected TLS handshake record")
|
||||
}
|
||||
if isTLSHandshakeRecord('G') {
|
||||
t.Fatal("GET should not be TLS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerHTTPRedirectEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
disabled := false
|
||||
enabled := true
|
||||
if config.ServerHTTPRedirectEnabled(nil) {
|
||||
t.Fatal("nil config should disable redirect")
|
||||
}
|
||||
if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true}) {
|
||||
t.Fatal("HTTPS without explicit flag should enable redirect")
|
||||
}
|
||||
if config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &disabled}) {
|
||||
t.Fatal("explicit false should disable redirect")
|
||||
}
|
||||
if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &enabled}) {
|
||||
t.Fatal("explicit true should enable redirect")
|
||||
}
|
||||
if config.ServerHTTPRedirectEnabled(&config.ServerConfig{}) {
|
||||
t.Fatal("plain HTTP should not redirect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMainServerMuxHTTPRedirectAndHTTPS(t *testing.T) {
|
||||
cert, err := generateMainServerSelfSignedCert()
|
||||
if err != nil {
|
||||
t.Fatalf("generate cert: %v", err)
|
||||
}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
})
|
||||
srv := &http.Server{Handler: handler, TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}}
|
||||
if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil {
|
||||
t.Fatalf("configure http2: %v", err)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
mux := newMainServerMux(ln, srv, portFromListenAddr(ln.Addr().String()), nil)
|
||||
go func() { _ = mux.Serve() }()
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12},
|
||||
},
|
||||
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
|
||||
httpResp, err := client.Get("http://" + addr + "/")
|
||||
if err != nil {
|
||||
t.Fatalf("http get: %v", err)
|
||||
}
|
||||
_ = httpResp.Body.Close()
|
||||
if httpResp.StatusCode != http.StatusPermanentRedirect {
|
||||
t.Fatalf("http status = %d, want %d", httpResp.StatusCode, http.StatusPermanentRedirect)
|
||||
}
|
||||
if got := httpResp.Header.Get("Location"); got != "https://127.0.0.1:"+strconv.Itoa(portFromListenAddr(addr))+"/" {
|
||||
t.Fatalf("Location = %q", got)
|
||||
}
|
||||
|
||||
httpsResp, err := client.Get("https://" + addr + "/")
|
||||
if err != nil {
|
||||
t.Fatalf("https get: %v", err)
|
||||
}
|
||||
defer httpsResp.Body.Close()
|
||||
if httpsResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("https status = %d, want %d", httpsResp.StatusCode, http.StatusOK)
|
||||
}
|
||||
body, _ := io.ReadAll(httpsResp.Body)
|
||||
if string(body) != "ok" {
|
||||
t.Fatalf("body = %q, want ok", body)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
)
|
||||
|
||||
// mainTLSMode 主 Web 服务 TLS 启动方式。
|
||||
type mainTLSMode int
|
||||
|
||||
const (
|
||||
mainTLSOff mainTLSMode = iota
|
||||
mainTLSFromFiles
|
||||
mainTLSInMemorySelfSigned
|
||||
)
|
||||
|
||||
// prepareMainServerTLS 根据 server 配置决定主站是否启用 HTTPS(及 HTTP/2 协商)。
|
||||
// fromFiles:使用 tls_cert_path + tls_key_path,由 http.Server.ListenAndServeTLS 加载 PEM。
|
||||
// inMemory:tls_auto_self_sign 生成的自签证书,仅用于本地/测试。
|
||||
func prepareMainServerTLS(cfg *config.ServerConfig) (mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string, err error) {
|
||||
if cfg == nil || !config.MainWebUIUsesHTTPS(cfg) {
|
||||
return mainTLSOff, nil, "", "", nil
|
||||
}
|
||||
certFile = strings.TrimSpace(cfg.TLSCertPath)
|
||||
keyFile = strings.TrimSpace(cfg.TLSKeyPath)
|
||||
if certFile != "" && keyFile != "" {
|
||||
// 证书由 ListenAndServeTLS 从文件加载;此处仅提供最小 TLS 配置供 http2.ConfigureServer 合并 ALPN。
|
||||
return mainTLSFromFiles, &tls.Config{MinVersion: tls.VersionTLS12}, certFile, keyFile, nil
|
||||
}
|
||||
if cfg.TLSAutoSelfSign {
|
||||
cert, genErr := generateMainServerSelfSignedCert()
|
||||
if genErr != nil {
|
||||
return mainTLSOff, nil, "", "", fmt.Errorf("生成自签 TLS 证书: %w", genErr)
|
||||
}
|
||||
tlsConf = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
return mainTLSInMemorySelfSigned, tlsConf, "", "", nil
|
||||
}
|
||||
return mainTLSOff, nil, "", "", fmt.Errorf("server: 已启用 TLS(tls_enabled / tls_auto_self_sign / 证书路径),请设置 tls_cert_path 与 tls_key_path,或将 tls_auto_self_sign 设为 true(仅测试环境)")
|
||||
}
|
||||
|
||||
func generateMainServerSelfSignedCert() (tls.Certificate, error) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{CommonName: "CyberStrikeAI"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
keyDER, err := x509.MarshalECPrivateKey(priv)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
return tls.X509KeyPair(certPEM, keyPEM)
|
||||
}
|
||||
@@ -0,0 +1,336 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/project"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func projectIDFromConversation(db *database.DB, ctx context.Context) (string, error) {
|
||||
convID := agent.ConversationIDFromContext(ctx)
|
||||
if convID == "" {
|
||||
return "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用项目事实工具")
|
||||
}
|
||||
pid, err := db.GetConversationProjectID(convID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(pid) == "" {
|
||||
return "", fmt.Errorf("当前对话未绑定项目,请先在对话中选择项目或创建带项目的对话")
|
||||
}
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
func textResult(msg string, isErr bool) *mcp.ToolResult {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{{Type: "text", Text: msg}},
|
||||
IsError: isErr,
|
||||
}
|
||||
}
|
||||
|
||||
// registerProjectFactTools 注册项目黑板 MCP 工具。
|
||||
func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *config.Config, logger *zap.Logger) {
|
||||
if db == nil || cfg == nil || !cfg.Project.Enabled {
|
||||
if logger != nil {
|
||||
logger.Info("项目黑板工具未注册(未启用)")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
upsertTool := mcp.Tool{
|
||||
Name: builtin.ToolUpsertProjectFact,
|
||||
Description: "写入或更新项目黑板事实,用于跨会话沉淀可复现上下文(非正式漏洞条目;可交付漏洞另用 record_vulnerability)。" +
|
||||
"边渗透边记录:每确认新认知(端口/入口/凭据/可利用点)后立即调用,同 fact_key 覆盖更新,勿等会话结束。" +
|
||||
"禁止仅写结论:summary 须含什么+在哪+如何验证;body 须含攻击链/请求响应/命令等复现细节。" +
|
||||
"发现类建议 fact_key 为 finding|chain|exploit|poc/<slug>,category 对应 finding|chain|exploit|poc,body 按攻击链模板填写。" +
|
||||
"环境类用 target|auth|infra|business/<slug>。同 fact_key 覆盖更新。需当前对话已绑定项目。",
|
||||
ShortDescription: "写入/更新项目事实(含攻击链 body)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"fact_key": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "项目内唯一 key:target/primary_domain、finding/sqli-login、exploit/upload-rce 等",
|
||||
},
|
||||
"category": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "target | auth | infra | business | finding | chain | exploit | poc | note",
|
||||
"enum": []string{"target", "auth", "infra", "business", "finding", "chain", "exploit", "poc", "note"},
|
||||
},
|
||||
"summary": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "索引用一行:结论 + 位置 + 触发/验证要点(勿仅写「存在 XSS」等空话)",
|
||||
},
|
||||
"body": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "完整可复现详情(仅 get_project_fact 返回):须含攻击链步骤、原始 HTTP/命令、响应现象、证据与关联。" +
|
||||
"发现/利用类首次写入必填;环境类建议含来源证据。攻击链类可参考模板章节:结论、目标与入口、攻击链、Exploit/POC、关键证据、关联、备注。" +
|
||||
"更新已有 fact_key 时若省略或留空 body,将保留库中已有 body(可只改 summary)。",
|
||||
},
|
||||
"confidence": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "confirmed | tentative | deprecated",
|
||||
"enum": []string{"confirmed", "tentative", "deprecated"},
|
||||
},
|
||||
"pinned": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "是否优先出现在黑板索引",
|
||||
},
|
||||
"related_vulnerability_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选:关联的漏洞记录 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"fact_key", "summary"},
|
||||
},
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(upsertTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
factKey, _ := args["fact_key"].(string)
|
||||
summary, _ := args["summary"].(string)
|
||||
if strings.TrimSpace(factKey) == "" || strings.TrimSpace(summary) == "" {
|
||||
return textResult("错误: fact_key 与 summary 必填", true), nil
|
||||
}
|
||||
if len([]rune(summary)) > cfg.Project.FactSummaryMaxRunesEffective() {
|
||||
return textResult(fmt.Sprintf("错误: summary 过长(最多 %d 字)", cfg.Project.FactSummaryMaxRunesEffective()), true), nil
|
||||
}
|
||||
f := &database.ProjectFact{
|
||||
ProjectID: projectID,
|
||||
FactKey: factKey,
|
||||
Category: strArg(args, "category"),
|
||||
Summary: summary,
|
||||
Body: strArg(args, "body"),
|
||||
Confidence: strArg(args, "confidence"),
|
||||
Pinned: boolArg(args, "pinned"),
|
||||
RelatedVulnerabilityID: strArg(args, "related_vulnerability_id"),
|
||||
}
|
||||
if convID := agent.ConversationIDFromContext(ctx); convID != "" {
|
||||
f.SourceConversationID = convID
|
||||
}
|
||||
created, err := db.UpsertProjectFact(f)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence)
|
||||
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||
msg += warn
|
||||
}
|
||||
return textResult(msg, false), nil
|
||||
})
|
||||
|
||||
getTool := mcp.Tool{
|
||||
Name: builtin.ToolGetProjectFact,
|
||||
Description: "按 fact_key 获取项目事实完整 body 与元数据。摘要不足时必须调用本工具,禁止臆造细节。",
|
||||
ShortDescription: "按 key 获取事实详情",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"fact_key": map[string]interface{}{"type": "string", "description": "事实 key"},
|
||||
},
|
||||
"required": []string{"fact_key"},
|
||||
},
|
||||
}
|
||||
mcpServer.RegisterTool(getTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
key := strings.TrimSpace(strArg(args, "fact_key"))
|
||||
if key == "" {
|
||||
return textResult("错误: fact_key 必填", true), nil
|
||||
}
|
||||
f, err := db.GetProjectFactByKey(projectID, key)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
msg := fmt.Sprintf("fact_key: %s\ncategory: %s\nconfidence: %s\nsummary: %s\nupdated_at: %s",
|
||||
f.FactKey, f.Category, f.Confidence, f.Summary, f.UpdatedAt.Format("2006-01-02 15:04:05"))
|
||||
if f.RelatedVulnerabilityID != "" {
|
||||
msg += fmt.Sprintf("\nrelated_vulnerability_id: %s", f.RelatedVulnerabilityID)
|
||||
}
|
||||
if f.SourceConversationID != "" {
|
||||
msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID)
|
||||
}
|
||||
msg += "\n\n--- body ---\n" + f.Body
|
||||
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||
msg += warn
|
||||
}
|
||||
return textResult(msg, false), nil
|
||||
})
|
||||
|
||||
listTool := mcp.Tool{
|
||||
Name: builtin.ToolListProjectFacts,
|
||||
Description: "列出当前项目的事实(分页)。",
|
||||
ShortDescription: "列出项目事实",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"category": map[string]interface{}{"type": "string"},
|
||||
"confidence": map[string]interface{}{"type": "string"},
|
||||
"limit": map[string]interface{}{"type": "integer"},
|
||||
"offset": map[string]interface{}{"type": "integer"},
|
||||
},
|
||||
},
|
||||
}
|
||||
mcpServer.RegisterTool(listTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
limit := intArg(args, "limit", 50)
|
||||
offset := intArg(args, "offset", 0)
|
||||
filter := database.ProjectFactListFilter{
|
||||
Category: strArg(args, "category"),
|
||||
Confidence: strArg(args, "confidence"),
|
||||
}
|
||||
list, err := db.ListProjectFacts(projectID, filter, limit, offset)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("共 %d 条(limit=%d offset=%d):\n", len(list), limit, offset))
|
||||
for _, f := range list {
|
||||
b.WriteString(fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, f.Summary, f.Confidence))
|
||||
}
|
||||
return textResult(b.String(), false), nil
|
||||
})
|
||||
|
||||
searchTool := mcp.Tool{
|
||||
Name: builtin.ToolSearchProjectFacts,
|
||||
Description: "按关键词搜索项目事实(summary/body/fact_key)。",
|
||||
ShortDescription: "搜索项目事实",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{"type": "string"},
|
||||
"limit": map[string]interface{}{"type": "integer"},
|
||||
"offset": map[string]interface{}{"type": "integer"},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
}
|
||||
mcpServer.RegisterTool(searchTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
q := strings.TrimSpace(strArg(args, "query"))
|
||||
if q == "" {
|
||||
return textResult("错误: query 必填", true), nil
|
||||
}
|
||||
list, err := db.ListProjectFacts(projectID, database.ProjectFactListFilter{Search: q}, intArg(args, "limit", 30), intArg(args, "offset", 0))
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("搜索 \"%s\" 命中 %d 条:\n", q, len(list)))
|
||||
for _, f := range list {
|
||||
b.WriteString(fmt.Sprintf("- [%s] %s — %s\n", f.FactKey, f.Category, f.Summary))
|
||||
}
|
||||
return textResult(b.String(), false), nil
|
||||
})
|
||||
|
||||
deprecateTool := mcp.Tool{
|
||||
Name: builtin.ToolDeprecateProjectFact,
|
||||
Description: "将事实标记为 deprecated,从黑板索引中排除。",
|
||||
ShortDescription: "废弃项目事实",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"fact_key": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"fact_key"},
|
||||
},
|
||||
}
|
||||
mcpServer.RegisterTool(deprecateTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
key := strings.TrimSpace(strArg(args, "fact_key"))
|
||||
if err := db.DeprecateProjectFact(projectID, key); err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
return textResult("事实已标记为 deprecated: "+key, false), nil
|
||||
})
|
||||
|
||||
restoreTool := mcp.Tool{
|
||||
Name: builtin.ToolRestoreProjectFact,
|
||||
Description: "将已废弃(deprecated)的事实恢复为 tentative 或 confirmed,重新参与黑板索引。",
|
||||
ShortDescription: "恢复已废弃的项目事实",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"fact_key": map[string]interface{}{"type": "string"},
|
||||
"confidence": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "恢复后的置信度:tentative(默认)或 confirmed",
|
||||
"enum": []string{"tentative", "confirmed"},
|
||||
},
|
||||
},
|
||||
"required": []string{"fact_key"},
|
||||
},
|
||||
}
|
||||
mcpServer.RegisterTool(restoreTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
key := strings.TrimSpace(strArg(args, "fact_key"))
|
||||
if key == "" {
|
||||
return textResult("错误: fact_key 必填", true), nil
|
||||
}
|
||||
conf := strArg(args, "confidence")
|
||||
if err := db.RestoreProjectFact(projectID, key, conf); err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
if conf == "" {
|
||||
conf = "tentative"
|
||||
}
|
||||
return textResult(fmt.Sprintf("事实已恢复为 %s: %s", conf, key), false), nil
|
||||
})
|
||||
|
||||
if logger != nil {
|
||||
logger.Info("项目黑板 MCP 工具注册成功")
|
||||
}
|
||||
}
|
||||
|
||||
func strArg(args map[string]interface{}, key string) string {
|
||||
if v, ok := args[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func boolArg(args map[string]interface{}, key string) bool {
|
||||
if v, ok := args[key].(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func intArg(args map[string]interface{}, key string, def int) int {
|
||||
switch v := args[key].(type) {
|
||||
case float64:
|
||||
return int(v)
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
default:
|
||||
return def
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,405 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func conversationIDFromToolCtx(ctx context.Context) string {
|
||||
if id := agent.ConversationIDFromContext(ctx); id != "" {
|
||||
return id
|
||||
}
|
||||
return mcp.MCPConversationIDFromContext(ctx)
|
||||
}
|
||||
|
||||
// canAccessVulnerability 校验当前对话是否有权查看该漏洞(默认项目隔离,未绑项目则仅本会话)。
|
||||
func canAccessVulnerability(vuln *database.Vulnerability, convID, projectID string) bool {
|
||||
if vuln == nil || convID == "" {
|
||||
return false
|
||||
}
|
||||
if projectID != "" {
|
||||
if strings.TrimSpace(vuln.ProjectID) == projectID {
|
||||
return true
|
||||
}
|
||||
// 历史记录:写入时尚未绑定 project_id,但属于同一会话
|
||||
if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID == convID {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return vuln.ConversationID == convID
|
||||
}
|
||||
|
||||
func buildVulnerabilityListFilter(db *database.DB, ctx context.Context, args map[string]interface{}) (database.VulnerabilityListFilter, string, error) {
|
||||
convID := conversationIDFromToolCtx(ctx)
|
||||
if convID == "" {
|
||||
return database.VulnerabilityListFilter{}, "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用漏洞查询工具")
|
||||
}
|
||||
|
||||
projectID := ""
|
||||
if pid, err := db.GetConversationProjectID(convID); err == nil {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
|
||||
scope := strings.TrimSpace(strArg(args, "scope"))
|
||||
if scope == "" {
|
||||
if projectID != "" {
|
||||
scope = "project"
|
||||
} else {
|
||||
scope = "conversation"
|
||||
}
|
||||
}
|
||||
|
||||
filter := database.VulnerabilityListFilter{
|
||||
Severity: strings.TrimSpace(strArg(args, "severity")),
|
||||
Status: strings.TrimSpace(strArg(args, "status")),
|
||||
}
|
||||
if q := strings.TrimSpace(strArg(args, "q")); q != "" {
|
||||
filter.Search = q
|
||||
} else {
|
||||
filter.Search = strings.TrimSpace(strArg(args, "search"))
|
||||
}
|
||||
|
||||
var scopeLabel string
|
||||
switch scope {
|
||||
case "project":
|
||||
if projectID == "" {
|
||||
return filter, "", fmt.Errorf("当前对话未绑定项目,无法按项目列出漏洞;请使用 scope=conversation,或先在对话中绑定项目")
|
||||
}
|
||||
filter.ProjectID = projectID
|
||||
scopeLabel = fmt.Sprintf("项目 %s", projectID)
|
||||
case "conversation":
|
||||
filter.ConversationID = convID
|
||||
scopeLabel = fmt.Sprintf("会话 %s", convID)
|
||||
default:
|
||||
return filter, "", fmt.Errorf("scope 仅支持 project 或 conversation,当前值: %s", scope)
|
||||
}
|
||||
return filter, scopeLabel, nil
|
||||
}
|
||||
|
||||
func formatVulnerabilityListItem(v *database.Vulnerability) string {
|
||||
line := fmt.Sprintf("- id=%s | %s | %s | %s", v.ID, v.Severity, v.Status, v.Title)
|
||||
if v.Type != "" {
|
||||
line += fmt.Sprintf(" | type=%s", v.Type)
|
||||
}
|
||||
if v.Target != "" {
|
||||
line += fmt.Sprintf(" | target=%s", truncateRunes(v.Target, 80))
|
||||
}
|
||||
return line
|
||||
}
|
||||
|
||||
func formatVulnerabilityDetail(v *database.Vulnerability) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("漏洞ID: %s\n", v.ID))
|
||||
b.WriteString(fmt.Sprintf("标题: %s\n", v.Title))
|
||||
b.WriteString(fmt.Sprintf("严重程度: %s\n", v.Severity))
|
||||
b.WriteString(fmt.Sprintf("状态: %s\n", v.Status))
|
||||
if v.Type != "" {
|
||||
b.WriteString(fmt.Sprintf("类型: %s\n", v.Type))
|
||||
}
|
||||
if v.Target != "" {
|
||||
b.WriteString(fmt.Sprintf("目标: %s\n", v.Target))
|
||||
}
|
||||
if v.ProjectID != "" {
|
||||
b.WriteString(fmt.Sprintf("项目ID: %s\n", v.ProjectID))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("会话ID: %s\n", v.ConversationID))
|
||||
if !v.CreatedAt.IsZero() {
|
||||
b.WriteString(fmt.Sprintf("创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05")))
|
||||
}
|
||||
if v.Description != "" {
|
||||
b.WriteString("\n--- 描述 ---\n")
|
||||
b.WriteString(v.Description)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Proof != "" {
|
||||
b.WriteString("\n--- 证明(POC) ---\n")
|
||||
b.WriteString(v.Proof)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Impact != "" {
|
||||
b.WriteString("\n--- 影响 ---\n")
|
||||
b.WriteString(v.Impact)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Recommendation != "" {
|
||||
b.WriteString("\n--- 修复建议 ---\n")
|
||||
b.WriteString(v.Recommendation)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func truncateRunes(s string, max int) string {
|
||||
r := []rune(s)
|
||||
if len(r) <= max {
|
||||
return s
|
||||
}
|
||||
return string(r[:max]) + "…"
|
||||
}
|
||||
|
||||
// registerVulnerabilityTools 注册漏洞记录与查询 MCP 工具。
|
||||
func registerVulnerabilityTools(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
registerRecordVulnerabilityTool(mcpServer, db, logger)
|
||||
registerListVulnerabilitiesTool(mcpServer, db, logger)
|
||||
registerGetVulnerabilityTool(mcpServer, db, logger)
|
||||
if logger != nil {
|
||||
logger.Info("漏洞 MCP 工具注册成功", zap.Strings("tools", []string{
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: builtin.ToolRecordVulnerability,
|
||||
Description: "记录发现的漏洞详情到漏洞管理系统。边渗透边记录:每验证出一条可复现漏洞(含 POC/影响)后立即调用,勿等会话结束。包括标题、描述、严重程度、类型、目标、证明、影响和建议等。记录前可先 list_vulnerabilities 避免重复。",
|
||||
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"title": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞标题(必需)",
|
||||
},
|
||||
"description": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞详细描述",
|
||||
},
|
||||
"severity": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
|
||||
"enum": []string{"critical", "high", "medium", "low", "info"},
|
||||
},
|
||||
"vulnerability_type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
|
||||
},
|
||||
"target": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "受影响的目标(URL、IP地址、服务等)",
|
||||
},
|
||||
"proof": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞证明(POC、截图、请求/响应等)",
|
||||
},
|
||||
"impact": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞影响说明",
|
||||
},
|
||||
"recommendation": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "修复建议",
|
||||
},
|
||||
},
|
||||
"required": []string{"title", "severity"},
|
||||
},
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
conversationID := strings.TrimSpace(strArg(args, "conversation_id"))
|
||||
if conversationID == "" {
|
||||
conversationID = conversationIDFromToolCtx(ctx)
|
||||
}
|
||||
if conversationID == "" {
|
||||
return textResult("错误: conversation_id 未设置。这是系统错误,请重试。", true), nil
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(strArg(args, "title"))
|
||||
if title == "" {
|
||||
return textResult("错误: title 参数必需且不能为空", true), nil
|
||||
}
|
||||
|
||||
severity := strings.TrimSpace(strArg(args, "severity"))
|
||||
if severity == "" {
|
||||
return textResult("错误: severity 参数必需且不能为空", true), nil
|
||||
}
|
||||
|
||||
validSeverities := map[string]bool{
|
||||
"critical": true, "high": true, "medium": true, "low": true, "info": true,
|
||||
}
|
||||
if !validSeverities[severity] {
|
||||
return textResult(fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity), true), nil
|
||||
}
|
||||
|
||||
projectID := ""
|
||||
if pid, perr := db.GetConversationProjectID(conversationID); perr == nil {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
|
||||
vuln := &database.Vulnerability{
|
||||
ConversationID: conversationID,
|
||||
ProjectID: projectID,
|
||||
Title: title,
|
||||
Description: strArg(args, "description"),
|
||||
Severity: severity,
|
||||
Status: "open",
|
||||
Type: strArg(args, "vulnerability_type"),
|
||||
Target: strArg(args, "target"),
|
||||
Proof: strArg(args, "proof"),
|
||||
Impact: strArg(args, "impact"),
|
||||
Recommendation: strArg(args, "recommendation"),
|
||||
}
|
||||
|
||||
created, err := db.CreateVulnerability(vuln)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Error("记录漏洞失败", zap.Error(err))
|
||||
}
|
||||
return textResult(fmt.Sprintf("记录漏洞失败: %v", err), true), nil
|
||||
}
|
||||
|
||||
if logger != nil {
|
||||
logger.Info("漏洞记录成功",
|
||||
zap.String("id", created.ID),
|
||||
zap.String("title", created.Title),
|
||||
zap.String("severity", created.Severity),
|
||||
zap.String("conversation_id", conversationID),
|
||||
)
|
||||
}
|
||||
|
||||
return textResult(fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n可使用 get_vulnerability(id) 查看详情,或 list_vulnerabilities 查看列表。",
|
||||
created.ID, created.Title, created.Severity, created.Status), false), nil
|
||||
})
|
||||
}
|
||||
|
||||
func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: builtin.ToolListVulnerabilities,
|
||||
Description: "列出当前授权范围内的漏洞(摘要)。默认:对话已绑定项目时列出该项目下全部漏洞;未绑项目时仅列出当前会话漏洞。可用 scope=conversation 仅看本会话。记录新漏洞前建议先调用以避免重复。",
|
||||
ShortDescription: "列出漏洞(默认当前项目)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"scope": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "范围:project(默认,需绑定项目)| conversation(仅当前会话)",
|
||||
"enum": []string{"project", "conversation"},
|
||||
},
|
||||
"severity": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "按严重程度筛选:critical、high、medium、low、info",
|
||||
"enum": []string{"critical", "high", "medium", "low", "info"},
|
||||
},
|
||||
"status": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "按状态筛选:open、confirmed、fixed、false_positive",
|
||||
"enum": []string{"open", "confirmed", "fixed", "false_positive"},
|
||||
},
|
||||
"q": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "关键词搜索(标题、描述、类型、目标等)",
|
||||
},
|
||||
"limit": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "返回条数上限,默认 30,最大 100",
|
||||
},
|
||||
"offset": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "分页偏移,默认 0",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
filter, scopeLabel, err := buildVulnerabilityListFilter(db, ctx, args)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
|
||||
limit := intArg(args, "limit", 30)
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 30
|
||||
}
|
||||
offset := intArg(args, "offset", 0)
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
total, err := db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("统计漏洞失败", zap.Error(err))
|
||||
}
|
||||
total = 0
|
||||
}
|
||||
|
||||
list, err := db.ListVulnerabilities(limit, offset, filter)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("范围: %s\n总计: %d | 本页: %d 条 (limit=%d offset=%d)\n\n", scopeLabel, total, len(list), limit, offset))
|
||||
if len(list) == 0 {
|
||||
b.WriteString("(暂无漏洞记录)\n")
|
||||
} else {
|
||||
for _, v := range list {
|
||||
b.WriteString(formatVulnerabilityListItem(v))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if total > offset+len(list) {
|
||||
b.WriteString(fmt.Sprintf("\n(还有更多,可增大 offset 或使用 q/severity/status 筛选)\n"))
|
||||
}
|
||||
}
|
||||
b.WriteString("\n需要 POC 与完整字段请对具体 id 调用 get_vulnerability。")
|
||||
return textResult(b.String(), false), nil
|
||||
})
|
||||
}
|
||||
|
||||
func registerGetVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: builtin.ToolGetVulnerability,
|
||||
Description: "按漏洞 ID 获取完整详情(含 POC、影响、修复建议)。仅能访问当前项目或当前会话下的漏洞(与 list_vulnerabilities 授权范围一致)。",
|
||||
ShortDescription: "按 ID 获取漏洞详情",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞 ID(list_vulnerabilities 返回的 id)",
|
||||
},
|
||||
},
|
||||
"required": []string{"id"},
|
||||
},
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
convID := conversationIDFromToolCtx(ctx)
|
||||
if convID == "" {
|
||||
return textResult("错误: 无法确定当前对话,请在对话上下文中使用本工具", true), nil
|
||||
}
|
||||
|
||||
id := strings.TrimSpace(strArg(args, "id"))
|
||||
if id == "" {
|
||||
return textResult("错误: id 必填", true), nil
|
||||
}
|
||||
|
||||
vuln, err := db.GetVulnerability(id)
|
||||
if err != nil {
|
||||
return textResult("错误: 漏洞不存在或查询失败", true), nil
|
||||
}
|
||||
|
||||
projectID := ""
|
||||
if pid, perr := db.GetConversationProjectID(convID); perr == nil {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
|
||||
if !canAccessVulnerability(vuln, convID, projectID) {
|
||||
return textResult("错误: 无权访问该漏洞(仅可查看当前项目或当前会话下的记录)", true), nil
|
||||
}
|
||||
|
||||
return textResult(formatVulnerabilityDetail(vuln), false), nil
|
||||
})
|
||||
}
|
||||
@@ -82,7 +82,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
|
||||
}
|
||||
}
|
||||
|
||||
// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出)
|
||||
// BuildChainFromConversation 从对话构建攻击链(单次 LLM 调用;输入为当前任务轮次的 last_react 轨迹,与继续对话续跑范围一致)。
|
||||
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
||||
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
|
||||
|
||||
@@ -157,33 +157,34 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
var reactInputFinal string
|
||||
var dataSource string // 记录数据来源
|
||||
|
||||
// 如果成功获取到保存的ReAct数据,直接使用
|
||||
if reactInputJSON != "" && modelOutput != "" {
|
||||
// 计算 ReAct 输入的哈希值,用于追踪
|
||||
hash := sha256.Sum256([]byte(reactInputJSON))
|
||||
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识
|
||||
// 优先使用落库的代理轨迹(与继续对话 loadHistoryFromAgentTrace 同源),并裁剪为「当前任务轮次」
|
||||
if reactInputJSON != "" {
|
||||
trimmedJSON := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
|
||||
hash := sha256.Sum256([]byte(trimmedJSON))
|
||||
reactInputHash := hex.EncodeToString(hash[:])[:16]
|
||||
|
||||
// 统计消息数量
|
||||
var messageCount int
|
||||
var tempMessages []interface{}
|
||||
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil {
|
||||
messageCount = len(tempMessages)
|
||||
if msgs, parseErr := agent.ParseTraceMessages(trimmedJSON); parseErr == nil {
|
||||
messageCount = len(msgs)
|
||||
msgs = agent.MergeAssistantTraceOutput(msgs, modelOutput)
|
||||
reactInputFinal = b.formatAgentTraceFromChatMessages(msgs)
|
||||
} else {
|
||||
b.logger.Warn("解析代理轨迹失败,回退原始 JSON 格式化", zap.Error(parseErr))
|
||||
reactInputFinal = b.formatAgentTraceInputFromJSON(trimmedJSON)
|
||||
if strings.TrimSpace(modelOutput) != "" {
|
||||
reactInputFinal += "\n\n## 助手结论(last_react_output)\n\n" + modelOutput
|
||||
}
|
||||
}
|
||||
|
||||
dataSource = "database_last_agent_trace"
|
||||
b.logger.Info("使用保存的ReAct数据构建攻击链",
|
||||
dataSource = "last_user_turn_agent_trace"
|
||||
b.logger.Info("使用当前任务轮次代理轨迹构建攻击链(与续跑上下文范围一致)",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("reactInputSize", len(reactInputJSON)),
|
||||
zap.Int("traceInputSizeBeforeTrim", len(reactInputJSON)),
|
||||
zap.Int("traceInputSizeAfterTrim", len(trimmedJSON)),
|
||||
zap.Int("messageCount", messageCount),
|
||||
zap.String("reactInputHash", reactInputHash),
|
||||
zap.Int("modelOutputSize", len(modelOutput)))
|
||||
|
||||
// 从保存的ReAct输入(JSON格式)中提取用户输入
|
||||
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
|
||||
|
||||
// 将JSON格式的messages转换为可读格式
|
||||
reactInputFinal = b.formatAgentTraceInputFromJSON(reactInputJSON)
|
||||
} else {
|
||||
// 2. 如果没有保存的ReAct数据,从对话消息构建
|
||||
dataSource = "messages_table"
|
||||
@@ -243,8 +244,15 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 构建简化的prompt,一次性传递给大模型
|
||||
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput)
|
||||
// 3. 按 token 预算压缩输入,再构建 prompt(避免超出模型上下文)
|
||||
reactInputFinal, modelOutput, _ = b.fitAttackChainPayload(reactInputFinal, modelOutput)
|
||||
|
||||
// 4. 构建 prompt 并单次调用大模型(助手结论已并入轨迹时不再重复传入)
|
||||
promptAssistantOut := modelOutput
|
||||
if reactInputJSON != "" {
|
||||
promptAssistantOut = ""
|
||||
}
|
||||
prompt := b.buildSimplePrompt(reactInputFinal, promptAssistantOut)
|
||||
// fmt.Println(prompt)
|
||||
// 6. 调用AI生成攻击链(一次性,不做任何处理)
|
||||
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
||||
@@ -301,7 +309,7 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
|
||||
// 目标:以主 agent(编排器)视角输出整轮迭代
|
||||
// - 保留:编排器工具调用/结果、对子代理的 task 调度、子代理最终回复(不含推理)
|
||||
// - 丢弃:thinking/planning/progress 等噪声、子代理的工具细节与推理过程
|
||||
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "planning" {
|
||||
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "reasoning_chain" || d.EventType == "planning" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -366,10 +374,17 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
// buildAgentTraceInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
|
||||
// buildAgentTraceInput 构建最后一轮 ReAct 的输入(从最后一条 user 消息起,不含更早轮次)。
|
||||
func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
||||
start := 0
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "user") {
|
||||
start = i
|
||||
break
|
||||
}
|
||||
}
|
||||
var builder strings.Builder
|
||||
for _, msg := range messages {
|
||||
for _, msg := range messages[start:] {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
|
||||
}
|
||||
return builder.String()
|
||||
@@ -396,67 +411,66 @@ func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// formatAgentTraceInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
|
||||
// formatAgentTraceInputFromJSON 将 JSON 轨迹转为可读文本(会先按当前任务轮次裁剪)。
|
||||
func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
|
||||
var messages []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||||
trimmed := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
|
||||
msgs, err := agent.ParseTraceMessages(trimmed)
|
||||
if err != nil {
|
||||
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||
return reactInputJSON // 如果解析失败,返回原始JSON
|
||||
return trimmed
|
||||
}
|
||||
return b.formatAgentTraceFromChatMessages(msgs)
|
||||
}
|
||||
|
||||
// formatAgentTraceFromChatMessages 将代理消息带格式化为攻击链分析输入(与续跑轨迹字段一致)。
|
||||
func (b *Builder) formatAgentTraceFromChatMessages(msgs []agent.ChatMessage) string {
|
||||
var builder strings.Builder
|
||||
for _, msg := range messages {
|
||||
role, _ := msg["role"].(string)
|
||||
content, _ := msg["content"].(string)
|
||||
for _, msg := range msgs {
|
||||
role := msg.Role
|
||||
content := msg.Content
|
||||
|
||||
// 处理assistant消息:提取tool_calls信息
|
||||
if role == "assistant" {
|
||||
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
|
||||
// 如果有文本内容,先显示
|
||||
if content != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
||||
}
|
||||
// 详细显示每个工具调用
|
||||
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls)))
|
||||
for i, toolCall := range toolCalls {
|
||||
if tc, ok := toolCall.(map[string]interface{}); ok {
|
||||
toolCallID, _ := tc["id"].(string)
|
||||
if funcData, ok := tc["function"].(map[string]interface{}); ok {
|
||||
toolName, _ := funcData["name"].(string)
|
||||
arguments, _ := funcData["arguments"].(string)
|
||||
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
||||
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
|
||||
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
|
||||
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
|
||||
}
|
||||
if strings.EqualFold(role, "assistant") && len(msg.ToolCalls) > 0 {
|
||||
if content != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(msg.ToolCalls)))
|
||||
for i, tc := range msg.ToolCalls {
|
||||
args := ""
|
||||
if tc.Function.Arguments != nil {
|
||||
if b, err := json.Marshal(tc.Function.Arguments); err == nil {
|
||||
args = string(b)
|
||||
}
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
continue
|
||||
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
||||
builder.WriteString(fmt.Sprintf(" ID: %s\n", tc.ID))
|
||||
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", tc.Function.Name))
|
||||
builder.WriteString(fmt.Sprintf(" 参数: %s\n", args))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理tool消息:显示tool_call_id和完整内容
|
||||
if role == "tool" {
|
||||
toolCallID, _ := msg["tool_call_id"].(string)
|
||||
if toolCallID != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
|
||||
if strings.EqualFold(role, "tool") {
|
||||
if msg.ToolCallID != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, msg.ToolCallID, content))
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 其他消息类型(system, user等)正常显示
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// buildSimplePrompt 构建简化的prompt
|
||||
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径。
|
||||
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据**当前任务轮次**的对话记录和工具执行结果,一次性输出攻击链 JSON(不要分多轮追问)。
|
||||
|
||||
## 输入范围(与「继续对话」续跑一致)
|
||||
- 下方「ReAct 轨迹」仅包含**最后一次用户提问之后**的消息与工具结果(last_react 当前任务轮次),不含更早的用户提问轮次。
|
||||
- 「助手结论」为同轮任务的最终输出摘要(last_react_output);节点须与轨迹中的实际工具执行一致,严禁编造。
|
||||
|
||||
## 核心目标
|
||||
|
||||
@@ -618,12 +632,9 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability)
|
||||
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
|
||||
|
||||
## 最后一轮ReAct输入
|
||||
## 当前任务 ReAct 轨迹(含工具执行;助手结论见轨迹末尾 assistant)
|
||||
|
||||
%s
|
||||
|
||||
## 大模型输出
|
||||
|
||||
%s
|
||||
|
||||
## 输出格式
|
||||
@@ -752,7 +763,15 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
|
||||
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
|
||||
|
||||
现在开始分析并构建攻击链:`, reactInput, modelOutput)
|
||||
现在开始分析并构建攻击链:`, reactInput, assistantOutSection(modelOutput))
|
||||
}
|
||||
|
||||
func assistantOutSection(modelOutput string) string {
|
||||
modelOutput = strings.TrimSpace(modelOutput)
|
||||
if modelOutput == "" {
|
||||
return ""
|
||||
}
|
||||
return "\n## 助手结论(补充)\n\n" + modelOutput + "\n"
|
||||
}
|
||||
|
||||
// saveChain 保存攻击链到数据库
|
||||
@@ -812,7 +831,7 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
|
||||
},
|
||||
},
|
||||
"temperature": 0.3,
|
||||
"max_completion_tokens": 80000,
|
||||
"max_completion_tokens": attackChainMaxCompletionTokens(b.maxTokens),
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
package attackchain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
attackChainTruncationMarker = "\n\n...[攻击链输入已截断 / attack chain input truncated]...\n\n"
|
||||
attackChainSystemReserve = 256
|
||||
attackChainSafetyReserve = 2048
|
||||
)
|
||||
|
||||
// attackChainMaxCompletionTokens 为攻击链 JSON 输出预留的 completion token 上限。
|
||||
func attackChainMaxCompletionTokens(maxTotal int) int {
|
||||
const capTokens = 16384
|
||||
if maxTotal <= 0 {
|
||||
return 8192
|
||||
}
|
||||
v := maxTotal / 8
|
||||
if v < 4096 {
|
||||
v = 4096
|
||||
}
|
||||
if v > capTokens {
|
||||
v = capTokens
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (b *Builder) modelName() string {
|
||||
if b.openAIConfig != nil && b.openAIConfig.Model != "" {
|
||||
return b.openAIConfig.Model
|
||||
}
|
||||
return "gpt-4"
|
||||
}
|
||||
|
||||
func (b *Builder) countTokens(text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
n, err := b.tokenCounter.Count(b.modelName(), text)
|
||||
if err != nil {
|
||||
return utf8.RuneCountInString(text) / 4
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// attackChainPayloadTokenBudget 计算 reactInput + modelOutput 可用的 token 预算。
|
||||
func (b *Builder) attackChainPayloadTokenBudget() int {
|
||||
maxTotal := b.maxTokens
|
||||
if maxTotal <= 0 {
|
||||
maxTotal = 100000
|
||||
}
|
||||
templateTok := b.countTokens(b.buildSimplePrompt("", ""))
|
||||
completion := attackChainMaxCompletionTokens(maxTotal)
|
||||
reserve := templateTok + attackChainSystemReserve + completion + attackChainSafetyReserve
|
||||
budget := maxTotal - reserve
|
||||
minBudget := maxTotal * 35 / 100
|
||||
if budget < minBudget {
|
||||
budget = minBudget
|
||||
}
|
||||
if budget < 4096 {
|
||||
budget = 4096
|
||||
}
|
||||
return budget
|
||||
}
|
||||
|
||||
// fitAttackChainPayload 在构建最终 prompt 前压缩 ReAct 轨迹与模型输出,避免超出模型上下文。
|
||||
func (b *Builder) fitAttackChainPayload(reactInput, modelOutput string) (string, string, bool) {
|
||||
budget := b.attackChainPayloadTokenBudget()
|
||||
modelBudget := budget * 15 / 100
|
||||
if modelBudget < 512 {
|
||||
modelBudget = 512
|
||||
}
|
||||
reactBudget := budget - modelBudget
|
||||
|
||||
origReactTok := b.countTokens(reactInput)
|
||||
origModelTok := b.countTokens(modelOutput)
|
||||
truncated := false
|
||||
|
||||
outModel := modelOutput
|
||||
if origModelTok > modelBudget {
|
||||
outModel = truncateTextByTokens(b, modelOutput, modelBudget)
|
||||
truncated = true
|
||||
}
|
||||
|
||||
outReact := reactInput
|
||||
perToolLimits := []int{12000, 6000, 3000, 1500, 800}
|
||||
for _, lim := range perToolLimits {
|
||||
compact := compactFormattedToolBodies(outReact, lim)
|
||||
if compact != outReact {
|
||||
outReact = compact
|
||||
truncated = true
|
||||
}
|
||||
if b.countTokens(outReact) <= reactBudget {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if b.countTokens(outReact) > reactBudget {
|
||||
outReact = truncateTextByTokens(b, outReact, reactBudget)
|
||||
truncated = true
|
||||
}
|
||||
|
||||
if truncated {
|
||||
b.logger.Info("攻击链输入已按 token 预算截断",
|
||||
zap.Int("maxTotalTokens", b.maxTokens),
|
||||
zap.Int("payloadBudget", budget),
|
||||
zap.Int("reactBudget", reactBudget),
|
||||
zap.Int("modelBudget", modelBudget),
|
||||
zap.Int("reactInputTokensBefore", origReactTok),
|
||||
zap.Int("reactInputTokensAfter", b.countTokens(outReact)),
|
||||
zap.Int("modelOutputTokensBefore", origModelTok),
|
||||
zap.Int("modelOutputTokensAfter", b.countTokens(outModel)),
|
||||
zap.Int("maxCompletionTokens", attackChainMaxCompletionTokens(b.maxTokens)),
|
||||
)
|
||||
}
|
||||
|
||||
return outReact, outModel, truncated
|
||||
}
|
||||
|
||||
// compactFormattedToolBodies 缩短格式化 trace 中 [tool] 消息的正文,保留工具头与调用 ID。
|
||||
func compactFormattedToolBodies(s string, maxRunesPerBody int) string {
|
||||
if maxRunesPerBody <= 0 || s == "" {
|
||||
return s
|
||||
}
|
||||
const marker = "[tool]"
|
||||
var out strings.Builder
|
||||
remaining := s
|
||||
changed := false
|
||||
for {
|
||||
idx := strings.Index(remaining, marker)
|
||||
if idx < 0 {
|
||||
out.WriteString(remaining)
|
||||
break
|
||||
}
|
||||
out.WriteString(remaining[:idx])
|
||||
remaining = remaining[idx:]
|
||||
nl := strings.IndexByte(remaining, '\n')
|
||||
if nl < 0 {
|
||||
out.WriteString(remaining)
|
||||
break
|
||||
}
|
||||
header := remaining[:nl+1]
|
||||
remaining = remaining[nl+1:]
|
||||
bodyEnd := strings.Index(remaining, "\n\n[")
|
||||
var body, rest string
|
||||
if bodyEnd < 0 {
|
||||
body = remaining
|
||||
rest = ""
|
||||
} else {
|
||||
body = remaining[:bodyEnd]
|
||||
rest = remaining[bodyEnd:]
|
||||
}
|
||||
if runeLen(body) > maxRunesPerBody {
|
||||
body = truncateRunesWithNotice(body, maxRunesPerBody)
|
||||
changed = true
|
||||
}
|
||||
out.WriteString(header)
|
||||
out.WriteString(body)
|
||||
remaining = rest
|
||||
if rest == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return s
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func truncateTextByTokens(b *Builder, text string, maxTokens int) string {
|
||||
if maxTokens <= 0 || text == "" {
|
||||
return ""
|
||||
}
|
||||
if b.countTokens(text) <= maxTokens {
|
||||
return text
|
||||
}
|
||||
markerTok := b.countTokens(attackChainTruncationMarker)
|
||||
usable := maxTokens - markerTok
|
||||
if usable < 256 {
|
||||
usable = maxTokens / 2
|
||||
}
|
||||
headBudget := usable * 60 / 100
|
||||
tailBudget := usable - headBudget
|
||||
head := takeTokensFromStart(b, text, headBudget)
|
||||
tail := takeTokensFromEnd(b, text, tailBudget)
|
||||
return head + attackChainTruncationMarker + tail
|
||||
}
|
||||
|
||||
func takeTokensFromStart(b *Builder, text string, maxTokens int) string {
|
||||
rs := []rune(text)
|
||||
if len(rs) == 0 || maxTokens <= 0 {
|
||||
return ""
|
||||
}
|
||||
lo, hi := 0, len(rs)
|
||||
for lo < hi {
|
||||
mid := (lo + hi + 1) / 2
|
||||
if b.countTokens(string(rs[:mid])) <= maxTokens {
|
||||
lo = mid
|
||||
} else {
|
||||
hi = mid - 1
|
||||
}
|
||||
}
|
||||
return string(rs[:lo])
|
||||
}
|
||||
|
||||
func takeTokensFromEnd(b *Builder, text string, maxTokens int) string {
|
||||
rs := []rune(text)
|
||||
if len(rs) == 0 || maxTokens <= 0 {
|
||||
return ""
|
||||
}
|
||||
lo, hi := 0, len(rs)
|
||||
for lo < hi {
|
||||
mid := (lo + hi) / 2
|
||||
if b.countTokens(string(rs[mid:])) <= maxTokens {
|
||||
hi = mid
|
||||
} else {
|
||||
lo = mid + 1
|
||||
}
|
||||
}
|
||||
return string(rs[lo:])
|
||||
}
|
||||
|
||||
func truncateRunesWithNotice(s string, maxRunes int) string {
|
||||
rs := []rune(s)
|
||||
if len(rs) <= maxRunes {
|
||||
return s
|
||||
}
|
||||
const notice = "\n...[工具输出已截断 / tool output truncated]...\n"
|
||||
noticeRunes := []rune(notice)
|
||||
keep := maxRunes - len(noticeRunes)
|
||||
if keep < 200 {
|
||||
keep = maxRunes * 2 / 3
|
||||
}
|
||||
if keep < 1 {
|
||||
return notice
|
||||
}
|
||||
head := keep * 70 / 100
|
||||
tail := keep - head
|
||||
return string(rs[:head]) + notice + string(rs[len(rs)-tail:])
|
||||
}
|
||||
|
||||
func runeLen(s string) int {
|
||||
return len([]rune(s))
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package attackchain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func testBuilder(maxTotal int) *Builder {
|
||||
return &Builder{
|
||||
logger: zap.NewNop(),
|
||||
openAIConfig: &config.OpenAIConfig{Model: "gpt-4"},
|
||||
tokenCounter: agent.NewTikTokenCounter(),
|
||||
maxTokens: maxTotal,
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactFormattedToolBodies(t *testing.T) {
|
||||
long := strings.Repeat("x", 20000)
|
||||
in := "[user]: hi\n\n[tool] (tool_call_id: abc):\n" + long + "\n\n[assistant]: done\n"
|
||||
out := compactFormattedToolBodies(in, 500)
|
||||
if strings.Contains(out, strings.Repeat("x", 10000)) {
|
||||
t.Fatal("expected tool body to be truncated")
|
||||
}
|
||||
if !strings.Contains(out, "[user]: hi") {
|
||||
t.Fatal("expected user header preserved")
|
||||
}
|
||||
if !strings.Contains(out, "[assistant]: done") {
|
||||
t.Fatal("expected assistant header preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFitAttackChainPayloadWithinBudget(t *testing.T) {
|
||||
b := testBuilder(32000)
|
||||
react := strings.Repeat("scan ", 50000)
|
||||
model := strings.Repeat("result ", 10000)
|
||||
r, m, truncated := b.fitAttackChainPayload(react, model)
|
||||
if !truncated {
|
||||
t.Fatal("expected truncation for large payload")
|
||||
}
|
||||
prompt := b.buildSimplePrompt(r, m)
|
||||
total := b.countTokens(prompt) + attackChainMaxCompletionTokens(b.maxTokens) + attackChainSystemReserve
|
||||
if total > b.maxTokens+attackChainSafetyReserve {
|
||||
t.Fatalf("prompt still too large: estimated %d > max %d", total, b.maxTokens)
|
||||
}
|
||||
_ = m
|
||||
}
|
||||
|
||||
func TestAttackChainMaxCompletionTokens(t *testing.T) {
|
||||
if got := attackChainMaxCompletionTokens(120000); got != 15000 && got != 16384 {
|
||||
// 120000/8 = 15000
|
||||
if got < 4096 || got > 16384 {
|
||||
t.Fatalf("unexpected completion cap: %d", got)
|
||||
}
|
||||
}
|
||||
if got := attackChainMaxCompletionTokens(0); got != 8192 {
|
||||
t.Fatalf("expected default 8192, got %d", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterConversationCreateHook records platform audit rows for every new conversation.
|
||||
func RegisterConversationCreateHook(s *Service) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
database.SetConversationCreateHook(func(conv *database.Conversation, meta database.ConversationCreateMeta) {
|
||||
detail := map[string]interface{}{
|
||||
"title": conv.Title,
|
||||
"source": meta.Source,
|
||||
}
|
||||
if meta.WebShellConnectionID != "" {
|
||||
detail["webshell_connection_id"] = meta.WebShellConnectionID
|
||||
}
|
||||
s.Record(nil, Entry{
|
||||
Category: "conversation",
|
||||
Action: "create",
|
||||
Result: "success",
|
||||
Message: "创建对话",
|
||||
ResourceType: "conversation",
|
||||
ResourceID: conv.ID,
|
||||
Detail: detail,
|
||||
ClientIP: meta.ClientIP,
|
||||
SessionHint: meta.SessionHint,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ConversationCreateMeta builds audit metadata for conversation creation.
|
||||
func ConversationCreateMeta(source string) database.ConversationCreateMeta {
|
||||
return database.ConversationCreateMeta{Source: strings.TrimSpace(source)}
|
||||
}
|
||||
|
||||
// ConversationCreateMetaFromGin includes client IP and session hint when available.
|
||||
func ConversationCreateMetaFromGin(c *gin.Context, source string) database.ConversationCreateMeta {
|
||||
m := ConversationCreateMeta(source)
|
||||
if c == nil {
|
||||
return m
|
||||
}
|
||||
m.ClientIP = c.ClientIP()
|
||||
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
|
||||
m.SessionHint = sessionHint(token)
|
||||
}
|
||||
return m
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package audit
|
||||
|
||||
// RetentionDays returns configured retention; 0 means keep forever.
|
||||
func (s *Service) RetentionDays() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 0
|
||||
}
|
||||
return s.cfg.Audit.RetentionDaysEffective()
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package audit
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// RecordAction writes a platform audit row with common defaults.
|
||||
func (s *Service) RecordAction(c *gin.Context, category, action, result, message, resourceType, resourceID string, detail map[string]interface{}) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.Record(c, Entry{
|
||||
Category: category,
|
||||
Action: action,
|
||||
Result: result,
|
||||
Message: message,
|
||||
ResourceType: resourceType,
|
||||
ResourceID: resourceID,
|
||||
Detail: detail,
|
||||
})
|
||||
}
|
||||
|
||||
// RecordOK is a shorthand for successful operations.
|
||||
func (s *Service) RecordOK(c *gin.Context, category, action, message, resourceType, resourceID string, detail map[string]interface{}) {
|
||||
s.RecordAction(c, category, action, "success", message, resourceType, resourceID, detail)
|
||||
}
|
||||
|
||||
// RecordFail is a shorthand for failed operations.
|
||||
func (s *Service) RecordFail(c *gin.Context, category, action, message string, detail map[string]interface{}) {
|
||||
s.RecordAction(c, category, action, "failure", message, "", "", detail)
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
var auditActionsResourceRemoved = map[string]bool{
|
||||
"delete": true,
|
||||
"item_delete": true,
|
||||
"connection_delete": true,
|
||||
"listener_delete": true,
|
||||
"session_delete": true,
|
||||
"task_delete": true,
|
||||
"execution_delete": true,
|
||||
"execution_delete_batch": true,
|
||||
"delete_queue": true,
|
||||
"delete_batch_task": true,
|
||||
"markdown_delete": true,
|
||||
}
|
||||
|
||||
// ApplyResourceAvailability sets log.ResourceAvailable when the linked resource can be checked.
|
||||
func ApplyResourceAvailability(db *database.DB, log *database.AuditLog) {
|
||||
if log == nil || strings.TrimSpace(log.ResourceID) == "" {
|
||||
return
|
||||
}
|
||||
if auditActionsResourceRemoved[log.Action] {
|
||||
f := false
|
||||
log.ResourceAvailable = &f
|
||||
return
|
||||
}
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
available, known := resourceStillExists(db, log.ResourceType, log.ResourceID)
|
||||
if known {
|
||||
log.ResourceAvailable = &available
|
||||
}
|
||||
}
|
||||
|
||||
func resourceStillExists(db *database.DB, resourceType, resourceID string) (bool, bool) {
|
||||
resourceID = strings.TrimSpace(resourceID)
|
||||
if resourceID == "" {
|
||||
return false, false
|
||||
}
|
||||
t := strings.TrimSpace(resourceType)
|
||||
if t == "" {
|
||||
if len(resourceID) > 8 && !strings.HasPrefix(resourceID, "c2_") {
|
||||
t = "conversation"
|
||||
} else {
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
switch t {
|
||||
case "conversation":
|
||||
ok, err := db.ConversationExists(resourceID)
|
||||
return ok, err == nil
|
||||
case "vulnerability":
|
||||
_, err := db.GetVulnerability(resourceID)
|
||||
if err != nil {
|
||||
return false, strings.Contains(err.Error(), "不存在")
|
||||
}
|
||||
return true, true
|
||||
case "batch_queue":
|
||||
_, err := db.GetBatchQueue(resourceID)
|
||||
return err == nil, true
|
||||
case "c2_listener":
|
||||
_, err := db.GetC2Listener(resourceID)
|
||||
return err == nil, true
|
||||
case "c2_session":
|
||||
_, err := db.GetC2Session(resourceID)
|
||||
return err == nil, true
|
||||
case "c2_task":
|
||||
_, err := db.GetC2Task(resourceID)
|
||||
return err == nil, true
|
||||
case "webshell_connection":
|
||||
c, err := db.GetWebshellConnection(resourceID)
|
||||
return err == nil && c != nil, true
|
||||
case "tool_execution":
|
||||
_, err := db.GetToolExecution(resourceID)
|
||||
return err == nil, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// auditRetentionPurgeInterval is how often PurgeExpired runs while the process is up (startup also purges once).
|
||||
const auditRetentionPurgeInterval = time.Hour
|
||||
|
||||
// StartRetentionLoop periodically purges expired audit rows.
|
||||
func StartRetentionLoop(s *Service, logger *zap.Logger) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(auditRetentionPurgeInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.PurgeExpired()
|
||||
if logger != nil {
|
||||
logger.Debug("audit retention tick completed")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var sensitiveKeySubstrings = []string{
|
||||
"password", "api_key", "apikey", "secret", "token", "authorization",
|
||||
"credential", "private_key", "access_key",
|
||||
}
|
||||
|
||||
// SanitizeDetail redacts sensitive keys and truncates serialized size.
|
||||
func SanitizeDetail(detail map[string]interface{}, maxBytes int) map[string]interface{} {
|
||||
if detail == nil {
|
||||
return nil
|
||||
}
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 8192
|
||||
}
|
||||
out := sanitizeValue("", detail)
|
||||
if m, ok := out.(map[string]interface{}); ok {
|
||||
b, _ := json.Marshal(m)
|
||||
if len(b) > maxBytes {
|
||||
return map[string]interface{}{
|
||||
"_truncated": true,
|
||||
"_preview": string(b[:maxBytes]),
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
return map[string]interface{}{"value": out}
|
||||
}
|
||||
|
||||
func sanitizeValue(key string, v interface{}) interface{} {
|
||||
kl := strings.ToLower(key)
|
||||
for _, sub := range sensitiveKeySubstrings {
|
||||
if strings.Contains(kl, sub) {
|
||||
return "***"
|
||||
}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case map[string]interface{}:
|
||||
m := make(map[string]interface{}, len(t))
|
||||
for k, val := range t {
|
||||
m[k] = sanitizeValue(k, val)
|
||||
}
|
||||
return m
|
||||
case []interface{}:
|
||||
arr := make([]interface{}, len(t))
|
||||
for i, val := range t {
|
||||
arr[i] = sanitizeValue(key, val)
|
||||
}
|
||||
return arr
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Service persists platform audit logs.
|
||||
type Service struct {
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
logger *zap.Logger
|
||||
failThrottle *failureThrottle
|
||||
}
|
||||
|
||||
// NewService creates an audit service.
|
||||
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
failThrottle: newFailureThrottle(),
|
||||
}
|
||||
}
|
||||
|
||||
// Enabled reports whether audit persistence is on.
|
||||
func (s *Service) Enabled() bool {
|
||||
if s == nil || s.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return s.cfg.Audit.EnabledEffective()
|
||||
}
|
||||
|
||||
// Record writes one audit row from a Gin request context.
|
||||
func (s *Service) Record(c *gin.Context, e Entry) {
|
||||
if s == nil || !s.Enabled() || s.db == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(e.Category) == "" || strings.TrimSpace(e.Action) == "" {
|
||||
return
|
||||
}
|
||||
if e.Result == "failure" && !s.allowFailureAudit(c, e) {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(e.Result) == "" {
|
||||
e.Result = "success"
|
||||
}
|
||||
if strings.TrimSpace(e.Level) == "" {
|
||||
if e.Result == "failure" {
|
||||
e.Level = "warn"
|
||||
} else {
|
||||
e.Level = "info"
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(e.Actor) == "" {
|
||||
e.Actor = "admin"
|
||||
}
|
||||
maxDetail := s.cfg.Audit.MaxDetailBytesEffective()
|
||||
detail := SanitizeDetail(e.Detail, maxDetail)
|
||||
|
||||
sessionHintVal := e.SessionHint
|
||||
if sessionHintVal == "" && c != nil {
|
||||
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
|
||||
sessionHintVal = sessionHint(token)
|
||||
}
|
||||
}
|
||||
clientIPVal := e.ClientIP
|
||||
if clientIPVal == "" {
|
||||
clientIPVal = clientIP(c)
|
||||
}
|
||||
|
||||
row := &database.AuditLog{
|
||||
ID: "audit_" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
CreatedAt: time.Now(),
|
||||
Level: e.Level,
|
||||
Category: e.Category,
|
||||
Action: e.Action,
|
||||
Result: e.Result,
|
||||
Actor: e.Actor,
|
||||
SessionHint: sessionHintVal,
|
||||
ClientIP: clientIPVal,
|
||||
UserAgent: userAgent(c),
|
||||
ResourceType: e.ResourceType,
|
||||
ResourceID: e.ResourceID,
|
||||
Message: e.Message,
|
||||
Detail: detail,
|
||||
}
|
||||
if err := s.db.AppendAuditLog(row); err != nil && s.logger != nil {
|
||||
s.logger.Warn("写入审计日志失败",
|
||||
zap.String("action", e.Action),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSystem writes an audit row without HTTP context (e.g. retention cleanup).
|
||||
func (s *Service) RecordSystem(e Entry) {
|
||||
s.Record(nil, e)
|
||||
}
|
||||
|
||||
// PurgeExpired deletes rows older than retention_days when configured.
|
||||
func (s *Service) PurgeExpired() {
|
||||
if s == nil || s.db == nil || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
days := s.cfg.Audit.RetentionDaysEffective()
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
n, err := s.db.DeleteAuditLogsBefore(cutoff)
|
||||
if err != nil {
|
||||
if s.logger != nil {
|
||||
s.logger.Warn("清理过期审计日志失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && s.logger != nil {
|
||||
s.logger.Info("已清理过期审计日志", zap.Int64("deleted", n))
|
||||
}
|
||||
}
|
||||
|
||||
// HintFromToken returns a short stable hash prefix for a session token.
|
||||
func HintFromToken(token string) string {
|
||||
return sessionHint(token)
|
||||
}
|
||||
|
||||
func sessionHint(token string) string {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:4])
|
||||
}
|
||||
|
||||
func (s *Service) allowFailureAudit(c *gin.Context, e Entry) bool {
|
||||
if !isAuthFailureThrottled(e.Category, e.Action) {
|
||||
return true
|
||||
}
|
||||
cooldown := time.Duration(s.cfg.Audit.AuthFailureCooldownEffective()) * time.Second
|
||||
key := authFailureThrottleKey(e.Category, e.Action, clientIP(c))
|
||||
return s.failThrottle.allow(key, cooldown)
|
||||
}
|
||||
|
||||
func clientIP(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
func userAgent(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
ua := c.GetHeader("User-Agent")
|
||||
if len(ua) > 512 {
|
||||
return ua[:512]
|
||||
}
|
||||
return ua
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// failureThrottle deduplicates high-frequency failure audit rows (e.g. wrong password).
|
||||
type failureThrottle struct {
|
||||
mu sync.Mutex
|
||||
last map[string]time.Time
|
||||
}
|
||||
|
||||
func newFailureThrottle() *failureThrottle {
|
||||
return &failureThrottle{last: make(map[string]time.Time)}
|
||||
}
|
||||
|
||||
// allow reports whether a row with the given key may be written now.
|
||||
func (t *failureThrottle) allow(key string, cooldown time.Duration) bool {
|
||||
if t == nil || cooldown <= 0 || key == "" {
|
||||
return true
|
||||
}
|
||||
now := time.Now()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if prev, ok := t.last[key]; ok && now.Sub(prev) < cooldown {
|
||||
return false
|
||||
}
|
||||
t.last[key] = now
|
||||
if len(t.last) > 4096 {
|
||||
for k, ts := range t.last {
|
||||
if now.Sub(ts) > cooldown*2 {
|
||||
delete(t.last, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// authFailureThrottleKey builds a per-IP key for auth failure deduplication.
|
||||
func authFailureThrottleKey(category, action, clientIP string) string {
|
||||
return category + ":" + action + ":" + clientIP
|
||||
}
|
||||
|
||||
func isAuthFailureThrottled(category, action string) bool {
|
||||
if category != "auth" {
|
||||
return false
|
||||
}
|
||||
switch action {
|
||||
case "login", "change_password":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package audit
|
||||
|
||||
// Entry describes one platform audit record (not chat/tool execution bodies).
|
||||
type Entry struct {
|
||||
Level string
|
||||
Category string
|
||||
Action string
|
||||
Result string // success | failure
|
||||
Actor string
|
||||
SessionHint string
|
||||
ResourceType string
|
||||
ResourceID string
|
||||
Message string
|
||||
Detail map[string]interface{}
|
||||
ClientIP string // optional when c is nil (robot, batch, DB hook)
|
||||
}
|
||||
@@ -239,13 +239,15 @@ func (m *Manager) StartListener(id string) (*database.C2Listener, error) {
|
||||
}
|
||||
cfg.ApplyDefaults()
|
||||
|
||||
// 通过工厂创建具体实现
|
||||
// 通过工厂创建具体实现。必须使用 rec 的副本:HTTP handler 在返回 JSON 前会清空
|
||||
// rec.ImplantToken / EncryptionKey 做脱敏,若 listener 实现持有同一指针会导致 beacon 鉴权永久失败。
|
||||
listenerRec := *rec
|
||||
factory := m.registry.Get(rec.Type)
|
||||
if factory == nil {
|
||||
return nil, ErrUnsupportedType
|
||||
}
|
||||
inst, err := factory(ListenerCreationCtx{
|
||||
Listener: rec,
|
||||
Listener: &listenerRec,
|
||||
Config: cfg,
|
||||
Manager: m,
|
||||
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// 回归:StartListener 返回的 rec 被 handler 脱敏清空 ImplantToken 后,运行中的 HTTP listener 仍能鉴权。
|
||||
func TestStartListener_ImplantTokenSurvivesHandlerRedaction(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := lnPick.Addr().(*net.TCPAddr).Port
|
||||
_ = lnPick.Close()
|
||||
|
||||
mgr := NewManager(db, zap.NewNop(), tmp)
|
||||
mgr.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
|
||||
rec, err := mgr.CreateListener(CreateListenerInput{
|
||||
Name: "t",
|
||||
Type: string(ListenerTypeHTTPBeacon),
|
||||
BindHost: "127.0.0.1",
|
||||
BindPort: port,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token := rec.ImplantToken
|
||||
|
||||
rec, err = mgr.StartListener(rec.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 模拟 internal/handler/c2.go StartListener 在 JSON 响应前的脱敏
|
||||
rec.ImplantToken = ""
|
||||
rec.EncryptionKey = ""
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:"+strconv.Itoa(port)+"/check_in", strings.NewReader(body))
|
||||
req.Header.Set("X-Implant-Token", token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
|
||||
}
|
||||
if !strings.Contains(string(b), "session_id") {
|
||||
t.Fatalf("expected session_id in body: %s", b)
|
||||
}
|
||||
_ = mgr.StopListener(rec.ID)
|
||||
}
|
||||
+299
-14
@@ -26,6 +26,7 @@ type Config struct {
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||
@@ -35,13 +36,39 @@ type Config struct {
|
||||
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
|
||||
AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.md,YAML front matter)
|
||||
MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"`
|
||||
Project ProjectConfig `yaml:"project,omitempty" json:"project,omitempty"`
|
||||
}
|
||||
|
||||
// ProjectConfig 项目黑板(跨对话共享事实)配置。
|
||||
type ProjectConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目
|
||||
FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"`
|
||||
FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"`
|
||||
DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"`
|
||||
}
|
||||
|
||||
// FactIndexMaxRunesEffective 自动注入黑板索引的最大 rune 数。
|
||||
func (c ProjectConfig) FactIndexMaxRunesEffective() int {
|
||||
if c.FactIndexMaxRunes <= 0 {
|
||||
return 3500
|
||||
}
|
||||
return c.FactIndexMaxRunes
|
||||
}
|
||||
|
||||
// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数(索引一行,宜含验证要点)。
|
||||
func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
|
||||
if c.FactSummaryMaxRunes <= 0 {
|
||||
return 200
|
||||
}
|
||||
return c.FactSummaryMaxRunes
|
||||
}
|
||||
|
||||
// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。
|
||||
type MultiAgentConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理
|
||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
RobotDefaultAgentMode string `yaml:"robot_default_agent_mode,omitempty" json:"robot_default_agent_mode,omitempty"` // react | eino_single | deep | plan_execute | supervisor
|
||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
|
||||
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
|
||||
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor)
|
||||
@@ -63,6 +90,126 @@ type MultiAgentConfig struct {
|
||||
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
|
||||
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
|
||||
EinoMiddleware MultiAgentEinoMiddlewareConfig `yaml:"eino_middleware,omitempty" json:"eino_middleware,omitempty"`
|
||||
// EinoCallbacks attaches CloudWeGo eino callbacks.InitCallbacks on ADK Runner context (structured logs + optional SSE trace).
|
||||
EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"`
|
||||
}
|
||||
|
||||
// MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single).
|
||||
// Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed).
|
||||
type MultiAgentEinoCallbacksConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"` // log_only | sse | full; empty with enabled=true defaults to log_only
|
||||
// SseTraceToClient when true emits eino_trace_* SSE for UI (use only for admin/debug; nil/false recommended in production).
|
||||
SseTraceToClient *bool `yaml:"sse_trace_to_client,omitempty" json:"sse_trace_to_client,omitempty"`
|
||||
// Otel configures OpenTelemetry trace export (independent of mode; exporter none disables export even if enabled).
|
||||
Otel MultiAgentEinoCallbacksOtelConfig `yaml:"otel,omitempty" json:"otel,omitempty"`
|
||||
// MaxInputSummaryRunes / MaxOutputSummaryRunes cap text placed in SSE payloads and debug logs (not full payloads).
|
||||
MaxInputSummaryRunes int `yaml:"max_input_summary_runes,omitempty" json:"max_input_summary_runes,omitempty"`
|
||||
MaxOutputSummaryRunes int `yaml:"max_output_summary_runes,omitempty" json:"max_output_summary_runes,omitempty"`
|
||||
// ZapVerbose when true logs input/output summaries at zap.Debug on start/end; false uses Info with short fields only.
|
||||
ZapVerbose bool `yaml:"zap_verbose,omitempty" json:"zap_verbose,omitempty"`
|
||||
}
|
||||
|
||||
// MultiAgentEinoCallbacksOtelConfig OpenTelemetry for Eino callback spans (W3C trace in collector / stdout).
|
||||
type MultiAgentEinoCallbacksOtelConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ServiceName string `yaml:"service_name,omitempty" json:"service_name,omitempty"`
|
||||
Exporter string `yaml:"exporter,omitempty" json:"exporter,omitempty"` // none | stdout | otlphttp
|
||||
OTLPEndpoint string `yaml:"otlp_endpoint,omitempty" json:"otlp_endpoint,omitempty"` // host:port, e.g. localhost:4318 (path /v1/traces)
|
||||
SampleRatio float64 `yaml:"sample_ratio,omitempty" json:"sample_ratio,omitempty"` // 0–1, default 1.0
|
||||
}
|
||||
|
||||
// EinoCallbacksModeEffective returns off | log_only | sse | full.
|
||||
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksModeEffective() string {
|
||||
if !c.Enabled {
|
||||
return "off"
|
||||
}
|
||||
m := strings.TrimSpace(strings.ToLower(c.Mode))
|
||||
switch m {
|
||||
case "log_only":
|
||||
return "log_only"
|
||||
case "sse":
|
||||
return "sse"
|
||||
case "full":
|
||||
return "full"
|
||||
case "":
|
||||
return "log_only"
|
||||
default:
|
||||
return "log_only"
|
||||
}
|
||||
}
|
||||
|
||||
// SseTraceToClientEffective is false unless explicitly set true (best practice: do not expose framework traces to end users by default).
|
||||
func (c MultiAgentEinoCallbacksConfig) SseTraceToClientEffective() bool {
|
||||
if c.SseTraceToClient == nil {
|
||||
return false
|
||||
}
|
||||
return *c.SseTraceToClient
|
||||
}
|
||||
|
||||
// ShouldEmitEinoTraceSSE is true when client-visible trace events should be sent over progress/SSE.
|
||||
func (c MultiAgentEinoCallbacksConfig) ShouldEmitEinoTraceSSE(mode string) bool {
|
||||
if !c.SseTraceToClientEffective() {
|
||||
return false
|
||||
}
|
||||
return mode == "sse" || mode == "full"
|
||||
}
|
||||
|
||||
// OtelExporterEffective returns none | stdout | otlphttp.
|
||||
func (c MultiAgentEinoCallbacksOtelConfig) OtelExporterEffective() string {
|
||||
e := strings.TrimSpace(strings.ToLower(c.Exporter))
|
||||
switch e {
|
||||
case "none", "stdout", "otlphttp":
|
||||
return e
|
||||
case "":
|
||||
if c.Enabled {
|
||||
return "stdout"
|
||||
}
|
||||
return "none"
|
||||
default:
|
||||
return "none"
|
||||
}
|
||||
}
|
||||
|
||||
// OtelTracingActive is true when spans should be started (enabled + non-none exporter).
|
||||
func (c MultiAgentEinoCallbacksConfig) OtelTracingActive() bool {
|
||||
if !c.Otel.Enabled {
|
||||
return false
|
||||
}
|
||||
return c.Otel.OtelExporterEffective() != "none"
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksOtelConfig) ServiceNameEffective() string {
|
||||
s := strings.TrimSpace(c.ServiceName)
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
return "cyberstrike-ai"
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksOtelConfig) SampleRatioEffective() float64 {
|
||||
r := c.SampleRatio
|
||||
if r <= 0 {
|
||||
return 1.0
|
||||
}
|
||||
if r > 1 {
|
||||
return 1.0
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxInputSummaryRunes() int {
|
||||
if c.MaxInputSummaryRunes > 0 {
|
||||
return c.MaxInputSummaryRunes
|
||||
}
|
||||
return 400
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxOutputSummaryRunes() int {
|
||||
if c.MaxOutputSummaryRunes > 0 {
|
||||
return c.MaxOutputSummaryRunes
|
||||
}
|
||||
return 400
|
||||
}
|
||||
|
||||
// MultiAgentEinoMiddlewareConfig optional Eino ADK middleware and Deep / supervisor tuning.
|
||||
@@ -90,7 +237,8 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||
// HistoryInputBudgetRatio caps pre-agent history tokens as max_total_tokens * ratio (default 0.35).
|
||||
// HistoryInputBudgetRatio 已不影响 Eino:从 last_react 轨迹转 ADK 消息时**不再**按 token 比例裁剪(完整注入)。
|
||||
// 字段仍保留,便于旧版 config 不报错;新部署可省略。
|
||||
HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"`
|
||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||
@@ -106,6 +254,10 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
||||
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
||||
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
|
||||
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
|
||||
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
||||
}
|
||||
@@ -241,9 +393,9 @@ type MultiAgentSubConfig struct {
|
||||
|
||||
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
|
||||
type MultiAgentPublic struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
SubAgentCount int `json:"sub_agent_count"`
|
||||
Orchestration string `json:"orchestration,omitempty"`
|
||||
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
|
||||
@@ -251,6 +403,18 @@ type MultiAgentPublic struct {
|
||||
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
|
||||
}
|
||||
|
||||
// NormalizeRobotAgentMode 解析机器人默认对话模式(react | eino_single | deep | plan_execute | supervisor);空值视为 react。
|
||||
func NormalizeRobotAgentMode(ma MultiAgentConfig) string {
|
||||
s := strings.TrimSpace(strings.ToLower(ma.RobotDefaultAgentMode))
|
||||
if s == "" || s == "single" || s == "react" {
|
||||
return "react"
|
||||
}
|
||||
if s == "eino_single" {
|
||||
return "eino_single"
|
||||
}
|
||||
return NormalizeMultiAgentOrchestration(s)
|
||||
}
|
||||
|
||||
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
|
||||
func NormalizeMultiAgentOrchestration(s string) string {
|
||||
v := strings.TrimSpace(strings.ToLower(s))
|
||||
@@ -266,21 +430,35 @@ func NormalizeMultiAgentOrchestration(s string) string {
|
||||
|
||||
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
|
||||
type MultiAgentAPIUpdate struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
||||
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
|
||||
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
}
|
||||
|
||||
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
|
||||
// RobotsConfig 机器人配置(企业微信、钉钉、飞书、微信 iLink 等)
|
||||
type RobotsConfig struct {
|
||||
Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略
|
||||
Wechat RobotWechatConfig `yaml:"wechat,omitempty" json:"wechat,omitempty"` // 微信(iLink 扫码绑定)
|
||||
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
|
||||
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
|
||||
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
|
||||
}
|
||||
|
||||
// RobotWechatConfig 微信 iLink 机器人配置(个人微信 ClawBot / iLink 协议)
|
||||
type RobotWechatConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
BotToken string `yaml:"bot_token,omitempty" json:"bot_token,omitempty"`
|
||||
ILinkBotID string `yaml:"ilink_bot_id,omitempty" json:"ilink_bot_id,omitempty"`
|
||||
ILinkUserID string `yaml:"ilink_user_id,omitempty" json:"ilink_user_id,omitempty"`
|
||||
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://ilinkai.weixin.qq.com
|
||||
BotType string `yaml:"bot_type,omitempty" json:"bot_type,omitempty"` // get_bot_qrcode 参数,默认 3
|
||||
BotAgent string `yaml:"bot_agent,omitempty" json:"bot_agent,omitempty"` // base_info.bot_agent
|
||||
GetUpdatesBuf string `yaml:"get_updates_buf,omitempty" json:"get_updates_buf,omitempty"` // 长轮询游标(运行时)
|
||||
}
|
||||
|
||||
// RobotSessionConfig 机器人会话隔离策略
|
||||
type RobotSessionConfig struct {
|
||||
StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底
|
||||
@@ -322,8 +500,17 @@ type RobotLarkConfig struct {
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
Host string `yaml:"host" json:"host"`
|
||||
Port int `yaml:"port" json:"port"`
|
||||
// TLSEnabled 为 true 时主 Web UI 使用 HTTPS;现代浏览器在同源下会协商 HTTP/2,缓解 HTTP/1.1 每源并发连接数限制。
|
||||
TLSEnabled bool `yaml:"tls_enabled,omitempty" json:"tls_enabled,omitempty"`
|
||||
// TLSCertPath / TLSKeyPath 非空时从 PEM 文件加载证书(生产环境推荐)。
|
||||
TLSCertPath string `yaml:"tls_cert_path,omitempty" json:"tls_cert_path,omitempty"`
|
||||
TLSKeyPath string `yaml:"tls_key_path,omitempty" json:"tls_key_path,omitempty"`
|
||||
// TLSAutoSelfSign 为 true 且未配置有效证书路径时,启动时生成内存自签证书(仅本地/测试;浏览器会提示不受信任)。
|
||||
TLSAutoSelfSign bool `yaml:"tls_auto_self_sign,omitempty" json:"tls_auto_self_sign,omitempty"`
|
||||
// TLSHTTPRedirect 为 false 时禁用 HTTP→HTTPS 跳转;省略或为 true 且已启用 HTTPS 时,明文 HTTP 访问将 308 跳转到 HTTPS(同端口嗅探分流)。
|
||||
TLSHTTPRedirect *bool `yaml:"tls_http_redirect,omitempty" json:"tls_http_redirect,omitempty"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
@@ -345,6 +532,48 @@ type OpenAIConfig struct {
|
||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||
Model string `yaml:"model" json:"model"`
|
||||
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
|
||||
// Reasoning 控制 Eino ChatModel 的 thinking / reasoning_effort / output_config 等(仅 Eino 路径生效;原生 ReAct 忽略)。
|
||||
Reasoning OpenAIReasoningConfig `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIReasoningConfig 全局默认与网关 profile(对话页可通过 ChatRequest.reasoning 覆盖,受 AllowClientReasoning 约束)。
|
||||
type OpenAIReasoningConfig struct {
|
||||
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||
// Effort: low | medium | high | max | xhigh;max/xhigh 为不同网关最高档命名,原样下发、不互转。空表示不单独指定强度。
|
||||
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
|
||||
// AllowClientReasoning 为 false 时忽略请求体 reasoning;nil 或未设置等同于 true。
|
||||
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
|
||||
// Profile: auto | deepseek_compat | openai_compat | output_config_effort
|
||||
Profile string `yaml:"profile,omitempty" json:"profile,omitempty"`
|
||||
// ExtraRequestFields 合并进 Chat Completions 根 JSON(管理员用;与自动字段同名时后者覆盖)。
|
||||
ExtraRequestFields map[string]interface{} `yaml:"extra_request_fields,omitempty" json:"extra_request_fields,omitempty"`
|
||||
}
|
||||
|
||||
// ModeEffective returns auto when empty or default.
|
||||
func (c OpenAIReasoningConfig) ModeEffective() string {
|
||||
m := strings.ToLower(strings.TrimSpace(c.Mode))
|
||||
if m == "" || m == "default" {
|
||||
return "auto"
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ProfileEffective returns auto when empty.
|
||||
func (c OpenAIReasoningConfig) ProfileEffective() string {
|
||||
p := strings.ToLower(strings.TrimSpace(c.Profile))
|
||||
if p == "" {
|
||||
return "auto"
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// AllowClientReasoningEffective true when client may send ChatRequest.reasoning.
|
||||
func (c OpenAIReasoningConfig) AllowClientReasoningEffective() bool {
|
||||
if c.AllowClientReasoning == nil {
|
||||
return true
|
||||
}
|
||||
return *c.AllowClientReasoning
|
||||
}
|
||||
|
||||
type FofaConfig struct {
|
||||
@@ -389,6 +618,51 @@ type AuthConfig struct {
|
||||
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
|
||||
type AuditConfig struct {
|
||||
// Enabled nil or true enables persistence; explicit false disables.
|
||||
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||
RetentionDays int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||
MaxDetailBytes int `yaml:"max_detail_bytes,omitempty" json:"max_detail_bytes,omitempty"`
|
||||
// AuthFailureCooldownSeconds: per-IP cooldown for auth login/change_password failure audit rows; -1 disables; 0 uses default 60.
|
||||
AuthFailureCooldownSeconds int `yaml:"auth_failure_cooldown_seconds,omitempty" json:"auth_failure_cooldown_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// EnabledEffective returns true unless audit.enabled is explicitly false.
|
||||
func (a AuditConfig) EnabledEffective() bool {
|
||||
if a.Enabled == nil {
|
||||
return true
|
||||
}
|
||||
return *a.Enabled
|
||||
}
|
||||
|
||||
// RetentionDaysEffective returns retention; 0 means keep forever.
|
||||
func (a AuditConfig) RetentionDaysEffective() int {
|
||||
if a.RetentionDays < 0 {
|
||||
return 0
|
||||
}
|
||||
return a.RetentionDays
|
||||
}
|
||||
|
||||
// MaxDetailBytesEffective caps serialized detail JSON size.
|
||||
func (a AuditConfig) MaxDetailBytesEffective() int {
|
||||
if a.MaxDetailBytes <= 0 {
|
||||
return 8192
|
||||
}
|
||||
return a.MaxDetailBytes
|
||||
}
|
||||
|
||||
// AuthFailureCooldownEffective returns seconds between duplicate auth-failure audit rows per IP (default 60; -1 disables).
|
||||
func (a AuditConfig) AuthFailureCooldownEffective() int {
|
||||
if a.AuthFailureCooldownSeconds < 0 {
|
||||
return 0
|
||||
}
|
||||
if a.AuthFailureCooldownSeconds == 0 {
|
||||
return 60
|
||||
}
|
||||
return a.AuthFailureCooldownSeconds
|
||||
}
|
||||
|
||||
// ExternalMCPConfig 外部MCP配置
|
||||
type ExternalMCPConfig struct {
|
||||
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
|
||||
@@ -481,6 +755,9 @@ func Load(path string) (*Config, error) {
|
||||
if cfg.Auth.SessionDurationHours <= 0 {
|
||||
cfg.Auth.SessionDurationHours = 12
|
||||
}
|
||||
if cfg.Audit.MaxDetailBytes <= 0 {
|
||||
cfg.Audit.MaxDetailBytes = 8192
|
||||
}
|
||||
if strings.TrimSpace(cfg.Auth.Password) == "" {
|
||||
password, err := generateStrongPassword(24)
|
||||
if err != nil {
|
||||
@@ -984,6 +1261,14 @@ func Default() *Config {
|
||||
Auth: AuthConfig{
|
||||
SessionDurationHours: 12,
|
||||
},
|
||||
Audit: func() AuditConfig {
|
||||
on := true
|
||||
return AuditConfig{
|
||||
RetentionDays: 90,
|
||||
MaxDetailBytes: 8192,
|
||||
Enabled: &on,
|
||||
}
|
||||
}(),
|
||||
Robots: RobotsConfig{
|
||||
Session: RobotSessionConfig{
|
||||
StrictUserIdentity: &strictRobotIdentity,
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package config
|
||||
|
||||
import "strings"
|
||||
|
||||
// MainWebUIUsesHTTPS 判断主 Web UI 是否以 HTTPS 监听(与 internal/app.prepareMainServerTLS 前置条件一致)。
|
||||
func MainWebUIUsesHTTPS(s *ServerConfig) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.TLSEnabled {
|
||||
return true
|
||||
}
|
||||
if s.TLSAutoSelfSign {
|
||||
return true
|
||||
}
|
||||
cert := strings.TrimSpace(s.TLSCertPath)
|
||||
key := strings.TrimSpace(s.TLSKeyPath)
|
||||
return cert != "" && key != ""
|
||||
}
|
||||
|
||||
// ServerHTTPRedirectEnabled 是否在主站启用 HTTPS 时把明文 HTTP 请求重定向到 HTTPS(默认开启)。
|
||||
func ServerHTTPRedirectEnabled(s *ServerConfig) bool {
|
||||
if s == nil || !MainWebUIUsesHTTPS(s) {
|
||||
return false
|
||||
}
|
||||
if s.TLSHTTPRedirect == nil {
|
||||
return true
|
||||
}
|
||||
return *s.TLSHTTPRedirect
|
||||
}
|
||||
|
||||
// ApplyDevHTTPSBootstrap 供 --https / 一键脚本使用:强制开启主站 TLS。
|
||||
// 若已配置 tls_cert_path 与 tls_key_path 则仅用 PEM,不开启自签;否则启用 tls_auto_self_sign(内存证书,仅本地测试)。
|
||||
func ApplyDevHTTPSBootstrap(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.Server.TLSEnabled = true
|
||||
cert := strings.TrimSpace(cfg.Server.TLSCertPath)
|
||||
key := strings.TrimSpace(cfg.Server.TLSKeyPath)
|
||||
if cert != "" && key != "" {
|
||||
cfg.Server.TLSAutoSelfSign = false
|
||||
return
|
||||
}
|
||||
cfg.Server.TLSAutoSelfSign = true
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuditLog platform operation audit record.
|
||||
type AuditLog struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Level string `json:"level"`
|
||||
Category string `json:"category"`
|
||||
Action string `json:"action"`
|
||||
Result string `json:"result"`
|
||||
Actor string `json:"actor"`
|
||||
SessionHint string `json:"sessionHint,omitempty"`
|
||||
ClientIP string `json:"clientIp,omitempty"`
|
||||
UserAgent string `json:"userAgent,omitempty"`
|
||||
ResourceType string `json:"resourceType,omitempty"`
|
||||
ResourceID string `json:"resourceId,omitempty"`
|
||||
ResourceAvailable *bool `json:"resourceAvailable,omitempty"` // API-only: whether linked resource still exists
|
||||
Message string `json:"message"`
|
||||
Detail map[string]interface{} `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
// ListAuditLogsFilter query parameters.
|
||||
type ListAuditLogsFilter struct {
|
||||
Level string
|
||||
Category string
|
||||
Action string
|
||||
Result string
|
||||
Query string
|
||||
ResourceType string
|
||||
ResourceID string
|
||||
Since *time.Time
|
||||
Until *time.Time
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
|
||||
conditions := []string{"1=1"}
|
||||
args := []interface{}{}
|
||||
if filter.Level != "" {
|
||||
conditions = append(conditions, "level = ?")
|
||||
args = append(args, filter.Level)
|
||||
}
|
||||
if filter.Category != "" {
|
||||
conditions = append(conditions, "category = ?")
|
||||
args = append(args, filter.Category)
|
||||
}
|
||||
if filter.Action != "" {
|
||||
conditions = append(conditions, "action = ?")
|
||||
args = append(args, filter.Action)
|
||||
}
|
||||
if filter.Result != "" {
|
||||
conditions = append(conditions, "result = ?")
|
||||
args = append(args, filter.Result)
|
||||
}
|
||||
if filter.ResourceType != "" {
|
||||
conditions = append(conditions, "resource_type = ?")
|
||||
args = append(args, filter.ResourceType)
|
||||
}
|
||||
if filter.ResourceID != "" {
|
||||
conditions = append(conditions, "resource_id = ?")
|
||||
args = append(args, filter.ResourceID)
|
||||
}
|
||||
if filter.Since != nil {
|
||||
conditions = append(conditions, "created_at >= ?")
|
||||
args = append(args, *filter.Since)
|
||||
}
|
||||
if filter.Until != nil {
|
||||
conditions = append(conditions, "created_at <= ?")
|
||||
args = append(args, *filter.Until)
|
||||
}
|
||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||
like := "%" + q + "%"
|
||||
conditions = append(conditions, "(message LIKE ? OR resource_id LIKE ? OR action LIKE ? OR category LIKE ?)")
|
||||
args = append(args, like, like, like, like)
|
||||
}
|
||||
return strings.Join(conditions, " AND "), args
|
||||
}
|
||||
|
||||
// AppendAuditLog inserts one audit row.
|
||||
func (db *DB) AppendAuditLog(row *AuditLog) error {
|
||||
if row == nil {
|
||||
return errors.New("audit log is nil")
|
||||
}
|
||||
if strings.TrimSpace(row.ID) == "" {
|
||||
return errors.New("audit id is required")
|
||||
}
|
||||
if row.CreatedAt.IsZero() {
|
||||
row.CreatedAt = time.Now()
|
||||
}
|
||||
if strings.TrimSpace(row.Level) == "" {
|
||||
row.Level = "info"
|
||||
}
|
||||
detailJSON := ""
|
||||
if len(row.Detail) > 0 {
|
||||
if b, err := json.Marshal(row.Detail); err == nil {
|
||||
detailJSON = string(b)
|
||||
}
|
||||
}
|
||||
query := `
|
||||
INSERT INTO audit_logs (
|
||||
id, created_at, level, category, action, result, actor, session_hint,
|
||||
client_ip, user_agent, resource_type, resource_id, message, detail_json
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
_, err := db.Exec(query,
|
||||
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result,
|
||||
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
|
||||
row.ResourceType, row.ResourceID, row.Message, detailJSON,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAuditLogByID returns one row.
|
||||
func (db *DB) GetAuditLogByID(id string) (*AuditLog, error) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return nil, errors.New("id is required")
|
||||
}
|
||||
query := `
|
||||
SELECT id, created_at, level, category, action, result, actor,
|
||||
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
|
||||
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
|
||||
FROM audit_logs WHERE id = ?
|
||||
`
|
||||
var row AuditLog
|
||||
var detailJSON string
|
||||
err := db.QueryRow(query, id).Scan(
|
||||
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
|
||||
&row.SessionHint, &row.ClientIP, &row.UserAgent,
|
||||
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if detailJSON != "" {
|
||||
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// CountAuditLogs counts rows matching filter.
|
||||
func (db *DB) CountAuditLogs(filter ListAuditLogsFilter) (int64, error) {
|
||||
where, args := buildAuditLogsWhere(filter)
|
||||
query := `SELECT COUNT(*) FROM audit_logs WHERE ` + where
|
||||
var n int64
|
||||
err := db.QueryRow(query, args...).Scan(&n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// ListAuditLogs lists audit rows newest first.
|
||||
func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
|
||||
where, args := buildAuditLogsWhere(filter)
|
||||
limit := filter.Limit
|
||||
if limit <= 0 || limit > 500 {
|
||||
limit = 50
|
||||
}
|
||||
offset := filter.Offset
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
query := `
|
||||
SELECT id, created_at, level, category, action, result, actor,
|
||||
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
|
||||
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
|
||||
FROM audit_logs
|
||||
WHERE ` + where + `
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
`
|
||||
args = append(args, limit, offset)
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var list []*AuditLog
|
||||
for rows.Next() {
|
||||
var row AuditLog
|
||||
var detailJSON string
|
||||
if err := rows.Scan(
|
||||
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
|
||||
&row.SessionHint, &row.ClientIP, &row.UserAgent,
|
||||
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
|
||||
); err != nil {
|
||||
continue
|
||||
}
|
||||
if detailJSON != "" {
|
||||
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
|
||||
}
|
||||
list = append(list, &row)
|
||||
}
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteAuditLogsBefore removes rows older than cutoff.
|
||||
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
|
||||
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
@@ -22,6 +22,7 @@ type BatchTaskQueueRow struct {
|
||||
LastScheduleTriggerAt sql.NullTime
|
||||
LastScheduleError sql.NullString
|
||||
LastRunError sql.NullString
|
||||
ProjectID sql.NullString
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
StartedAt sql.NullTime
|
||||
@@ -51,6 +52,7 @@ func (db *DB) CreateBatchQueue(
|
||||
scheduleMode string,
|
||||
cronExpr string,
|
||||
nextRunAt *time.Time,
|
||||
projectID string,
|
||||
tasks []map[string]interface{},
|
||||
) error {
|
||||
tx, err := db.Begin()
|
||||
@@ -65,9 +67,13 @@ func (db *DB) CreateBatchQueue(
|
||||
nextRunAtValue = *nextRunAt
|
||||
}
|
||||
|
||||
var projectIDVal interface{}
|
||||
if strings.TrimSpace(projectID) != "" {
|
||||
projectIDVal = strings.TrimSpace(projectID)
|
||||
}
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, "pending", now, 0,
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
||||
@@ -101,9 +107,9 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
||||
queueID,
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -127,7 +133,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
// GetAllBatchQueues 获取所有批量任务队列
|
||||
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||
@@ -138,7 +144,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
@@ -158,7 +164,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
|
||||
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
||||
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
// 状态筛选
|
||||
@@ -186,7 +192,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
type Conversation struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
Pinned bool `json:"pinned"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
@@ -25,33 +26,53 @@ type Conversation struct {
|
||||
|
||||
// Message 消息
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoningContent,omitempty"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
func (db *DB) CreateConversation(title string) (*Conversation, error) {
|
||||
return db.CreateConversationWithWebshell("", title)
|
||||
func (db *DB) CreateConversation(title string, meta ConversationCreateMeta) (*Conversation, error) {
|
||||
return db.CreateConversationWithWebshell("", title, meta)
|
||||
}
|
||||
|
||||
// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话)
|
||||
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) {
|
||||
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string, meta ConversationCreateMeta) (*Conversation, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
projectID := strings.TrimSpace(meta.ProjectID)
|
||||
if projectID != "" {
|
||||
if _, err := db.GetProject(projectID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if webshellConnectionID != "" {
|
||||
wsID := strings.TrimSpace(webshellConnectionID)
|
||||
switch {
|
||||
case wsID != "" && projectID != "":
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id, project_id) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
id, title, now, now, wsID, projectID,
|
||||
)
|
||||
case wsID != "":
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)",
|
||||
id, title, now, now, webshellConnectionID,
|
||||
id, title, now, now, wsID,
|
||||
)
|
||||
} else {
|
||||
case projectID != "":
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at, project_id) VALUES (?, ?, ?, ?, ?)",
|
||||
id, title, now, now, projectID,
|
||||
)
|
||||
default:
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
|
||||
id, title, now, now,
|
||||
@@ -61,12 +82,18 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string)
|
||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||
}
|
||||
|
||||
return &Conversation{
|
||||
conv := &Conversation{
|
||||
ID: id,
|
||||
Title: title,
|
||||
ProjectID: projectID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
if wsID != "" {
|
||||
meta.WebShellConnectionID = wsID
|
||||
}
|
||||
notifyConversationCreated(conv, meta)
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化)
|
||||
@@ -116,6 +143,7 @@ func (db *DB) GetConversationByWebshellConnectionID(connectionID string) (*Conve
|
||||
}
|
||||
for i := range conv.Messages {
|
||||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||||
details = DedupeConsecutiveProcessDetails(details)
|
||||
detailsJSON := make([]map[string]interface{}, len(details))
|
||||
for j, detail := range details {
|
||||
var data interface{}
|
||||
@@ -180,22 +208,43 @@ func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]We
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
// ConversationExists reports whether a conversation row exists (lightweight check for audit links).
|
||||
func (db *DB) ConversationExists(id string) (bool, error) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return false, nil
|
||||
}
|
||||
var one int
|
||||
err := db.QueryRow("SELECT 1 FROM conversations WHERE id = ? LIMIT 1", id).Scan(&one)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetConversation 获取对话
|
||||
func (db *DB) GetConversation(id string) (*Conversation, error) {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
|
||||
var projectID sql.NullString
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
|
||||
"SELECT id, title, pinned, created_at, updated_at, project_id FROM conversations WHERE id = ?",
|
||||
id,
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("对话不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||
}
|
||||
if projectID.Valid {
|
||||
conv.ProjectID = strings.TrimSpace(projectID.String)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
@@ -234,6 +283,7 @@ func (db *DB) GetConversation(id string) (*Conversation, error) {
|
||||
// 将过程详情附加到对应的消息上
|
||||
for i := range conv.Messages {
|
||||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||||
details = DedupeConsecutiveProcessDetails(details)
|
||||
// 将ProcessDetail转换为JSON格式,以便前端使用
|
||||
detailsJSON := make([]map[string]interface{}, len(details))
|
||||
for j, detail := range details {
|
||||
@@ -267,16 +317,20 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
|
||||
var projectID sql.NullString
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
|
||||
"SELECT id, title, pinned, created_at, updated_at, project_id FROM conversations WHERE id = ?",
|
||||
id,
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("对话不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||
}
|
||||
if projectID.Valid {
|
||||
conv.ProjectID = strings.TrimSpace(projectID.String)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
@@ -316,7 +370,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
||||
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
||||
searchPattern := "%" + search + "%"
|
||||
rows, err = db.Query(
|
||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at
|
||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
|
||||
FROM conversations c
|
||||
WHERE c.title LIKE ?
|
||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
||||
@@ -326,7 +380,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
||||
)
|
||||
} else {
|
||||
rows, err = db.Query(
|
||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||||
limit, offset,
|
||||
)
|
||||
}
|
||||
@@ -341,10 +395,14 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
var projectID sql.NullString
|
||||
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil {
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID); err != nil {
|
||||
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||||
}
|
||||
if projectID.Valid {
|
||||
conv.ProjectID = strings.TrimSpace(projectID.String)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
@@ -498,8 +556,8 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
|
||||
}
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, mcpIDsJSON, now, now,
|
||||
"INSERT INTO messages (id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, "", mcpIDsJSON, now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("添加消息失败: %w", err)
|
||||
@@ -523,10 +581,30 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
|
||||
return message, nil
|
||||
}
|
||||
|
||||
// UpdateAssistantMessageFinalize 更新助手消息终态(正文、MCP id、思考链聚合文本,供无轨迹回退时回放)。
|
||||
func (db *DB) UpdateAssistantMessageFinalize(messageID, content string, mcpExecutionIDs []string, reasoningContent string) error {
|
||||
var mcpIDsJSON string
|
||||
if len(mcpExecutionIDs) > 0 {
|
||||
jsonData, err := json.Marshal(mcpExecutionIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化MCP执行ID失败: %w", err)
|
||||
}
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, err := db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, reasoning_content = ?, updated_at = ? WHERE id = ?",
|
||||
content, mcpIDsJSON, strings.TrimSpace(reasoningContent), time.Now(), messageID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新助手消息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessages 获取对话的所有消息
|
||||
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
"SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -537,13 +615,17 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
var reasoning sql.NullString
|
||||
var mcpIDsJSON sql.NullString
|
||||
var createdAt string
|
||||
var updatedAt sql.NullString
|
||||
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &reasoning, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||
}
|
||||
if reasoning.Valid {
|
||||
msg.ReasoningContent = reasoning.String
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err error
|
||||
@@ -683,7 +765,7 @@ type ProcessDetail struct {
|
||||
ID string `json:"id"`
|
||||
MessageID string `json:"messageId"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error
|
||||
EventType string `json:"eventType"` // iteration, thinking, reasoning_chain, tool_calls_detected, tool_call, tool_result, progress, error
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"` // JSON格式的数据
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package database
|
||||
|
||||
// ConversationCreateMeta describes how a conversation was created (for audit hooks).
|
||||
type ConversationCreateMeta struct {
|
||||
Source string
|
||||
WebShellConnectionID string
|
||||
ProjectID string
|
||||
ClientIP string
|
||||
SessionHint string
|
||||
}
|
||||
|
||||
// ConversationCreateHook is invoked after a conversation row is inserted.
|
||||
type ConversationCreateHook func(conv *Conversation, meta ConversationCreateMeta)
|
||||
|
||||
var conversationCreateHook ConversationCreateHook
|
||||
|
||||
// SetConversationCreateHook registers a global hook (e.g. platform audit).
|
||||
func SetConversationCreateHook(h ConversationCreateHook) {
|
||||
conversationCreateHook = h
|
||||
}
|
||||
|
||||
func notifyConversationCreated(conv *Conversation, meta ConversationCreateMeta) {
|
||||
if conversationCreateHook == nil || conv == nil {
|
||||
return
|
||||
}
|
||||
if meta.Source == "" {
|
||||
meta.Source = "unknown"
|
||||
}
|
||||
conversationCreateHook(conv, meta)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -12,19 +13,106 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
// SQLite 在 WAL 模式下建议使用较保守的连接数,降低长读快照导致 checkpoint 饥饿的概率。
|
||||
sqliteMaxOpenConns = 25
|
||||
sqliteMaxIdleConns = 5
|
||||
// 以页为单位的自动 checkpoint 触发阈值(默认 1000 页,约 4MB @ 4KB/page)。
|
||||
sqliteWALAutoCheckpointPages = 1000
|
||||
// 控制 WAL 目标上限,避免异常场景持续膨胀(256MB)。
|
||||
sqliteJournalSizeLimitBytes = 256 * 1024 * 1024
|
||||
// 定时执行 PASSIVE checkpoint,平滑推进 WAL 回收。
|
||||
sqlitePassiveCheckpointInterval = 300 * time.Second
|
||||
)
|
||||
|
||||
// configureDBPool 设置 SQLite 连接池参数,提升并发稳定性
|
||||
func configureDBPool(db *sql.DB) {
|
||||
// SQLite 同一时间只允许一个写入者,限制连接数避免 "database is locked" 错误
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
// SQLite 同一时间只允许一个写入者;过高连接数会放大锁竞争和 WAL 回收延迟。
|
||||
db.SetMaxOpenConns(sqliteMaxOpenConns)
|
||||
db.SetMaxIdleConns(sqliteMaxIdleConns)
|
||||
db.SetConnMaxLifetime(30 * time.Minute)
|
||||
}
|
||||
|
||||
// configureSQLitePragmas 调整 WAL 回收行为,降低 -wal 文件长期膨胀风险。
|
||||
func configureSQLitePragmas(db *sql.DB) error {
|
||||
if _, err := db.Exec(fmt.Sprintf("PRAGMA wal_autocheckpoint=%d", sqliteWALAutoCheckpointPages)); err != nil {
|
||||
return fmt.Errorf("设置 wal_autocheckpoint 失败: %w", err)
|
||||
}
|
||||
if _, err := db.Exec(fmt.Sprintf("PRAGMA journal_size_limit=%d", sqliteJournalSizeLimitBytes)); err != nil {
|
||||
return fmt.Errorf("设置 journal_size_limit 失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DB 数据库连接
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
logger *zap.Logger
|
||||
conversationArtifactsDir string
|
||||
checkpointLoopName string
|
||||
checkpointStop chan struct{}
|
||||
checkpointDone chan struct{}
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// startPassiveCheckpointLoop 启动后台 PASSIVE checkpoint 循环。
|
||||
func (db *DB) startPassiveCheckpointLoop(name string) {
|
||||
if sqlitePassiveCheckpointInterval <= 0 || db == nil || db.DB == nil {
|
||||
return
|
||||
}
|
||||
db.checkpointLoopName = strings.TrimSpace(name)
|
||||
db.checkpointStop = make(chan struct{})
|
||||
db.checkpointDone = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(db.checkpointDone)
|
||||
ticker := time.NewTicker(sqlitePassiveCheckpointInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动后先尝试一次,尽快回收已有 WAL 堆积。
|
||||
db.runPassiveCheckpoint("startup")
|
||||
for {
|
||||
select {
|
||||
case <-db.checkpointStop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
db.runPassiveCheckpoint("ticker")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// runPassiveCheckpoint 执行一次 PRAGMA wal_checkpoint(PASSIVE)。
|
||||
func (db *DB) runPassiveCheckpoint(trigger string) {
|
||||
if db == nil || db.DB == nil {
|
||||
return
|
||||
}
|
||||
startAt := time.Now()
|
||||
var busy, logFrames, checkpointed int
|
||||
err := db.QueryRow("PRAGMA wal_checkpoint(PASSIVE)").Scan(&busy, &logFrames, &checkpointed)
|
||||
if db.logger == nil {
|
||||
return
|
||||
}
|
||||
fields := []zap.Field{
|
||||
zap.String("db", db.checkpointLoopName),
|
||||
zap.String("trigger", trigger),
|
||||
zap.Int("busy", busy),
|
||||
zap.Int("log_frames", logFrames),
|
||||
zap.Int("checkpointed_frames", checkpointed),
|
||||
zap.Int64("elapsed_ms", time.Since(startAt).Milliseconds()),
|
||||
}
|
||||
if err != nil {
|
||||
db.logger.Warn("SQLite PASSIVE checkpoint 完成(失败)",
|
||||
append(fields, zap.Error(err))...,
|
||||
)
|
||||
return
|
||||
}
|
||||
if busy > 0 {
|
||||
db.logger.Info("SQLite PASSIVE checkpoint 完成(部分推进)", fields...)
|
||||
return
|
||||
}
|
||||
db.logger.Info("SQLite PASSIVE checkpoint 完成(成功)", fields...)
|
||||
}
|
||||
|
||||
// NewDB 创建数据库连接
|
||||
@@ -37,8 +125,13 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
configureDBPool(db)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
if err := configureSQLitePragmas(db); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("配置数据库 PRAGMA 失败: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{
|
||||
DB: db,
|
||||
@@ -54,8 +147,10 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
|
||||
// 初始化表
|
||||
if err := database.initTables(); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("初始化表失败: %w", err)
|
||||
}
|
||||
database.startPassiveCheckpointLoop("conversations")
|
||||
|
||||
return database, nil
|
||||
}
|
||||
@@ -213,6 +308,59 @@ func (db *DB) initTables() error {
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建项目表
|
||||
createProjectsTable := `
|
||||
CREATE TABLE IF NOT EXISTS projects (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
scope_json TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL
|
||||
);`
|
||||
|
||||
// 创建项目事实表(黑板)
|
||||
createProjectFactsTable := `
|
||||
CREATE TABLE IF NOT EXISTS project_facts (
|
||||
id TEXT PRIMARY KEY,
|
||||
project_id TEXT NOT NULL,
|
||||
fact_key TEXT NOT NULL,
|
||||
category TEXT NOT NULL DEFAULT 'note',
|
||||
summary TEXT NOT NULL DEFAULT '',
|
||||
body TEXT,
|
||||
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||
source_conversation_id TEXT,
|
||||
source_message_id TEXT,
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
supersedes_fact_id TEXT,
|
||||
related_vulnerability_id TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||
UNIQUE(project_id, fact_key)
|
||||
);`
|
||||
|
||||
createProjectFactVersionsTable := `
|
||||
CREATE TABLE IF NOT EXISTS project_fact_versions (
|
||||
id TEXT PRIMARY KEY,
|
||||
fact_id TEXT NOT NULL,
|
||||
project_id TEXT NOT NULL,
|
||||
fact_key TEXT NOT NULL,
|
||||
category TEXT NOT NULL DEFAULT 'note',
|
||||
summary TEXT NOT NULL DEFAULT '',
|
||||
body TEXT,
|
||||
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||
source_conversation_id TEXT,
|
||||
source_message_id TEXT,
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
related_vulnerability_id TEXT,
|
||||
archived_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建漏洞表
|
||||
createVulnerabilitiesTable := `
|
||||
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
||||
@@ -387,6 +535,24 @@ func (db *DB) initTables() error {
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
createAuditLogsTable := `
|
||||
CREATE TABLE IF NOT EXISTS audit_logs (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at DATETIME NOT NULL,
|
||||
level TEXT NOT NULL DEFAULT 'info',
|
||||
category TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
result TEXT NOT NULL,
|
||||
actor TEXT NOT NULL DEFAULT 'admin',
|
||||
session_hint TEXT,
|
||||
client_ip TEXT,
|
||||
user_agent TEXT,
|
||||
resource_type TEXT,
|
||||
resource_id TEXT,
|
||||
message TEXT NOT NULL,
|
||||
detail_json TEXT
|
||||
);`
|
||||
|
||||
createC2ProfilesTable := `
|
||||
CREATE TABLE IF NOT EXISTS c2_profiles (
|
||||
id TEXT PRIMARY KEY,
|
||||
@@ -427,6 +593,14 @@ func (db *DB) initTables() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_projects_status ON projects(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_projects_updated_at ON projects(updated_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
|
||||
@@ -445,6 +619,10 @@ func (db *DB) initTables() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_category ON audit_logs(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_action ON audit_logs(action);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_result ON audit_logs(result);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(createConversationsTable); err != nil {
|
||||
@@ -494,6 +672,18 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建robot_user_sessions表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createProjectsTable); err != nil {
|
||||
return fmt.Errorf("创建projects表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createProjectFactsTable); err != nil {
|
||||
return fmt.Errorf("创建project_facts表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createProjectFactVersionsTable); err != nil {
|
||||
return fmt.Errorf("创建project_fact_versions表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
||||
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
|
||||
}
|
||||
@@ -514,6 +704,10 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createAuditLogsTable); err != nil {
|
||||
return fmt.Errorf("创建audit_logs表失败: %w", err)
|
||||
}
|
||||
|
||||
for tableName, ddl := range map[string]string{
|
||||
"c2_listeners": createC2ListenersTable,
|
||||
"c2_sessions": createC2SessionsTable,
|
||||
@@ -557,6 +751,13 @@ func (db *DB) initTables() error {
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if err := db.migrateProjectsTable(); err != nil {
|
||||
db.logger.Warn("迁移projects相关表失败", zap.Error(err))
|
||||
}
|
||||
if err := db.migrateProjectFactVersionsTable(); err != nil {
|
||||
db.logger.Warn("迁移project_fact_versions表失败", zap.Error(err))
|
||||
}
|
||||
|
||||
if err := db.migrateWebshellConnectionsTable(); err != nil {
|
||||
db.logger.Warn("迁移webshell_connections表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
@@ -594,6 +795,25 @@ func (db *DB) migrateMessagesTable() error {
|
||||
|
||||
// 回填已有数据:让 updated_at 至少等于 created_at,避免前端出现空/当前时间回退。
|
||||
_, _ = db.Exec("UPDATE messages SET updated_at = created_at WHERE updated_at IS NULL OR updated_at = ''")
|
||||
|
||||
// reasoning_content:DeepSeek 思考模式 + 工具调用续跑;与 last_react_input 互补,供消息表回退路径回放
|
||||
var rcColCount int
|
||||
errRC := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='reasoning_content'").Scan(&rcColCount)
|
||||
if errRC != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", addErr)
|
||||
}
|
||||
}
|
||||
} else if rcColCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); err != nil {
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -885,6 +1105,79 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
|
||||
}
|
||||
}
|
||||
|
||||
var projectIDCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='project_id'").Scan(&projectIDCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN project_id TEXT"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加batch_task_queues.project_id字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if projectIDCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN project_id TEXT"); err != nil {
|
||||
db.logger.Warn("添加batch_task_queues.project_id字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateProjectsTable 迁移 projects / conversations / vulnerabilities 的项目关联字段。
|
||||
func (db *DB) migrateProjectsTable() error {
|
||||
for _, col := range []struct {
|
||||
table string
|
||||
name string
|
||||
stmt string
|
||||
}{
|
||||
{"conversations", "project_id", "ALTER TABLE conversations ADD COLUMN project_id TEXT REFERENCES projects(id) ON DELETE SET NULL"},
|
||||
{"vulnerabilities", "project_id", "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"},
|
||||
} {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info(?) WHERE name=?", col.table, col.name).Scan(&count)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec(col.stmt); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加字段失败", zap.String("table", col.table), zap.String("field", col.name), zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if count == 0 {
|
||||
if _, addErr := db.Exec(col.stmt); addErr != nil {
|
||||
db.logger.Warn("添加字段失败", zap.String("table", col.table), zap.String("field", col.name), zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateProjectFactVersionsTable 为已有库创建事实版本表。
|
||||
func (db *DB) migrateProjectFactVersionsTable() error {
|
||||
ddl := `
|
||||
CREATE TABLE IF NOT EXISTS project_fact_versions (
|
||||
id TEXT PRIMARY KEY,
|
||||
fact_id TEXT NOT NULL,
|
||||
project_id TEXT NOT NULL,
|
||||
fact_key TEXT NOT NULL,
|
||||
category TEXT NOT NULL DEFAULT 'note',
|
||||
summary TEXT NOT NULL DEFAULT '',
|
||||
body TEXT,
|
||||
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||
source_conversation_id TEXT,
|
||||
source_message_id TEXT,
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
related_vulnerability_id TEXT,
|
||||
archived_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
||||
);`
|
||||
if _, err := db.Exec(ddl); err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id)`)
|
||||
_, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id)`)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -896,6 +1189,7 @@ func (db *DB) migrateVulnerabilitiesTable() error {
|
||||
}{
|
||||
{name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"},
|
||||
{name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"},
|
||||
{name: "project_id", stmt: "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"},
|
||||
}
|
||||
|
||||
for _, col := range columns {
|
||||
@@ -960,8 +1254,13 @@ func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
configureDBPool(sqlDB)
|
||||
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
_ = sqlDB.Close()
|
||||
return nil, fmt.Errorf("连接知识库数据库失败: %w", err)
|
||||
}
|
||||
if err := configureSQLitePragmas(sqlDB); err != nil {
|
||||
_ = sqlDB.Close()
|
||||
return nil, fmt.Errorf("配置知识库数据库 PRAGMA 失败: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{
|
||||
DB: sqlDB,
|
||||
@@ -970,8 +1269,10 @@ func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
|
||||
// 初始化知识库表
|
||||
if err := database.initKnowledgeTables(); err != nil {
|
||||
_ = sqlDB.Close()
|
||||
return nil, fmt.Errorf("初始化知识库表失败: %w", err)
|
||||
}
|
||||
database.startPassiveCheckpointLoop("knowledge")
|
||||
|
||||
return database, nil
|
||||
}
|
||||
@@ -1085,5 +1386,19 @@ func (db *DB) migrateKnowledgeEmbeddingsColumns() error {
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
if db == nil {
|
||||
return nil
|
||||
}
|
||||
db.closeOnce.Do(func() {
|
||||
if db.checkpointStop != nil {
|
||||
close(db.checkpointStop)
|
||||
if db.checkpointDone != nil {
|
||||
<-db.checkpointDone
|
||||
}
|
||||
}
|
||||
if db.DB != nil {
|
||||
db.closeErr = db.DB.Close()
|
||||
}
|
||||
})
|
||||
return db.closeErr
|
||||
}
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DedupeConsecutiveProcessDetails 去掉相邻且语义相同的过程详情(使用 DB 中 data 列原始 JSON 作指纹,避免 map 序列化键序不稳定)。
|
||||
func DedupeConsecutiveProcessDetails(rows []ProcessDetail) []ProcessDetail {
|
||||
if len(rows) < 2 {
|
||||
return rows
|
||||
}
|
||||
out := make([]ProcessDetail, 0, len(rows))
|
||||
var lastKey string
|
||||
for _, d := range rows {
|
||||
key := processDetailRowKey(d)
|
||||
if len(out) > 0 && key != "" && key == lastKey {
|
||||
continue
|
||||
}
|
||||
out = append(out, d)
|
||||
lastKey = key
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func processDetailRowKey(d ProcessDetail) string {
|
||||
return fmt.Sprintf("%s\x00%s\x00%s", d.EventType, strings.TrimSpace(d.Message), d.Data)
|
||||
}
|
||||
@@ -0,0 +1,513 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var factKeyPattern = regexp.MustCompile(`^[a-z0-9][a-z0-9._/-]*$`)
|
||||
|
||||
// ValidateFactKey 校验事实 key(项目内唯一标识)。
|
||||
func ValidateFactKey(key string) error {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return fmt.Errorf("fact_key 不能为空")
|
||||
}
|
||||
if len(key) > 128 {
|
||||
return fmt.Errorf("fact_key 过长(最多 128 字符)")
|
||||
}
|
||||
if !factKeyPattern.MatchString(key) {
|
||||
return fmt.Errorf("fact_key 格式无效,仅允许小写字母、数字及 . _ / -,且须以小写字母或数字开头")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Project 渗透测试项目(跨对话共享黑板)。
|
||||
type Project struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
ScopeJSON string `json:"scope_json,omitempty"`
|
||||
Status string `json:"status"` // active | archived
|
||||
Pinned bool `json:"pinned"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ProjectFact 项目事实(黑板条目)。
|
||||
type ProjectFact struct {
|
||||
ID string `json:"id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
FactKey string `json:"fact_key"`
|
||||
Category string `json:"category"`
|
||||
Summary string `json:"summary"`
|
||||
Body string `json:"body"`
|
||||
Confidence string `json:"confidence"` // confirmed | tentative | deprecated
|
||||
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
||||
SourceMessageID string `json:"source_message_id,omitempty"`
|
||||
Pinned bool `json:"pinned"`
|
||||
SupersedesFactID string `json:"supersedes_fact_id,omitempty"`
|
||||
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ProjectFactListFilter 事实列表筛选。
|
||||
type ProjectFactListFilter struct {
|
||||
Category string
|
||||
Confidence string
|
||||
Search string
|
||||
RelatedVulnerabilityID string
|
||||
ExcludeDeprecated bool // 为 true 时排除 confidence=deprecated
|
||||
}
|
||||
|
||||
// CreateProject 创建项目。
|
||||
func (db *DB) CreateProject(p *Project) (*Project, error) {
|
||||
if p.ID == "" {
|
||||
p.ID = uuid.New().String()
|
||||
}
|
||||
if strings.TrimSpace(p.Status) == "" {
|
||||
p.Status = "active"
|
||||
}
|
||||
now := time.Now()
|
||||
if p.CreatedAt.IsZero() {
|
||||
p.CreatedAt = now
|
||||
}
|
||||
p.UpdatedAt = now
|
||||
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO projects (id, name, description, scope_json, status, pinned, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
p.ID, p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.CreatedAt, p.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建项目失败: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// GetProject 获取项目。
|
||||
func (db *DB) GetProject(id string) (*Project, error) {
|
||||
var p Project
|
||||
var pinned int
|
||||
var createdAt, updatedAt string
|
||||
err := db.QueryRow(
|
||||
`SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
||||
FROM projects WHERE id = ?`, id,
|
||||
).Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("项目不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("获取项目失败: %w", err)
|
||||
}
|
||||
p.Pinned = pinned != 0
|
||||
p.CreatedAt = parseDBTime(createdAt)
|
||||
p.UpdatedAt = parseDBTime(updatedAt)
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// ListProjects 列出项目。
|
||||
func (db *DB) ListProjects(status string, limit, offset int) ([]*Project, error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
||||
FROM projects WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
if s := strings.TrimSpace(status); s != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, s)
|
||||
}
|
||||
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("列出项目失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*Project
|
||||
for rows.Next() {
|
||||
var p Project
|
||||
var pinned int
|
||||
var createdAt, updatedAt string
|
||||
if err := rows.Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Pinned = pinned != 0
|
||||
p.CreatedAt = parseDBTime(createdAt)
|
||||
p.UpdatedAt = parseDBTime(updatedAt)
|
||||
out = append(out, &p)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateProject 更新项目。
|
||||
func (db *DB) UpdateProject(p *Project) error {
|
||||
p.UpdatedAt = time.Now()
|
||||
_, err := db.Exec(
|
||||
`UPDATE projects SET name = ?, description = ?, scope_json = ?, status = ?, pinned = ?, updated_at = ? WHERE id = ?`,
|
||||
p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.UpdatedAt, p.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新项目失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteProject 删除项目(级联删除事实;对话 project_id 置空由 FK 处理;漏洞 project_id 置空)。
|
||||
func (db *DB) DeleteProject(id string) error {
|
||||
if _, err := db.Exec(`UPDATE vulnerabilities SET project_id = NULL WHERE project_id = ?`, id); err != nil {
|
||||
return fmt.Errorf("解除漏洞项目关联失败: %w", err)
|
||||
}
|
||||
_, err := db.Exec(`DELETE FROM projects WHERE id = ?`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除项目失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConversationProjectID 返回对话绑定的项目 ID。
|
||||
func (db *DB) GetConversationProjectID(conversationID string) (string, error) {
|
||||
var pid sql.NullString
|
||||
err := db.QueryRow(`SELECT project_id FROM conversations WHERE id = ?`, conversationID).Scan(&pid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("对话不存在")
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if pid.Valid {
|
||||
return strings.TrimSpace(pid.String), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// SetConversationProjectID 设置对话所属项目(空字符串表示解除绑定)。
|
||||
func (db *DB) SetConversationProjectID(conversationID, projectID string) error {
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID != "" {
|
||||
if _, err := db.GetProject(projectID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
var val interface{}
|
||||
if projectID == "" {
|
||||
val = nil
|
||||
} else {
|
||||
val = projectID
|
||||
}
|
||||
_, err := db.Exec(`UPDATE conversations SET project_id = ?, updated_at = ? WHERE id = ?`, val, time.Now(), conversationID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("设置对话项目失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListProjectFactsForIndex 列出用于黑板索引注入的事实(不含 deprecated,除非 includeDeprecated)。
|
||||
func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) {
|
||||
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||
FROM project_facts WHERE project_id = ?`
|
||||
args := []interface{}{projectID}
|
||||
if !includeDeprecated {
|
||||
query += " AND confidence != 'deprecated'"
|
||||
}
|
||||
query += " ORDER BY pinned DESC, updated_at DESC"
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanProjectFacts(rows)
|
||||
}
|
||||
|
||||
// ListProjectFacts 分页列出项目事实。
|
||||
func (db *DB) ListProjectFacts(projectID string, filter ProjectFactListFilter, limit, offset int) ([]*ProjectFact, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||
FROM project_facts WHERE project_id = ?`
|
||||
args := []interface{}{projectID}
|
||||
if c := strings.TrimSpace(filter.Category); c != "" {
|
||||
query += " AND category = ?"
|
||||
args = append(args, c)
|
||||
}
|
||||
if c := strings.TrimSpace(filter.Confidence); c != "" {
|
||||
query += " AND confidence = ?"
|
||||
args = append(args, c)
|
||||
}
|
||||
if filter.ExcludeDeprecated {
|
||||
query += " AND confidence != 'deprecated'"
|
||||
}
|
||||
if rid := strings.TrimSpace(filter.RelatedVulnerabilityID); rid != "" {
|
||||
query += " AND related_vulnerability_id = ?"
|
||||
args = append(args, rid)
|
||||
}
|
||||
if s := strings.TrimSpace(filter.Search); s != "" {
|
||||
pat := "%" + s + "%"
|
||||
query += " AND (fact_key LIKE ? OR summary LIKE ? OR body LIKE ?)"
|
||||
args = append(args, pat, pat, pat)
|
||||
}
|
||||
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanProjectFacts(rows)
|
||||
}
|
||||
|
||||
// GetProjectFactByKey 按 key 获取事实。
|
||||
func (db *DB) GetProjectFactByKey(projectID, factKey string) (*ProjectFact, error) {
|
||||
row := db.QueryRow(
|
||||
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||
FROM project_facts WHERE project_id = ? AND fact_key = ?`,
|
||||
projectID, factKey,
|
||||
)
|
||||
return scanProjectFactRow(row)
|
||||
}
|
||||
|
||||
// GetProjectFact 按 ID 获取事实。
|
||||
func (db *DB) GetProjectFact(id string) (*ProjectFact, error) {
|
||||
row := db.QueryRow(
|
||||
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||
FROM project_facts WHERE id = ?`, id,
|
||||
)
|
||||
return scanProjectFactRow(row)
|
||||
}
|
||||
|
||||
// mergeFactBodyOnUpdate 更新时若 incoming body 为空则保留已有内容,避免仅改 summary 时丢失攻击链。
|
||||
func mergeFactBodyOnUpdate(incoming, existing string) string {
|
||||
if strings.TrimSpace(incoming) == "" {
|
||||
return existing
|
||||
}
|
||||
return incoming
|
||||
}
|
||||
|
||||
// UpsertProjectFact 创建或更新事实(按 project_id + fact_key)。
|
||||
func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
|
||||
if err := ValidateFactKey(f.FactKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(f.Category) == "" {
|
||||
f.Category = "note"
|
||||
}
|
||||
if strings.TrimSpace(f.Confidence) == "" {
|
||||
f.Confidence = "tentative"
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
existing, err := db.GetProjectFactByKey(f.ProjectID, f.FactKey)
|
||||
if err == nil && existing != nil {
|
||||
f.ID = existing.ID
|
||||
f.CreatedAt = existing.CreatedAt
|
||||
f.UpdatedAt = now
|
||||
f.Body = mergeFactBodyOnUpdate(f.Body, existing.Body)
|
||||
if strings.TrimSpace(f.Category) == "" {
|
||||
f.Category = existing.Category
|
||||
}
|
||||
if strings.TrimSpace(f.Confidence) == "" {
|
||||
f.Confidence = existing.Confidence
|
||||
}
|
||||
if projectFactContentChanged(existing, f) {
|
||||
versionID, verr := db.InsertProjectFactVersion(existing)
|
||||
if verr != nil {
|
||||
return nil, verr
|
||||
}
|
||||
f.SupersedesFactID = versionID
|
||||
} else if f.SupersedesFactID == "" {
|
||||
f.SupersedesFactID = existing.SupersedesFactID
|
||||
}
|
||||
_, err = db.Exec(
|
||||
`UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?,
|
||||
source_conversation_id = COALESCE(?, source_conversation_id),
|
||||
source_message_id = COALESCE(?, source_message_id),
|
||||
pinned = ?, supersedes_fact_id = ?, related_vulnerability_id = ?, updated_at = ?
|
||||
WHERE id = ?`,
|
||||
f.Category, f.Summary, f.Body, f.Confidence,
|
||||
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("更新事实失败: %w", err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
if f.ID == "" {
|
||||
f.ID = uuid.New().String()
|
||||
}
|
||||
f.CreatedAt = now
|
||||
f.UpdatedAt = now
|
||||
_, err = db.Exec(
|
||||
`INSERT INTO project_facts (
|
||||
id, project_id, fact_key, category, summary, body, confidence,
|
||||
source_conversation_id, source_message_id, pinned, supersedes_fact_id, related_vulnerability_id,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
|
||||
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID),
|
||||
f.CreatedAt, f.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建事实失败: %w", err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// DeprecateProjectFact 将事实标记为 deprecated。
|
||||
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
||||
res, err := db.Exec(
|
||||
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
||||
time.Now(), projectID, factKey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return fmt.Errorf("事实不存在")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。
|
||||
func (db *DB) RestoreProjectFact(projectID, factKey, confidence string) error {
|
||||
confidence = strings.TrimSpace(strings.ToLower(confidence))
|
||||
if confidence == "" {
|
||||
confidence = "tentative"
|
||||
}
|
||||
if confidence != "confirmed" && confidence != "tentative" {
|
||||
return fmt.Errorf("confidence 须为 confirmed 或 tentative")
|
||||
}
|
||||
|
||||
existing, err := db.GetProjectFactByKey(projectID, factKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("事实不存在")
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(existing.Confidence)) != "deprecated" {
|
||||
return fmt.Errorf("事实未处于废弃状态")
|
||||
}
|
||||
|
||||
_, err = db.Exec(
|
||||
`UPDATE project_facts SET confidence = ?, updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
||||
confidence, time.Now(), projectID, factKey,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteProjectFact 删除事实。
|
||||
func (db *DB) DeleteProjectFact(id string) error {
|
||||
_, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func scanProjectFacts(rows *sql.Rows) ([]*ProjectFact, error) {
|
||||
var out []*ProjectFact
|
||||
for rows.Next() {
|
||||
f, err := scanProjectFactFromRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, f)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func scanProjectFactRow(row *sql.Row) (*ProjectFact, error) {
|
||||
var f ProjectFact
|
||||
var pinned int
|
||||
var createdAt, updatedAt string
|
||||
err := row.Scan(
|
||||
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
||||
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
||||
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("事实不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
f.Pinned = pinned != 0
|
||||
f.CreatedAt = parseDBTime(createdAt)
|
||||
f.UpdatedAt = parseDBTime(updatedAt)
|
||||
return &f, nil
|
||||
}
|
||||
|
||||
func scanProjectFactFromRows(rows *sql.Rows) (*ProjectFact, error) {
|
||||
var f ProjectFact
|
||||
var pinned int
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
||||
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
||||
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.Pinned = pinned != 0
|
||||
f.CreatedAt = parseDBTime(createdAt)
|
||||
f.UpdatedAt = parseDBTime(updatedAt)
|
||||
return &f, nil
|
||||
}
|
||||
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func nullIfEmpty(s string) interface{} {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func parseDBTime(s string) time.Time {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
// go-sqlite3 读 DATETIME 常返回 RFC3339(含 T),写入时可能是空格分隔格式,需兼容多种形态
|
||||
layouts := []string{
|
||||
time.RFC3339Nano,
|
||||
time.RFC3339,
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05-07:00",
|
||||
"2006-01-02T15:04:05.999999999-07:00",
|
||||
"2006-01-02T15:04:05-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05.999999999",
|
||||
"2006-01-02T15:04:05",
|
||||
}
|
||||
for _, layout := range layouts {
|
||||
if t, e := time.Parse(layout, s); e == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestUpsertProjectFact_preservesBodyOnEmptyUpdate(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
proj, err := db.CreateProject(&Project{Name: "test-facts"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const body = "## 攻击链\n1. step\n```http\nGET / HTTP/1.1\n```\n"
|
||||
_, err = db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "finding/sqli-login",
|
||||
Category: "finding",
|
||||
Summary: "SQLi on /login",
|
||||
Body: body,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
updated, err := db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "finding/sqli-login",
|
||||
Summary: "SQLi on /login (confirmed)",
|
||||
Body: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if updated.Summary != "SQLi on /login (confirmed)" {
|
||||
t.Fatalf("summary=%q", updated.Summary)
|
||||
}
|
||||
if updated.Body != body {
|
||||
t.Fatalf("returned body=%q want preserved attack chain", updated.Body)
|
||||
}
|
||||
|
||||
fromDB, err := db.GetProjectFactByKey(proj.ID, "finding/sqli-login")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if fromDB.Body != body {
|
||||
t.Fatalf("stored body=%q want preserved", fromDB.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertProjectFact_replacesBodyWhenProvided(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
proj, err := db.CreateProject(&Project{Name: "test-facts"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "target/primary",
|
||||
Summary: "v1",
|
||||
Body: "old body",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const newBody = "new body with evidence"
|
||||
updated, err := db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "target/primary",
|
||||
Summary: "v2",
|
||||
Body: newBody,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if updated.Body != newBody {
|
||||
t.Fatalf("body=%q want %q", updated.Body, newBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreProjectFact(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
proj, err := db.CreateProject(&Project{Name: "restore-test"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
key := "target/restore-me"
|
||||
_, err = db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: key,
|
||||
Summary: "s",
|
||||
Confidence: "confirmed",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.DeprecateProjectFact(proj.ID, key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.RestoreProjectFact(proj.ID, key, "confirmed"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f, err := db.GetProjectFactByKey(proj.ID, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if f.Confidence != "confirmed" {
|
||||
t.Fatalf("confidence=%q want confirmed", f.Confidence)
|
||||
}
|
||||
if err := db.RestoreProjectFact(proj.ID, key, ""); err == nil {
|
||||
t.Fatal("expected error when not deprecated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertProjectFact_createsVersionOnContentChange(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
proj, err := db.CreateProject(&Project{Name: "version-test"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
created, err := db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "finding/xss",
|
||||
Category: "finding",
|
||||
Summary: "v1",
|
||||
Body: "body v1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if created.SupersedesFactID != "" {
|
||||
t.Fatalf("expected no supersedes on create, got %q", created.SupersedesFactID)
|
||||
}
|
||||
|
||||
updated, err := db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "finding/xss",
|
||||
Summary: "v2",
|
||||
Body: "body v2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if updated.SupersedesFactID == "" {
|
||||
t.Fatal("expected supersedes_fact_id after content change")
|
||||
}
|
||||
prev, err := db.GetProjectFactVersion(updated.SupersedesFactID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if prev.Summary != "v1" || prev.Body != "body v1" {
|
||||
t.Fatalf("previous version mismatch: summary=%q body=%q", prev.Summary, prev.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeFactBodyOnUpdate(t *testing.T) {
|
||||
if got := mergeFactBodyOnUpdate("", "keep"); got != "keep" {
|
||||
t.Fatalf("empty incoming: got %q", got)
|
||||
}
|
||||
if got := mergeFactBodyOnUpdate(" ", "keep"); got != "keep" {
|
||||
t.Fatalf("whitespace incoming: got %q", got)
|
||||
}
|
||||
if got := mergeFactBodyOnUpdate("new", "old"); got != "new" {
|
||||
t.Fatalf("non-empty incoming: got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ProjectFactVersion 事实历史快照(同 fact_key 更新前归档)。
|
||||
type ProjectFactVersion struct {
|
||||
ID string `json:"id"`
|
||||
FactID string `json:"fact_id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
FactKey string `json:"fact_key"`
|
||||
Category string `json:"category"`
|
||||
Summary string `json:"summary"`
|
||||
Body string `json:"body"`
|
||||
Confidence string `json:"confidence"`
|
||||
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
||||
SourceMessageID string `json:"source_message_id,omitempty"`
|
||||
Pinned bool `json:"pinned"`
|
||||
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
|
||||
ArchivedAt time.Time `json:"archived_at"`
|
||||
}
|
||||
|
||||
// InsertProjectFactVersion 将当前事实行快照写入版本表。
|
||||
func (db *DB) InsertProjectFactVersion(f *ProjectFact) (string, error) {
|
||||
if f == nil || f.ID == "" {
|
||||
return "", fmt.Errorf("无效的事实记录")
|
||||
}
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO project_fact_versions (
|
||||
id, fact_id, project_id, fact_key, category, summary, body, confidence,
|
||||
source_conversation_id, source_message_id, pinned, related_vulnerability_id, archived_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
id, f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
|
||||
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||
nullIfEmpty(f.RelatedVulnerabilityID), now,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("归档事实版本失败: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetProjectFactVersion 按版本 ID 获取快照。
|
||||
func (db *DB) GetProjectFactVersion(versionID string) (*ProjectFactVersion, error) {
|
||||
row := db.QueryRow(
|
||||
`SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(related_vulnerability_id,''), archived_at
|
||||
FROM project_fact_versions WHERE id = ?`, versionID,
|
||||
)
|
||||
return scanProjectFactVersionRow(row)
|
||||
}
|
||||
|
||||
// ListProjectFactVersions 列出某条事实的全部历史版本(新→旧)。
|
||||
func (db *DB) ListProjectFactVersions(factID string, limit int) ([]*ProjectFactVersion, error) {
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
rows, err := db.Query(
|
||||
`SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(related_vulnerability_id,''), archived_at
|
||||
FROM project_fact_versions WHERE fact_id = ? ORDER BY archived_at DESC LIMIT ?`,
|
||||
factID, limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*ProjectFactVersion
|
||||
for rows.Next() {
|
||||
v, err := scanProjectFactVersionFromRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, v)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func projectFactContentChanged(existing, incoming *ProjectFact) bool {
|
||||
if existing == nil || incoming == nil {
|
||||
return false
|
||||
}
|
||||
mergedBody := mergeFactBodyOnUpdate(incoming.Body, existing.Body)
|
||||
inCat := stringsTrimDefault(incoming.Category, existing.Category)
|
||||
inConf := stringsTrimDefault(incoming.Confidence, existing.Confidence)
|
||||
return existing.Summary != incoming.Summary ||
|
||||
existing.Body != mergedBody ||
|
||||
existing.Category != inCat ||
|
||||
existing.Confidence != inConf
|
||||
}
|
||||
|
||||
func stringsTrimDefault(s, fallback string) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return fallback
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func scanProjectFactVersionRow(row *sql.Row) (*ProjectFactVersion, error) {
|
||||
var v ProjectFactVersion
|
||||
var pinned int
|
||||
var archivedAt string
|
||||
err := row.Scan(
|
||||
&v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence,
|
||||
&v.SourceConversationID, &v.SourceMessageID, &pinned,
|
||||
&v.RelatedVulnerabilityID, &archivedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("事实版本不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
v.Pinned = pinned != 0
|
||||
v.ArchivedAt = parseDBTime(archivedAt)
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
func scanProjectFactVersionFromRows(rows *sql.Rows) (*ProjectFactVersion, error) {
|
||||
var v ProjectFactVersion
|
||||
var pinned int
|
||||
var archivedAt string
|
||||
err := rows.Scan(
|
||||
&v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence,
|
||||
&v.SourceConversationID, &v.SourceMessageID, &pinned,
|
||||
&v.RelatedVulnerabilityID, &archivedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v.Pinned = pinned != 0
|
||||
v.ArchivedAt = parseDBTime(archivedAt)
|
||||
return &v, nil
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProjectStats 项目聚合统计。
|
||||
type ProjectStats struct {
|
||||
FactCount int `json:"fact_count"`
|
||||
VulnCount int `json:"vuln_count"`
|
||||
ConversationCount int `json:"conversation_count"`
|
||||
SparseFactCount int `json:"sparse_fact_count"`
|
||||
}
|
||||
|
||||
// GetProjectStatsCounts 统计项目下事实、漏洞、对话数量(不含 sparse,由 project 包补全)。
|
||||
func (db *DB) GetProjectStatsCounts(projectID string) (*ProjectStats, error) {
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return nil, fmt.Errorf("project_id 不能为空")
|
||||
}
|
||||
if _, err := db.GetProject(projectID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats := &ProjectStats{}
|
||||
if err := db.QueryRow(
|
||||
`SELECT COUNT(*) FROM project_facts WHERE project_id = ? AND confidence != 'deprecated'`,
|
||||
projectID,
|
||||
).Scan(&stats.FactCount); err != nil {
|
||||
return nil, fmt.Errorf("统计事实失败: %w", err)
|
||||
}
|
||||
if err := db.QueryRow(
|
||||
`SELECT COUNT(*) FROM vulnerabilities WHERE project_id = ?`,
|
||||
projectID,
|
||||
).Scan(&stats.VulnCount); err != nil {
|
||||
return nil, fmt.Errorf("统计漏洞失败: %w", err)
|
||||
}
|
||||
if err := db.QueryRow(
|
||||
`SELECT COUNT(*) FROM conversations WHERE project_id = ?`,
|
||||
projectID,
|
||||
).Scan(&stats.ConversationCount); err != nil {
|
||||
return nil, fmt.Errorf("统计对话失败: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// ListProjectFactsForSparseCheck 返回用于待补全检测的事实字段(非 deprecated)。
|
||||
func (db *DB) ListProjectFactsForSparseCheck(projectID string) ([]struct {
|
||||
Category string
|
||||
FactKey string
|
||||
Body string
|
||||
}, error) {
|
||||
rows, err := db.Query(
|
||||
`SELECT category, fact_key, COALESCE(body,'') FROM project_facts WHERE project_id = ? AND confidence != 'deprecated'`,
|
||||
projectID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []struct {
|
||||
Category string
|
||||
FactKey string
|
||||
Body string
|
||||
}
|
||||
for rows.Next() {
|
||||
var row struct {
|
||||
Category string
|
||||
FactKey string
|
||||
Body string
|
||||
}
|
||||
if err := rows.Scan(&row.Category, &row.FactKey, &row.Body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// ListConversationsByProjectID 列出绑定到项目的对话。
|
||||
func (db *DB) ListConversationsByProjectID(projectID string, limit, offset int) ([]*Conversation, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
rows, err := db.Query(
|
||||
`SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id
|
||||
FROM conversations WHERE project_id = ? ORDER BY updated_at DESC LIMIT ? OFFSET ?`,
|
||||
projectID, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询项目对话失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conversations []*Conversation
|
||||
for rows.Next() {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
var pid sql.NullString
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &pid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pid.Valid {
|
||||
conv.ProjectID = strings.TrimSpace(pid.String)
|
||||
}
|
||||
conv.CreatedAt = parseDBTime(createdAt)
|
||||
conv.UpdatedAt = parseDBTime(updatedAt)
|
||||
conv.Pinned = pinned != 0
|
||||
conversations = append(conversations, &conv)
|
||||
}
|
||||
return conversations, rows.Err()
|
||||
}
|
||||
|
||||
// CountConversationsByProjectID 统计项目绑定对话数。
|
||||
func (db *DB) CountConversationsByProjectID(projectID string) (int, error) {
|
||||
var n int
|
||||
err := db.QueryRow(`SELECT COUNT(*) FROM conversations WHERE project_id = ?`, projectID).Scan(&n)
|
||||
return n, err
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestParseDBTime_projectFactFormats(t *testing.T) {
|
||||
cases := []string{
|
||||
"2026-05-26 11:13:07.442143+08:00",
|
||||
"2026-05-26 11:13:07",
|
||||
"2026-05-26T11:13:07.442143+08:00",
|
||||
}
|
||||
for _, s := range cases {
|
||||
got := parseDBTime(s)
|
||||
if got.IsZero() {
|
||||
t.Fatalf("parseDBTime(%q) returned zero", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListProjectFacts_updatedAtJSON(t *testing.T) {
|
||||
root, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Skip(err)
|
||||
}
|
||||
dbPath := filepath.Join(root, "..", "..", "data", "conversations.db")
|
||||
if _, err := os.Stat(dbPath); err != nil {
|
||||
t.Skip("conversations.db not found")
|
||||
}
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
projects, err := db.ListProjects("", 1, 0)
|
||||
if err != nil || len(projects) == 0 {
|
||||
t.Skip("no projects")
|
||||
}
|
||||
pid := projects[0].ID
|
||||
|
||||
list, err := db.ListProjectFacts(pid, ProjectFactListFilter{}, 5, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(list) == 0 {
|
||||
t.Skip("no facts")
|
||||
}
|
||||
for _, f := range list {
|
||||
if f.UpdatedAt.IsZero() {
|
||||
t.Fatalf("fact %s UpdatedAt is zero after ListProjectFacts", f.FactKey)
|
||||
}
|
||||
b, err := json.Marshal(f)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var m map[string]interface{}
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
raw, ok := m["updated_at"].(string)
|
||||
if !ok || raw == "" || raw[:4] == "0001" {
|
||||
t.Fatalf("bad updated_at in JSON: %v", m["updated_at"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDBTime_zeroOnGarbage(t *testing.T) {
|
||||
if !parseDBTime("").IsZero() {
|
||||
t.Fatal("expected zero for empty")
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure RFC3339 round-trip used by API is after year 2000.
|
||||
func TestParseDBTime_marshalRoundTrip(t *testing.T) {
|
||||
s := "2026-05-26 11:13:07.442143+08:00"
|
||||
tm := parseDBTime(s)
|
||||
b, err := json.Marshal(tm)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var back time.Time
|
||||
if err := json.Unmarshal(b, &back); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if back.IsZero() {
|
||||
t.Fatalf("unmarshal zero from %s", string(b))
|
||||
}
|
||||
}
|
||||
@@ -3,16 +3,94 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// VulnerabilityListFilter 列表/统计/导出共用的筛选条件
|
||||
type VulnerabilityListFilter struct {
|
||||
ID string
|
||||
Search string // 关键词模糊匹配(标题、描述、类型、目标等)
|
||||
ConversationID string
|
||||
ProjectID string
|
||||
Severity string
|
||||
Status string
|
||||
TaskID string
|
||||
ConversationTag string
|
||||
TaskTag string
|
||||
}
|
||||
|
||||
func escapeVulnerabilityLikePattern(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||
return "%" + s + "%"
|
||||
}
|
||||
|
||||
func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (string, []interface{}) {
|
||||
if f.ID != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, f.ID)
|
||||
}
|
||||
if f.ConversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, f.ConversationID)
|
||||
}
|
||||
if f.ProjectID != "" {
|
||||
query += " AND project_id = ?"
|
||||
args = append(args, f.ProjectID)
|
||||
}
|
||||
if f.TaskID != "" {
|
||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, f.TaskID, f.TaskID)
|
||||
}
|
||||
if f.ConversationTag != "" {
|
||||
query += " AND conversation_tag = ?"
|
||||
args = append(args, f.ConversationTag)
|
||||
}
|
||||
if f.TaskTag != "" {
|
||||
query += " AND task_tag = ?"
|
||||
args = append(args, f.TaskTag)
|
||||
}
|
||||
if f.Severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, f.Severity)
|
||||
}
|
||||
if f.Status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, f.Status)
|
||||
}
|
||||
search := strings.TrimSpace(f.Search)
|
||||
if search != "" {
|
||||
pattern := escapeVulnerabilityLikePattern(search)
|
||||
query += ` AND (
|
||||
LOWER(id) LIKE LOWER(?) OR
|
||||
LOWER(title) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(description, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(target, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(task_tag, '')) LIKE LOWER(?)
|
||||
)`
|
||||
for i := 0; i < 11; i++ {
|
||||
args = append(args, pattern)
|
||||
}
|
||||
}
|
||||
return query, args
|
||||
}
|
||||
|
||||
// Vulnerability 漏洞
|
||||
type Vulnerability struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
ConversationTag string `json:"conversation_tag,omitempty"`
|
||||
TaskTag string `json:"task_tag,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
@@ -44,17 +122,23 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
||||
}
|
||||
vuln.UpdatedAt = now
|
||||
|
||||
if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID != "" {
|
||||
if pid, err := db.GetConversationProjectID(vuln.ConversationID); err == nil {
|
||||
vuln.ProjectID = pid
|
||||
}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO vulnerabilities (
|
||||
id, conversation_id, conversation_tag, task_tag, title, description, severity, status,
|
||||
id, conversation_id, project_id, conversation_tag, task_tag, title, description, severity, status,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := db.Exec(
|
||||
query,
|
||||
vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
|
||||
vuln.ID, vuln.ConversationID, nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
|
||||
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
|
||||
vuln.Proof, vuln.Impact, vuln.Recommendation,
|
||||
vuln.CreatedAt, vuln.UpdatedAt,
|
||||
@@ -70,7 +154,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
||||
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
var vuln Vulnerability
|
||||
query := `
|
||||
SELECT id, conversation_id, title, description, severity, status,
|
||||
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status,
|
||||
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
|
||||
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
||||
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
||||
@@ -80,7 +164,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
`
|
||||
|
||||
err := db.QueryRow(query, id).Scan(
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
|
||||
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
||||
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
||||
&vuln.TaskID, &vuln.TaskQueueID,
|
||||
@@ -97,9 +181,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
}
|
||||
|
||||
// ListVulnerabilities 列出漏洞
|
||||
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) {
|
||||
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
|
||||
query := `
|
||||
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
|
||||
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
||||
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
||||
@@ -108,35 +192,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
||||
WHERE 1=1
|
||||
`
|
||||
args := []interface{}{}
|
||||
|
||||
if id != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, id)
|
||||
}
|
||||
if conversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if taskID != "" {
|
||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, taskID, taskID)
|
||||
}
|
||||
if conversationTag != "" {
|
||||
query += " AND conversation_tag = ?"
|
||||
args = append(args, conversationTag)
|
||||
}
|
||||
if taskTag != "" {
|
||||
query += " AND task_tag = ?"
|
||||
args = append(args, taskTag)
|
||||
}
|
||||
if severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, severity)
|
||||
}
|
||||
if status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
query, args = filter.appendWhere(query, args)
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
@@ -151,7 +207,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
||||
for rows.Next() {
|
||||
var vuln Vulnerability
|
||||
err := rows.Scan(
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
|
||||
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
||||
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
||||
&vuln.TaskID, &vuln.TaskQueueID,
|
||||
@@ -168,38 +224,10 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
||||
}
|
||||
|
||||
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
|
||||
func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) {
|
||||
func (db *DB) CountVulnerabilities(filter VulnerabilityListFilter) (int, error) {
|
||||
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
if id != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, id)
|
||||
}
|
||||
if conversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if taskID != "" {
|
||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, taskID, taskID)
|
||||
}
|
||||
if conversationTag != "" {
|
||||
query += " AND conversation_tag = ?"
|
||||
args = append(args, conversationTag)
|
||||
}
|
||||
if taskTag != "" {
|
||||
query += " AND task_tag = ?"
|
||||
args = append(args, taskTag)
|
||||
}
|
||||
if severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, severity)
|
||||
}
|
||||
if status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
query, args = filter.appendWhere(query, args)
|
||||
|
||||
var count int
|
||||
err := db.QueryRow(query, args...).Scan(&count)
|
||||
@@ -216,7 +244,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
||||
|
||||
query := `
|
||||
UPDATE vulnerabilities
|
||||
SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
|
||||
SET project_id = ?, conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
|
||||
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
|
||||
recommendation = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
@@ -224,7 +252,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
||||
|
||||
_, err := db.Exec(
|
||||
query,
|
||||
vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
|
||||
nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
|
||||
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
|
||||
vuln.Recommendation, vuln.UpdatedAt, id,
|
||||
)
|
||||
@@ -237,27 +265,32 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
||||
|
||||
// DeleteVulnerability 删除漏洞
|
||||
func (db *DB) DeleteVulnerability(id string) error {
|
||||
_, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id)
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("开启事务失败: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// 删除漏洞前先解除项目事实中的关联,避免前端继续显示已删除漏洞的短 ID。
|
||||
if _, err := tx.Exec("UPDATE project_facts SET related_vulnerability_id = NULL WHERE related_vulnerability_id = ?", id); err != nil {
|
||||
return fmt.Errorf("清理事实漏洞关联失败: %w", err)
|
||||
}
|
||||
if _, err := tx.Exec("DELETE FROM vulnerabilities WHERE id = ?", id); err != nil {
|
||||
return fmt.Errorf("删除漏洞失败: %w", err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("提交事务失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
|
||||
func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) {
|
||||
func (db *DB) GetVulnerabilityStats(filter VulnerabilityListFilter) (map[string]interface{}, error) {
|
||||
stats := make(map[string]interface{})
|
||||
|
||||
where := "WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
if conversationID != "" {
|
||||
where += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if taskID != "" {
|
||||
where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, taskID, taskID)
|
||||
}
|
||||
where, args = filter.appendWhere(where, args)
|
||||
|
||||
// 总漏洞数
|
||||
var totalCount int
|
||||
@@ -357,10 +390,15 @@ func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
|
||||
}
|
||||
projectIDs, err := collect(`SELECT DISTINCT project_id FROM vulnerabilities WHERE project_id IS NOT NULL AND project_id <> '' ORDER BY created_at DESC LIMIT 200`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询项目ID建议失败: %w", err)
|
||||
}
|
||||
|
||||
return map[string][]string{
|
||||
"vulnerability_ids": vulnIDs,
|
||||
"conversation_ids": conversationIDs,
|
||||
"project_ids": projectIDs,
|
||||
"task_ids": taskIDs,
|
||||
"queue_ids": queueIDs,
|
||||
"conversation_tags": conversationTags,
|
||||
|
||||
@@ -23,12 +23,16 @@ type ExecutionRecorder func(executionID string)
|
||||
const ToolErrorPrefix = "__CYBERSTRIKE_AI_TOOL_ERROR__\n"
|
||||
|
||||
// ToolsFromDefinitions 将单 Agent 使用的 OpenAI 风格工具定义转为 Eino InvokableTool,执行时走 Agent 的 MCP 路径。
|
||||
// invokeNotify 可选:与 runEinoADKAgentLoop 共享,在 InvokableRun 返回时触发 UI 与 pending 清理(与 ADK Tool 事件去重)。
|
||||
// einoAgentName 为该套工具所属 ChatModelAgent 的 Name(主代理或子代理 id),用于 SSE 上的 einoAgent 字段。
|
||||
func ToolsFromDefinitions(
|
||||
ag *agent.Agent,
|
||||
holder *ConversationHolder,
|
||||
defs []agent.Tool,
|
||||
rec ExecutionRecorder,
|
||||
toolOutputChunk func(toolName, toolCallID, chunk string),
|
||||
invokeNotify *ToolInvokeNotifyHolder,
|
||||
einoAgentName string,
|
||||
) ([]tool.BaseTool, error) {
|
||||
out := make([]tool.BaseTool, 0, len(defs))
|
||||
for _, d := range defs {
|
||||
@@ -40,12 +44,14 @@ func ToolsFromDefinitions(
|
||||
return nil, fmt.Errorf("tool %q: %w", d.Function.Name, err)
|
||||
}
|
||||
out = append(out, &mcpBridgeTool{
|
||||
info: info,
|
||||
name: d.Function.Name,
|
||||
agent: ag,
|
||||
holder: holder,
|
||||
record: rec,
|
||||
chunk: toolOutputChunk,
|
||||
info: info,
|
||||
name: d.Function.Name,
|
||||
agent: ag,
|
||||
holder: holder,
|
||||
record: rec,
|
||||
chunk: toolOutputChunk,
|
||||
invokeNotify: invokeNotify,
|
||||
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
@@ -77,12 +83,14 @@ func toolInfoFromDefinition(d agent.Tool) (*schema.ToolInfo, error) {
|
||||
}
|
||||
|
||||
type mcpBridgeTool struct {
|
||||
info *schema.ToolInfo
|
||||
name string
|
||||
agent *agent.Agent
|
||||
holder *ConversationHolder
|
||||
record ExecutionRecorder
|
||||
chunk func(toolName, toolCallID, chunk string)
|
||||
info *schema.ToolInfo
|
||||
name string
|
||||
agent *agent.Agent
|
||||
holder *ConversationHolder
|
||||
record ExecutionRecorder
|
||||
chunk func(toolName, toolCallID, chunk string)
|
||||
invokeNotify *ToolInvokeNotifyHolder
|
||||
einoAgentName string
|
||||
}
|
||||
|
||||
func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
@@ -90,8 +98,27 @@ func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return m.info, nil
|
||||
}
|
||||
|
||||
func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (out string, err error) {
|
||||
_ = opts
|
||||
toolCallID := compose.GetToolCallID(ctx)
|
||||
defer func() {
|
||||
if m.invokeNotify == nil {
|
||||
return
|
||||
}
|
||||
tid := strings.TrimSpace(toolCallID)
|
||||
if tid == "" {
|
||||
return
|
||||
}
|
||||
success := err == nil && !strings.HasPrefix(out, ToolErrorPrefix)
|
||||
body := out
|
||||
if err != nil {
|
||||
success = false
|
||||
} else if strings.HasPrefix(out, ToolErrorPrefix) {
|
||||
success = false
|
||||
body = strings.TrimPrefix(out, ToolErrorPrefix)
|
||||
}
|
||||
m.invokeNotify.Fire(tid, m.name, m.einoAgentName, success, body, err)
|
||||
}()
|
||||
return runMCPToolInvocation(ctx, m.agent, m.holder, m.name, argumentsInJSON, m.record, m.chunk)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package einomcp
|
||||
|
||||
import "sync"
|
||||
|
||||
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP 桥在每次 InvokableRun 结束时 Fire,
|
||||
// 用于在 ADK 未透出 schema.Tool 事件时仍推送 tool_result、清 pending,避免 UI 卡在「执行中」或迭代末 force-close。
|
||||
type ToolInvokeNotifyHolder struct {
|
||||
mu sync.RWMutex
|
||||
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
||||
}
|
||||
|
||||
// NewToolInvokeNotifyHolder 创建可在 ToolsFromDefinitions 与 run loop 之间共享的 holder。
|
||||
func NewToolInvokeNotifyHolder() *ToolInvokeNotifyHolder {
|
||||
return &ToolInvokeNotifyHolder{}
|
||||
}
|
||||
|
||||
// Set 由 runEinoADKAgentLoop 在开始消费 iter 之前调用;可多次覆盖(通常仅一次)。
|
||||
func (h *ToolInvokeNotifyHolder) Set(fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.fn = fn
|
||||
}
|
||||
|
||||
// Fire 由 mcpBridgeTool 在工具调用返回时调用;若尚未 Set 或 toolCallID 为空则忽略。
|
||||
func (h *ToolInvokeNotifyHolder) Fire(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
h.mu.RLock()
|
||||
fn := h.fn
|
||||
h.mu.RUnlock()
|
||||
if fn == nil {
|
||||
return
|
||||
}
|
||||
fn(toolCallID, toolName, einoAgent, success, content, invokeErr)
|
||||
}
|
||||
@@ -0,0 +1,451 @@
|
||||
// Package einoobserve attaches CloudWeGo Eino [callbacks.Handler] to ADK Runner contexts for
|
||||
// structured logging and optional SSE trace events (eino_trace_*).
|
||||
package einoobserve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ctxSpanKey struct{}
|
||||
|
||||
type ctxOtelSpanKey struct{}
|
||||
|
||||
// Params for attaching per-run callback instrumentation.
|
||||
type Params struct {
|
||||
Logger *zap.Logger
|
||||
Progress func(eventType, message string, data interface{})
|
||||
ConversationID string
|
||||
OrchMode string
|
||||
OrchestratorName string
|
||||
}
|
||||
|
||||
// AttachAgentRunCallbacks returns ctx wrapped with callbacks.InitCallbacks when enabled.
|
||||
// Safe to call with nil cfg or disabled cfg (returns ctx unchanged).
|
||||
func AttachAgentRunCallbacks(ctx context.Context, cfg *config.MultiAgentEinoCallbacksConfig, p Params) context.Context {
|
||||
if ctx == nil {
|
||||
return ctx
|
||||
}
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return ctx
|
||||
}
|
||||
mode := cfg.EinoCallbacksModeEffective()
|
||||
if mode == "off" {
|
||||
return ctx
|
||||
}
|
||||
runID := uuid.New().String()
|
||||
if p.Progress != nil && cfg.ShouldEmitEinoTraceSSE(mode) {
|
||||
p.Progress("eino_trace_run", "Eino callbacks session", map[string]interface{}{
|
||||
"runId": runID,
|
||||
"conversationId": strings.TrimSpace(p.ConversationID),
|
||||
"orchestration": strings.TrimSpace(p.OrchMode),
|
||||
"orchestratorName": strings.TrimSpace(p.OrchestratorName),
|
||||
"observeMode": mode,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
h := &runHandler{
|
||||
cfg: *cfg,
|
||||
mode: mode,
|
||||
params: p,
|
||||
runID: runID,
|
||||
}
|
||||
b := callbacks.NewHandlerBuilder().
|
||||
OnStartFn(h.onStart).
|
||||
OnEndFn(h.onEnd).
|
||||
OnErrorFn(h.onError)
|
||||
if mode == "full" {
|
||||
b = b.OnStartWithStreamInputFn(h.onStartStreamIn).OnEndWithStreamOutputFn(h.onEndStreamOut)
|
||||
}
|
||||
ri := &callbacks.RunInfo{
|
||||
Name: "CyberStrikeADKRun",
|
||||
Type: strings.TrimSpace(p.OrchMode),
|
||||
Component: components.Component("AgentSession"),
|
||||
}
|
||||
return callbacks.InitCallbacks(ctx, ri, b.Build())
|
||||
}
|
||||
|
||||
type runHandler struct {
|
||||
cfg config.MultiAgentEinoCallbacksConfig
|
||||
mode string
|
||||
params Params
|
||||
runID string
|
||||
|
||||
mu sync.Mutex
|
||||
spanStack []string
|
||||
seq atomic.Uint64
|
||||
}
|
||||
|
||||
func safeRunInfo(info *callbacks.RunInfo) callbacks.RunInfo {
|
||||
if info == nil {
|
||||
return callbacks.RunInfo{
|
||||
Name: "unknown",
|
||||
Type: "unknown",
|
||||
Component: components.Component("unknown"),
|
||||
}
|
||||
}
|
||||
return *info
|
||||
}
|
||||
|
||||
func (h *runHandler) genSpanID() string {
|
||||
return fmt.Sprintf("%s-%d", h.runID, h.seq.Add(1))
|
||||
}
|
||||
|
||||
func (h *runHandler) popSpan() (id string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if len(h.spanStack) == 0 {
|
||||
return ""
|
||||
}
|
||||
id = h.spanStack[len(h.spanStack)-1]
|
||||
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||
return id
|
||||
}
|
||||
|
||||
// popMatching removes the given id from the stack top if it matches; otherwise pops until empty or match (rare ordering mismatch).
|
||||
func (h *runHandler) popMatching(want string) string {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if want == "" {
|
||||
if len(h.spanStack) == 0 {
|
||||
return ""
|
||||
}
|
||||
id := h.spanStack[len(h.spanStack)-1]
|
||||
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||
return id
|
||||
}
|
||||
for len(h.spanStack) > 0 {
|
||||
top := h.spanStack[len(h.spanStack)-1]
|
||||
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||
if top == want {
|
||||
return top
|
||||
}
|
||||
}
|
||||
return want
|
||||
}
|
||||
|
||||
func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||
ri := safeRunInfo(info)
|
||||
var parentID string
|
||||
h.mu.Lock()
|
||||
if len(h.spanStack) > 0 {
|
||||
parentID = h.spanStack[len(h.spanStack)-1]
|
||||
}
|
||||
spanID := h.genSpanID()
|
||||
h.spanStack = append(h.spanStack, spanID)
|
||||
h.mu.Unlock()
|
||||
|
||||
inSum := summarizeCallbackInput(input, h.cfg.EinoCallbacksMaxInputSummaryRunes())
|
||||
if h.cfg.OtelTracingActive() {
|
||||
tracer := otel.Tracer("cyberstrike/eino")
|
||||
spanName := callbackSpanName(info)
|
||||
var sp trace.Span
|
||||
ctx, sp = tracer.Start(ctx, spanName,
|
||||
trace.WithSpanKind(trace.SpanKindInternal),
|
||||
trace.WithAttributes(
|
||||
attribute.String("eino.component", string(ri.Component)),
|
||||
attribute.String("eino.name", ri.Name),
|
||||
attribute.String("eino.type", ri.Type),
|
||||
attribute.String("cyberstrike.run_id", h.runID),
|
||||
attribute.String("cyberstrike.conversation_id", strings.TrimSpace(h.params.ConversationID)),
|
||||
attribute.String("cyberstrike.orchestration", strings.TrimSpace(h.params.OrchMode)),
|
||||
),
|
||||
)
|
||||
if inSum != "" {
|
||||
sp.SetAttributes(attribute.String("eino.input.summary", truncateForAttr(inSum, 256)))
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxOtelSpanKey{}, sp)
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
fields := []zap.Field{
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("parentSpanId", parentID),
|
||||
zap.String("component", string(ri.Component)),
|
||||
zap.String("name", ri.Name),
|
||||
zap.String("type", ri.Type),
|
||||
zap.String("phase", "start"),
|
||||
}
|
||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||
if sc := sp.SpanContext(); sc.IsValid() {
|
||||
fields = append(fields,
|
||||
zap.String("trace_id", sc.TraceID().String()),
|
||||
zap.String("otel_span_id", sc.SpanID().String()),
|
||||
)
|
||||
}
|
||||
}
|
||||
if h.cfg.ZapVerbose {
|
||||
h.params.Logger.Debug("eino_callback", append(fields, zap.String("inputSummary", inSum))...)
|
||||
} else {
|
||||
h.params.Logger.Info("eino_callback", fields...)
|
||||
}
|
||||
}
|
||||
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||
h.params.Progress("eino_trace_start", "", map[string]interface{}{
|
||||
"runId": h.runID,
|
||||
"spanId": spanID,
|
||||
"parentSpanId": parentID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(ri.Component),
|
||||
"name": ri.Name,
|
||||
"type": ri.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"inputSummary": inSum,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxSpanKey{}, spanID)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
||||
ri := safeRunInfo(info)
|
||||
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||
if spanID == "" {
|
||||
spanID = h.popSpan()
|
||||
} else {
|
||||
spanID = h.popMatching(spanID)
|
||||
}
|
||||
outSum := summarizeCallbackOutput(output, h.cfg.EinoCallbacksMaxOutputSummaryRunes())
|
||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||
if outSum != "" {
|
||||
sp.SetAttributes(attribute.String("eino.output.summary", truncateForAttr(outSum, 256)))
|
||||
}
|
||||
sp.SetStatus(codes.Ok, "")
|
||||
sp.End()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
fields := []zap.Field{
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("component", string(ri.Component)),
|
||||
zap.String("name", ri.Name),
|
||||
zap.String("type", ri.Type),
|
||||
zap.String("phase", "end"),
|
||||
}
|
||||
if h.cfg.ZapVerbose {
|
||||
h.params.Logger.Debug("eino_callback", append(fields, zap.String("outputSummary", outSum))...)
|
||||
} else {
|
||||
h.params.Logger.Info("eino_callback", fields...)
|
||||
}
|
||||
}
|
||||
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||
h.params.Progress("eino_trace_end", "", map[string]interface{}{
|
||||
"runId": h.runID,
|
||||
"spanId": spanID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(ri.Component),
|
||||
"name": ri.Name,
|
||||
"type": ri.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"outputSummary": outSum,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||
ri := safeRunInfo(info)
|
||||
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||
if spanID == "" {
|
||||
spanID = h.popSpan()
|
||||
} else {
|
||||
spanID = h.popMatching(spanID)
|
||||
}
|
||||
msg := ""
|
||||
if err != nil {
|
||||
msg = truncateRunes(err.Error(), h.cfg.EinoCallbacksMaxOutputSummaryRunes())
|
||||
}
|
||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||
if err != nil {
|
||||
sp.RecordError(err)
|
||||
}
|
||||
sp.SetStatus(codes.Error, msg)
|
||||
sp.End()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Warn("eino_callback_error",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("component", string(ri.Component)),
|
||||
zap.String("name", ri.Name),
|
||||
zap.String("type", ri.Type),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||
h.params.Progress("eino_trace_error", msg, map[string]interface{}{
|
||||
"runId": h.runID,
|
||||
"spanId": spanID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(ri.Component),
|
||||
"name": ri.Name,
|
||||
"type": ri.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"error": msg,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onStartStreamIn(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
|
||||
ri := safeRunInfo(info)
|
||||
if input != nil {
|
||||
input.Close()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Debug("eino_callback_stream_in",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("component", string(ri.Component)),
|
||||
zap.String("name", ri.Name),
|
||||
)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onEndStreamOut(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
|
||||
ri := safeRunInfo(info)
|
||||
if output != nil {
|
||||
output.Close()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Debug("eino_callback_stream_out",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("component", string(ri.Component)),
|
||||
zap.String("name", ri.Name),
|
||||
)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func callbackSpanName(info *callbacks.RunInfo) string {
|
||||
if info == nil {
|
||||
return "eino.callback"
|
||||
}
|
||||
comp := strings.TrimSpace(string(info.Component))
|
||||
name := strings.TrimSpace(info.Name)
|
||||
typ := strings.TrimSpace(info.Type)
|
||||
if name != "" && comp != "" {
|
||||
return comp + "/" + name
|
||||
}
|
||||
if typ != "" && comp != "" {
|
||||
return comp + "[" + typ + "]"
|
||||
}
|
||||
if comp != "" {
|
||||
return comp
|
||||
}
|
||||
return "eino.callback"
|
||||
}
|
||||
|
||||
func truncateForAttr(s string, maxRunes int) string {
|
||||
return truncateRunes(s, maxRunes)
|
||||
}
|
||||
|
||||
func summarizeCallbackInput(in callbacks.CallbackInput, maxRunes int) string {
|
||||
if in == nil {
|
||||
return ""
|
||||
}
|
||||
if ai := adk.ConvAgentCallbackInput(in); ai != nil {
|
||||
parts := []string{"agent"}
|
||||
if ai.Input != nil {
|
||||
parts = append(parts, fmt.Sprintf("messages=%d", len(ai.Input.Messages)))
|
||||
}
|
||||
if ai.ResumeInfo != nil {
|
||||
parts = append(parts, "resume=true")
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
if mi := model.ConvCallbackInput(in); mi != nil {
|
||||
return fmt.Sprintf("chatModel messages=%d tools=%d", len(mi.Messages), len(mi.Tools))
|
||||
}
|
||||
if ti := tool.ConvCallbackInput(in); ti != nil {
|
||||
raw := ti.ArgumentsInJSON
|
||||
return "tool args=" + truncateRunes(raw, maxRunes)
|
||||
}
|
||||
b, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%T", in)
|
||||
}
|
||||
return truncateRunes(string(b), maxRunes)
|
||||
}
|
||||
|
||||
func summarizeCallbackOutput(out callbacks.CallbackOutput, maxRunes int) string {
|
||||
if out == nil {
|
||||
return ""
|
||||
}
|
||||
if ao := adk.ConvAgentCallbackOutput(out); ao != nil {
|
||||
return "agent_events=stream"
|
||||
}
|
||||
if mo := model.ConvCallbackOutput(out); mo != nil && mo.Message != nil {
|
||||
s := ""
|
||||
if mo.Message.Content != "" {
|
||||
s = mo.Message.Content
|
||||
}
|
||||
if mo.TokenUsage != nil {
|
||||
return fmt.Sprintf("tokens total=%d completion=%d prompt=%d text=%s",
|
||||
mo.TokenUsage.TotalTokens, mo.TokenUsage.CompletionTokens, mo.TokenUsage.PromptTokens,
|
||||
truncateRunes(s, minInt(120, maxRunes)))
|
||||
}
|
||||
return "assistant len=" + itoa(len(s))
|
||||
}
|
||||
if to := tool.ConvCallbackOutput(out); to != nil {
|
||||
if to.Response != "" {
|
||||
return truncateRunes(to.Response, maxRunes)
|
||||
}
|
||||
if to.ToolOutput != nil {
|
||||
return "tool_result multimodal"
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%T", out)
|
||||
}
|
||||
return truncateRunes(string(b), maxRunes)
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func itoa(n int) string {
|
||||
return fmt.Sprintf("%d", n)
|
||||
}
|
||||
|
||||
func truncateRunes(s string, maxRunes int) string {
|
||||
if maxRunes <= 0 {
|
||||
return ""
|
||||
}
|
||||
r := []rune(s)
|
||||
if len(r) <= maxRunes {
|
||||
return s
|
||||
}
|
||||
return string(r[:maxRunes]) + "…"
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package einoobserve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
)
|
||||
|
||||
func TestAttachAgentRunCallbacks_Disabled(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cfg := &config.MultiAgentEinoCallbacksConfig{Enabled: false}
|
||||
out := AttachAgentRunCallbacks(ctx, cfg, Params{})
|
||||
if out != ctx {
|
||||
t.Fatalf("expected same ctx when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateRunes(t *testing.T) {
|
||||
if got := truncateRunes("abc", 10); got != "abc" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
if got := truncateRunes("abcdefghij", 4); got != "abcd…" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package einoobserve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
|
||||
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
otelMu sync.Mutex
|
||||
otelShutdown func(context.Context) error
|
||||
otelInitialized bool
|
||||
)
|
||||
|
||||
// InitOtelFromConfig installs the global OpenTelemetry TracerProvider when
|
||||
// eino_callbacks.otel is enabled and exporter is not none. Safe to call multiple times.
|
||||
func InitOtelFromConfig(cfg *config.MultiAgentEinoCallbacksConfig, log *zap.Logger) (shutdown func(context.Context) error, err error) {
|
||||
shutdown = func(context.Context) error { return nil }
|
||||
if cfg == nil || !cfg.OtelTracingActive() {
|
||||
return shutdown, nil
|
||||
}
|
||||
|
||||
otelMu.Lock()
|
||||
defer otelMu.Unlock()
|
||||
if otelInitialized {
|
||||
if otelShutdown != nil {
|
||||
return otelShutdown, nil
|
||||
}
|
||||
return shutdown, nil
|
||||
}
|
||||
|
||||
oc := cfg.Otel
|
||||
expKind := oc.OtelExporterEffective()
|
||||
ctx := context.Background()
|
||||
|
||||
var exporter sdktrace.SpanExporter
|
||||
switch expKind {
|
||||
case "stdout":
|
||||
exporter, err = stdouttrace.New()
|
||||
if err != nil {
|
||||
return shutdown, fmt.Errorf("eino otel stdout exporter: %w", err)
|
||||
}
|
||||
case "otlphttp":
|
||||
ep := strings.TrimSpace(oc.OTLPEndpoint)
|
||||
if ep == "" {
|
||||
ep = "localhost:4318"
|
||||
}
|
||||
exporter, err = otlptracehttp.New(ctx,
|
||||
otlptracehttp.WithEndpoint(ep),
|
||||
otlptracehttp.WithURLPath("/v1/traces"),
|
||||
)
|
||||
if err != nil {
|
||||
return shutdown, fmt.Errorf("eino otel otlphttp exporter: %w", err)
|
||||
}
|
||||
default:
|
||||
return shutdown, nil
|
||||
}
|
||||
|
||||
res, err := resource.New(ctx,
|
||||
resource.WithAttributes(
|
||||
semconv.ServiceName(oc.ServiceNameEffective()),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return shutdown, fmt.Errorf("eino otel resource: %w", err)
|
||||
}
|
||||
|
||||
sampler := sdktrace.ParentBased(sdktrace.TraceIDRatioBased(oc.SampleRatioEffective()))
|
||||
tp := sdktrace.NewTracerProvider(
|
||||
sdktrace.WithBatcher(exporter),
|
||||
sdktrace.WithResource(res),
|
||||
sdktrace.WithSampler(sampler),
|
||||
)
|
||||
otel.SetTracerProvider(tp)
|
||||
|
||||
otelShutdown = tp.Shutdown
|
||||
otelInitialized = true
|
||||
if log != nil {
|
||||
log.Info("eino otel: tracer provider initialized",
|
||||
zap.String("exporter", expKind),
|
||||
zap.String("service", oc.ServiceNameEffective()),
|
||||
zap.Float64("sample_ratio", oc.SampleRatioEffective()),
|
||||
)
|
||||
}
|
||||
return otelShutdown, nil
|
||||
}
|
||||
|
||||
// ShutdownOtel flushes and shuts down the global TracerProvider if it was installed.
|
||||
func ShutdownOtel(ctx context.Context) error {
|
||||
otelMu.Lock()
|
||||
fn := otelShutdown
|
||||
otelShutdown = nil
|
||||
inited := otelInitialized
|
||||
otelInitialized = false
|
||||
otelMu.Unlock()
|
||||
if !inited || fn == nil {
|
||||
return nil
|
||||
}
|
||||
return fn(ctx)
|
||||
}
|
||||
+386
-145
@@ -17,11 +17,14 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/robfig/cron/v3"
|
||||
@@ -84,6 +87,23 @@ func normalizeProcessDetailText(s string) string {
|
||||
// discardPlanningIfEchoesToolResult drops buffered planning text when it only repeats the
|
||||
// upcoming tool_result body. Streaming models often echo tool stdout in chunk.Content; flushing
|
||||
// that into "planning" before persisting tool_result duplicates the output after page refresh.
|
||||
// sameResponseStreamMeta 判断是否为同一段主通道流(Eino ADK 可能对同一 MessageStream 重复发 response_start)。
|
||||
func sameResponseStreamMeta(a, b map[string]interface{}) bool {
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
agentA, _ := a["einoAgent"].(string)
|
||||
agentB, _ := b["einoAgent"].(string)
|
||||
agentA = strings.TrimSpace(agentA)
|
||||
agentB = strings.TrimSpace(agentB)
|
||||
if agentA == "" || !strings.EqualFold(agentA, agentB) {
|
||||
return false
|
||||
}
|
||||
orchA, _ := a["orchestration"].(string)
|
||||
orchB, _ := b["orchestration"].(string)
|
||||
return strings.TrimSpace(orchA) == strings.TrimSpace(orchB)
|
||||
}
|
||||
|
||||
func discardPlanningIfEchoesToolResult(respPlan *responsePlanAgg, toolData interface{}) {
|
||||
if respPlan == nil {
|
||||
return
|
||||
@@ -129,6 +149,12 @@ type AgentHandler struct {
|
||||
batchRunning map[string]struct{}
|
||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *AgentHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||
@@ -201,18 +227,35 @@ type ChatAttachment struct {
|
||||
ServerPath string `json:"serverPath,omitempty"` // 已保存在 chat_uploads 下的绝对路径(由 POST /api/chat-uploads 返回)
|
||||
}
|
||||
|
||||
// ChatReasoningRequest 对话页「模型推理」意图(仅 Eino 路径消费;原生 agent-loop 忽略)。
|
||||
type ChatReasoningRequest struct {
|
||||
// Mode: default(跟随系统)| off | on | auto
|
||||
Mode string `json:"mode,omitempty"`
|
||||
// Effort: low | medium | high | max | xhigh(原样下发;不同网关最高档命名不同)。空表示不指定。
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest 聊天请求
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty"` // 新对话绑定的项目(可选;未指定时可用 config.project.default_project_id)
|
||||
Role string `json:"role,omitempty"` // 角色名称
|
||||
Attachments []ChatAttachment `json:"attachments,omitempty"`
|
||||
WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具
|
||||
Hitl *HITLRequest `json:"hitl,omitempty"`
|
||||
Reasoning *ChatReasoningRequest `json:"reasoning,omitempty"`
|
||||
// Orchestration 仅对 /api/multi-agent、/api/multi-agent/stream:deep | plan_execute | supervisor;空则等同 deep。机器人/批量等无请求体时由服务端默认 deep。/api/eino-agent* 不使用此字段。
|
||||
Orchestration string `json:"orchestration,omitempty"`
|
||||
}
|
||||
|
||||
func chatReasoningToClientIntent(r *ChatReasoningRequest) *reasoning.ClientIntent {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return &reasoning.ClientIntent{Mode: r.Mode, Effort: r.Effort}
|
||||
}
|
||||
|
||||
type HITLRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
@@ -535,7 +578,9 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
conversationID := req.ConversationID
|
||||
if conversationID == "" {
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop")
|
||||
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -567,14 +612,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
@@ -617,6 +655,8 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
builtin.ToolWebshellFileRead,
|
||||
builtin.ToolWebshellFileWrite,
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
builtin.ToolListKnowledgeRiskTypes,
|
||||
builtin.ToolSearchKnowledgeBase,
|
||||
}
|
||||
@@ -664,7 +704,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, "", nil)
|
||||
|
||||
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
|
||||
@@ -706,11 +746,45 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) finalizeRobotAgentError(ctx context.Context, assistantMessageID, conversationID string, resultMA *multiagent.RunResult, errMA error) (string, string, error) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errMsg := "执行失败: " + errMA.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
return "", conversationID, errMA
|
||||
}
|
||||
|
||||
func (h *AgentHandler) finalizeRobotAgentSuccess(assistantMessageID, conversationID string, resultMA *multiagent.RunResult) (string, string, error) {
|
||||
if assistantMessageID != "" {
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||
}
|
||||
} else {
|
||||
if _, err := h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
|
||||
}
|
||||
return resultMA.Response, conversationID, nil
|
||||
}
|
||||
|
||||
// ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复
|
||||
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationID, message, role string) (response string, convID string, err error) {
|
||||
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, conversationID, message, role string) (response string, convID string, err error) {
|
||||
if conversationID == "" {
|
||||
title := safeTruncateString(message, 50)
|
||||
conv, createErr := h.db.CreateConversation(title)
|
||||
src := "robot"
|
||||
if strings.TrimSpace(platform) != "" {
|
||||
src = "robot:" + strings.TrimSpace(platform)
|
||||
}
|
||||
meta := audit.ConversationCreateMeta(src)
|
||||
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||
conv, createErr := h.db.CreateConversation(title, meta)
|
||||
if createErr != nil {
|
||||
return "", "", fmt.Errorf("创建对话失败: %w", createErr)
|
||||
}
|
||||
@@ -758,61 +832,92 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
progressCallback := h.createProgressCallback(ctx, nil, conversationID, assistantMessageID, nil)
|
||||
|
||||
useRobotMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.RobotUseMultiAgent
|
||||
if useRobotMulti {
|
||||
resultMA, errMA := multiagent.RunDeepAgent(
|
||||
ctx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
finalMessage,
|
||||
agentHistoryMessages,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
"deep",
|
||||
)
|
||||
if errMA != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errMsg := "执行失败: " + errMA.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
return "", conversationID, errMA
|
||||
// 注册运行中任务并向 taskEventBus 镜像进度事件,供 Web 端 task-events 补流(与 agent-loop/stream 一致)。
|
||||
taskCtx, cancelWithCause := context.WithCancelCause(ctx)
|
||||
defer cancelWithCause(nil)
|
||||
taskStatus := "completed"
|
||||
defer func() {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}()
|
||||
if _, err := h.tasks.StartTask(conversationID, message, cancelWithCause); err != nil {
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
return "", conversationID, fmt.Errorf("当前会话已有任务正在执行中,请稍后再试")
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(resultMA.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(resultMA.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, err = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
resultMA.Response, mcpIDsJSON, time.Now(), assistantMessageID,
|
||||
return "", conversationID, fmt.Errorf("无法启动任务: %w", err)
|
||||
}
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, nil)
|
||||
|
||||
robotMode := "react"
|
||||
if h.config != nil {
|
||||
robotMode = config.NormalizeRobotAgentMode(h.config.MultiAgent)
|
||||
}
|
||||
switch robotMode {
|
||||
case "eino_single":
|
||||
curHist := agentHistoryMessages
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
||||
conversationID, curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(err))
|
||||
if errMA == nil {
|
||||
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
|
||||
transientRunAttempts = 0
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if _, err = h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
|
||||
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
|
||||
case "deep", "plan_execute", "supervisor":
|
||||
if h.config == nil || !h.config.MultiAgent.Enabled {
|
||||
h.logger.Warn("机器人配置为多代理模式但未启用 multi_agent,回退原生 ReAct",
|
||||
zap.String("robot_mode", robotMode))
|
||||
break
|
||||
}
|
||||
return resultMA.Response, conversationID, nil
|
||||
curHist := agentHistoryMessages
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunDeepAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
||||
conversationID, curMsg, curHist, roleTools, progressCallback,
|
||||
h.agentsMarkdownDir, robotMode, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
if errMA == nil {
|
||||
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
|
||||
transientRunAttempts = 0
|
||||
break
|
||||
}
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
|
||||
}
|
||||
|
||||
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
if err != nil {
|
||||
taskStatus = "failed"
|
||||
errMsg := "执行失败: " + err.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
@@ -823,17 +928,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
|
||||
// 更新助手消息内容与 MCP 执行 ID(与 stream 一致)
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, err = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response, mcpIDsJSON, time.Now(), assistantMessageID,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(err))
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||
}
|
||||
} else {
|
||||
if _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs); err != nil {
|
||||
@@ -853,6 +949,23 @@ type StreamEvent struct {
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// publishProgressToTaskEventBus 将进度事件镜像到 taskEventBus(机器人/无 HTTP SSE 客户端时供 Web task-events 订阅)。
|
||||
func (h *AgentHandler) publishProgressToTaskEventBus(conversationID, eventType, message string, data interface{}) {
|
||||
if h == nil || h.taskEventBus == nil || strings.TrimSpace(conversationID) == "" {
|
||||
return
|
||||
}
|
||||
event := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
eventJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sseLine := make([]byte, 0, len(eventJSON)+8)
|
||||
sseLine = append(sseLine, []byte("data: ")...)
|
||||
sseLine = append(sseLine, eventJSON...)
|
||||
sseLine = append(sseLine, '\n', '\n')
|
||||
h.taskEventBus.Publish(conversationID, sseLine)
|
||||
}
|
||||
|
||||
// createProgressCallback 创建进度回调函数,用于保存processDetails
|
||||
// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件
|
||||
func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
|
||||
@@ -891,13 +1004,17 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return ""
|
||||
}
|
||||
|
||||
// thinking_stream_*:不逐条落库,按 streamId 聚合,在后续关键事件前补一条可持久化的 thinking
|
||||
// thinking_stream_*(ReAct 等助手正文流)与 reasoning_chain_stream_*(Eino ReasoningContent):
|
||||
// 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。
|
||||
type thinkingBuf struct {
|
||||
b strings.Builder
|
||||
meta map[string]interface{}
|
||||
b strings.Builder
|
||||
meta map[string]interface{}
|
||||
persistAs string // "thinking" | "reasoning_chain"
|
||||
}
|
||||
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
|
||||
flushedThinking := make(map[string]bool) // streamId -> flushed
|
||||
seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature
|
||||
seenToolResultSigs := make(map[string]string) // toolCallId -> payload signature
|
||||
|
||||
// response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta;
|
||||
// 聚合为一条 planning 写入 process_details,刷新后与线上一致。
|
||||
@@ -948,17 +1065,46 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
data[k] = v
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "thinking", content, data); err != nil {
|
||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "thinking"))
|
||||
persist := tb.persistAs
|
||||
if persist != "reasoning_chain" {
|
||||
persist = "thinking"
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, persist, content, data); err != nil {
|
||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", persist))
|
||||
}
|
||||
flushedThinking[sid] = true
|
||||
}
|
||||
}
|
||||
|
||||
return func(eventType, message string, data interface{}) {
|
||||
// 如果提供了sendEventFunc,发送流式事件
|
||||
// 上游在重试/补偿时可能重复回调相同 tool_call/tool_result。
|
||||
// 这里做幂等过滤,保证前端展示和 process_details 都以唯一事件为准。
|
||||
if (eventType == "tool_call" || eventType == "tool_result") && data != nil {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolCallID := strings.TrimSpace(fmt.Sprint(dataMap["toolCallId"]))
|
||||
if toolCallID != "" && toolCallID != "<nil>" {
|
||||
payloadJSON, _ := json.Marshal(dataMap)
|
||||
sig := eventType + "|" + message + "|" + string(payloadJSON)
|
||||
seen := seenToolCallSigs
|
||||
if eventType == "tool_result" {
|
||||
seen = seenToolResultSigs
|
||||
}
|
||||
if prev, exists := seen[toolCallID]; exists && prev == sig {
|
||||
h.logger.Debug("跳过重复工具进度事件",
|
||||
zap.String("eventType", eventType),
|
||||
zap.String("toolCallId", toolCallID))
|
||||
return
|
||||
}
|
||||
seen[toolCallID] = sig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 流式:写 HTTP SSE;非流式(机器人等):镜像到 taskEventBus 供 Web 订阅
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc(eventType, message, data)
|
||||
} else {
|
||||
h.publishProgressToTaskEventBus(conversationID, eventType, message, data)
|
||||
}
|
||||
|
||||
// 保存tool_call事件中的参数
|
||||
@@ -1147,6 +1293,17 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
|
||||
// 多代理主代理「规划中」:response_start / response_delta 仅用于 SSE,聚合落一条 planning
|
||||
if eventType == "response_start" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sameResponseStreamMeta(respPlan.meta, dataMap) {
|
||||
if respPlan.meta == nil {
|
||||
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
||||
}
|
||||
for k, v := range dataMap {
|
||||
respPlan.meta[k] = v
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
flushResponsePlan()
|
||||
respPlan.meta = nil
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
@@ -1159,7 +1316,16 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return
|
||||
}
|
||||
if eventType == "response_delta" {
|
||||
respPlan.b.WriteString(message)
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
|
||||
respPlan.b.Reset()
|
||||
respPlan.b.WriteString(acc)
|
||||
} else {
|
||||
respPlan.b.WriteString(message)
|
||||
}
|
||||
} else {
|
||||
respPlan.b.WriteString(message)
|
||||
}
|
||||
if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil {
|
||||
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
||||
for k, v := range dataMap {
|
||||
@@ -1177,14 +1343,20 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return
|
||||
}
|
||||
|
||||
// 聚合 thinking_stream_*(ReasoningContent),不逐条落库
|
||||
if eventType == "thinking_stream_start" {
|
||||
// 聚合 thinking_stream_* / reasoning_chain_stream_*,不逐条落库
|
||||
if eventType == "thinking_stream_start" || eventType == "reasoning_chain_stream_start" {
|
||||
persistAs := "thinking"
|
||||
if eventType == "reasoning_chain_stream_start" {
|
||||
persistAs = "reasoning_chain"
|
||||
}
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||
tb := thinkingStreams[sid]
|
||||
if tb == nil {
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}}
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}, persistAs: persistAs}
|
||||
thinkingStreams[sid] = tb
|
||||
} else {
|
||||
tb.persistAs = persistAs
|
||||
}
|
||||
// 记录元信息(source/einoAgent/einoRole/iteration 等)
|
||||
for k, v := range dataMap {
|
||||
@@ -1194,16 +1366,26 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
return
|
||||
}
|
||||
if eventType == "thinking_stream_delta" {
|
||||
if eventType == "thinking_stream_delta" || eventType == "reasoning_chain_stream_delta" {
|
||||
persistAs := "thinking"
|
||||
if eventType == "reasoning_chain_stream_delta" {
|
||||
persistAs = "reasoning_chain"
|
||||
}
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||
tb := thinkingStreams[sid]
|
||||
if tb == nil {
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}}
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}, persistAs: persistAs}
|
||||
thinkingStreams[sid] = tb
|
||||
} else if tb.persistAs == "" {
|
||||
tb.persistAs = persistAs
|
||||
}
|
||||
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
|
||||
tb.b.Reset()
|
||||
tb.b.WriteString(acc)
|
||||
} else {
|
||||
tb.b.WriteString(message)
|
||||
}
|
||||
// delta 片段直接拼接;message 本身就是 reasoning content
|
||||
tb.b.WriteString(message)
|
||||
// 有时 delta 先到 start 未到,补充元信息
|
||||
for k, v := range dataMap {
|
||||
tb.meta[k] = v
|
||||
@@ -1213,10 +1395,9 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return
|
||||
}
|
||||
|
||||
// 当 Agent 同时发送 thinking_stream_* 和 thinking(带同一 streamId)时,
|
||||
// thinking_stream_* 已经会在 flushThinkingStreams() 聚合落库;
|
||||
// 这里跳过同 streamId 的 thinking,避免 processDetails 双份展示。
|
||||
if eventType == "thinking" {
|
||||
// 当 Agent 同时发送 *_stream_* 与同名 streamId 的 thinking/reasoning_chain 时,
|
||||
// 流式聚合已会在 flushThinkingStreams() 落库;此处跳过逐条重复。
|
||||
if eventType == "thinking" || eventType == "reasoning_chain" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||
if tb, exists := thinkingStreams[sid]; exists && tb != nil {
|
||||
@@ -1239,13 +1420,17 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
eventType != "response_start" &&
|
||||
eventType != "response_delta" &&
|
||||
eventType != "tool_result_delta" &&
|
||||
eventType != "eino_trace_run" &&
|
||||
eventType != "eino_trace_start" &&
|
||||
eventType != "eino_trace_end" &&
|
||||
eventType != "eino_trace_error" &&
|
||||
eventType != "eino_agent_reply_stream_start" &&
|
||||
eventType != "eino_agent_reply_stream_delta" &&
|
||||
eventType != "eino_agent_reply_stream_end" {
|
||||
if eventType == "tool_result" {
|
||||
discardPlanningIfEchoesToolResult(&respPlan, data)
|
||||
}
|
||||
// 在关键过程事件落库前,先把「规划中」与 thinking_stream 落库
|
||||
// 在关键过程事件落库前,先把「规划中」与聚合中的 thinking / reasoning_chain 流落库
|
||||
flushResponsePlan()
|
||||
flushThinkingStreams()
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil {
|
||||
@@ -1260,7 +1445,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// 对于流式请求,也发送SSE格式的错误
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
event := StreamEvent{
|
||||
@@ -1282,7 +1467,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
)
|
||||
|
||||
// 设置SSE响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no") // 禁用nginx缓冲
|
||||
@@ -1392,10 +1577,13 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
var conv *database.Conversation
|
||||
var err error
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop_stream")
|
||||
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||
if req.WebShellConnectionID != "" {
|
||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
|
||||
meta.Source = "webshell_chat"
|
||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title, meta)
|
||||
} else {
|
||||
conv, err = h.db.CreateConversation(title)
|
||||
conv, err = h.db.CreateConversation(title, meta)
|
||||
}
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
@@ -1427,14 +1615,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
@@ -1475,6 +1656,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
builtin.ToolWebshellFileRead,
|
||||
builtin.ToolWebshellFileWrite,
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
builtin.ToolListKnowledgeRiskTypes,
|
||||
builtin.ToolSearchKnowledgeBase,
|
||||
}
|
||||
@@ -1605,7 +1788,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
|
||||
defer close(stopKeepalive)
|
||||
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
cause := context.Cause(baseCtx)
|
||||
@@ -1727,20 +1910,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
|
||||
// 更新助手消息内容
|
||||
if assistantMsg != nil {
|
||||
_, err = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response,
|
||||
func() string {
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
return string(jsonData)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
time.Now(), assistantMessageID,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Error("更新助手消息失败", zap.Error(err))
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Error("更新助手消息失败", zap.Error(errU))
|
||||
}
|
||||
} else {
|
||||
// 如果之前创建失败,现在创建
|
||||
@@ -1789,27 +1960,51 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
execID := h.tasks.ActiveMCPExecutionID(req.ConversationID)
|
||||
if execID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前没有正在执行的 MCP 工具(例如模型尚在推理、尚未发起工具调用)。请等待工具开始执行后再试,或使用「彻底停止」结束整轮任务。"})
|
||||
return
|
||||
}
|
||||
note := strings.TrimSpace(req.Reason)
|
||||
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
||||
if execID != "" {
|
||||
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
||||
return
|
||||
}
|
||||
h.logger.Info("对话页仅终止当前 MCP 工具",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
zap.String("executionId", execID),
|
||||
zap.Bool("hasNote", note != ""),
|
||||
)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": req.ConversationID,
|
||||
"executionId": execID,
|
||||
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
h.logger.Info("对话页仅终止当前 MCP 工具",
|
||||
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
||||
h.tasks.SetInterruptContinueNote(req.ConversationID, note)
|
||||
ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue)
|
||||
if err != nil {
|
||||
h.logger.Error("中断并继续(无工具)失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
|
||||
return
|
||||
}
|
||||
h.logger.Info("对话页中断并继续(无 MCP 工具,将自动续跑)",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
zap.String("executionId", execID),
|
||||
zap.Bool("hasNote", note != ""),
|
||||
)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": req.ConversationID,
|
||||
"executionId": execID,
|
||||
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"status": "interrupt_continue_scheduled",
|
||||
"conversationId": req.ConversationID,
|
||||
"message": "已请求暂停当前推理;用户补充将合并到上下文并自动继续执行(无需整轮停止)。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -1853,7 +2048,7 @@ func (h *AgentHandler) SubscribeAgentTaskEvents(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
@@ -1905,6 +2100,7 @@ type BatchTaskRequest struct {
|
||||
ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron
|
||||
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
|
||||
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
|
||||
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
|
||||
}
|
||||
|
||||
func normalizeBatchQueueAgentMode(mode string) string {
|
||||
@@ -1985,7 +2181,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
||||
nextRunAt = &next
|
||||
}
|
||||
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, nextRunAt, validTasks)
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks)
|
||||
if createErr != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
|
||||
return
|
||||
@@ -2006,6 +2202,11 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
||||
queue = refreshed
|
||||
}
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "create_queue", "创建批量任务队列", "batch_queue", queue.ID, map[string]interface{}{
|
||||
"task_count": len(validTasks), "started": started,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"queueId": queue.ID,
|
||||
"queue": queue,
|
||||
@@ -2113,6 +2314,9 @@ func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "start_queue", "启动批量任务队列", "batch_queue", queueID, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
|
||||
}
|
||||
|
||||
@@ -2141,6 +2345,9 @@ func (h *AgentHandler) RerunBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "启动失败"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "rerun_queue", "重跑批量任务队列", "batch_queue", queueID, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已重新开始执行", "queueId": queueID})
|
||||
}
|
||||
|
||||
@@ -2152,6 +2359,9 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "pause_queue", "暂停批量任务队列", "batch_queue", queueID, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"})
|
||||
}
|
||||
|
||||
@@ -2247,6 +2457,16 @@ func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "task",
|
||||
Action: "delete_queue",
|
||||
Result: "success",
|
||||
ResourceType: "batch_queue",
|
||||
ResourceID: queueID,
|
||||
Message: "删除批量任务队列",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"})
|
||||
}
|
||||
|
||||
@@ -2332,6 +2552,11 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "delete_batch_task", "删除批量子任务", "batch_task", taskID, map[string]interface{}{
|
||||
"batch_queue_id": queueID,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
||||
}
|
||||
|
||||
@@ -2490,7 +2715,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
|
||||
// 创建新对话
|
||||
title := safeTruncateString(task.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
batchMeta := audit.ConversationCreateMeta("batch_task")
|
||||
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, batchMeta)
|
||||
var conversationID string
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
@@ -2640,15 +2867,15 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
var runErr error
|
||||
switch {
|
||||
case useBatchMulti:
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch)
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
||||
case useEinoSingle:
|
||||
if h.config == nil {
|
||||
runErr = fmt.Errorf("服务器配置未加载")
|
||||
} else {
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback)
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
||||
}
|
||||
default:
|
||||
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
|
||||
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
@@ -2744,17 +2971,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(mcpIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(mcpIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
resText,
|
||||
mcpIDsJSON,
|
||||
time.Now(), assistantMessageID,
|
||||
); updateErr != nil {
|
||||
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
// 如果更新失败,尝试创建新消息
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
||||
@@ -2846,6 +3063,10 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
||||
if content, ok := msgMap["content"].(string); ok {
|
||||
msg.Content = content
|
||||
}
|
||||
// DeepSeek 思考模式:含工具调用的 assistant 须在后续请求中回传 reasoning_content
|
||||
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
|
||||
msg.ReasoningContent = rc
|
||||
}
|
||||
|
||||
// 解析tool_calls(如果存在)
|
||||
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
||||
@@ -2901,6 +3122,11 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
||||
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
||||
msg.ToolCallID = toolCallID
|
||||
}
|
||||
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
|
||||
msg.ToolName = strings.TrimSpace(tn)
|
||||
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
|
||||
msg.ToolName = strings.TrimSpace(tn)
|
||||
}
|
||||
|
||||
agentMessages = append(agentMessages, msg)
|
||||
}
|
||||
@@ -2946,3 +3172,18 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
||||
)
|
||||
return agentMessages, nil
|
||||
}
|
||||
|
||||
// dbMessagesToAgentChatMessages maps DB rows to agent ChatMessage for history fallback
|
||||
// (includes reasoning_content for DeepSeek thinking + tool replay).
|
||||
func dbMessagesToAgentChatMessages(msgs []database.Message) []agent.ChatMessage {
|
||||
out := make([]agent.ChatMessage, 0, len(msgs))
|
||||
for i := range msgs {
|
||||
m := msgs[i]
|
||||
out = append(out, agent.ChatMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AuditHandler serves platform audit log APIs.
|
||||
type AuditHandler struct {
|
||||
db *database.DB
|
||||
audit *audit.Service
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuditHandler creates an audit log handler.
|
||||
func NewAuditHandler(db *database.DB, auditSvc *audit.Service, logger *zap.Logger) *AuditHandler {
|
||||
return &AuditHandler{db: db, audit: auditSvc, logger: logger}
|
||||
}
|
||||
|
||||
// Meta GET /api/audit/meta
|
||||
func (h *AuditHandler) Meta(c *gin.Context) {
|
||||
enabled := false
|
||||
retentionDays := 0
|
||||
if h.audit != nil {
|
||||
enabled = h.audit.Enabled()
|
||||
retentionDays = h.audit.RetentionDays()
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": enabled,
|
||||
"retention_days": retentionDays,
|
||||
"default_page_size": 20,
|
||||
"max_page_size": 100,
|
||||
"max_export": 5000,
|
||||
})
|
||||
}
|
||||
|
||||
// Summary GET /api/audit/summary
|
||||
func (h *AuditHandler) Summary(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
base := auditFilterFromQuery(c)
|
||||
total, err := h.db.CountAuditLogs(base)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
failFilter := base
|
||||
failFilter.Result = "failure"
|
||||
failures, err := h.db.CountAuditLogs(failFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
since := time.Now().AddDate(0, 0, -7)
|
||||
recentFilter := base
|
||||
recentFilter.Since = &since
|
||||
recent7d, err := h.db.CountAuditLogs(recentFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"total": total,
|
||||
"failures": failures,
|
||||
"recent_7d": recent7d,
|
||||
"has_filters": c.Query("category") != "" || c.Query("action") != "" || c.Query("result") != "" ||
|
||||
c.Query("q") != "" || c.Query("since") != "" || c.Query("until") != "",
|
||||
})
|
||||
}
|
||||
|
||||
// ListLogs GET /api/audit/logs
|
||||
func (h *AuditHandler) ListLogs(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
filter := auditFilterFromQuery(c)
|
||||
page, pageSize := auditPaginationFromQuery(c)
|
||||
filter.Limit = pageSize
|
||||
filter.Offset = (page - 1) * pageSize
|
||||
|
||||
logs, err := h.db.ListAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
total, err := h.db.CountAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"logs": logs,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// GetLog GET /api/audit/logs/:id
|
||||
func (h *AuditHandler) GetLog(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
row, err := h.db.GetAuditLogByID(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"})
|
||||
return
|
||||
}
|
||||
audit.ApplyResourceAvailability(h.db, row)
|
||||
c.JSON(http.StatusOK, gin.H{"log": row})
|
||||
}
|
||||
|
||||
// ExportLogs GET /api/audit/logs/export — JSON or CSV (?format=csv), max 5000 rows.
|
||||
func (h *AuditHandler) ExportLogs(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
filter := auditFilterFromQuery(c)
|
||||
filter.Limit = 5000
|
||||
filter.Offset = 0
|
||||
|
||||
logs, err := h.db.ListAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if c.Query("format") == "csv" {
|
||||
writeAuditLogsCSV(c, logs)
|
||||
return
|
||||
}
|
||||
c.Header("Content-Disposition", `attachment; filename="audit-logs.json"`)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"exported_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"logs": logs,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func writeAuditLogsCSV(c *gin.Context, logs []*database.AuditLog) {
|
||||
c.Header("Content-Type", "text/csv; charset=utf-8")
|
||||
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="audit-logs-%s.csv"`, time.Now().Format("20060102")))
|
||||
|
||||
w := csv.NewWriter(c.Writer)
|
||||
_ = w.Write([]string{
|
||||
"id", "created_at", "level", "category", "action", "result", "actor",
|
||||
"session_hint", "client_ip", "resource_type", "resource_id", "message",
|
||||
})
|
||||
for _, row := range logs {
|
||||
if row == nil {
|
||||
continue
|
||||
}
|
||||
_ = w.Write([]string{
|
||||
row.ID,
|
||||
row.CreatedAt.UTC().Format(time.RFC3339),
|
||||
row.Level,
|
||||
row.Category,
|
||||
row.Action,
|
||||
row.Result,
|
||||
row.Actor,
|
||||
row.SessionHint,
|
||||
row.ClientIP,
|
||||
row.ResourceType,
|
||||
row.ResourceID,
|
||||
row.Message,
|
||||
})
|
||||
}
|
||||
w.Flush()
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
|
||||
filter := database.ListAuditLogsFilter{
|
||||
Level: c.Query("level"),
|
||||
Category: c.Query("category"),
|
||||
Action: c.Query("action"),
|
||||
Result: c.Query("result"),
|
||||
Query: c.Query("q"),
|
||||
ResourceType: c.Query("resource_type"),
|
||||
ResourceID: c.Query("resource_id"),
|
||||
}
|
||||
if since := c.Query("since"); since != "" {
|
||||
if t, err := time.Parse(time.RFC3339, since); err == nil {
|
||||
filter.Since = &t
|
||||
}
|
||||
}
|
||||
if until := c.Query("until"); until != "" {
|
||||
if t, err := time.Parse(time.RFC3339, until); err == nil {
|
||||
filter.Until = &t
|
||||
}
|
||||
}
|
||||
return filter
|
||||
}
|
||||
|
||||
func auditPaginationFromQuery(c *gin.Context) (page, pageSize int) {
|
||||
page = 1
|
||||
pageSize = 20
|
||||
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
|
||||
page = p
|
||||
}
|
||||
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "20")); err == nil && ps > 0 {
|
||||
pageSize = ps
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
}
|
||||
return page, pageSize
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
@@ -18,6 +19,12 @@ type AuthHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *AuthHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler.
|
||||
@@ -49,10 +56,32 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
|
||||
token, expiresAt, err := h.manager.Authenticate(req.Password)
|
||||
if err != nil {
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Level: "warn",
|
||||
Category: "auth",
|
||||
Action: "login",
|
||||
Result: "failure",
|
||||
Message: "登录失败:密码错误",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "login",
|
||||
Result: "success",
|
||||
SessionHint: audit.HintFromToken(token),
|
||||
Message: "登录成功",
|
||||
Detail: map[string]interface{}{
|
||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": token,
|
||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||
@@ -73,6 +102,14 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.manager.RevokeToken(token)
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "logout",
|
||||
Result: "success",
|
||||
Message: "退出登录",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
|
||||
}
|
||||
|
||||
@@ -103,6 +140,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
}
|
||||
|
||||
if !h.manager.CheckPassword(oldPassword) {
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Level: "warn",
|
||||
Category: "auth",
|
||||
Action: "change_password",
|
||||
Result: "failure",
|
||||
Message: "修改密码失败:当前密码不正确",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
|
||||
return
|
||||
}
|
||||
@@ -132,6 +178,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
h.logger.Info("登录密码已更新,所有会话已失效")
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "change_password",
|
||||
Result: "success",
|
||||
Message: "登录密码已修改",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
|
||||
}
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ type BatchTaskQueue struct {
|
||||
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
|
||||
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
||||
LastRunError string `json:"lastRunError,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
Tasks []*BatchTask `json:"tasks"`
|
||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
@@ -103,7 +104,7 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (m *BatchTaskManager) CreateBatchQueue(
|
||||
title, role, agentMode, scheduleMode, cronExpr string,
|
||||
title, role, agentMode, scheduleMode, cronExpr, projectID string,
|
||||
nextRunAt *time.Time,
|
||||
tasks []string,
|
||||
) (*BatchTaskQueue, error) {
|
||||
@@ -126,6 +127,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
||||
ID: queueID,
|
||||
Title: title,
|
||||
Role: role,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
AgentMode: normalizeBatchQueueAgentMode(agentMode),
|
||||
ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode),
|
||||
CronExpr: strings.TrimSpace(cronExpr),
|
||||
@@ -171,6 +173,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
||||
queue.ScheduleMode,
|
||||
queue.CronExpr,
|
||||
queue.NextRunAt,
|
||||
queue.ProjectID,
|
||||
dbTasks,
|
||||
); err != nil {
|
||||
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
@@ -263,6 +266,9 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
||||
if queueRow.LastRunError.Valid {
|
||||
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
||||
}
|
||||
if queueRow.ProjectID.Valid {
|
||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||
}
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
@@ -499,6 +505,9 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
||||
if queueRow.LastRunError.Valid {
|
||||
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
||||
}
|
||||
if queueRow.ProjectID.Valid {
|
||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||
}
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
|
||||
@@ -176,6 +176,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
"type": "boolean",
|
||||
"description": "创建后是否立即开始执行队列,默认 false(pending,需 batch_task_start)",
|
||||
},
|
||||
"project_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
@@ -204,7 +208,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
if !ok {
|
||||
executeNow = false
|
||||
}
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks)
|
||||
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
|
||||
if createErr != nil {
|
||||
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
@@ -25,6 +26,12 @@ import (
|
||||
type C2Handler struct {
|
||||
mgrPtr atomic.Pointer[c2.Manager]
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *C2Handler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时)
|
||||
@@ -104,6 +111,11 @@ func (h *C2Handler) CreateListener(c *gin.Context) {
|
||||
implantToken := listener.ImplantToken
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_create", "创建 C2 监听器", "c2_listener", listener.ID, map[string]interface{}{
|
||||
"name": listener.Name, "bind": listener.BindHost, "port": listener.BindPort,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken})
|
||||
}
|
||||
|
||||
@@ -205,6 +217,9 @@ func (h *C2Handler) DeleteListener(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_delete", "删除 C2 监听器", "c2_listener", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
@@ -222,6 +237,9 @@ func (h *C2Handler) StartListener(c *gin.Context) {
|
||||
}
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_start", "启动 C2 监听器", "c2_listener", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"listener": listener})
|
||||
}
|
||||
|
||||
@@ -236,6 +254,9 @@ func (h *C2Handler) StopListener(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_stop", "停止 C2 监听器", "c2_listener", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"stopped": true})
|
||||
}
|
||||
|
||||
@@ -297,6 +318,9 @@ func (h *C2Handler) DeleteSession(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "session_delete", "删除 C2 会话", "c2_session", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
@@ -407,6 +431,11 @@ func (h *C2Handler) DeleteTasks(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "task_delete", "批量删除 C2 任务", "c2_task", "", map[string]interface{}{
|
||||
"count": n, "ids": req.IDs,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": n})
|
||||
}
|
||||
|
||||
@@ -457,6 +486,11 @@ func (h *C2Handler) CreateTask(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "task_create", "创建 C2 任务", "c2_task", task.ID, map[string]interface{}{
|
||||
"session_id": req.SessionID, "task_type": req.TaskType,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"task": task})
|
||||
}
|
||||
|
||||
@@ -471,6 +505,9 @@ func (h *C2Handler) CancelTask(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "task_cancel", "取消 C2 任务", "c2_task", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"cancelled": true})
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -24,6 +26,12 @@ const (
|
||||
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
|
||||
type ChatUploadsHandler struct {
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ChatUploadsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewChatUploadsHandler 创建处理器
|
||||
@@ -230,6 +238,9 @@ func (h *ChatUploadsHandler) Delete(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "file", "delete", "删除对话附件", "chat_upload", body.Path, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
@@ -503,6 +514,11 @@ func (h *ChatUploadsHandler) Upload(c *gin.Context) {
|
||||
}
|
||||
rel, _ := filepath.Rel(root, fullPath)
|
||||
absSaved, _ := filepath.Abs(fullPath)
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "file", "upload", "上传对话附件", "chat_upload", filepath.ToSlash(rel), map[string]interface{}{
|
||||
"name": unique,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"relativePath": filepath.ToSlash(rel),
|
||||
|
||||
+154
-19
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agents"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
@@ -87,6 +88,7 @@ type ConfigHandler struct {
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书
|
||||
audit *audit.Service
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
|
||||
@@ -206,6 +208,32 @@ func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) {
|
||||
h.robotRestarter = restarter
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ConfigHandler) SetAudit(s *audit.Service) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// ApplyWechatRobotBinding 微信 iLink 扫码绑定成功后写入配置并重启机器人连接
|
||||
func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) error {
|
||||
h.mu.Lock()
|
||||
wc.Enabled = true
|
||||
h.config.Robots.Wechat = wc
|
||||
h.mu.Unlock()
|
||||
if err := h.saveConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
if h.robotRestarter != nil {
|
||||
h.robotRestarter.RestartRobotConnections()
|
||||
}
|
||||
h.logger.Info("微信机器人绑定已保存",
|
||||
zap.String("ilink_bot_id", wc.ILinkBotID),
|
||||
zap.Bool("enabled", wc.Enabled),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConfigResponse 获取配置响应
|
||||
type GetConfigResponse struct {
|
||||
OpenAI config.OpenAIConfig `json:"openai"`
|
||||
@@ -291,7 +319,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
}
|
||||
multiPub := config.MultiAgentPublic{
|
||||
Enabled: h.config.MultiAgent.Enabled,
|
||||
RobotUseMultiAgent: h.config.MultiAgent.RobotUseMultiAgent,
|
||||
RobotDefaultAgentMode: config.NormalizeRobotAgentMode(h.config.MultiAgent),
|
||||
BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent,
|
||||
SubAgentCount: subAgentCount,
|
||||
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
|
||||
@@ -609,15 +637,46 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
|
||||
// UpdateConfigRequest 更新配置请求
|
||||
type UpdateConfigRequest struct {
|
||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
||||
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
|
||||
Robots *config.RobotsConfig `json:"robots,omitempty"`
|
||||
MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"`
|
||||
C2 *config.C2APIUpdate `json:"c2,omitempty"`
|
||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||
Agent *AgentConfigUpdate `json:"agent,omitempty"`
|
||||
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
|
||||
Robots *config.RobotsConfig `json:"robots,omitempty"`
|
||||
MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"`
|
||||
C2 *config.C2APIUpdate `json:"c2,omitempty"`
|
||||
}
|
||||
|
||||
// AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。
|
||||
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
|
||||
type AgentConfigUpdate struct {
|
||||
MaxIterations *int `json:"max_iterations,omitempty"`
|
||||
LargeResultThreshold *int `json:"large_result_threshold,omitempty"`
|
||||
ResultStorageDir *string `json:"result_storage_dir,omitempty"`
|
||||
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
|
||||
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
|
||||
}
|
||||
|
||||
func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
if src.MaxIterations != nil {
|
||||
dst.MaxIterations = *src.MaxIterations
|
||||
}
|
||||
if src.LargeResultThreshold != nil {
|
||||
dst.LargeResultThreshold = *src.LargeResultThreshold
|
||||
}
|
||||
if src.ResultStorageDir != nil {
|
||||
dst.ResultStorageDir = *src.ResultStorageDir
|
||||
}
|
||||
if src.ToolTimeoutMinutes != nil {
|
||||
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
|
||||
}
|
||||
if src.SystemPromptPath != nil {
|
||||
dst.SystemPromptPath = *src.SystemPromptPath
|
||||
}
|
||||
}
|
||||
|
||||
// ToolEnableStatus 工具启用状态
|
||||
@@ -664,12 +723,19 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
|
||||
// 更新Agent配置
|
||||
// 更新Agent配置(按字段合并,避免部分 JSON 把未出现的字段写成 0)
|
||||
if req.Agent != nil {
|
||||
h.config.Agent = *req.Agent
|
||||
applyAgentConfigUpdate(&h.config.Agent, req.Agent)
|
||||
h.logger.Info("更新Agent配置",
|
||||
zap.Int("max_iterations", h.config.Agent.MaxIterations),
|
||||
zap.Int("tool_timeout_minutes", h.config.Agent.ToolTimeoutMinutes),
|
||||
)
|
||||
if h.agent != nil && req.Agent.MaxIterations != nil {
|
||||
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
|
||||
}
|
||||
if h.mcpServer != nil {
|
||||
h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新Knowledge配置
|
||||
@@ -697,6 +763,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
if req.Robots != nil {
|
||||
h.config.Robots = *req.Robots
|
||||
h.logger.Info("更新机器人配置",
|
||||
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
|
||||
zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled),
|
||||
zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled),
|
||||
zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled),
|
||||
@@ -712,15 +779,21 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
// 多代理标量(sub_agents 等仍由 config.yaml 维护)
|
||||
if req.MultiAgent != nil {
|
||||
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
|
||||
h.config.MultiAgent.RobotUseMultiAgent = req.MultiAgent.RobotUseMultiAgent
|
||||
h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent
|
||||
if mode := strings.TrimSpace(req.MultiAgent.RobotDefaultAgentMode); mode != "" {
|
||||
h.config.MultiAgent.RobotDefaultAgentMode = mode
|
||||
} else {
|
||||
h.config.MultiAgent.RobotDefaultAgentMode = "react"
|
||||
}
|
||||
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
||||
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
||||
}
|
||||
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
||||
if req.MultiAgent.ToolSearchAlwaysVisibleTools != nil {
|
||||
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(*req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
||||
}
|
||||
h.logger.Info("更新多代理配置",
|
||||
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
||||
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
|
||||
zap.String("robot_default_agent_mode", config.NormalizeRobotAgentMode(h.config.MultiAgent)),
|
||||
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
|
||||
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
|
||||
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
|
||||
@@ -843,6 +916,9 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "config", "update", "更新内存配置", "config", "", nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
|
||||
@@ -973,6 +1049,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
|
||||
if _, err := knowledgeInitializer(); err != nil {
|
||||
h.logger.Error("动态初始化知识库失败", zap.Error(err))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordFail(c, "config", "apply", "应用配置失败:初始化知识库", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1007,6 +1086,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
|
||||
if _, err := reinitKnowledgeInitializer(); err != nil {
|
||||
h.logger.Error("重新初始化知识库失败", zap.Error(err))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新初始化知识库", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1020,6 +1102,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
if c2Rt != nil {
|
||||
if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil {
|
||||
h.logger.Error("C2 配置应用失败", zap.Error(err))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordFail(c, "config", "apply", "应用配置失败:C2", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1116,6 +1201,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.agent.UpdateToolDescriptionMode(h.config.Security.ToolDescriptionMode)
|
||||
h.logger.Info("Agent配置已更新")
|
||||
}
|
||||
if h.mcpServer != nil {
|
||||
h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes)
|
||||
}
|
||||
|
||||
// 更新AttackChainHandler的OpenAI配置
|
||||
if h.attackChainHandler != nil {
|
||||
@@ -1158,6 +1246,20 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
zap.Int("tools_count", len(h.config.Security.Tools)),
|
||||
)
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "config",
|
||||
Action: "apply",
|
||||
Result: "success",
|
||||
Message: "配置已应用",
|
||||
Detail: map[string]interface{}{
|
||||
"tools_count": len(h.config.Security.Tools),
|
||||
"knowledge_enabled": h.config.Knowledge.Enabled,
|
||||
"c2_enabled": h.config.C2.EnabledEffective(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "配置已应用",
|
||||
"tools_count": len(h.config.Security.Tools),
|
||||
@@ -1181,7 +1283,7 @@ func (h *ConfigHandler) saveConfig() error {
|
||||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
updateAgentConfig(root, h.config.Agent.MaxIterations)
|
||||
updateAgentConfig(root, h.config.Agent)
|
||||
updateMCPConfig(root, h.config.MCP)
|
||||
updateOpenAIConfig(root, h.config.OpenAI)
|
||||
updateFOFAConfig(root, h.config.FOFA)
|
||||
@@ -1286,10 +1388,14 @@ func writeYAMLDocument(path string, doc *yaml.Node) error {
|
||||
return os.WriteFile(path, buf.Bytes(), 0644)
|
||||
}
|
||||
|
||||
func updateAgentConfig(doc *yaml.Node, maxIterations int) {
|
||||
func updateAgentConfig(doc *yaml.Node, agent config.AgentConfig) {
|
||||
root := doc.Content[0]
|
||||
agentNode := ensureMap(root, "agent")
|
||||
setIntInMap(agentNode, "max_iterations", maxIterations)
|
||||
setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
|
||||
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
|
||||
setIntInMap(agentNode, "large_result_threshold", agent.LargeResultThreshold)
|
||||
setStringInMap(agentNode, "result_storage_dir", agent.ResultStorageDir)
|
||||
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
|
||||
}
|
||||
|
||||
func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) {
|
||||
@@ -1312,6 +1418,19 @@ func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
|
||||
if cfg.MaxTotalTokens > 0 {
|
||||
setIntInMap(openaiNode, "max_total_tokens", cfg.MaxTotalTokens)
|
||||
}
|
||||
rn := ensureMap(openaiNode, "reasoning")
|
||||
if strings.TrimSpace(cfg.Reasoning.Mode) != "" {
|
||||
setStringInMap(rn, "mode", cfg.Reasoning.Mode)
|
||||
}
|
||||
if strings.TrimSpace(cfg.Reasoning.Effort) != "" {
|
||||
setStringInMap(rn, "effort", cfg.Reasoning.Effort)
|
||||
}
|
||||
if cfg.Reasoning.AllowClientReasoning != nil {
|
||||
setBoolInMap(rn, "allow_client_reasoning", *cfg.Reasoning.AllowClientReasoning)
|
||||
}
|
||||
if strings.TrimSpace(cfg.Reasoning.Profile) != "" {
|
||||
setStringInMap(rn, "profile", cfg.Reasoning.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
func updateFOFAConfig(doc *yaml.Node, cfg config.FofaConfig) {
|
||||
@@ -1416,6 +1535,20 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
||||
root := doc.Content[0]
|
||||
robotsNode := ensureMap(root, "robots")
|
||||
|
||||
if cfg.Session.StrictUserIdentity != nil {
|
||||
sessionNode := ensureMap(robotsNode, "session")
|
||||
setBoolInMap(sessionNode, "strict_user_identity", *cfg.Session.StrictUserIdentity)
|
||||
}
|
||||
|
||||
wechatNode := ensureMap(robotsNode, "wechat")
|
||||
setBoolInMap(wechatNode, "enabled", cfg.Wechat.Enabled)
|
||||
setStringInMap(wechatNode, "bot_token", cfg.Wechat.BotToken)
|
||||
setStringInMap(wechatNode, "ilink_bot_id", cfg.Wechat.ILinkBotID)
|
||||
setStringInMap(wechatNode, "ilink_user_id", cfg.Wechat.ILinkUserID)
|
||||
setStringInMap(wechatNode, "base_url", cfg.Wechat.BaseURL)
|
||||
setStringInMap(wechatNode, "bot_type", cfg.Wechat.BotType)
|
||||
setStringInMap(wechatNode, "bot_agent", cfg.Wechat.BotAgent)
|
||||
|
||||
wecomNode := ensureMap(robotsNode, "wecom")
|
||||
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
|
||||
setStringInMap(wecomNode, "token", cfg.Wecom.Token)
|
||||
@@ -1428,19 +1561,21 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
||||
setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled)
|
||||
setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID)
|
||||
setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret)
|
||||
setBoolInMap(dingtalkNode, "allow_conversation_id_fallback", cfg.Dingtalk.AllowConversationIDFallback)
|
||||
|
||||
larkNode := ensureMap(robotsNode, "lark")
|
||||
setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled)
|
||||
setStringInMap(larkNode, "app_id", cfg.Lark.AppID)
|
||||
setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret)
|
||||
setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken)
|
||||
setBoolInMap(larkNode, "allow_chat_id_fallback", cfg.Lark.AllowChatIDFallback)
|
||||
}
|
||||
|
||||
func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
|
||||
root := doc.Content[0]
|
||||
maNode := ensureMap(root, "multi_agent")
|
||||
setBoolInMap(maNode, "enabled", cfg.Enabled)
|
||||
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent)
|
||||
setStringInMap(maNode, "robot_default_agent_mode", config.NormalizeRobotAgentMode(cfg))
|
||||
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
|
||||
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
|
||||
mwNode := ensureMap(maNode, "eino_middleware")
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -14,6 +16,12 @@ import (
|
||||
type ConversationHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ConversationHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewConversationHandler 创建新的对话处理器
|
||||
@@ -26,7 +34,13 @@ func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHa
|
||||
|
||||
// CreateConversationRequest 创建对话请求
|
||||
type CreateConversationRequest struct {
|
||||
Title string `json:"title"`
|
||||
Title string `json:"title"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
}
|
||||
|
||||
// SetConversationProjectRequest 设置对话所属项目
|
||||
type SetConversationProjectRequest struct {
|
||||
ProjectID string `json:"projectId"` // 空字符串表示解除绑定
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
@@ -42,7 +56,9 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
||||
title = "新对话"
|
||||
}
|
||||
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "api")
|
||||
meta.ProjectID = strings.TrimSpace(req.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -52,6 +68,25 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, conv)
|
||||
}
|
||||
|
||||
// SetConversationProject 设置或清除对话绑定的项目
|
||||
func (h *ConversationHandler) SetConversationProject(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req SetConversationProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if _, err := h.db.GetConversation(id); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||
return
|
||||
}
|
||||
if err := h.db.SetConversationProjectID(id, req.ProjectID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "projectId": strings.TrimSpace(req.ProjectID)})
|
||||
}
|
||||
|
||||
// ListConversations 列出对话
|
||||
func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "50")
|
||||
@@ -117,6 +152,8 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
details = database.DedupeConsecutiveProcessDetails(details)
|
||||
|
||||
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
|
||||
out := make([]map[string]interface{}, 0, len(details))
|
||||
for _, d := range details {
|
||||
@@ -187,6 +224,17 @@ func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "conversation",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "conversation",
|
||||
ResourceID: id,
|
||||
Message: "删除对话",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
@@ -225,6 +273,12 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "conversation", "delete_turn", "删除对话轮次", "conversation", conversationID, map[string]interface{}{
|
||||
"message_id": req.MessageID,
|
||||
"deleted": len(deletedIDs),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"deletedMessageIds": deletedIDs,
|
||||
"message": "ok",
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
)
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
||||
if h.config != nil {
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
||||
}
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
||||
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
||||
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if segmentUserMessage != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
||||
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
||||
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
||||
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
transientAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, fatal error) {
|
||||
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
||||
return false, nil
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*transientAttempts++
|
||||
if *transientAttempts > maxAttempts {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
||||
}
|
||||
attemptNo := *transientAttempts
|
||||
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
}
|
||||
select {
|
||||
case <-baseCtx.Done():
|
||||
return false, context.Cause(baseCtx)
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
})
|
||||
}
|
||||
if sendProgress != nil {
|
||||
sendProgress("正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "transient_retry",
|
||||
})
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
|
||||
// EinoSingleAgentLoopStream Eino ADK 单代理(ChatModelAgent + Runner)流式对话;不依赖 multi_agent.enabled。
|
||||
func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
@@ -46,7 +46,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if eventType == "error" && baseCtx != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream")
|
||||
if err != nil {
|
||||
sendEvent("error", err.Error(), nil)
|
||||
sendEvent("done", "", nil)
|
||||
@@ -119,6 +119,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
|
||||
@@ -175,29 +176,126 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
)
|
||||
timeoutCancel()
|
||||
for {
|
||||
segmentMainIterationMax := 0
|
||||
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
if eventType == "iteration" {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||
raw := 0
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
raw = v
|
||||
case int32:
|
||||
raw = int(v)
|
||||
case int64:
|
||||
raw = int(v)
|
||||
case float64:
|
||||
raw = int(v)
|
||||
case float32:
|
||||
raw = int(v)
|
||||
}
|
||||
if raw > 0 {
|
||||
if raw > segmentMainIterationMax {
|
||||
segmentMainIterationMax = raw
|
||||
}
|
||||
m["iteration"] = raw + mainIterationOffset
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rawProgressCallback(eventType, message, data)
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtxLoop,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
note := h.tasks.TakeInterruptContinueNote(conversationID)
|
||||
icSummary := interruptContinueTimelineSummary(note)
|
||||
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"rawReason": strings.TrimSpace(note),
|
||||
"emptyReason": strings.TrimSpace(note) == "",
|
||||
"kind": "no_active_mcp_tool",
|
||||
})
|
||||
inject := formatInterruptContinueUserMessage(note)
|
||||
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
curHistory = hist
|
||||
}
|
||||
curFinalMessage = inject
|
||||
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
@@ -221,6 +319,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -238,6 +337,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -254,22 +354,14 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
timeoutCancel()
|
||||
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
time.Now(),
|
||||
assistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
@@ -279,7 +371,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"mcpExecutionIds": cumulativeMCPExecutionIDs,
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"agentMode": "eino_single",
|
||||
@@ -297,7 +389,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
|
||||
h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID))
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -337,6 +429,8 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
prep.History,
|
||||
prep.RoleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
if runErr != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -347,18 +441,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
}
|
||||
|
||||
if prep.AssistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
time.Now(),
|
||||
prep.AssistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
@@ -20,9 +21,15 @@ type ExternalMCPHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ExternalMCPHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewExternalMCPHandler 创建外部MCP处理器
|
||||
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
|
||||
return &ExternalMCPHandler{
|
||||
@@ -180,6 +187,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "external_mcp",
|
||||
Action: "upsert",
|
||||
Result: "success",
|
||||
ResourceType: "external_mcp",
|
||||
ResourceID: name,
|
||||
Message: "更新外部 MCP 配置",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
|
||||
@@ -209,6 +226,16 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "external_mcp",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "external_mcp",
|
||||
ResourceID: name,
|
||||
Message: "删除外部 MCP 配置",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
|
||||
}
|
||||
|
||||
|
||||
@@ -616,6 +616,11 @@ func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "decision", "HITL 审批决策", "hitl_interrupt", req.InterruptID, map[string]interface{}{
|
||||
"decision": req.Decision,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
|
||||
@@ -20,6 +21,12 @@ type KnowledgeHandler struct {
|
||||
indexer *knowledge.Indexer
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *KnowledgeHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewKnowledgeHandler 创建新的知识库处理器
|
||||
@@ -303,6 +310,9 @@ func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "knowledge", "item_delete", "删除知识项", "knowledge_item", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
@@ -316,6 +326,9 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "knowledge", "index_rebuild", "重建知识库索引", "knowledge", "", nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agents"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -18,7 +19,8 @@ var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.m
|
||||
|
||||
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
|
||||
type MarkdownAgentsHandler struct {
|
||||
dir string
|
||||
dir string
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
|
||||
@@ -26,6 +28,11 @@ func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler {
|
||||
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *MarkdownAgentsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
|
||||
filename = strings.TrimSpace(filename)
|
||||
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
|
||||
@@ -227,6 +234,9 @@ func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_create", "创建 Markdown 子代理", "markdown_agent", filepath.Base(path), nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
|
||||
}
|
||||
|
||||
@@ -294,6 +304,9 @@ func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_update", "更新 Markdown 子代理", "markdown_agent", filename, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已保存"})
|
||||
}
|
||||
|
||||
@@ -313,5 +326,8 @@ func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_delete", "删除 Markdown 子代理", "markdown_agent", filename, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已删除"})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
@@ -23,6 +24,12 @@ type MonitorHandler struct {
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *MonitorHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewMonitorHandler 创建新的监控处理器
|
||||
@@ -365,6 +372,11 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "tool", "execution_delete", "删除工具执行记录", "tool_execution", id, map[string]interface{}{
|
||||
"tool_name": exec.ToolName,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
|
||||
return
|
||||
}
|
||||
@@ -440,6 +452,11 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "tool", "execution_delete_batch", "批量删除工具执行记录", "tool_execution", "", map[string]interface{}{
|
||||
"count": len(request.IDs),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
|
||||
return
|
||||
}
|
||||
|
||||
+181
-51
@@ -20,7 +20,7 @@ import (
|
||||
|
||||
// MultiAgentLoopStream Eino DeepAgent 流式对话(需 config.multi_agent.enabled)。
|
||||
func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
if h.config == nil || !h.config.MultiAgent.Enabled {
|
||||
@@ -63,7 +63,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
// 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。
|
||||
if eventType == "error" && baseCtx != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -107,7 +107,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream")
|
||||
if err != nil {
|
||||
sendEvent("error", err.Error(), nil)
|
||||
sendEvent("done", "", nil)
|
||||
@@ -136,6 +136,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
orch := strings.TrimSpace(req.Orchestration)
|
||||
@@ -184,31 +185,129 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
result, runErr = multiagent.RunDeepAgent(
|
||||
taskCtx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
orch,
|
||||
)
|
||||
timeoutCancel()
|
||||
for {
|
||||
segmentMainIterationMax := 0
|
||||
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
if eventType == "iteration" {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||
raw := 0
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
raw = v
|
||||
case int32:
|
||||
raw = int(v)
|
||||
case int64:
|
||||
raw = int(v)
|
||||
case float64:
|
||||
raw = int(v)
|
||||
case float32:
|
||||
raw = int(v)
|
||||
}
|
||||
if raw > 0 {
|
||||
if raw > segmentMainIterationMax {
|
||||
segmentMainIterationMax = raw
|
||||
}
|
||||
m["iteration"] = raw + mainIterationOffset
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rawProgressCallback(eventType, message, data)
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
result, runErr = multiagent.RunDeepAgent(
|
||||
taskCtxLoop,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
orch,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
note := h.tasks.TakeInterruptContinueNote(conversationID)
|
||||
icSummary := interruptContinueTimelineSummary(note)
|
||||
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"rawReason": strings.TrimSpace(note),
|
||||
"emptyReason": strings.TrimSpace(note) == "",
|
||||
"kind": "no_active_mcp_tool",
|
||||
})
|
||||
inject := formatInterruptContinueUserMessage(note)
|
||||
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
curHistory = hist
|
||||
}
|
||||
curFinalMessage = inject
|
||||
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
@@ -232,6 +331,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -249,6 +349,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -265,22 +366,14 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
timeoutCancel()
|
||||
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
time.Now(),
|
||||
assistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
@@ -294,7 +387,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
effectiveOrch = config.NormalizeMultiAgentOrchestration(o)
|
||||
}
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"mcpExecutionIds": cumulativeMCPExecutionIDs,
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"agentMode": "eino_" + effectiveOrch,
|
||||
@@ -317,7 +410,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
|
||||
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent")
|
||||
if err != nil {
|
||||
status, msg := multiAgentHTTPErrorStatus(err)
|
||||
c.JSON(status, gin.H{"error": msg})
|
||||
@@ -350,6 +443,8 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
strings.TrimSpace(req.Orchestration),
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
if runErr != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -365,18 +460,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
}
|
||||
|
||||
if prep.AssistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
time.Now(),
|
||||
prep.AssistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
@@ -406,6 +490,52 @@ func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, res
|
||||
}
|
||||
}
|
||||
|
||||
// mergeMCPExecutionIDLists 去重合并多段 Run 的 MCP execution id(顺序:先 dst 后 more)。
|
||||
func mergeMCPExecutionIDLists(dst []string, more []string) []string {
|
||||
seen := make(map[string]struct{}, len(dst)+len(more))
|
||||
out := make([]string, 0, len(dst)+len(more))
|
||||
add := func(ids []string) {
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
add(dst)
|
||||
add(more)
|
||||
return out
|
||||
}
|
||||
|
||||
// interruptContinueTimelineSummary 时间线 / process_details 中展示的简短正文(完整模板已写入另一条用户消息)。
|
||||
func interruptContinueTimelineSummary(note string) string {
|
||||
note = strings.TrimSpace(note)
|
||||
if note == "" {
|
||||
return "用户选择「中断并继续」,未填写说明;已按默认渗透补充模板合并上下文并续跑。"
|
||||
}
|
||||
return "用户中断说明(原文):\n\n" + note
|
||||
}
|
||||
|
||||
// formatInterruptContinueUserMessage 将「中断并继续」弹窗中的说明格式化为新一轮 user 消息(渗透场景下强调路径补充与端口复扫)。
|
||||
func formatInterruptContinueUserMessage(note string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("【用户补充 / 中断后继续】\n")
|
||||
if s := strings.TrimSpace(note); s != "" {
|
||||
b.WriteString(s)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
b.WriteString("【请在本轮落实】\n")
|
||||
b.WriteString("- 将用户提供的接口路径、参数、业务变化纳入后续测试与推理。\n")
|
||||
b.WriteString("- 若资产或目标信息有更新,请对目标重新执行端口/服务探测,再基于新结果规划下一步。\n")
|
||||
b.WriteString("- 在已有轨迹基础上推进,避免无意义重复已完成的步骤。\n")
|
||||
return strings.TrimSpace(b.String())
|
||||
}
|
||||
|
||||
func multiAgentHTTPErrorStatus(err error) (int, string) {
|
||||
msg := err.Error()
|
||||
switch {
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -22,7 +24,7 @@ type multiAgentPrepared struct {
|
||||
UserMessageID string
|
||||
}
|
||||
|
||||
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) {
|
||||
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) {
|
||||
if len(req.Attachments) > maxAttachments {
|
||||
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
|
||||
}
|
||||
@@ -33,10 +35,14 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
var conv *database.Conversation
|
||||
var err error
|
||||
meta := audit.ConversationCreateMetaFromGin(c, source)
|
||||
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||
if strings.TrimSpace(req.WebShellConnectionID) != "" {
|
||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
|
||||
meta.Source = source + "_webshell"
|
||||
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
|
||||
conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta)
|
||||
} else {
|
||||
conv, err = h.db.CreateConversation(title)
|
||||
conv, err = h.db.CreateConversation(title, meta)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||
@@ -55,13 +61,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
||||
if getErr != nil {
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +91,14 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
||||
builtin.ToolWebshellFileRead,
|
||||
builtin.ToolWebshellFileWrite,
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
builtin.ToolUpsertProjectFact,
|
||||
builtin.ToolGetProjectFact,
|
||||
builtin.ToolListProjectFacts,
|
||||
builtin.ToolSearchProjectFacts,
|
||||
builtin.ToolDeprecateProjectFact,
|
||||
builtin.ToolRestoreProjectFact,
|
||||
builtin.ToolListKnowledgeRiskTypes,
|
||||
builtin.ToolSearchKnowledgeBase,
|
||||
}
|
||||
|
||||
+139
-1
@@ -73,8 +73,22 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"description": "对话标题",
|
||||
"example": "Web应用安全测试",
|
||||
},
|
||||
"projectId": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "绑定的项目 ID(可选,共享事实黑板)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"SetConversationProjectRequest": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"projectId": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "项目 ID;空字符串表示解除绑定",
|
||||
},
|
||||
},
|
||||
"required": []string{"projectId"},
|
||||
},
|
||||
"Conversation": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
@@ -98,6 +112,10 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"format": "date-time",
|
||||
"description": "更新时间",
|
||||
},
|
||||
"projectId": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "绑定的项目 ID(可选)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"ConversationDetail": map[string]interface{}{
|
||||
@@ -1326,6 +1344,37 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/conversations/{id}/project": map[string]interface{}{
|
||||
"put": map[string]interface{}{
|
||||
"tags": []string{"对话管理"},
|
||||
"summary": "设置对话所属项目",
|
||||
"description": "绑定或解除对话与项目的关联,用于共享事实黑板",
|
||||
"operationId": "setConversationProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{
|
||||
"name": "id", "in": "path", "required": true,
|
||||
"description": "对话ID",
|
||||
"schema": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
"requestBody": map[string]interface{}{
|
||||
"required": true,
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/components/schemas/SetConversationProjectRequest",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{"description": "设置成功"},
|
||||
"400": map[string]interface{}{"description": "项目不存在或参数错误"},
|
||||
"404": map[string]interface{}{"description": "对话不存在"},
|
||||
"401": map[string]interface{}{"description": "未授权"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/conversations/{id}/results": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"对话管理"},
|
||||
@@ -2444,6 +2493,86 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/projects": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"},
|
||||
"summary": "列出项目",
|
||||
"operationId": "listProjects",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "status", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"active", "archived"}}},
|
||||
{"name": "limit", "in": "query", "schema": map[string]interface{}{"type": "integer", "default": 200}},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{"description": "项目列表"},
|
||||
"401": map[string]interface{}{"description": "未授权"},
|
||||
},
|
||||
},
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"项目管理"},
|
||||
"summary": "创建项目",
|
||||
"operationId": "createProject",
|
||||
"requestBody": map[string]interface{}{
|
||||
"required": true,
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"name": map[string]interface{}{"type": "string"},
|
||||
"description": map[string]interface{}{"type": "string"},
|
||||
"scope_json": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{"description": "创建成功"},
|
||||
"401": map[string]interface{}{"description": "未授权"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/projects/{id}": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "获取项目", "operationId": "getProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "项目详情"}},
|
||||
},
|
||||
"put": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "更新项目", "operationId": "updateProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "更新成功"}},
|
||||
},
|
||||
"delete": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "删除项目", "operationId": "deleteProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}},
|
||||
},
|
||||
},
|
||||
"/api/projects/{id}/facts": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "列出或按 key 获取事实", "operationId": "listProjectFacts",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}},
|
||||
},
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
|
||||
},
|
||||
},
|
||||
"/api/vulnerabilities": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"漏洞管理"},
|
||||
@@ -2502,6 +2631,15 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "project_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"description": "项目ID",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "severity",
|
||||
"in": "query",
|
||||
@@ -6254,7 +6392,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取漏洞列表
|
||||
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "")
|
||||
vulnList, err := h.db.ListVulnerabilities(1000, 0, database.VulnerabilityListFilter{ConversationID: conversationID})
|
||||
if err != nil {
|
||||
h.logger.Warn("获取漏洞列表失败", zap.Error(err))
|
||||
vulnList = []*database.Vulnerability{}
|
||||
|
||||
@@ -0,0 +1,400 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/project"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProjectHandler 项目管理处理器。
|
||||
type ProjectHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProjectHandler 创建项目管理处理器。
|
||||
func NewProjectHandler(db *database.DB, logger *zap.Logger) *ProjectHandler {
|
||||
return &ProjectHandler{db: db, logger: logger}
|
||||
}
|
||||
|
||||
type createProjectRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
ScopeJSON string `json:"scope_json"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// updateProjectRequest 部分更新:字段省略表示不修改;传 null 或 "" 可清空字符串字段。
|
||||
type updateProjectRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
ScopeJSON *string `json:"scope_json"`
|
||||
Status *string `json:"status"`
|
||||
Pinned *bool `json:"pinned"`
|
||||
}
|
||||
|
||||
// CreateProject POST /api/projects
|
||||
func (h *ProjectHandler) CreateProject(c *gin.Context) {
|
||||
var req createProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
p := &database.Project{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Description: req.Description,
|
||||
ScopeJSON: req.ScopeJSON,
|
||||
Status: strings.TrimSpace(req.Status),
|
||||
}
|
||||
created, err := h.db.CreateProject(p)
|
||||
if err != nil {
|
||||
h.logger.Error("创建项目失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// ListProjects GET /api/projects
|
||||
func (h *ProjectHandler) ListProjects(c *gin.Context) {
|
||||
status := c.Query("status")
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "200"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
list, err := h.db.ListProjects(status, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.Project{}
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// GetProjectStats GET /api/projects/:id/stats
|
||||
func (h *ProjectHandler) GetProjectStats(c *gin.Context) {
|
||||
stats, err := project.GetProjectStats(h.db, c.Param("id"))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "不存在") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// ListProjectConversations GET /api/projects/:id/conversations
|
||||
func (h *ProjectHandler) ListProjectConversations(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
if _, err := h.db.GetProject(projectID); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
list, err := h.db.ListConversationsByProjectID(projectID, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.Conversation{}
|
||||
}
|
||||
total, _ := h.db.CountConversationsByProjectID(projectID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"conversations": list,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// GetProject GET /api/projects/:id
|
||||
func (h *ProjectHandler) GetProject(c *gin.Context) {
|
||||
p, err := h.db.GetProject(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, p)
|
||||
}
|
||||
|
||||
// UpdateProject PUT /api/projects/:id
|
||||
func (h *ProjectHandler) UpdateProject(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
p, err := h.db.GetProject(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
var req updateProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.Name != nil {
|
||||
if s := strings.TrimSpace(*req.Name); s != "" {
|
||||
p.Name = s
|
||||
}
|
||||
}
|
||||
if req.Description != nil {
|
||||
p.Description = *req.Description
|
||||
}
|
||||
if req.ScopeJSON != nil {
|
||||
p.ScopeJSON = *req.ScopeJSON
|
||||
}
|
||||
if req.Status != nil {
|
||||
if s := strings.TrimSpace(*req.Status); s != "" {
|
||||
p.Status = s
|
||||
}
|
||||
}
|
||||
if req.Pinned != nil {
|
||||
p.Pinned = *req.Pinned
|
||||
}
|
||||
if err := h.db.UpdateProject(p); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, p)
|
||||
}
|
||||
|
||||
// DeleteProject DELETE /api/projects/:id
|
||||
func (h *ProjectHandler) DeleteProject(c *gin.Context) {
|
||||
if err := h.db.DeleteProject(c.Param("id")); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type upsertFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
Category string `json:"category"`
|
||||
Summary string `json:"summary" binding:"required"`
|
||||
Body string `json:"body"`
|
||||
Confidence string `json:"confidence"`
|
||||
Pinned bool `json:"pinned"`
|
||||
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
||||
}
|
||||
|
||||
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
|
||||
type updateFactRequest struct {
|
||||
FactKey *string `json:"fact_key"`
|
||||
Category *string `json:"category"`
|
||||
Summary *string `json:"summary"`
|
||||
Body *string `json:"body"`
|
||||
Confidence *string `json:"confidence"`
|
||||
Pinned *bool `json:"pinned"`
|
||||
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
|
||||
ClearBody bool `json:"clear_body"`
|
||||
}
|
||||
|
||||
// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情)
|
||||
func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
if key := strings.TrimSpace(c.Query("fact_key")); key != "" {
|
||||
f, err := h.db.GetProjectFactByKey(projectID, key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, f)
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
filter := database.ProjectFactListFilter{
|
||||
Category: c.Query("category"),
|
||||
Confidence: c.Query("confidence"),
|
||||
Search: c.Query("search"),
|
||||
RelatedVulnerabilityID: c.Query("related_vulnerability_id"),
|
||||
}
|
||||
if c.Query("exclude_deprecated") == "1" || c.Query("exclude_deprecated") == "true" {
|
||||
filter.ExcludeDeprecated = true
|
||||
}
|
||||
list, err := h.db.ListProjectFacts(projectID, filter, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.ProjectFact{}
|
||||
}
|
||||
if sparseOnly := c.Query("sparse_only"); sparseOnly == "1" || sparseOnly == "true" {
|
||||
filtered := make([]*database.ProjectFact, 0, len(list))
|
||||
for _, f := range list {
|
||||
if project.IsSparseFactBody(f.Category, f.FactKey, f.Body) {
|
||||
filtered = append(filtered, f)
|
||||
}
|
||||
}
|
||||
list = filtered
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// GetFactPreviousVersion GET /api/projects/:id/facts/:factId/previous-version
|
||||
func (h *ProjectHandler) GetFactPreviousVersion(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(existing.SupersedesFactID) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "无上一版本"})
|
||||
return
|
||||
}
|
||||
v, err := h.db.GetProjectFactVersion(existing.SupersedesFactID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, v)
|
||||
}
|
||||
|
||||
// ListFactVersions GET /api/projects/:id/facts/:factId/versions
|
||||
func (h *ProjectHandler) ListFactVersions(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
list, err := h.db.ListProjectFactVersions(existing.ID, limit)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.ProjectFactVersion{}
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// CreateFact POST /api/projects/:id/facts
|
||||
func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
||||
var req upsertFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
f := &database.ProjectFact{
|
||||
ProjectID: c.Param("id"),
|
||||
FactKey: req.FactKey,
|
||||
Category: req.Category,
|
||||
Summary: req.Summary,
|
||||
Body: req.Body,
|
||||
Confidence: req.Confidence,
|
||||
Pinned: req.Pinned,
|
||||
RelatedVulnerabilityID: req.RelatedVulnerabilityID,
|
||||
}
|
||||
created, err := h.db.UpsertProjectFact(f)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// UpdateFact PUT /api/projects/:id/facts/:factId
|
||||
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
var req updateFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.FactKey != nil {
|
||||
if k := strings.TrimSpace(*req.FactKey); k != "" {
|
||||
existing.FactKey = k
|
||||
}
|
||||
}
|
||||
if req.Category != nil && strings.TrimSpace(*req.Category) != "" {
|
||||
existing.Category = *req.Category
|
||||
}
|
||||
if req.Summary != nil && strings.TrimSpace(*req.Summary) != "" {
|
||||
existing.Summary = *req.Summary
|
||||
}
|
||||
if req.ClearBody {
|
||||
existing.Body = ""
|
||||
} else if req.Body != nil {
|
||||
existing.Body = *req.Body
|
||||
}
|
||||
if req.Confidence != nil && strings.TrimSpace(*req.Confidence) != "" {
|
||||
existing.Confidence = *req.Confidence
|
||||
}
|
||||
if req.Pinned != nil {
|
||||
existing.Pinned = *req.Pinned
|
||||
}
|
||||
if req.RelatedVulnerabilityID != nil {
|
||||
existing.RelatedVulnerabilityID = *req.RelatedVulnerabilityID
|
||||
}
|
||||
updated, err := h.db.UpsertProjectFact(existing)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
// DeleteFact DELETE /api/projects/:id/facts/:factId
|
||||
func (h *ProjectHandler) DeleteFact(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteProjectFact(existing.ID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type deprecateFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
}
|
||||
|
||||
// DeprecateFact POST /api/projects/:id/facts/deprecate
|
||||
func (h *ProjectHandler) DeprecateFact(c *gin.Context) {
|
||||
var req deprecateFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeprecateProjectFact(c.Param("id"), req.FactKey); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type restoreFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
Confidence string `json:"confidence"` // 可选:confirmed | tentative,默认 tentative
|
||||
}
|
||||
|
||||
// RestoreFact POST /api/projects/:id/facts/restore
|
||||
func (h *ProjectHandler) RestoreFact(c *gin.Context) {
|
||||
var req restoreFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.db.RestoreProjectFact(c.Param("id"), req.FactKey, req.Confidence); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/project"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
|
||||
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
||||
if h == nil || h.db == nil || h.config == nil {
|
||||
return ""
|
||||
}
|
||||
if !h.config.Project.Enabled {
|
||||
return ""
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
projectID, err := h.db.GetConversationProjectID(conversationID)
|
||||
if err != nil || projectID == "" {
|
||||
return ""
|
||||
}
|
||||
block, err := project.BuildProjectBlackboardBlock(h.db, projectID, h.config.Project)
|
||||
if err != nil {
|
||||
h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(block)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
)
|
||||
|
||||
// effectiveProjectID 请求/队列显式项目优先,否则使用 config.project.default_project_id。
|
||||
func effectiveProjectID(cfg *config.Config, explicit string) string {
|
||||
if pid := strings.TrimSpace(explicit); pid != "" {
|
||||
return pid
|
||||
}
|
||||
if cfg != nil {
|
||||
return strings.TrimSpace(cfg.Project.DefaultProjectID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
+219
-22
@@ -40,8 +40,13 @@ const (
|
||||
robotCmdRoles = "角色"
|
||||
robotCmdRolesList = "角色列表"
|
||||
robotCmdSwitchRole = "切换角色"
|
||||
robotCmdDelete = "删除"
|
||||
robotCmdVersion = "版本"
|
||||
robotCmdDelete = "删除"
|
||||
robotCmdVersion = "版本"
|
||||
robotCmdProjects = "项目"
|
||||
robotCmdProjectsList = "项目列表"
|
||||
robotCmdBindProject = "绑定项目"
|
||||
robotCmdNewProject = "新建项目"
|
||||
robotCmdUnbindProject = "解除项目"
|
||||
)
|
||||
|
||||
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
|
||||
@@ -133,7 +138,9 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (
|
||||
} else {
|
||||
t = safeTruncateString(t, 50)
|
||||
}
|
||||
conv, err := h.db.CreateConversation(t)
|
||||
meta := database.ConversationCreateMeta{Source: "robot:" + platform}
|
||||
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||
conv, err := h.db.CreateConversation(t, meta)
|
||||
if err != nil {
|
||||
h.logger.Warn("创建机器人会话失败", zap.Error(err))
|
||||
return "", false
|
||||
@@ -188,7 +195,9 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) {
|
||||
// clearConversation 清空当前会话(切换到新对话)
|
||||
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
|
||||
title := "新对话 " + time.Now().Format("01-02 15:04")
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
meta := database.ConversationCreateMeta{Source: "robot:" + platform + ":new"}
|
||||
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||
conv, err := h.db.CreateConversation(title, meta)
|
||||
if err != nil {
|
||||
h.logger.Warn("创建新对话失败", zap.Error(err))
|
||||
return ""
|
||||
@@ -230,7 +239,7 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
||||
_ = h.db.UpdateConversationTitle(convID, newTitle)
|
||||
}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), h.robotMessageTimeout())
|
||||
sk := h.sessionKey(platform, userID)
|
||||
h.cancelMu.Lock()
|
||||
h.runningCancels[sk] = cancel
|
||||
@@ -242,12 +251,15 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
||||
h.cancelMu.Unlock()
|
||||
}()
|
||||
role := h.getRole(platform, userID)
|
||||
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role)
|
||||
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, platform, convID, text, role)
|
||||
if err != nil {
|
||||
h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err))
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return "任务已取消。"
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return "任务执行超时,请稍后重试或精简本次请求范围。"
|
||||
}
|
||||
return "处理失败: " + err.Error()
|
||||
}
|
||||
if newConvID != convID {
|
||||
@@ -256,22 +268,182 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
||||
return resp
|
||||
}
|
||||
|
||||
func (h *RobotHandler) robotMessageTimeout() time.Duration {
|
||||
// 机器人整次消息处理超时(与单次工具超时 agent.tool_timeout_minutes 解耦)。
|
||||
return 10 * time.Hour
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdHelp() string {
|
||||
return "**【CyberStrikeAI 机器人命令】**\n\n" +
|
||||
"- `帮助` `help` — 显示本帮助 | Show this help\n" +
|
||||
"- `列表` `list` — 列出所有对话标题与 ID | List conversations\n" +
|
||||
"- `切换 <ID>` `switch <ID>` — 指定对话继续 | Switch to conversation\n" +
|
||||
"- `新对话` `new` — 开启新对话 | Start new conversation\n" +
|
||||
"- `清空` `clear` — 清空当前上下文 | Clear context\n" +
|
||||
"- `当前` `current` — 显示当前对话 ID 与标题 | Show current conversation\n" +
|
||||
"- `停止` `stop` — 中断当前任务 | Stop running task\n" +
|
||||
"- `角色` `roles` — 列出所有可用角色 | List roles\n" +
|
||||
"- `角色 <名>` `role <name>` — 切换当前角色 | Switch role\n" +
|
||||
"- `删除 <ID>` `delete <ID>` — 删除指定对话 | Delete conversation\n" +
|
||||
"- `版本` `version` — 显示当前版本号 | Show version\n\n" +
|
||||
"---\n" +
|
||||
"除以上命令外,直接输入内容将发送给 AI 进行渗透测试/安全分析。\n" +
|
||||
"Otherwise, send any text for AI penetration testing / security analysis."
|
||||
var b strings.Builder
|
||||
b.WriteString("【CyberStrikeAI 机器人命令】\n\n")
|
||||
b.WriteString("【通用 General】\n")
|
||||
b.WriteString("· 帮助 / help — 显示本帮助\n")
|
||||
b.WriteString("· 版本 / version — 显示当前版本号\n")
|
||||
b.WriteString("\n【对话 Conversation】\n")
|
||||
b.WriteString("· 列表 / list — 列出所有对话标题与 ID\n")
|
||||
b.WriteString("· 切换 <ID> / switch <ID> — 指定对话继续\n")
|
||||
b.WriteString("· 新对话 / new — 开启新对话\n")
|
||||
b.WriteString("· 清空 / clear — 清空当前上下文\n")
|
||||
b.WriteString("· 当前 / current — 显示当前对话、角色与项目\n")
|
||||
b.WriteString("· 停止 / stop — 中断当前任务\n")
|
||||
b.WriteString("· 删除 <ID> / delete <ID> — 删除指定对话\n")
|
||||
b.WriteString("\n【角色 Role】\n")
|
||||
b.WriteString("· 角色 / roles — 列出所有可用角色\n")
|
||||
b.WriteString("· 角色 <名> / role <name> — 切换当前角色\n")
|
||||
if h.projectsEnabled() {
|
||||
b.WriteString("\n【项目 Project】\n")
|
||||
b.WriteString("· 项目 / projects — 列出所有项目\n")
|
||||
b.WriteString("· 新建项目 <名称> / new project <name> — 创建并绑定当前对话\n")
|
||||
b.WriteString("· 绑定项目 <ID或名称> / bind project <ID|name> — 绑定到已有项目\n")
|
||||
b.WriteString("· 解除项目 / unbind project — 解除项目绑定\n")
|
||||
}
|
||||
b.WriteString("\n──────────────\n")
|
||||
b.WriteString("除以上命令外,直接输入内容将发送给 AI 进行渗透测试/安全分析。")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (h *RobotHandler) projectsEnabled() bool {
|
||||
return h.config != nil && h.config.Project.Enabled
|
||||
}
|
||||
|
||||
func (h *RobotHandler) resolveProjectByIDOrName(idOrName string) (*database.Project, string) {
|
||||
idOrName = strings.TrimSpace(idOrName)
|
||||
if idOrName == "" {
|
||||
return nil, "请指定项目 ID 或名称,例如:绑定项目 xxx-xxx"
|
||||
}
|
||||
if p, err := h.db.GetProject(idOrName); err == nil {
|
||||
return p, ""
|
||||
}
|
||||
list, err := h.db.ListProjects("", 200, 0)
|
||||
if err != nil {
|
||||
return nil, "查询项目失败: " + err.Error()
|
||||
}
|
||||
var matches []*database.Project
|
||||
for _, p := range list {
|
||||
if p.Name == idOrName {
|
||||
matches = append(matches, p)
|
||||
}
|
||||
}
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
return nil, fmt.Sprintf("项目「%s」不存在。发送「项目」查看列表。", idOrName)
|
||||
case 1:
|
||||
return matches[0], ""
|
||||
default:
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("名称「%s」匹配到多个项目,请使用 ID 绑定:\n", idOrName))
|
||||
for _, p := range matches {
|
||||
b.WriteString(fmt.Sprintf("· %s\n ID: %s\n", p.Name, p.ID))
|
||||
}
|
||||
return nil, strings.TrimSuffix(b.String(), "\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RobotHandler) formatProjectLabel(projectID string) string {
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
return "未绑定"
|
||||
}
|
||||
if p, err := h.db.GetProject(projectID); err == nil {
|
||||
return fmt.Sprintf("「%s」 (%s)", p.Name, p.ID)
|
||||
}
|
||||
return projectID
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdProjects() string {
|
||||
if !h.projectsEnabled() {
|
||||
return "项目功能未启用(config.project.enabled)。"
|
||||
}
|
||||
list, err := h.db.ListProjects("", 50, 0)
|
||||
if err != nil {
|
||||
return "获取项目列表失败: " + err.Error()
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return "暂无项目。发送「新建项目 <名称>」创建并绑定到当前对话。"
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("【项目列表】\n")
|
||||
for i, p := range list {
|
||||
if i >= 20 {
|
||||
b.WriteString("… 仅显示前 20 条\n")
|
||||
break
|
||||
}
|
||||
status := p.Status
|
||||
if status == "" {
|
||||
status = "active"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("· %s [%s]\n ID: %s\n", p.Name, status, p.ID))
|
||||
}
|
||||
return strings.TrimSuffix(b.String(), "\n")
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdBindProject(platform, userID, idOrName string) string {
|
||||
if !h.projectsEnabled() {
|
||||
return "项目功能未启用(config.project.enabled)。"
|
||||
}
|
||||
p, errMsg := h.resolveProjectByIDOrName(idOrName)
|
||||
if p == nil {
|
||||
return errMsg
|
||||
}
|
||||
convID, _ := h.getOrCreateConversation(platform, userID, "")
|
||||
if convID == "" {
|
||||
return "无法获取当前对话,请稍后再试。"
|
||||
}
|
||||
if err := h.db.SetConversationProjectID(convID, p.ID); err != nil {
|
||||
return "绑定失败: " + err.Error()
|
||||
}
|
||||
return fmt.Sprintf("已将当前对话绑定到项目:「%s」\nID: %s", p.Name, p.ID)
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdNewProject(platform, userID, name string) string {
|
||||
if !h.projectsEnabled() {
|
||||
return "项目功能未启用(config.project.enabled)。"
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return "请指定项目名称,例如:新建项目 某目标渗透"
|
||||
}
|
||||
p := &database.Project{Name: name, Status: "active"}
|
||||
created, err := h.db.CreateProject(p)
|
||||
if err != nil {
|
||||
return "创建项目失败: " + err.Error()
|
||||
}
|
||||
convID, _ := h.getOrCreateConversation(platform, userID, name)
|
||||
if convID == "" {
|
||||
return fmt.Sprintf("项目已创建:「%s」\nID: %s\n(绑定当前对话失败,请手动发送「绑定项目 %s」)", created.Name, created.ID, created.ID)
|
||||
}
|
||||
if err := h.db.SetConversationProjectID(convID, created.ID); err != nil {
|
||||
return fmt.Sprintf("项目已创建:「%s」\nID: %s\n绑定失败: %s", created.Name, created.ID, err.Error())
|
||||
}
|
||||
return fmt.Sprintf("已创建项目并绑定当前对话:「%s」\nID: %s", created.Name, created.ID)
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdUnbindProject(platform, userID string) string {
|
||||
if !h.projectsEnabled() {
|
||||
return "项目功能未启用(config.project.enabled)。"
|
||||
}
|
||||
sk := h.sessionKey(platform, userID)
|
||||
h.mu.RLock()
|
||||
convID := h.sessions[sk]
|
||||
h.mu.RUnlock()
|
||||
if convID == "" {
|
||||
if persistedConvID, _ := h.loadSessionBinding(sk); persistedConvID != "" {
|
||||
convID = persistedConvID
|
||||
}
|
||||
}
|
||||
if convID == "" {
|
||||
return "当前没有进行中的对话,无需解除绑定。"
|
||||
}
|
||||
projectID, err := h.db.GetConversationProjectID(convID)
|
||||
if err != nil {
|
||||
return "获取对话项目失败: " + err.Error()
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
return "当前对话未绑定项目。"
|
||||
}
|
||||
if err := h.db.SetConversationProjectID(convID, ""); err != nil {
|
||||
return "解除绑定失败: " + err.Error()
|
||||
}
|
||||
return "已解除当前对话的项目绑定。"
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdList() string {
|
||||
@@ -345,7 +517,12 @@ func (h *RobotHandler) cmdCurrent(platform, userID string) string {
|
||||
return "当前对话 ID: " + convID + "(获取标题失败)"
|
||||
}
|
||||
role := h.getRole(platform, userID)
|
||||
return fmt.Sprintf("当前对话:「%s」\nID: %s\n当前角色: %s", conv.Title, conv.ID, role)
|
||||
reply := fmt.Sprintf("当前对话:「%s」\nID: %s\n当前角色: %s", conv.Title, conv.ID, role)
|
||||
if h.projectsEnabled() {
|
||||
projectID, _ := h.db.GetConversationProjectID(conv.ID)
|
||||
reply += "\n当前项目: " + h.formatProjectLabel(projectID)
|
||||
}
|
||||
return reply
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdRoles() string {
|
||||
@@ -482,6 +659,26 @@ func (h *RobotHandler) handleRobotCommand(platform, userID, text string) (string
|
||||
return h.cmdDelete(platform, userID, convID), true
|
||||
case text == robotCmdVersion || text == "version":
|
||||
return h.cmdVersion(), true
|
||||
case text == robotCmdProjects || text == robotCmdProjectsList || text == "projects":
|
||||
return h.cmdProjects(), true
|
||||
case text == robotCmdUnbindProject || text == "unbind project":
|
||||
return h.cmdUnbindProject(platform, userID), true
|
||||
case strings.HasPrefix(text, robotCmdNewProject+" ") || strings.HasPrefix(text, "new project "):
|
||||
var name string
|
||||
if strings.HasPrefix(text, robotCmdNewProject+" ") {
|
||||
name = strings.TrimSpace(text[len(robotCmdNewProject)+1:])
|
||||
} else {
|
||||
name = strings.TrimSpace(text[len("new project "):])
|
||||
}
|
||||
return h.cmdNewProject(platform, userID, name), true
|
||||
case strings.HasPrefix(text, robotCmdBindProject+" ") || strings.HasPrefix(text, "bind project "):
|
||||
var idOrName string
|
||||
if strings.HasPrefix(text, robotCmdBindProject+" ") {
|
||||
idOrName = strings.TrimSpace(text[len(robotCmdBindProject)+1:])
|
||||
} else {
|
||||
idOrName = strings.TrimSpace(text[len("bind project "):])
|
||||
}
|
||||
return h.cmdBindProject(platform, userID, idOrName), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -21,6 +22,12 @@ type RoleHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *RoleHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewRoleHandler 创建新的角色处理器
|
||||
@@ -174,6 +181,9 @@ func (h *RoleHandler) UpdateRole(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "update", "更新角色", "role", finalKey, map[string]interface{}{"name": req.Name})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已更新",
|
||||
"role": req,
|
||||
@@ -219,6 +229,9 @@ func (h *RoleHandler) CreateRole(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("创建角色", zap.String("roleName", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "create", "创建角色", "role", req.Name, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已创建",
|
||||
"role": req,
|
||||
@@ -287,6 +300,9 @@ func (h *RoleHandler) DeleteRole(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("删除角色", zap.String("roleName", roleName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "delete", "删除角色", "role", roleName, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已删除",
|
||||
})
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/skillpackage"
|
||||
@@ -23,6 +24,12 @@ type SkillsHandler struct {
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除)
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *SkillsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewSkillsHandler 创建新的Skills处理器
|
||||
@@ -365,6 +372,9 @@ func (h *SkillsHandler) CreateSkill(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "create", "创建 Skill", "skill", req.Name, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已创建",
|
||||
"skill": map[string]interface{}{
|
||||
@@ -425,6 +435,9 @@ func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("更新skill成功", zap.String("skill", skillName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "update", "更新 Skill", "skill", skillName, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已更新",
|
||||
})
|
||||
@@ -459,6 +472,11 @@ func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("删除skill成功", zap.String("skill", skillName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "delete", "删除 Skill", "skill", skillName, map[string]interface{}{
|
||||
"affected_roles": affectedRoles,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": responseMsg,
|
||||
"affected_roles": affectedRoles,
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
)
|
||||
|
||||
// ErrTaskCancelled 用户取消任务的错误
|
||||
@@ -32,6 +34,9 @@ type AgentTask struct {
|
||||
// ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具)
|
||||
ActiveMCPExecutionID string `json:"-"`
|
||||
|
||||
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
|
||||
InterruptContinueNote string `json:"-"`
|
||||
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
@@ -65,6 +70,50 @@ func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID str
|
||||
}
|
||||
}
|
||||
|
||||
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
|
||||
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.InterruptContinueNote = note
|
||||
}
|
||||
}
|
||||
|
||||
// TakeInterruptContinueNote 读取并清空补充说明(续跑开始时调用一次)。
|
||||
func (m *AgentTaskManager) TakeInterruptContinueNote(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
n := t.InterruptContinueNote
|
||||
t.InterruptContinueNote = ""
|
||||
return n
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BindTaskCancel 在同一运行任务内替换与 context 绑定的 cancel 函数(用于中断后继续时换新 baseCtx)。
|
||||
func (m *AgentTaskManager) BindTaskCancel(conversationID string, cancel context.CancelCauseFunc) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" || cancel == nil {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.cancel = func(err error) {
|
||||
cancel(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。
|
||||
func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
@@ -210,8 +259,16 @@ func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool,
|
||||
return true, nil
|
||||
}
|
||||
|
||||
task.Status = "cancelling"
|
||||
task.CancellingAt = time.Now()
|
||||
// ErrInterruptContinue:仅掐断当前推理步骤,随后由处理器续跑,不进入长时间「取消中」态。
|
||||
if cause != nil && errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
task.Status = "running"
|
||||
} else {
|
||||
task.Status = "cancelling"
|
||||
task.CancellingAt = time.Now()
|
||||
}
|
||||
if cause != nil && errors.Is(cause, ErrTaskCancelled) {
|
||||
task.InterruptContinueNote = ""
|
||||
}
|
||||
cancel := task.cancel
|
||||
m.mu.Unlock()
|
||||
|
||||
|
||||
@@ -253,5 +253,5 @@ func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
runCommandStreamImpl(cmd, sendEvent, ctx)
|
||||
_ = runCommandStreamImpl(cmd, sendEvent, ctx)
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ const ptyCols = 256
|
||||
const ptyRows = 40
|
||||
|
||||
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
|
||||
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
defer ptmx.Close()
|
||||
|
||||
@@ -43,4 +43,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
|
||||
exitCode = -1
|
||||
}
|
||||
sendEvent(streamEvent{T: "exit", C: exitCode})
|
||||
return exitCode
|
||||
}
|
||||
|
||||
@@ -11,20 +11,20 @@ import (
|
||||
)
|
||||
|
||||
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
|
||||
normalize := func(s string) string {
|
||||
@@ -62,4 +62,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
|
||||
exitCode = -1
|
||||
}
|
||||
sendEvent(streamEvent{T: "exit", C: exitCode})
|
||||
return exitCode
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -16,6 +17,12 @@ import (
|
||||
type VulnerabilityHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *VulnerabilityHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewVulnerabilityHandler 创建新的漏洞处理器
|
||||
@@ -29,6 +36,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
|
||||
// CreateVulnerabilityRequest 创建漏洞请求
|
||||
type CreateVulnerabilityRequest struct {
|
||||
ConversationID string `json:"conversation_id" binding:"required"`
|
||||
ProjectID string `json:"project_id"`
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
@@ -52,6 +60,7 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
||||
|
||||
vuln := &database.Vulnerability{
|
||||
ConversationID: req.ConversationID,
|
||||
ProjectID: strings.TrimSpace(req.ProjectID),
|
||||
ConversationTag: req.ConversationTag,
|
||||
TaskTag: req.TaskTag,
|
||||
Title: req.Title,
|
||||
@@ -72,6 +81,11 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "vulnerability", "create", "创建漏洞记录", "vulnerability", created.ID, map[string]interface{}{
|
||||
"severity": created.Severity, "title": created.Title,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
@@ -98,18 +112,30 @@ type ListVulnerabilitiesResponse struct {
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter {
|
||||
q := strings.TrimSpace(c.Query("q"))
|
||||
if q == "" {
|
||||
q = strings.TrimSpace(c.Query("search"))
|
||||
}
|
||||
return database.VulnerabilityListFilter{
|
||||
ProjectID: c.Query("project_id"),
|
||||
ID: c.Query("id"),
|
||||
Search: q,
|
||||
ConversationID: c.Query("conversation_id"),
|
||||
Severity: c.Query("severity"),
|
||||
Status: c.Query("status"),
|
||||
TaskID: c.Query("task_id"),
|
||||
ConversationTag: c.Query("conversation_tag"),
|
||||
TaskTag: c.Query("task_tag"),
|
||||
}
|
||||
}
|
||||
|
||||
// ListVulnerabilities 列出漏洞
|
||||
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "20")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
pageStr := c.Query("page")
|
||||
id := c.Query("id")
|
||||
conversationID := c.Query("conversation_id")
|
||||
severity := c.Query("severity")
|
||||
status := c.Query("status")
|
||||
taskID := c.Query("task_id")
|
||||
conversationTag := c.Query("conversation_tag")
|
||||
taskTag := c.Query("task_tag")
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
limit, _ := strconv.Atoi(limitStr)
|
||||
offset, _ := strconv.Atoi(offsetStr)
|
||||
@@ -131,7 +157,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
total, err := h.db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞总数失败", zap.Error(err))
|
||||
// 继续执行,使用0作为总数
|
||||
@@ -139,7 +165,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取漏洞列表
|
||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -170,17 +196,18 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
|
||||
// UpdateVulnerabilityRequest 更新漏洞请求
|
||||
type UpdateVulnerabilityRequest struct {
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
ProjectID *string `json:"project_id"`
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
}
|
||||
|
||||
// UpdateVulnerability 更新漏洞
|
||||
@@ -201,6 +228,9 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.ProjectID != nil {
|
||||
existing.ProjectID = strings.TrimSpace(*req.ProjectID)
|
||||
}
|
||||
if req.ConversationTag != "" {
|
||||
existing.ConversationTag = req.ConversationTag
|
||||
}
|
||||
@@ -249,6 +279,11 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
|
||||
"severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
@@ -262,15 +297,25 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "vulnerability",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "vulnerability",
|
||||
ResourceID: id,
|
||||
Message: "删除漏洞记录",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
// GetVulnerabilityStats 获取漏洞统计
|
||||
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
||||
conversationID := c.Query("conversation_id")
|
||||
taskID := c.Query("task_id")
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
stats, err := h.db.GetVulnerabilityStats(conversationID, taskID)
|
||||
stats, err := h.db.GetVulnerabilityStats(filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞统计失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -304,15 +349,9 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Query("id")
|
||||
conversationID := c.Query("conversation_id")
|
||||
severity := c.Query("severity")
|
||||
status := c.Query("status")
|
||||
taskID := c.Query("task_id")
|
||||
conversationTag := c.Query("conversation_tag")
|
||||
taskTag := c.Query("task_tag")
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
total, err := h.db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -322,7 +361,7 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
items, err := h.db.ListVulnerabilities(total, 0, filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -132,6 +134,16 @@ func quoteCmdPath(p string) string {
|
||||
return "\"" + strings.ReplaceAll(p, "\"", "\"\"") + "\""
|
||||
}
|
||||
|
||||
// normalizeWindowsCmdPath 把前端统一的 "/" 路径转换为 cmd 更稳定识别的 "\"。
|
||||
// 仅用于 Windows 命令构造,不改变语义(例如 "." / ".." 会保持不变)。
|
||||
func normalizeWindowsCmdPath(p string) string {
|
||||
s := strings.TrimSpace(p)
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
return strings.ReplaceAll(s, "/", "\\")
|
||||
}
|
||||
|
||||
// quotePsSingle 把字符串按 PowerShell 单引号字符串规则转义(内部 ' → '')。
|
||||
// 供 PowerShell 脚本参数使用,全脚本只用单引号,外层 cmd 再用双引号包裹即可安全传递。
|
||||
func quotePsSingle(s string) string {
|
||||
@@ -196,6 +208,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
p = "."
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
p = normalizeWindowsCmdPath(p)
|
||||
return "dir /a " + quoteCmdPath(p), nil
|
||||
}
|
||||
return "ls -la " + quoteShellSinglePosix(p), nil
|
||||
@@ -205,6 +218,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return "type " + quoteCmdPath(path), nil
|
||||
}
|
||||
return "cat " + quoteShellSinglePosix(path), nil
|
||||
@@ -214,6 +228,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return "del /q /f " + quoteCmdPath(path), nil
|
||||
}
|
||||
return "rm -f " + quoteShellSinglePosix(path), nil
|
||||
@@ -223,6 +238,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
// cmd 的 md 默认会自动创建中间目录(等价于 Linux 的 mkdir -p)
|
||||
return "md " + quoteCmdPath(path), nil
|
||||
}
|
||||
@@ -235,6 +251,8 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpRenameNeedsBothPaths
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
oldPath = normalizeWindowsCmdPath(oldPath)
|
||||
newPath = normalizeWindowsCmdPath(newPath)
|
||||
return "move /y " + quoteCmdPath(oldPath) + " " + quoteCmdPath(newPath), nil
|
||||
}
|
||||
return "mv -f " + quoteShellSinglePosix(oldPath) + " " + quoteShellSinglePosix(newPath), nil
|
||||
@@ -247,6 +265,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
// 这样既能写入任意二进制/含引号的文本,又避免各家 shell 的转义地狱。
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte(in.Content))
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return buildWindowsPowerShellWrite(path, b64), nil
|
||||
}
|
||||
return "echo '" + b64 + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
||||
@@ -259,6 +278,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpUploadTooLarge
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return buildWindowsPowerShellWrite(path, in.Content), nil
|
||||
}
|
||||
return "echo '" + in.Content + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
||||
@@ -268,6 +288,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
if in.ChunkIndex == 0 {
|
||||
return buildWindowsPowerShellWrite(path, in.Content), nil
|
||||
}
|
||||
@@ -304,6 +325,12 @@ type WebShellHandler struct {
|
||||
logger *zap.Logger
|
||||
client *http.Client
|
||||
db *database.DB
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *WebShellHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用)
|
||||
@@ -311,8 +338,12 @@ func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler {
|
||||
return &WebShellHandler{
|
||||
logger: logger,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{DisableKeepAlives: false},
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: false,
|
||||
// WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy
|
||||
},
|
||||
},
|
||||
db: db,
|
||||
}
|
||||
@@ -403,6 +434,15 @@ func (h *WebShellHandler) CreateConnection(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
host := req.URL
|
||||
if u, err := url.Parse(req.URL); err == nil {
|
||||
host = u.Host
|
||||
}
|
||||
h.audit.RecordOK(c, "webshell", "connection_create", "创建 WebShell 连接", "webshell_connection", conn.ID, map[string]interface{}{
|
||||
"host": host, "type": shellType,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, conn)
|
||||
}
|
||||
|
||||
@@ -485,6 +525,9 @@ func (h *WebShellHandler) DeleteConnection(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "webshell", "connection_delete", "删除 WebShell 连接", "webshell_connection", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
@@ -714,8 +757,9 @@ func (h *WebShellHandler) Exec(c *gin.Context) {
|
||||
output := decodeWebshellOutput(out, req.Encoding)
|
||||
httpCode := resp.StatusCode
|
||||
|
||||
ok := resp.StatusCode == http.StatusOK
|
||||
c.JSON(http.StatusOK, ExecResponse{
|
||||
OK: resp.StatusCode == http.StatusOK,
|
||||
OK: ok,
|
||||
Output: output,
|
||||
HTTPCode: httpCode,
|
||||
})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user