mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-06 06:13:58 +02:00
Compare commits
129 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0d9eebffe6 | |||
| 403d4421d2 | |||
| e606369e31 | |||
| da8fdafe59 | |||
| 0492365430 | |||
| 3a6bc60276 | |||
| 3a401ade68 | |||
| 71aade5bd9 | |||
| a5f11cc003 | |||
| dcea95968b | |||
| 7db0294d5c | |||
| b4d85c5a77 | |||
| fcbc7b9226 | |||
| b8b1e8431b | |||
| 203a99bed4 | |||
| 449781c029 | |||
| 924f59015d | |||
| f0fb634a6b | |||
| b8dfb9556a | |||
| 9c1d3ae85e | |||
| b8ebf023a0 | |||
| 604ce34d5e | |||
| b29b36bfd5 | |||
| 11bab83fc5 | |||
| dc750e3680 | |||
| 0236d1c155 | |||
| be59ddcab6 | |||
| 25464a68e6 | |||
| eabfed09c9 | |||
| cbcbd414cd | |||
| 0933f9365b | |||
| e792891ff3 | |||
| e14e5f15d3 | |||
| 4d5e0c5f21 | |||
| b3238304ce | |||
| 665e2ec73a | |||
| d63d9c25b8 | |||
| d1c63d0ba7 | |||
| 55d6d449cd | |||
| d4bc9646d9 | |||
| b941f5a8d9 | |||
| 97e2c0fd43 | |||
| bd3e48c2d0 | |||
| 8b0b91fddc | |||
| 2b38595b42 | |||
| 5c795439ee | |||
| df531910cf | |||
| 8a089a826c | |||
| 60b32ffc69 | |||
| 21c36fcce8 | |||
| 4d048f6da0 | |||
| 03a2707b83 | |||
| 9941f51b3e | |||
| 1553e896c5 | |||
| ea2184773e | |||
| 764d8110ec | |||
| e037f383f5 | |||
| e40f7cb468 | |||
| 72aca69204 | |||
| 133da1c640 | |||
| af78b47517 | |||
| f5fabc05a4 | |||
| 5cc53b1076 | |||
| f1be2064db | |||
| 0c9c2ec606 | |||
| cf09dd36d8 | |||
| c6e2701b30 | |||
| 42b5901d99 | |||
| 117bed6839 | |||
| bad323cd0e | |||
| 8138f8b576 | |||
| 74627d214b | |||
| f622efe245 | |||
| 3924b5285b | |||
| 21f641bbd7 | |||
| d913695303 | |||
| 6bb3a73f73 | |||
| f0a80a8e58 | |||
| 3f9dbb4214 | |||
| c0f0861b31 | |||
| 704137aa34 | |||
| c56bf36df0 | |||
| 5560f34c6c | |||
| 70e9a73fc0 | |||
| 12bc9d8ab6 | |||
| f8db82a065 | |||
| 8ce30d9072 | |||
| e6506d00e8 | |||
| b2308617b8 | |||
| cd17fdca33 | |||
| 1acaccd09f | |||
| 983fe650c1 | |||
| 52d03dc849 | |||
| 9de72d9ad5 | |||
| d95275ffae | |||
| 6cef93dbb7 | |||
| dd3b1ae219 | |||
| f42209682a | |||
| 1b1aed1699 | |||
| 44ced98863 | |||
| 97834c162e | |||
| 9276f2f144 | |||
| a454cada6a | |||
| 99b53d4fbc | |||
| a43a9deaea | |||
| ce88da84c9 | |||
| 15855c7073 | |||
| 43eb3e546b | |||
| 2d52c9b6ac | |||
| d5401b8b4c | |||
| 5fd4393a2e | |||
| a049f6b5c2 | |||
| acba8e5a39 | |||
| f826b91362 | |||
| 98c2de2a60 | |||
| 1c4d4b305b | |||
| f210ac9a03 | |||
| 6685076dfb | |||
| 7f322653f6 | |||
| 66ac2f1357 | |||
| c446e22d0c | |||
| 0358d3a67d | |||
| 9b82f265fd | |||
| 3d9cae58e4 | |||
| 1f1eadee5e | |||
| 0569255189 | |||
| 8ccf90d067 | |||
| b3be89f47d | |||
| b9bf8f62d4 |
@@ -113,10 +113,12 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
|
|||||||
- 🔒 Password-protected web UI, audit logs, and SQLite persistence
|
- 🔒 Password-protected web UI, audit logs, and SQLite persistence
|
||||||
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
|
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
|
||||||
- 📁 Conversation grouping with pinning, rename, and batch management
|
- 📁 Conversation grouping with pinning, rename, and batch management
|
||||||
|
- 📂 **Project management**: group conversations and vulnerabilities by project; **shared facts** (project blackboard) persist cross-session context (targets, env, auth notes) with auto-injection for agents and MCP tools (`upsert_project_fact`, `get_project_fact`, …)
|
||||||
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
||||||
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
||||||
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
||||||
- 🧩 **Multi-agent (CloudWeGo Eino)**: alongside **single-agent ReAct** (`/api/agent-loop`), **multi mode** (`/api/multi-agent/stream`) offers **`deep`** (coordinator + `task` sub-agents), **`plan_execute`** (planner / executor / replanner), and **`supervisor`** (orchestrator + `transfer` / `exit`); chosen per request via **`orchestration`**. Markdown under `agents/`: `orchestrator.md` (Deep), `orchestrator-plan-execute.md`, `orchestrator-supervisor.md`, plus sub-agent `*.md` where applicable (see [Multi-agent doc](docs/MULTI_AGENT_EINO.md))
|
- 🧩 **Agent orchestration (CloudWeGo Eino)**: **single-agent** via **`/api/eino-agent/stream`** (Eino ADK `ChatModelAgent`); **multi-agent** via **`/api/multi-agent/stream`** with **`deep`** (coordinator + `task` sub-agents), **`plan_execute`**, or **`supervisor`** (`orchestration` in the request body). Markdown under `agents/`: `orchestrator.md`, `orchestrator-plan-execute.md`, `orchestrator-supervisor.md`, plus sub-agent `*.md` (see [Multi-agent doc](docs/MULTI_AGENT_EINO.md))
|
||||||
|
- 🖼️ **Vision analysis (`analyze_image`)**: separate VL model (e.g. `qwen-vl-max`) via MCP for local screenshots, captchas, and UI; image bytes stay out of agent history (text summaries only). Configure `vision` in `config.yaml`; see [docs/VISION.md](docs/VISION.md)
|
||||||
- 🎯 **Skills (refactored for Eino)**: packs under `skills_dir` follow **Agent Skills** layout (`SKILL.md` + optional files); **multi-agent** sessions use the official Eino ADK **`skill`** tool for **progressive disclosure** (load by name), with optional **host filesystem / shell** via `multi_agent.eino_skills`; optional **`eino_middleware`** adds patchtoolcalls, tool_search, plantask, reduction, checkpoints, and Deep tuning—20+ sample domains (SQLi, XSS, API security, …) ship under `skills/`
|
- 🎯 **Skills (refactored for Eino)**: packs under `skills_dir` follow **Agent Skills** layout (`SKILL.md` + optional files); **multi-agent** sessions use the official Eino ADK **`skill`** tool for **progressive disclosure** (load by name), with optional **host filesystem / shell** via `multi_agent.eino_skills`; optional **`eino_middleware`** adds patchtoolcalls, tool_search, plantask, reduction, checkpoints, and Deep tuning—20+ sample domains (SQLi, XSS, API security, …) ship under `skills/`
|
||||||
- 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands)
|
- 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands)
|
||||||
- 🧑⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals
|
- 🧑⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals
|
||||||
@@ -234,7 +236,7 @@ Requirements / tips:
|
|||||||
|
|
||||||
### Core Workflows
|
### Core Workflows
|
||||||
- **Conversation testing** – Natural-language prompts trigger toolchains with streaming SSE output.
|
- **Conversation testing** – Natural-language prompts trigger toolchains with streaming SSE output.
|
||||||
- **Single vs multi-agent** – With `multi_agent.enabled: true`, the chat UI can switch between **single** (classic **ReAct** loop, `/api/agent-loop/stream`) and **multi** (`/api/multi-agent/stream`). Multi mode keeps **`deep`** as the baseline coordinator + **`task`** sub-agents, and adds **`plan_execute`** and **`supervisor`** orchestrations via the request body **`orchestration`** field. MCP tools are bridged the same way as single-agent.
|
- **Single vs multi-agent** – Chat UI switches between **Eino single-agent** (`/api/eino-agent/stream`) and **multi-agent** (`/api/multi-agent/stream` with `orchestration`: `deep` | `plan_execute` | `supervisor`). Multi mode requires `multi_agent.enabled: true`. MCP tools are bridged the same way for both paths.
|
||||||
- **Role-based testing** – Select from predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, etc.) to customize AI behavior and tool availability. Each role applies custom system prompts and can restrict available tools for focused testing scenarios.
|
- **Role-based testing** – Select from predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, etc.) to customize AI behavior and tool availability. Each role applies custom system prompts and can restrict available tools for focused testing scenarios.
|
||||||
- **Tool monitor** – Inspect running jobs, execution logs, and large-result attachments.
|
- **Tool monitor** – Inspect running jobs, execution logs, and large-result attachments.
|
||||||
- **History & audit** – Every conversation and tool invocation is stored in SQLite with replay.
|
- **History & audit** – Every conversation and tool invocation is stored in SQLite with replay.
|
||||||
@@ -258,7 +260,7 @@ Requirements / tips:
|
|||||||
- **Predefined roles** – System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory.
|
- **Predefined roles** – System includes 12+ predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, Binary Analysis, Cloud Security Audit, etc.) in the `roles/` directory.
|
||||||
- **Custom prompts** – Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas.
|
- **Custom prompts** – Each role can define a `user_prompt` that prepends to user messages, guiding the AI to adopt specialized testing methodologies and focus areas.
|
||||||
- **Tool restrictions** – Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities).
|
- **Tool restrictions** – Roles can specify a `tools` list to limit available tools, ensuring focused testing workflows (e.g., CTF role restricts to CTF-specific utilities).
|
||||||
- **Skills** – Skill packs live under `skills_dir` and are loaded in **multi-agent / Eino** sessions via the ADK **`skill`** tool (**progressive disclosure**). Configure **`multi_agent.eino_skills`** for middleware, tool name override, and optional host **read_file / glob / grep / write / edit / execute** (**Deep / Supervisor** when enabled; **plan_execute** differs—see docs). Single-agent ReAct does not mount this Eino skill stack today.
|
- **Skills** – Skill packs live under `skills_dir` and load via the Eino ADK **`skill`** tool (**progressive disclosure**) in both **single- and multi-agent** sessions when **`multi_agent.eino_skills`** is enabled. Optional host **read_file / glob / grep / write / edit / execute** and **`eino_middleware`** (tool_search, reduction, checkpoints, etc.) apply per mode—see docs.
|
||||||
- **Easy role creation** – Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields.
|
- **Easy role creation** – Create custom roles by adding YAML files to the `roles/` directory. Each role defines `name`, `description`, `user_prompt`, `icon`, `tools`, and `enabled` fields.
|
||||||
- **Web UI integration** – Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions.
|
- **Web UI integration** – Select roles from a dropdown in the chat interface. Role selection affects both AI behavior and available tool suggestions.
|
||||||
|
|
||||||
@@ -278,14 +280,14 @@ Requirements / tips:
|
|||||||
2. Restart the server or reload configuration; the role appears in the role selector dropdown.
|
2. Restart the server or reload configuration; the role appears in the role selector dropdown.
|
||||||
|
|
||||||
### Multi-Agent Mode (Eino: Deep, Plan-Execute, Supervisor)
|
### Multi-Agent Mode (Eino: Deep, Plan-Execute, Supervisor)
|
||||||
- **What it is** – An optional execution path beside **single-agent ReAct**, built on CloudWeGo **Eino** `adk/prebuilt`: **`deep`** — coordinator + **`task`** sub-agents; **`plan_execute`** — planner / executor / replanner loop (no YAML/Markdown sub-agent list); **`supervisor`** — orchestrator with **`transfer`** and **`exit`** over Markdown-defined specialists. The client sends **`orchestration`**: `deep` | `plan_execute` | `supervisor` (default `deep`).
|
- **What it is** – Multi-agent orchestration on CloudWeGo **Eino** `adk/prebuilt` (alongside **Eino single-agent** on `/api/eino-agent*`): **`deep`** — coordinator + **`task`** sub-agents; **`plan_execute`** — planner / executor / replanner; **`supervisor`** — orchestrator with **`transfer`** / **`exit`**. Client sends **`orchestration`**: `deep` | `plan_execute` | `supervisor` (default `deep`).
|
||||||
- **Markdown agents** – Under `agents_dir` (default `agents/`):
|
- **Markdown agents** – Under `agents_dir` (default `agents/`):
|
||||||
- **Deep orchestrator**: `orchestrator.md` *or* one `.md` with `kind: orchestrator`. Body or `multi_agent.orchestrator_instruction`, then Eino defaults.
|
- **Deep orchestrator**: `orchestrator.md` *or* one `.md` with `kind: orchestrator`. Body or `multi_agent.orchestrator_instruction`, then Eino defaults.
|
||||||
- **Plan-Execute orchestrator**: fixed name **`orchestrator-plan-execute.md`** (plus optional `orchestrator_instruction_plan_execute` in YAML).
|
- **Plan-Execute orchestrator**: fixed name **`orchestrator-plan-execute.md`** (plus optional `orchestrator_instruction_plan_execute` in YAML).
|
||||||
- **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent.
|
- **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent.
|
||||||
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
|
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
|
||||||
- **Management** – Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
|
- **Management** – Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
|
||||||
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `default_mode`, `robot_use_multi_agent`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
||||||
- **Details** – **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
|
- **Details** – **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
|
||||||
|
|
||||||
### Skills System (Agent Skills + Eino)
|
### Skills System (Agent Skills + Eino)
|
||||||
@@ -535,8 +537,8 @@ skills_dir: "skills" # Skills directory (relative to config file)
|
|||||||
agents_dir: "agents" # Multi-agent Markdown definitions (orchestrator + sub-agents)
|
agents_dir: "agents" # Multi-agent Markdown definitions (orchestrator + sub-agents)
|
||||||
multi_agent:
|
multi_agent:
|
||||||
enabled: false
|
enabled: false
|
||||||
default_mode: "single" # single | multi (UI default when multi-agent is enabled)
|
default_mode: "eino_single" # eino_single | multi (UI default when multi-agent is enabled)
|
||||||
robot_use_multi_agent: false
|
robot_default_agent_mode: eino_single
|
||||||
batch_use_multi_agent: false
|
batch_use_multi_agent: false
|
||||||
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
|
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
|
||||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
||||||
|
|||||||
+9
-7
@@ -112,10 +112,12 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
|||||||
- 🔒 Web 登录保护、审计日志、SQLite 持久化
|
- 🔒 Web 登录保护、审计日志、SQLite 持久化
|
||||||
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
|
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
|
||||||
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
||||||
|
- 📂 **项目管理**:按项目归类对话与漏洞;**共享事实**(项目黑板)在多会话间沉淀目标/环境/认证等认知,自动注入 Agent 上下文,支持 MCP 工具读写(`upsert_project_fact`、`get_project_fact` 等)
|
||||||
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
||||||
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
||||||
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
||||||
- 🧩 **多代理(CloudWeGo Eino)**:在 **单代理 ReAct**(`/api/agent-loop`)之外,**多代理**(`/api/multi-agent/stream`)提供 **`deep`**(协调主代理 + `task` 子代理)、**`plan_execute`**(规划 / 执行 / 重规划)、**`supervisor`**(主代理 `transfer` / `exit` 监督子代理);由请求体 **`orchestration`** 选择。`agents/` 下分模式主代理:`orchestrator.md`(Deep)、`orchestrator-plan-execute.md`、`orchestrator-supervisor.md`,及适用的子代理 `*.md`(详见 [多代理说明](docs/MULTI_AGENT_EINO.md))
|
- 🧩 **Agent 编排(CloudWeGo Eino)**:**单代理** `POST /api/eino-agent/stream`(Eino ADK);**多代理** `POST /api/multi-agent/stream`,`orchestration` 选 **`deep`** / **`plan_execute`** / **`supervisor`**。`agents/` 下主代理与子代理 Markdown 见 [多代理说明](docs/MULTI_AGENT_EINO.md)
|
||||||
|
- 🖼️ **视觉分析(`analyze_image`)**:独立 Vision 模型(如 `qwen-vl-max`),MCP 工具分析本地截图/验证码/UI;图片仅在单次 VL 调用中出现,对话上下文只保留文字摘要。配置见 `config.yaml` → `vision` 与 [视觉分析说明](docs/VISION.md)
|
||||||
- 🎯 **Skills(面向 Eino 重构)**:技能包放在 **`skills_dir`**,遵循 **Agent Skills** 目录规范(`SKILL.md` + 可选文件);**多代理** 下通过 Eino 官方 **`skill`** 工具 **渐进式披露**(按 name 加载)。**`multi_agent.eino_skills`** 控制是否启用、本机文件/Shell 工具、工具名覆盖;**`eino_middleware`** 可选 patch、tool_search、plantask、reduction、断点目录及 Deep 调参。20+ 领域示例仍可绑定角色
|
- 🎯 **Skills(面向 Eino 重构)**:技能包放在 **`skills_dir`**,遵循 **Agent Skills** 目录规范(`SKILL.md` + 可选文件);**多代理** 下通过 Eino 官方 **`skill`** 工具 **渐进式披露**(按 name 加载)。**`multi_agent.eino_skills`** 控制是否启用、本机文件/Shell 工具、工具名覆盖;**`eino_middleware`** 可选 patch、tool_search、plantask、reduction、断点目录及 Deep 调参。20+ 领域示例仍可绑定角色
|
||||||
- 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md))
|
- 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md))
|
||||||
- 🧑⚖️ **人机协同(HITL)**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml` 的 `hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用
|
- 🧑⚖️ **人机协同(HITL)**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml` 的 `hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用
|
||||||
@@ -232,7 +234,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
|||||||
|
|
||||||
### 常用流程
|
### 常用流程
|
||||||
- **对话测试**:自然语言触发多步工具编排,SSE 实时输出。
|
- **对话测试**:自然语言触发多步工具编排,SSE 实时输出。
|
||||||
- **单代理 / 多代理**:`multi_agent.enabled: true` 后可在聊天中切换 **单代理**(原有 **ReAct**,`/api/agent-loop/stream`)与 **多代理**(`/api/multi-agent/stream`)。多代理在既有 **`deep`**(`task` 子代理)基础上,新增 **`plan_execute`**、**`supervisor`**,由 **`orchestration`** 指定。MCP 工具与单代理同源桥接。
|
- **单代理 / 多代理**:聊天可选 **Eino 单代理**(`/api/eino-agent/stream`)与 **多代理**(`/api/multi-agent/stream` + `orchestration`)。多代理需 `multi_agent.enabled: true`。MCP 工具桥接一致。
|
||||||
- **角色化测试**:从预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试等)中选择,自定义 AI 行为和可用工具。每个角色可应用自定义系统提示词,并可限制可用工具列表,实现聚焦的测试场景。
|
- **角色化测试**:从预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试等)中选择,自定义 AI 行为和可用工具。每个角色可应用自定义系统提示词,并可限制可用工具列表,实现聚焦的测试场景。
|
||||||
- **工具监控**:查看任务队列、执行日志、大文件附件。
|
- **工具监控**:查看任务队列、执行日志、大文件附件。
|
||||||
- **会话历史**:所有对话与工具调用保存在 SQLite,可随时重放。
|
- **会话历史**:所有对话与工具调用保存在 SQLite,可随时重放。
|
||||||
@@ -256,7 +258,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
|||||||
- **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。
|
- **预设角色**:系统内置 12+ 个预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试、二进制分析、云安全审计等),位于 `roles/` 目录。
|
||||||
- **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。
|
- **自定义提示词**:每个角色可定义 `user_prompt`,会在用户消息前自动添加,引导 AI 采用特定的测试方法和关注重点。
|
||||||
- **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。
|
- **工具限制**:角色可指定 `tools` 列表,限制可用工具,实现聚焦的测试流程(如 CTF 角色限制为 CTF 专用工具)。
|
||||||
- **Skills**:技能包位于 `skills_dir`;**多代理 / Eino** 下由 **`skill`** 工具 **按需加载**(渐进式披露)。**`multi_agent.eino_skills`** 控制中间件与本机 read_file/glob/grep/write/edit/execute(**Deep / Supervisor** 主/子代理;**plan_execute** 执行器无独立 skill 中间件,见文档)。**单代理 ReAct** 当前不挂载该 Eino skill 链。
|
- **Skills**:技能包位于 `skills_dir`;启用 **`multi_agent.eino_skills`** 后,**单代理与多代理**均可通过 Eino **`skill`** 工具按需加载。中间件与本机 read_file/glob/grep 等见文档。
|
||||||
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段。
|
- **轻松创建角色**:通过在 `roles/` 目录添加 YAML 文件即可创建自定义角色。每个角色定义 `name`、`description`、`user_prompt`、`icon`、`tools`、`enabled` 字段。
|
||||||
- **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。
|
- **Web 界面集成**:在聊天界面通过下拉菜单选择角色。角色选择会影响 AI 行为和可用工具建议。
|
||||||
|
|
||||||
@@ -276,14 +278,14 @@ go build -o cyberstrike-ai cmd/server/main.go
|
|||||||
2. 重启服务或重新加载配置,角色会出现在角色选择下拉菜单中。
|
2. 重启服务或重新加载配置,角色会出现在角色选择下拉菜单中。
|
||||||
|
|
||||||
### 多代理模式(Eino:Deep / Plan-Execute / Supervisor)
|
### 多代理模式(Eino:Deep / Plan-Execute / Supervisor)
|
||||||
- **能力说明**:与 **单代理 ReAct** 并存的可选路径,基于 CloudWeGo **Eino** `adk/prebuilt`:**`deep`** — 协调主代理 + **`task`** 子代理;**`plan_execute`** — 规划 / 执行 / 重规划闭环(不使用 YAML/Markdown 子代理列表);**`supervisor`** — 主代理 **`transfer`** / **`exit`** 调度 Markdown 专家。客户端通过 **`orchestration`** 选 `deep` | `plan_execute` | `supervisor`(缺省 `deep`)。
|
- **能力说明**:在 **Eino 单代理**(`/api/eino-agent*`)之外,多代理基于 CloudWeGo **Eino** `adk/prebuilt`:**`deep`**、**`plan_execute`**、**`supervisor`**;客户端 **`orchestration`** 选择(缺省 `deep`)。
|
||||||
- **Markdown 定义**(`agents_dir`,默认 `agents/`):
|
- **Markdown 定义**(`agents_dir`,默认 `agents/`):
|
||||||
- **Deep 主代理**:`orchestrator.md` 或唯一 `kind: orchestrator` 的 `.md`;正文或 `multi_agent.orchestrator_instruction`,再回退 Eino 默认。
|
- **Deep 主代理**:`orchestrator.md` 或唯一 `kind: orchestrator` 的 `.md`;正文或 `multi_agent.orchestrator_instruction`,再回退 Eino 默认。
|
||||||
- **Plan-Execute 主代理**:固定 **`orchestrator-plan-execute.md`**(另可配 `orchestrator_instruction_plan_execute`)。
|
- **Plan-Execute 主代理**:固定 **`orchestrator-plan-execute.md`**(另可配 `orchestrator_instruction_plan_execute`)。
|
||||||
- **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。
|
- **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。
|
||||||
- **子代理**(**deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
|
- **子代理**(**deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
|
||||||
- **界面管理**:**Agents → Agent 管理**;API `/api/multi-agent/markdown-agents`。
|
- **界面管理**:**Agents → Agent 管理**;API `/api/multi-agent/markdown-agents`。
|
||||||
- **配置项**:`multi_agent`:`enabled`、`default_mode`、`robot_use_multi_agent`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
- **配置项**:`multi_agent`:`enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
||||||
- **更多细节**:[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
|
- **更多细节**:[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
|
||||||
|
|
||||||
### Skills 技能系统(Agent Skills + Eino)
|
### Skills 技能系统(Agent Skills + Eino)
|
||||||
@@ -533,8 +535,8 @@ skills_dir: "skills" # Skills 目录(相对于配置文件所在目录)
|
|||||||
agents_dir: "agents" # 多代理 Markdown(主代理 orchestrator.md + 子代理 *.md)
|
agents_dir: "agents" # 多代理 Markdown(主代理 orchestrator.md + 子代理 *.md)
|
||||||
multi_agent:
|
multi_agent:
|
||||||
enabled: false
|
enabled: false
|
||||||
default_mode: "single" # single | multi(开启多代理时的界面默认模式)
|
default_mode: "eino_single" # eino_single | multi(开启多代理时的界面默认模式)
|
||||||
robot_use_multi_agent: false
|
robot_default_agent_mode: eino_single
|
||||||
batch_use_multi_agent: false
|
batch_use_multi_agent: false
|
||||||
orchestrator_instruction: "" # Deep;orchestrator.md 正文为空时使用
|
orchestrator_instruction: "" # Deep;orchestrator.md 正文为空时使用
|
||||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
||||||
|
|||||||
@@ -61,4 +61,8 @@ max_iterations: 0
|
|||||||
5) Follow-up Verification Plan(后续验证建议)
|
5) Follow-up Verification Plan(后续验证建议)
|
||||||
- 对每个优先条目:建议由哪个阶段子代理接手、需要补测的最小证据集
|
- 对每个优先条目:建议由哪个阶段子代理接手、需要补测的最小证据集
|
||||||
|
|
||||||
输出后直接结束。遇到证据不足的条目标注为“需要补证据”。
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|
||||||
|
输出后直接结束。遇到证据不足的条目标注为“需要补证据”。
|
||||||
|
|||||||
@@ -51,4 +51,8 @@ max_iterations: 0
|
|||||||
- 可能仍残留的风险类别与建议监控方式(只做高层建议)
|
- 可能仍残留的风险类别与建议监控方式(只做高层建议)
|
||||||
|
|
||||||
4) Handoff to Reporting(交接给报告的要点)
|
4) Handoff to Reporting(交接给报告的要点)
|
||||||
- 报告里应包含哪些字段以证明“合规清理”。
|
- 报告里应包含哪些字段以证明“合规清理”。
|
||||||
|
|
||||||
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|||||||
@@ -61,4 +61,8 @@ max_iterations: 0
|
|||||||
5) Open Questions(待澄清问题)
|
5) Open Questions(待澄清问题)
|
||||||
- 不足以继续的关键问题(尽量少而关键)
|
- 不足以继续的关键问题(尽量少而关键)
|
||||||
|
|
||||||
当你完成以上输出时,直接停止;不要向协调主代理以外的人解释过多背景。将所有不确定性标注为“需要补证据/需要澄清”。
|
当你完成以上输出时,直接停止;不要向协调主代理以外的人解释过多背景。将所有不确定性标注为“需要补证据/需要澄清”。
|
||||||
|
|
||||||
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|||||||
@@ -50,4 +50,8 @@ max_iterations: 0
|
|||||||
- 你要求执行的最小化原则(如不导出明文敏感字段、不保留原始样本等,用描述性语言)
|
- 你要求执行的最小化原则(如不导出明文敏感字段、不保留原始样本等,用描述性语言)
|
||||||
|
|
||||||
4) Recommended Next Agent(下一步建议)
|
4) Recommended Next Agent(下一步建议)
|
||||||
- 建议交给 `reporting-remediation` 和 `cleanup-rollback` 的证据输入要点。
|
- 建议交给 `reporting-remediation` 和 `cleanup-rollback` 的证据输入要点。
|
||||||
|
|
||||||
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|||||||
@@ -32,3 +32,7 @@ max_iterations: 0
|
|||||||
- 优先用工具拿可验证事实,标注信息来源与置信度;避免无依据推测。
|
- 优先用工具拿可验证事实,标注信息来源与置信度;避免无依据推测。
|
||||||
- 输出结构化(目标、发现项、证据摘要、建议后续动作),便于协调者合并进总报告。
|
- 输出结构化(目标、发现项、证据摘要、建议后续动作),便于协调者合并进总报告。
|
||||||
- 不执行未授权的入侵或社工骚扰;双用途技术仅用于甲方书面授权场景。
|
- 不执行未授权的入侵或社工骚扰;双用途技术仅用于甲方书面授权场景。
|
||||||
|
|
||||||
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|||||||
@@ -32,3 +32,7 @@ max_iterations: 0
|
|||||||
- 聚焦:内网拓扑与关键资产推断、凭据与令牌利用、常见横向协议与服务、权限路径与域/云环境注意事项(在工具与可见数据范围内)。
|
- 聚焦:内网拓扑与关键资产推断、凭据与令牌利用、常见横向协议与服务、权限路径与域/云环境注意事项(在工具与可见数据范围内)。
|
||||||
- 每一步说明假设前提与证据;禁止对未授权网段、生产无关系统或真实用户数据进行操作。
|
- 每一步说明假设前提与证据;禁止对未授权网段、生产无关系统或真实用户数据进行操作。
|
||||||
- 输出结构化:当前据点能力、发现的主机/服务、建议的下一步(可交给其他子代理或主代理编排)、风险与回滚注意点。
|
- 输出结构化:当前据点能力、发现的主机/服务、建议的下一步(可交给其他子代理或主代理编排)、风险与回滚注意点。
|
||||||
|
|
||||||
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|||||||
@@ -51,4 +51,8 @@ max_iterations: 0
|
|||||||
- 建议记录哪些证据字段(时间戳、目标、请求摘要、响应摘要、变更清单、回滚确认)
|
- 建议记录哪些证据字段(时间戳、目标、请求摘要、响应摘要、变更清单、回滚确认)
|
||||||
|
|
||||||
4) Stop & Rollback Criteria(停止与回滚标准)
|
4) Stop & Rollback Criteria(停止与回滚标准)
|
||||||
- 触发阈值/不可控情况(用描述性语言即可)
|
- 触发阈值/不可控情况(用描述性语言即可)
|
||||||
|
|
||||||
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|||||||
@@ -102,10 +102,34 @@ description: plan_execute 模式下的规划/重规划侧主代理:拆解目
|
|||||||
|
|
||||||
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
|
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
|
||||||
|
|
||||||
## 证据与漏洞
|
## 证据、黑板与漏洞
|
||||||
|
|
||||||
- 要求结论有证据支撑(请求/响应、命令输出、可复现步骤);禁止无依据的确定断言。
|
- 要求结论有证据支撑(请求/响应、命令输出、可复现步骤);禁止无依据的确定断言。
|
||||||
- 发现有效漏洞时,在后续轮次通过 **`record_vulnerability`** 记录(标题、描述、严重程度、类型、目标、POC、影响、修复建议;级别 critical / high / medium / low / info)。
|
|
||||||
|
## 项目黑板(事实)与漏洞记录(分离)
|
||||||
|
|
||||||
|
当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。
|
||||||
|
|
||||||
|
- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。
|
||||||
|
- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。
|
||||||
|
- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。
|
||||||
|
- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。
|
||||||
|
- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。
|
||||||
|
- **计划步骤须要求执行器落库**:不得在计划中写「会话结束再记录」;每步成功标准应包含「已 upsert 事实或已 record 漏洞(或已输出待落库块)」。
|
||||||
|
|
||||||
|
### 事实写入规范(审计复现 / 知识沉淀)
|
||||||
|
|
||||||
|
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||||
|
- **body**:完整可复现上下文,写入 `upsert_project_fact` 的 body 字段;索引不含 body,后续会话须靠 `get_project_fact` 取回。
|
||||||
|
- **category / fact_key 建议**:
|
||||||
|
- 环境认知:`target/`、`auth/`、`infra/`、`business/`(body 用环境模板即可)
|
||||||
|
- 发现与利用:`finding/`、`chain/`、`exploit/`、`poc/`(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||||
|
- **与漏洞记录分工**:`record_vulnerability` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||||
|
- 更新同一发现时保持相同 `fact_key` 覆盖写入,勿散落多个 key 导致上下文丢失。
|
||||||
|
|
||||||
|
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。
|
||||||
|
|
||||||
## 执行器对用户输出(重要)
|
## 执行器对用户输出(重要)
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ description: supervisor 模式下的协调者:通过 transfer 委派专家子
|
|||||||
- **`transfer` 交接包(强制,避免专家重复侦察)**:**把专家当作刚走进房间的同事——它没看过你的对话,不知道你做了什么,也不了解这个任务为什么重要。** 在触发 `transfer` 的**同一条助手正文**中写清(勿仅依赖历史里的长工具输出;摘要后专家可能看不到细节):
|
- **`transfer` 交接包(强制,避免专家重复侦察)**:**把专家当作刚走进房间的同事——它没看过你的对话,不知道你做了什么,也不了解这个任务为什么重要。** 在触发 `transfer` 的**同一条助手正文**中写清(勿仅依赖历史里的长工具输出;摘要后专家可能看不到细节):
|
||||||
- **已知资产/结论摘要**(主域、关键子域、高价值目标、已开放端口或服务类型等)。
|
- **已知资产/结论摘要**(主域、关键子域、高价值目标、已开放端口或服务类型等)。
|
||||||
- **本轮唯一任务**与 **禁止项**(例如:「不得再做全量子域枚举;仅对下列主机做 MQTT 验证」)。
|
- **本轮唯一任务**与 **禁止项**(例如:「不得再做全量子域枚举;仅对下列主机做 MQTT 验证」)。
|
||||||
|
- **图片/验证码(若有)**:本地绝对路径 + 期望输出格式(如验证码「只输出字符」);专家默认看不到父对话识图结果,须在交接正文中写明。
|
||||||
- **专家类型**:验证/利用/协议分析派对应专家,**避免**把「仅差验证」的工作交给 `recon` 导致其按习惯从侦察阶段重来。
|
- **专家类型**:验证/利用/协议分析派对应专家,**避免**把「仅差验证」的工作交给 `recon` 导致其按习惯从侦察阶段重来。
|
||||||
- **transfer 前目标完整性校验(强制)**:在 `transfer` 前必须具备并显式写入:
|
- **transfer 前目标完整性校验(强制)**:在 `transfer` 前必须具备并显式写入:
|
||||||
- 目标标识:`URL` 或 `IP:Port` 或 `域名 + 具体路径/API 基址`
|
- 目标标识:`URL` 或 `IP:Port` 或 `域名 + 具体路径/API 基址`
|
||||||
@@ -117,9 +118,29 @@ description: supervisor 模式下的协调者:通过 transfer 委派专家子
|
|||||||
3. 期望交付物是否可验收(例如:可复现命令、截图要点、结论段落)?
|
3. 期望交付物是否可验收(例如:可复现命令、截图要点、结论段落)?
|
||||||
4. 是否已明确写出 URL/IP:Port/域名路径与 in-scope 边界(而非“按上文继续”)?
|
4. 是否已明确写出 URL/IP:Port/域名路径与 in-scope 边界(而非“按上文继续”)?
|
||||||
|
|
||||||
## 漏洞
|
## 项目黑板(事实)与漏洞记录(分离)
|
||||||
|
|
||||||
有效漏洞应通过 **`record_vulnerability`** 记录(含 POC 与严重性)。
|
当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。
|
||||||
|
|
||||||
|
- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。
|
||||||
|
- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。
|
||||||
|
- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。
|
||||||
|
- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。
|
||||||
|
- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。
|
||||||
|
|
||||||
|
### 事实写入规范(审计复现 / 知识沉淀)
|
||||||
|
|
||||||
|
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||||
|
- **body**:完整可复现上下文,写入 `upsert_project_fact` 的 body 字段;索引不含 body,后续会话须靠 `get_project_fact` 取回。
|
||||||
|
- **category / fact_key 建议**:
|
||||||
|
- 环境认知:`target/`、`auth/`、`infra/`、`business/`(body 用环境模板即可)
|
||||||
|
- 发现与利用:`finding/`、`chain/`、`exploit/`、`poc/`(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||||
|
- **与漏洞记录分工**:`record_vulnerability` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||||
|
- 更新同一发现时保持相同 `fact_key` 覆盖写入,勿散落多个 key 导致上下文丢失。
|
||||||
|
|
||||||
|
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。
|
||||||
|
|
||||||
## 表达
|
## 表达
|
||||||
|
|
||||||
|
|||||||
+24
-1
@@ -33,6 +33,7 @@ description: 多代理模式下的 Deep 编排者:在已授权安全场景中
|
|||||||
- **`task` 上下文交接(强制,避免重复劳动)**:**把子代理当作刚走进房间的同事——它没看过你的对话,不知道你做了什么,也不了解这个任务为什么重要。** 框架下子代理默认**只看到**你传入的 `description` 文本,**看不到**你在父对话里已跑过的工具输出全文。因此每次 `task` 的 `description` 必须自带**交接包**(可精简,但不可省略关键事实):
|
- **`task` 上下文交接(强制,避免重复劳动)**:**把子代理当作刚走进房间的同事——它没看过你的对话,不知道你做了什么,也不了解这个任务为什么重要。** 框架下子代理默认**只看到**你传入的 `description` 文本,**看不到**你在父对话里已跑过的工具输出全文。因此每次 `task` 的 `description` 必须自带**交接包**(可精简,但不可省略关键事实):
|
||||||
- **已完成**:已枚举的主域/子域要点、已扫端口或服务结论、已确认 IP/URL、协调者已知的漏洞假设等(用列表或短段落即可)。
|
- **已完成**:已枚举的主域/子域要点、已扫端口或服务结论、已确认 IP/URL、协调者已知的漏洞假设等(用列表或短段落即可)。
|
||||||
- **本轮只做**:明确写「本轮禁止重复全量子域爆破 / 禁止重复相同 subfinder 参数集」等(若确实需要增量,写清增量范围)。
|
- **本轮只做**:明确写「本轮禁止重复全量子域爆破 / 禁止重复相同 subfinder 参数集」等(若确实需要增量,写清增量范围)。
|
||||||
|
- **图片/验证码(若有)**:本地绝对路径 + 期望输出格式(如验证码「只输出字符」、登录页 UI 要素列表);子代理默认看不到父对话里的识图结果,须在 description 中写明路径与格式。
|
||||||
- **专家匹配**:验证、利用、协议深挖(如 MQTT)等应委派给**对应专项子代理**;不要把此类子目标交给纯侦察(`recon`)角色除非任务仅为补充攻击面。
|
- **专家匹配**:验证、利用、协议深挖(如 MQTT)等应委派给**对应专项子代理**;不要把此类子目标交给纯侦察(`recon`)角色除非任务仅为补充攻击面。
|
||||||
- **派单前目标完整性校验(强制)**:在调用 `task` 前,你必须检查并写入最小必需字段;任一缺失时**禁止委派**,先向用户澄清或先自行补充证据:
|
- **派单前目标完整性校验(强制)**:在调用 `task` 前,你必须检查并写入最小必需字段;任一缺失时**禁止委派**,先向用户澄清或先自行补充证据:
|
||||||
- **目标标识**:`URL` 或 `IP:Port` 或 `域名 + 具体路径/API 基址`
|
- **目标标识**:`URL` 或 `IP:Port` 或 `域名 + 具体路径/API 基址`
|
||||||
@@ -127,7 +128,29 @@ description: 多代理模式下的 Deep 编排者:在已授权安全场景中
|
|||||||
## 工具与 MCP
|
## 工具与 MCP
|
||||||
|
|
||||||
- **工具调用失败时**:1) 仔细分析错误信息,理解失败的具体原因;2) 如果工具不存在或未启用,尝试使用其他替代工具完成相同目标;3) 如果参数错误,根据错误提示修正参数后重试;4) 如果工具执行失败但输出了有用信息,可以基于这些信息继续分析;5) 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作;6) 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务。工具返回的错误信息会包含在工具响应中,请仔细阅读并做出合理决策。
|
- **工具调用失败时**:1) 仔细分析错误信息,理解失败的具体原因;2) 如果工具不存在或未启用,尝试使用其他替代工具完成相同目标;3) 如果参数错误,根据错误提示修正参数后重试;4) 如果工具执行失败但输出了有用信息,可以基于这些信息继续分析;5) 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作;6) 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务。工具返回的错误信息会包含在工具响应中,请仔细阅读并做出合理决策。
|
||||||
- **漏洞记录**:发现**有效漏洞**时,必须使用 **`record_vulnerability`** 记录(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度使用 critical / high / medium / low / info。记录后可在授权范围内继续测试。
|
## 项目黑板(事实)与漏洞记录(分离)
|
||||||
|
|
||||||
|
当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。
|
||||||
|
|
||||||
|
- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。
|
||||||
|
- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。
|
||||||
|
- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。
|
||||||
|
- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。
|
||||||
|
- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。
|
||||||
|
|
||||||
|
### 事实写入规范(审计复现 / 知识沉淀)
|
||||||
|
|
||||||
|
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||||
|
- **body**:完整可复现上下文,写入 `upsert_project_fact` 的 body 字段;索引不含 body,后续会话须靠 `get_project_fact` 取回。
|
||||||
|
- **category / fact_key 建议**:
|
||||||
|
- 环境认知:`target/`、`auth/`、`infra/`、`business/`(body 用环境模板即可)
|
||||||
|
- 发现与利用:`finding/`、`chain/`、`exploit/`、`poc/`(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||||
|
- **与漏洞记录分工**:`record_vulnerability` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||||
|
- 更新同一发现时保持相同 `fact_key` 覆盖写入,勿散落多个 key 导致上下文丢失。
|
||||||
|
|
||||||
|
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。
|
||||||
- **编排进度(待办)**:当你的任务包含 3 个或以上步骤,或你准备委派多个子目标并行/串行推进时,优先使用 `write_todos` 来向用户展示“当前在做什么/接下来做什么”。维护约束:同一时刻最多一个条目处于 `in_progress`;完成后立刻标记 `completed`;遇到阻塞就保留为 `in_progress` 并继续推进。
|
- **编排进度(待办)**:当你的任务包含 3 个或以上步骤,或你准备委派多个子目标并行/串行推进时,优先使用 `write_todos` 来向用户展示“当前在做什么/接下来做什么”。维护约束:同一时刻最多一个条目处于 `in_progress`;完成后立刻标记 `completed`;遇到阻塞就保留为 `in_progress` 并继续推进。
|
||||||
- **强触发建议(提升多 agent 使用率)**:如果你将要进行任何“证据收集/枚举/扫描/验证/复现/整理报告”这类实质执行动作,且不只是单步查询,请优先在第一个工具调用前就用 `write_todos` 建立计划;随后用 `task` 委派至少一个子代理获取结构化证据,而不是自己把全部步骤做完。
|
- **强触发建议(提升多 agent 使用率)**:如果你将要进行任何“证据收集/枚举/扫描/验证/复现/整理报告”这类实质执行动作,且不只是单步查询,请优先在第一个工具调用前就用 `write_todos` 建立计划;随后用 `task` 委派至少一个子代理获取结构化证据,而不是自己把全部步骤做完。
|
||||||
- **技能库(Skills)与知识库**:技能包位于服务器 `skills/` 目录(各子目录 `SKILL.md`,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。多代理本会话通过内置 **`skill`** 工具渐进加载;子代理同样挂载 skill + 可选本机文件工具时,可在委派说明中提示按需加载。若当前无 skill 工具,需要完整 Skill 工作流时请使用多代理模式或切换为 Eino 编排会话。
|
- **技能库(Skills)与知识库**:技能包位于服务器 `skills/` 目录(各子目录 `SKILL.md`,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。多代理本会话通过内置 **`skill`** 工具渐进加载;子代理同样挂载 skill + 可选本机文件工具时,可在委派说明中提示按需加载。若当前无 skill 工具,需要完整 Skill 工作流时请使用多代理模式或切换为 Eino 编排会话。
|
||||||
|
|||||||
@@ -31,5 +31,9 @@ max_iterations: 0
|
|||||||
- 禁止自行猜测目标、替换为历史目标或擅自发起全量探索。
|
- 禁止自行猜测目标、替换为历史目标或擅自发起全量探索。
|
||||||
|
|
||||||
- 以证据为中心:请求/响应、Payload、命令输出、截图说明等,便于审计与复现。
|
- 以证据为中心:请求/响应、Payload、命令输出、截图说明等,便于审计与复现。
|
||||||
- 先确认边界与禁止项(如拒绝 DoS、数据破坏);发现有效漏洞时按协调者要求使用 `record_vulnerability` 等流程(若你的工具集中包含)。
|
- 先确认边界与禁止项(如拒绝 DoS、数据破坏)。
|
||||||
- 输出包含:攻击路径摘要、关键步骤、影响评估、修复与缓解建议;语言简洁,便于主代理汇总。
|
- 输出包含:攻击路径摘要、关键步骤、影响评估、修复与缓解建议;语言简洁,便于主代理汇总。
|
||||||
|
|
||||||
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|||||||
@@ -51,4 +51,8 @@ max_iterations: 0
|
|||||||
- 列出需要清理/验证的痕迹类型(配置、会话、日志、服务变更等层级描述即可)
|
- 列出需要清理/验证的痕迹类型(配置、会话、日志、服务变更等层级描述即可)
|
||||||
|
|
||||||
4) Recommended Next Steps(下一步建议)
|
4) Recommended Next Steps(下一步建议)
|
||||||
- 建议由哪个阶段子代理接手,以及需要哪些证据输入。
|
- 建议由哪个阶段子代理接手,以及需要哪些证据输入。
|
||||||
|
|
||||||
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|||||||
@@ -53,4 +53,8 @@ max_iterations: 0
|
|||||||
4) Recommended Next Agent(下一步建议)
|
4) Recommended Next Agent(下一步建议)
|
||||||
- 明确建议由哪个子代理接手(例如 `lateral-movement` / `persistence-maintenance` / `impact-exfiltration` / `reporting-remediation`)
|
- 明确建议由哪个子代理接手(例如 `lateral-movement` / `persistence-maintenance` / `impact-exfiltration` / `reporting-remediation`)
|
||||||
|
|
||||||
输出后直接结束。
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|
||||||
|
输出后直接结束。
|
||||||
|
|||||||
@@ -34,3 +34,7 @@ max_iterations: 0
|
|||||||
|
|
||||||
- 若 **`description` / 用户消息 / 上文交接包** 中已给出资产列表、枚举结论或明确写「跳过全量枚举 / 仅做增量 / 从端口扫描或验证开始」,则**不得**为走完整流程而重新执行等价的广域子域爆破或相同参数集的枚举;仅在交接包声明的**缺口**上补充侦察。
|
- 若 **`description` / 用户消息 / 上文交接包** 中已给出资产列表、枚举结论或明确写「跳过全量枚举 / 仅做增量 / 从端口扫描或验证开始」,则**不得**为走完整流程而重新执行等价的广域子域爆破或相同参数集的枚举;仅在交接包声明的**缺口**上补充侦察。
|
||||||
- 若子目标实为**漏洞验证、协议利用、权限提升**等而非攻击面扩展,应**极短说明**「当前角色为侦察;建议协调者改派专项代理」并仅提供与侦察相关的最小补充信息,避免擅自把任务扩写成新一轮全盘资产收集。
|
- 若子目标实为**漏洞验证、协议利用、权限提升**等而非攻击面扩展,应**极短说明**「当前角色为侦察;建议协调者改派专项代理」并仅提供与侦察相关的最小补充信息,避免擅自把任务扩写成新一轮全盘资产收集。
|
||||||
|
|
||||||
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|||||||
@@ -55,4 +55,8 @@ max_iterations: 0
|
|||||||
5) Appendix(附录)
|
5) Appendix(附录)
|
||||||
- 术语、假设、证据清单索引(按证据类型列出即可)
|
- 术语、假设、证据清单索引(按证据类型列出即可)
|
||||||
|
|
||||||
输出后直接结束。
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|
||||||
|
输出后直接结束。
|
||||||
|
|||||||
@@ -57,4 +57,8 @@ max_iterations: 0
|
|||||||
4) Uncertainties & Missing Evidence(不确定性与缺口)
|
4) Uncertainties & Missing Evidence(不确定性与缺口)
|
||||||
- 列出最关键的缺口(尽量少,但要关键)
|
- 列出最关键的缺口(尽量少,但要关键)
|
||||||
|
|
||||||
输出后直接结束。
|
## 边渗透边记录
|
||||||
|
|
||||||
|
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||||
|
|
||||||
|
输出后直接结束。
|
||||||
|
|||||||
+40
-9
@@ -10,7 +10,7 @@
|
|||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||||
version: "v1.6.19"
|
version: "v1.6.30"
|
||||||
# 服务器配置
|
# 服务器配置
|
||||||
server:
|
server:
|
||||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||||
@@ -61,10 +61,25 @@ openai:
|
|||||||
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
||||||
reasoning:
|
reasoning:
|
||||||
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
||||||
effort: high # low | medium | high | max;空表示不指定(openai_compat 下 auto 且无强度时不发请求扩展)
|
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
|
||||||
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
|
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
|
||||||
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
|
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
|
||||||
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
|
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
|
||||||
|
# 视觉分析(analyze_image MCP 工具;图片仅在单次 VL 调用中出现,Agent 上下文只保留文字摘要)
|
||||||
|
vision:
|
||||||
|
enabled: false # true 且 model 非空时注册 analyze_image
|
||||||
|
model: qwen-vl # VL 模型名(enabled 时必填)
|
||||||
|
api_key: "" # 留空则复用 openai.api_key
|
||||||
|
base_url: "" # 留空则复用 openai.base_url
|
||||||
|
provider: # 留空则复用 openai.provider(openai | claude)
|
||||||
|
max_image_bytes: 5242880 # 原始文件上限(字节),默认 5MB
|
||||||
|
max_dimension: 2048 # 长边缩放像素
|
||||||
|
jpeg_quality: 82
|
||||||
|
max_payload_bytes: 524288 # 编码后送 VL API 上限,默认 512KB
|
||||||
|
skip_preprocess_below_bytes: 2097152 # 低于 2MB 且长边<=max_dimension 且<=max_payload 时原图直传;0=始终压缩
|
||||||
|
detail: auto # low | high | auto(Eino ImageURLDetail)
|
||||||
|
timeout_seconds: 60
|
||||||
|
# allowed_roots: [] # 额外允许的绝对路径根目录
|
||||||
# ============================================
|
# ============================================
|
||||||
# 信息收集(FOFA)配置(可选)
|
# 信息收集(FOFA)配置(可选)
|
||||||
# ============================================
|
# ============================================
|
||||||
@@ -77,21 +92,23 @@ fofa:
|
|||||||
# Agent 配置
|
# Agent 配置
|
||||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||||
agent:
|
agent:
|
||||||
max_iterations: 1200 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
max_iterations: 12000 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
||||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
||||||
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||||
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||||
|
|
||||||
|
system_prompt_path: ""
|
||||||
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
||||||
hitl:
|
hitl:
|
||||||
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
|
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
|
||||||
tool_whitelist: [read_file, list_dir, glob, grep]
|
tool_whitelist: [read_file, list_dir, glob, grep]
|
||||||
# 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存)
|
# 多代理与 Eino 单代理(CloudWeGo Eino ADK;单代理入口 /api/eino-agent*,多代理 /api/multi-agent*)
|
||||||
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
|
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
|
||||||
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/stream;Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人/批量无请求体时固定按 deep
|
# Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体 orchestration 中指定;机器人按 robot_default_agent_mode
|
||||||
multi_agent:
|
multi_agent:
|
||||||
enabled: true
|
enabled: true
|
||||||
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高)
|
robot_default_agent_mode: eino_single # 企微/钉钉/飞书机器人默认:eino_single | deep | plan_execute | supervisor
|
||||||
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
||||||
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
|
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
|
||||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
|
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
|
||||||
@@ -114,7 +131,7 @@ multi_agent:
|
|||||||
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
||||||
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
||||||
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
||||||
tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||||
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过
|
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过
|
||||||
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放
|
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放
|
||||||
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
||||||
@@ -125,12 +142,13 @@ multi_agent:
|
|||||||
reduction_sub_agents: true # true:子代理也挂 reduction;false:仅编排主代理使用 reduction
|
reduction_sub_agents: true # true:子代理也挂 reduction;false:仅编排主代理使用 reduction
|
||||||
summarization_trigger_ratio: 0.8 # summarization 触发比例(max_total_tokens * ratio),建议 0.75~0.85
|
summarization_trigger_ratio: 0.8 # summarization 触发比例(max_total_tokens * ratio),建议 0.75~0.85
|
||||||
summarization_emit_internal_events: true # true:发出 summarization 内部事件(便于诊断)
|
summarization_emit_internal_events: true # true:发出 summarization 内部事件(便于诊断)
|
||||||
history_input_budget_ratio: 0.35 # 历史入队预算比例(max_total_tokens * ratio)
|
|
||||||
plan_execute_user_input_budget_ratio: 0.35 # plan_execute 中 userInput 预算比例(planner/replanner/executor 共用)
|
plan_execute_user_input_budget_ratio: 0.35 # plan_execute 中 userInput 预算比例(planner/replanner/executor 共用)
|
||||||
plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例
|
plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例
|
||||||
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
||||||
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
||||||
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
|
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
|
||||||
|
run_retry_max_attempts: 0 # >0:429/5xx/网络抖动时 ADK 运行循环指数退避续跑次数;0=默认 10
|
||||||
|
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
||||||
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
||||||
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
||||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||||
@@ -260,11 +278,13 @@ robots:
|
|||||||
enabled: false
|
enabled: false
|
||||||
client_id: ""
|
client_id: ""
|
||||||
client_secret: ""
|
client_secret: ""
|
||||||
|
allow_conversation_id_fallback: false
|
||||||
lark: # 飞书
|
lark: # 飞书
|
||||||
enabled: false
|
enabled: false
|
||||||
app_id: ""
|
app_id: ""
|
||||||
app_secret: ""
|
app_secret: ""
|
||||||
verify_token: ""
|
verify_token: ""
|
||||||
|
allow_chat_id_fallback: false
|
||||||
# ============================================
|
# ============================================
|
||||||
# Skills 相关配置
|
# Skills 相关配置
|
||||||
# ============================================
|
# ============================================
|
||||||
@@ -286,3 +306,14 @@ agents_dir: agents
|
|||||||
# 系统会从该目录加载所有 .yaml 格式的角色配置文件
|
# 系统会从该目录加载所有 .yaml 格式的角色配置文件
|
||||||
# 每个角色应创建独立的配置文件,例如:roles/CTF.yaml, roles/默认.yaml 等
|
# 每个角色应创建独立的配置文件,例如:roles/CTF.yaml, roles/默认.yaml 等
|
||||||
roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录)
|
roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录)
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# 项目管理与事实黑板
|
||||||
|
# ============================================
|
||||||
|
project:
|
||||||
|
enabled: true
|
||||||
|
# default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID
|
||||||
|
fact_index_max_runes: 3500
|
||||||
|
fact_summary_max_runes: 240
|
||||||
|
default_inject_deprecated: false
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
# Eino 多代理改造说明(DeepAgent)
|
# Eino 多代理改造说明(DeepAgent)
|
||||||
|
|
||||||
本文档记录 **单 Agent(原有 ReAct)** 与 **多 Agent(CloudWeGo Eino `adk/prebuilt/deep`)** 并存的改造范围、进度与后续事项。
|
本文档记录 **Eino 单代理(ADK)** 与 **多 Agent(CloudWeGo Eino `adk/prebuilt`)** 的改造范围、进度与后续事项。原生 ReAct 执行路径已移除。
|
||||||
|
|
||||||
## 总体结论
|
## 总体结论
|
||||||
|
|
||||||
- **改造已可用于生产试验**:流式对话、MCP 工具桥接、配置开关、前端模式切换均已落地。
|
- **改造已可用于生产试验**:流式对话、MCP 工具桥接、配置开关、前端模式切换均已落地。
|
||||||
- **入口策略**:主聊天与 WebShell 在开启多代理且用户选择 **Deep / Plan-Execute / Supervisor** 时走 `/api/multi-agent/stream`,请求体字段 **`orchestration`** 指定当次编排(与界面一致);**原生 ReAct** 走 `/api/agent-loop/stream`。机器人、批量任务无该请求体时服务端按 **`deep`** 执行。均需 `multi_agent.enabled`。
|
- **入口策略**:**单代理** 走 `/api/eino-agent/stream`;多代理(**Deep / Plan-Execute / Supervisor**)走 `/api/multi-agent/stream`,请求体 **`orchestration`** 指定编排。机器人默认 `robot_default_agent_mode: eino_single`;批量队列默认 `eino_single`,多代理模式需 `multi_agent.enabled`。
|
||||||
|
|
||||||
## 已完成项
|
## 已完成项
|
||||||
|
|
||||||
@@ -18,13 +18,13 @@
|
|||||||
| 编排 | `internal/multiagent/runner.go`:`deep.New` + 子 `ChatModelAgent` + `adk.NewRunner`(`EnableStreaming: true`,可选 `CheckPointStore`),事件映射为现有 SSE `tool_call` / `response_delta` 等。 |
|
| 编排 | `internal/multiagent/runner.go`:`deep.New` + 子 `ChatModelAgent` + `adk.NewRunner`(`EnableStreaming: true`,可选 `CheckPointStore`),事件映射为现有 SSE `tool_call` / `response_delta` 等。 |
|
||||||
| HTTP | `POST /api/multi-agent`(非流式)、`POST /api/multi-agent/stream`(SSE);路由**常注册**,是否可用由运行时 `multi_agent.enabled` 决定(流式未启用时 SSE 内 `error` + `done`)。 |
|
| HTTP | `POST /api/multi-agent`(非流式)、`POST /api/multi-agent/stream`(SSE);路由**常注册**,是否可用由运行时 `multi_agent.enabled` 决定(流式未启用时 SSE 内 `error` + `done`)。 |
|
||||||
| 会话准备 | `internal/handler/multi_agent_prepare.go`:`prepareMultiAgentSession`(含 **WebShell** `CreateConversationWithWebshell`、工具白名单与单代理一致)。 |
|
| 会话准备 | `internal/handler/multi_agent_prepare.go`:`prepareMultiAgentSession`(含 **WebShell** `CreateConversationWithWebshell`、工具白名单与单代理一致)。 |
|
||||||
| 单 Agent | `internal/agent` 增加 `ToolsForRole`、`ExecuteMCPToolForConversation`;原 `/api/agent-loop` 未删改语义。 |
|
| 单 Agent | `internal/agent` 为 MCP/工具层(`ToolsForRole`、`ExecuteMCPToolForConversation`);单代理编排走 `RunEinoSingleChatModelAgent`(`/api/eino-agent*`)。 |
|
||||||
| 前端 | 主聊天 / WebShell:`multi_agent.enabled` 时可选 **原生 ReAct** 与三种 Eino 命名,多代理路径在 JSON 中带 `orchestration`。设置页不再配置预置编排项;`plan_execute` 外层循环上限等仍可在设置中保存。 |
|
| 前端 | 主聊天 / WebShell:**Eino 单代理**(`/api/eino-agent/stream`)与 **Deep / Plan-Execute / Supervisor**(`/api/multi-agent/stream` + `orchestration`);`multi_agent.enabled` 控制多代理选项是否展示。 |
|
||||||
| 流式兼容 | 与 `/api/agent-loop/stream` 共用 `handleStreamEvent`:`conversation`、`progress`、`response_start` / `response_delta`、`thinking` / `thinking_stream_*`(模型 `ReasoningContent`)、`tool_*`、`response`、`done` 等;`tool_result` 带 `toolCallId` 与 `tool_call` 联动;`data.mcpExecutionIds` 与进度 i18n 已对齐。 |
|
| 流式兼容 | Eino 单/多代理与 Web UI 共用 `handleStreamEvent`:`conversation`、`progress`、`response_start` / `response_delta`、`thinking` / `thinking_stream_*`、`tool_*`、`response`、`done` 等。 |
|
||||||
| 批量任务 | 队列 `agentMode` 为 `deep` / `plan_execute` / `supervisor` 时子任务带对应 `orchestration` 调用 `RunDeepAgent`;旧值 `multi` 与「`agentMode` 为空且 `batch_use_multi_agent: true`」均按 `deep`。 |
|
| 批量任务 | 队列 `agentMode` 为 `deep` / `plan_execute` / `supervisor` 时子任务带对应 `orchestration` 调用 `RunDeepAgent`;旧值 `multi` 与「`agentMode` 为空且 `batch_use_multi_agent: true`」均按 `deep`。 |
|
||||||
| 配置 API | `GET /api/config` 返回 `multi_agent: { enabled, robot_use_multi_agent, sub_agent_count }`;`PUT /api/config` 可更新 `enabled`、`robot_use_multi_agent`(不覆盖 `sub_agents`)。 |
|
| 配置 API | `GET /api/config` 返回 `multi_agent: { enabled, robot_use_multi_agent, sub_agent_count }`;`PUT /api/config` 可更新 `enabled`、`robot_use_multi_agent`(不覆盖 `sub_agents`)。 |
|
||||||
| OpenAPI | 多代理路径说明已更新(流式未启用为 SSE 错误事件)。 |
|
| OpenAPI | 多代理路径说明已更新(流式未启用为 SSE 错误事件)。 |
|
||||||
| 机器人 | `ProcessMessageForRobot` 在 `enabled && robot_use_multi_agent` 时调用 `multiagent.RunDeepAgent`。 |
|
| 机器人 | `ProcessMessageForRobot` 按 `robot_default_agent_mode`(默认 `eino_single`)调用 `RunEinoSingleChatModelAgent` 或 `RunDeepAgent`。 |
|
||||||
| 预置编排 | 聊天 / WebShell:`POST /api/multi-agent*` 请求体 `orchestration`:`deep` \| `plan_execute` \| `supervisor`(缺省 `deep`)。`plan_execute` 不构建 YAML/Markdown 子代理;`plan_execute_loop_max_iterations` 仍来自配置。`supervisor` 至少需一个子代理。 |
|
| 预置编排 | 聊天 / WebShell:`POST /api/multi-agent*` 请求体 `orchestration`:`deep` \| `plan_execute` \| `supervisor`(缺省 `deep`)。`plan_execute` 不构建 YAML/Markdown 子代理;`plan_execute_loop_max_iterations` 仍来自配置。`supervisor` 至少需一个子代理。 |
|
||||||
| Eino 中间件 | `multi_agent.eino_middleware`(可选):`patchtoolcalls`(默认开)、`toolsearch`(按阈值拆分 MCP 工具列表)、`plantask`(需 `eino_skills`)、`reduction`(大工具输出截断/落盘)、`checkpoint_dir`(Runner 断点)、`deep_output_key` / `deep_model_retry_max_retries` / `task_tool_description_prefix`(Deep 与 supervisor 主代理共享其中模型重试与 OutputKey)。`plan_execute` 的 Executor 无 Handlers:仅继承 **ToolsConfig** 侧效果(如 `tool_search` 列表拆分),不挂载 patch/plantask/reduction 中间件。 |
|
| Eino 中间件 | `multi_agent.eino_middleware`(可选):`patchtoolcalls`(默认开)、`toolsearch`(按阈值拆分 MCP 工具列表)、`plantask`(需 `eino_skills`)、`reduction`(大工具输出截断/落盘)、`checkpoint_dir`(Runner 断点)、`deep_output_key` / `deep_model_retry_max_retries` / `task_tool_description_prefix`(Deep 与 supervisor 主代理共享其中模型重试与 OutputKey)。`plan_execute` 的 Executor 无 Handlers:仅继承 **ToolsConfig** 侧效果(如 `tool_search` 列表拆分),不挂载 patch/plantask/reduction 中间件。 |
|
||||||
|
|
||||||
@@ -59,3 +59,4 @@
|
|||||||
| 2026-03-22 | `orchestrator.md` / `kind: orchestrator` 主代理、列表主/子标记、与 `orchestrator_instruction` 优先级。 |
|
| 2026-03-22 | `orchestrator.md` / `kind: orchestrator` 主代理、列表主/子标记、与 `orchestrator_instruction` 优先级。 |
|
||||||
| 2026-04-19 | 主聊天「对话模式」:原生 ReAct 与 Deep / Plan-Execute / Supervisor;`POST /api/multi-agent*` 请求体 `orchestration` 与界面一致;`config.yaml` / 设置页不再维护预置编排字段(机器人/批量默认 `deep`)。 |
|
| 2026-04-19 | 主聊天「对话模式」:原生 ReAct 与 Deep / Plan-Execute / Supervisor;`POST /api/multi-agent*` 请求体 `orchestration` 与界面一致;`config.yaml` / 设置页不再维护预置编排字段(机器人/批量默认 `deep`)。 |
|
||||||
| 2026-04-21 | 移除角色 `skills` 与 `/api/roles/skills/list`;`bind_role` 仅继承 tools;Skills 仅通过 Eino `skill` 工具按需加载。 |
|
| 2026-04-21 | 移除角色 `skills` 与 `/api/roles/skills/list`;`bind_role` 仅继承 tools;Skills 仅通过 Eino `skill` 工具按需加载。 |
|
||||||
|
| 2026-06-02 | **移除原生 ReAct**:删除 `/api/agent-loop*` 执行入口与 `AgentLoopWithProgress`;统一 Eino ADK(单代理 `/api/eino-agent*`,多代理 `/api/multi-agent*`);任务 cancel/tasks API 保留。 |
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
# 视觉分析(analyze_image)
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
- **工具名**:`analyze_image`(MCP 内置)
|
||||||
|
- **行为**:读取本地图片 → `imaging` 缩放/JPEG 压缩 → 调用独立 **Vision** 模型 → 返回**纯文本**给 Agent
|
||||||
|
- **上下文**:图片字节**不会**写入对话历史;仅路径与文字摘要进入 Agent 上下文
|
||||||
|
|
||||||
|
## 配置(`config.yaml` → `vision`)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vision:
|
||||||
|
enabled: true
|
||||||
|
model: qwen-vl-max # 必填
|
||||||
|
api_key: # 留空 → openai.api_key
|
||||||
|
base_url: # 留空 → openai.base_url
|
||||||
|
provider: # 留空 → openai.provider
|
||||||
|
max_image_bytes: 5242880
|
||||||
|
max_dimension: 2048
|
||||||
|
jpeg_quality: 82
|
||||||
|
max_payload_bytes: 524288
|
||||||
|
skip_preprocess_below_bytes: 2097152 # 低于 2MB 且长边<=max_dimension 时原图直传;0=始终 JPEG 压缩
|
||||||
|
detail: low # low | high | auto
|
||||||
|
timeout_seconds: 60
|
||||||
|
# allowed_roots: [] # 额外绝对路径根
|
||||||
|
```
|
||||||
|
|
||||||
|
`enabled: false` 时不注册工具。
|
||||||
|
|
||||||
|
## Web 设置
|
||||||
|
|
||||||
|
**系统设置 → 基本设置 → 视觉分析(analyze_image)** 可配置启用开关、视觉模型、API Key/Base URL(留空复用 OpenAI)、预处理参数;**保存并应用** 后写入 `config.yaml` 并重新注册 MCP 工具。
|
||||||
|
|
||||||
|
## 路径白名单
|
||||||
|
|
||||||
|
默认可读:
|
||||||
|
|
||||||
|
- 进程工作目录(`cwd`)及其子路径
|
||||||
|
- `chat_uploads/`
|
||||||
|
- `agent.result_storage_dir`(默认 `tmp/`)
|
||||||
|
- `vision.allowed_roots` 中配置的绝对路径
|
||||||
|
|
||||||
|
## Agent 使用
|
||||||
|
|
||||||
|
系统提示已说明:遇图片调用 `analyze_image`,勿用 `read_file` 读二进制图。
|
||||||
|
|
||||||
|
`multi_agent.eino_middleware.tool_search_always_visible_tools` 建议包含 `analyze_image`。
|
||||||
|
|
||||||
|
## 合规
|
||||||
|
|
||||||
|
启用后图片会发往 Vision API 配置的上游;敏感环境请使用可信网关或保持 `enabled: false`。
|
||||||
+1
-1
@@ -272,4 +272,4 @@ curl -X POST "http://localhost:8080/api/robot/test" \
|
|||||||
|
|
||||||
- 钉钉、飞书均**仅处理文本消息**;其他类型(如图片、语音)会提示暂不支持或忽略。
|
- 钉钉、飞书均**仅处理文本消息**;其他类型(如图片、语音)会提示暂不支持或忽略。
|
||||||
- 会话与 Web 端共用同一套对话数据:在机器人里创建的对话会在 Web 端「对话」列表中看到,反之亦然。
|
- 会话与 Web 端共用同一套对话数据:在机器人里创建的对话会在 Web 端「对话」列表中看到,反之亦然。
|
||||||
- 机器人执行逻辑与 **`/api/agent-loop/stream`** 一致(含进度回调、过程详情写入数据库),仅不向客户端推送 SSE,最后将完整回复一次性发回钉钉/飞书/企业微信。
|
- 机器人执行与 **Eino 单/多代理** 相同逻辑(`ProcessMessageForRobot`,含进度回调与过程详情入库),仅不向客户端推送 SSE,最后一次性回复钉钉/飞书/企业微信。默认 `robot_default_agent_mode: eino_single`。
|
||||||
|
|||||||
+1
-1
@@ -269,4 +269,4 @@ Check in this order:
|
|||||||
|
|
||||||
- DingTalk and Lark: **text messages only**; other types (e.g. image, voice) are not supported and may be ignored.
|
- DingTalk and Lark: **text messages only**; other types (e.g. image, voice) are not supported and may be ignored.
|
||||||
- Conversations are shared with the web UI: conversations created from the bot appear in the web “Conversations” list and vice versa.
|
- Conversations are shared with the web UI: conversations created from the bot appear in the web “Conversations” list and vice versa.
|
||||||
- Bot execution uses the same logic as **`/api/agent-loop/stream`** (progress callbacks, process details stored in the DB); only the final reply is sent back to DingTalk/Lark in one message (no SSE to the client).
|
- Bot execution uses the same **Eino single/multi-agent** path as the web UI (`ProcessMessageForRobot`, with progress callbacks and process details stored in the DB); only the final reply is sent back to DingTalk/Lark in one message (no SSE). Default: `robot_default_agent_mode: eino_single`.
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ require (
|
|||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||||
github.com/pkoukk/tiktoken-go v0.1.8
|
github.com/pkoukk/tiktoken-go v0.1.8
|
||||||
github.com/robfig/cron/v3 v3.0.1
|
github.com/robfig/cron/v3 v3.0.1
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
go.opentelemetry.io/otel v1.34.0
|
go.opentelemetry.io/otel v1.34.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
|
||||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
|
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
|
||||||
@@ -48,6 +49,7 @@ require (
|
|||||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
|
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
|
||||||
|
github.com/disintegration/imaging v1.6.2 // indirect
|
||||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||||
@@ -75,7 +77,6 @@ require (
|
|||||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e // indirect
|
|
||||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect
|
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
@@ -90,6 +91,7 @@ require (
|
|||||||
golang.org/x/arch v0.15.0 // indirect
|
golang.org/x/arch v0.15.0 // indirect
|
||||||
golang.org/x/crypto v0.39.0 // indirect
|
golang.org/x/crypto v0.39.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||||
|
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8 // indirect
|
||||||
golang.org/x/oauth2 v0.30.0 // indirect
|
golang.org/x/oauth2 v0.30.0 // indirect
|
||||||
golang.org/x/sys v0.33.0 // indirect
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfv
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
|
||||||
|
github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
|
||||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
@@ -240,6 +242,8 @@ golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
|||||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw=
|
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw=
|
||||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM=
|
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM=
|
||||||
|
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8 h1:hVwzHzIUGRjiF7EcUjqNxk3NCfkPxbDKRdnNE1Rpg0U=
|
||||||
|
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 182 KiB After Width: | Height: | Size: 178 KiB |
+3
-1033
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,10 @@
|
|||||||
package agent
|
package agent
|
||||||
|
|
||||||
import "cyberstrike-ai/internal/mcp/builtin"
|
import (
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
|
)
|
||||||
|
|
||||||
// DefaultSingleAgentSystemPrompt 单代理(ReAct / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。
|
// DefaultSingleAgentSystemPrompt 单代理(Eino ADK / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。
|
||||||
func DefaultSingleAgentSystemPrompt() string {
|
func DefaultSingleAgentSystemPrompt() string {
|
||||||
return `你是CyberStrikeAI,是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。
|
return `你是CyberStrikeAI,是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。
|
||||||
|
|
||||||
@@ -105,15 +107,11 @@ func DefaultSingleAgentSystemPrompt() string {
|
|||||||
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
|
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
|
||||||
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
|
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
|
||||||
|
|
||||||
## 漏洞记录
|
` + project.FactRecordingBlackboardSection(false) + `
|
||||||
|
|
||||||
发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。
|
|
||||||
|
|
||||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。记录后可在授权范围内继续测试。
|
|
||||||
|
|
||||||
## 技能库(Skills)与知识库
|
## 技能库(Skills)与知识库
|
||||||
|
|
||||||
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
|
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
|
||||||
- 单代理本会话通过 MCP 使用知识库与漏洞记录等;Skills 的渐进式加载在「多代理 / Eino DeepAgent」中由内置 skill 工具完成(需在配置中启用 multi_agent.eino_skills)。
|
- 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。
|
||||||
- 若当前无 skill 工具,需要完整 Skill 工作流时请使用多代理模式或切换为 Eino 编排会话(亦可选 Eino ADK 单代理路径 /api/eino-agent)。`
|
- 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,491 +0,0 @@
|
|||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
|
||||||
"cyberstrike-ai/internal/openai"
|
|
||||||
|
|
||||||
"github.com/pkoukk/tiktoken-go"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// DefaultMinRecentMessage 压缩历史消息时保留的最近消息数量,确保最近的对话上下文不被压缩
|
|
||||||
DefaultMinRecentMessage = 5
|
|
||||||
// defaultChunkSize 压缩历史消息时每次处理的消息块大小,将旧消息分成多个块进行摘要
|
|
||||||
defaultChunkSize = 10
|
|
||||||
// defaultMaxImages 压缩时最多保留的图片数量,超过此数量的图片会被移除以节省上下文空间
|
|
||||||
defaultMaxImages = 3
|
|
||||||
// defaultSummaryTimeout 生成消息摘要时的超时时间
|
|
||||||
defaultSummaryTimeout = 10 * time.Minute
|
|
||||||
|
|
||||||
summaryPromptTemplate = `你是一名负责为安全代理执行上下文压缩的助手,任务是在保持所有关键渗透信息完整的前提下压缩扫描数据。
|
|
||||||
|
|
||||||
必须保留的关键信息:
|
|
||||||
- 已发现的漏洞与潜在攻击路径
|
|
||||||
- 扫描结果与工具输出(可压缩,但需保留核心发现)
|
|
||||||
- 获取到的访问凭证、令牌或认证细节
|
|
||||||
- 系统架构洞察与潜在薄弱点
|
|
||||||
- 当前评估进展
|
|
||||||
- 失败尝试与死路(避免重复劳动)
|
|
||||||
- 关于测试策略的所有决策记录
|
|
||||||
|
|
||||||
压缩指南:
|
|
||||||
- 保留精确技术细节(URL、路径、参数、Payload 等)
|
|
||||||
- 将冗长的工具输出压缩成概述,但保留关键发现
|
|
||||||
- 记录版本号与识别出的技术/组件信息
|
|
||||||
- 保留可能暗示漏洞的原始报错
|
|
||||||
- 将重复或相似发现整合成一条带有共性说明的结论
|
|
||||||
|
|
||||||
请牢记:另一位安全代理会依赖这份摘要继续测试,他必须在不损失任何作战上下文的情况下无缝接手。
|
|
||||||
|
|
||||||
需要压缩的对话片段:
|
|
||||||
%s
|
|
||||||
|
|
||||||
请给出技术精准且简明扼要的摘要,覆盖全部与安全评估相关的上下文。`
|
|
||||||
)
|
|
||||||
|
|
||||||
// MemoryCompressor 负责在调用LLM前压缩历史上下文,以避免Token爆炸。
|
|
||||||
type MemoryCompressor struct {
|
|
||||||
maxTotalTokens int
|
|
||||||
minRecentMessage int
|
|
||||||
maxImages int
|
|
||||||
chunkSize int
|
|
||||||
summaryModel string
|
|
||||||
timeout time.Duration
|
|
||||||
|
|
||||||
tokenCounter TokenCounter
|
|
||||||
completionClient CompletionClient
|
|
||||||
logger *zap.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// MemoryCompressorConfig 用于初始化 MemoryCompressor。
|
|
||||||
type MemoryCompressorConfig struct {
|
|
||||||
MaxTotalTokens int
|
|
||||||
MinRecentMessage int
|
|
||||||
MaxImages int
|
|
||||||
ChunkSize int
|
|
||||||
SummaryModel string
|
|
||||||
Timeout time.Duration
|
|
||||||
TokenCounter TokenCounter
|
|
||||||
CompletionClient CompletionClient
|
|
||||||
Logger *zap.Logger
|
|
||||||
|
|
||||||
// 当 CompletionClient 为空时,可以通过 OpenAIConfig + HTTPClient 构造默认的客户端。
|
|
||||||
OpenAIConfig *config.OpenAIConfig
|
|
||||||
HTTPClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMemoryCompressor 创建新的 MemoryCompressor。
|
|
||||||
func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error) {
|
|
||||||
if cfg.Logger == nil {
|
|
||||||
cfg.Logger = zap.NewNop()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果没有显式配置 MaxTotalTokens,则后续逻辑会根据模型的最大上下文长度进行控制;
|
|
||||||
// 优先推荐在 config.yaml 的 openai.max_total_tokens 中统一配置。
|
|
||||||
if cfg.MinRecentMessage <= 0 {
|
|
||||||
cfg.MinRecentMessage = DefaultMinRecentMessage
|
|
||||||
}
|
|
||||||
if cfg.MaxImages <= 0 {
|
|
||||||
cfg.MaxImages = defaultMaxImages
|
|
||||||
}
|
|
||||||
if cfg.ChunkSize <= 0 {
|
|
||||||
cfg.ChunkSize = defaultChunkSize
|
|
||||||
}
|
|
||||||
if cfg.Timeout <= 0 {
|
|
||||||
cfg.Timeout = defaultSummaryTimeout
|
|
||||||
}
|
|
||||||
if cfg.SummaryModel == "" && cfg.OpenAIConfig != nil && cfg.OpenAIConfig.Model != "" {
|
|
||||||
cfg.SummaryModel = cfg.OpenAIConfig.Model
|
|
||||||
}
|
|
||||||
if cfg.SummaryModel == "" {
|
|
||||||
return nil, errors.New("summary model is required (either SummaryModel or OpenAIConfig.Model must be set)")
|
|
||||||
}
|
|
||||||
if cfg.TokenCounter == nil {
|
|
||||||
cfg.TokenCounter = NewTikTokenCounter()
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.CompletionClient == nil {
|
|
||||||
if cfg.OpenAIConfig == nil {
|
|
||||||
return nil, errors.New("memory compressor requires either CompletionClient or OpenAIConfig")
|
|
||||||
}
|
|
||||||
if cfg.HTTPClient == nil {
|
|
||||||
cfg.HTTPClient = &http.Client{
|
|
||||||
Timeout: 5 * time.Minute,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cfg.CompletionClient = NewOpenAICompletionClient(cfg.OpenAIConfig, cfg.HTTPClient, cfg.Logger)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &MemoryCompressor{
|
|
||||||
maxTotalTokens: cfg.MaxTotalTokens,
|
|
||||||
minRecentMessage: cfg.MinRecentMessage,
|
|
||||||
maxImages: cfg.MaxImages,
|
|
||||||
chunkSize: cfg.ChunkSize,
|
|
||||||
summaryModel: cfg.SummaryModel,
|
|
||||||
timeout: cfg.Timeout,
|
|
||||||
tokenCounter: cfg.TokenCounter,
|
|
||||||
completionClient: cfg.CompletionClient,
|
|
||||||
logger: cfg.Logger,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateConfig 更新OpenAI配置(用于动态更新模型配置)
|
|
||||||
func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) {
|
|
||||||
if cfg == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新summaryModel字段
|
|
||||||
if cfg.Model != "" {
|
|
||||||
mc.summaryModel = cfg.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新completionClient中的配置(如果是OpenAICompletionClient)
|
|
||||||
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
|
|
||||||
openAIClient.UpdateConfig(cfg)
|
|
||||||
mc.logger.Info("MemoryCompressor配置已更新",
|
|
||||||
zap.String("model", cfg.Model),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CompressHistory 根据 Token 限制压缩历史消息。reservedTokens 为预留给 tools 等非消息内容的 token 数,压缩时使用 (maxTotalTokens - reservedTokens) 作为消息上限。
|
|
||||||
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage, reservedTokens int) ([]ChatMessage, bool, error) {
|
|
||||||
if len(messages) == 0 {
|
|
||||||
return messages, false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mc.handleImages(messages)
|
|
||||||
|
|
||||||
systemMsgs, regularMsgs := mc.splitMessages(messages)
|
|
||||||
if len(regularMsgs) <= mc.minRecentMessage {
|
|
||||||
return messages, false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
effectiveMax := mc.maxTotalTokens
|
|
||||||
if reservedTokens > 0 && reservedTokens < mc.maxTotalTokens {
|
|
||||||
effectiveMax = mc.maxTotalTokens - reservedTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
totalTokens := mc.countTotalTokens(systemMsgs, regularMsgs)
|
|
||||||
if totalTokens <= int(float64(effectiveMax)*0.9) {
|
|
||||||
return messages, false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
recentStart := len(regularMsgs) - mc.minRecentMessage
|
|
||||||
recentStart = mc.adjustRecentStartForToolCalls(regularMsgs, recentStart)
|
|
||||||
oldMsgs := regularMsgs[:recentStart]
|
|
||||||
recentMsgs := regularMsgs[recentStart:]
|
|
||||||
|
|
||||||
mc.logger.Info("memory compression triggered",
|
|
||||||
zap.Int("total_tokens", totalTokens),
|
|
||||||
zap.Int("max_total_tokens", mc.maxTotalTokens),
|
|
||||||
zap.Int("reserved_tokens", reservedTokens),
|
|
||||||
zap.Int("effective_max", effectiveMax),
|
|
||||||
zap.Int("system_messages", len(systemMsgs)),
|
|
||||||
zap.Int("regular_messages", len(regularMsgs)),
|
|
||||||
zap.Int("old_messages", len(oldMsgs)),
|
|
||||||
zap.Int("recent_messages", len(recentMsgs)))
|
|
||||||
|
|
||||||
var compressed []ChatMessage
|
|
||||||
for i := 0; i < len(oldMsgs); i += mc.chunkSize {
|
|
||||||
end := i + mc.chunkSize
|
|
||||||
if end > len(oldMsgs) {
|
|
||||||
end = len(oldMsgs)
|
|
||||||
}
|
|
||||||
chunk := oldMsgs[i:end]
|
|
||||||
if len(chunk) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
summary, err := mc.summarizeChunk(ctx, chunk)
|
|
||||||
if err != nil {
|
|
||||||
mc.logger.Warn("chunk summary failed, fallback to raw chunk",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.Int("start", i),
|
|
||||||
zap.Int("end", end))
|
|
||||||
compressed = append(compressed, chunk...)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
compressed = append(compressed, summary)
|
|
||||||
}
|
|
||||||
|
|
||||||
finalMessages := make([]ChatMessage, 0, len(systemMsgs)+len(compressed)+len(recentMsgs))
|
|
||||||
finalMessages = append(finalMessages, systemMsgs...)
|
|
||||||
finalMessages = append(finalMessages, compressed...)
|
|
||||||
finalMessages = append(finalMessages, recentMsgs...)
|
|
||||||
|
|
||||||
return finalMessages, true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MemoryCompressor) handleImages(messages []ChatMessage) {
|
|
||||||
if mc.maxImages <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
count := 0
|
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
|
||||||
content := messages[i].Content
|
|
||||||
if !strings.Contains(content, "[IMAGE]") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
if count > mc.maxImages {
|
|
||||||
messages[i].Content = "[Previously attached image removed to preserve context]"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MemoryCompressor) splitMessages(messages []ChatMessage) (systemMsgs, regularMsgs []ChatMessage) {
|
|
||||||
for _, msg := range messages {
|
|
||||||
if strings.EqualFold(msg.Role, "system") {
|
|
||||||
systemMsgs = append(systemMsgs, msg)
|
|
||||||
} else {
|
|
||||||
regularMsgs = append(regularMsgs, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MemoryCompressor) countTotalTokens(systemMsgs, regularMsgs []ChatMessage) int {
|
|
||||||
total := 0
|
|
||||||
for _, msg := range systemMsgs {
|
|
||||||
total += mc.countTokens(msg.Content)
|
|
||||||
}
|
|
||||||
for _, msg := range regularMsgs {
|
|
||||||
total += mc.countTokens(msg.Content)
|
|
||||||
}
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
|
|
||||||
// getModelName 获取当前使用的模型名称(优先从completionClient获取最新配置)
|
|
||||||
func (mc *MemoryCompressor) getModelName() string {
|
|
||||||
// 如果completionClient是OpenAICompletionClient,从它获取最新的模型名称
|
|
||||||
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
|
|
||||||
if openAIClient.config != nil && openAIClient.config.Model != "" {
|
|
||||||
return openAIClient.config.Model
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 否则使用保存的summaryModel
|
|
||||||
return mc.summaryModel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MemoryCompressor) countTokens(text string) int {
|
|
||||||
if mc.tokenCounter == nil {
|
|
||||||
return len(text) / 4
|
|
||||||
}
|
|
||||||
modelName := mc.getModelName()
|
|
||||||
count, err := mc.tokenCounter.Count(modelName, text)
|
|
||||||
if err != nil {
|
|
||||||
return len(text) / 4
|
|
||||||
}
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|
||||||
// CountTextTokens 对外暴露的文本 Token 计数,用于统计 tools 等非消息内容的 token(如 agent 侧序列化 tools 后计数)。
|
|
||||||
func (mc *MemoryCompressor) CountTextTokens(text string) int {
|
|
||||||
return mc.countTokens(text)
|
|
||||||
}
|
|
||||||
|
|
||||||
// totalTokensFor provides token statistics without mutating the message list.
|
|
||||||
func (mc *MemoryCompressor) totalTokensFor(messages []ChatMessage) (totalTokens int, systemCount int, regularCount int) {
|
|
||||||
if len(messages) == 0 {
|
|
||||||
return 0, 0, 0
|
|
||||||
}
|
|
||||||
systemMsgs, regularMsgs := mc.splitMessages(messages)
|
|
||||||
return mc.countTotalTokens(systemMsgs, regularMsgs), len(systemMsgs), len(regularMsgs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MemoryCompressor) summarizeChunk(ctx context.Context, chunk []ChatMessage) (ChatMessage, error) {
|
|
||||||
if len(chunk) == 0 {
|
|
||||||
return ChatMessage{}, errors.New("chunk is empty")
|
|
||||||
}
|
|
||||||
formatted := make([]string, 0, len(chunk))
|
|
||||||
for _, msg := range chunk {
|
|
||||||
formatted = append(formatted, fmt.Sprintf("%s: %s", msg.Role, mc.extractMessageText(msg)))
|
|
||||||
}
|
|
||||||
conversation := strings.Join(formatted, "\n")
|
|
||||||
prompt := fmt.Sprintf(summaryPromptTemplate, conversation)
|
|
||||||
|
|
||||||
// 使用动态获取的模型名称,而不是保存的summaryModel
|
|
||||||
modelName := mc.getModelName()
|
|
||||||
summary, err := mc.completionClient.Complete(ctx, modelName, prompt, mc.timeout)
|
|
||||||
if err != nil {
|
|
||||||
return ChatMessage{}, err
|
|
||||||
}
|
|
||||||
summary = strings.TrimSpace(summary)
|
|
||||||
if summary == "" {
|
|
||||||
return chunk[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return ChatMessage{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: fmt.Sprintf("<context_summary message_count='%d'>%s</context_summary>", len(chunk), summary),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MemoryCompressor) extractMessageText(msg ChatMessage) string {
|
|
||||||
return msg.Content
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *MemoryCompressor) adjustRecentStartForToolCalls(msgs []ChatMessage, recentStart int) int {
|
|
||||||
if recentStart <= 0 || recentStart >= len(msgs) {
|
|
||||||
return recentStart
|
|
||||||
}
|
|
||||||
|
|
||||||
adjusted := recentStart
|
|
||||||
for adjusted > 0 && strings.EqualFold(msgs[adjusted].Role, "tool") {
|
|
||||||
adjusted--
|
|
||||||
}
|
|
||||||
|
|
||||||
if adjusted != recentStart {
|
|
||||||
mc.logger.Debug("adjusted recent window to keep tool call context",
|
|
||||||
zap.Int("original_recent_start", recentStart),
|
|
||||||
zap.Int("adjusted_recent_start", adjusted),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return adjusted
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenCounter 用于计算文本Token数量。
|
|
||||||
type TokenCounter interface {
|
|
||||||
Count(model, text string) (int, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TikTokenCounter 基于 tiktoken 的 Token 统计器。
|
|
||||||
type TikTokenCounter struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
cache map[string]*tiktoken.Tiktoken
|
|
||||||
fallbackEncoding *tiktoken.Tiktoken
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTikTokenCounter 创建新的 TikTokenCounter。
|
|
||||||
func NewTikTokenCounter() *TikTokenCounter {
|
|
||||||
return &TikTokenCounter{
|
|
||||||
cache: make(map[string]*tiktoken.Tiktoken),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count 实现 TokenCounter 接口。
|
|
||||||
func (tc *TikTokenCounter) Count(model, text string) (int, error) {
|
|
||||||
enc, err := tc.encodingForModel(model)
|
|
||||||
if err != nil {
|
|
||||||
return len(text) / 4, err
|
|
||||||
}
|
|
||||||
tokens := enc.Encode(text, nil, nil)
|
|
||||||
return len(tokens), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tc *TikTokenCounter) encodingForModel(model string) (*tiktoken.Tiktoken, error) {
|
|
||||||
tc.mu.RLock()
|
|
||||||
if enc, ok := tc.cache[model]; ok {
|
|
||||||
tc.mu.RUnlock()
|
|
||||||
return enc, nil
|
|
||||||
}
|
|
||||||
tc.mu.RUnlock()
|
|
||||||
|
|
||||||
tc.mu.Lock()
|
|
||||||
defer tc.mu.Unlock()
|
|
||||||
|
|
||||||
if enc, ok := tc.cache[model]; ok {
|
|
||||||
return enc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
enc, err := tiktoken.EncodingForModel(model)
|
|
||||||
if err != nil {
|
|
||||||
if tc.fallbackEncoding == nil {
|
|
||||||
tc.fallbackEncoding, err = tiktoken.GetEncoding("cl100k_base")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tc.cache[model] = tc.fallbackEncoding
|
|
||||||
return tc.fallbackEncoding, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
tc.cache[model] = enc
|
|
||||||
return enc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CompletionClient 对话压缩时使用的补全接口。
|
|
||||||
type CompletionClient interface {
|
|
||||||
Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenAICompletionClient 基于 OpenAI Chat Completion。
|
|
||||||
type OpenAICompletionClient struct {
|
|
||||||
config *config.OpenAIConfig
|
|
||||||
client *openai.Client
|
|
||||||
logger *zap.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewOpenAICompletionClient 创建 OpenAICompletionClient。
|
|
||||||
func NewOpenAICompletionClient(cfg *config.OpenAIConfig, client *http.Client, logger *zap.Logger) *OpenAICompletionClient {
|
|
||||||
if logger == nil {
|
|
||||||
logger = zap.NewNop()
|
|
||||||
}
|
|
||||||
return &OpenAICompletionClient{
|
|
||||||
config: cfg,
|
|
||||||
client: openai.NewClient(cfg, client, logger),
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateConfig 更新底层配置。
|
|
||||||
func (c *OpenAICompletionClient) UpdateConfig(cfg *config.OpenAIConfig) {
|
|
||||||
c.config = cfg
|
|
||||||
if c.client != nil {
|
|
||||||
c.client.UpdateConfig(cfg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Complete 调用OpenAI获取摘要。
|
|
||||||
func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error) {
|
|
||||||
if c.config == nil {
|
|
||||||
return "", errors.New("openai config is required")
|
|
||||||
}
|
|
||||||
if model == "" {
|
|
||||||
return "", errors.New("model name is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
reqBody := OpenAIRequest{
|
|
||||||
Model: model,
|
|
||||||
Messages: []ChatMessage{
|
|
||||||
{Role: "user", Content: prompt},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
requestCtx := ctx
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
if timeout > 0 {
|
|
||||||
requestCtx, cancel = context.WithTimeout(ctx, timeout)
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
var completion OpenAIResponse
|
|
||||||
if c.client == nil {
|
|
||||||
return "", errors.New("openai completion client not initialized")
|
|
||||||
}
|
|
||||||
if err := c.client.ChatCompletion(requestCtx, reqBody, &completion); err != nil {
|
|
||||||
if apiErr, ok := err.(*openai.APIError); ok {
|
|
||||||
return "", fmt.Errorf("openai completion failed, status: %d, body: %s", apiErr.StatusCode, apiErr.Body)
|
|
||||||
}
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if completion.Error != nil {
|
|
||||||
return "", errors.New(completion.Error.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(completion.Choices) == 0 || completion.Choices[0].Message.Content == "" {
|
|
||||||
return "", errors.New("empty completion response")
|
|
||||||
}
|
|
||||||
return completion.Choices[0].Message.Content, nil
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pkoukk/tiktoken-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenCounter 估算文本 token 数(tiktoken;模型未知时回退 cl100k_base)。
|
||||||
|
type TokenCounter interface {
|
||||||
|
Count(model, text string) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type tikTokenCounter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
cache map[string]*tiktoken.Tiktoken
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTikTokenCounter 创建基于 tiktoken 的 TokenCounter。
|
||||||
|
func NewTikTokenCounter() TokenCounter {
|
||||||
|
return &tikTokenCounter{cache: make(map[string]*tiktoken.Tiktoken)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tikTokenCounter) encoding(model string) (*tiktoken.Tiktoken, error) {
|
||||||
|
key := model
|
||||||
|
if key == "" {
|
||||||
|
key = "cl100k_base"
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if enc, ok := c.cache[key]; ok {
|
||||||
|
return enc, nil
|
||||||
|
}
|
||||||
|
enc, err := tiktoken.EncodingForModel(key)
|
||||||
|
if err != nil {
|
||||||
|
enc, err = tiktoken.GetEncoding("cl100k_base")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.cache[key] = enc
|
||||||
|
return enc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tikTokenCounter) Count(model, text string) (int, error) {
|
||||||
|
if text == "" {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
enc, err := c.encoding(model)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return len(enc.Encode(text, nil, nil)), nil
|
||||||
|
}
|
||||||
+28
-195
@@ -111,7 +111,9 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
executor.RegisterTools(mcpServer)
|
executor.RegisterTools(mcpServer)
|
||||||
|
|
||||||
// 注册漏洞记录工具
|
// 注册漏洞记录工具
|
||||||
registerVulnerabilityTool(mcpServer, db, log.Logger)
|
registerVulnerabilityTools(mcpServer, db, log.Logger)
|
||||||
|
registerProjectFactTools(mcpServer, db, cfg, log.Logger)
|
||||||
|
registerVisionTools(mcpServer, cfg, log.Logger)
|
||||||
|
|
||||||
if cfg.Auth.GeneratedPassword != "" {
|
if cfg.Auth.GeneratedPassword != "" {
|
||||||
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
|
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
|
||||||
@@ -346,6 +348,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
authHandler.SetAudit(auditSvc)
|
authHandler.SetAudit(auditSvc)
|
||||||
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
||||||
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
||||||
|
projectHandler := handler.NewProjectHandler(db, log.Logger)
|
||||||
vulnerabilityHandler.SetAudit(auditSvc)
|
vulnerabilityHandler.SetAudit(auditSvc)
|
||||||
webshellHandler := handler.NewWebShellHandler(log.Logger, db)
|
webshellHandler := handler.NewWebShellHandler(log.Logger, db)
|
||||||
webshellHandler.SetAudit(auditSvc)
|
webshellHandler.SetAudit(auditSvc)
|
||||||
@@ -414,7 +417,9 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
|
|
||||||
// 设置漏洞工具注册器(内置工具,必须设置)
|
// 设置漏洞工具注册器(内置工具,必须设置)
|
||||||
vulnerabilityRegistrar := func() error {
|
vulnerabilityRegistrar := func() error {
|
||||||
registerVulnerabilityTool(mcpServer, db, log.Logger)
|
registerVulnerabilityTools(mcpServer, db, log.Logger)
|
||||||
|
registerProjectFactTools(mcpServer, db, cfg, log.Logger)
|
||||||
|
registerVisionTools(mcpServer, cfg, log.Logger)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar)
|
configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar)
|
||||||
@@ -502,6 +507,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
attackChainHandler,
|
attackChainHandler,
|
||||||
app, // 传递 App 实例以便动态获取 knowledgeHandler
|
app, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||||
vulnerabilityHandler,
|
vulnerabilityHandler,
|
||||||
|
projectHandler,
|
||||||
webshellHandler,
|
webshellHandler,
|
||||||
chatUploadsHandler,
|
chatUploadsHandler,
|
||||||
roleHandler,
|
roleHandler,
|
||||||
@@ -747,6 +753,7 @@ func setupRoutes(
|
|||||||
attackChainHandler *handler.AttackChainHandler,
|
attackChainHandler *handler.AttackChainHandler,
|
||||||
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
|
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||||
vulnerabilityHandler *handler.VulnerabilityHandler,
|
vulnerabilityHandler *handler.VulnerabilityHandler,
|
||||||
|
projectHandler *handler.ProjectHandler,
|
||||||
webshellHandler *handler.WebShellHandler,
|
webshellHandler *handler.WebShellHandler,
|
||||||
chatUploadsHandler *handler.ChatUploadsHandler,
|
chatUploadsHandler *handler.ChatUploadsHandler,
|
||||||
roleHandler *handler.RoleHandler,
|
roleHandler *handler.RoleHandler,
|
||||||
@@ -796,10 +803,6 @@ func setupRoutes(
|
|||||||
protected.POST("/robot/wechat/qrcode/verify", wechatRobotHandler.HandleWechatVerifyCode)
|
protected.POST("/robot/wechat/qrcode/verify", wechatRobotHandler.HandleWechatVerifyCode)
|
||||||
protected.GET("/robot/wechat/status", wechatRobotHandler.HandleWechatStatus)
|
protected.GET("/robot/wechat/status", wechatRobotHandler.HandleWechatStatus)
|
||||||
|
|
||||||
// Agent Loop
|
|
||||||
protected.POST("/agent-loop", agentHandler.AgentLoop)
|
|
||||||
// Agent Loop 流式输出
|
|
||||||
protected.POST("/agent-loop/stream", agentHandler.AgentLoopStream)
|
|
||||||
// Eino ADK 单代理(ChatModelAgent + Runner;不依赖 multi_agent.enabled)
|
// Eino ADK 单代理(ChatModelAgent + Runner;不依赖 multi_agent.enabled)
|
||||||
protected.POST("/eino-agent", agentHandler.EinoSingleAgentLoop)
|
protected.POST("/eino-agent", agentHandler.EinoSingleAgentLoop)
|
||||||
protected.POST("/eino-agent/stream", agentHandler.EinoSingleAgentLoopStream)
|
protected.POST("/eino-agent/stream", agentHandler.EinoSingleAgentLoopStream)
|
||||||
@@ -851,6 +854,7 @@ func setupRoutes(
|
|||||||
protected.GET("/conversations/:id", conversationHandler.GetConversation)
|
protected.GET("/conversations/:id", conversationHandler.GetConversation)
|
||||||
protected.GET("/messages/:id/process-details", conversationHandler.GetMessageProcessDetails)
|
protected.GET("/messages/:id/process-details", conversationHandler.GetMessageProcessDetails)
|
||||||
protected.PUT("/conversations/:id", conversationHandler.UpdateConversation)
|
protected.PUT("/conversations/:id", conversationHandler.UpdateConversation)
|
||||||
|
protected.PUT("/conversations/:id/project", conversationHandler.SetConversationProject)
|
||||||
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
|
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
|
||||||
protected.POST("/conversations/:id/delete-turn", conversationHandler.DeleteConversationTurn)
|
protected.POST("/conversations/:id/delete-turn", conversationHandler.DeleteConversationTurn)
|
||||||
protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned)
|
protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned)
|
||||||
@@ -886,6 +890,7 @@ func setupRoutes(
|
|||||||
protected.PUT("/config", configHandler.UpdateConfig)
|
protected.PUT("/config", configHandler.UpdateConfig)
|
||||||
protected.POST("/config/apply", configHandler.ApplyConfig)
|
protected.POST("/config/apply", configHandler.ApplyConfig)
|
||||||
protected.POST("/config/test-openai", configHandler.TestOpenAI)
|
protected.POST("/config/test-openai", configHandler.TestOpenAI)
|
||||||
|
protected.POST("/config/test-vision", configHandler.TestVision)
|
||||||
|
|
||||||
// 系统设置 - 终端(执行命令,提高运维效率)
|
// 系统设置 - 终端(执行命令,提高运维效率)
|
||||||
protected.POST("/terminal/run", terminalHandler.RunCommand)
|
protected.POST("/terminal/run", terminalHandler.RunCommand)
|
||||||
@@ -1067,6 +1072,23 @@ func setupRoutes(
|
|||||||
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
|
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
|
||||||
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
|
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
|
||||||
|
|
||||||
|
// 项目管理与事实黑板
|
||||||
|
protected.GET("/projects", projectHandler.ListProjects)
|
||||||
|
protected.POST("/projects", projectHandler.CreateProject)
|
||||||
|
protected.GET("/projects/:id/stats", projectHandler.GetProjectStats)
|
||||||
|
protected.GET("/projects/:id/conversations", projectHandler.ListProjectConversations)
|
||||||
|
protected.GET("/projects/:id", projectHandler.GetProject)
|
||||||
|
protected.PUT("/projects/:id", projectHandler.UpdateProject)
|
||||||
|
protected.DELETE("/projects/:id", projectHandler.DeleteProject)
|
||||||
|
protected.GET("/projects/:id/facts", projectHandler.ListFacts)
|
||||||
|
protected.GET("/projects/:id/facts/:factId/previous-version", projectHandler.GetFactPreviousVersion)
|
||||||
|
protected.GET("/projects/:id/facts/:factId/versions", projectHandler.ListFactVersions)
|
||||||
|
protected.POST("/projects/:id/facts", projectHandler.CreateFact)
|
||||||
|
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
|
||||||
|
protected.DELETE("/projects/:id/facts/:factId", projectHandler.DeleteFact)
|
||||||
|
protected.POST("/projects/:id/facts/deprecate", projectHandler.DeprecateFact)
|
||||||
|
protected.POST("/projects/:id/facts/restore", projectHandler.RestoreFact)
|
||||||
|
|
||||||
// WebShell 管理(代理执行 + 连接配置存 SQLite)
|
// WebShell 管理(代理执行 + 连接配置存 SQLite)
|
||||||
protected.GET("/webshell/connections", webshellHandler.ListConnections)
|
protected.GET("/webshell/connections", webshellHandler.ListConnections)
|
||||||
protected.POST("/webshell/connections", webshellHandler.CreateConnection)
|
protected.POST("/webshell/connections", webshellHandler.CreateConnection)
|
||||||
@@ -1187,195 +1209,6 @@ func setupRoutes(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
|
|
||||||
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
|
||||||
tool := mcp.Tool{
|
|
||||||
Name: builtin.ToolRecordVulnerability,
|
|
||||||
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
|
|
||||||
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
|
|
||||||
InputSchema: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"title": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "漏洞标题(必需)",
|
|
||||||
},
|
|
||||||
"description": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "漏洞详细描述",
|
|
||||||
},
|
|
||||||
"severity": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
|
|
||||||
"enum": []string{"critical", "high", "medium", "low", "info"},
|
|
||||||
},
|
|
||||||
"vulnerability_type": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
|
|
||||||
},
|
|
||||||
"target": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "受影响的目标(URL、IP地址、服务等)",
|
|
||||||
},
|
|
||||||
"proof": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "漏洞证明(POC、截图、请求/响应等)",
|
|
||||||
},
|
|
||||||
"impact": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "漏洞影响说明",
|
|
||||||
},
|
|
||||||
"recommendation": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "修复建议",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": []string{"title", "severity"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
|
||||||
// 从参数中获取conversation_id(由Agent自动添加)
|
|
||||||
conversationID, _ := args["conversation_id"].(string)
|
|
||||||
if conversationID == "" {
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "错误: conversation_id 未设置。这是系统错误,请重试。",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
title, ok := args["title"].(string)
|
|
||||||
if !ok || title == "" {
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "错误: title 参数必需且不能为空",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
severity, ok := args["severity"].(string)
|
|
||||||
if !ok || severity == "" {
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "错误: severity 参数必需且不能为空",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证严重程度
|
|
||||||
validSeverities := map[string]bool{
|
|
||||||
"critical": true,
|
|
||||||
"high": true,
|
|
||||||
"medium": true,
|
|
||||||
"low": true,
|
|
||||||
"info": true,
|
|
||||||
}
|
|
||||||
if !validSeverities[severity] {
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取可选参数
|
|
||||||
description := ""
|
|
||||||
if d, ok := args["description"].(string); ok {
|
|
||||||
description = d
|
|
||||||
}
|
|
||||||
|
|
||||||
vulnType := ""
|
|
||||||
if t, ok := args["vulnerability_type"].(string); ok {
|
|
||||||
vulnType = t
|
|
||||||
}
|
|
||||||
|
|
||||||
target := ""
|
|
||||||
if t, ok := args["target"].(string); ok {
|
|
||||||
target = t
|
|
||||||
}
|
|
||||||
|
|
||||||
proof := ""
|
|
||||||
if p, ok := args["proof"].(string); ok {
|
|
||||||
proof = p
|
|
||||||
}
|
|
||||||
|
|
||||||
impact := ""
|
|
||||||
if i, ok := args["impact"].(string); ok {
|
|
||||||
impact = i
|
|
||||||
}
|
|
||||||
|
|
||||||
recommendation := ""
|
|
||||||
if r, ok := args["recommendation"].(string); ok {
|
|
||||||
recommendation = r
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建漏洞记录
|
|
||||||
vuln := &database.Vulnerability{
|
|
||||||
ConversationID: conversationID,
|
|
||||||
Title: title,
|
|
||||||
Description: description,
|
|
||||||
Severity: severity,
|
|
||||||
Status: "open",
|
|
||||||
Type: vulnType,
|
|
||||||
Target: target,
|
|
||||||
Proof: proof,
|
|
||||||
Impact: impact,
|
|
||||||
Recommendation: recommendation,
|
|
||||||
}
|
|
||||||
|
|
||||||
created, err := db.CreateVulnerability(vuln)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("记录漏洞失败", zap.Error(err))
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: fmt.Sprintf("记录漏洞失败: %v", err),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("漏洞记录成功",
|
|
||||||
zap.String("id", created.ID),
|
|
||||||
zap.String("title", created.Title),
|
|
||||||
zap.String("severity", created.Severity),
|
|
||||||
zap.String("conversation_id", conversationID),
|
|
||||||
)
|
|
||||||
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n你可以在漏洞管理页面查看和管理此漏洞。", created.ID, created.Title, created.Severity, created.Status),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: false,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mcpServer.RegisterTool(tool, handler)
|
|
||||||
logger.Info("漏洞记录工具注册成功")
|
|
||||||
}
|
|
||||||
|
|
||||||
// registerWebshellTools 注册 WebShell 相关 MCP 工具,供 AI 助手在指定连接上执行命令与文件操作
|
// registerWebshellTools 注册 WebShell 相关 MCP 工具,供 AI 助手在指定连接上执行命令与文件操作
|
||||||
func registerWebshellTools(mcpServer *mcp.Server, db *database.DB, webshellHandler *handler.WebShellHandler, logger *zap.Logger) {
|
func registerWebshellTools(mcpServer *mcp.Server, db *database.DB, webshellHandler *handler.WebShellHandler, logger *zap.Logger) {
|
||||||
if db == nil || webshellHandler == nil {
|
if db == nil || webshellHandler == nil {
|
||||||
|
|||||||
@@ -0,0 +1,336 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func projectIDFromConversation(db *database.DB, ctx context.Context) (string, error) {
|
||||||
|
convID := agent.ConversationIDFromContext(ctx)
|
||||||
|
if convID == "" {
|
||||||
|
return "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用项目事实工具")
|
||||||
|
}
|
||||||
|
pid, err := db.GetConversationProjectID(convID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(pid) == "" {
|
||||||
|
return "", fmt.Errorf("当前对话未绑定项目,请先在对话中选择项目或创建带项目的对话")
|
||||||
|
}
|
||||||
|
return pid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func textResult(msg string, isErr bool) *mcp.ToolResult {
|
||||||
|
return &mcp.ToolResult{
|
||||||
|
Content: []mcp.Content{{Type: "text", Text: msg}},
|
||||||
|
IsError: isErr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerProjectFactTools 注册项目黑板 MCP 工具。
|
||||||
|
func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *config.Config, logger *zap.Logger) {
|
||||||
|
if db == nil || cfg == nil || !cfg.Project.Enabled {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Info("项目黑板工具未注册(未启用)")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
upsertTool := mcp.Tool{
|
||||||
|
Name: builtin.ToolUpsertProjectFact,
|
||||||
|
Description: "写入或更新项目黑板事实,用于跨会话沉淀可复现上下文(非正式漏洞条目;可交付漏洞另用 record_vulnerability)。" +
|
||||||
|
"边渗透边记录:每确认新认知(端口/入口/凭据/可利用点)后立即调用,同 fact_key 覆盖更新,勿等会话结束。" +
|
||||||
|
"禁止仅写结论:summary 须含什么+在哪+如何验证;body 须含攻击链/请求响应/命令等复现细节。" +
|
||||||
|
"发现类建议 fact_key 为 finding|chain|exploit|poc/<slug>,category 对应 finding|chain|exploit|poc,body 按攻击链模板填写。" +
|
||||||
|
"环境类用 target|auth|infra|business/<slug>。同 fact_key 覆盖更新。需当前对话已绑定项目。",
|
||||||
|
ShortDescription: "写入/更新项目事实(含攻击链 body)",
|
||||||
|
InputSchema: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"fact_key": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "项目内唯一 key:target/primary_domain、finding/sqli-login、exploit/upload-rce 等",
|
||||||
|
},
|
||||||
|
"category": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "target | auth | infra | business | finding | chain | exploit | poc | note",
|
||||||
|
"enum": []string{"target", "auth", "infra", "business", "finding", "chain", "exploit", "poc", "note"},
|
||||||
|
},
|
||||||
|
"summary": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "索引用一行:结论 + 位置 + 触发/验证要点(勿仅写「存在 XSS」等空话)",
|
||||||
|
},
|
||||||
|
"body": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "完整可复现详情(仅 get_project_fact 返回):须含攻击链步骤、原始 HTTP/命令、响应现象、证据与关联。" +
|
||||||
|
"发现/利用类首次写入必填;环境类建议含来源证据。攻击链类可参考模板章节:结论、目标与入口、攻击链、Exploit/POC、关键证据、关联、备注。" +
|
||||||
|
"更新已有 fact_key 时若省略或留空 body,将保留库中已有 body(可只改 summary)。",
|
||||||
|
},
|
||||||
|
"confidence": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "confirmed | tentative | deprecated",
|
||||||
|
"enum": []string{"confirmed", "tentative", "deprecated"},
|
||||||
|
},
|
||||||
|
"pinned": map[string]interface{}{
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "是否优先出现在黑板索引",
|
||||||
|
},
|
||||||
|
"related_vulnerability_id": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "可选:关联的漏洞记录 ID",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"fact_key", "summary"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mcpServer.RegisterTool(upsertTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
|
projectID, err := projectIDFromConversation(db, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
factKey, _ := args["fact_key"].(string)
|
||||||
|
summary, _ := args["summary"].(string)
|
||||||
|
if strings.TrimSpace(factKey) == "" || strings.TrimSpace(summary) == "" {
|
||||||
|
return textResult("错误: fact_key 与 summary 必填", true), nil
|
||||||
|
}
|
||||||
|
if len([]rune(summary)) > cfg.Project.FactSummaryMaxRunesEffective() {
|
||||||
|
return textResult(fmt.Sprintf("错误: summary 过长(最多 %d 字)", cfg.Project.FactSummaryMaxRunesEffective()), true), nil
|
||||||
|
}
|
||||||
|
f := &database.ProjectFact{
|
||||||
|
ProjectID: projectID,
|
||||||
|
FactKey: factKey,
|
||||||
|
Category: strArg(args, "category"),
|
||||||
|
Summary: summary,
|
||||||
|
Body: strArg(args, "body"),
|
||||||
|
Confidence: strArg(args, "confidence"),
|
||||||
|
Pinned: boolArg(args, "pinned"),
|
||||||
|
RelatedVulnerabilityID: strArg(args, "related_vulnerability_id"),
|
||||||
|
}
|
||||||
|
if convID := agent.ConversationIDFromContext(ctx); convID != "" {
|
||||||
|
f.SourceConversationID = convID
|
||||||
|
}
|
||||||
|
created, err := db.UpsertProjectFact(f)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence)
|
||||||
|
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||||
|
msg += warn
|
||||||
|
}
|
||||||
|
return textResult(msg, false), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
getTool := mcp.Tool{
|
||||||
|
Name: builtin.ToolGetProjectFact,
|
||||||
|
Description: "按 fact_key 获取项目事实完整 body 与元数据。摘要不足时必须调用本工具,禁止臆造细节。",
|
||||||
|
ShortDescription: "按 key 获取事实详情",
|
||||||
|
InputSchema: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"fact_key": map[string]interface{}{"type": "string", "description": "事实 key"},
|
||||||
|
},
|
||||||
|
"required": []string{"fact_key"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mcpServer.RegisterTool(getTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
|
projectID, err := projectIDFromConversation(db, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(strArg(args, "fact_key"))
|
||||||
|
if key == "" {
|
||||||
|
return textResult("错误: fact_key 必填", true), nil
|
||||||
|
}
|
||||||
|
f, err := db.GetProjectFactByKey(projectID, key)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
msg := fmt.Sprintf("fact_key: %s\ncategory: %s\nconfidence: %s\nsummary: %s\nupdated_at: %s",
|
||||||
|
f.FactKey, f.Category, f.Confidence, f.Summary, f.UpdatedAt.Format("2006-01-02 15:04:05"))
|
||||||
|
if f.RelatedVulnerabilityID != "" {
|
||||||
|
msg += fmt.Sprintf("\nrelated_vulnerability_id: %s", f.RelatedVulnerabilityID)
|
||||||
|
}
|
||||||
|
if f.SourceConversationID != "" {
|
||||||
|
msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID)
|
||||||
|
}
|
||||||
|
msg += "\n\n--- body ---\n" + f.Body
|
||||||
|
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||||
|
msg += warn
|
||||||
|
}
|
||||||
|
return textResult(msg, false), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
listTool := mcp.Tool{
|
||||||
|
Name: builtin.ToolListProjectFacts,
|
||||||
|
Description: "列出当前项目的事实(分页)。",
|
||||||
|
ShortDescription: "列出项目事实",
|
||||||
|
InputSchema: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"category": map[string]interface{}{"type": "string"},
|
||||||
|
"confidence": map[string]interface{}{"type": "string"},
|
||||||
|
"limit": map[string]interface{}{"type": "integer"},
|
||||||
|
"offset": map[string]interface{}{"type": "integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mcpServer.RegisterTool(listTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
|
projectID, err := projectIDFromConversation(db, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
limit := intArg(args, "limit", 50)
|
||||||
|
offset := intArg(args, "offset", 0)
|
||||||
|
filter := database.ProjectFactListFilter{
|
||||||
|
Category: strArg(args, "category"),
|
||||||
|
Confidence: strArg(args, "confidence"),
|
||||||
|
}
|
||||||
|
list, err := db.ListProjectFacts(projectID, filter, limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(fmt.Sprintf("共 %d 条(limit=%d offset=%d):\n", len(list), limit, offset))
|
||||||
|
for _, f := range list {
|
||||||
|
b.WriteString(fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, f.Summary, f.Confidence))
|
||||||
|
}
|
||||||
|
return textResult(b.String(), false), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
searchTool := mcp.Tool{
|
||||||
|
Name: builtin.ToolSearchProjectFacts,
|
||||||
|
Description: "按关键词搜索项目事实(summary/body/fact_key)。",
|
||||||
|
ShortDescription: "搜索项目事实",
|
||||||
|
InputSchema: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"query": map[string]interface{}{"type": "string"},
|
||||||
|
"limit": map[string]interface{}{"type": "integer"},
|
||||||
|
"offset": map[string]interface{}{"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": []string{"query"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mcpServer.RegisterTool(searchTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
|
projectID, err := projectIDFromConversation(db, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
q := strings.TrimSpace(strArg(args, "query"))
|
||||||
|
if q == "" {
|
||||||
|
return textResult("错误: query 必填", true), nil
|
||||||
|
}
|
||||||
|
list, err := db.ListProjectFacts(projectID, database.ProjectFactListFilter{Search: q}, intArg(args, "limit", 30), intArg(args, "offset", 0))
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(fmt.Sprintf("搜索 \"%s\" 命中 %d 条:\n", q, len(list)))
|
||||||
|
for _, f := range list {
|
||||||
|
b.WriteString(fmt.Sprintf("- [%s] %s — %s\n", f.FactKey, f.Category, f.Summary))
|
||||||
|
}
|
||||||
|
return textResult(b.String(), false), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
deprecateTool := mcp.Tool{
|
||||||
|
Name: builtin.ToolDeprecateProjectFact,
|
||||||
|
Description: "将事实标记为 deprecated,从黑板索引中排除。",
|
||||||
|
ShortDescription: "废弃项目事实",
|
||||||
|
InputSchema: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"fact_key": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
"required": []string{"fact_key"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mcpServer.RegisterTool(deprecateTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
|
projectID, err := projectIDFromConversation(db, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(strArg(args, "fact_key"))
|
||||||
|
if err := db.DeprecateProjectFact(projectID, key); err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
return textResult("事实已标记为 deprecated: "+key, false), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
restoreTool := mcp.Tool{
|
||||||
|
Name: builtin.ToolRestoreProjectFact,
|
||||||
|
Description: "将已废弃(deprecated)的事实恢复为 tentative 或 confirmed,重新参与黑板索引。",
|
||||||
|
ShortDescription: "恢复已废弃的项目事实",
|
||||||
|
InputSchema: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"fact_key": map[string]interface{}{"type": "string"},
|
||||||
|
"confidence": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "恢复后的置信度:tentative(默认)或 confirmed",
|
||||||
|
"enum": []string{"tentative", "confirmed"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"fact_key"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mcpServer.RegisterTool(restoreTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
|
projectID, err := projectIDFromConversation(db, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(strArg(args, "fact_key"))
|
||||||
|
if key == "" {
|
||||||
|
return textResult("错误: fact_key 必填", true), nil
|
||||||
|
}
|
||||||
|
conf := strArg(args, "confidence")
|
||||||
|
if err := db.RestoreProjectFact(projectID, key, conf); err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
if conf == "" {
|
||||||
|
conf = "tentative"
|
||||||
|
}
|
||||||
|
return textResult(fmt.Sprintf("事实已恢复为 %s: %s", conf, key), false), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if logger != nil {
|
||||||
|
logger.Info("项目黑板 MCP 工具注册成功")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func strArg(args map[string]interface{}, key string) string {
|
||||||
|
if v, ok := args[key].(string); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolArg(args map[string]interface{}, key string) bool {
|
||||||
|
if v, ok := args[key].(bool); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func intArg(args map[string]interface{}, key string, def int) int {
|
||||||
|
switch v := args[key].(type) {
|
||||||
|
case float64:
|
||||||
|
return int(v)
|
||||||
|
case int:
|
||||||
|
return v
|
||||||
|
case int64:
|
||||||
|
return int(v)
|
||||||
|
default:
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
"cyberstrike-ai/internal/vision"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func registerVisionTools(mcpServer *mcp.Server, cfg *config.Config, logger *zap.Logger) {
|
||||||
|
vision.RegisterAnalyzeImageTool(mcpServer, cfg, logger)
|
||||||
|
}
|
||||||
@@ -0,0 +1,405 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func conversationIDFromToolCtx(ctx context.Context) string {
|
||||||
|
if id := agent.ConversationIDFromContext(ctx); id != "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return mcp.MCPConversationIDFromContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// canAccessVulnerability 校验当前对话是否有权查看该漏洞(默认项目隔离,未绑项目则仅本会话)。
|
||||||
|
func canAccessVulnerability(vuln *database.Vulnerability, convID, projectID string) bool {
|
||||||
|
if vuln == nil || convID == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if projectID != "" {
|
||||||
|
if strings.TrimSpace(vuln.ProjectID) == projectID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// 历史记录:写入时尚未绑定 project_id,但属于同一会话
|
||||||
|
if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID == convID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return vuln.ConversationID == convID
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildVulnerabilityListFilter(db *database.DB, ctx context.Context, args map[string]interface{}) (database.VulnerabilityListFilter, string, error) {
|
||||||
|
convID := conversationIDFromToolCtx(ctx)
|
||||||
|
if convID == "" {
|
||||||
|
return database.VulnerabilityListFilter{}, "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用漏洞查询工具")
|
||||||
|
}
|
||||||
|
|
||||||
|
projectID := ""
|
||||||
|
if pid, err := db.GetConversationProjectID(convID); err == nil {
|
||||||
|
projectID = strings.TrimSpace(pid)
|
||||||
|
}
|
||||||
|
|
||||||
|
scope := strings.TrimSpace(strArg(args, "scope"))
|
||||||
|
if scope == "" {
|
||||||
|
if projectID != "" {
|
||||||
|
scope = "project"
|
||||||
|
} else {
|
||||||
|
scope = "conversation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := database.VulnerabilityListFilter{
|
||||||
|
Severity: strings.TrimSpace(strArg(args, "severity")),
|
||||||
|
Status: strings.TrimSpace(strArg(args, "status")),
|
||||||
|
}
|
||||||
|
if q := strings.TrimSpace(strArg(args, "q")); q != "" {
|
||||||
|
filter.Search = q
|
||||||
|
} else {
|
||||||
|
filter.Search = strings.TrimSpace(strArg(args, "search"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var scopeLabel string
|
||||||
|
switch scope {
|
||||||
|
case "project":
|
||||||
|
if projectID == "" {
|
||||||
|
return filter, "", fmt.Errorf("当前对话未绑定项目,无法按项目列出漏洞;请使用 scope=conversation,或先在对话中绑定项目")
|
||||||
|
}
|
||||||
|
filter.ProjectID = projectID
|
||||||
|
scopeLabel = fmt.Sprintf("项目 %s", projectID)
|
||||||
|
case "conversation":
|
||||||
|
filter.ConversationID = convID
|
||||||
|
scopeLabel = fmt.Sprintf("会话 %s", convID)
|
||||||
|
default:
|
||||||
|
return filter, "", fmt.Errorf("scope 仅支持 project 或 conversation,当前值: %s", scope)
|
||||||
|
}
|
||||||
|
return filter, scopeLabel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatVulnerabilityListItem(v *database.Vulnerability) string {
|
||||||
|
line := fmt.Sprintf("- id=%s | %s | %s | %s", v.ID, v.Severity, v.Status, v.Title)
|
||||||
|
if v.Type != "" {
|
||||||
|
line += fmt.Sprintf(" | type=%s", v.Type)
|
||||||
|
}
|
||||||
|
if v.Target != "" {
|
||||||
|
line += fmt.Sprintf(" | target=%s", truncateRunes(v.Target, 80))
|
||||||
|
}
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatVulnerabilityDetail(v *database.Vulnerability) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(fmt.Sprintf("漏洞ID: %s\n", v.ID))
|
||||||
|
b.WriteString(fmt.Sprintf("标题: %s\n", v.Title))
|
||||||
|
b.WriteString(fmt.Sprintf("严重程度: %s\n", v.Severity))
|
||||||
|
b.WriteString(fmt.Sprintf("状态: %s\n", v.Status))
|
||||||
|
if v.Type != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("类型: %s\n", v.Type))
|
||||||
|
}
|
||||||
|
if v.Target != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("目标: %s\n", v.Target))
|
||||||
|
}
|
||||||
|
if v.ProjectID != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("项目ID: %s\n", v.ProjectID))
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf("会话ID: %s\n", v.ConversationID))
|
||||||
|
if !v.CreatedAt.IsZero() {
|
||||||
|
b.WriteString(fmt.Sprintf("创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05")))
|
||||||
|
}
|
||||||
|
if v.Description != "" {
|
||||||
|
b.WriteString("\n--- 描述 ---\n")
|
||||||
|
b.WriteString(v.Description)
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
if v.Proof != "" {
|
||||||
|
b.WriteString("\n--- 证明(POC) ---\n")
|
||||||
|
b.WriteString(v.Proof)
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
if v.Impact != "" {
|
||||||
|
b.WriteString("\n--- 影响 ---\n")
|
||||||
|
b.WriteString(v.Impact)
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
if v.Recommendation != "" {
|
||||||
|
b.WriteString("\n--- 修复建议 ---\n")
|
||||||
|
b.WriteString(v.Recommendation)
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateRunes(s string, max int) string {
|
||||||
|
r := []rune(s)
|
||||||
|
if len(r) <= max {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return string(r[:max]) + "…"
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerVulnerabilityTools 注册漏洞记录与查询 MCP 工具。
|
||||||
|
func registerVulnerabilityTools(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||||
|
registerRecordVulnerabilityTool(mcpServer, db, logger)
|
||||||
|
registerListVulnerabilitiesTool(mcpServer, db, logger)
|
||||||
|
registerGetVulnerabilityTool(mcpServer, db, logger)
|
||||||
|
if logger != nil {
|
||||||
|
logger.Info("漏洞 MCP 工具注册成功", zap.Strings("tools", []string{
|
||||||
|
builtin.ToolRecordVulnerability,
|
||||||
|
builtin.ToolListVulnerabilities,
|
||||||
|
builtin.ToolGetVulnerability,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||||
|
tool := mcp.Tool{
|
||||||
|
Name: builtin.ToolRecordVulnerability,
|
||||||
|
Description: "记录发现的漏洞详情到漏洞管理系统。边渗透边记录:每验证出一条可复现漏洞(含 POC/影响)后立即调用,勿等会话结束。包括标题、描述、严重程度、类型、目标、证明、影响和建议等。记录前可先 list_vulnerabilities 避免重复。",
|
||||||
|
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
|
||||||
|
InputSchema: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"title": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "漏洞标题(必需)",
|
||||||
|
},
|
||||||
|
"description": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "漏洞详细描述",
|
||||||
|
},
|
||||||
|
"severity": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
|
||||||
|
"enum": []string{"critical", "high", "medium", "low", "info"},
|
||||||
|
},
|
||||||
|
"vulnerability_type": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
|
||||||
|
},
|
||||||
|
"target": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "受影响的目标(URL、IP地址、服务等)",
|
||||||
|
},
|
||||||
|
"proof": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "漏洞证明(POC、截图、请求/响应等)",
|
||||||
|
},
|
||||||
|
"impact": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "漏洞影响说明",
|
||||||
|
},
|
||||||
|
"recommendation": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "修复建议",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"title", "severity"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
|
conversationID := strings.TrimSpace(strArg(args, "conversation_id"))
|
||||||
|
if conversationID == "" {
|
||||||
|
conversationID = conversationIDFromToolCtx(ctx)
|
||||||
|
}
|
||||||
|
if conversationID == "" {
|
||||||
|
return textResult("错误: conversation_id 未设置。这是系统错误,请重试。", true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
title := strings.TrimSpace(strArg(args, "title"))
|
||||||
|
if title == "" {
|
||||||
|
return textResult("错误: title 参数必需且不能为空", true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
severity := strings.TrimSpace(strArg(args, "severity"))
|
||||||
|
if severity == "" {
|
||||||
|
return textResult("错误: severity 参数必需且不能为空", true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
validSeverities := map[string]bool{
|
||||||
|
"critical": true, "high": true, "medium": true, "low": true, "info": true,
|
||||||
|
}
|
||||||
|
if !validSeverities[severity] {
|
||||||
|
return textResult(fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity), true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
projectID := ""
|
||||||
|
if pid, perr := db.GetConversationProjectID(conversationID); perr == nil {
|
||||||
|
projectID = strings.TrimSpace(pid)
|
||||||
|
}
|
||||||
|
|
||||||
|
vuln := &database.Vulnerability{
|
||||||
|
ConversationID: conversationID,
|
||||||
|
ProjectID: projectID,
|
||||||
|
Title: title,
|
||||||
|
Description: strArg(args, "description"),
|
||||||
|
Severity: severity,
|
||||||
|
Status: "open",
|
||||||
|
Type: strArg(args, "vulnerability_type"),
|
||||||
|
Target: strArg(args, "target"),
|
||||||
|
Proof: strArg(args, "proof"),
|
||||||
|
Impact: strArg(args, "impact"),
|
||||||
|
Recommendation: strArg(args, "recommendation"),
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := db.CreateVulnerability(vuln)
|
||||||
|
if err != nil {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Error("记录漏洞失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
return textResult(fmt.Sprintf("记录漏洞失败: %v", err), true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger != nil {
|
||||||
|
logger.Info("漏洞记录成功",
|
||||||
|
zap.String("id", created.ID),
|
||||||
|
zap.String("title", created.Title),
|
||||||
|
zap.String("severity", created.Severity),
|
||||||
|
zap.String("conversation_id", conversationID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return textResult(fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n可使用 get_vulnerability(id) 查看详情,或 list_vulnerabilities 查看列表。",
|
||||||
|
created.ID, created.Title, created.Severity, created.Status), false), nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||||
|
tool := mcp.Tool{
|
||||||
|
Name: builtin.ToolListVulnerabilities,
|
||||||
|
Description: "列出当前授权范围内的漏洞(摘要)。默认:对话已绑定项目时列出该项目下全部漏洞;未绑项目时仅列出当前会话漏洞。可用 scope=conversation 仅看本会话。记录新漏洞前建议先调用以避免重复。",
|
||||||
|
ShortDescription: "列出漏洞(默认当前项目)",
|
||||||
|
InputSchema: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"scope": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "范围:project(默认,需绑定项目)| conversation(仅当前会话)",
|
||||||
|
"enum": []string{"project", "conversation"},
|
||||||
|
},
|
||||||
|
"severity": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "按严重程度筛选:critical、high、medium、low、info",
|
||||||
|
"enum": []string{"critical", "high", "medium", "low", "info"},
|
||||||
|
},
|
||||||
|
"status": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "按状态筛选:open、confirmed、fixed、false_positive",
|
||||||
|
"enum": []string{"open", "confirmed", "fixed", "false_positive"},
|
||||||
|
},
|
||||||
|
"q": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "关键词搜索(标题、描述、类型、目标等)",
|
||||||
|
},
|
||||||
|
"limit": map[string]interface{}{
|
||||||
|
"type": "integer",
|
||||||
|
"description": "返回条数上限,默认 30,最大 100",
|
||||||
|
},
|
||||||
|
"offset": map[string]interface{}{
|
||||||
|
"type": "integer",
|
||||||
|
"description": "分页偏移,默认 0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
|
filter, scopeLabel, err := buildVulnerabilityListFilter(db, ctx, args)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := intArg(args, "limit", 30)
|
||||||
|
if limit <= 0 || limit > 100 {
|
||||||
|
limit = 30
|
||||||
|
}
|
||||||
|
offset := intArg(args, "offset", 0)
|
||||||
|
if offset < 0 {
|
||||||
|
offset = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
total, err := db.CountVulnerabilities(filter)
|
||||||
|
if err != nil {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Warn("统计漏洞失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
total = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := db.ListVulnerabilities(limit, offset, filter)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(fmt.Sprintf("范围: %s\n总计: %d | 本页: %d 条 (limit=%d offset=%d)\n\n", scopeLabel, total, len(list), limit, offset))
|
||||||
|
if len(list) == 0 {
|
||||||
|
b.WriteString("(暂无漏洞记录)\n")
|
||||||
|
} else {
|
||||||
|
for _, v := range list {
|
||||||
|
b.WriteString(formatVulnerabilityListItem(v))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
if total > offset+len(list) {
|
||||||
|
b.WriteString(fmt.Sprintf("\n(还有更多,可增大 offset 或使用 q/severity/status 筛选)\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.WriteString("\n需要 POC 与完整字段请对具体 id 调用 get_vulnerability。")
|
||||||
|
return textResult(b.String(), false), nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerGetVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||||
|
tool := mcp.Tool{
|
||||||
|
Name: builtin.ToolGetVulnerability,
|
||||||
|
Description: "按漏洞 ID 获取完整详情(含 POC、影响、修复建议)。仅能访问当前项目或当前会话下的漏洞(与 list_vulnerabilities 授权范围一致)。",
|
||||||
|
ShortDescription: "按 ID 获取漏洞详情",
|
||||||
|
InputSchema: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"id": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "漏洞 ID(list_vulnerabilities 返回的 id)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"id"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
|
convID := conversationIDFromToolCtx(ctx)
|
||||||
|
if convID == "" {
|
||||||
|
return textResult("错误: 无法确定当前对话,请在对话上下文中使用本工具", true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
id := strings.TrimSpace(strArg(args, "id"))
|
||||||
|
if id == "" {
|
||||||
|
return textResult("错误: id 必填", true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
vuln, err := db.GetVulnerability(id)
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: 漏洞不存在或查询失败", true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
projectID := ""
|
||||||
|
if pid, perr := db.GetConversationProjectID(convID); perr == nil {
|
||||||
|
projectID = strings.TrimSpace(pid)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !canAccessVulnerability(vuln, convID, projectID) {
|
||||||
|
return textResult("错误: 无权访问该漏洞(仅可查看当前项目或当前会话下的记录)", true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return textResult(formatVulnerabilityDetail(vuln), false), nil
|
||||||
|
})
|
||||||
|
}
|
||||||
+65
-29
@@ -36,13 +36,40 @@ type Config struct {
|
|||||||
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
|
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
|
||||||
AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.md,YAML front matter)
|
AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.md,YAML front matter)
|
||||||
MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"`
|
MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"`
|
||||||
|
Project ProjectConfig `yaml:"project,omitempty" json:"project,omitempty"`
|
||||||
|
Vision VisionConfig `yaml:"vision,omitempty" json:"vision,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。
|
// ProjectConfig 项目黑板(跨对话共享事实)配置。
|
||||||
|
type ProjectConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
|
DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目
|
||||||
|
FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"`
|
||||||
|
FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"`
|
||||||
|
DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactIndexMaxRunesEffective 自动注入黑板索引的最大 rune 数。
|
||||||
|
func (c ProjectConfig) FactIndexMaxRunesEffective() int {
|
||||||
|
if c.FactIndexMaxRunes <= 0 {
|
||||||
|
return 3500
|
||||||
|
}
|
||||||
|
return c.FactIndexMaxRunes
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数(索引一行,宜含验证要点)。
|
||||||
|
func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
|
||||||
|
if c.FactSummaryMaxRunes <= 0 {
|
||||||
|
return 200
|
||||||
|
}
|
||||||
|
return c.FactSummaryMaxRunes
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor)。
|
||||||
type MultiAgentConfig struct {
|
type MultiAgentConfig struct {
|
||||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理
|
RobotDefaultAgentMode string `yaml:"robot_default_agent_mode,omitempty" json:"robot_default_agent_mode,omitempty"` // eino_single | deep | plan_execute | supervisor
|
||||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||||
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
|
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
|
||||||
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
|
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
|
||||||
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor)
|
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor)
|
||||||
@@ -211,9 +238,6 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||||
// HistoryInputBudgetRatio 已不影响 Eino:从 last_react 轨迹转 ADK 消息时**不再**按 token 比例裁剪(完整注入)。
|
|
||||||
// 字段仍保留,便于旧版 config 不报错;新部署可省略。
|
|
||||||
HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"`
|
|
||||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||||
// PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2).
|
// PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2).
|
||||||
@@ -228,6 +252,10 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
||||||
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
||||||
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
||||||
|
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
|
||||||
|
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||||
|
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||||
|
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||||
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
|
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
|
||||||
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -253,20 +281,6 @@ func (c MultiAgentEinoMiddlewareConfig) SummarizationEmitInternalEventsEffective
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c MultiAgentEinoMiddlewareConfig) HistoryInputBudgetRatioEffective() float64 {
|
|
||||||
v := c.HistoryInputBudgetRatio
|
|
||||||
if v <= 0 {
|
|
||||||
return 0.35
|
|
||||||
}
|
|
||||||
if v < 0.15 {
|
|
||||||
return 0.15
|
|
||||||
}
|
|
||||||
if v > 0.6 {
|
|
||||||
return 0.6
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteUserInputBudgetRatioEffective() float64 {
|
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteUserInputBudgetRatioEffective() float64 {
|
||||||
v := c.PlanExecuteUserInputBudgetRatio
|
v := c.PlanExecuteUserInputBudgetRatio
|
||||||
if v <= 0 {
|
if v <= 0 {
|
||||||
@@ -363,9 +377,9 @@ type MultiAgentSubConfig struct {
|
|||||||
|
|
||||||
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
|
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
|
||||||
type MultiAgentPublic struct {
|
type MultiAgentPublic struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||||
SubAgentCount int `json:"sub_agent_count"`
|
SubAgentCount int `json:"sub_agent_count"`
|
||||||
Orchestration string `json:"orchestration,omitempty"`
|
Orchestration string `json:"orchestration,omitempty"`
|
||||||
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
|
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
|
||||||
@@ -373,6 +387,28 @@ type MultiAgentPublic struct {
|
|||||||
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
|
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeAgentMode 解析代理模式(eino_single | deep | plan_execute | supervisor);空值默认 eino_single。
|
||||||
|
func NormalizeAgentMode(mode string) string {
|
||||||
|
s := strings.TrimSpace(strings.ToLower(mode))
|
||||||
|
switch s {
|
||||||
|
case "", "eino_single":
|
||||||
|
return "eino_single"
|
||||||
|
case "deep":
|
||||||
|
return "deep"
|
||||||
|
case "plan_execute", "plan-execute", "planexecute", "pe":
|
||||||
|
return "plan_execute"
|
||||||
|
case "supervisor", "super", "sv":
|
||||||
|
return "supervisor"
|
||||||
|
default:
|
||||||
|
return "eino_single"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeRobotAgentMode 解析机器人默认对话模式。
|
||||||
|
func NormalizeRobotAgentMode(ma MultiAgentConfig) string {
|
||||||
|
return NormalizeAgentMode(ma.RobotDefaultAgentMode)
|
||||||
|
}
|
||||||
|
|
||||||
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
|
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
|
||||||
func NormalizeMultiAgentOrchestration(s string) string {
|
func NormalizeMultiAgentOrchestration(s string) string {
|
||||||
v := strings.TrimSpace(strings.ToLower(s))
|
v := strings.TrimSpace(strings.ToLower(s))
|
||||||
@@ -388,9 +424,9 @@ func NormalizeMultiAgentOrchestration(s string) string {
|
|||||||
|
|
||||||
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
|
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
|
||||||
type MultiAgentAPIUpdate struct {
|
type MultiAgentAPIUpdate struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||||
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
||||||
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
|
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
|
||||||
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
|
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
|
||||||
@@ -490,7 +526,7 @@ type OpenAIConfig struct {
|
|||||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||||
Model string `yaml:"model" json:"model"`
|
Model string `yaml:"model" json:"model"`
|
||||||
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
|
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
|
||||||
// Reasoning 控制 Eino ChatModel 的 thinking / reasoning_effort / output_config 等(仅 Eino 路径生效;原生 ReAct 忽略)。
|
// Reasoning 控制 Eino ChatModel 的 thinking / reasoning_effort / output_config 等(Eino 单/多代理路径生效)。
|
||||||
Reasoning OpenAIReasoningConfig `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
Reasoning OpenAIReasoningConfig `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -498,7 +534,7 @@ type OpenAIConfig struct {
|
|||||||
type OpenAIReasoningConfig struct {
|
type OpenAIReasoningConfig struct {
|
||||||
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
|
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
|
||||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||||
// Effort: low | medium | high | max;空表示不单独指定强度(各 profile 行为见 internal/reasoning)。
|
// Effort: low | medium | high | max | xhigh;max/xhigh 为不同网关最高档命名,原样下发、不互转。空表示不单独指定强度。
|
||||||
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
|
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
|
||||||
// AllowClientReasoning 为 false 时忽略请求体 reasoning;nil 或未设置等同于 true。
|
// AllowClientReasoning 为 false 时忽略请求体 reasoning;nil 或未设置等同于 true。
|
||||||
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
|
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
|
||||||
|
|||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// VisionConfig 独立视觉模型与 analyze_image 工具参数;enabled 时注册 MCP 工具 analyze_image。
|
||||||
|
type VisionConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
|
APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty"`
|
||||||
|
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"`
|
||||||
|
Model string `yaml:"model,omitempty" json:"model,omitempty"`
|
||||||
|
Provider string `yaml:"provider,omitempty" json:"provider,omitempty"`
|
||||||
|
TimeoutSeconds int `yaml:"timeout_seconds,omitempty" json:"timeout_seconds,omitempty"`
|
||||||
|
MaxImageBytes int64 `yaml:"max_image_bytes,omitempty" json:"max_image_bytes,omitempty"`
|
||||||
|
MaxDimension int `yaml:"max_dimension,omitempty" json:"max_dimension,omitempty"`
|
||||||
|
JPEGQuality int `yaml:"jpeg_quality,omitempty" json:"jpeg_quality,omitempty"`
|
||||||
|
MaxPayloadBytes int64 `yaml:"max_payload_bytes,omitempty" json:"max_payload_bytes,omitempty"`
|
||||||
|
SkipPreprocessBelowBytes int64 `yaml:"skip_preprocess_below_bytes,omitempty" json:"skip_preprocess_below_bytes,omitempty"` // 0=始终压缩;默认 2MB 且长边已<=max_dimension 时原图直传
|
||||||
|
Detail string `yaml:"detail,omitempty" json:"detail,omitempty"` // low | high | auto
|
||||||
|
AllowedRoots []string `yaml:"allowed_roots,omitempty" json:"allowed_roots,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v VisionConfig) TimeoutSecondsEffective() int {
|
||||||
|
if v.TimeoutSeconds <= 0 {
|
||||||
|
return 60
|
||||||
|
}
|
||||||
|
return v.TimeoutSeconds
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v VisionConfig) MaxImageBytesEffective() int64 {
|
||||||
|
if v.MaxImageBytes <= 0 {
|
||||||
|
return 5 * 1024 * 1024
|
||||||
|
}
|
||||||
|
return v.MaxImageBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v VisionConfig) MaxDimensionEffective() int {
|
||||||
|
if v.MaxDimension <= 0 {
|
||||||
|
return 2048
|
||||||
|
}
|
||||||
|
return v.MaxDimension
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v VisionConfig) JPEGQualityEffective() int {
|
||||||
|
if v.JPEGQuality <= 0 || v.JPEGQuality > 100 {
|
||||||
|
return 82
|
||||||
|
}
|
||||||
|
return v.JPEGQuality
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v VisionConfig) MaxPayloadBytesEffective() int64 {
|
||||||
|
if v.MaxPayloadBytes <= 0 {
|
||||||
|
return 512 * 1024
|
||||||
|
}
|
||||||
|
return v.MaxPayloadBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// SkipPreprocessBelowBytesEffective 低于该字节数且长边<=max_dimension、且<=max_payload 时可原图直传;0 表示始终压缩。
|
||||||
|
func (v VisionConfig) SkipPreprocessBelowBytesEffective() int64 {
|
||||||
|
if v.SkipPreprocessBelowBytes < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return v.SkipPreprocessBelowBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v VisionConfig) DetailEffective() string {
|
||||||
|
d := strings.ToLower(strings.TrimSpace(v.Detail))
|
||||||
|
switch d {
|
||||||
|
case "high", "low", "auto":
|
||||||
|
return d
|
||||||
|
default:
|
||||||
|
return "low"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAICfgEffective 合并主 openai 配置与 vision 覆盖项,供 VL ChatModel 使用。
|
||||||
|
// vision.api_key / base_url / provider 留空或省略时,沿用 main(openai)对应字段;vision.model 必填(由 Ready 校验)。
|
||||||
|
func (v VisionConfig) OpenAICfgEffective(main OpenAIConfig) OpenAIConfig {
|
||||||
|
out := main
|
||||||
|
if k := strings.TrimSpace(v.APIKey); k != "" {
|
||||||
|
out.APIKey = k
|
||||||
|
}
|
||||||
|
if u := strings.TrimSpace(v.BaseURL); u != "" {
|
||||||
|
out.BaseURL = u
|
||||||
|
}
|
||||||
|
if m := strings.TrimSpace(v.Model); m != "" {
|
||||||
|
out.Model = m
|
||||||
|
}
|
||||||
|
if p := strings.TrimSpace(v.Provider); p != "" {
|
||||||
|
out.Provider = p
|
||||||
|
}
|
||||||
|
out.Reasoning.Mode = "off"
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ready 表示已启用且模型名非空。
|
||||||
|
func (v VisionConfig) Ready() bool {
|
||||||
|
return v.Enabled && strings.TrimSpace(v.Model) != ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestVisionConfig_OpenAICfgEffective_fallbackToMain(t *testing.T) {
|
||||||
|
main := OpenAIConfig{
|
||||||
|
APIKey: "main-key",
|
||||||
|
BaseURL: "https://main.example/v1",
|
||||||
|
Model: "main-model",
|
||||||
|
Provider: "openai",
|
||||||
|
}
|
||||||
|
v := VisionConfig{Model: "qwen-vl-max"}
|
||||||
|
out := v.OpenAICfgEffective(main)
|
||||||
|
if out.APIKey != main.APIKey || out.BaseURL != main.BaseURL || out.Provider != main.Provider {
|
||||||
|
t.Fatalf("expected openai fallback, got key=%q url=%q provider=%q", out.APIKey, out.BaseURL, out.Provider)
|
||||||
|
}
|
||||||
|
if out.Model != "qwen-vl-max" {
|
||||||
|
t.Fatalf("model: %s", out.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVisionConfig_OpenAICfgEffective(t *testing.T) {
|
||||||
|
main := OpenAIConfig{
|
||||||
|
APIKey: "main-key",
|
||||||
|
BaseURL: "https://main.example/v1",
|
||||||
|
Model: "main-model",
|
||||||
|
Provider: "openai",
|
||||||
|
Reasoning: OpenAIReasoningConfig{Mode: "on"},
|
||||||
|
}
|
||||||
|
v := VisionConfig{
|
||||||
|
Model: "vl-model",
|
||||||
|
APIKey: "vl-key",
|
||||||
|
BaseURL: "https://vl.example/v1",
|
||||||
|
Provider: "claude",
|
||||||
|
}
|
||||||
|
out := v.OpenAICfgEffective(main)
|
||||||
|
if out.APIKey != "vl-key" || out.BaseURL != "https://vl.example/v1" || out.Model != "vl-model" {
|
||||||
|
t.Fatalf("unexpected merge: %+v", out)
|
||||||
|
}
|
||||||
|
if out.Provider != "claude" {
|
||||||
|
t.Fatalf("provider: %s", out.Provider)
|
||||||
|
}
|
||||||
|
if out.Reasoning.Mode != "off" {
|
||||||
|
t.Fatalf("reasoning should be off for vision, got %s", out.Reasoning.Mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVisionConfig_Ready(t *testing.T) {
|
||||||
|
if (VisionConfig{Enabled: true, Model: "x"}).Ready() != true {
|
||||||
|
t.Fatal("expected ready")
|
||||||
|
}
|
||||||
|
if (VisionConfig{Enabled: true}).Ready() != false {
|
||||||
|
t.Fatal("expected not ready without model")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,6 +22,7 @@ type BatchTaskQueueRow struct {
|
|||||||
LastScheduleTriggerAt sql.NullTime
|
LastScheduleTriggerAt sql.NullTime
|
||||||
LastScheduleError sql.NullString
|
LastScheduleError sql.NullString
|
||||||
LastRunError sql.NullString
|
LastRunError sql.NullString
|
||||||
|
ProjectID sql.NullString
|
||||||
Status string
|
Status string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
StartedAt sql.NullTime
|
StartedAt sql.NullTime
|
||||||
@@ -51,6 +52,7 @@ func (db *DB) CreateBatchQueue(
|
|||||||
scheduleMode string,
|
scheduleMode string,
|
||||||
cronExpr string,
|
cronExpr string,
|
||||||
nextRunAt *time.Time,
|
nextRunAt *time.Time,
|
||||||
|
projectID string,
|
||||||
tasks []map[string]interface{},
|
tasks []map[string]interface{},
|
||||||
) error {
|
) error {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
@@ -65,9 +67,13 @@ func (db *DB) CreateBatchQueue(
|
|||||||
nextRunAtValue = *nextRunAt
|
nextRunAtValue = *nextRunAt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var projectIDVal interface{}
|
||||||
|
if strings.TrimSpace(projectID) != "" {
|
||||||
|
projectIDVal = strings.TrimSpace(projectID)
|
||||||
|
}
|
||||||
_, err = tx.Exec(
|
_, err = tx.Exec(
|
||||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, "pending", now, 0,
|
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
||||||
@@ -101,9 +107,9 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
|||||||
var row BatchTaskQueueRow
|
var row BatchTaskQueueRow
|
||||||
var createdAt string
|
var createdAt string
|
||||||
err := db.QueryRow(
|
err := db.QueryRow(
|
||||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
||||||
queueID,
|
queueID,
|
||||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -127,7 +133,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
|||||||
// GetAllBatchQueues 获取所有批量任务队列
|
// GetAllBatchQueues 获取所有批量任务队列
|
||||||
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||||
rows, err := db.Query(
|
rows, err := db.Query(
|
||||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||||
@@ -138,7 +144,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var row BatchTaskQueueRow
|
var row BatchTaskQueueRow
|
||||||
var createdAt string
|
var createdAt string
|
||||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||||
}
|
}
|
||||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||||
@@ -158,7 +164,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
|||||||
|
|
||||||
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
||||||
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
||||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
|
|
||||||
// 状态筛选
|
// 状态筛选
|
||||||
@@ -186,7 +192,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var row BatchTaskQueueRow
|
var row BatchTaskQueueRow
|
||||||
var createdAt string
|
var createdAt string
|
||||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||||
}
|
}
|
||||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
|
ProjectID string `json:"projectId,omitempty"`
|
||||||
Pinned bool `json:"pinned"`
|
Pinned bool `json:"pinned"`
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
UpdatedAt time.Time `json:"updatedAt"`
|
UpdatedAt time.Time `json:"updatedAt"`
|
||||||
@@ -46,13 +47,32 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string,
|
|||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
|
projectID := strings.TrimSpace(meta.ProjectID)
|
||||||
|
if projectID != "" {
|
||||||
|
if _, err := db.GetProject(projectID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if webshellConnectionID != "" {
|
wsID := strings.TrimSpace(webshellConnectionID)
|
||||||
|
switch {
|
||||||
|
case wsID != "" && projectID != "":
|
||||||
|
_, err = db.Exec(
|
||||||
|
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id, project_id) VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
id, title, now, now, wsID, projectID,
|
||||||
|
)
|
||||||
|
case wsID != "":
|
||||||
_, err = db.Exec(
|
_, err = db.Exec(
|
||||||
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)",
|
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)",
|
||||||
id, title, now, now, webshellConnectionID,
|
id, title, now, now, wsID,
|
||||||
)
|
)
|
||||||
} else {
|
case projectID != "":
|
||||||
|
_, err = db.Exec(
|
||||||
|
"INSERT INTO conversations (id, title, created_at, updated_at, project_id) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
id, title, now, now, projectID,
|
||||||
|
)
|
||||||
|
default:
|
||||||
_, err = db.Exec(
|
_, err = db.Exec(
|
||||||
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
|
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
|
||||||
id, title, now, now,
|
id, title, now, now,
|
||||||
@@ -65,11 +85,12 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string,
|
|||||||
conv := &Conversation{
|
conv := &Conversation{
|
||||||
ID: id,
|
ID: id,
|
||||||
Title: title,
|
Title: title,
|
||||||
|
ProjectID: projectID,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
}
|
}
|
||||||
if webshellConnectionID != "" {
|
if wsID != "" {
|
||||||
meta.WebShellConnectionID = webshellConnectionID
|
meta.WebShellConnectionID = wsID
|
||||||
}
|
}
|
||||||
notifyConversationCreated(conv, meta)
|
notifyConversationCreated(conv, meta)
|
||||||
return conv, nil
|
return conv, nil
|
||||||
@@ -210,16 +231,20 @@ func (db *DB) GetConversation(id string) (*Conversation, error) {
|
|||||||
var createdAt, updatedAt string
|
var createdAt, updatedAt string
|
||||||
var pinned int
|
var pinned int
|
||||||
|
|
||||||
|
var projectID sql.NullString
|
||||||
err := db.QueryRow(
|
err := db.QueryRow(
|
||||||
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
|
"SELECT id, title, pinned, created_at, updated_at, project_id FROM conversations WHERE id = ?",
|
||||||
id,
|
id,
|
||||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
|
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, fmt.Errorf("对话不存在")
|
return nil, fmt.Errorf("对话不存在")
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||||
}
|
}
|
||||||
|
if projectID.Valid {
|
||||||
|
conv.ProjectID = strings.TrimSpace(projectID.String)
|
||||||
|
}
|
||||||
|
|
||||||
// 尝试多种时间格式解析
|
// 尝试多种时间格式解析
|
||||||
var err1, err2 error
|
var err1, err2 error
|
||||||
@@ -292,16 +317,20 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
|||||||
var createdAt, updatedAt string
|
var createdAt, updatedAt string
|
||||||
var pinned int
|
var pinned int
|
||||||
|
|
||||||
|
var projectID sql.NullString
|
||||||
err := db.QueryRow(
|
err := db.QueryRow(
|
||||||
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
|
"SELECT id, title, pinned, created_at, updated_at, project_id FROM conversations WHERE id = ?",
|
||||||
id,
|
id,
|
||||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
|
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, fmt.Errorf("对话不存在")
|
return nil, fmt.Errorf("对话不存在")
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||||
}
|
}
|
||||||
|
if projectID.Valid {
|
||||||
|
conv.ProjectID = strings.TrimSpace(projectID.String)
|
||||||
|
}
|
||||||
|
|
||||||
// 尝试多种时间格式解析
|
// 尝试多种时间格式解析
|
||||||
var err1, err2 error
|
var err1, err2 error
|
||||||
@@ -341,7 +370,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
|||||||
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
||||||
searchPattern := "%" + search + "%"
|
searchPattern := "%" + search + "%"
|
||||||
rows, err = db.Query(
|
rows, err = db.Query(
|
||||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at
|
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
|
||||||
FROM conversations c
|
FROM conversations c
|
||||||
WHERE c.title LIKE ?
|
WHERE c.title LIKE ?
|
||||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
||||||
@@ -351,7 +380,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
|||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
rows, err = db.Query(
|
rows, err = db.Query(
|
||||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||||||
limit, offset,
|
limit, offset,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -366,10 +395,14 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
|||||||
var conv Conversation
|
var conv Conversation
|
||||||
var createdAt, updatedAt string
|
var createdAt, updatedAt string
|
||||||
var pinned int
|
var pinned int
|
||||||
|
var projectID sql.NullString
|
||||||
|
|
||||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil {
|
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID); err != nil {
|
||||||
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||||||
}
|
}
|
||||||
|
if projectID.Valid {
|
||||||
|
conv.ProjectID = strings.TrimSpace(projectID.String)
|
||||||
|
}
|
||||||
|
|
||||||
// 尝试多种时间格式解析
|
// 尝试多种时间格式解析
|
||||||
var err1, err2 error
|
var err1, err2 error
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package database
|
|||||||
type ConversationCreateMeta struct {
|
type ConversationCreateMeta struct {
|
||||||
Source string
|
Source string
|
||||||
WebShellConnectionID string
|
WebShellConnectionID string
|
||||||
|
ProjectID string
|
||||||
ClientIP string
|
ClientIP string
|
||||||
SessionHint string
|
SessionHint string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -12,19 +13,106 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SQLite 在 WAL 模式下建议使用较保守的连接数,降低长读快照导致 checkpoint 饥饿的概率。
|
||||||
|
sqliteMaxOpenConns = 25
|
||||||
|
sqliteMaxIdleConns = 5
|
||||||
|
// 以页为单位的自动 checkpoint 触发阈值(默认 1000 页,约 4MB @ 4KB/page)。
|
||||||
|
sqliteWALAutoCheckpointPages = 1000
|
||||||
|
// 控制 WAL 目标上限,避免异常场景持续膨胀(256MB)。
|
||||||
|
sqliteJournalSizeLimitBytes = 256 * 1024 * 1024
|
||||||
|
// 定时执行 PASSIVE checkpoint,平滑推进 WAL 回收。
|
||||||
|
sqlitePassiveCheckpointInterval = 300 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
// configureDBPool 设置 SQLite 连接池参数,提升并发稳定性
|
// configureDBPool 设置 SQLite 连接池参数,提升并发稳定性
|
||||||
func configureDBPool(db *sql.DB) {
|
func configureDBPool(db *sql.DB) {
|
||||||
// SQLite 同一时间只允许一个写入者,限制连接数避免 "database is locked" 错误
|
// SQLite 同一时间只允许一个写入者;过高连接数会放大锁竞争和 WAL 回收延迟。
|
||||||
db.SetMaxOpenConns(25)
|
db.SetMaxOpenConns(sqliteMaxOpenConns)
|
||||||
db.SetMaxIdleConns(5)
|
db.SetMaxIdleConns(sqliteMaxIdleConns)
|
||||||
db.SetConnMaxLifetime(30 * time.Minute)
|
db.SetConnMaxLifetime(30 * time.Minute)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// configureSQLitePragmas 调整 WAL 回收行为,降低 -wal 文件长期膨胀风险。
|
||||||
|
func configureSQLitePragmas(db *sql.DB) error {
|
||||||
|
if _, err := db.Exec(fmt.Sprintf("PRAGMA wal_autocheckpoint=%d", sqliteWALAutoCheckpointPages)); err != nil {
|
||||||
|
return fmt.Errorf("设置 wal_autocheckpoint 失败: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(fmt.Sprintf("PRAGMA journal_size_limit=%d", sqliteJournalSizeLimitBytes)); err != nil {
|
||||||
|
return fmt.Errorf("设置 journal_size_limit 失败: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DB 数据库连接
|
// DB 数据库连接
|
||||||
type DB struct {
|
type DB struct {
|
||||||
*sql.DB
|
*sql.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
conversationArtifactsDir string
|
conversationArtifactsDir string
|
||||||
|
checkpointLoopName string
|
||||||
|
checkpointStop chan struct{}
|
||||||
|
checkpointDone chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
closeErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
// startPassiveCheckpointLoop 启动后台 PASSIVE checkpoint 循环。
|
||||||
|
func (db *DB) startPassiveCheckpointLoop(name string) {
|
||||||
|
if sqlitePassiveCheckpointInterval <= 0 || db == nil || db.DB == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
db.checkpointLoopName = strings.TrimSpace(name)
|
||||||
|
db.checkpointStop = make(chan struct{})
|
||||||
|
db.checkpointDone = make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(db.checkpointDone)
|
||||||
|
ticker := time.NewTicker(sqlitePassiveCheckpointInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
// 启动后先尝试一次,尽快回收已有 WAL 堆积。
|
||||||
|
db.runPassiveCheckpoint("startup")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-db.checkpointStop:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
db.runPassiveCheckpoint("ticker")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// runPassiveCheckpoint 执行一次 PRAGMA wal_checkpoint(PASSIVE)。
|
||||||
|
func (db *DB) runPassiveCheckpoint(trigger string) {
|
||||||
|
if db == nil || db.DB == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
startAt := time.Now()
|
||||||
|
var busy, logFrames, checkpointed int
|
||||||
|
err := db.QueryRow("PRAGMA wal_checkpoint(PASSIVE)").Scan(&busy, &logFrames, &checkpointed)
|
||||||
|
if db.logger == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fields := []zap.Field{
|
||||||
|
zap.String("db", db.checkpointLoopName),
|
||||||
|
zap.String("trigger", trigger),
|
||||||
|
zap.Int("busy", busy),
|
||||||
|
zap.Int("log_frames", logFrames),
|
||||||
|
zap.Int("checkpointed_frames", checkpointed),
|
||||||
|
zap.Int64("elapsed_ms", time.Since(startAt).Milliseconds()),
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
db.logger.Warn("SQLite PASSIVE checkpoint 完成(失败)",
|
||||||
|
append(fields, zap.Error(err))...,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if busy > 0 {
|
||||||
|
db.logger.Info("SQLite PASSIVE checkpoint 完成(部分推进)", fields...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
db.logger.Info("SQLite PASSIVE checkpoint 完成(成功)", fields...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDB 创建数据库连接
|
// NewDB 创建数据库连接
|
||||||
@@ -37,8 +125,13 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
|||||||
configureDBPool(db)
|
configureDBPool(db)
|
||||||
|
|
||||||
if err := db.Ping(); err != nil {
|
if err := db.Ping(); err != nil {
|
||||||
|
_ = db.Close()
|
||||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||||
}
|
}
|
||||||
|
if err := configureSQLitePragmas(db); err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
return nil, fmt.Errorf("配置数据库 PRAGMA 失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
database := &DB{
|
database := &DB{
|
||||||
DB: db,
|
DB: db,
|
||||||
@@ -54,8 +147,10 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
|||||||
|
|
||||||
// 初始化表
|
// 初始化表
|
||||||
if err := database.initTables(); err != nil {
|
if err := database.initTables(); err != nil {
|
||||||
|
_ = db.Close()
|
||||||
return nil, fmt.Errorf("初始化表失败: %w", err)
|
return nil, fmt.Errorf("初始化表失败: %w", err)
|
||||||
}
|
}
|
||||||
|
database.startPassiveCheckpointLoop("conversations")
|
||||||
|
|
||||||
return database, nil
|
return database, nil
|
||||||
}
|
}
|
||||||
@@ -213,6 +308,59 @@ func (db *DB) initTables() error {
|
|||||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||||
);`
|
);`
|
||||||
|
|
||||||
|
// 创建项目表
|
||||||
|
createProjectsTable := `
|
||||||
|
CREATE TABLE IF NOT EXISTS projects (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
scope_json TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'active',
|
||||||
|
pinned INTEGER NOT NULL DEFAULT 0,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
updated_at DATETIME NOT NULL
|
||||||
|
);`
|
||||||
|
|
||||||
|
// 创建项目事实表(黑板)
|
||||||
|
createProjectFactsTable := `
|
||||||
|
CREATE TABLE IF NOT EXISTS project_facts (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
project_id TEXT NOT NULL,
|
||||||
|
fact_key TEXT NOT NULL,
|
||||||
|
category TEXT NOT NULL DEFAULT 'note',
|
||||||
|
summary TEXT NOT NULL DEFAULT '',
|
||||||
|
body TEXT,
|
||||||
|
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||||
|
source_conversation_id TEXT,
|
||||||
|
source_message_id TEXT,
|
||||||
|
pinned INTEGER NOT NULL DEFAULT 0,
|
||||||
|
supersedes_fact_id TEXT,
|
||||||
|
related_vulnerability_id TEXT,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
updated_at DATETIME NOT NULL,
|
||||||
|
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||||
|
UNIQUE(project_id, fact_key)
|
||||||
|
);`
|
||||||
|
|
||||||
|
createProjectFactVersionsTable := `
|
||||||
|
CREATE TABLE IF NOT EXISTS project_fact_versions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
fact_id TEXT NOT NULL,
|
||||||
|
project_id TEXT NOT NULL,
|
||||||
|
fact_key TEXT NOT NULL,
|
||||||
|
category TEXT NOT NULL DEFAULT 'note',
|
||||||
|
summary TEXT NOT NULL DEFAULT '',
|
||||||
|
body TEXT,
|
||||||
|
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||||
|
source_conversation_id TEXT,
|
||||||
|
source_message_id TEXT,
|
||||||
|
pinned INTEGER NOT NULL DEFAULT 0,
|
||||||
|
related_vulnerability_id TEXT,
|
||||||
|
archived_at DATETIME NOT NULL,
|
||||||
|
FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
||||||
|
);`
|
||||||
|
|
||||||
// 创建漏洞表
|
// 创建漏洞表
|
||||||
createVulnerabilitiesTable := `
|
createVulnerabilitiesTable := `
|
||||||
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
||||||
@@ -240,7 +388,7 @@ func (db *DB) initTables() error {
|
|||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
title TEXT,
|
title TEXT,
|
||||||
role TEXT,
|
role TEXT,
|
||||||
agent_mode TEXT NOT NULL DEFAULT 'single',
|
agent_mode TEXT NOT NULL DEFAULT 'eino_single',
|
||||||
schedule_mode TEXT NOT NULL DEFAULT 'manual',
|
schedule_mode TEXT NOT NULL DEFAULT 'manual',
|
||||||
cron_expr TEXT,
|
cron_expr TEXT,
|
||||||
next_run_at DATETIME,
|
next_run_at DATETIME,
|
||||||
@@ -445,6 +593,14 @@ func (db *DB) initTables() error {
|
|||||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
|
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
|
||||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
|
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
|
||||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
|
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_projects_status ON projects(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_projects_updated_at ON projects(updated_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at);
|
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at);
|
||||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
|
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
|
||||||
@@ -516,6 +672,18 @@ func (db *DB) initTables() error {
|
|||||||
return fmt.Errorf("创建robot_user_sessions表失败: %w", err)
|
return fmt.Errorf("创建robot_user_sessions表失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(createProjectsTable); err != nil {
|
||||||
|
return fmt.Errorf("创建projects表失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(createProjectFactsTable); err != nil {
|
||||||
|
return fmt.Errorf("创建project_facts表失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(createProjectFactVersionsTable); err != nil {
|
||||||
|
return fmt.Errorf("创建project_fact_versions表失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
||||||
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
|
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
|
||||||
}
|
}
|
||||||
@@ -583,6 +751,13 @@ func (db *DB) initTables() error {
|
|||||||
// 不返回错误,允许继续运行
|
// 不返回错误,允许继续运行
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := db.migrateProjectsTable(); err != nil {
|
||||||
|
db.logger.Warn("迁移projects相关表失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
if err := db.migrateProjectFactVersionsTable(); err != nil {
|
||||||
|
db.logger.Warn("迁移project_fact_versions表失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
if err := db.migrateWebshellConnectionsTable(); err != nil {
|
if err := db.migrateWebshellConnectionsTable(); err != nil {
|
||||||
db.logger.Warn("迁移webshell_connections表失败", zap.Error(err))
|
db.logger.Warn("迁移webshell_connections表失败", zap.Error(err))
|
||||||
// 不返回错误,允许继续运行
|
// 不返回错误,允许继续运行
|
||||||
@@ -809,14 +984,14 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
|
|||||||
var agentModeCount int
|
var agentModeCount int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='agent_mode'").Scan(&agentModeCount)
|
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='agent_mode'").Scan(&agentModeCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); addErr != nil {
|
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'eino_single'"); addErr != nil {
|
||||||
errMsg := strings.ToLower(addErr.Error())
|
errMsg := strings.ToLower(addErr.Error())
|
||||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||||
db.logger.Warn("添加agent_mode字段失败", zap.Error(addErr))
|
db.logger.Warn("添加agent_mode字段失败", zap.Error(addErr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if agentModeCount == 0 {
|
} else if agentModeCount == 0 {
|
||||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); err != nil {
|
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'eino_single'"); err != nil {
|
||||||
db.logger.Warn("添加agent_mode字段失败", zap.Error(err))
|
db.logger.Warn("添加agent_mode字段失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -930,6 +1105,79 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var projectIDCount int
|
||||||
|
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='project_id'").Scan(&projectIDCount)
|
||||||
|
if err != nil {
|
||||||
|
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN project_id TEXT"); addErr != nil {
|
||||||
|
errMsg := strings.ToLower(addErr.Error())
|
||||||
|
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||||
|
db.logger.Warn("添加batch_task_queues.project_id字段失败", zap.Error(addErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if projectIDCount == 0 {
|
||||||
|
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN project_id TEXT"); err != nil {
|
||||||
|
db.logger.Warn("添加batch_task_queues.project_id字段失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrateProjectsTable 迁移 projects / conversations / vulnerabilities 的项目关联字段。
|
||||||
|
func (db *DB) migrateProjectsTable() error {
|
||||||
|
for _, col := range []struct {
|
||||||
|
table string
|
||||||
|
name string
|
||||||
|
stmt string
|
||||||
|
}{
|
||||||
|
{"conversations", "project_id", "ALTER TABLE conversations ADD COLUMN project_id TEXT REFERENCES projects(id) ON DELETE SET NULL"},
|
||||||
|
{"vulnerabilities", "project_id", "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"},
|
||||||
|
} {
|
||||||
|
var count int
|
||||||
|
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info(?) WHERE name=?", col.table, col.name).Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
if _, addErr := db.Exec(col.stmt); addErr != nil {
|
||||||
|
errMsg := strings.ToLower(addErr.Error())
|
||||||
|
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||||
|
db.logger.Warn("添加字段失败", zap.String("table", col.table), zap.String("field", col.name), zap.Error(addErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if count == 0 {
|
||||||
|
if _, addErr := db.Exec(col.stmt); addErr != nil {
|
||||||
|
db.logger.Warn("添加字段失败", zap.String("table", col.table), zap.String("field", col.name), zap.Error(addErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrateProjectFactVersionsTable 为已有库创建事实版本表。
|
||||||
|
func (db *DB) migrateProjectFactVersionsTable() error {
|
||||||
|
ddl := `
|
||||||
|
CREATE TABLE IF NOT EXISTS project_fact_versions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
fact_id TEXT NOT NULL,
|
||||||
|
project_id TEXT NOT NULL,
|
||||||
|
fact_key TEXT NOT NULL,
|
||||||
|
category TEXT NOT NULL DEFAULT 'note',
|
||||||
|
summary TEXT NOT NULL DEFAULT '',
|
||||||
|
body TEXT,
|
||||||
|
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||||
|
source_conversation_id TEXT,
|
||||||
|
source_message_id TEXT,
|
||||||
|
pinned INTEGER NOT NULL DEFAULT 0,
|
||||||
|
related_vulnerability_id TEXT,
|
||||||
|
archived_at DATETIME NOT NULL,
|
||||||
|
FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
||||||
|
);`
|
||||||
|
if _, err := db.Exec(ddl); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id)`)
|
||||||
|
_, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id)`)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -941,6 +1189,7 @@ func (db *DB) migrateVulnerabilitiesTable() error {
|
|||||||
}{
|
}{
|
||||||
{name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"},
|
{name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"},
|
||||||
{name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"},
|
{name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"},
|
||||||
|
{name: "project_id", stmt: "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
@@ -1005,8 +1254,13 @@ func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
|||||||
configureDBPool(sqlDB)
|
configureDBPool(sqlDB)
|
||||||
|
|
||||||
if err := sqlDB.Ping(); err != nil {
|
if err := sqlDB.Ping(); err != nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
return nil, fmt.Errorf("连接知识库数据库失败: %w", err)
|
return nil, fmt.Errorf("连接知识库数据库失败: %w", err)
|
||||||
}
|
}
|
||||||
|
if err := configureSQLitePragmas(sqlDB); err != nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
|
return nil, fmt.Errorf("配置知识库数据库 PRAGMA 失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
database := &DB{
|
database := &DB{
|
||||||
DB: sqlDB,
|
DB: sqlDB,
|
||||||
@@ -1015,8 +1269,10 @@ func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
|||||||
|
|
||||||
// 初始化知识库表
|
// 初始化知识库表
|
||||||
if err := database.initKnowledgeTables(); err != nil {
|
if err := database.initKnowledgeTables(); err != nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
return nil, fmt.Errorf("初始化知识库表失败: %w", err)
|
return nil, fmt.Errorf("初始化知识库表失败: %w", err)
|
||||||
}
|
}
|
||||||
|
database.startPassiveCheckpointLoop("knowledge")
|
||||||
|
|
||||||
return database, nil
|
return database, nil
|
||||||
}
|
}
|
||||||
@@ -1130,5 +1386,19 @@ func (db *DB) migrateKnowledgeEmbeddingsColumns() error {
|
|||||||
|
|
||||||
// Close 关闭数据库连接
|
// Close 关闭数据库连接
|
||||||
func (db *DB) Close() error {
|
func (db *DB) Close() error {
|
||||||
return db.DB.Close()
|
if db == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
db.closeOnce.Do(func() {
|
||||||
|
if db.checkpointStop != nil {
|
||||||
|
close(db.checkpointStop)
|
||||||
|
if db.checkpointDone != nil {
|
||||||
|
<-db.checkpointDone
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if db.DB != nil {
|
||||||
|
db.closeErr = db.DB.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return db.closeErr
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,513 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
var factKeyPattern = regexp.MustCompile(`^[a-z0-9][a-z0-9._/-]*$`)
|
||||||
|
|
||||||
|
// ValidateFactKey 校验事实 key(项目内唯一标识)。
|
||||||
|
func ValidateFactKey(key string) error {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
return fmt.Errorf("fact_key 不能为空")
|
||||||
|
}
|
||||||
|
if len(key) > 128 {
|
||||||
|
return fmt.Errorf("fact_key 过长(最多 128 字符)")
|
||||||
|
}
|
||||||
|
if !factKeyPattern.MatchString(key) {
|
||||||
|
return fmt.Errorf("fact_key 格式无效,仅允许小写字母、数字及 . _ / -,且须以小写字母或数字开头")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Project 渗透测试项目(跨对话共享黑板)。
|
||||||
|
type Project struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
ScopeJSON string `json:"scope_json,omitempty"`
|
||||||
|
Status string `json:"status"` // active | archived
|
||||||
|
Pinned bool `json:"pinned"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFact 项目事实(黑板条目)。
|
||||||
|
type ProjectFact struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
ProjectID string `json:"project_id"`
|
||||||
|
FactKey string `json:"fact_key"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
Summary string `json:"summary"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
Confidence string `json:"confidence"` // confirmed | tentative | deprecated
|
||||||
|
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
||||||
|
SourceMessageID string `json:"source_message_id,omitempty"`
|
||||||
|
Pinned bool `json:"pinned"`
|
||||||
|
SupersedesFactID string `json:"supersedes_fact_id,omitempty"`
|
||||||
|
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactListFilter 事实列表筛选。
|
||||||
|
type ProjectFactListFilter struct {
|
||||||
|
Category string
|
||||||
|
Confidence string
|
||||||
|
Search string
|
||||||
|
RelatedVulnerabilityID string
|
||||||
|
ExcludeDeprecated bool // 为 true 时排除 confidence=deprecated
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProject 创建项目。
|
||||||
|
func (db *DB) CreateProject(p *Project) (*Project, error) {
|
||||||
|
if p.ID == "" {
|
||||||
|
p.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(p.Status) == "" {
|
||||||
|
p.Status = "active"
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if p.CreatedAt.IsZero() {
|
||||||
|
p.CreatedAt = now
|
||||||
|
}
|
||||||
|
p.UpdatedAt = now
|
||||||
|
|
||||||
|
_, err := db.Exec(
|
||||||
|
`INSERT INTO projects (id, name, description, scope_json, status, pinned, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
p.ID, p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.CreatedAt, p.UpdatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("创建项目失败: %w", err)
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProject 获取项目。
|
||||||
|
func (db *DB) GetProject(id string) (*Project, error) {
|
||||||
|
var p Project
|
||||||
|
var pinned int
|
||||||
|
var createdAt, updatedAt string
|
||||||
|
err := db.QueryRow(
|
||||||
|
`SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
||||||
|
FROM projects WHERE id = ?`, id,
|
||||||
|
).Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("项目不存在")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("获取项目失败: %w", err)
|
||||||
|
}
|
||||||
|
p.Pinned = pinned != 0
|
||||||
|
p.CreatedAt = parseDBTime(createdAt)
|
||||||
|
p.UpdatedAt = parseDBTime(updatedAt)
|
||||||
|
return &p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProjects 列出项目。
|
||||||
|
func (db *DB) ListProjects(status string, limit, offset int) ([]*Project, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 200
|
||||||
|
}
|
||||||
|
query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
||||||
|
FROM projects WHERE 1=1`
|
||||||
|
args := []interface{}{}
|
||||||
|
if s := strings.TrimSpace(status); s != "" {
|
||||||
|
query += " AND status = ?"
|
||||||
|
args = append(args, s)
|
||||||
|
}
|
||||||
|
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
||||||
|
args = append(args, limit, offset)
|
||||||
|
|
||||||
|
rows, err := db.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("列出项目失败: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var out []*Project
|
||||||
|
for rows.Next() {
|
||||||
|
var p Project
|
||||||
|
var pinned int
|
||||||
|
var createdAt, updatedAt string
|
||||||
|
if err := rows.Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p.Pinned = pinned != 0
|
||||||
|
p.CreatedAt = parseDBTime(createdAt)
|
||||||
|
p.UpdatedAt = parseDBTime(updatedAt)
|
||||||
|
out = append(out, &p)
|
||||||
|
}
|
||||||
|
return out, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateProject 更新项目。
|
||||||
|
func (db *DB) UpdateProject(p *Project) error {
|
||||||
|
p.UpdatedAt = time.Now()
|
||||||
|
_, err := db.Exec(
|
||||||
|
`UPDATE projects SET name = ?, description = ?, scope_json = ?, status = ?, pinned = ?, updated_at = ? WHERE id = ?`,
|
||||||
|
p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.UpdatedAt, p.ID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("更新项目失败: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteProject 删除项目(级联删除事实;对话 project_id 置空由 FK 处理;漏洞 project_id 置空)。
|
||||||
|
func (db *DB) DeleteProject(id string) error {
|
||||||
|
if _, err := db.Exec(`UPDATE vulnerabilities SET project_id = NULL WHERE project_id = ?`, id); err != nil {
|
||||||
|
return fmt.Errorf("解除漏洞项目关联失败: %w", err)
|
||||||
|
}
|
||||||
|
_, err := db.Exec(`DELETE FROM projects WHERE id = ?`, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("删除项目失败: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationProjectID 返回对话绑定的项目 ID。
|
||||||
|
func (db *DB) GetConversationProjectID(conversationID string) (string, error) {
|
||||||
|
var pid sql.NullString
|
||||||
|
err := db.QueryRow(`SELECT project_id FROM conversations WHERE id = ?`, conversationID).Scan(&pid)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return "", fmt.Errorf("对话不存在")
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if pid.Valid {
|
||||||
|
return strings.TrimSpace(pid.String), nil
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConversationProjectID 设置对话所属项目(空字符串表示解除绑定)。
|
||||||
|
func (db *DB) SetConversationProjectID(conversationID, projectID string) error {
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
if projectID != "" {
|
||||||
|
if _, err := db.GetProject(projectID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var val interface{}
|
||||||
|
if projectID == "" {
|
||||||
|
val = nil
|
||||||
|
} else {
|
||||||
|
val = projectID
|
||||||
|
}
|
||||||
|
_, err := db.Exec(`UPDATE conversations SET project_id = ?, updated_at = ? WHERE id = ?`, val, time.Now(), conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("设置对话项目失败: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProjectFactsForIndex 列出用于黑板索引注入的事实(不含 deprecated,除非 includeDeprecated)。
|
||||||
|
func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) {
|
||||||
|
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||||
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||||
|
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||||
|
FROM project_facts WHERE project_id = ?`
|
||||||
|
args := []interface{}{projectID}
|
||||||
|
if !includeDeprecated {
|
||||||
|
query += " AND confidence != 'deprecated'"
|
||||||
|
}
|
||||||
|
query += " ORDER BY pinned DESC, updated_at DESC"
|
||||||
|
rows, err := db.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return scanProjectFacts(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProjectFacts 分页列出项目事实。
|
||||||
|
func (db *DB) ListProjectFacts(projectID string, filter ProjectFactListFilter, limit, offset int) ([]*ProjectFact, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||||
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||||
|
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||||
|
FROM project_facts WHERE project_id = ?`
|
||||||
|
args := []interface{}{projectID}
|
||||||
|
if c := strings.TrimSpace(filter.Category); c != "" {
|
||||||
|
query += " AND category = ?"
|
||||||
|
args = append(args, c)
|
||||||
|
}
|
||||||
|
if c := strings.TrimSpace(filter.Confidence); c != "" {
|
||||||
|
query += " AND confidence = ?"
|
||||||
|
args = append(args, c)
|
||||||
|
}
|
||||||
|
if filter.ExcludeDeprecated {
|
||||||
|
query += " AND confidence != 'deprecated'"
|
||||||
|
}
|
||||||
|
if rid := strings.TrimSpace(filter.RelatedVulnerabilityID); rid != "" {
|
||||||
|
query += " AND related_vulnerability_id = ?"
|
||||||
|
args = append(args, rid)
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(filter.Search); s != "" {
|
||||||
|
pat := "%" + s + "%"
|
||||||
|
query += " AND (fact_key LIKE ? OR summary LIKE ? OR body LIKE ?)"
|
||||||
|
args = append(args, pat, pat, pat)
|
||||||
|
}
|
||||||
|
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
||||||
|
args = append(args, limit, offset)
|
||||||
|
|
||||||
|
rows, err := db.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return scanProjectFacts(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectFactByKey 按 key 获取事实。
|
||||||
|
func (db *DB) GetProjectFactByKey(projectID, factKey string) (*ProjectFact, error) {
|
||||||
|
row := db.QueryRow(
|
||||||
|
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||||
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||||
|
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||||
|
FROM project_facts WHERE project_id = ? AND fact_key = ?`,
|
||||||
|
projectID, factKey,
|
||||||
|
)
|
||||||
|
return scanProjectFactRow(row)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectFact 按 ID 获取事实。
|
||||||
|
func (db *DB) GetProjectFact(id string) (*ProjectFact, error) {
|
||||||
|
row := db.QueryRow(
|
||||||
|
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||||
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||||
|
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||||
|
FROM project_facts WHERE id = ?`, id,
|
||||||
|
)
|
||||||
|
return scanProjectFactRow(row)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeFactBodyOnUpdate 更新时若 incoming body 为空则保留已有内容,避免仅改 summary 时丢失攻击链。
|
||||||
|
func mergeFactBodyOnUpdate(incoming, existing string) string {
|
||||||
|
if strings.TrimSpace(incoming) == "" {
|
||||||
|
return existing
|
||||||
|
}
|
||||||
|
return incoming
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpsertProjectFact 创建或更新事实(按 project_id + fact_key)。
|
||||||
|
func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
|
||||||
|
if err := ValidateFactKey(f.FactKey); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(f.Category) == "" {
|
||||||
|
f.Category = "note"
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(f.Confidence) == "" {
|
||||||
|
f.Confidence = "tentative"
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
existing, err := db.GetProjectFactByKey(f.ProjectID, f.FactKey)
|
||||||
|
if err == nil && existing != nil {
|
||||||
|
f.ID = existing.ID
|
||||||
|
f.CreatedAt = existing.CreatedAt
|
||||||
|
f.UpdatedAt = now
|
||||||
|
f.Body = mergeFactBodyOnUpdate(f.Body, existing.Body)
|
||||||
|
if strings.TrimSpace(f.Category) == "" {
|
||||||
|
f.Category = existing.Category
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(f.Confidence) == "" {
|
||||||
|
f.Confidence = existing.Confidence
|
||||||
|
}
|
||||||
|
if projectFactContentChanged(existing, f) {
|
||||||
|
versionID, verr := db.InsertProjectFactVersion(existing)
|
||||||
|
if verr != nil {
|
||||||
|
return nil, verr
|
||||||
|
}
|
||||||
|
f.SupersedesFactID = versionID
|
||||||
|
} else if f.SupersedesFactID == "" {
|
||||||
|
f.SupersedesFactID = existing.SupersedesFactID
|
||||||
|
}
|
||||||
|
_, err = db.Exec(
|
||||||
|
`UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?,
|
||||||
|
source_conversation_id = COALESCE(?, source_conversation_id),
|
||||||
|
source_message_id = COALESCE(?, source_message_id),
|
||||||
|
pinned = ?, supersedes_fact_id = ?, related_vulnerability_id = ?, updated_at = ?
|
||||||
|
WHERE id = ?`,
|
||||||
|
f.Category, f.Summary, f.Body, f.Confidence,
|
||||||
|
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||||
|
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("更新事实失败: %w", err)
|
||||||
|
}
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.ID == "" {
|
||||||
|
f.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
f.CreatedAt = now
|
||||||
|
f.UpdatedAt = now
|
||||||
|
_, err = db.Exec(
|
||||||
|
`INSERT INTO project_facts (
|
||||||
|
id, project_id, fact_key, category, summary, body, confidence,
|
||||||
|
source_conversation_id, source_message_id, pinned, supersedes_fact_id, related_vulnerability_id,
|
||||||
|
created_at, updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
|
||||||
|
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||||
|
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID),
|
||||||
|
f.CreatedAt, f.UpdatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("创建事实失败: %w", err)
|
||||||
|
}
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeprecateProjectFact 将事实标记为 deprecated。
|
||||||
|
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
||||||
|
res, err := db.Exec(
|
||||||
|
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
||||||
|
time.Now(), projectID, factKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return fmt.Errorf("事实不存在")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。
|
||||||
|
func (db *DB) RestoreProjectFact(projectID, factKey, confidence string) error {
|
||||||
|
confidence = strings.TrimSpace(strings.ToLower(confidence))
|
||||||
|
if confidence == "" {
|
||||||
|
confidence = "tentative"
|
||||||
|
}
|
||||||
|
if confidence != "confirmed" && confidence != "tentative" {
|
||||||
|
return fmt.Errorf("confidence 须为 confirmed 或 tentative")
|
||||||
|
}
|
||||||
|
|
||||||
|
existing, err := db.GetProjectFactByKey(projectID, factKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("事实不存在")
|
||||||
|
}
|
||||||
|
if strings.ToLower(strings.TrimSpace(existing.Confidence)) != "deprecated" {
|
||||||
|
return fmt.Errorf("事实未处于废弃状态")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec(
|
||||||
|
`UPDATE project_facts SET confidence = ?, updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
||||||
|
confidence, time.Now(), projectID, factKey,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteProjectFact 删除事实。
|
||||||
|
func (db *DB) DeleteProjectFact(id string) error {
|
||||||
|
_, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanProjectFacts(rows *sql.Rows) ([]*ProjectFact, error) {
|
||||||
|
var out []*ProjectFact
|
||||||
|
for rows.Next() {
|
||||||
|
f, err := scanProjectFactFromRows(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out = append(out, f)
|
||||||
|
}
|
||||||
|
return out, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanProjectFactRow(row *sql.Row) (*ProjectFact, error) {
|
||||||
|
var f ProjectFact
|
||||||
|
var pinned int
|
||||||
|
var createdAt, updatedAt string
|
||||||
|
err := row.Scan(
|
||||||
|
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
||||||
|
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
||||||
|
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("事实不存在")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f.Pinned = pinned != 0
|
||||||
|
f.CreatedAt = parseDBTime(createdAt)
|
||||||
|
f.UpdatedAt = parseDBTime(updatedAt)
|
||||||
|
return &f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanProjectFactFromRows(rows *sql.Rows) (*ProjectFact, error) {
|
||||||
|
var f ProjectFact
|
||||||
|
var pinned int
|
||||||
|
var createdAt, updatedAt string
|
||||||
|
err := rows.Scan(
|
||||||
|
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
||||||
|
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
||||||
|
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f.Pinned = pinned != 0
|
||||||
|
f.CreatedAt = parseDBTime(createdAt)
|
||||||
|
f.UpdatedAt = parseDBTime(updatedAt)
|
||||||
|
return &f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolToInt(b bool) int {
|
||||||
|
if b {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func nullIfEmpty(s string) interface{} {
|
||||||
|
if strings.TrimSpace(s) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDBTime(s string) time.Time {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
// go-sqlite3 读 DATETIME 常返回 RFC3339(含 T),写入时可能是空格分隔格式,需兼容多种形态
|
||||||
|
layouts := []string{
|
||||||
|
time.RFC3339Nano,
|
||||||
|
time.RFC3339,
|
||||||
|
"2006-01-02 15:04:05.999999999-07:00",
|
||||||
|
"2006-01-02 15:04:05-07:00",
|
||||||
|
"2006-01-02T15:04:05.999999999-07:00",
|
||||||
|
"2006-01-02T15:04:05-07:00",
|
||||||
|
"2006-01-02 15:04:05.999999999",
|
||||||
|
"2006-01-02 15:04:05",
|
||||||
|
"2006-01-02T15:04:05.999999999",
|
||||||
|
"2006-01-02T15:04:05",
|
||||||
|
}
|
||||||
|
for _, layout := range layouts {
|
||||||
|
if t, e := time.Parse(layout, s); e == nil {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
@@ -0,0 +1,196 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpsertProjectFact_preservesBodyOnEmptyUpdate(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
proj, err := db.CreateProject(&Project{Name: "test-facts"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const body = "## 攻击链\n1. step\n```http\nGET / HTTP/1.1\n```\n"
|
||||||
|
_, err = db.UpsertProjectFact(&ProjectFact{
|
||||||
|
ProjectID: proj.ID,
|
||||||
|
FactKey: "finding/sqli-login",
|
||||||
|
Category: "finding",
|
||||||
|
Summary: "SQLi on /login",
|
||||||
|
Body: body,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := db.UpsertProjectFact(&ProjectFact{
|
||||||
|
ProjectID: proj.ID,
|
||||||
|
FactKey: "finding/sqli-login",
|
||||||
|
Summary: "SQLi on /login (confirmed)",
|
||||||
|
Body: "",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if updated.Summary != "SQLi on /login (confirmed)" {
|
||||||
|
t.Fatalf("summary=%q", updated.Summary)
|
||||||
|
}
|
||||||
|
if updated.Body != body {
|
||||||
|
t.Fatalf("returned body=%q want preserved attack chain", updated.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
fromDB, err := db.GetProjectFactByKey(proj.ID, "finding/sqli-login")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if fromDB.Body != body {
|
||||||
|
t.Fatalf("stored body=%q want preserved", fromDB.Body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpsertProjectFact_replacesBodyWhenProvided(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
proj, err := db.CreateProject(&Project{Name: "test-facts"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.UpsertProjectFact(&ProjectFact{
|
||||||
|
ProjectID: proj.ID,
|
||||||
|
FactKey: "target/primary",
|
||||||
|
Summary: "v1",
|
||||||
|
Body: "old body",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const newBody = "new body with evidence"
|
||||||
|
updated, err := db.UpsertProjectFact(&ProjectFact{
|
||||||
|
ProjectID: proj.ID,
|
||||||
|
FactKey: "target/primary",
|
||||||
|
Summary: "v2",
|
||||||
|
Body: newBody,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if updated.Body != newBody {
|
||||||
|
t.Fatalf("body=%q want %q", updated.Body, newBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRestoreProjectFact(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
proj, err := db.CreateProject(&Project{Name: "restore-test"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
key := "target/restore-me"
|
||||||
|
_, err = db.UpsertProjectFact(&ProjectFact{
|
||||||
|
ProjectID: proj.ID,
|
||||||
|
FactKey: key,
|
||||||
|
Summary: "s",
|
||||||
|
Confidence: "confirmed",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.DeprecateProjectFact(proj.ID, key); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.RestoreProjectFact(proj.ID, key, "confirmed"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
f, err := db.GetProjectFactByKey(proj.ID, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if f.Confidence != "confirmed" {
|
||||||
|
t.Fatalf("confidence=%q want confirmed", f.Confidence)
|
||||||
|
}
|
||||||
|
if err := db.RestoreProjectFact(proj.ID, key, ""); err == nil {
|
||||||
|
t.Fatal("expected error when not deprecated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpsertProjectFact_createsVersionOnContentChange(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
proj, err := db.CreateProject(&Project{Name: "version-test"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := db.UpsertProjectFact(&ProjectFact{
|
||||||
|
ProjectID: proj.ID,
|
||||||
|
FactKey: "finding/xss",
|
||||||
|
Category: "finding",
|
||||||
|
Summary: "v1",
|
||||||
|
Body: "body v1",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if created.SupersedesFactID != "" {
|
||||||
|
t.Fatalf("expected no supersedes on create, got %q", created.SupersedesFactID)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := db.UpsertProjectFact(&ProjectFact{
|
||||||
|
ProjectID: proj.ID,
|
||||||
|
FactKey: "finding/xss",
|
||||||
|
Summary: "v2",
|
||||||
|
Body: "body v2",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if updated.SupersedesFactID == "" {
|
||||||
|
t.Fatal("expected supersedes_fact_id after content change")
|
||||||
|
}
|
||||||
|
prev, err := db.GetProjectFactVersion(updated.SupersedesFactID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if prev.Summary != "v1" || prev.Body != "body v1" {
|
||||||
|
t.Fatalf("previous version mismatch: summary=%q body=%q", prev.Summary, prev.Body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeFactBodyOnUpdate(t *testing.T) {
|
||||||
|
if got := mergeFactBodyOnUpdate("", "keep"); got != "keep" {
|
||||||
|
t.Fatalf("empty incoming: got %q", got)
|
||||||
|
}
|
||||||
|
if got := mergeFactBodyOnUpdate(" ", "keep"); got != "keep" {
|
||||||
|
t.Fatalf("whitespace incoming: got %q", got)
|
||||||
|
}
|
||||||
|
if got := mergeFactBodyOnUpdate("new", "old"); got != "new" {
|
||||||
|
t.Fatalf("non-empty incoming: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,144 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProjectFactVersion 事实历史快照(同 fact_key 更新前归档)。
|
||||||
|
type ProjectFactVersion struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
FactID string `json:"fact_id"`
|
||||||
|
ProjectID string `json:"project_id"`
|
||||||
|
FactKey string `json:"fact_key"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
Summary string `json:"summary"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
||||||
|
SourceMessageID string `json:"source_message_id,omitempty"`
|
||||||
|
Pinned bool `json:"pinned"`
|
||||||
|
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
|
||||||
|
ArchivedAt time.Time `json:"archived_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertProjectFactVersion 将当前事实行快照写入版本表。
|
||||||
|
func (db *DB) InsertProjectFactVersion(f *ProjectFact) (string, error) {
|
||||||
|
if f == nil || f.ID == "" {
|
||||||
|
return "", fmt.Errorf("无效的事实记录")
|
||||||
|
}
|
||||||
|
id := uuid.New().String()
|
||||||
|
now := time.Now()
|
||||||
|
_, err := db.Exec(
|
||||||
|
`INSERT INTO project_fact_versions (
|
||||||
|
id, fact_id, project_id, fact_key, category, summary, body, confidence,
|
||||||
|
source_conversation_id, source_message_id, pinned, related_vulnerability_id, archived_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
id, f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
|
||||||
|
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||||
|
nullIfEmpty(f.RelatedVulnerabilityID), now,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("归档事实版本失败: %w", err)
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectFactVersion 按版本 ID 获取快照。
|
||||||
|
func (db *DB) GetProjectFactVersion(versionID string) (*ProjectFactVersion, error) {
|
||||||
|
row := db.QueryRow(
|
||||||
|
`SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||||
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||||
|
COALESCE(related_vulnerability_id,''), archived_at
|
||||||
|
FROM project_fact_versions WHERE id = ?`, versionID,
|
||||||
|
)
|
||||||
|
return scanProjectFactVersionRow(row)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProjectFactVersions 列出某条事实的全部历史版本(新→旧)。
|
||||||
|
func (db *DB) ListProjectFactVersions(factID string, limit int) ([]*ProjectFactVersion, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 20
|
||||||
|
}
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||||
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||||
|
COALESCE(related_vulnerability_id,''), archived_at
|
||||||
|
FROM project_fact_versions WHERE fact_id = ? ORDER BY archived_at DESC LIMIT ?`,
|
||||||
|
factID, limit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var out []*ProjectFactVersion
|
||||||
|
for rows.Next() {
|
||||||
|
v, err := scanProjectFactVersionFromRows(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out = append(out, v)
|
||||||
|
}
|
||||||
|
return out, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func projectFactContentChanged(existing, incoming *ProjectFact) bool {
|
||||||
|
if existing == nil || incoming == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
mergedBody := mergeFactBodyOnUpdate(incoming.Body, existing.Body)
|
||||||
|
inCat := stringsTrimDefault(incoming.Category, existing.Category)
|
||||||
|
inConf := stringsTrimDefault(incoming.Confidence, existing.Confidence)
|
||||||
|
return existing.Summary != incoming.Summary ||
|
||||||
|
existing.Body != mergedBody ||
|
||||||
|
existing.Category != inCat ||
|
||||||
|
existing.Confidence != inConf
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringsTrimDefault(s, fallback string) string {
|
||||||
|
if strings.TrimSpace(s) == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanProjectFactVersionRow(row *sql.Row) (*ProjectFactVersion, error) {
|
||||||
|
var v ProjectFactVersion
|
||||||
|
var pinned int
|
||||||
|
var archivedAt string
|
||||||
|
err := row.Scan(
|
||||||
|
&v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence,
|
||||||
|
&v.SourceConversationID, &v.SourceMessageID, &pinned,
|
||||||
|
&v.RelatedVulnerabilityID, &archivedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("事实版本不存在")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
v.Pinned = pinned != 0
|
||||||
|
v.ArchivedAt = parseDBTime(archivedAt)
|
||||||
|
return &v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanProjectFactVersionFromRows(rows *sql.Rows) (*ProjectFactVersion, error) {
|
||||||
|
var v ProjectFactVersion
|
||||||
|
var pinned int
|
||||||
|
var archivedAt string
|
||||||
|
err := rows.Scan(
|
||||||
|
&v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence,
|
||||||
|
&v.SourceConversationID, &v.SourceMessageID, &pinned,
|
||||||
|
&v.RelatedVulnerabilityID, &archivedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
v.Pinned = pinned != 0
|
||||||
|
v.ArchivedAt = parseDBTime(archivedAt)
|
||||||
|
return &v, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProjectStats 项目聚合统计。
|
||||||
|
type ProjectStats struct {
|
||||||
|
FactCount int `json:"fact_count"`
|
||||||
|
VulnCount int `json:"vuln_count"`
|
||||||
|
ConversationCount int `json:"conversation_count"`
|
||||||
|
SparseFactCount int `json:"sparse_fact_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectStatsCounts 统计项目下事实、漏洞、对话数量(不含 sparse,由 project 包补全)。
|
||||||
|
func (db *DB) GetProjectStatsCounts(projectID string) (*ProjectStats, error) {
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
if projectID == "" {
|
||||||
|
return nil, fmt.Errorf("project_id 不能为空")
|
||||||
|
}
|
||||||
|
if _, err := db.GetProject(projectID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stats := &ProjectStats{}
|
||||||
|
if err := db.QueryRow(
|
||||||
|
`SELECT COUNT(*) FROM project_facts WHERE project_id = ? AND confidence != 'deprecated'`,
|
||||||
|
projectID,
|
||||||
|
).Scan(&stats.FactCount); err != nil {
|
||||||
|
return nil, fmt.Errorf("统计事实失败: %w", err)
|
||||||
|
}
|
||||||
|
if err := db.QueryRow(
|
||||||
|
`SELECT COUNT(*) FROM vulnerabilities WHERE project_id = ?`,
|
||||||
|
projectID,
|
||||||
|
).Scan(&stats.VulnCount); err != nil {
|
||||||
|
return nil, fmt.Errorf("统计漏洞失败: %w", err)
|
||||||
|
}
|
||||||
|
if err := db.QueryRow(
|
||||||
|
`SELECT COUNT(*) FROM conversations WHERE project_id = ?`,
|
||||||
|
projectID,
|
||||||
|
).Scan(&stats.ConversationCount); err != nil {
|
||||||
|
return nil, fmt.Errorf("统计对话失败: %w", err)
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProjectFactsForSparseCheck 返回用于待补全检测的事实字段(非 deprecated)。
|
||||||
|
func (db *DB) ListProjectFactsForSparseCheck(projectID string) ([]struct {
|
||||||
|
Category string
|
||||||
|
FactKey string
|
||||||
|
Body string
|
||||||
|
}, error) {
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT category, fact_key, COALESCE(body,'') FROM project_facts WHERE project_id = ? AND confidence != 'deprecated'`,
|
||||||
|
projectID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var out []struct {
|
||||||
|
Category string
|
||||||
|
FactKey string
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
for rows.Next() {
|
||||||
|
var row struct {
|
||||||
|
Category string
|
||||||
|
FactKey string
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&row.Category, &row.FactKey, &row.Body); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out = append(out, row)
|
||||||
|
}
|
||||||
|
return out, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListConversationsByProjectID 列出绑定到项目的对话。
|
||||||
|
func (db *DB) ListConversationsByProjectID(projectID string, limit, offset int) ([]*Conversation, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id
|
||||||
|
FROM conversations WHERE project_id = ? ORDER BY updated_at DESC LIMIT ? OFFSET ?`,
|
||||||
|
projectID, limit, offset,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询项目对话失败: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var conversations []*Conversation
|
||||||
|
for rows.Next() {
|
||||||
|
var conv Conversation
|
||||||
|
var createdAt, updatedAt string
|
||||||
|
var pinned int
|
||||||
|
var pid sql.NullString
|
||||||
|
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &pid); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if pid.Valid {
|
||||||
|
conv.ProjectID = strings.TrimSpace(pid.String)
|
||||||
|
}
|
||||||
|
conv.CreatedAt = parseDBTime(createdAt)
|
||||||
|
conv.UpdatedAt = parseDBTime(updatedAt)
|
||||||
|
conv.Pinned = pinned != 0
|
||||||
|
conversations = append(conversations, &conv)
|
||||||
|
}
|
||||||
|
return conversations, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountConversationsByProjectID 统计项目绑定对话数。
|
||||||
|
func (db *DB) CountConversationsByProjectID(projectID string) (int, error) {
|
||||||
|
var n int
|
||||||
|
err := db.QueryRow(`SELECT COUNT(*) FROM conversations WHERE project_id = ?`, projectID).Scan(&n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseDBTime_projectFactFormats(t *testing.T) {
|
||||||
|
cases := []string{
|
||||||
|
"2026-05-26 11:13:07.442143+08:00",
|
||||||
|
"2026-05-26 11:13:07",
|
||||||
|
"2026-05-26T11:13:07.442143+08:00",
|
||||||
|
}
|
||||||
|
for _, s := range cases {
|
||||||
|
got := parseDBTime(s)
|
||||||
|
if got.IsZero() {
|
||||||
|
t.Fatalf("parseDBTime(%q) returned zero", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListProjectFacts_updatedAtJSON(t *testing.T) {
|
||||||
|
root, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
t.Skip(err)
|
||||||
|
}
|
||||||
|
dbPath := filepath.Join(root, "..", "..", "data", "conversations.db")
|
||||||
|
if _, err := os.Stat(dbPath); err != nil {
|
||||||
|
t.Skip("conversations.db not found")
|
||||||
|
}
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
projects, err := db.ListProjects("", 1, 0)
|
||||||
|
if err != nil || len(projects) == 0 {
|
||||||
|
t.Skip("no projects")
|
||||||
|
}
|
||||||
|
pid := projects[0].ID
|
||||||
|
|
||||||
|
list, err := db.ListProjectFacts(pid, ProjectFactListFilter{}, 5, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(list) == 0 {
|
||||||
|
t.Skip("no facts")
|
||||||
|
}
|
||||||
|
for _, f := range list {
|
||||||
|
if f.UpdatedAt.IsZero() {
|
||||||
|
t.Fatalf("fact %s UpdatedAt is zero after ListProjectFacts", f.FactKey)
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(f)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var m map[string]interface{}
|
||||||
|
if err := json.Unmarshal(b, &m); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
raw, ok := m["updated_at"].(string)
|
||||||
|
if !ok || raw == "" || raw[:4] == "0001" {
|
||||||
|
t.Fatalf("bad updated_at in JSON: %v", m["updated_at"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDBTime_zeroOnGarbage(t *testing.T) {
|
||||||
|
if !parseDBTime("").IsZero() {
|
||||||
|
t.Fatal("expected zero for empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure RFC3339 round-trip used by API is after year 2000.
|
||||||
|
func TestParseDBTime_marshalRoundTrip(t *testing.T) {
|
||||||
|
s := "2026-05-26 11:13:07.442143+08:00"
|
||||||
|
tm := parseDBTime(s)
|
||||||
|
b, err := json.Marshal(tm)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var back time.Time
|
||||||
|
if err := json.Unmarshal(b, &back); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if back.IsZero() {
|
||||||
|
t.Fatalf("unmarshal zero from %s", string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,16 +3,94 @@ package database
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// VulnerabilityListFilter 列表/统计/导出共用的筛选条件
|
||||||
|
type VulnerabilityListFilter struct {
|
||||||
|
ID string
|
||||||
|
Search string // 关键词模糊匹配(标题、描述、类型、目标等)
|
||||||
|
ConversationID string
|
||||||
|
ProjectID string
|
||||||
|
Severity string
|
||||||
|
Status string
|
||||||
|
TaskID string
|
||||||
|
ConversationTag string
|
||||||
|
TaskTag string
|
||||||
|
}
|
||||||
|
|
||||||
|
func escapeVulnerabilityLikePattern(s string) string {
|
||||||
|
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||||
|
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||||
|
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||||
|
return "%" + s + "%"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (string, []interface{}) {
|
||||||
|
if f.ID != "" {
|
||||||
|
query += " AND id = ?"
|
||||||
|
args = append(args, f.ID)
|
||||||
|
}
|
||||||
|
if f.ConversationID != "" {
|
||||||
|
query += " AND conversation_id = ?"
|
||||||
|
args = append(args, f.ConversationID)
|
||||||
|
}
|
||||||
|
if f.ProjectID != "" {
|
||||||
|
query += " AND project_id = ?"
|
||||||
|
args = append(args, f.ProjectID)
|
||||||
|
}
|
||||||
|
if f.TaskID != "" {
|
||||||
|
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||||
|
args = append(args, f.TaskID, f.TaskID)
|
||||||
|
}
|
||||||
|
if f.ConversationTag != "" {
|
||||||
|
query += " AND conversation_tag = ?"
|
||||||
|
args = append(args, f.ConversationTag)
|
||||||
|
}
|
||||||
|
if f.TaskTag != "" {
|
||||||
|
query += " AND task_tag = ?"
|
||||||
|
args = append(args, f.TaskTag)
|
||||||
|
}
|
||||||
|
if f.Severity != "" {
|
||||||
|
query += " AND severity = ?"
|
||||||
|
args = append(args, f.Severity)
|
||||||
|
}
|
||||||
|
if f.Status != "" {
|
||||||
|
query += " AND status = ?"
|
||||||
|
args = append(args, f.Status)
|
||||||
|
}
|
||||||
|
search := strings.TrimSpace(f.Search)
|
||||||
|
if search != "" {
|
||||||
|
pattern := escapeVulnerabilityLikePattern(search)
|
||||||
|
query += ` AND (
|
||||||
|
LOWER(id) LIKE LOWER(?) OR
|
||||||
|
LOWER(title) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(description, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(target, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(task_tag, '')) LIKE LOWER(?)
|
||||||
|
)`
|
||||||
|
for i := 0; i < 11; i++ {
|
||||||
|
args = append(args, pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return query, args
|
||||||
|
}
|
||||||
|
|
||||||
// Vulnerability 漏洞
|
// Vulnerability 漏洞
|
||||||
type Vulnerability struct {
|
type Vulnerability struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
ConversationID string `json:"conversation_id"`
|
ConversationID string `json:"conversation_id"`
|
||||||
|
ProjectID string `json:"project_id,omitempty"`
|
||||||
ConversationTag string `json:"conversation_tag,omitempty"`
|
ConversationTag string `json:"conversation_tag,omitempty"`
|
||||||
TaskTag string `json:"task_tag,omitempty"`
|
TaskTag string `json:"task_tag,omitempty"`
|
||||||
TaskID string `json:"task_id,omitempty"`
|
TaskID string `json:"task_id,omitempty"`
|
||||||
@@ -44,17 +122,23 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
|||||||
}
|
}
|
||||||
vuln.UpdatedAt = now
|
vuln.UpdatedAt = now
|
||||||
|
|
||||||
|
if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID != "" {
|
||||||
|
if pid, err := db.GetConversationProjectID(vuln.ConversationID); err == nil {
|
||||||
|
vuln.ProjectID = pid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO vulnerabilities (
|
INSERT INTO vulnerabilities (
|
||||||
id, conversation_id, conversation_tag, task_tag, title, description, severity, status,
|
id, conversation_id, project_id, conversation_tag, task_tag, title, description, severity, status,
|
||||||
vulnerability_type, target, proof, impact, recommendation,
|
vulnerability_type, target, proof, impact, recommendation,
|
||||||
created_at, updated_at
|
created_at, updated_at
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
|
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
query,
|
query,
|
||||||
vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
|
vuln.ID, vuln.ConversationID, nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
|
||||||
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
|
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
|
||||||
vuln.Proof, vuln.Impact, vuln.Recommendation,
|
vuln.Proof, vuln.Impact, vuln.Recommendation,
|
||||||
vuln.CreatedAt, vuln.UpdatedAt,
|
vuln.CreatedAt, vuln.UpdatedAt,
|
||||||
@@ -70,7 +154,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
|||||||
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||||
var vuln Vulnerability
|
var vuln Vulnerability
|
||||||
query := `
|
query := `
|
||||||
SELECT id, conversation_id, title, description, severity, status,
|
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status,
|
||||||
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
|
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
|
||||||
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
||||||
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
||||||
@@ -80,7 +164,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
|||||||
`
|
`
|
||||||
|
|
||||||
err := db.QueryRow(query, id).Scan(
|
err := db.QueryRow(query, id).Scan(
|
||||||
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
&vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
|
||||||
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
||||||
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
||||||
&vuln.TaskID, &vuln.TaskQueueID,
|
&vuln.TaskID, &vuln.TaskQueueID,
|
||||||
@@ -97,9 +181,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListVulnerabilities 列出漏洞
|
// ListVulnerabilities 列出漏洞
|
||||||
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) {
|
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
|
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag,
|
||||||
vulnerability_type, target, proof, impact, recommendation,
|
vulnerability_type, target, proof, impact, recommendation,
|
||||||
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
||||||
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
||||||
@@ -108,35 +192,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
|||||||
WHERE 1=1
|
WHERE 1=1
|
||||||
`
|
`
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
|
query, args = filter.appendWhere(query, args)
|
||||||
if id != "" {
|
|
||||||
query += " AND id = ?"
|
|
||||||
args = append(args, id)
|
|
||||||
}
|
|
||||||
if conversationID != "" {
|
|
||||||
query += " AND conversation_id = ?"
|
|
||||||
args = append(args, conversationID)
|
|
||||||
}
|
|
||||||
if taskID != "" {
|
|
||||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
|
||||||
args = append(args, taskID, taskID)
|
|
||||||
}
|
|
||||||
if conversationTag != "" {
|
|
||||||
query += " AND conversation_tag = ?"
|
|
||||||
args = append(args, conversationTag)
|
|
||||||
}
|
|
||||||
if taskTag != "" {
|
|
||||||
query += " AND task_tag = ?"
|
|
||||||
args = append(args, taskTag)
|
|
||||||
}
|
|
||||||
if severity != "" {
|
|
||||||
query += " AND severity = ?"
|
|
||||||
args = append(args, severity)
|
|
||||||
}
|
|
||||||
if status != "" {
|
|
||||||
query += " AND status = ?"
|
|
||||||
args = append(args, status)
|
|
||||||
}
|
|
||||||
|
|
||||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||||
args = append(args, limit, offset)
|
args = append(args, limit, offset)
|
||||||
@@ -151,7 +207,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var vuln Vulnerability
|
var vuln Vulnerability
|
||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
&vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
|
||||||
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
||||||
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
||||||
&vuln.TaskID, &vuln.TaskQueueID,
|
&vuln.TaskID, &vuln.TaskQueueID,
|
||||||
@@ -168,38 +224,10 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
|
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
|
||||||
func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) {
|
func (db *DB) CountVulnerabilities(filter VulnerabilityListFilter) (int, error) {
|
||||||
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
|
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
|
query, args = filter.appendWhere(query, args)
|
||||||
if id != "" {
|
|
||||||
query += " AND id = ?"
|
|
||||||
args = append(args, id)
|
|
||||||
}
|
|
||||||
if conversationID != "" {
|
|
||||||
query += " AND conversation_id = ?"
|
|
||||||
args = append(args, conversationID)
|
|
||||||
}
|
|
||||||
if taskID != "" {
|
|
||||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
|
||||||
args = append(args, taskID, taskID)
|
|
||||||
}
|
|
||||||
if conversationTag != "" {
|
|
||||||
query += " AND conversation_tag = ?"
|
|
||||||
args = append(args, conversationTag)
|
|
||||||
}
|
|
||||||
if taskTag != "" {
|
|
||||||
query += " AND task_tag = ?"
|
|
||||||
args = append(args, taskTag)
|
|
||||||
}
|
|
||||||
if severity != "" {
|
|
||||||
query += " AND severity = ?"
|
|
||||||
args = append(args, severity)
|
|
||||||
}
|
|
||||||
if status != "" {
|
|
||||||
query += " AND status = ?"
|
|
||||||
args = append(args, status)
|
|
||||||
}
|
|
||||||
|
|
||||||
var count int
|
var count int
|
||||||
err := db.QueryRow(query, args...).Scan(&count)
|
err := db.QueryRow(query, args...).Scan(&count)
|
||||||
@@ -216,7 +244,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
|||||||
|
|
||||||
query := `
|
query := `
|
||||||
UPDATE vulnerabilities
|
UPDATE vulnerabilities
|
||||||
SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
|
SET project_id = ?, conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
|
||||||
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
|
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
|
||||||
recommendation = ?, updated_at = ?
|
recommendation = ?, updated_at = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
@@ -224,7 +252,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
|||||||
|
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
query,
|
query,
|
||||||
vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
|
nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
|
||||||
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
|
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
|
||||||
vuln.Recommendation, vuln.UpdatedAt, id,
|
vuln.Recommendation, vuln.UpdatedAt, id,
|
||||||
)
|
)
|
||||||
@@ -237,27 +265,32 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
|||||||
|
|
||||||
// DeleteVulnerability 删除漏洞
|
// DeleteVulnerability 删除漏洞
|
||||||
func (db *DB) DeleteVulnerability(id string) error {
|
func (db *DB) DeleteVulnerability(id string) error {
|
||||||
_, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id)
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
return fmt.Errorf("开启事务失败: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
// 删除漏洞前先解除项目事实中的关联,避免前端继续显示已删除漏洞的短 ID。
|
||||||
|
if _, err := tx.Exec("UPDATE project_facts SET related_vulnerability_id = NULL WHERE related_vulnerability_id = ?", id); err != nil {
|
||||||
|
return fmt.Errorf("清理事实漏洞关联失败: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec("DELETE FROM vulnerabilities WHERE id = ?", id); err != nil {
|
||||||
return fmt.Errorf("删除漏洞失败: %w", err)
|
return fmt.Errorf("删除漏洞失败: %w", err)
|
||||||
}
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("提交事务失败: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
|
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
|
||||||
func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) {
|
func (db *DB) GetVulnerabilityStats(filter VulnerabilityListFilter) (map[string]interface{}, error) {
|
||||||
stats := make(map[string]interface{})
|
stats := make(map[string]interface{})
|
||||||
|
|
||||||
where := "WHERE 1=1"
|
where := "WHERE 1=1"
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
if conversationID != "" {
|
where, args = filter.appendWhere(where, args)
|
||||||
where += " AND conversation_id = ?"
|
|
||||||
args = append(args, conversationID)
|
|
||||||
}
|
|
||||||
if taskID != "" {
|
|
||||||
where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
|
||||||
args = append(args, taskID, taskID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 总漏洞数
|
// 总漏洞数
|
||||||
var totalCount int
|
var totalCount int
|
||||||
@@ -357,10 +390,15 @@ func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
|
return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
|
||||||
}
|
}
|
||||||
|
projectIDs, err := collect(`SELECT DISTINCT project_id FROM vulnerabilities WHERE project_id IS NOT NULL AND project_id <> '' ORDER BY created_at DESC LIMIT 200`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询项目ID建议失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return map[string][]string{
|
return map[string][]string{
|
||||||
"vulnerability_ids": vulnIDs,
|
"vulnerability_ids": vulnIDs,
|
||||||
"conversation_ids": conversationIDs,
|
"conversation_ids": conversationIDs,
|
||||||
|
"project_ids": projectIDs,
|
||||||
"task_ids": taskIDs,
|
"task_ids": taskIDs,
|
||||||
"queue_ids": queueIDs,
|
"queue_ids": queueIDs,
|
||||||
"conversation_tags": conversationTags,
|
"conversation_tags": conversationTags,
|
||||||
|
|||||||
@@ -96,6 +96,17 @@ type runHandler struct {
|
|||||||
seq atomic.Uint64
|
seq atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func safeRunInfo(info *callbacks.RunInfo) callbacks.RunInfo {
|
||||||
|
if info == nil {
|
||||||
|
return callbacks.RunInfo{
|
||||||
|
Name: "unknown",
|
||||||
|
Type: "unknown",
|
||||||
|
Component: components.Component("unknown"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *info
|
||||||
|
}
|
||||||
|
|
||||||
func (h *runHandler) genSpanID() string {
|
func (h *runHandler) genSpanID() string {
|
||||||
return fmt.Sprintf("%s-%d", h.runID, h.seq.Add(1))
|
return fmt.Sprintf("%s-%d", h.runID, h.seq.Add(1))
|
||||||
}
|
}
|
||||||
@@ -134,6 +145,7 @@ func (h *runHandler) popMatching(want string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||||
|
ri := safeRunInfo(info)
|
||||||
var parentID string
|
var parentID string
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
if len(h.spanStack) > 0 {
|
if len(h.spanStack) > 0 {
|
||||||
@@ -151,9 +163,9 @@ func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input
|
|||||||
ctx, sp = tracer.Start(ctx, spanName,
|
ctx, sp = tracer.Start(ctx, spanName,
|
||||||
trace.WithSpanKind(trace.SpanKindInternal),
|
trace.WithSpanKind(trace.SpanKindInternal),
|
||||||
trace.WithAttributes(
|
trace.WithAttributes(
|
||||||
attribute.String("eino.component", string(info.Component)),
|
attribute.String("eino.component", string(ri.Component)),
|
||||||
attribute.String("eino.name", info.Name),
|
attribute.String("eino.name", ri.Name),
|
||||||
attribute.String("eino.type", info.Type),
|
attribute.String("eino.type", ri.Type),
|
||||||
attribute.String("cyberstrike.run_id", h.runID),
|
attribute.String("cyberstrike.run_id", h.runID),
|
||||||
attribute.String("cyberstrike.conversation_id", strings.TrimSpace(h.params.ConversationID)),
|
attribute.String("cyberstrike.conversation_id", strings.TrimSpace(h.params.ConversationID)),
|
||||||
attribute.String("cyberstrike.orchestration", strings.TrimSpace(h.params.OrchMode)),
|
attribute.String("cyberstrike.orchestration", strings.TrimSpace(h.params.OrchMode)),
|
||||||
@@ -169,9 +181,9 @@ func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input
|
|||||||
zap.String("runId", h.runID),
|
zap.String("runId", h.runID),
|
||||||
zap.String("spanId", spanID),
|
zap.String("spanId", spanID),
|
||||||
zap.String("parentSpanId", parentID),
|
zap.String("parentSpanId", parentID),
|
||||||
zap.String("component", string(info.Component)),
|
zap.String("component", string(ri.Component)),
|
||||||
zap.String("name", info.Name),
|
zap.String("name", ri.Name),
|
||||||
zap.String("type", info.Type),
|
zap.String("type", ri.Type),
|
||||||
zap.String("phase", "start"),
|
zap.String("phase", "start"),
|
||||||
}
|
}
|
||||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||||
@@ -195,9 +207,9 @@ func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input
|
|||||||
"parentSpanId": parentID,
|
"parentSpanId": parentID,
|
||||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||||
"component": string(info.Component),
|
"component": string(ri.Component),
|
||||||
"name": info.Name,
|
"name": ri.Name,
|
||||||
"type": info.Type,
|
"type": ri.Type,
|
||||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||||
"inputSummary": inSum,
|
"inputSummary": inSum,
|
||||||
"source": "eino_callbacks",
|
"source": "eino_callbacks",
|
||||||
@@ -208,6 +220,7 @@ func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
||||||
|
ri := safeRunInfo(info)
|
||||||
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||||
if spanID == "" {
|
if spanID == "" {
|
||||||
spanID = h.popSpan()
|
spanID = h.popSpan()
|
||||||
@@ -226,9 +239,9 @@ func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output
|
|||||||
fields := []zap.Field{
|
fields := []zap.Field{
|
||||||
zap.String("runId", h.runID),
|
zap.String("runId", h.runID),
|
||||||
zap.String("spanId", spanID),
|
zap.String("spanId", spanID),
|
||||||
zap.String("component", string(info.Component)),
|
zap.String("component", string(ri.Component)),
|
||||||
zap.String("name", info.Name),
|
zap.String("name", ri.Name),
|
||||||
zap.String("type", info.Type),
|
zap.String("type", ri.Type),
|
||||||
zap.String("phase", "end"),
|
zap.String("phase", "end"),
|
||||||
}
|
}
|
||||||
if h.cfg.ZapVerbose {
|
if h.cfg.ZapVerbose {
|
||||||
@@ -243,9 +256,9 @@ func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output
|
|||||||
"spanId": spanID,
|
"spanId": spanID,
|
||||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||||
"component": string(info.Component),
|
"component": string(ri.Component),
|
||||||
"name": info.Name,
|
"name": ri.Name,
|
||||||
"type": info.Type,
|
"type": ri.Type,
|
||||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||||
"outputSummary": outSum,
|
"outputSummary": outSum,
|
||||||
"source": "eino_callbacks",
|
"source": "eino_callbacks",
|
||||||
@@ -255,6 +268,7 @@ func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||||
|
ri := safeRunInfo(info)
|
||||||
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||||
if spanID == "" {
|
if spanID == "" {
|
||||||
spanID = h.popSpan()
|
spanID = h.popSpan()
|
||||||
@@ -276,9 +290,9 @@ func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err e
|
|||||||
h.params.Logger.Warn("eino_callback_error",
|
h.params.Logger.Warn("eino_callback_error",
|
||||||
zap.String("runId", h.runID),
|
zap.String("runId", h.runID),
|
||||||
zap.String("spanId", spanID),
|
zap.String("spanId", spanID),
|
||||||
zap.String("component", string(info.Component)),
|
zap.String("component", string(ri.Component)),
|
||||||
zap.String("name", info.Name),
|
zap.String("name", ri.Name),
|
||||||
zap.String("type", info.Type),
|
zap.String("type", ri.Type),
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -288,9 +302,9 @@ func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err e
|
|||||||
"spanId": spanID,
|
"spanId": spanID,
|
||||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||||
"component": string(info.Component),
|
"component": string(ri.Component),
|
||||||
"name": info.Name,
|
"name": ri.Name,
|
||||||
"type": info.Type,
|
"type": ri.Type,
|
||||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||||
"error": msg,
|
"error": msg,
|
||||||
"source": "eino_callbacks",
|
"source": "eino_callbacks",
|
||||||
@@ -300,28 +314,30 @@ func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *runHandler) onStartStreamIn(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
|
func (h *runHandler) onStartStreamIn(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
|
||||||
|
ri := safeRunInfo(info)
|
||||||
if input != nil {
|
if input != nil {
|
||||||
input.Close()
|
input.Close()
|
||||||
}
|
}
|
||||||
if h.params.Logger != nil {
|
if h.params.Logger != nil {
|
||||||
h.params.Logger.Debug("eino_callback_stream_in",
|
h.params.Logger.Debug("eino_callback_stream_in",
|
||||||
zap.String("runId", h.runID),
|
zap.String("runId", h.runID),
|
||||||
zap.String("component", string(info.Component)),
|
zap.String("component", string(ri.Component)),
|
||||||
zap.String("name", info.Name),
|
zap.String("name", ri.Name),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *runHandler) onEndStreamOut(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
|
func (h *runHandler) onEndStreamOut(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
|
||||||
|
ri := safeRunInfo(info)
|
||||||
if output != nil {
|
if output != nil {
|
||||||
output.Close()
|
output.Close()
|
||||||
}
|
}
|
||||||
if h.params.Logger != nil {
|
if h.params.Logger != nil {
|
||||||
h.params.Logger.Debug("eino_callback_stream_out",
|
h.params.Logger.Debug("eino_callback_stream_out",
|
||||||
zap.String("runId", h.runID),
|
zap.String("runId", h.runID),
|
||||||
zap.String("component", string(info.Component)),
|
zap.String("component", string(ri.Component)),
|
||||||
zap.String("name", info.Name),
|
zap.String("name", ri.Name),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return ctx
|
return ctx
|
||||||
|
|||||||
+232
-813
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@@ -65,6 +66,7 @@ type BatchTaskQueue struct {
|
|||||||
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
|
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
|
||||||
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
||||||
LastRunError string `json:"lastRunError,omitempty"`
|
LastRunError string `json:"lastRunError,omitempty"`
|
||||||
|
ProjectID string `json:"projectId,omitempty"`
|
||||||
Tasks []*BatchTask `json:"tasks"`
|
Tasks []*BatchTask `json:"tasks"`
|
||||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
@@ -103,7 +105,7 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
|
|||||||
|
|
||||||
// CreateBatchQueue 创建批量任务队列
|
// CreateBatchQueue 创建批量任务队列
|
||||||
func (m *BatchTaskManager) CreateBatchQueue(
|
func (m *BatchTaskManager) CreateBatchQueue(
|
||||||
title, role, agentMode, scheduleMode, cronExpr string,
|
title, role, agentMode, scheduleMode, cronExpr, projectID string,
|
||||||
nextRunAt *time.Time,
|
nextRunAt *time.Time,
|
||||||
tasks []string,
|
tasks []string,
|
||||||
) (*BatchTaskQueue, error) {
|
) (*BatchTaskQueue, error) {
|
||||||
@@ -126,7 +128,8 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
|||||||
ID: queueID,
|
ID: queueID,
|
||||||
Title: title,
|
Title: title,
|
||||||
Role: role,
|
Role: role,
|
||||||
AgentMode: normalizeBatchQueueAgentMode(agentMode),
|
ProjectID: strings.TrimSpace(projectID),
|
||||||
|
AgentMode: config.NormalizeAgentMode(agentMode),
|
||||||
ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode),
|
ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode),
|
||||||
CronExpr: strings.TrimSpace(cronExpr),
|
CronExpr: strings.TrimSpace(cronExpr),
|
||||||
NextRunAt: nextRunAt,
|
NextRunAt: nextRunAt,
|
||||||
@@ -171,6 +174,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
|||||||
queue.ScheduleMode,
|
queue.ScheduleMode,
|
||||||
queue.CronExpr,
|
queue.CronExpr,
|
||||||
queue.NextRunAt,
|
queue.NextRunAt,
|
||||||
|
queue.ProjectID,
|
||||||
dbTasks,
|
dbTasks,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
||||||
@@ -222,7 +226,7 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
|||||||
|
|
||||||
queue := &BatchTaskQueue{
|
queue := &BatchTaskQueue{
|
||||||
ID: queueRow.ID,
|
ID: queueRow.ID,
|
||||||
AgentMode: "single",
|
AgentMode: "eino_single",
|
||||||
ScheduleMode: "manual",
|
ScheduleMode: "manual",
|
||||||
Status: queueRow.Status,
|
Status: queueRow.Status,
|
||||||
CreatedAt: queueRow.CreatedAt,
|
CreatedAt: queueRow.CreatedAt,
|
||||||
@@ -237,7 +241,7 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
|||||||
queue.Role = queueRow.Role.String
|
queue.Role = queueRow.Role.String
|
||||||
}
|
}
|
||||||
if queueRow.AgentMode.Valid {
|
if queueRow.AgentMode.Valid {
|
||||||
queue.AgentMode = normalizeBatchQueueAgentMode(queueRow.AgentMode.String)
|
queue.AgentMode = config.NormalizeAgentMode(queueRow.AgentMode.String)
|
||||||
}
|
}
|
||||||
if queueRow.ScheduleMode.Valid {
|
if queueRow.ScheduleMode.Valid {
|
||||||
queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String)
|
queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String)
|
||||||
@@ -263,6 +267,9 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
|||||||
if queueRow.LastRunError.Valid {
|
if queueRow.LastRunError.Valid {
|
||||||
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
||||||
}
|
}
|
||||||
|
if queueRow.ProjectID.Valid {
|
||||||
|
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||||
|
}
|
||||||
if queueRow.StartedAt.Valid {
|
if queueRow.StartedAt.Valid {
|
||||||
queue.StartedAt = &queueRow.StartedAt.Time
|
queue.StartedAt = &queueRow.StartedAt.Time
|
||||||
}
|
}
|
||||||
@@ -458,7 +465,7 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
|||||||
|
|
||||||
queue := &BatchTaskQueue{
|
queue := &BatchTaskQueue{
|
||||||
ID: queueRow.ID,
|
ID: queueRow.ID,
|
||||||
AgentMode: "single",
|
AgentMode: "eino_single",
|
||||||
ScheduleMode: "manual",
|
ScheduleMode: "manual",
|
||||||
Status: queueRow.Status,
|
Status: queueRow.Status,
|
||||||
CreatedAt: queueRow.CreatedAt,
|
CreatedAt: queueRow.CreatedAt,
|
||||||
@@ -473,7 +480,7 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
|||||||
queue.Role = queueRow.Role.String
|
queue.Role = queueRow.Role.String
|
||||||
}
|
}
|
||||||
if queueRow.AgentMode.Valid {
|
if queueRow.AgentMode.Valid {
|
||||||
queue.AgentMode = normalizeBatchQueueAgentMode(queueRow.AgentMode.String)
|
queue.AgentMode = config.NormalizeAgentMode(queueRow.AgentMode.String)
|
||||||
}
|
}
|
||||||
if queueRow.ScheduleMode.Valid {
|
if queueRow.ScheduleMode.Valid {
|
||||||
queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String)
|
queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String)
|
||||||
@@ -499,6 +506,9 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
|||||||
if queueRow.LastRunError.Valid {
|
if queueRow.LastRunError.Valid {
|
||||||
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
||||||
}
|
}
|
||||||
|
if queueRow.ProjectID.Valid {
|
||||||
|
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||||
|
}
|
||||||
if queueRow.StartedAt.Valid {
|
if queueRow.StartedAt.Valid {
|
||||||
queue.StartedAt = &queueRow.StartedAt.Time
|
queue.StartedAt = &queueRow.StartedAt.Time
|
||||||
}
|
}
|
||||||
@@ -660,7 +670,7 @@ func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode s
|
|||||||
|
|
||||||
// 如果未传 agentMode,保留原值
|
// 如果未传 agentMode,保留原值
|
||||||
if strings.TrimSpace(agentMode) != "" {
|
if strings.TrimSpace(agentMode) != "" {
|
||||||
agentMode = normalizeBatchQueueAgentMode(agentMode)
|
agentMode = config.NormalizeAgentMode(agentMode)
|
||||||
} else {
|
} else {
|
||||||
agentMode = queue.AgentMode
|
agentMode = queue.AgentMode
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
|
|
||||||
@@ -134,7 +135,7 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
|
|
||||||
【何时用】用户明确要批量排队执行、Cron 周期跑同一批指令、或需要与任务管理页面对齐时调用。需要即时追问、强依赖当前对话上下文的分析/编码,应在本对话内直接完成,不要为了”委派”而创建队列。
|
【何时用】用户明确要批量排队执行、Cron 周期跑同一批指令、或需要与任务管理页面对齐时调用。需要即时追问、强依赖当前对话上下文的分析/编码,应在本对话内直接完成,不要为了”委派”而创建队列。
|
||||||
|
|
||||||
【参数】tasks(字符串数组)或 tasks_text(多行,每行一条)二选一;每项是一条将来由系统按队列顺序执行的指令文案。agent_mode:single(原生 ReAct,默认)、eino_single(Eino ADK 单代理)、deep / plan_execute / supervisor(需系统启用多代理);兼容旧值 multi(视为 deep)。非”把主对话拆给子代理”。schedule_mode:manual(默认)或 cron;cron 须填 cron_expr(5 段,如 “0 */6 * * *”)。
|
【参数】tasks(字符串数组)或 tasks_text(多行,每行一条)二选一;每项是一条将来由系统按队列顺序执行的指令文案。agent_mode:eino_single(Eino ADK 单代理,默认)、deep / plan_execute / supervisor(需系统启用多代理)。非”把主对话拆给子代理”。schedule_mode:manual(默认)或 cron;cron 须填 cron_expr(5 段,如 “0 */6 * * *”)。
|
||||||
|
|
||||||
【执行】默认创建后为 pending,不自动跑。execute_now=true 可创建后立即跑;否则之后调用 batch_task_start。Cron 自动下一轮需 schedule_enabled 为 true(可用 batch_task_schedule_enabled)。`,
|
【执行】默认创建后为 pending,不自动跑。execute_now=true 可创建后立即跑;否则之后调用 batch_task_start。Cron 自动下一轮需 schedule_enabled 为 true(可用 batch_task_schedule_enabled)。`,
|
||||||
ShortDescription: "任务管理:创建批量任务队列(登记多条指令,可选立即或 Cron)",
|
ShortDescription: "任务管理:创建批量任务队列(登记多条指令,可选立即或 Cron)",
|
||||||
@@ -160,8 +161,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
},
|
},
|
||||||
"agent_mode": map[string]interface{}{
|
"agent_mode": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "执行模式:single(原生 ReAct)、eino_single(Eino ADK)、deep/plan_execute/supervisor(Eino 编排,需启用多代理);multi 兼容为 deep",
|
"description": "执行模式:eino_single(Eino ADK,默认)、deep/plan_execute/supervisor(Eino 编排,需启用多代理)",
|
||||||
"enum": []string{"single", "eino_single", "deep", "plan_execute", "supervisor", "multi"},
|
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
||||||
},
|
},
|
||||||
"schedule_mode": map[string]interface{}{
|
"schedule_mode": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -176,6 +177,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"description": "创建后是否立即开始执行队列,默认 false(pending,需 batch_task_start)",
|
"description": "创建后是否立即开始执行队列,默认 false(pending,需 batch_task_start)",
|
||||||
},
|
},
|
||||||
|
"project_id": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
@@ -185,7 +190,7 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
}
|
}
|
||||||
title := mcpArgString(args, "title")
|
title := mcpArgString(args, "title")
|
||||||
role := mcpArgString(args, "role")
|
role := mcpArgString(args, "role")
|
||||||
agentMode := normalizeBatchQueueAgentMode(mcpArgString(args, "agent_mode"))
|
agentMode := config.NormalizeAgentMode(mcpArgString(args, "agent_mode"))
|
||||||
scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode"))
|
scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode"))
|
||||||
cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr"))
|
cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr"))
|
||||||
var nextRunAt *time.Time
|
var nextRunAt *time.Time
|
||||||
@@ -204,7 +209,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
if !ok {
|
if !ok {
|
||||||
executeNow = false
|
executeNow = false
|
||||||
}
|
}
|
||||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks)
|
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
||||||
|
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
|
||||||
if createErr != nil {
|
if createErr != nil {
|
||||||
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
||||||
}
|
}
|
||||||
@@ -388,8 +394,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
},
|
},
|
||||||
"agent_mode": map[string]interface{}{
|
"agent_mode": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "代理模式:single、eino_single、deep、plan_execute、supervisor;multi 视为 deep",
|
"description": "代理模式:eino_single、deep、plan_execute、supervisor",
|
||||||
"enum": []string{"single", "eino_single", "deep", "plan_execute", "supervisor", "multi"},
|
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": []string{"queue_id"},
|
"required": []string{"queue_id"},
|
||||||
|
|||||||
+155
-4
@@ -237,6 +237,7 @@ func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) err
|
|||||||
// GetConfigResponse 获取配置响应
|
// GetConfigResponse 获取配置响应
|
||||||
type GetConfigResponse struct {
|
type GetConfigResponse struct {
|
||||||
OpenAI config.OpenAIConfig `json:"openai"`
|
OpenAI config.OpenAIConfig `json:"openai"`
|
||||||
|
Vision config.VisionConfig `json:"vision"`
|
||||||
FOFA config.FofaConfig `json:"fofa"`
|
FOFA config.FofaConfig `json:"fofa"`
|
||||||
MCP config.MCPConfig `json:"mcp"`
|
MCP config.MCPConfig `json:"mcp"`
|
||||||
Tools []ToolConfigInfo `json:"tools"`
|
Tools []ToolConfigInfo `json:"tools"`
|
||||||
@@ -319,7 +320,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
multiPub := config.MultiAgentPublic{
|
multiPub := config.MultiAgentPublic{
|
||||||
Enabled: h.config.MultiAgent.Enabled,
|
Enabled: h.config.MultiAgent.Enabled,
|
||||||
RobotUseMultiAgent: h.config.MultiAgent.RobotUseMultiAgent,
|
RobotDefaultAgentMode: config.NormalizeRobotAgentMode(h.config.MultiAgent),
|
||||||
BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent,
|
BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent,
|
||||||
SubAgentCount: subAgentCount,
|
SubAgentCount: subAgentCount,
|
||||||
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
|
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
|
||||||
@@ -333,6 +334,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
|||||||
|
|
||||||
c.JSON(http.StatusOK, GetConfigResponse{
|
c.JSON(http.StatusOK, GetConfigResponse{
|
||||||
OpenAI: h.config.OpenAI,
|
OpenAI: h.config.OpenAI,
|
||||||
|
Vision: h.config.Vision,
|
||||||
FOFA: h.config.FOFA,
|
FOFA: h.config.FOFA,
|
||||||
MCP: h.config.MCP,
|
MCP: h.config.MCP,
|
||||||
Tools: tools,
|
Tools: tools,
|
||||||
@@ -638,6 +640,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
|||||||
// UpdateConfigRequest 更新配置请求
|
// UpdateConfigRequest 更新配置请求
|
||||||
type UpdateConfigRequest struct {
|
type UpdateConfigRequest struct {
|
||||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||||
|
Vision *config.VisionConfig `json:"vision,omitempty"`
|
||||||
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
||||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||||
@@ -707,6 +710,14 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Vision != nil {
|
||||||
|
h.config.Vision = *req.Vision
|
||||||
|
h.logger.Info("更新 Vision 配置",
|
||||||
|
zap.Bool("enabled", h.config.Vision.Enabled),
|
||||||
|
zap.String("model", h.config.Vision.Model),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// 更新FOFA配置
|
// 更新FOFA配置
|
||||||
if req.FOFA != nil {
|
if req.FOFA != nil {
|
||||||
h.config.FOFA = *req.FOFA
|
h.config.FOFA = *req.FOFA
|
||||||
@@ -779,8 +790,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
|||||||
// 多代理标量(sub_agents 等仍由 config.yaml 维护)
|
// 多代理标量(sub_agents 等仍由 config.yaml 维护)
|
||||||
if req.MultiAgent != nil {
|
if req.MultiAgent != nil {
|
||||||
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
|
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
|
||||||
h.config.MultiAgent.RobotUseMultiAgent = req.MultiAgent.RobotUseMultiAgent
|
|
||||||
h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent
|
h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent
|
||||||
|
if mode := strings.TrimSpace(req.MultiAgent.RobotDefaultAgentMode); mode != "" {
|
||||||
|
h.config.MultiAgent.RobotDefaultAgentMode = mode
|
||||||
|
} else {
|
||||||
|
h.config.MultiAgent.RobotDefaultAgentMode = "eino_single"
|
||||||
|
}
|
||||||
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
||||||
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
||||||
}
|
}
|
||||||
@@ -789,7 +804,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
h.logger.Info("更新多代理配置",
|
h.logger.Info("更新多代理配置",
|
||||||
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
||||||
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
|
zap.String("robot_default_agent_mode", config.NormalizeRobotAgentMode(h.config.MultiAgent)),
|
||||||
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
|
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
|
||||||
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
|
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
|
||||||
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
|
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
|
||||||
@@ -1027,6 +1042,99 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestVisionRequest 测试 Vision 模型连接;vision.api_key/base_url 留空时可传 openai 段作回退。
|
||||||
|
type TestVisionRequest struct {
|
||||||
|
Vision config.VisionConfig `json:"vision"`
|
||||||
|
OpenAI config.OpenAIConfig `json:"openai,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVision 测试视觉模型 API 连接(最小 chat completion)。
|
||||||
|
func (h *ConfigHandler) TestVision(c *gin.Context) {
|
||||||
|
var req TestVisionRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
oa := req.Vision.OpenAICfgEffective(req.OpenAI)
|
||||||
|
if strings.TrimSpace(oa.APIKey) == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "API Key 不能为空(可填写 vision.api_key 或 openai.api_key)"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(oa.Model) == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "视觉模型不能为空"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := strings.TrimSuffix(strings.TrimSpace(oa.BaseURL), "/")
|
||||||
|
if baseURL == "" {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
|
||||||
|
baseURL = "https://api.anthropic.com"
|
||||||
|
} else {
|
||||||
|
baseURL = "https://api.openai.com/v1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"model": oa.Model,
|
||||||
|
"messages": []map[string]string{
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
},
|
||||||
|
"max_completion_tokens": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpCfg := &config.OpenAIConfig{
|
||||||
|
Provider: oa.Provider,
|
||||||
|
BaseURL: baseURL,
|
||||||
|
APIKey: strings.TrimSpace(oa.APIKey),
|
||||||
|
Model: oa.Model,
|
||||||
|
}
|
||||||
|
client := openai.NewClient(tmpCfg, nil, h.logger)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
var chatResp struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
err := client.ChatCompletion(ctx, payload, &chatResp)
|
||||||
|
latency := time.Since(start)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if apiErr, ok := err.(*openai.APIError); ok {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", apiErr.StatusCode, apiErr.Body),
|
||||||
|
"status_code": apiErr.StatusCode,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": "连接失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(chatResp.Choices) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"error": "API 响应缺少 choices 字段,请检查 Base URL 与视觉模型名称",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"model": chatResp.Model,
|
||||||
|
"latency_ms": latency.Milliseconds(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ApplyConfig 应用配置(重新加载并重启相关服务)
|
// ApplyConfig 应用配置(重新加载并重启相关服务)
|
||||||
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||||
// 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求)
|
// 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求)
|
||||||
@@ -1282,6 +1390,7 @@ func (h *ConfigHandler) saveConfig() error {
|
|||||||
updateAgentConfig(root, h.config.Agent)
|
updateAgentConfig(root, h.config.Agent)
|
||||||
updateMCPConfig(root, h.config.MCP)
|
updateMCPConfig(root, h.config.MCP)
|
||||||
updateOpenAIConfig(root, h.config.OpenAI)
|
updateOpenAIConfig(root, h.config.OpenAI)
|
||||||
|
updateVisionConfig(root, h.config.Vision)
|
||||||
updateFOFAConfig(root, h.config.FOFA)
|
updateFOFAConfig(root, h.config.FOFA)
|
||||||
updateKnowledgeConfig(root, h.config.Knowledge)
|
updateKnowledgeConfig(root, h.config.Knowledge)
|
||||||
updateC2Config(root, h.config.C2)
|
updateC2Config(root, h.config.C2)
|
||||||
@@ -1402,6 +1511,48 @@ func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) {
|
|||||||
setIntInMap(mcpNode, "port", cfg.Port)
|
setIntInMap(mcpNode, "port", cfg.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func updateVisionConfig(doc *yaml.Node, cfg config.VisionConfig) {
|
||||||
|
root := doc.Content[0]
|
||||||
|
visionNode := ensureMap(root, "vision")
|
||||||
|
setBoolInMap(visionNode, "enabled", cfg.Enabled)
|
||||||
|
if strings.TrimSpace(cfg.APIKey) != "" {
|
||||||
|
setStringInMap(visionNode, "api_key", cfg.APIKey)
|
||||||
|
} else {
|
||||||
|
setStringInMap(visionNode, "api_key", "")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(cfg.BaseURL) != "" {
|
||||||
|
setStringInMap(visionNode, "base_url", cfg.BaseURL)
|
||||||
|
} else {
|
||||||
|
setStringInMap(visionNode, "base_url", "")
|
||||||
|
}
|
||||||
|
setStringInMap(visionNode, "model", cfg.Model)
|
||||||
|
if strings.TrimSpace(cfg.Provider) != "" {
|
||||||
|
setStringInMap(visionNode, "provider", cfg.Provider)
|
||||||
|
}
|
||||||
|
if cfg.TimeoutSeconds > 0 {
|
||||||
|
setIntInMap(visionNode, "timeout_seconds", cfg.TimeoutSeconds)
|
||||||
|
}
|
||||||
|
if cfg.MaxImageBytes > 0 {
|
||||||
|
setIntInMap(visionNode, "max_image_bytes", int(cfg.MaxImageBytes))
|
||||||
|
}
|
||||||
|
if cfg.MaxDimension > 0 {
|
||||||
|
setIntInMap(visionNode, "max_dimension", cfg.MaxDimension)
|
||||||
|
}
|
||||||
|
if cfg.JPEGQuality > 0 {
|
||||||
|
setIntInMap(visionNode, "jpeg_quality", cfg.JPEGQuality)
|
||||||
|
}
|
||||||
|
if cfg.MaxPayloadBytes > 0 {
|
||||||
|
setIntInMap(visionNode, "max_payload_bytes", int(cfg.MaxPayloadBytes))
|
||||||
|
}
|
||||||
|
setIntInMap(visionNode, "skip_preprocess_below_bytes", int(cfg.SkipPreprocessBelowBytes))
|
||||||
|
if strings.TrimSpace(cfg.Detail) != "" {
|
||||||
|
setStringInMap(visionNode, "detail", cfg.Detail)
|
||||||
|
}
|
||||||
|
if len(cfg.AllowedRoots) > 0 {
|
||||||
|
setStringSliceInMap(visionNode, "allowed_roots", cfg.AllowedRoots)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
|
func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
|
||||||
root := doc.Content[0]
|
root := doc.Content[0]
|
||||||
openaiNode := ensureMap(root, "openai")
|
openaiNode := ensureMap(root, "openai")
|
||||||
@@ -1571,7 +1722,7 @@ func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
|
|||||||
root := doc.Content[0]
|
root := doc.Content[0]
|
||||||
maNode := ensureMap(root, "multi_agent")
|
maNode := ensureMap(root, "multi_agent")
|
||||||
setBoolInMap(maNode, "enabled", cfg.Enabled)
|
setBoolInMap(maNode, "enabled", cfg.Enabled)
|
||||||
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent)
|
setStringInMap(maNode, "robot_default_agent_mode", config.NormalizeRobotAgentMode(cfg))
|
||||||
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
|
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
|
||||||
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
|
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
|
||||||
mwNode := ensureMap(maNode, "eino_middleware")
|
mwNode := ensureMap(maNode, "eino_middleware")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/audit"
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
@@ -33,7 +34,13 @@ func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHa
|
|||||||
|
|
||||||
// CreateConversationRequest 创建对话请求
|
// CreateConversationRequest 创建对话请求
|
||||||
type CreateConversationRequest struct {
|
type CreateConversationRequest struct {
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
|
ProjectID string `json:"projectId,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConversationProjectRequest 设置对话所属项目
|
||||||
|
type SetConversationProjectRequest struct {
|
||||||
|
ProjectID string `json:"projectId"` // 空字符串表示解除绑定
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateConversation 创建新对话
|
// CreateConversation 创建新对话
|
||||||
@@ -49,7 +56,9 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
|||||||
title = "新对话"
|
title = "新对话"
|
||||||
}
|
}
|
||||||
|
|
||||||
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "api"))
|
meta := audit.ConversationCreateMetaFromGin(c, "api")
|
||||||
|
meta.ProjectID = strings.TrimSpace(req.ProjectID)
|
||||||
|
conv, err := h.db.CreateConversation(title, meta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("创建对话失败", zap.Error(err))
|
h.logger.Error("创建对话失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -59,6 +68,25 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, conv)
|
c.JSON(http.StatusOK, conv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetConversationProject 设置或清除对话绑定的项目
|
||||||
|
func (h *ConversationHandler) SetConversationProject(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
var req SetConversationProjectRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := h.db.GetConversation(id); err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.db.SetConversationProjectID(id, req.ProjectID); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true, "projectId": strings.TrimSpace(req.ProjectID)})
|
||||||
|
}
|
||||||
|
|
||||||
// ListConversations 列出对话
|
// ListConversations 列出对话
|
||||||
func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
||||||
limitStr := c.DefaultQuery("limit", "50")
|
limitStr := c.DefaultQuery("limit", "50")
|
||||||
|
|||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/multiagent"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
||||||
|
if h.config != nil {
|
||||||
|
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
||||||
|
}
|
||||||
|
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
||||||
|
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
||||||
|
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||||
|
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||||
|
conversationID string,
|
||||||
|
result *multiagent.RunResult,
|
||||||
|
curHistory *[]agent.ChatMessage,
|
||||||
|
curFinalMessage *string,
|
||||||
|
segmentUserMessage string,
|
||||||
|
) {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
|
}
|
||||||
|
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||||
|
*curHistory = hist
|
||||||
|
}
|
||||||
|
if segmentUserMessage != "" {
|
||||||
|
*curFinalMessage = segmentUserMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
||||||
|
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
||||||
|
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
||||||
|
conversationID string,
|
||||||
|
result *multiagent.RunResult,
|
||||||
|
curHistory *[]agent.ChatMessage,
|
||||||
|
curFinalMessage *string,
|
||||||
|
segmentUserMessage string,
|
||||||
|
) {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
|
}
|
||||||
|
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||||
|
*curHistory = hist
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
||||||
|
*curFinalMessage = segmentUserMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
||||||
|
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
||||||
|
baseCtx context.Context,
|
||||||
|
conversationID string,
|
||||||
|
result *multiagent.RunResult,
|
||||||
|
runErr error,
|
||||||
|
transientAttempts *int,
|
||||||
|
curHistory *[]agent.ChatMessage,
|
||||||
|
curFinalMessage *string,
|
||||||
|
segmentUserMessage string,
|
||||||
|
progressCallback func(eventType, message string, data interface{}),
|
||||||
|
sendProgress func(msg string, extra map[string]interface{}),
|
||||||
|
) (handled bool, fatal error) {
|
||||||
|
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||||
|
*transientAttempts++
|
||||||
|
if *transientAttempts > maxAttempts {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
|
}
|
||||||
|
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
||||||
|
}
|
||||||
|
attemptNo := *transientAttempts
|
||||||
|
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
||||||
|
if progressCallback != nil {
|
||||||
|
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"attempt": attemptNo,
|
||||||
|
"maxAttempts": maxAttempts,
|
||||||
|
"backoffSec": int(backoff.Seconds()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-baseCtx.Done():
|
||||||
|
return false, context.Cause(baseCtx)
|
||||||
|
case <-time.After(backoff):
|
||||||
|
}
|
||||||
|
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||||
|
if progressCallback != nil {
|
||||||
|
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"attempt": attemptNo,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if sendProgress != nil {
|
||||||
|
sendProgress("正在重试…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "transient_retry",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
|
|
||||||
// EinoSingleAgentLoopStream Eino ADK 单代理(ChatModelAgent + Runner)流式对话;不依赖 multi_agent.enabled。
|
// EinoSingleAgentLoopStream Eino ADK 单代理(ChatModelAgent + Runner)流式对话;不依赖 multi_agent.enabled。
|
||||||
func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
|
|
||||||
@@ -119,6 +119,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
var cancelWithCause context.CancelCauseFunc
|
var cancelWithCause context.CancelCauseFunc
|
||||||
curFinalMessage := prep.FinalMessage
|
curFinalMessage := prep.FinalMessage
|
||||||
|
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||||
curHistory := prep.History
|
curHistory := prep.History
|
||||||
roleTools := prep.RoleTools
|
roleTools := prep.RoleTools
|
||||||
|
|
||||||
@@ -176,9 +177,41 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
taskOwned = true
|
taskOwned = true
|
||||||
|
|
||||||
var cumulativeMCPExecutionIDs []string
|
var cumulativeMCPExecutionIDs []string
|
||||||
|
var transientRunAttempts int
|
||||||
|
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||||
|
var mainIterationOffset int
|
||||||
|
|
||||||
for {
|
for {
|
||||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
segmentMainIterationMax := 0
|
||||||
|
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||||
|
progressCallback := func(eventType, message string, data interface{}) {
|
||||||
|
if eventType == "iteration" {
|
||||||
|
if m, ok := data.(map[string]interface{}); ok {
|
||||||
|
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||||
|
raw := 0
|
||||||
|
switch v := m["iteration"].(type) {
|
||||||
|
case int:
|
||||||
|
raw = v
|
||||||
|
case int32:
|
||||||
|
raw = int(v)
|
||||||
|
case int64:
|
||||||
|
raw = int(v)
|
||||||
|
case float64:
|
||||||
|
raw = int(v)
|
||||||
|
case float32:
|
||||||
|
raw = int(v)
|
||||||
|
}
|
||||||
|
if raw > 0 {
|
||||||
|
if raw > segmentMainIterationMax {
|
||||||
|
segmentMainIterationMax = raw
|
||||||
|
}
|
||||||
|
m["iteration"] = raw + mainIterationOffset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rawProgressCallback(eventType, message, data)
|
||||||
|
}
|
||||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||||
@@ -197,17 +230,38 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
roleTools,
|
roleTools,
|
||||||
progressCallback,
|
progressCallback,
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
|
h.projectBlackboardBlock(conversationID),
|
||||||
)
|
)
|
||||||
timeoutCancel()
|
|
||||||
|
|
||||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
|
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||||
|
transientRunAttempts = 0
|
||||||
|
timeoutCancel()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||||
|
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||||
|
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||||
|
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||||
|
)
|
||||||
|
if handled {
|
||||||
|
mainIterationOffset += segmentMainIterationMax
|
||||||
|
timeoutCancel()
|
||||||
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
|
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fatalErr != nil {
|
||||||
|
runErr = fatalErr
|
||||||
|
}
|
||||||
|
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
@@ -231,10 +285,14 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "interrupt_continue",
|
"source": "interrupt_continue",
|
||||||
})
|
})
|
||||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
mainIterationOffset += segmentMainIterationMax
|
||||||
|
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||||
|
transientRunAttempts = 0
|
||||||
|
timeoutCancel()
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,6 +319,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
"messageId": assistantMessageID,
|
"messageId": assistantMessageID,
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,6 +337,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
"errorType": "timeout",
|
"errorType": "timeout",
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,9 +354,12 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
"messageId": assistantMessageID,
|
"messageId": assistantMessageID,
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
timeoutCancel()
|
||||||
|
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||||
}
|
}
|
||||||
@@ -367,6 +430,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
|||||||
prep.RoleTools,
|
prep.RoleTools,
|
||||||
progressCallback,
|
progressCallback,
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
|
h.projectBlackboardBlock(prep.ConversationID),
|
||||||
)
|
)
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/multiagent"
|
"cyberstrike-ai/internal/multiagent"
|
||||||
|
|
||||||
@@ -691,35 +690,6 @@ func (h *AgentHandler) interceptHITLForEinoTool(runCtx context.Context, cancelRu
|
|||||||
return arguments, nil
|
return arguments, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AgentHandler) interceptHITLForReactTool(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{}), toolName string, arguments map[string]interface{}, toolCallID string) (map[string]interface{}, error) {
|
|
||||||
payload := map[string]interface{}{
|
|
||||||
"toolName": toolName,
|
|
||||||
"argumentsObj": arguments,
|
|
||||||
"toolCallId": toolCallID,
|
|
||||||
"source": "react_pre_exec",
|
|
||||||
}
|
|
||||||
d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, toolCallID, payload, sendEventFunc)
|
|
||||||
if err != nil || d == nil {
|
|
||||||
return arguments, err
|
|
||||||
}
|
|
||||||
if d.Decision == "reject" {
|
|
||||||
comment := strings.TrimSpace(d.Comment)
|
|
||||||
if comment == "" {
|
|
||||||
comment = "no extra feedback"
|
|
||||||
}
|
|
||||||
return arguments, errors.New("human rejected this tool call; feedback: " + comment)
|
|
||||||
}
|
|
||||||
if len(d.EditedArguments) > 0 {
|
|
||||||
return d.EditedArguments, nil
|
|
||||||
}
|
|
||||||
return arguments, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *AgentHandler) injectReactHITLInterceptor(ctx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) context.Context {
|
|
||||||
return agent.WithToolCallInterceptor(ctx, func(c context.Context, toolName string, args map[string]interface{}, toolCallID string) (map[string]interface{}, error) {
|
|
||||||
return h.interceptHITLForReactTool(c, cancelRun, conversationID, assistantMessageID, sendEventFunc, toolName, args, toolCallID)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type hitlConfigReq struct {
|
type hitlConfigReq struct {
|
||||||
ConversationID string `json:"conversationId" binding:"required"`
|
ConversationID string `json:"conversationId" binding:"required"`
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
|
|
||||||
// MultiAgentLoopStream Eino DeepAgent 流式对话(需 config.multi_agent.enabled)。
|
// MultiAgentLoopStream Eino DeepAgent 流式对话(需 config.multi_agent.enabled)。
|
||||||
func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
if h.config == nil || !h.config.MultiAgent.Enabled {
|
if h.config == nil || !h.config.MultiAgent.Enabled {
|
||||||
@@ -136,6 +136,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
var cancelWithCause context.CancelCauseFunc
|
var cancelWithCause context.CancelCauseFunc
|
||||||
curFinalMessage := prep.FinalMessage
|
curFinalMessage := prep.FinalMessage
|
||||||
|
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||||
curHistory := prep.History
|
curHistory := prep.History
|
||||||
roleTools := prep.RoleTools
|
roleTools := prep.RoleTools
|
||||||
orch := strings.TrimSpace(req.Orchestration)
|
orch := strings.TrimSpace(req.Orchestration)
|
||||||
@@ -186,9 +187,41 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||||
var cumulativeMCPExecutionIDs []string
|
var cumulativeMCPExecutionIDs []string
|
||||||
|
var transientRunAttempts int
|
||||||
|
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||||
|
var mainIterationOffset int
|
||||||
|
|
||||||
for {
|
for {
|
||||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
segmentMainIterationMax := 0
|
||||||
|
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||||
|
progressCallback := func(eventType, message string, data interface{}) {
|
||||||
|
if eventType == "iteration" {
|
||||||
|
if m, ok := data.(map[string]interface{}); ok {
|
||||||
|
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||||
|
raw := 0
|
||||||
|
switch v := m["iteration"].(type) {
|
||||||
|
case int:
|
||||||
|
raw = v
|
||||||
|
case int32:
|
||||||
|
raw = int(v)
|
||||||
|
case int64:
|
||||||
|
raw = int(v)
|
||||||
|
case float64:
|
||||||
|
raw = int(v)
|
||||||
|
case float32:
|
||||||
|
raw = int(v)
|
||||||
|
}
|
||||||
|
if raw > 0 {
|
||||||
|
if raw > segmentMainIterationMax {
|
||||||
|
segmentMainIterationMax = raw
|
||||||
|
}
|
||||||
|
m["iteration"] = raw + mainIterationOffset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rawProgressCallback(eventType, message, data)
|
||||||
|
}
|
||||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||||
@@ -209,17 +242,38 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
h.agentsMarkdownDir,
|
h.agentsMarkdownDir,
|
||||||
orch,
|
orch,
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
|
h.projectBlackboardBlock(conversationID),
|
||||||
)
|
)
|
||||||
timeoutCancel()
|
|
||||||
|
|
||||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
|
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||||
|
transientRunAttempts = 0
|
||||||
|
timeoutCancel()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||||
|
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||||
|
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||||
|
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||||
|
)
|
||||||
|
if handled {
|
||||||
|
mainIterationOffset += segmentMainIterationMax
|
||||||
|
timeoutCancel()
|
||||||
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
|
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fatalErr != nil {
|
||||||
|
runErr = fatalErr
|
||||||
|
}
|
||||||
|
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
@@ -243,10 +297,14 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "interrupt_continue",
|
"source": "interrupt_continue",
|
||||||
})
|
})
|
||||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
mainIterationOffset += segmentMainIterationMax
|
||||||
|
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||||
|
transientRunAttempts = 0
|
||||||
|
timeoutCancel()
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -273,6 +331,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
"messageId": assistantMessageID,
|
"messageId": assistantMessageID,
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -290,6 +349,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
"errorType": "timeout",
|
"errorType": "timeout",
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,9 +366,12 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
"messageId": assistantMessageID,
|
"messageId": assistantMessageID,
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
timeoutCancel()
|
||||||
|
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||||
}
|
}
|
||||||
@@ -332,7 +395,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
}
|
}
|
||||||
|
|
||||||
// MultiAgentLoop Eino DeepAgent 非流式对话(与 POST /api/agent-loop 对齐,需 multi_agent.enabled)。
|
// MultiAgentLoop Eino DeepAgent 非流式对话(需 multi_agent.enabled)。
|
||||||
func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||||
if h.config == nil || !h.config.MultiAgent.Enabled {
|
if h.config == nil || !h.config.MultiAgent.Enabled {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "多代理未启用,请在 config.yaml 中设置 multi_agent.enabled: true"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "多代理未启用,请在 config.yaml 中设置 multi_agent.enabled: true"})
|
||||||
@@ -381,6 +444,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
h.agentsMarkdownDir,
|
h.agentsMarkdownDir,
|
||||||
strings.TrimSpace(req.Orchestration),
|
strings.TrimSpace(req.Orchestration),
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
|
h.projectBlackboardBlock(prep.ConversationID),
|
||||||
)
|
)
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context
|
|||||||
var conv *database.Conversation
|
var conv *database.Conversation
|
||||||
var err error
|
var err error
|
||||||
meta := audit.ConversationCreateMetaFromGin(c, source)
|
meta := audit.ConversationCreateMetaFromGin(c, source)
|
||||||
|
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||||
if strings.TrimSpace(req.WebShellConnectionID) != "" {
|
if strings.TrimSpace(req.WebShellConnectionID) != "" {
|
||||||
meta.Source = source + "_webshell"
|
meta.Source = source + "_webshell"
|
||||||
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
|
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
|
||||||
@@ -90,6 +91,14 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context
|
|||||||
builtin.ToolWebshellFileRead,
|
builtin.ToolWebshellFileRead,
|
||||||
builtin.ToolWebshellFileWrite,
|
builtin.ToolWebshellFileWrite,
|
||||||
builtin.ToolRecordVulnerability,
|
builtin.ToolRecordVulnerability,
|
||||||
|
builtin.ToolListVulnerabilities,
|
||||||
|
builtin.ToolGetVulnerability,
|
||||||
|
builtin.ToolUpsertProjectFact,
|
||||||
|
builtin.ToolGetProjectFact,
|
||||||
|
builtin.ToolListProjectFacts,
|
||||||
|
builtin.ToolSearchProjectFacts,
|
||||||
|
builtin.ToolDeprecateProjectFact,
|
||||||
|
builtin.ToolRestoreProjectFact,
|
||||||
builtin.ToolListKnowledgeRiskTypes,
|
builtin.ToolListKnowledgeRiskTypes,
|
||||||
builtin.ToolSearchKnowledgeBase,
|
builtin.ToolSearchKnowledgeBase,
|
||||||
}
|
}
|
||||||
|
|||||||
+241
-150
@@ -73,8 +73,22 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"description": "对话标题",
|
"description": "对话标题",
|
||||||
"example": "Web应用安全测试",
|
"example": "Web应用安全测试",
|
||||||
},
|
},
|
||||||
|
"projectId": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "绑定的项目 ID(可选,共享事实黑板)",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"SetConversationProjectRequest": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"projectId": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "项目 ID;空字符串表示解除绑定",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"projectId"},
|
||||||
|
},
|
||||||
"Conversation": map[string]interface{}{
|
"Conversation": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
@@ -98,6 +112,10 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"format": "date-time",
|
"format": "date-time",
|
||||||
"description": "更新时间",
|
"description": "更新时间",
|
||||||
},
|
},
|
||||||
|
"projectId": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "绑定的项目 ID(可选)",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"ConversationDetail": map[string]interface{}{
|
"ConversationDetail": map[string]interface{}{
|
||||||
@@ -405,8 +423,8 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
"agentMode": map[string]interface{}{
|
"agentMode": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "代理模式:single(原生 ReAct)| eino_single(Eino ADK 单代理)| deep | plan_execute | supervisor;react 同 single;旧值 multi 按 deep",
|
"description": "代理模式:eino_single(Eino ADK 单代理,默认)| deep | plan_execute | supervisor",
|
||||||
"enum": []string{"single", "eino_single", "deep", "plan_execute", "supervisor", "multi", "react"},
|
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
||||||
},
|
},
|
||||||
"scheduleMode": map[string]interface{}{
|
"scheduleMode": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -760,11 +778,55 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
"ConfigResponse": map[string]interface{}{
|
"ConfigResponse": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": "配置信息",
|
"description": "配置信息(含 openai、vision、multi_agent 等)",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"vision": map[string]interface{}{
|
||||||
|
"$ref": "#/components/schemas/VisionConfig",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"UpdateConfigRequest": map[string]interface{}{
|
"UpdateConfigRequest": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": "更新配置请求",
|
"description": "更新配置请求",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"vision": map[string]interface{}{
|
||||||
|
"$ref": "#/components/schemas/VisionConfig",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"VisionConfig": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"description": "视觉分析(analyze_image MCP 工具);enabled 且 model 非空时注册工具",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"enabled": map[string]interface{}{"type": "boolean", "description": "是否启用 analyze_image"},
|
||||||
|
"model": map[string]interface{}{"type": "string", "description": "视觉模型名(必填)", "example": "qwen-vl-max"},
|
||||||
|
"api_key": map[string]interface{}{"type": "string", "description": "API Key;留空复用 openai.api_key"},
|
||||||
|
"base_url": map[string]interface{}{"type": "string", "description": "Base URL;留空复用 openai.base_url"},
|
||||||
|
"provider": map[string]interface{}{"type": "string", "description": "提供商;留空复用 openai.provider"},
|
||||||
|
"timeout_seconds": map[string]interface{}{"type": "integer", "description": "VL 调用超时(秒)"},
|
||||||
|
"max_image_bytes": map[string]interface{}{"type": "integer", "description": "原始文件大小上限(字节)"},
|
||||||
|
"max_dimension": map[string]interface{}{"type": "integer", "description": "长边缩放像素"},
|
||||||
|
"jpeg_quality": map[string]interface{}{"type": "integer", "description": "JPEG 质量 60-100"},
|
||||||
|
"max_payload_bytes": map[string]interface{}{"type": "integer", "description": "送 API 体积上限(字节)"},
|
||||||
|
"skip_preprocess_below_bytes": map[string]interface{}{"type": "integer", "description": "低于该字节且尺寸合规时可原图直传;0=始终压缩"},
|
||||||
|
"detail": map[string]interface{}{"type": "string", "enum": []string{"low", "high", "auto"}, "description": "OpenAI 兼容 image detail"},
|
||||||
|
"allowed_roots": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "额外允许读取的绝对路径根"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"AnalyzeImageToolCall": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"description": "内置 MCP 工具 analyze_image:分析服务器本地图片,返回纯文本(验证码/UI/报错等)",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"path": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "图片路径(cwd、chat_uploads、result_storage_dir 或 allowed_roots 下)",
|
||||||
|
},
|
||||||
|
"question": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "可选:重点问题;验证码建议「只输出验证码字符」",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"path"},
|
||||||
},
|
},
|
||||||
"ExternalMCPConfig": map[string]interface{}{
|
"ExternalMCPConfig": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -1103,7 +1165,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"post": map[string]interface{}{
|
"post": map[string]interface{}{
|
||||||
"tags": []string{"对话管理"},
|
"tags": []string{"对话管理"},
|
||||||
"summary": "创建对话",
|
"summary": "创建对话",
|
||||||
"description": "创建一个新的安全测试对话。\n**重要说明**:\n- ✅ 创建的对话会**立即保存到数据库**\n- ✅ 前端页面会**自动刷新**显示新对话\n- ✅ 与前端创建的对话**完全一致**\n**创建对话的两种方式**:\n**方式1(推荐):** 直接使用 `/api/agent-loop` 发送消息,**不提供** `conversationId` 参数,系统会自动创建新对话并发送消息。这是最简单的方式,一步完成创建和发送。\n**方式2:** 先调用此端点创建空对话,然后使用返回的 `conversationId` 调用 `/api/agent-loop` 发送消息。适用于需要先创建对话,稍后再发送消息的场景。\n**示例**:\n```json\n{\n \"title\": \"Web应用安全测试\"\n}\n```",
|
"description": "创建一个新的安全测试对话。\n**重要说明**:\n- ✅ 创建的对话会**立即保存到数据库**\n- ✅ 前端页面会**自动刷新**显示新对话\n- ✅ 与前端创建的对话**完全一致**\n**创建对话的两种方式**:\n**方式1(推荐):** 直接使用 `/api/eino-agent` 发送消息,**不提供** `conversationId` 参数,系统会自动创建新对话并发送消息。这是最简单的方式,一步完成创建和发送。\n**方式2:** 先调用此端点创建空对话,然后使用返回的 `conversationId` 调用 `/api/eino-agent` 发送消息。适用于需要先创建对话,稍后再发送消息的场景。\n**示例**:\n```json\n{\n \"title\": \"Web应用安全测试\"\n}\n```",
|
||||||
"operationId": "createConversation",
|
"operationId": "createConversation",
|
||||||
"requestBody": map[string]interface{}{
|
"requestBody": map[string]interface{}{
|
||||||
"required": true,
|
"required": true,
|
||||||
@@ -1326,6 +1388,37 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"/api/conversations/{id}/project": map[string]interface{}{
|
||||||
|
"put": map[string]interface{}{
|
||||||
|
"tags": []string{"对话管理"},
|
||||||
|
"summary": "设置对话所属项目",
|
||||||
|
"description": "绑定或解除对话与项目的关联,用于共享事实黑板",
|
||||||
|
"operationId": "setConversationProject",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "id", "in": "path", "required": true,
|
||||||
|
"description": "对话ID",
|
||||||
|
"schema": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"requestBody": map[string]interface{}{
|
||||||
|
"required": true,
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"$ref": "#/components/schemas/SetConversationProjectRequest",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{
|
||||||
|
"200": map[string]interface{}{"description": "设置成功"},
|
||||||
|
"400": map[string]interface{}{"description": "项目不存在或参数错误"},
|
||||||
|
"404": map[string]interface{}{"description": "对话不存在"},
|
||||||
|
"401": map[string]interface{}{"description": "未授权"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"/api/conversations/{id}/results": map[string]interface{}{
|
"/api/conversations/{id}/results": map[string]interface{}{
|
||||||
"get": map[string]interface{}{
|
"get": map[string]interface{}{
|
||||||
"tags": []string{"对话管理"},
|
"tags": []string{"对话管理"},
|
||||||
@@ -1363,148 +1456,11 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"/api/agent-loop": map[string]interface{}{
|
|
||||||
"post": map[string]interface{}{
|
|
||||||
"tags": []string{"对话交互"},
|
|
||||||
"summary": "发送消息并获取AI回复(非流式)",
|
|
||||||
"description": "向AI发送消息并获取回复(非流式响应)。**这是与AI交互的核心端点**,与前端聊天功能完全一致。\n**重要说明**:\n- ✅ 通过此API创建/发送的消息会**立即保存到数据库**\n- ✅ 前端页面会**自动刷新**显示新创建的对话和消息\n- ✅ 所有操作都有**完整的交互痕迹**,就像在前端操作一样\n- ✅ 支持角色配置,可以指定使用哪个测试角色\n**推荐使用流程**:\n1. **先创建对话**:调用 `POST /api/conversations` 创建新对话,获取 `conversationId`\n2. **再发送消息**:使用返回的 `conversationId` 调用此端点发送消息\n**使用示例**:\n**步骤1 - 创建对话:**\n```json\nPOST /api/conversations\n{\n \"title\": \"Web应用安全测试\"\n}\n```\n**步骤2 - 发送消息:**\n```json\nPOST /api/agent-loop\n{\n \"conversationId\": \"返回的对话ID\",\n \"message\": \"扫描 http://example.com 的SQL注入漏洞\",\n \"role\": \"渗透测试\"\n}\n```\n**其他方式**:\n如果不提供 `conversationId`,系统会自动创建新对话并发送消息。但**推荐先创建对话**,这样可以更好地管理对话列表。\n**响应**:返回AI的回复、对话ID和MCP执行ID列表。前端会自动刷新显示新消息。",
|
|
||||||
"operationId": "sendMessage",
|
|
||||||
"requestBody": map[string]interface{}{
|
|
||||||
"required": true,
|
|
||||||
"content": map[string]interface{}{
|
|
||||||
"application/json": map[string]interface{}{
|
|
||||||
"schema": map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "要发送的消息(必需)",
|
|
||||||
"example": "扫描 http://example.com 的SQL注入漏洞",
|
|
||||||
},
|
|
||||||
"conversationId": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "对话ID(可选)。\n- **不提供**:自动创建新对话并发送消息(推荐)\n- **提供**:消息会添加到指定对话中(对话必须存在)",
|
|
||||||
"example": "550e8400-e29b-41d4-a716-446655440000",
|
|
||||||
},
|
|
||||||
"role": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "角色名称(可选),如:默认、渗透测试、Web应用扫描等",
|
|
||||||
"example": "默认",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": []string{"message"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"responses": map[string]interface{}{
|
|
||||||
"200": map[string]interface{}{
|
|
||||||
"description": "消息发送成功,返回AI回复",
|
|
||||||
"content": map[string]interface{}{
|
|
||||||
"application/json": map[string]interface{}{
|
|
||||||
"schema": map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"response": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "AI的回复内容",
|
|
||||||
},
|
|
||||||
"conversationId": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "对话ID",
|
|
||||||
},
|
|
||||||
"mcpExecutionIds": map[string]interface{}{
|
|
||||||
"type": "array",
|
|
||||||
"description": "MCP执行ID列表",
|
|
||||||
"items": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"time": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"format": "date-time",
|
|
||||||
"description": "响应时间",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"400": map[string]interface{}{
|
|
||||||
"description": "请求参数错误",
|
|
||||||
},
|
|
||||||
"401": map[string]interface{}{
|
|
||||||
"description": "未授权,需要有效的Token",
|
|
||||||
},
|
|
||||||
"500": map[string]interface{}{
|
|
||||||
"description": "服务器内部错误",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"/api/agent-loop/stream": map[string]interface{}{
|
|
||||||
"post": map[string]interface{}{
|
|
||||||
"tags": []string{"对话交互"},
|
|
||||||
"summary": "发送消息并获取AI回复(流式)",
|
|
||||||
"description": "向AI发送消息并获取流式回复(Server-Sent Events)。**这是与AI交互的核心端点**,与前端聊天功能完全一致。\n**重要说明**:\n- ✅ 通过此API创建/发送的消息会**立即保存到数据库**\n- ✅ 前端页面会**自动刷新**显示新创建的对话和消息\n- ✅ 所有操作都有**完整的交互痕迹**,就像在前端操作一样\n- ✅ 支持角色配置,可以指定使用哪个测试角色\n- ✅ 返回流式响应,适合实时显示AI回复\n**推荐使用流程**:\n1. **先创建对话**:调用 `POST /api/conversations` 创建新对话,获取 `conversationId`\n2. **再发送消息**:使用返回的 `conversationId` 调用此端点发送消息\n**使用示例**:\n**步骤1 - 创建对话:**\n```json\nPOST /api/conversations\n{\n \"title\": \"Web应用安全测试\"\n}\n```\n**步骤2 - 发送消息(流式):**\n```json\nPOST /api/agent-loop/stream\n{\n \"conversationId\": \"返回的对话ID\",\n \"message\": \"扫描 http://example.com 的SQL注入漏洞\",\n \"role\": \"渗透测试\"\n}\n```\n**响应格式**:Server-Sent Events (SSE),事件类型包括:\n- `message`: 用户消息确认\n- `response`: AI回复片段\n- `progress`: 进度更新\n- `done`: 完成\n- `error`: 错误\n- `cancelled`: 已取消",
|
|
||||||
"operationId": "sendMessageStream",
|
|
||||||
"requestBody": map[string]interface{}{
|
|
||||||
"required": true,
|
|
||||||
"content": map[string]interface{}{
|
|
||||||
"application/json": map[string]interface{}{
|
|
||||||
"schema": map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "要发送的消息(必需)",
|
|
||||||
"example": "扫描 http://example.com 的SQL注入漏洞",
|
|
||||||
},
|
|
||||||
"conversationId": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "对话ID(可选)。\n- **不提供**:自动创建新对话并发送消息(推荐)\n- **提供**:消息会添加到指定对话中(对话必须存在)",
|
|
||||||
"example": "550e8400-e29b-41d4-a716-446655440000",
|
|
||||||
},
|
|
||||||
"role": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "角色名称(可选),如:默认、渗透测试、Web应用扫描等",
|
|
||||||
"example": "默认",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": []string{"message"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"responses": map[string]interface{}{
|
|
||||||
"200": map[string]interface{}{
|
|
||||||
"description": "流式响应(Server-Sent Events)",
|
|
||||||
"content": map[string]interface{}{
|
|
||||||
"text/event-stream": map[string]interface{}{
|
|
||||||
"schema": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "SSE流式数据",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"400": map[string]interface{}{
|
|
||||||
"description": "请求参数错误",
|
|
||||||
},
|
|
||||||
"401": map[string]interface{}{
|
|
||||||
"description": "未授权,需要有效的Token",
|
|
||||||
},
|
|
||||||
"500": map[string]interface{}{
|
|
||||||
"description": "服务器内部错误",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"/api/eino-agent": map[string]interface{}{
|
"/api/eino-agent": map[string]interface{}{
|
||||||
"post": map[string]interface{}{
|
"post": map[string]interface{}{
|
||||||
"tags": []string{"对话交互"},
|
"tags": []string{"对话交互"},
|
||||||
"summary": "发送消息并获取 AI 回复(Eino ADK 单代理,非流式)",
|
"summary": "发送消息并获取 AI 回复(Eino ADK 单代理,非流式)",
|
||||||
"description": "与 `POST /api/agent-loop` 请求体相同,由 **CloudWeGo Eino** `adk.NewChatModelAgent` + `adk.NewRunner.Run` 执行(单代理 MCP 工具链)。**不依赖** `multi_agent.enabled`;`multi_agent.eino_skills` / `eino_middleware` 等与多代理主代理一致时可生效。支持 `webshellConnectionId`。",
|
"description": "向 AI 发送消息并获取回复(非流式)。由 **CloudWeGo Eino** `adk.NewChatModelAgent` + `adk.NewRunner.Run` 执行单代理 MCP 工具链。**不依赖** `multi_agent.enabled`;`multi_agent.eino_skills` / `eino_middleware` 等与多代理主代理一致时可生效。支持 `webshellConnectionId`、角色与附件。",
|
||||||
"operationId": "sendMessageEinoSingleAgent",
|
"operationId": "sendMessageEinoSingleAgent",
|
||||||
"requestBody": map[string]interface{}{
|
"requestBody": map[string]interface{}{
|
||||||
"required": true,
|
"required": true,
|
||||||
@@ -1524,7 +1480,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"responses": map[string]interface{}{
|
"responses": map[string]interface{}{
|
||||||
"200": map[string]interface{}{"description": "成功,响应格式同 /api/agent-loop"},
|
"200": map[string]interface{}{"description": "成功,响应格式同 /api/eino-agent"},
|
||||||
"400": map[string]interface{}{"description": "参数错误"},
|
"400": map[string]interface{}{"description": "参数错误"},
|
||||||
"401": map[string]interface{}{"description": "未授权"},
|
"401": map[string]interface{}{"description": "未授权"},
|
||||||
"500": map[string]interface{}{"description": "执行失败"},
|
"500": map[string]interface{}{"description": "执行失败"},
|
||||||
@@ -1535,7 +1491,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"post": map[string]interface{}{
|
"post": map[string]interface{}{
|
||||||
"tags": []string{"对话交互"},
|
"tags": []string{"对话交互"},
|
||||||
"summary": "发送消息并获取 AI 回复(Eino ADK 单代理,SSE)",
|
"summary": "发送消息并获取 AI 回复(Eino ADK 单代理,SSE)",
|
||||||
"description": "与 `POST /api/agent-loop/stream` 类似;由 Eino **单代理** ADK 执行。事件类型与多代理流式一致(含 `tool_call` / `response_delta` 等)。**不依赖** `multi_agent.enabled`。",
|
"description": "向 AI 发送消息并获取流式回复(SSE)。由 Eino **单代理** ADK 执行;事件类型与多代理流式一致(含 `tool_call` / `response_delta` / `thinking` 等)。**不依赖** `multi_agent.enabled`。",
|
||||||
"operationId": "sendMessageEinoSingleAgentStream",
|
"operationId": "sendMessageEinoSingleAgentStream",
|
||||||
"requestBody": map[string]interface{}{
|
"requestBody": map[string]interface{}{
|
||||||
"required": true,
|
"required": true,
|
||||||
@@ -1574,7 +1530,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"post": map[string]interface{}{
|
"post": map[string]interface{}{
|
||||||
"tags": []string{"对话交互"},
|
"tags": []string{"对话交互"},
|
||||||
"summary": "发送消息并获取 AI 回复(Eino 多代理,非流式)",
|
"summary": "发送消息并获取 AI 回复(Eino 多代理,非流式)",
|
||||||
"description": "与 `POST /api/agent-loop` 请求体相同,但由 **CloudWeGo Eino** 多代理执行。编排由请求体 `orchestration`(`deep` | `plan_execute` | `supervisor`)指定,缺省为 `deep`。**前提**:`multi_agent.enabled: true`;未启用时返回 404 JSON。支持 `webshellConnectionId`。",
|
"description": "与 `POST /api/eino-agent` 请求体相同,但由 **CloudWeGo Eino** 多代理执行。编排由请求体 `orchestration`(`deep` | `plan_execute` | `supervisor`)指定,缺省为 `deep`。**前提**:`multi_agent.enabled: true`;未启用时返回 404 JSON。支持 `webshellConnectionId`。",
|
||||||
"operationId": "sendMessageMultiAgent",
|
"operationId": "sendMessageMultiAgent",
|
||||||
"requestBody": map[string]interface{}{
|
"requestBody": map[string]interface{}{
|
||||||
"required": true,
|
"required": true,
|
||||||
@@ -1597,7 +1553,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
"webshellConnectionId": map[string]interface{}{
|
"webshellConnectionId": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "WebShell 连接 ID(可选,与 agent-loop 行为一致)",
|
"description": "WebShell 连接 ID(可选,与 Eino 单/多代理流式行为一致)",
|
||||||
},
|
},
|
||||||
"orchestration": map[string]interface{}{
|
"orchestration": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -1612,7 +1568,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
"responses": map[string]interface{}{
|
"responses": map[string]interface{}{
|
||||||
"200": map[string]interface{}{
|
"200": map[string]interface{}{
|
||||||
"description": "成功,响应格式同 /api/agent-loop",
|
"description": "成功,响应格式同 /api/eino-agent",
|
||||||
},
|
},
|
||||||
"400": map[string]interface{}{"description": "参数错误"},
|
"400": map[string]interface{}{"description": "参数错误"},
|
||||||
"401": map[string]interface{}{"description": "未授权"},
|
"401": map[string]interface{}{"description": "未授权"},
|
||||||
@@ -1625,7 +1581,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"post": map[string]interface{}{
|
"post": map[string]interface{}{
|
||||||
"tags": []string{"对话交互"},
|
"tags": []string{"对话交互"},
|
||||||
"summary": "发送消息并获取 AI 回复(Eino 多代理,SSE)",
|
"summary": "发送消息并获取 AI 回复(Eino 多代理,SSE)",
|
||||||
"description": "与 `POST /api/agent-loop/stream` 类似;由 Eino 多代理执行。`orchestration` 指定 deep / plan_execute / supervisor,缺省 deep。**前提**:`multi_agent.enabled: true`;未启用时 SSE 内首条为 `type: error` 后接 `done`。支持 `webshellConnectionId`。",
|
"description": "与 `POST /api/eino-agent/stream` 类似;由 Eino 多代理执行。`orchestration` 指定 deep / plan_execute / supervisor,缺省 deep。**前提**:`multi_agent.enabled: true`;未启用时 SSE 内首条为 `type: error` 后接 `done`。支持 `webshellConnectionId`。",
|
||||||
"operationId": "sendMessageMultiAgentStream",
|
"operationId": "sendMessageMultiAgentStream",
|
||||||
"requestBody": map[string]interface{}{
|
"requestBody": map[string]interface{}{
|
||||||
"required": true,
|
"required": true,
|
||||||
@@ -2444,6 +2400,86 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"/api/projects": map[string]interface{}{
|
||||||
|
"get": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"},
|
||||||
|
"summary": "列出项目",
|
||||||
|
"operationId": "listProjects",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "status", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"active", "archived"}}},
|
||||||
|
{"name": "limit", "in": "query", "schema": map[string]interface{}{"type": "integer", "default": 200}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{
|
||||||
|
"200": map[string]interface{}{"description": "项目列表"},
|
||||||
|
"401": map[string]interface{}{"description": "未授权"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"post": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"},
|
||||||
|
"summary": "创建项目",
|
||||||
|
"operationId": "createProject",
|
||||||
|
"requestBody": map[string]interface{}{
|
||||||
|
"required": true,
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"name": map[string]interface{}{"type": "string"},
|
||||||
|
"description": map[string]interface{}{"type": "string"},
|
||||||
|
"scope_json": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
"required": []string{"name"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{
|
||||||
|
"200": map[string]interface{}{"description": "创建成功"},
|
||||||
|
"401": map[string]interface{}{"description": "未授权"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/api/projects/{id}": map[string]interface{}{
|
||||||
|
"get": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "获取项目", "operationId": "getProject",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "项目详情"}},
|
||||||
|
},
|
||||||
|
"put": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "更新项目", "operationId": "updateProject",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "更新成功"}},
|
||||||
|
},
|
||||||
|
"delete": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "删除项目", "operationId": "deleteProject",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/api/projects/{id}/facts": map[string]interface{}{
|
||||||
|
"get": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "列出或按 key 获取事实", "operationId": "listProjectFacts",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}},
|
||||||
|
},
|
||||||
|
"post": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
"/api/vulnerabilities": map[string]interface{}{
|
"/api/vulnerabilities": map[string]interface{}{
|
||||||
"get": map[string]interface{}{
|
"get": map[string]interface{}{
|
||||||
"tags": []string{"漏洞管理"},
|
"tags": []string{"漏洞管理"},
|
||||||
@@ -2502,6 +2538,15 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "project_id",
|
||||||
|
"in": "query",
|
||||||
|
"required": false,
|
||||||
|
"description": "项目ID",
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "severity",
|
"name": "severity",
|
||||||
"in": "query",
|
"in": "query",
|
||||||
@@ -4652,7 +4697,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"title": map[string]interface{}{"type": "string", "description": "队列标题"},
|
"title": map[string]interface{}{"type": "string", "description": "队列标题"},
|
||||||
"role": map[string]interface{}{"type": "string", "description": "使用的角色名称"},
|
"role": map[string]interface{}{"type": "string", "description": "使用的角色名称"},
|
||||||
"agentMode": map[string]interface{}{"type": "string", "description": "代理模式", "enum": []string{"single", "eino_single", "deep", "plan_execute", "supervisor"}},
|
"agentMode": map[string]interface{}{"type": "string", "description": "代理模式", "enum": []string{"eino_single", "deep", "plan_execute", "supervisor"}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -4899,6 +4944,52 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// ==================== 配置管理 - 缺失端点 ====================
|
// ==================== 配置管理 - 缺失端点 ====================
|
||||||
|
"/api/config/test-vision": map[string]interface{}{
|
||||||
|
"post": map[string]interface{}{
|
||||||
|
"tags": []string{"配置管理"},
|
||||||
|
"summary": "测试视觉模型连接",
|
||||||
|
"description": "测试 Vision 模型 API 是否可用。vision.api_key/base_url 留空时可传 openai 段作回退。",
|
||||||
|
"operationId": "testVision",
|
||||||
|
"requestBody": map[string]interface{}{
|
||||||
|
"required": true,
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"required": []string{"vision"},
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"vision": map[string]interface{}{"$ref": "#/components/schemas/VisionConfig"},
|
||||||
|
"openai": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"description": "主 LLM 配置(vision 字段留空时用于 API Key/Base URL 回退)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{
|
||||||
|
"200": map[string]interface{}{
|
||||||
|
"description": "测试结果",
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"success": map[string]interface{}{"type": "boolean"},
|
||||||
|
"error": map[string]interface{}{"type": "string"},
|
||||||
|
"model": map[string]interface{}{"type": "string"},
|
||||||
|
"latency_ms": map[string]interface{}{"type": "number"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"400": map[string]interface{}{"description": "参数错误"},
|
||||||
|
"401": map[string]interface{}{"description": "未授权"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"/api/config/test-openai": map[string]interface{}{
|
"/api/config/test-openai": map[string]interface{}{
|
||||||
"post": map[string]interface{}{
|
"post": map[string]interface{}{
|
||||||
"tags": []string{"配置管理"},
|
"tags": []string{"配置管理"},
|
||||||
@@ -6254,7 +6345,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取漏洞列表
|
// 获取漏洞列表
|
||||||
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "")
|
vulnList, err := h.db.ListVulnerabilities(1000, 0, database.VulnerabilityListFilter{ConversationID: conversationID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("获取漏洞列表失败", zap.Error(err))
|
h.logger.Warn("获取漏洞列表失败", zap.Error(err))
|
||||||
vulnList = []*database.Vulnerability{}
|
vulnList = []*database.Vulnerability{}
|
||||||
|
|||||||
@@ -0,0 +1,400 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProjectHandler 项目管理处理器。
|
||||||
|
type ProjectHandler struct {
|
||||||
|
db *database.DB
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProjectHandler 创建项目管理处理器。
|
||||||
|
func NewProjectHandler(db *database.DB, logger *zap.Logger) *ProjectHandler {
|
||||||
|
return &ProjectHandler{db: db, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
type createProjectRequest struct {
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
ScopeJSON string `json:"scope_json"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateProjectRequest 部分更新:字段省略表示不修改;传 null 或 "" 可清空字符串字段。
|
||||||
|
type updateProjectRequest struct {
|
||||||
|
Name *string `json:"name"`
|
||||||
|
Description *string `json:"description"`
|
||||||
|
ScopeJSON *string `json:"scope_json"`
|
||||||
|
Status *string `json:"status"`
|
||||||
|
Pinned *bool `json:"pinned"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProject POST /api/projects
|
||||||
|
func (h *ProjectHandler) CreateProject(c *gin.Context) {
|
||||||
|
var req createProjectRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p := &database.Project{
|
||||||
|
Name: strings.TrimSpace(req.Name),
|
||||||
|
Description: req.Description,
|
||||||
|
ScopeJSON: req.ScopeJSON,
|
||||||
|
Status: strings.TrimSpace(req.Status),
|
||||||
|
}
|
||||||
|
created, err := h.db.CreateProject(p)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("创建项目失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, created)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProjects GET /api/projects
|
||||||
|
func (h *ProjectHandler) ListProjects(c *gin.Context) {
|
||||||
|
status := c.Query("status")
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "200"))
|
||||||
|
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||||
|
list, err := h.db.ListProjects(status, limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if list == nil {
|
||||||
|
list = []*database.Project{}
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectStats GET /api/projects/:id/stats
|
||||||
|
func (h *ProjectHandler) GetProjectStats(c *gin.Context) {
|
||||||
|
stats, err := project.GetProjectStats(h.db, c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "不存在") {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProjectConversations GET /api/projects/:id/conversations
|
||||||
|
func (h *ProjectHandler) ListProjectConversations(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
if _, err := h.db.GetProject(projectID); err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||||
|
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||||
|
list, err := h.db.ListConversationsByProjectID(projectID, limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if list == nil {
|
||||||
|
list = []*database.Conversation{}
|
||||||
|
}
|
||||||
|
total, _ := h.db.CountConversationsByProjectID(projectID)
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"conversations": list,
|
||||||
|
"total": total,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProject GET /api/projects/:id
|
||||||
|
func (h *ProjectHandler) GetProject(c *gin.Context) {
|
||||||
|
p, err := h.db.GetProject(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateProject PUT /api/projects/:id
|
||||||
|
func (h *ProjectHandler) UpdateProject(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
p, err := h.db.GetProject(id)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req updateProjectRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Name != nil {
|
||||||
|
if s := strings.TrimSpace(*req.Name); s != "" {
|
||||||
|
p.Name = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.Description != nil {
|
||||||
|
p.Description = *req.Description
|
||||||
|
}
|
||||||
|
if req.ScopeJSON != nil {
|
||||||
|
p.ScopeJSON = *req.ScopeJSON
|
||||||
|
}
|
||||||
|
if req.Status != nil {
|
||||||
|
if s := strings.TrimSpace(*req.Status); s != "" {
|
||||||
|
p.Status = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.Pinned != nil {
|
||||||
|
p.Pinned = *req.Pinned
|
||||||
|
}
|
||||||
|
if err := h.db.UpdateProject(p); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteProject DELETE /api/projects/:id
|
||||||
|
func (h *ProjectHandler) DeleteProject(c *gin.Context) {
|
||||||
|
if err := h.db.DeleteProject(c.Param("id")); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
type upsertFactRequest struct {
|
||||||
|
FactKey string `json:"fact_key" binding:"required"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
Summary string `json:"summary" binding:"required"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
Pinned bool `json:"pinned"`
|
||||||
|
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
|
||||||
|
type updateFactRequest struct {
|
||||||
|
FactKey *string `json:"fact_key"`
|
||||||
|
Category *string `json:"category"`
|
||||||
|
Summary *string `json:"summary"`
|
||||||
|
Body *string `json:"body"`
|
||||||
|
Confidence *string `json:"confidence"`
|
||||||
|
Pinned *bool `json:"pinned"`
|
||||||
|
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
|
||||||
|
ClearBody bool `json:"clear_body"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情)
|
||||||
|
func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
if key := strings.TrimSpace(c.Query("fact_key")); key != "" {
|
||||||
|
f, err := h.db.GetProjectFactByKey(projectID, key)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, f)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||||
|
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||||
|
filter := database.ProjectFactListFilter{
|
||||||
|
Category: c.Query("category"),
|
||||||
|
Confidence: c.Query("confidence"),
|
||||||
|
Search: c.Query("search"),
|
||||||
|
RelatedVulnerabilityID: c.Query("related_vulnerability_id"),
|
||||||
|
}
|
||||||
|
if c.Query("exclude_deprecated") == "1" || c.Query("exclude_deprecated") == "true" {
|
||||||
|
filter.ExcludeDeprecated = true
|
||||||
|
}
|
||||||
|
list, err := h.db.ListProjectFacts(projectID, filter, limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if list == nil {
|
||||||
|
list = []*database.ProjectFact{}
|
||||||
|
}
|
||||||
|
if sparseOnly := c.Query("sparse_only"); sparseOnly == "1" || sparseOnly == "true" {
|
||||||
|
filtered := make([]*database.ProjectFact, 0, len(list))
|
||||||
|
for _, f := range list {
|
||||||
|
if project.IsSparseFactBody(f.Category, f.FactKey, f.Body) {
|
||||||
|
filtered = append(filtered, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
list = filtered
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFactPreviousVersion GET /api/projects/:id/facts/:factId/previous-version
|
||||||
|
func (h *ProjectHandler) GetFactPreviousVersion(c *gin.Context) {
|
||||||
|
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||||
|
if err != nil || existing.ProjectID != c.Param("id") {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(existing.SupersedesFactID) == "" {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "无上一版本"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v, err := h.db.GetProjectFactVersion(existing.SupersedesFactID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFactVersions GET /api/projects/:id/facts/:factId/versions
|
||||||
|
func (h *ProjectHandler) ListFactVersions(c *gin.Context) {
|
||||||
|
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||||
|
if err != nil || existing.ProjectID != c.Param("id") {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||||
|
list, err := h.db.ListProjectFactVersions(existing.ID, limit)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if list == nil {
|
||||||
|
list = []*database.ProjectFactVersion{}
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateFact POST /api/projects/:id/facts
|
||||||
|
func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
||||||
|
var req upsertFactRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f := &database.ProjectFact{
|
||||||
|
ProjectID: c.Param("id"),
|
||||||
|
FactKey: req.FactKey,
|
||||||
|
Category: req.Category,
|
||||||
|
Summary: req.Summary,
|
||||||
|
Body: req.Body,
|
||||||
|
Confidence: req.Confidence,
|
||||||
|
Pinned: req.Pinned,
|
||||||
|
RelatedVulnerabilityID: req.RelatedVulnerabilityID,
|
||||||
|
}
|
||||||
|
created, err := h.db.UpsertProjectFact(f)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, created)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFact PUT /api/projects/:id/facts/:factId
|
||||||
|
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
||||||
|
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||||
|
if err != nil || existing.ProjectID != c.Param("id") {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req updateFactRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.FactKey != nil {
|
||||||
|
if k := strings.TrimSpace(*req.FactKey); k != "" {
|
||||||
|
existing.FactKey = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.Category != nil && strings.TrimSpace(*req.Category) != "" {
|
||||||
|
existing.Category = *req.Category
|
||||||
|
}
|
||||||
|
if req.Summary != nil && strings.TrimSpace(*req.Summary) != "" {
|
||||||
|
existing.Summary = *req.Summary
|
||||||
|
}
|
||||||
|
if req.ClearBody {
|
||||||
|
existing.Body = ""
|
||||||
|
} else if req.Body != nil {
|
||||||
|
existing.Body = *req.Body
|
||||||
|
}
|
||||||
|
if req.Confidence != nil && strings.TrimSpace(*req.Confidence) != "" {
|
||||||
|
existing.Confidence = *req.Confidence
|
||||||
|
}
|
||||||
|
if req.Pinned != nil {
|
||||||
|
existing.Pinned = *req.Pinned
|
||||||
|
}
|
||||||
|
if req.RelatedVulnerabilityID != nil {
|
||||||
|
existing.RelatedVulnerabilityID = *req.RelatedVulnerabilityID
|
||||||
|
}
|
||||||
|
updated, err := h.db.UpsertProjectFact(existing)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFact DELETE /api/projects/:id/facts/:factId
|
||||||
|
func (h *ProjectHandler) DeleteFact(c *gin.Context) {
|
||||||
|
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||||
|
if err != nil || existing.ProjectID != c.Param("id") {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.db.DeleteProjectFact(existing.ID); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
type deprecateFactRequest struct {
|
||||||
|
FactKey string `json:"fact_key" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeprecateFact POST /api/projects/:id/facts/deprecate
|
||||||
|
func (h *ProjectHandler) DeprecateFact(c *gin.Context) {
|
||||||
|
var req deprecateFactRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.db.DeprecateProjectFact(c.Param("id"), req.FactKey); err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
type restoreFactRequest struct {
|
||||||
|
FactKey string `json:"fact_key" binding:"required"`
|
||||||
|
Confidence string `json:"confidence"` // 可选:confirmed | tentative,默认 tentative
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreFact POST /api/projects/:id/facts/restore
|
||||||
|
func (h *ProjectHandler) RestoreFact(c *gin.Context) {
|
||||||
|
var req restoreFactRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.db.RestoreProjectFact(c.Param("id"), req.FactKey, req.Confidence); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
|
||||||
|
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
||||||
|
if h == nil || h.db == nil || h.config == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if !h.config.Project.Enabled {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
projectID, err := h.db.GetConversationProjectID(conversationID)
|
||||||
|
if err != nil || projectID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
block, err := project.BuildProjectBlackboardBlock(h.db, projectID, h.config.Project)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(block)
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// effectiveProjectID 请求/队列显式项目优先,否则使用 config.project.default_project_id。
|
||||||
|
func effectiveProjectID(cfg *config.Config, explicit string) string {
|
||||||
|
if pid := strings.TrimSpace(explicit); pid != "" {
|
||||||
|
return pid
|
||||||
|
}
|
||||||
|
if cfg != nil {
|
||||||
|
return strings.TrimSpace(cfg.Project.DefaultProjectID)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
+218
-21
@@ -40,8 +40,13 @@ const (
|
|||||||
robotCmdRoles = "角色"
|
robotCmdRoles = "角色"
|
||||||
robotCmdRolesList = "角色列表"
|
robotCmdRolesList = "角色列表"
|
||||||
robotCmdSwitchRole = "切换角色"
|
robotCmdSwitchRole = "切换角色"
|
||||||
robotCmdDelete = "删除"
|
robotCmdDelete = "删除"
|
||||||
robotCmdVersion = "版本"
|
robotCmdVersion = "版本"
|
||||||
|
robotCmdProjects = "项目"
|
||||||
|
robotCmdProjectsList = "项目列表"
|
||||||
|
robotCmdBindProject = "绑定项目"
|
||||||
|
robotCmdNewProject = "新建项目"
|
||||||
|
robotCmdUnbindProject = "解除项目"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
|
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
|
||||||
@@ -133,7 +138,9 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (
|
|||||||
} else {
|
} else {
|
||||||
t = safeTruncateString(t, 50)
|
t = safeTruncateString(t, 50)
|
||||||
}
|
}
|
||||||
conv, err := h.db.CreateConversation(t, database.ConversationCreateMeta{Source: "robot:" + platform})
|
meta := database.ConversationCreateMeta{Source: "robot:" + platform}
|
||||||
|
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||||
|
conv, err := h.db.CreateConversation(t, meta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("创建机器人会话失败", zap.Error(err))
|
h.logger.Warn("创建机器人会话失败", zap.Error(err))
|
||||||
return "", false
|
return "", false
|
||||||
@@ -188,7 +195,9 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) {
|
|||||||
// clearConversation 清空当前会话(切换到新对话)
|
// clearConversation 清空当前会话(切换到新对话)
|
||||||
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
|
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
|
||||||
title := "新对话 " + time.Now().Format("01-02 15:04")
|
title := "新对话 " + time.Now().Format("01-02 15:04")
|
||||||
conv, err := h.db.CreateConversation(title, database.ConversationCreateMeta{Source: "robot:" + platform + ":new"})
|
meta := database.ConversationCreateMeta{Source: "robot:" + platform + ":new"}
|
||||||
|
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||||
|
conv, err := h.db.CreateConversation(title, meta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("创建新对话失败", zap.Error(err))
|
h.logger.Warn("创建新对话失败", zap.Error(err))
|
||||||
return ""
|
return ""
|
||||||
@@ -230,7 +239,7 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
|||||||
_ = h.db.UpdateConversationTitle(convID, newTitle)
|
_ = h.db.UpdateConversationTitle(convID, newTitle)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), h.robotMessageTimeout())
|
||||||
sk := h.sessionKey(platform, userID)
|
sk := h.sessionKey(platform, userID)
|
||||||
h.cancelMu.Lock()
|
h.cancelMu.Lock()
|
||||||
h.runningCancels[sk] = cancel
|
h.runningCancels[sk] = cancel
|
||||||
@@ -248,6 +257,9 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
|||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
return "任务已取消。"
|
return "任务已取消。"
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return "任务执行超时,请稍后重试或精简本次请求范围。"
|
||||||
|
}
|
||||||
return "处理失败: " + err.Error()
|
return "处理失败: " + err.Error()
|
||||||
}
|
}
|
||||||
if newConvID != convID {
|
if newConvID != convID {
|
||||||
@@ -256,22 +268,182 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
|||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *RobotHandler) robotMessageTimeout() time.Duration {
|
||||||
|
// 机器人整次消息处理超时(与单次工具超时 agent.tool_timeout_minutes 解耦)。
|
||||||
|
return 10 * time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
func (h *RobotHandler) cmdHelp() string {
|
func (h *RobotHandler) cmdHelp() string {
|
||||||
return "**【CyberStrikeAI 机器人命令】**\n\n" +
|
var b strings.Builder
|
||||||
"- `帮助` `help` — 显示本帮助 | Show this help\n" +
|
b.WriteString("【CyberStrikeAI 机器人命令】\n\n")
|
||||||
"- `列表` `list` — 列出所有对话标题与 ID | List conversations\n" +
|
b.WriteString("【通用 General】\n")
|
||||||
"- `切换 <ID>` `switch <ID>` — 指定对话继续 | Switch to conversation\n" +
|
b.WriteString("· 帮助 / help — 显示本帮助\n")
|
||||||
"- `新对话` `new` — 开启新对话 | Start new conversation\n" +
|
b.WriteString("· 版本 / version — 显示当前版本号\n")
|
||||||
"- `清空` `clear` — 清空当前上下文 | Clear context\n" +
|
b.WriteString("\n【对话 Conversation】\n")
|
||||||
"- `当前` `current` — 显示当前对话 ID 与标题 | Show current conversation\n" +
|
b.WriteString("· 列表 / list — 列出所有对话标题与 ID\n")
|
||||||
"- `停止` `stop` — 中断当前任务 | Stop running task\n" +
|
b.WriteString("· 切换 <ID> / switch <ID> — 指定对话继续\n")
|
||||||
"- `角色` `roles` — 列出所有可用角色 | List roles\n" +
|
b.WriteString("· 新对话 / new — 开启新对话\n")
|
||||||
"- `角色 <名>` `role <name>` — 切换当前角色 | Switch role\n" +
|
b.WriteString("· 清空 / clear — 清空当前上下文\n")
|
||||||
"- `删除 <ID>` `delete <ID>` — 删除指定对话 | Delete conversation\n" +
|
b.WriteString("· 当前 / current — 显示当前对话、角色与项目\n")
|
||||||
"- `版本` `version` — 显示当前版本号 | Show version\n\n" +
|
b.WriteString("· 停止 / stop — 中断当前任务\n")
|
||||||
"---\n" +
|
b.WriteString("· 删除 <ID> / delete <ID> — 删除指定对话\n")
|
||||||
"除以上命令外,直接输入内容将发送给 AI 进行渗透测试/安全分析。\n" +
|
b.WriteString("\n【角色 Role】\n")
|
||||||
"Otherwise, send any text for AI penetration testing / security analysis."
|
b.WriteString("· 角色 / roles — 列出所有可用角色\n")
|
||||||
|
b.WriteString("· 角色 <名> / role <name> — 切换当前角色\n")
|
||||||
|
if h.projectsEnabled() {
|
||||||
|
b.WriteString("\n【项目 Project】\n")
|
||||||
|
b.WriteString("· 项目 / projects — 列出所有项目\n")
|
||||||
|
b.WriteString("· 新建项目 <名称> / new project <name> — 创建并绑定当前对话\n")
|
||||||
|
b.WriteString("· 绑定项目 <ID或名称> / bind project <ID|name> — 绑定到已有项目\n")
|
||||||
|
b.WriteString("· 解除项目 / unbind project — 解除项目绑定\n")
|
||||||
|
}
|
||||||
|
b.WriteString("\n──────────────\n")
|
||||||
|
b.WriteString("除以上命令外,直接输入内容将发送给 AI 进行渗透测试/安全分析。")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RobotHandler) projectsEnabled() bool {
|
||||||
|
return h.config != nil && h.config.Project.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RobotHandler) resolveProjectByIDOrName(idOrName string) (*database.Project, string) {
|
||||||
|
idOrName = strings.TrimSpace(idOrName)
|
||||||
|
if idOrName == "" {
|
||||||
|
return nil, "请指定项目 ID 或名称,例如:绑定项目 xxx-xxx"
|
||||||
|
}
|
||||||
|
if p, err := h.db.GetProject(idOrName); err == nil {
|
||||||
|
return p, ""
|
||||||
|
}
|
||||||
|
list, err := h.db.ListProjects("", 200, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "查询项目失败: " + err.Error()
|
||||||
|
}
|
||||||
|
var matches []*database.Project
|
||||||
|
for _, p := range list {
|
||||||
|
if p.Name == idOrName {
|
||||||
|
matches = append(matches, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch len(matches) {
|
||||||
|
case 0:
|
||||||
|
return nil, fmt.Sprintf("项目「%s」不存在。发送「项目」查看列表。", idOrName)
|
||||||
|
case 1:
|
||||||
|
return matches[0], ""
|
||||||
|
default:
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(fmt.Sprintf("名称「%s」匹配到多个项目,请使用 ID 绑定:\n", idOrName))
|
||||||
|
for _, p := range matches {
|
||||||
|
b.WriteString(fmt.Sprintf("· %s\n ID: %s\n", p.Name, p.ID))
|
||||||
|
}
|
||||||
|
return nil, strings.TrimSuffix(b.String(), "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RobotHandler) formatProjectLabel(projectID string) string {
|
||||||
|
if strings.TrimSpace(projectID) == "" {
|
||||||
|
return "未绑定"
|
||||||
|
}
|
||||||
|
if p, err := h.db.GetProject(projectID); err == nil {
|
||||||
|
return fmt.Sprintf("「%s」 (%s)", p.Name, p.ID)
|
||||||
|
}
|
||||||
|
return projectID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RobotHandler) cmdProjects() string {
|
||||||
|
if !h.projectsEnabled() {
|
||||||
|
return "项目功能未启用(config.project.enabled)。"
|
||||||
|
}
|
||||||
|
list, err := h.db.ListProjects("", 50, 0)
|
||||||
|
if err != nil {
|
||||||
|
return "获取项目列表失败: " + err.Error()
|
||||||
|
}
|
||||||
|
if len(list) == 0 {
|
||||||
|
return "暂无项目。发送「新建项目 <名称>」创建并绑定到当前对话。"
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("【项目列表】\n")
|
||||||
|
for i, p := range list {
|
||||||
|
if i >= 20 {
|
||||||
|
b.WriteString("… 仅显示前 20 条\n")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
status := p.Status
|
||||||
|
if status == "" {
|
||||||
|
status = "active"
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf("· %s [%s]\n ID: %s\n", p.Name, status, p.ID))
|
||||||
|
}
|
||||||
|
return strings.TrimSuffix(b.String(), "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RobotHandler) cmdBindProject(platform, userID, idOrName string) string {
|
||||||
|
if !h.projectsEnabled() {
|
||||||
|
return "项目功能未启用(config.project.enabled)。"
|
||||||
|
}
|
||||||
|
p, errMsg := h.resolveProjectByIDOrName(idOrName)
|
||||||
|
if p == nil {
|
||||||
|
return errMsg
|
||||||
|
}
|
||||||
|
convID, _ := h.getOrCreateConversation(platform, userID, "")
|
||||||
|
if convID == "" {
|
||||||
|
return "无法获取当前对话,请稍后再试。"
|
||||||
|
}
|
||||||
|
if err := h.db.SetConversationProjectID(convID, p.ID); err != nil {
|
||||||
|
return "绑定失败: " + err.Error()
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("已将当前对话绑定到项目:「%s」\nID: %s", p.Name, p.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RobotHandler) cmdNewProject(platform, userID, name string) string {
|
||||||
|
if !h.projectsEnabled() {
|
||||||
|
return "项目功能未启用(config.project.enabled)。"
|
||||||
|
}
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return "请指定项目名称,例如:新建项目 某目标渗透"
|
||||||
|
}
|
||||||
|
p := &database.Project{Name: name, Status: "active"}
|
||||||
|
created, err := h.db.CreateProject(p)
|
||||||
|
if err != nil {
|
||||||
|
return "创建项目失败: " + err.Error()
|
||||||
|
}
|
||||||
|
convID, _ := h.getOrCreateConversation(platform, userID, name)
|
||||||
|
if convID == "" {
|
||||||
|
return fmt.Sprintf("项目已创建:「%s」\nID: %s\n(绑定当前对话失败,请手动发送「绑定项目 %s」)", created.Name, created.ID, created.ID)
|
||||||
|
}
|
||||||
|
if err := h.db.SetConversationProjectID(convID, created.ID); err != nil {
|
||||||
|
return fmt.Sprintf("项目已创建:「%s」\nID: %s\n绑定失败: %s", created.Name, created.ID, err.Error())
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("已创建项目并绑定当前对话:「%s」\nID: %s", created.Name, created.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RobotHandler) cmdUnbindProject(platform, userID string) string {
|
||||||
|
if !h.projectsEnabled() {
|
||||||
|
return "项目功能未启用(config.project.enabled)。"
|
||||||
|
}
|
||||||
|
sk := h.sessionKey(platform, userID)
|
||||||
|
h.mu.RLock()
|
||||||
|
convID := h.sessions[sk]
|
||||||
|
h.mu.RUnlock()
|
||||||
|
if convID == "" {
|
||||||
|
if persistedConvID, _ := h.loadSessionBinding(sk); persistedConvID != "" {
|
||||||
|
convID = persistedConvID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if convID == "" {
|
||||||
|
return "当前没有进行中的对话,无需解除绑定。"
|
||||||
|
}
|
||||||
|
projectID, err := h.db.GetConversationProjectID(convID)
|
||||||
|
if err != nil {
|
||||||
|
return "获取对话项目失败: " + err.Error()
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(projectID) == "" {
|
||||||
|
return "当前对话未绑定项目。"
|
||||||
|
}
|
||||||
|
if err := h.db.SetConversationProjectID(convID, ""); err != nil {
|
||||||
|
return "解除绑定失败: " + err.Error()
|
||||||
|
}
|
||||||
|
return "已解除当前对话的项目绑定。"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RobotHandler) cmdList() string {
|
func (h *RobotHandler) cmdList() string {
|
||||||
@@ -345,7 +517,12 @@ func (h *RobotHandler) cmdCurrent(platform, userID string) string {
|
|||||||
return "当前对话 ID: " + convID + "(获取标题失败)"
|
return "当前对话 ID: " + convID + "(获取标题失败)"
|
||||||
}
|
}
|
||||||
role := h.getRole(platform, userID)
|
role := h.getRole(platform, userID)
|
||||||
return fmt.Sprintf("当前对话:「%s」\nID: %s\n当前角色: %s", conv.Title, conv.ID, role)
|
reply := fmt.Sprintf("当前对话:「%s」\nID: %s\n当前角色: %s", conv.Title, conv.ID, role)
|
||||||
|
if h.projectsEnabled() {
|
||||||
|
projectID, _ := h.db.GetConversationProjectID(conv.ID)
|
||||||
|
reply += "\n当前项目: " + h.formatProjectLabel(projectID)
|
||||||
|
}
|
||||||
|
return reply
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RobotHandler) cmdRoles() string {
|
func (h *RobotHandler) cmdRoles() string {
|
||||||
@@ -482,6 +659,26 @@ func (h *RobotHandler) handleRobotCommand(platform, userID, text string) (string
|
|||||||
return h.cmdDelete(platform, userID, convID), true
|
return h.cmdDelete(platform, userID, convID), true
|
||||||
case text == robotCmdVersion || text == "version":
|
case text == robotCmdVersion || text == "version":
|
||||||
return h.cmdVersion(), true
|
return h.cmdVersion(), true
|
||||||
|
case text == robotCmdProjects || text == robotCmdProjectsList || text == "projects":
|
||||||
|
return h.cmdProjects(), true
|
||||||
|
case text == robotCmdUnbindProject || text == "unbind project":
|
||||||
|
return h.cmdUnbindProject(platform, userID), true
|
||||||
|
case strings.HasPrefix(text, robotCmdNewProject+" ") || strings.HasPrefix(text, "new project "):
|
||||||
|
var name string
|
||||||
|
if strings.HasPrefix(text, robotCmdNewProject+" ") {
|
||||||
|
name = strings.TrimSpace(text[len(robotCmdNewProject)+1:])
|
||||||
|
} else {
|
||||||
|
name = strings.TrimSpace(text[len("new project "):])
|
||||||
|
}
|
||||||
|
return h.cmdNewProject(platform, userID, name), true
|
||||||
|
case strings.HasPrefix(text, robotCmdBindProject+" ") || strings.HasPrefix(text, "bind project "):
|
||||||
|
var idOrName string
|
||||||
|
if strings.HasPrefix(text, robotCmdBindProject+" ") {
|
||||||
|
idOrName = strings.TrimSpace(text[len(robotCmdBindProject)+1:])
|
||||||
|
} else {
|
||||||
|
idOrName = strings.TrimSpace(text[len("bind project "):])
|
||||||
|
}
|
||||||
|
return h.cmdBindProject(platform, userID, idOrName), true
|
||||||
default:
|
default:
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
|
|||||||
// CreateVulnerabilityRequest 创建漏洞请求
|
// CreateVulnerabilityRequest 创建漏洞请求
|
||||||
type CreateVulnerabilityRequest struct {
|
type CreateVulnerabilityRequest struct {
|
||||||
ConversationID string `json:"conversation_id" binding:"required"`
|
ConversationID string `json:"conversation_id" binding:"required"`
|
||||||
|
ProjectID string `json:"project_id"`
|
||||||
ConversationTag string `json:"conversation_tag"`
|
ConversationTag string `json:"conversation_tag"`
|
||||||
TaskTag string `json:"task_tag"`
|
TaskTag string `json:"task_tag"`
|
||||||
Title string `json:"title" binding:"required"`
|
Title string `json:"title" binding:"required"`
|
||||||
@@ -59,6 +60,7 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
|||||||
|
|
||||||
vuln := &database.Vulnerability{
|
vuln := &database.Vulnerability{
|
||||||
ConversationID: req.ConversationID,
|
ConversationID: req.ConversationID,
|
||||||
|
ProjectID: strings.TrimSpace(req.ProjectID),
|
||||||
ConversationTag: req.ConversationTag,
|
ConversationTag: req.ConversationTag,
|
||||||
TaskTag: req.TaskTag,
|
TaskTag: req.TaskTag,
|
||||||
Title: req.Title,
|
Title: req.Title,
|
||||||
@@ -110,18 +112,30 @@ type ListVulnerabilitiesResponse struct {
|
|||||||
TotalPages int `json:"total_pages"`
|
TotalPages int `json:"total_pages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter {
|
||||||
|
q := strings.TrimSpace(c.Query("q"))
|
||||||
|
if q == "" {
|
||||||
|
q = strings.TrimSpace(c.Query("search"))
|
||||||
|
}
|
||||||
|
return database.VulnerabilityListFilter{
|
||||||
|
ProjectID: c.Query("project_id"),
|
||||||
|
ID: c.Query("id"),
|
||||||
|
Search: q,
|
||||||
|
ConversationID: c.Query("conversation_id"),
|
||||||
|
Severity: c.Query("severity"),
|
||||||
|
Status: c.Query("status"),
|
||||||
|
TaskID: c.Query("task_id"),
|
||||||
|
ConversationTag: c.Query("conversation_tag"),
|
||||||
|
TaskTag: c.Query("task_tag"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ListVulnerabilities 列出漏洞
|
// ListVulnerabilities 列出漏洞
|
||||||
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||||
limitStr := c.DefaultQuery("limit", "20")
|
limitStr := c.DefaultQuery("limit", "20")
|
||||||
offsetStr := c.DefaultQuery("offset", "0")
|
offsetStr := c.DefaultQuery("offset", "0")
|
||||||
pageStr := c.Query("page")
|
pageStr := c.Query("page")
|
||||||
id := c.Query("id")
|
filter := parseVulnerabilityListFilter(c)
|
||||||
conversationID := c.Query("conversation_id")
|
|
||||||
severity := c.Query("severity")
|
|
||||||
status := c.Query("status")
|
|
||||||
taskID := c.Query("task_id")
|
|
||||||
conversationTag := c.Query("conversation_tag")
|
|
||||||
taskTag := c.Query("task_tag")
|
|
||||||
|
|
||||||
limit, _ := strconv.Atoi(limitStr)
|
limit, _ := strconv.Atoi(limitStr)
|
||||||
offset, _ := strconv.Atoi(offsetStr)
|
offset, _ := strconv.Atoi(offsetStr)
|
||||||
@@ -143,7 +157,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取总数
|
// 获取总数
|
||||||
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
total, err := h.db.CountVulnerabilities(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("获取漏洞总数失败", zap.Error(err))
|
h.logger.Error("获取漏洞总数失败", zap.Error(err))
|
||||||
// 继续执行,使用0作为总数
|
// 继续执行,使用0作为总数
|
||||||
@@ -151,7 +165,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取漏洞列表
|
// 获取漏洞列表
|
||||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -182,17 +196,18 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
|||||||
|
|
||||||
// UpdateVulnerabilityRequest 更新漏洞请求
|
// UpdateVulnerabilityRequest 更新漏洞请求
|
||||||
type UpdateVulnerabilityRequest struct {
|
type UpdateVulnerabilityRequest struct {
|
||||||
ConversationTag string `json:"conversation_tag"`
|
ProjectID *string `json:"project_id"`
|
||||||
TaskTag string `json:"task_tag"`
|
ConversationTag string `json:"conversation_tag"`
|
||||||
Title string `json:"title"`
|
TaskTag string `json:"task_tag"`
|
||||||
Description string `json:"description"`
|
Title string `json:"title"`
|
||||||
Severity string `json:"severity"`
|
Description string `json:"description"`
|
||||||
Status string `json:"status"`
|
Severity string `json:"severity"`
|
||||||
Type string `json:"type"`
|
Status string `json:"status"`
|
||||||
Target string `json:"target"`
|
Type string `json:"type"`
|
||||||
Proof string `json:"proof"`
|
Target string `json:"target"`
|
||||||
Impact string `json:"impact"`
|
Proof string `json:"proof"`
|
||||||
Recommendation string `json:"recommendation"`
|
Impact string `json:"impact"`
|
||||||
|
Recommendation string `json:"recommendation"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateVulnerability 更新漏洞
|
// UpdateVulnerability 更新漏洞
|
||||||
@@ -213,6 +228,9 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 更新字段
|
// 更新字段
|
||||||
|
if req.ProjectID != nil {
|
||||||
|
existing.ProjectID = strings.TrimSpace(*req.ProjectID)
|
||||||
|
}
|
||||||
if req.ConversationTag != "" {
|
if req.ConversationTag != "" {
|
||||||
existing.ConversationTag = req.ConversationTag
|
existing.ConversationTag = req.ConversationTag
|
||||||
}
|
}
|
||||||
@@ -263,7 +281,7 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
|||||||
|
|
||||||
if h.audit != nil {
|
if h.audit != nil {
|
||||||
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
|
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
|
||||||
"severity": updated.Severity, "status": updated.Status,
|
"severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, updated)
|
c.JSON(http.StatusOK, updated)
|
||||||
@@ -295,10 +313,9 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
|
|||||||
|
|
||||||
// GetVulnerabilityStats 获取漏洞统计
|
// GetVulnerabilityStats 获取漏洞统计
|
||||||
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
||||||
conversationID := c.Query("conversation_id")
|
filter := parseVulnerabilityListFilter(c)
|
||||||
taskID := c.Query("task_id")
|
|
||||||
|
|
||||||
stats, err := h.db.GetVulnerabilityStats(conversationID, taskID)
|
stats, err := h.db.GetVulnerabilityStats(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("获取漏洞统计失败", zap.Error(err))
|
h.logger.Error("获取漏洞统计失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -332,15 +349,9 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id := c.Query("id")
|
filter := parseVulnerabilityListFilter(c)
|
||||||
conversationID := c.Query("conversation_id")
|
|
||||||
severity := c.Query("severity")
|
|
||||||
status := c.Query("status")
|
|
||||||
taskID := c.Query("task_id")
|
|
||||||
conversationTag := c.Query("conversation_tag")
|
|
||||||
taskTag := c.Query("task_tag")
|
|
||||||
|
|
||||||
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
total, err := h.db.CountVulnerabilities(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -350,7 +361,7 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
items, err := h.db.ListVulnerabilities(total, 0, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -133,6 +134,16 @@ func quoteCmdPath(p string) string {
|
|||||||
return "\"" + strings.ReplaceAll(p, "\"", "\"\"") + "\""
|
return "\"" + strings.ReplaceAll(p, "\"", "\"\"") + "\""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeWindowsCmdPath 把前端统一的 "/" 路径转换为 cmd 更稳定识别的 "\"。
|
||||||
|
// 仅用于 Windows 命令构造,不改变语义(例如 "." / ".." 会保持不变)。
|
||||||
|
func normalizeWindowsCmdPath(p string) string {
|
||||||
|
s := strings.TrimSpace(p)
|
||||||
|
if s == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return strings.ReplaceAll(s, "/", "\\")
|
||||||
|
}
|
||||||
|
|
||||||
// quotePsSingle 把字符串按 PowerShell 单引号字符串规则转义(内部 ' → '')。
|
// quotePsSingle 把字符串按 PowerShell 单引号字符串规则转义(内部 ' → '')。
|
||||||
// 供 PowerShell 脚本参数使用,全脚本只用单引号,外层 cmd 再用双引号包裹即可安全传递。
|
// 供 PowerShell 脚本参数使用,全脚本只用单引号,外层 cmd 再用双引号包裹即可安全传递。
|
||||||
func quotePsSingle(s string) string {
|
func quotePsSingle(s string) string {
|
||||||
@@ -197,6 +208,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
|||||||
p = "."
|
p = "."
|
||||||
}
|
}
|
||||||
if targetOS == "windows" {
|
if targetOS == "windows" {
|
||||||
|
p = normalizeWindowsCmdPath(p)
|
||||||
return "dir /a " + quoteCmdPath(p), nil
|
return "dir /a " + quoteCmdPath(p), nil
|
||||||
}
|
}
|
||||||
return "ls -la " + quoteShellSinglePosix(p), nil
|
return "ls -la " + quoteShellSinglePosix(p), nil
|
||||||
@@ -206,6 +218,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
|||||||
return "", errFileOpPathRequired
|
return "", errFileOpPathRequired
|
||||||
}
|
}
|
||||||
if targetOS == "windows" {
|
if targetOS == "windows" {
|
||||||
|
path = normalizeWindowsCmdPath(path)
|
||||||
return "type " + quoteCmdPath(path), nil
|
return "type " + quoteCmdPath(path), nil
|
||||||
}
|
}
|
||||||
return "cat " + quoteShellSinglePosix(path), nil
|
return "cat " + quoteShellSinglePosix(path), nil
|
||||||
@@ -215,6 +228,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
|||||||
return "", errFileOpPathRequired
|
return "", errFileOpPathRequired
|
||||||
}
|
}
|
||||||
if targetOS == "windows" {
|
if targetOS == "windows" {
|
||||||
|
path = normalizeWindowsCmdPath(path)
|
||||||
return "del /q /f " + quoteCmdPath(path), nil
|
return "del /q /f " + quoteCmdPath(path), nil
|
||||||
}
|
}
|
||||||
return "rm -f " + quoteShellSinglePosix(path), nil
|
return "rm -f " + quoteShellSinglePosix(path), nil
|
||||||
@@ -224,6 +238,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
|||||||
return "", errFileOpPathRequired
|
return "", errFileOpPathRequired
|
||||||
}
|
}
|
||||||
if targetOS == "windows" {
|
if targetOS == "windows" {
|
||||||
|
path = normalizeWindowsCmdPath(path)
|
||||||
// cmd 的 md 默认会自动创建中间目录(等价于 Linux 的 mkdir -p)
|
// cmd 的 md 默认会自动创建中间目录(等价于 Linux 的 mkdir -p)
|
||||||
return "md " + quoteCmdPath(path), nil
|
return "md " + quoteCmdPath(path), nil
|
||||||
}
|
}
|
||||||
@@ -236,6 +251,8 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
|||||||
return "", errFileOpRenameNeedsBothPaths
|
return "", errFileOpRenameNeedsBothPaths
|
||||||
}
|
}
|
||||||
if targetOS == "windows" {
|
if targetOS == "windows" {
|
||||||
|
oldPath = normalizeWindowsCmdPath(oldPath)
|
||||||
|
newPath = normalizeWindowsCmdPath(newPath)
|
||||||
return "move /y " + quoteCmdPath(oldPath) + " " + quoteCmdPath(newPath), nil
|
return "move /y " + quoteCmdPath(oldPath) + " " + quoteCmdPath(newPath), nil
|
||||||
}
|
}
|
||||||
return "mv -f " + quoteShellSinglePosix(oldPath) + " " + quoteShellSinglePosix(newPath), nil
|
return "mv -f " + quoteShellSinglePosix(oldPath) + " " + quoteShellSinglePosix(newPath), nil
|
||||||
@@ -248,6 +265,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
|||||||
// 这样既能写入任意二进制/含引号的文本,又避免各家 shell 的转义地狱。
|
// 这样既能写入任意二进制/含引号的文本,又避免各家 shell 的转义地狱。
|
||||||
b64 := base64.StdEncoding.EncodeToString([]byte(in.Content))
|
b64 := base64.StdEncoding.EncodeToString([]byte(in.Content))
|
||||||
if targetOS == "windows" {
|
if targetOS == "windows" {
|
||||||
|
path = normalizeWindowsCmdPath(path)
|
||||||
return buildWindowsPowerShellWrite(path, b64), nil
|
return buildWindowsPowerShellWrite(path, b64), nil
|
||||||
}
|
}
|
||||||
return "echo '" + b64 + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
return "echo '" + b64 + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
||||||
@@ -260,6 +278,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
|||||||
return "", errFileOpUploadTooLarge
|
return "", errFileOpUploadTooLarge
|
||||||
}
|
}
|
||||||
if targetOS == "windows" {
|
if targetOS == "windows" {
|
||||||
|
path = normalizeWindowsCmdPath(path)
|
||||||
return buildWindowsPowerShellWrite(path, in.Content), nil
|
return buildWindowsPowerShellWrite(path, in.Content), nil
|
||||||
}
|
}
|
||||||
return "echo '" + in.Content + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
return "echo '" + in.Content + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
||||||
@@ -269,6 +288,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
|||||||
return "", errFileOpPathRequired
|
return "", errFileOpPathRequired
|
||||||
}
|
}
|
||||||
if targetOS == "windows" {
|
if targetOS == "windows" {
|
||||||
|
path = normalizeWindowsCmdPath(path)
|
||||||
if in.ChunkIndex == 0 {
|
if in.ChunkIndex == 0 {
|
||||||
return buildWindowsPowerShellWrite(path, in.Content), nil
|
return buildWindowsPowerShellWrite(path, in.Content), nil
|
||||||
}
|
}
|
||||||
@@ -318,8 +338,12 @@ func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler {
|
|||||||
return &WebShellHandler{
|
return &WebShellHandler{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
Transport: &http.Transport{DisableKeepAlives: false},
|
Transport: &http.Transport{
|
||||||
|
DisableKeepAlives: false,
|
||||||
|
// WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy
|
||||||
|
},
|
||||||
},
|
},
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `s
|
|||||||
|
|
||||||
// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。
|
// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。
|
||||||
// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。
|
// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。
|
||||||
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base"
|
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_vulnerabilities、get_vulnerability、upsert_project_fact、get_project_fact、list_project_facts、search_project_facts、deprecate_project_fact、restore_project_fact、list_knowledge_risk_types、search_knowledge_base"
|
||||||
|
|
||||||
// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。
|
// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。
|
||||||
// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、
|
// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、
|
||||||
@@ -65,7 +65,7 @@ func BuildWebshellAssistantContext(conn *database.WebShellConnection, skillHint,
|
|||||||
b.WriteString(conn.ID)
|
b.WriteString(conn.ID)
|
||||||
b.WriteString("\"):")
|
b.WriteString("\"):")
|
||||||
b.WriteString(webshellAssistantToolList)
|
b.WriteString(webshellAssistantToolList)
|
||||||
b.WriteString("。")
|
b.WriteString("。边渗透边记录:每确认新认知即 upsert_project_fact,每验证漏洞即 record_vulnerability,勿等会话结束。")
|
||||||
b.WriteString(skillHint)
|
b.WriteString(skillHint)
|
||||||
b.WriteString("\n\n用户请求:")
|
b.WriteString("\n\n用户请求:")
|
||||||
b.WriteString(userMsg)
|
b.WriteString(userMsg)
|
||||||
|
|||||||
@@ -4,12 +4,25 @@ package builtin
|
|||||||
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
|
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
|
||||||
const (
|
const (
|
||||||
// 漏洞管理工具
|
// 漏洞管理工具
|
||||||
ToolRecordVulnerability = "record_vulnerability"
|
ToolRecordVulnerability = "record_vulnerability"
|
||||||
|
ToolListVulnerabilities = "list_vulnerabilities"
|
||||||
|
ToolGetVulnerability = "get_vulnerability"
|
||||||
|
|
||||||
|
// 项目黑板(事实)工具
|
||||||
|
ToolUpsertProjectFact = "upsert_project_fact"
|
||||||
|
ToolGetProjectFact = "get_project_fact"
|
||||||
|
ToolListProjectFacts = "list_project_facts"
|
||||||
|
ToolSearchProjectFacts = "search_project_facts"
|
||||||
|
ToolDeprecateProjectFact = "deprecate_project_fact"
|
||||||
|
ToolRestoreProjectFact = "restore_project_fact"
|
||||||
|
|
||||||
// 知识库工具
|
// 知识库工具
|
||||||
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
|
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
|
||||||
ToolSearchKnowledgeBase = "search_knowledge_base"
|
ToolSearchKnowledgeBase = "search_knowledge_base"
|
||||||
|
|
||||||
|
// 视觉分析(本地图片 → VL 模型 → 文本摘要)
|
||||||
|
ToolAnalyzeImage = "analyze_image"
|
||||||
|
|
||||||
// WebShell 助手工具(AI 在 WebShell 管理 - AI 助手 中使用)
|
// WebShell 助手工具(AI 在 WebShell 管理 - AI 助手 中使用)
|
||||||
ToolWebshellExec = "webshell_exec"
|
ToolWebshellExec = "webshell_exec"
|
||||||
ToolWebshellFileList = "webshell_file_list"
|
ToolWebshellFileList = "webshell_file_list"
|
||||||
@@ -53,8 +66,17 @@ const (
|
|||||||
func IsBuiltinTool(toolName string) bool {
|
func IsBuiltinTool(toolName string) bool {
|
||||||
switch toolName {
|
switch toolName {
|
||||||
case ToolRecordVulnerability,
|
case ToolRecordVulnerability,
|
||||||
|
ToolListVulnerabilities,
|
||||||
|
ToolGetVulnerability,
|
||||||
|
ToolUpsertProjectFact,
|
||||||
|
ToolGetProjectFact,
|
||||||
|
ToolListProjectFacts,
|
||||||
|
ToolSearchProjectFacts,
|
||||||
|
ToolDeprecateProjectFact,
|
||||||
|
ToolRestoreProjectFact,
|
||||||
ToolListKnowledgeRiskTypes,
|
ToolListKnowledgeRiskTypes,
|
||||||
ToolSearchKnowledgeBase,
|
ToolSearchKnowledgeBase,
|
||||||
|
ToolAnalyzeImage,
|
||||||
ToolWebshellExec,
|
ToolWebshellExec,
|
||||||
ToolWebshellFileList,
|
ToolWebshellFileList,
|
||||||
ToolWebshellFileRead,
|
ToolWebshellFileRead,
|
||||||
@@ -96,8 +118,17 @@ func IsBuiltinTool(toolName string) bool {
|
|||||||
func GetAllBuiltinTools() []string {
|
func GetAllBuiltinTools() []string {
|
||||||
return []string{
|
return []string{
|
||||||
ToolRecordVulnerability,
|
ToolRecordVulnerability,
|
||||||
|
ToolListVulnerabilities,
|
||||||
|
ToolGetVulnerability,
|
||||||
|
ToolUpsertProjectFact,
|
||||||
|
ToolGetProjectFact,
|
||||||
|
ToolListProjectFacts,
|
||||||
|
ToolSearchProjectFacts,
|
||||||
|
ToolDeprecateProjectFact,
|
||||||
|
ToolRestoreProjectFact,
|
||||||
ToolListKnowledgeRiskTypes,
|
ToolListKnowledgeRiskTypes,
|
||||||
ToolSearchKnowledgeBase,
|
ToolSearchKnowledgeBase,
|
||||||
|
ToolAnalyzeImage,
|
||||||
ToolWebshellExec,
|
ToolWebshellExec,
|
||||||
ToolWebshellFileList,
|
ToolWebshellFileList,
|
||||||
ToolWebshellFileRead,
|
ToolWebshellFileRead,
|
||||||
|
|||||||
@@ -44,11 +44,12 @@ func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, log
|
|||||||
|
|
||||||
// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient
|
// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient
|
||||||
type lazySDKClient struct {
|
type lazySDKClient struct {
|
||||||
serverCfg config.ExternalMCPServerConfig
|
serverCfg config.ExternalMCPServerConfig
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
inner ExternalMCPClient // 连接成功后为 *sdkClient
|
sessionCancel context.CancelFunc
|
||||||
mu sync.RWMutex
|
inner ExternalMCPClient // connected SDK client
|
||||||
status string
|
mu sync.RWMutex
|
||||||
|
status string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient {
|
func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient {
|
||||||
@@ -92,14 +93,61 @@ func (c *lazySDKClient) Initialize(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
inner, err := createSDKClient(ctx, c.serverCfg, c.logger)
|
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
||||||
if err != nil {
|
type connectResult struct {
|
||||||
|
inner ExternalMCPClient
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan connectResult)
|
||||||
|
abandoned := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
inner, err := createSDKClient(sessionCtx, c.serverCfg, c.logger)
|
||||||
|
select {
|
||||||
|
case resultCh <- connectResult{inner: inner, err: err}:
|
||||||
|
case <-abandoned:
|
||||||
|
if inner != nil {
|
||||||
|
_ = inner.Close()
|
||||||
|
}
|
||||||
|
sessionCancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var result connectResult
|
||||||
|
select {
|
||||||
|
case result = <-resultCh:
|
||||||
|
case <-ctx.Done():
|
||||||
|
close(abandoned)
|
||||||
|
sessionCancel()
|
||||||
|
c.setStatus("error")
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
sessionCancel()
|
||||||
|
if result.inner != nil {
|
||||||
|
_ = result.inner.Close()
|
||||||
|
}
|
||||||
c.setStatus("error")
|
c.setStatus("error")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if result.err != nil {
|
||||||
|
sessionCancel()
|
||||||
|
c.setStatus("error")
|
||||||
|
return result.err
|
||||||
|
}
|
||||||
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
c.inner = inner
|
if c.inner != nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
sessionCancel()
|
||||||
|
if result.inner != nil {
|
||||||
|
_ = result.inner.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.inner = result.inner
|
||||||
|
c.sessionCancel = sessionCancel
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
c.setStatus("connected")
|
c.setStatus("connected")
|
||||||
return nil
|
return nil
|
||||||
@@ -128,9 +176,14 @@ func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[stri
|
|||||||
func (c *lazySDKClient) Close() error {
|
func (c *lazySDKClient) Close() error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
inner := c.inner
|
inner := c.inner
|
||||||
|
sessionCancel := c.sessionCancel
|
||||||
c.inner = nil
|
c.inner = nil
|
||||||
|
c.sessionCancel = nil
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
c.setStatus("disconnected")
|
c.setStatus("disconnected")
|
||||||
|
if sessionCancel != nil {
|
||||||
|
sessionCancel()
|
||||||
|
}
|
||||||
if inner != nil {
|
if inner != nil {
|
||||||
return inner.Close()
|
return inner.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,6 +77,9 @@ type einoADKRunLoopArgs struct {
|
|||||||
StreamsMainAssistant func(agent string) bool
|
StreamsMainAssistant func(agent string) bool
|
||||||
EinoRoleTag func(agent string) string
|
EinoRoleTag func(agent string) string
|
||||||
CheckpointDir string
|
CheckpointDir string
|
||||||
|
// RunRetryMaxAttempts / RunRetryMaxBackoffSec:429、5xx、网络抖动时的指数退避续跑(0=默认 10 次 / 30s 上限)。
|
||||||
|
RunRetryMaxAttempts int
|
||||||
|
RunRetryMaxBackoffSec int
|
||||||
|
|
||||||
McpIDsMu *sync.Mutex
|
McpIDsMu *sync.Mutex
|
||||||
McpIDs *[]string
|
McpIDs *[]string
|
||||||
@@ -181,14 +184,19 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
mainAgentToolStep := make(map[string]int)
|
mainAgentToolStep := make(map[string]int)
|
||||||
pendingByID := make(map[string]toolCallPendingInfo)
|
pendingByID := make(map[string]toolCallPendingInfo)
|
||||||
pendingQueueByAgent := make(map[string][]string)
|
pendingQueueByAgent := make(map[string][]string)
|
||||||
|
var pendingMu sync.Mutex
|
||||||
markPending := func(tc toolCallPendingInfo) {
|
markPending := func(tc toolCallPendingInfo) {
|
||||||
if tc.ToolCallID == "" {
|
if tc.ToolCallID == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
pendingMu.Lock()
|
||||||
|
defer pendingMu.Unlock()
|
||||||
pendingByID[tc.ToolCallID] = tc
|
pendingByID[tc.ToolCallID] = tc
|
||||||
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
||||||
}
|
}
|
||||||
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
||||||
|
pendingMu.Lock()
|
||||||
|
defer pendingMu.Unlock()
|
||||||
q := pendingQueueByAgent[agentName]
|
q := pendingQueueByAgent[agentName]
|
||||||
for len(q) > 0 {
|
for len(q) > 0 {
|
||||||
id := q[0]
|
id := q[0]
|
||||||
@@ -205,19 +213,42 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
if toolCallID == "" {
|
if toolCallID == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
pendingMu.Lock()
|
||||||
|
defer pendingMu.Unlock()
|
||||||
delete(pendingByID, toolCallID)
|
delete(pendingByID, toolCallID)
|
||||||
}
|
}
|
||||||
|
popAnyPending := func() (toolCallPendingInfo, bool) {
|
||||||
|
pendingMu.Lock()
|
||||||
|
defer pendingMu.Unlock()
|
||||||
|
for id, tc := range pendingByID {
|
||||||
|
delete(pendingByID, id)
|
||||||
|
return tc, true
|
||||||
|
}
|
||||||
|
return toolCallPendingInfo{}, false
|
||||||
|
}
|
||||||
|
pendingCount := func() int {
|
||||||
|
pendingMu.Lock()
|
||||||
|
defer pendingMu.Unlock()
|
||||||
|
return len(pendingByID)
|
||||||
|
}
|
||||||
flushAllPendingAsFailed := func(err error) {
|
flushAllPendingAsFailed := func(err error) {
|
||||||
|
pendingMu.Lock()
|
||||||
|
pendingSnapshot := make([]toolCallPendingInfo, 0, len(pendingByID))
|
||||||
|
for _, tc := range pendingByID {
|
||||||
|
pendingSnapshot = append(pendingSnapshot, tc)
|
||||||
|
}
|
||||||
|
pendingByID = make(map[string]toolCallPendingInfo)
|
||||||
|
pendingQueueByAgent = make(map[string][]string)
|
||||||
|
pendingMu.Unlock()
|
||||||
|
|
||||||
if progress == nil {
|
if progress == nil {
|
||||||
pendingByID = make(map[string]toolCallPendingInfo)
|
|
||||||
pendingQueueByAgent = make(map[string][]string)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msg := ""
|
msg := ""
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg = err.Error()
|
msg = err.Error()
|
||||||
}
|
}
|
||||||
for _, tc := range pendingByID {
|
for _, tc := range pendingSnapshot {
|
||||||
toolName := tc.ToolName
|
toolName := tc.ToolName
|
||||||
if strings.TrimSpace(toolName) == "" {
|
if strings.TrimSpace(toolName) == "" {
|
||||||
toolName = "unknown"
|
toolName = "unknown"
|
||||||
@@ -235,8 +266,6 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"source": "eino",
|
"source": "eino",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
pendingByID = make(map[string]toolCallPendingInfo)
|
|
||||||
pendingQueueByAgent = make(map[string][]string)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 最近一次成功的 Eino filesystem execute 的标准输出(trim):用于抑制模型紧接着复述同一字符串时的重复「助手输出」时间线。
|
// 最近一次成功的 Eino filesystem execute 的标准输出(trim):用于抑制模型紧接着复述同一字符串时的重复「助手输出」时间线。
|
||||||
@@ -316,7 +345,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
|
|
||||||
runnerCfg := adk.RunnerConfig{
|
runnerCfg := adk.RunnerConfig{
|
||||||
Agent: da,
|
Agent: da,
|
||||||
|
// 启用 ADK 流式事件:plan_execute 也需要输出 reasoning/response 流,
|
||||||
|
// 与 deep/supervisor/eino_single 的前端体验保持一致。
|
||||||
EnableStreaming: true,
|
EnableStreaming: true,
|
||||||
}
|
}
|
||||||
var cpStore *fileCheckPointStore
|
var cpStore *fileCheckPointStore
|
||||||
@@ -437,6 +468,28 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
return runErr
|
return runErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
|
||||||
|
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
|
||||||
|
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||||
|
return false, handleRunErr(runErr)
|
||||||
|
}
|
||||||
|
if logger != nil {
|
||||||
|
logger.Warn("eino transient error, ending run segment for handler resume",
|
||||||
|
zap.Error(runErr),
|
||||||
|
zap.String("orchestration", orchMode))
|
||||||
|
}
|
||||||
|
if progress != nil {
|
||||||
|
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"orchestration": orchMode,
|
||||||
|
"error": runErr.Error(),
|
||||||
|
"resumeKind": "trace_segment",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return false, ErrTransientRetryContinue
|
||||||
|
}
|
||||||
|
|
||||||
takePartial := func(runErr error) (*RunResult, error) {
|
takePartial := func(runErr error) (*RunResult, error) {
|
||||||
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
||||||
return nil, runErr
|
return nil, runErr
|
||||||
@@ -494,8 +547,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
return takePartial(ctxErr)
|
return takePartial(ctxErr)
|
||||||
}
|
}
|
||||||
if len(pendingByID) > 0 {
|
if orphanCount := pendingCount(); orphanCount > 0 {
|
||||||
orphanCount := len(pendingByID)
|
|
||||||
flushAllPendingAsFailed(errors.New("pending tool call missing result before run completion"))
|
flushAllPendingAsFailed(errors.New("pending tool call missing result before run completion"))
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
progress("eino_pending_orphaned", "pending tool calls were force-closed at run end", map[string]interface{}{
|
progress("eino_pending_orphaned", "pending tool calls were force-closed at run end", map[string]interface{}{
|
||||||
@@ -519,7 +571,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if ev.Err != nil {
|
if ev.Err != nil {
|
||||||
if retErr := handleRunErr(ev.Err); retErr != nil {
|
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
|
||||||
return takePartial(retErr)
|
return takePartial(retErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -821,7 +873,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoRole": einoRoleTag(ev.AgentName),
|
"einoRole": einoRoleTag(ev.AgentName),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if retErr := handleRunErr(streamRecvErr); retErr != nil {
|
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
|
||||||
return takePartial(retErr)
|
return takePartial(retErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -932,12 +984,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
toolCallID = inferred.ToolCallID
|
toolCallID = inferred.ToolCallID
|
||||||
} else if inferred, ok := popNextPendingForAgent(""); ok {
|
} else if inferred, ok := popNextPendingForAgent(""); ok {
|
||||||
toolCallID = inferred.ToolCallID
|
toolCallID = inferred.ToolCallID
|
||||||
} else {
|
} else if inferred, ok := popAnyPending(); ok {
|
||||||
for id := range pendingByID {
|
toolCallID = inferred.ToolCallID
|
||||||
toolCallID = id
|
|
||||||
delete(pendingByID, id)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if toolCallID != "" {
|
if toolCallID != "" {
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
|||||||
}
|
}
|
||||||
plannerCfg := &planexecute.PlannerConfig{
|
plannerCfg := &planexecute.PlannerConfig{
|
||||||
ToolCallingChatModel: tcm,
|
ToolCallingChatModel: tcm,
|
||||||
|
NewPlan: newLenientPlan,
|
||||||
}
|
}
|
||||||
if fn := planExecutePlannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers); fn != nil {
|
if fn := planExecutePlannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers); fn != nil {
|
||||||
plannerCfg.GenInputFn = fn
|
plannerCfg.GenInputFn = fn
|
||||||
@@ -70,6 +71,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
|||||||
replanner, err := planexecute.NewReplanner(ctx, &planexecute.ReplannerConfig{
|
replanner, err := planexecute.NewReplanner(ctx, &planexecute.ReplannerConfig{
|
||||||
ChatModel: tcm,
|
ChatModel: tcm,
|
||||||
GenInputFn: planExecuteReplannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers),
|
GenInputFn: planExecuteReplannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers),
|
||||||
|
NewPlan: newLenientPlan,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("plan_execute replanner: %w", err)
|
return nil, fmt.Errorf("plan_execute replanner: %w", err)
|
||||||
@@ -146,14 +148,12 @@ func planExecutePlannerGenInput(
|
|||||||
}
|
}
|
||||||
return func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) {
|
return func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) {
|
||||||
userInput = capPlanExecuteUserInputMessages(userInput, appCfg, mwCfg)
|
userInput = capPlanExecuteUserInputMessages(userInput, appCfg, mwCfg)
|
||||||
msgs := make([]adk.Message, 0, 1+len(userInput))
|
msgs := make([]adk.Message, 0, len(userInput))
|
||||||
if oi != "" {
|
|
||||||
msgs = append(msgs, schema.SystemMessage(oi))
|
|
||||||
}
|
|
||||||
msgs = append(msgs, userInput...)
|
msgs = append(msgs, userInput...)
|
||||||
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
|
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
|
||||||
msgs = rewritten
|
msgs = rewritten
|
||||||
}
|
}
|
||||||
|
msgs = normalizeSingleLeadingSystemMessage(msgs, oi)
|
||||||
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_planner", msgs)
|
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_planner", msgs)
|
||||||
return msgs, nil
|
return msgs, nil
|
||||||
}
|
}
|
||||||
@@ -182,9 +182,7 @@ func planExecuteExecutorGenInput(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if oi != "" {
|
userMsgs = normalizeSingleLeadingSystemMessage(userMsgs, oi)
|
||||||
userMsgs = append([]adk.Message{schema.SystemMessage(oi)}, userMsgs...)
|
|
||||||
}
|
|
||||||
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_executor_gen_input", userMsgs)
|
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_executor_gen_input", userMsgs)
|
||||||
return userMsgs, nil
|
return userMsgs, nil
|
||||||
}
|
}
|
||||||
@@ -231,17 +229,54 @@ func planExecuteReplannerGenInput(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if oi != "" {
|
|
||||||
msgs = append([]adk.Message{schema.SystemMessage(oi)}, msgs...)
|
|
||||||
}
|
|
||||||
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
|
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
|
||||||
msgs = rewritten
|
msgs = rewritten
|
||||||
}
|
}
|
||||||
|
msgs = normalizeSingleLeadingSystemMessage(msgs, oi)
|
||||||
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_replanner", msgs)
|
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_replanner", msgs)
|
||||||
return msgs, nil
|
return msgs, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeSingleLeadingSystemMessage enforces a provider-friendly message shape:
|
||||||
|
// exactly one system message at index 0 (when any system context exists).
|
||||||
|
// For strict OpenAI-compatible backends (e.g. qwen/vllm templates), this avoids
|
||||||
|
// "System message must be at the beginning" caused by multiple/disordered system messages.
|
||||||
|
func normalizeSingleLeadingSystemMessage(msgs []adk.Message, extraSystem string) []adk.Message {
|
||||||
|
extraSystem = strings.TrimSpace(extraSystem)
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
if extraSystem == "" {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
return []adk.Message{schema.SystemMessage(extraSystem)}
|
||||||
|
}
|
||||||
|
|
||||||
|
systemParts := make([]string, 0, 2)
|
||||||
|
if extraSystem != "" {
|
||||||
|
systemParts = append(systemParts, extraSystem)
|
||||||
|
}
|
||||||
|
nonSystem := make([]adk.Message, 0, len(msgs))
|
||||||
|
for _, msg := range msgs {
|
||||||
|
if msg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if msg.Role == schema.System {
|
||||||
|
if s := strings.TrimSpace(msg.Content); s != "" {
|
||||||
|
systemParts = append(systemParts, s)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nonSystem = append(nonSystem, msg)
|
||||||
|
}
|
||||||
|
if len(systemParts) == 0 {
|
||||||
|
return nonSystem
|
||||||
|
}
|
||||||
|
out := make([]adk.Message, 0, len(nonSystem)+1)
|
||||||
|
out = append(out, schema.SystemMessage(strings.Join(systemParts, "\n\n")))
|
||||||
|
out = append(out, nonSystem...)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func capPlanExecuteUserInputMessages(input []adk.Message, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
|
func capPlanExecuteUserInputMessages(input []adk.Message, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
|
||||||
if len(input) == 0 {
|
if len(input) == 0 {
|
||||||
return input
|
return input
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeSingleLeadingSystemMessage_MergesMultipleSystems(t *testing.T) {
|
||||||
|
in := []adk.Message{
|
||||||
|
schema.SystemMessage("sys-1"),
|
||||||
|
schema.UserMessage("u1"),
|
||||||
|
schema.SystemMessage("sys-2"),
|
||||||
|
schema.AssistantMessage("a1", nil),
|
||||||
|
}
|
||||||
|
out := normalizeSingleLeadingSystemMessage(in, "orch")
|
||||||
|
if len(out) != 3 {
|
||||||
|
t.Fatalf("unexpected output length: got %d want 3", len(out))
|
||||||
|
}
|
||||||
|
if out[0].Role != schema.System {
|
||||||
|
t.Fatalf("first message role must be system, got %s", out[0].Role)
|
||||||
|
}
|
||||||
|
if got := out[0].Content; got != "orch\n\nsys-1\n\nsys-2" {
|
||||||
|
t.Fatalf("unexpected merged system content: %q", got)
|
||||||
|
}
|
||||||
|
if out[1].Role != schema.User || out[2].Role != schema.Assistant {
|
||||||
|
t.Fatalf("non-system message order changed unexpectedly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeSingleLeadingSystemMessage_NoSystemKeepsFlow(t *testing.T) {
|
||||||
|
in := []adk.Message{
|
||||||
|
schema.UserMessage("u1"),
|
||||||
|
schema.AssistantMessage("a1", nil),
|
||||||
|
}
|
||||||
|
out := normalizeSingleLeadingSystemMessage(in, "")
|
||||||
|
if len(out) != 2 {
|
||||||
|
t.Fatalf("unexpected output length: got %d want 2", len(out))
|
||||||
|
}
|
||||||
|
if out[0].Role != schema.User || out[1].Role != schema.Assistant {
|
||||||
|
t.Fatalf("message order changed unexpectedly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -13,12 +13,12 @@ import (
|
|||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
"cyberstrike-ai/internal/reasoning"
|
"cyberstrike-ai/internal/reasoning"
|
||||||
|
|
||||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
"github.com/cloudwego/eino/compose"
|
"github.com/cloudwego/eino/compose"
|
||||||
"github.com/cloudwego/eino/schema"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ import (
|
|||||||
const einoSingleAgentName = "cyberstrike-eino-single"
|
const einoSingleAgentName = "cyberstrike-eino-single"
|
||||||
|
|
||||||
// RunEinoSingleChatModelAgent 使用 Eino adk.NewChatModelAgent + adk.NewRunner.Run(官方 Quick Start 的 Query 同属 Runner API;此处用历史 + 用户消息切片等价于多轮 Query)。
|
// RunEinoSingleChatModelAgent 使用 Eino adk.NewChatModelAgent + adk.NewRunner.Run(官方 Quick Start 的 Query 同属 Runner API;此处用历史 + 用户消息切片等价于多轮 Query)。
|
||||||
// 不替代既有原生 ReAct;与 RunDeepAgent 共享 runEinoADKAgentLoop 的 SSE 映射与 MCP 桥。
|
// 与 RunDeepAgent 共享 runEinoADKAgentLoop 的 SSE 映射与 MCP 桥。
|
||||||
func RunEinoSingleChatModelAgent(
|
func RunEinoSingleChatModelAgent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
appCfg *config.Config,
|
appCfg *config.Config,
|
||||||
@@ -39,6 +39,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
roleTools []string,
|
roleTools []string,
|
||||||
progress func(eventType, message string, data interface{}),
|
progress func(eventType, message string, data interface{}),
|
||||||
reasoningClient *reasoning.ClientIntent,
|
reasoningClient *reasoning.ClientIntent,
|
||||||
|
systemPromptExtra string,
|
||||||
) (*RunResult, error) {
|
) (*RunResult, error) {
|
||||||
if appCfg == nil || ag == nil {
|
if appCfg == nil || ag == nil {
|
||||||
return nil, fmt.Errorf("eino single: 配置或 Agent 为空")
|
return nil, fmt.Errorf("eino single: 配置或 Agent 为空")
|
||||||
@@ -178,7 +179,9 @@ func RunEinoSingleChatModelAgent(
|
|||||||
},
|
},
|
||||||
EmitInternalEvents: true,
|
EmitInternalEvents: true,
|
||||||
}
|
}
|
||||||
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools, singleToolSearchActive)
|
ins := project.AppendSystemPromptBlock(ag.EinoSingleAgentSystemInstruction(), systemPromptExtra)
|
||||||
|
ins = project.AppendVisionImageAnalysisIfReady(ins, appCfg.Vision.Ready())
|
||||||
|
ins = injectToolNamesOnlyInstruction(ctx, ins, mainTools, singleToolSearchActive)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
names := collectToolNames(ctx, mainTools)
|
names := collectToolNames(ctx, mainTools)
|
||||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||||
@@ -213,7 +216,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
||||||
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
|
baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
|
||||||
|
|
||||||
streamsMainAssistant := func(agent string) bool {
|
streamsMainAssistant := func(agent string) bool {
|
||||||
return agent == "" || agent == einoSingleAgentName
|
return agent == "" || agent == einoSingleAgentName
|
||||||
@@ -233,6 +236,8 @@ func RunEinoSingleChatModelAgent(
|
|||||||
StreamsMainAssistant: streamsMainAssistant,
|
StreamsMainAssistant: streamsMainAssistant,
|
||||||
EinoRoleTag: einoRoleTag,
|
EinoRoleTag: einoRoleTag,
|
||||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||||
|
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
|
||||||
|
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
|
||||||
McpIDsMu: &mcpIDsMu,
|
McpIDsMu: &mcpIDsMu,
|
||||||
McpIDs: &mcpIDs,
|
McpIDs: &mcpIDs,
|
||||||
FilesystemMonitorAgent: ag,
|
FilesystemMonitorAgent: ag,
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// einoSummarizeUserInstruction 与单 Agent MemoryCompressor 目标一致:压缩时保留渗透关键信息。
|
// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。
|
||||||
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。
|
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。
|
||||||
|
|
||||||
必须保留:已确认漏洞与攻击路径、工具输出中的核心发现、凭证与认证细节、架构与薄弱点、当前进度、失败尝试与死路、策略决策。
|
必须保留:已确认漏洞与攻击路径、工具输出中的核心发现、凭证与认证细节、架构与薄弱点、当前进度、失败尝试与死路、策略决策。
|
||||||
@@ -29,7 +29,7 @@ const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完
|
|||||||
输出须使后续代理能无缝继续同一授权测试任务。`
|
输出须使后续代理能无缝继续同一授权测试任务。`
|
||||||
|
|
||||||
// newEinoSummarizationMiddleware 使用 Eino ADK Summarization 中间件(见 https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_summarization/)。
|
// newEinoSummarizationMiddleware 使用 Eino ADK Summarization 中间件(见 https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_summarization/)。
|
||||||
// 触发阈值与单 Agent MemoryCompressor 一致:当估算 token 超过 openai.max_total_tokens 的 90% 时摘要。
|
// 触发阈值:估算 token 超过 openai.max_total_tokens * summarization_trigger_ratio(默认 0.8)时摘要。
|
||||||
func newEinoSummarizationMiddleware(
|
func newEinoSummarizationMiddleware(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
summaryModel model.BaseChatModel,
|
summaryModel model.BaseChatModel,
|
||||||
|
|||||||
@@ -0,0 +1,173 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultEinoRunRetryMaxAttempts = 10
|
||||||
|
defaultEinoRunRetryMaxBackoff = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等)。
|
||||||
|
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列。
|
||||||
|
func isEinoTransientRunError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if isEinoIterationLimitError(err) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msg := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||||
|
if msg == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
transientMarkers := []string{
|
||||||
|
"406",
|
||||||
|
"429",
|
||||||
|
"too many requests",
|
||||||
|
"rate limit",
|
||||||
|
"rate_limit",
|
||||||
|
"ratelimit",
|
||||||
|
"quota exceeded",
|
||||||
|
"overloaded",
|
||||||
|
"capacity",
|
||||||
|
"temporarily unavailable",
|
||||||
|
"service unavailable",
|
||||||
|
"bad gateway",
|
||||||
|
"gateway timeout",
|
||||||
|
"internal server error",
|
||||||
|
"connection reset",
|
||||||
|
"connection refused",
|
||||||
|
"connection closed",
|
||||||
|
"i/o timeout",
|
||||||
|
"no such host",
|
||||||
|
"network is unreachable",
|
||||||
|
"broken pipe",
|
||||||
|
"read tcp",
|
||||||
|
"write tcp",
|
||||||
|
"dial tcp",
|
||||||
|
"tls handshake timeout",
|
||||||
|
"stream error",
|
||||||
|
"unexpected eof",
|
||||||
|
`": eof`, // net/http: Post "url": EOF (often wraps io.EOF)
|
||||||
|
"unexpected end of json",
|
||||||
|
"status code: 406",
|
||||||
|
"status code: 502",
|
||||||
|
"502",
|
||||||
|
"503",
|
||||||
|
"504",
|
||||||
|
"500",
|
||||||
|
}
|
||||||
|
for _, m := range transientMarkers {
|
||||||
|
if strings.Contains(msg, m) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||||
|
if args != nil && args.RunRetryMaxAttempts > 0 {
|
||||||
|
return args.RunRetryMaxAttempts
|
||||||
|
}
|
||||||
|
return defaultEinoRunRetryMaxAttempts
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致)。
|
||||||
|
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
||||||
|
if mw != nil && mw.RunRetryMaxAttempts > 0 {
|
||||||
|
return mw.RunRetryMaxAttempts
|
||||||
|
}
|
||||||
|
return defaultEinoRunRetryMaxAttempts
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransientRetryBackoff 供 handler 在分段续跑前退避。
|
||||||
|
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
|
||||||
|
max := defaultEinoRunRetryMaxBackoff
|
||||||
|
if maxBackoffSec > 0 {
|
||||||
|
max = time.Duration(maxBackoffSec) * time.Second
|
||||||
|
}
|
||||||
|
return einoTransientRetryBackoff(attempt, max)
|
||||||
|
}
|
||||||
|
|
||||||
|
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
|
||||||
|
if args != nil && args.RunRetryMaxBackoffSec > 0 {
|
||||||
|
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
|
||||||
|
}
|
||||||
|
return defaultEinoRunRetryMaxBackoff
|
||||||
|
}
|
||||||
|
|
||||||
|
// einoRunRestartContextSource 描述无 checkpoint Resume 时 Run 使用的消息来源(日志/SSE)。
|
||||||
|
type einoRunRestartContextSource string
|
||||||
|
|
||||||
|
const (
|
||||||
|
einoRestartContextInitial einoRunRestartContextSource = "initial"
|
||||||
|
einoRestartContextAccumulated einoRunRestartContextSource = "accumulated"
|
||||||
|
einoRestartContextModelTrace einoRunRestartContextSource = "model_trace"
|
||||||
|
)
|
||||||
|
|
||||||
|
// einoMessagesForRunRestart 在退避后重新 Run 时选用最完整的上下文:
|
||||||
|
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
|
||||||
|
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
|
||||||
|
if trace := persistTraceSource(args, nil); len(trace) > 0 {
|
||||||
|
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace
|
||||||
|
}
|
||||||
|
if len(accumulated) > baseCount {
|
||||||
|
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated
|
||||||
|
}
|
||||||
|
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
|
||||||
|
}
|
||||||
|
|
||||||
|
// adkMessagesHasUserContent 从尾部向前查找,是否已有与 want 相同的 user 消息(避免重复 append)。
|
||||||
|
func adkMessagesHasUserContent(msgs []adk.Message, want string) bool {
|
||||||
|
want = strings.TrimSpace(want)
|
||||||
|
if want == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for i := len(msgs) - 1; i >= 0; i-- {
|
||||||
|
m := msgs[i]
|
||||||
|
if m == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m.Role == schema.User {
|
||||||
|
return strings.TrimSpace(m.Content) == want
|
||||||
|
}
|
||||||
|
if m.Role == schema.Assistant || m.Role == schema.Tool {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当轨迹中尚未包含该句)。
|
||||||
|
func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message {
|
||||||
|
if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
return append(msgs, schema.UserMessage(userMessage))
|
||||||
|
}
|
||||||
|
|
||||||
|
// einoTransientRetryBackoff 指数退避:2s, 4s, 8s… capped by maxBackoff。
|
||||||
|
func einoTransientRetryBackoff(attempt int, maxBackoff time.Duration) time.Duration {
|
||||||
|
if attempt < 0 {
|
||||||
|
attempt = 0
|
||||||
|
}
|
||||||
|
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
|
||||||
|
if maxBackoff > 0 && backoff > maxBackoff {
|
||||||
|
backoff = maxBackoff
|
||||||
|
}
|
||||||
|
return backoff
|
||||||
|
}
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsEinoTransientRunError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"nil", nil, false},
|
||||||
|
{"io eof", io.EOF, false},
|
||||||
|
{"plain eof text", errors.New("EOF"), false},
|
||||||
|
{"post chat completions eof", errors.New(`Post "https://token-plan-cn.xiaomimimo.com/v1/chat/completions": EOF`), true},
|
||||||
|
{"post eof wraps io.EOF", fmt.Errorf(`Post %q: %w`, "https://token-plan-cn.xiaomimimo.com/v1/chat/completions", io.EOF), true},
|
||||||
|
{"429", errors.New("HTTP 429 Too Many Requests"), true},
|
||||||
|
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
|
||||||
|
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
|
||||||
|
{"unexpected eof", errors.New("unexpected EOF"), true},
|
||||||
|
{"503", errors.New("upstream returned 503"), true},
|
||||||
|
{"iteration limit", errors.New("max iteration reached"), false},
|
||||||
|
{"canceled", context.Canceled, false},
|
||||||
|
{"deadline", context.DeadlineExceeded, false},
|
||||||
|
{"auth", errors.New("invalid api key"), false},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if got := isEinoTransientRunError(tc.err); got != tc.want {
|
||||||
|
t.Fatalf("isEinoTransientRunError(%v) = %v, want %v", tc.err, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoTransientRetryBackoff(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
max := 30 * time.Second
|
||||||
|
if got := einoTransientRetryBackoff(0, max); got != 2*time.Second {
|
||||||
|
t.Fatalf("attempt 0: got %v", got)
|
||||||
|
}
|
||||||
|
if got := einoTransientRetryBackoff(4, max); got != 30*time.Second {
|
||||||
|
t.Fatalf("attempt 4 capped: got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoMessagesForRunRestart(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
base := []adk.Message{schema.UserMessage("hi")}
|
||||||
|
acc := append([]adk.Message(nil), base...)
|
||||||
|
acc = append(acc, schema.AssistantMessage("step1", nil))
|
||||||
|
|
||||||
|
got, src := einoMessagesForRunRestart(nil, base, acc, len(base))
|
||||||
|
if src != einoRestartContextAccumulated || len(got) != 2 {
|
||||||
|
t.Fatalf("accumulated: src=%s len=%d", src, len(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
holder := newModelFacingTraceHolder()
|
||||||
|
holder.storeFromState(&adk.ChatModelAgentState{
|
||||||
|
Messages: []adk.Message{schema.UserMessage("u"), schema.AssistantMessage("model-view", nil)},
|
||||||
|
})
|
||||||
|
got2, src2 := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, base, acc, len(base))
|
||||||
|
if src2 != einoRestartContextModelTrace || len(got2) != 2 {
|
||||||
|
t.Fatalf("model trace: src=%s len=%d", src2, len(got2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if einoRunRetryMaxAttempts(nil) != defaultEinoRunRetryMaxAttempts {
|
||||||
|
t.Fatal("nil args should use default")
|
||||||
|
}
|
||||||
|
if einoRunRetryMaxAttempts(&einoADKRunLoopArgs{RunRetryMaxAttempts: 3}) != 3 {
|
||||||
|
t.Fatal("custom max attempts")
|
||||||
|
}
|
||||||
|
if RunRetryMaxAttemptsFromConfig(nil) != defaultEinoRunRetryMaxAttempts {
|
||||||
|
t.Fatal("config nil should use default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
msgs := []adk.Message{schema.UserMessage("old task")}
|
||||||
|
out := appendUserMessageIfNeeded(msgs, "你好,你是谁")
|
||||||
|
if len(out) != 2 || out[1].Content != "你好,你是谁" {
|
||||||
|
t.Fatalf("should append user: len=%d", len(out))
|
||||||
|
}
|
||||||
|
dup := appendUserMessageIfNeeded(out, "你好,你是谁")
|
||||||
|
if len(dup) != 2 {
|
||||||
|
t.Fatalf("should not duplicate user message: len=%d", len(dup))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrTransientRetryContinue(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
|
||||||
|
t.Fatal("sentinel should match")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,3 +5,7 @@ import "errors"
|
|||||||
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
||||||
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
||||||
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
||||||
|
|
||||||
|
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
|
||||||
|
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
|
||||||
|
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/agents"
|
"cyberstrike-ai/internal/agents"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/project"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。
|
// DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。
|
||||||
@@ -106,16 +106,14 @@ func DefaultPlanExecuteOrchestratorInstruction() string {
|
|||||||
|
|
||||||
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
|
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
|
||||||
|
|
||||||
## 漏洞记录
|
` + project.FactRecordingBlackboardSection(true) + `
|
||||||
|
|
||||||
发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。
|
- **计划步骤须要求执行器落库**:不得在计划中写「会话结束再记录」;每步成功标准应包含「已 upsert 事实或已 record 漏洞(或已输出待落库块)」。
|
||||||
|
|
||||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。记录后可在授权范围内继续测试。
|
|
||||||
|
|
||||||
## 技能库(Skills)与知识库
|
## 技能库(Skills)与知识库
|
||||||
|
|
||||||
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
|
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
|
||||||
- plan_execute 执行器通过 MCP 使用知识库与漏洞记录等;Skills 的渐进式加载在「多代理 / Eino DeepAgent」等模式中由内置 skill 工具完成(需 multi_agent.eino_skills)。
|
- plan_execute 执行器通过 MCP 使用知识库、项目事实与漏洞记录等;Skills 的渐进式加载在「多代理 / Eino DeepAgent」等模式中由内置 skill 工具完成(需 multi_agent.eino_skills)。
|
||||||
- 若需要完整 Skill 工作流而当前会话无 skill 工具,请在计划或对用户说明中建议切换多代理或 Eino 编排会话。
|
- 若需要完整 Skill 工作流而当前会话无 skill 工具,请在计划或对用户说明中建议切换多代理或 Eino 编排会话。
|
||||||
|
|
||||||
## 执行器对用户输出(重要)
|
## 执行器对用户输出(重要)
|
||||||
@@ -206,7 +204,8 @@ func DefaultSupervisorOrchestratorInstruction() string {
|
|||||||
- **委派优先**:可独立封装、需要专项上下文的子目标(枚举、验证、归纳、报告素材)优先 transfer 给匹配子代理,并在委派说明中写清:子目标、约束、期望交付物结构、证据要求。
|
- **委派优先**:可独立封装、需要专项上下文的子目标(枚举、验证、归纳、报告素材)优先 transfer 给匹配子代理,并在委派说明中写清:子目标、约束、期望交付物结构、证据要求。
|
||||||
- **亲自执行**:仅当无合适专家、需全局衔接或子代理结果不足时,由你直接调用工具。
|
- **亲自执行**:仅当无合适专家、需全局衔接或子代理结果不足时,由你直接调用工具。
|
||||||
- **汇总**:子代理输出是证据来源;你要对齐矛盾、补全上下文,给出统一结论与可复现验证步骤,避免机械拼接。
|
- **汇总**:子代理输出是证据来源;你要对齐矛盾、补全上下文,给出统一结论与可复现验证步骤,避免机械拼接。
|
||||||
- **漏洞**:有效漏洞应通过 ` + builtin.ToolRecordVulnerability + ` 记录(含 POC 与严重性:critical / high / medium / low / info)。
|
|
||||||
|
` + project.FactRecordingBlackboardSection(true) + `
|
||||||
|
|
||||||
## transfer 交接与防重复劳动
|
## transfer 交接与防重复劳动
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,157 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk/prebuilt/planexecute"
|
||||||
|
)
|
||||||
|
|
||||||
|
// lenientPlan keeps plan_execute running even when model tool arguments contain minor JSON defects.
|
||||||
|
// It first tries strict JSON, then falls back to lightweight step extraction heuristics.
|
||||||
|
type lenientPlan struct {
|
||||||
|
Steps []string `json:"steps"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLenientPlan(context.Context) planexecute.Plan {
|
||||||
|
return &lenientPlan{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lenientPlan) FirstStep() string {
|
||||||
|
if p == nil || len(p.Steps) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return p.Steps[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lenientPlan) MarshalJSON() ([]byte, error) {
|
||||||
|
type alias lenientPlan
|
||||||
|
return json.Marshal((*alias)(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lenientPlan) UnmarshalJSON(b []byte) error {
|
||||||
|
type alias lenientPlan
|
||||||
|
var strict alias
|
||||||
|
if err := json.Unmarshal(b, &strict); err == nil {
|
||||||
|
strict.Steps = normalizePlanSteps(strict.Steps)
|
||||||
|
if len(strict.Steps) > 0 {
|
||||||
|
*p = lenientPlan(strict)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
steps := extractPlanStepsLenient(string(b))
|
||||||
|
if len(steps) == 0 {
|
||||||
|
steps = []string{"继续按当前目标执行下一步,并输出可验证证据。"}
|
||||||
|
}
|
||||||
|
p.Steps = steps
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractPlanStepsLenient(raw string) []string {
|
||||||
|
s := strings.TrimSpace(stripCodeFence(raw))
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if extracted, ok := sliceByStepsArray(s); ok {
|
||||||
|
var arr []string
|
||||||
|
if err := json.Unmarshal([]byte(extracted), &arr); err == nil {
|
||||||
|
arr = normalizePlanSteps(arr)
|
||||||
|
if len(arr) > 0 {
|
||||||
|
return arr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if arr := splitStepsHeuristically(strings.Trim(extracted, "[]")); len(arr) > 0 {
|
||||||
|
return arr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last-resort: treat plaintext body as one actionable step.
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{s}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sliceByStepsArray(s string) (string, bool) {
|
||||||
|
lower := strings.ToLower(s)
|
||||||
|
key := `"steps"`
|
||||||
|
i := strings.Index(lower, key)
|
||||||
|
if i < 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
start := strings.Index(s[i:], "[")
|
||||||
|
if start < 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
start += i
|
||||||
|
depth := 0
|
||||||
|
for j := start; j < len(s); j++ {
|
||||||
|
switch s[j] {
|
||||||
|
case '[':
|
||||||
|
depth++
|
||||||
|
case ']':
|
||||||
|
depth--
|
||||||
|
if depth == 0 {
|
||||||
|
return s[start : j+1], true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitStepsHeuristically(body string) []string {
|
||||||
|
body = strings.ReplaceAll(body, "\r\n", "\n")
|
||||||
|
body = strings.ReplaceAll(body, "\\n", "\n")
|
||||||
|
var parts []string
|
||||||
|
if strings.Contains(body, "\n") {
|
||||||
|
for _, line := range strings.Split(body, "\n") {
|
||||||
|
parts = append(parts, line)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, seg := range strings.Split(body, ",") {
|
||||||
|
parts = append(parts, seg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]string, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
t := strings.TrimSpace(part)
|
||||||
|
t = strings.Trim(t, "\"'`")
|
||||||
|
t = strings.TrimLeft(t, "-*0123456789.、 \t")
|
||||||
|
t = strings.TrimSpace(strings.ReplaceAll(t, `\"`, `"`))
|
||||||
|
if t == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, t)
|
||||||
|
}
|
||||||
|
return normalizePlanSteps(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePlanSteps(in []string) []string {
|
||||||
|
out := make([]string, 0, len(in))
|
||||||
|
for _, step := range in {
|
||||||
|
t := strings.TrimSpace(step)
|
||||||
|
if t == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, t)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripCodeFence(s string) string {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if !strings.HasPrefix(s, "```") {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
s = strings.TrimPrefix(s, "```json")
|
||||||
|
s = strings.TrimPrefix(s, "```JSON")
|
||||||
|
s = strings.TrimPrefix(s, "```")
|
||||||
|
s = strings.TrimSuffix(strings.TrimSpace(s), "```")
|
||||||
|
return strings.TrimSpace(s)
|
||||||
|
}
|
||||||
|
|
||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
"cyberstrike-ai/internal/reasoning"
|
"cyberstrike-ai/internal/reasoning"
|
||||||
|
|
||||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
@@ -64,6 +65,7 @@ func RunDeepAgent(
|
|||||||
agentsMarkdownDir string,
|
agentsMarkdownDir string,
|
||||||
orchestrationOverride string,
|
orchestrationOverride string,
|
||||||
reasoningClient *reasoning.ClientIntent,
|
reasoningClient *reasoning.ClientIntent,
|
||||||
|
systemPromptExtra string,
|
||||||
) (*RunResult, error) {
|
) (*RunResult, error) {
|
||||||
if appCfg == nil || ma == nil || ag == nil {
|
if appCfg == nil || ma == nil || ag == nil {
|
||||||
return nil, fmt.Errorf("multiagent: 配置或 Agent 为空")
|
return nil, fmt.Errorf("multiagent: 配置或 Agent 为空")
|
||||||
@@ -260,7 +262,8 @@ func RunDeepAgent(
|
|||||||
subHandlers = append(subHandlers, teleMw)
|
subHandlers = append(subHandlers, teleMw)
|
||||||
}
|
}
|
||||||
|
|
||||||
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools, subToolSearchActive)
|
subInstrFinal := project.AppendVisionImageAnalysisIfReady(instr, appCfg.Vision.Ready())
|
||||||
|
subInstrFinal = injectToolNamesOnlyInstruction(ctx, subInstrFinal, subTools, subToolSearchActive)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
subNames := collectToolNames(ctx, subTools)
|
subNames := collectToolNames(ctx, subTools)
|
||||||
mountedNames := collectToolNames(ctx, subToolsForCfg)
|
mountedNames := collectToolNames(ctx, subToolsForCfg)
|
||||||
@@ -339,6 +342,8 @@ func RunDeepAgent(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
orchInstruction = project.AppendSystemPromptBlock(orchInstruction, systemPromptExtra)
|
||||||
|
orchInstruction = project.AppendVisionImageAnalysisIfReady(orchInstruction, appCfg.Vision.Ready())
|
||||||
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools, mainToolSearchActive)
|
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools, mainToolSearchActive)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
mainNames := collectToolNames(ctx, mainTools)
|
mainNames := collectToolNames(ctx, mainTools)
|
||||||
@@ -387,7 +392,8 @@ func RunDeepAgent(
|
|||||||
|
|
||||||
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
||||||
deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()}
|
deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()}
|
||||||
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes); mw != nil {
|
taskEnrichExtra := systemPromptExtra
|
||||||
|
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes, taskEnrichExtra); mw != nil {
|
||||||
deepHandlers = append(deepHandlers, mw)
|
deepHandlers = append(deepHandlers, mw)
|
||||||
}
|
}
|
||||||
if len(mainOrchestratorPre) > 0 {
|
if len(mainOrchestratorPre) > 0 {
|
||||||
@@ -538,7 +544,7 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
||||||
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
|
baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
|
||||||
|
|
||||||
streamsMainAssistant := func(agent string) bool {
|
streamsMainAssistant := func(agent string) bool {
|
||||||
if orchMode == "plan_execute" {
|
if orchMode == "plan_execute" {
|
||||||
@@ -566,6 +572,8 @@ func RunDeepAgent(
|
|||||||
StreamsMainAssistant: streamsMainAssistant,
|
StreamsMainAssistant: streamsMainAssistant,
|
||||||
EinoRoleTag: einoRoleTag,
|
EinoRoleTag: einoRoleTag,
|
||||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||||
|
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
|
||||||
|
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
|
||||||
McpIDsMu: &mcpIDsMu,
|
McpIDsMu: &mcpIDsMu,
|
||||||
McpIDs: &mcpIDs,
|
McpIDs: &mcpIDs,
|
||||||
FilesystemMonitorAgent: ag,
|
FilesystemMonitorAgent: ag,
|
||||||
@@ -595,6 +603,13 @@ func chatToolCallsToSchema(tcs []agent.ToolCall) []schema.ToolCall {
|
|||||||
argsStr = string(b)
|
argsStr = string(b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Some OpenAI-compatible gateways require `function.arguments` to exist
|
||||||
|
// on every assistant tool_call message. When args are empty, omitempty may
|
||||||
|
// drop the field during serialization and cause "missing field arguments"
|
||||||
|
// on the next turn history replay.
|
||||||
|
if strings.TrimSpace(argsStr) == "" {
|
||||||
|
argsStr = "{}"
|
||||||
|
}
|
||||||
typ := tc.Type
|
typ := tc.Type
|
||||||
if typ == "" {
|
if typ == "" {
|
||||||
typ = "function"
|
typ = "function"
|
||||||
|
|||||||
@@ -30,8 +30,15 @@ type taskContextEnrichMiddleware struct {
|
|||||||
// newTaskContextEnrichMiddleware returns a middleware that enriches task
|
// newTaskContextEnrichMiddleware returns a middleware that enriches task
|
||||||
// descriptions with user conversation context. Returns nil if disabled
|
// descriptions with user conversation context. Returns nil if disabled
|
||||||
// (maxRunes < 0) or no user messages exist.
|
// (maxRunes < 0) or no user messages exist.
|
||||||
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int) adk.ChatModelAgentMiddleware {
|
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int, projectBlackboard string) adk.ChatModelAgentMiddleware {
|
||||||
supplement := buildUserContextSupplement(userMessage, history, maxRunes)
|
supplement := buildUserContextSupplement(userMessage, history, maxRunes)
|
||||||
|
if bb := strings.TrimSpace(projectBlackboard); bb != "" {
|
||||||
|
if supplement != "" {
|
||||||
|
supplement += "\n\n## 项目黑板索引\n" + bb
|
||||||
|
} else {
|
||||||
|
supplement = "\n\n## 项目黑板索引\n" + bb
|
||||||
|
}
|
||||||
|
}
|
||||||
if supplement == "" {
|
if supplement == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ func TestTaskContextEnrichMiddleware_EnrichesTaskDescription(t *testing.T) {
|
|||||||
"继续测试",
|
"继续测试",
|
||||||
[]agent.ChatMessage{{Role: "user", Content: "http://8.163.32.73:8081 pikachu靶场"}},
|
[]agent.ChatMessage{{Role: "user", Content: "http://8.163.32.73:8081 pikachu靶场"}},
|
||||||
0,
|
0,
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
if mw == nil {
|
if mw == nil {
|
||||||
t.Fatal("expected non-nil middleware")
|
t.Fatal("expected non-nil middleware")
|
||||||
@@ -149,7 +150,7 @@ func TestTaskContextEnrichMiddleware_EnrichesTaskDescription(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) {
|
func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) {
|
||||||
mw := newTaskContextEnrichMiddleware("test", nil, 0)
|
mw := newTaskContextEnrichMiddleware("test", nil, 0, "")
|
||||||
if mw == nil {
|
if mw == nil {
|
||||||
t.Fatal("expected non-nil middleware")
|
t.Fatal("expected non-nil middleware")
|
||||||
}
|
}
|
||||||
@@ -175,7 +176,7 @@ func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTaskContextEnrichMiddleware_NilWhenDisabled(t *testing.T) {
|
func TestTaskContextEnrichMiddleware_NilWhenDisabled(t *testing.T) {
|
||||||
mw := newTaskContextEnrichMiddleware("test", nil, -1)
|
mw := newTaskContextEnrichMiddleware("test", nil, -1, "")
|
||||||
if mw != nil {
|
if mw != nil {
|
||||||
t.Error("middleware should be nil when disabled")
|
t.Error("middleware should be nil when disabled")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AppendSystemPromptBlock 将附加块追加到 system prompt。
|
||||||
|
func AppendSystemPromptBlock(base, block string) string {
|
||||||
|
base = strings.TrimSpace(base)
|
||||||
|
block = strings.TrimSpace(block)
|
||||||
|
if block == "" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
if base == "" {
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
return base + "\n\n" + block
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。
|
||||||
|
func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
|
||||||
|
if db == nil || !cfg.Enabled {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
if projectID == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
proj, err := db.GetProject(projectID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
facts, err := db.ListProjectFactsForIndex(projectID, cfg.DefaultInjectDeprecated)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(facts) == 0 {
|
||||||
|
return fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n(暂无事实)\n需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。", proj.Name, proj.ID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.SliceStable(facts, func(i, j int) bool {
|
||||||
|
if facts[i].Pinned != facts[j].Pinned {
|
||||||
|
return facts[i].Pinned
|
||||||
|
}
|
||||||
|
return facts[i].UpdatedAt.After(facts[j].UpdatedAt)
|
||||||
|
})
|
||||||
|
|
||||||
|
maxRunes := cfg.FactIndexMaxRunesEffective()
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n", proj.Name, proj.ID))
|
||||||
|
used := len([]rune(b.String()))
|
||||||
|
omitted := 0
|
||||||
|
|
||||||
|
for _, f := range facts {
|
||||||
|
line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence)
|
||||||
|
lineRunes := len([]rune(line))
|
||||||
|
if used+lineRunes > maxRunes {
|
||||||
|
omitted++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteString(line)
|
||||||
|
used += lineRunes
|
||||||
|
}
|
||||||
|
|
||||||
|
if omitted > 0 {
|
||||||
|
b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted))
|
||||||
|
}
|
||||||
|
b.WriteString("需要完整内容(攻击链、POC、请求响应等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。\n")
|
||||||
|
b.WriteString("写入事实时:summary 写「什么+在哪+如何验证」;body 写可复现全流程(发现/利用类 fact_key 建议 finding|chain|exploit|poc/ 前缀)。\n")
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 边渗透边记录:统一节奏文案(agents/*.md 须与 FactRecordingIncrementalRhythmMarkdown 保持一致)。
|
||||||
|
const (
|
||||||
|
factRhythmCore = "勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。"
|
||||||
|
factRhythmCoordinatorSuffix = "委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。"
|
||||||
|
factRhythmSubAgentSuffix = "若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FactRecordingIncrementalRhythmMarkdown 返回边渗透边记录节奏(Markdown,供 agents/*.md 与文档对齐)。
|
||||||
|
func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("- **边渗透边记录(强制节奏)**:")
|
||||||
|
b.WriteString(factRhythmCore)
|
||||||
|
if coordinator {
|
||||||
|
b.WriteString(factRhythmCoordinatorSuffix)
|
||||||
|
}
|
||||||
|
if subAgent {
|
||||||
|
b.WriteString(factRhythmSubAgentSuffix)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func factRecordingIncrementalRhythmBuiltin(coordinator, subAgent bool) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 ")
|
||||||
|
b.WriteString(builtin.ToolUpsertProjectFact)
|
||||||
|
b.WriteString("(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 ")
|
||||||
|
b.WriteString(builtin.ToolRecordVulnerability)
|
||||||
|
b.WriteString(";与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。")
|
||||||
|
if coordinator {
|
||||||
|
b.WriteString(factRhythmCoordinatorSuffix)
|
||||||
|
}
|
||||||
|
if subAgent {
|
||||||
|
b.WriteString(factRhythmSubAgentSuffix)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactRecordingBlackboardSection 项目黑板与漏洞记录的完整系统提示块(单/多 Agent 主代理共用)。
|
||||||
|
// coordinatorDelegate 为 true 时追加「协调者代子代理落库」说明(Deep / plan_execute / supervisor)。
|
||||||
|
func FactRecordingBlackboardSection(coordinatorDelegate bool) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
||||||
|
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ")
|
||||||
|
b.WriteString(builtin.ToolGetProjectFact)
|
||||||
|
b.WriteString("(fact_key) 获取 body,禁止凭摘要臆造细节。**\n\n")
|
||||||
|
b.WriteString(factRecordingIncrementalRhythmBuiltin(coordinatorDelegate, false))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞条目):使用 ")
|
||||||
|
b.WriteString(builtin.ToolUpsertProjectFact)
|
||||||
|
b.WriteString(",fact_key 建议 `category/slug`(如 target/primary_domain),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
||||||
|
b.WriteString("- **发现与利用上下文**(审计复现):fact_key 建议 finding/、chain/、exploit/、poc/ 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 related_vulnerability_id),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
||||||
|
b.WriteString("- **可交付漏洞**:使用 ")
|
||||||
|
b.WriteString(builtin.ToolRecordVulnerability)
|
||||||
|
b.WriteString(",含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ")
|
||||||
|
b.WriteString(builtin.ToolListVulnerabilities)
|
||||||
|
b.WriteString(" 查重,详情用 ")
|
||||||
|
b.WriteString(builtin.ToolGetVulnerability)
|
||||||
|
b.WriteString("(id)(默认仅当前项目/会话)。\n")
|
||||||
|
b.WriteString("- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ")
|
||||||
|
b.WriteString(builtin.ToolDeprecateProjectFact)
|
||||||
|
b.WriteString(" 或漏洞状态 false_positive。\n")
|
||||||
|
b.WriteString("- 事实多时用 ")
|
||||||
|
b.WriteString(builtin.ToolListProjectFacts)
|
||||||
|
b.WriteString(" / ")
|
||||||
|
b.WriteString(builtin.ToolSearchProjectFacts)
|
||||||
|
b.WriteString(" 检索。\n\n")
|
||||||
|
b.WriteString(FactRecordingGuidanceBlock())
|
||||||
|
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactRecordingSubAgentSection 子代理边渗透边记录(无工具时输出待落库条目)。
|
||||||
|
func FactRecordingSubAgentSection() string {
|
||||||
|
return "## 边渗透边记录\n\n" + factRecordingIncrementalRhythmBuiltin(false, true) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactRecordingBlackboardSectionMarkdown 与 FactRecordingBlackboardSection 等价的 Markdown(工具名为字面量,供 agents/*.md)。
|
||||||
|
func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
||||||
|
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**\n\n")
|
||||||
|
b.WriteString(FactRecordingIncrementalRhythmMarkdown(coordinatorDelegate, false))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
||||||
|
b.WriteString("- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
||||||
|
b.WriteString("- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。\n")
|
||||||
|
b.WriteString("- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。\n")
|
||||||
|
b.WriteString("- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。\n\n")
|
||||||
|
b.WriteString(FactRecordingGuidanceBlock())
|
||||||
|
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
@@ -0,0 +1,140 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 事实 category 常量(写入 upsert_project_fact 的 category 字段)。
|
||||||
|
const (
|
||||||
|
FactCategoryTarget = "target"
|
||||||
|
FactCategoryAuth = "auth"
|
||||||
|
FactCategoryInfra = "infra"
|
||||||
|
FactCategoryBusiness = "business"
|
||||||
|
FactCategoryFinding = "finding"
|
||||||
|
FactCategoryChain = "chain"
|
||||||
|
FactCategoryExploit = "exploit"
|
||||||
|
FactCategoryPOC = "poc"
|
||||||
|
FactCategoryNote = "note"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequiresAttackChainBody 判断该事实是否应携带可复现的攻击链 / exploit 详情(写在 body,非仅 summary)。
|
||||||
|
func RequiresAttackChainBody(category, factKey string) bool {
|
||||||
|
c := strings.ToLower(strings.TrimSpace(category))
|
||||||
|
switch c {
|
||||||
|
case FactCategoryFinding, FactCategoryChain, FactCategoryExploit, FactCategoryPOC, "vuln":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
key := strings.ToLower(strings.TrimSpace(factKey))
|
||||||
|
for _, prefix := range []string{"finding/", "chain/", "exploit/", "poc/"} {
|
||||||
|
if strings.HasPrefix(key, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSparseFactBody 攻击链类事实 body 过短或缺少关键段落时返回 true(软校验,不阻断写入)。
|
||||||
|
func IsSparseFactBody(category, factKey, body string) bool {
|
||||||
|
if !RequiresAttackChainBody(category, factKey) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
body = strings.TrimSpace(body)
|
||||||
|
if body == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(body)
|
||||||
|
// 至少应包含可复现线索:步骤/请求/命令/代码块 之一
|
||||||
|
hasSteps := strings.Contains(lower, "攻击链") || strings.Contains(lower, "## 攻击") ||
|
||||||
|
strings.Contains(lower, "## exploit") || strings.Contains(lower, "## poc")
|
||||||
|
hasHTTP := strings.Contains(lower, "```http") || strings.Contains(lower, "```bash") ||
|
||||||
|
strings.Contains(lower, "curl ") || strings.Contains(lower, "get ") || strings.Contains(lower, "post ")
|
||||||
|
hasReq := strings.Contains(lower, "请求") || strings.Contains(lower, "响应") || strings.Contains(lower, "payload")
|
||||||
|
// 无攻击链/POC/请求等结构线索,视为仅结论性描述(不论长短)
|
||||||
|
return !(hasSteps || hasHTTP || hasReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactBodyTemplate 按 category 返回建议的 body Markdown 骨架(供 Agent 填入真实内容)。
|
||||||
|
func FactBodyTemplate(category, factKey string) string {
|
||||||
|
if RequiresAttackChainBody(category, factKey) {
|
||||||
|
return attackChainFactBodyTemplate
|
||||||
|
}
|
||||||
|
return envFactBodyTemplate
|
||||||
|
}
|
||||||
|
|
||||||
|
const attackChainFactBodyTemplate = `## 结论(可验证,一句话)
|
||||||
|
<勿仅写「存在漏洞」;写明类型 + 位置 + 触发条件>
|
||||||
|
|
||||||
|
## 目标与入口
|
||||||
|
- 目标: <URL / IP:Port / 主机名>
|
||||||
|
- 入口: <路径 / 接口 / 参数>
|
||||||
|
- 前置条件: <匿名 / 角色 / Cookie / 其他依赖>
|
||||||
|
|
||||||
|
## 攻击链(逐步可复现)
|
||||||
|
1. <侦察/发现>
|
||||||
|
2. <利用/触发>
|
||||||
|
3. <影响证明(读文件、RCE 回显、越权数据等)>
|
||||||
|
|
||||||
|
## Exploit / POC
|
||||||
|
### 请求
|
||||||
|
` + "```http\n<METHOD> <path> HTTP/1.1\nHost: ...\n...\n\n<body>\n```" + `
|
||||||
|
|
||||||
|
### 响应 / 现象
|
||||||
|
<关键响应片段、状态码、差异点>
|
||||||
|
|
||||||
|
### 命令 / 脚本(如有)
|
||||||
|
` + "```bash\n<command>\n```" + `
|
||||||
|
|
||||||
|
## 关键证据
|
||||||
|
- <工具输出摘要 / 截图路径 / 会话或消息 ID>
|
||||||
|
|
||||||
|
## 关联
|
||||||
|
- related_vulnerability_id: <可选,对应 record_vulnerability 的 id>
|
||||||
|
- 依赖事实: <fact_key,如 auth/session_cookie>
|
||||||
|
|
||||||
|
## 备注与不确定性
|
||||||
|
<待验证假设、环境差异、绕过尝试记录>`
|
||||||
|
|
||||||
|
const envFactBodyTemplate = `## 摘要
|
||||||
|
<该事实的核心认知>
|
||||||
|
|
||||||
|
## 细节
|
||||||
|
<端口/版本/路径/凭据特征/业务规则等>
|
||||||
|
|
||||||
|
## 来源与证据
|
||||||
|
<命令输出、响应片段、发现时间>
|
||||||
|
|
||||||
|
## 关联
|
||||||
|
- 相关 fact_key: <可选>`
|
||||||
|
|
||||||
|
// FactRecordingGuidanceBlock 写入系统提示:要求事实沉淀攻击链上下文而非仅结论。
|
||||||
|
func FactRecordingGuidanceBlock() string {
|
||||||
|
return `### 事实写入规范(审计复现 / 知识沉淀)
|
||||||
|
|
||||||
|
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||||
|
- **body**:完整可复现上下文,写入 ` + "`upsert_project_fact`" + ` 的 body 字段;索引不含 body,后续会话须靠 ` + "`get_project_fact`" + ` 取回。
|
||||||
|
- **category / fact_key 建议**:
|
||||||
|
- 环境认知:` + "`target/`" + `、` + "`auth/`" + `、` + "`infra/`" + `、` + "`business/`" + `(body 用环境模板即可)
|
||||||
|
- 发现与利用:` + "`finding/`" + `、` + "`chain/`" + `、` + "`exploit/`" + `、` + "`poc/`" + `(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||||
|
- **与漏洞记录分工**:` + "`record_vulnerability`" + ` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||||
|
- 更新同一发现时保持相同 ` + "`fact_key`" + ` 覆盖写入,勿散落多个 key 导致上下文丢失。`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SparseBodyWarning 攻击链类事实 body 不足时的工具返回提示(不阻断保存)。
|
||||||
|
func SparseBodyWarning(category, factKey string) string {
|
||||||
|
if !IsSparseFactBody(category, factKey, "") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"\n\n⚠ 提示:category=%q / fact_key=%q 属于攻击链类事实,但 body 为空或过简。请补充完整攻击链与 POC(参考模板),便于后续审计复现。\n建议 body 骨架:\n%s",
|
||||||
|
category, factKey, FactBodyTemplate(category, factKey),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SparseBodyWarningIfNeeded 根据实际 body 判断是否追加警告。
|
||||||
|
func SparseBodyWarningIfNeeded(category, factKey, body string) string {
|
||||||
|
if !IsSparseFactBody(category, factKey, body) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return SparseBodyWarning(category, factKey)
|
||||||
|
}
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequiresAttackChainBody(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
cat, key string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"finding", "note/misc", true},
|
||||||
|
{"note", "finding/sqli-login", true},
|
||||||
|
{"target", "target/primary_domain", false},
|
||||||
|
{"auth", "auth/admin_cookie", false},
|
||||||
|
{"chain", "x", true},
|
||||||
|
{"", "exploit/rce-upload", true},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
if got := RequiresAttackChainBody(tc.cat, tc.key); got != tc.want {
|
||||||
|
t.Errorf("RequiresAttackChainBody(%q,%q)=%v want %v", tc.cat, tc.key, got, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsSparseFactBody(t *testing.T) {
|
||||||
|
long := strings.Repeat("x", 150)
|
||||||
|
if !IsSparseFactBody("finding", "finding/x", "") {
|
||||||
|
t.Error("empty body should be sparse")
|
||||||
|
}
|
||||||
|
if !IsSparseFactBody("finding", "finding/x", long) {
|
||||||
|
t.Error("body without repro clues should be sparse")
|
||||||
|
}
|
||||||
|
body := "## 攻击链\n1. step\n## Exploit\n```http\nGET / HTTP/1.1\n```\n"
|
||||||
|
if IsSparseFactBody("finding", "finding/x", body) {
|
||||||
|
t.Error("structured body should not be sparse")
|
||||||
|
}
|
||||||
|
if IsSparseFactBody("target", "target/x", "") {
|
||||||
|
t.Error("env fact empty body is ok")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
// projectScopePayload 解析 projects.scope_json(约定字段,可扩展)。
|
||||||
|
type projectScopePayload struct {
|
||||||
|
Targets []string `json:"targets"`
|
||||||
|
Exclude []string `json:"exclude"`
|
||||||
|
Notes string `json:"notes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildScopeBlock 将项目 scope_json 格式化为 Agent 可读的授权范围块。
|
||||||
|
func BuildScopeBlock(proj *database.Project) string {
|
||||||
|
if proj == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(proj.ScopeJSON)
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload projectScopePayload
|
||||||
|
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||||
|
return fmt.Sprintf("## 项目测试范围(project: %s)\n(scope_json 非合法 JSON,请人工核对配置)\n```\n%s\n```\n"+
|
||||||
|
"仅对明确授权目标执行测试;超出范围须停止并说明。\n", proj.Name, truncateRunes(raw, 800))
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(fmt.Sprintf("## 项目测试范围(project: %s, id: %s)\n", proj.Name, proj.ID))
|
||||||
|
b.WriteString("以下为授权边界,**必须遵守**:仅测试列出的 targets,避开 exclude,不得擅自扩大范围。\n")
|
||||||
|
|
||||||
|
if len(payload.Targets) > 0 {
|
||||||
|
b.WriteString("\n**允许测试(targets)**:\n")
|
||||||
|
for _, t := range payload.Targets {
|
||||||
|
t = strings.TrimSpace(t)
|
||||||
|
if t != "" {
|
||||||
|
b.WriteString("- " + t + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(payload.Exclude) > 0 {
|
||||||
|
b.WriteString("\n**明确排除(exclude)**:\n")
|
||||||
|
for _, t := range payload.Exclude {
|
||||||
|
t = strings.TrimSpace(t)
|
||||||
|
if t != "" {
|
||||||
|
b.WriteString("- " + t + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if n := strings.TrimSpace(payload.Notes); n != "" {
|
||||||
|
b.WriteString("\n**说明(notes)**:\n" + n + "\n")
|
||||||
|
}
|
||||||
|
if len(payload.Targets) == 0 && len(payload.Exclude) == 0 && strings.TrimSpace(payload.Notes) == "" {
|
||||||
|
b.WriteString("\n(scope_json 已配置但未识别 targets/exclude/notes 字段,原始内容供参考)\n```json\n")
|
||||||
|
b.WriteString(truncateRunes(raw, 1200))
|
||||||
|
b.WriteString("\n```\n")
|
||||||
|
}
|
||||||
|
b.WriteString("\n若目标不在 targets 内或命中 exclude,不得主动扫描/利用;需用户明确扩大授权后再继续。\n")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateRunes(s string, max int) string {
|
||||||
|
r := []rune(s)
|
||||||
|
if len(r) <= max {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return string(r[:max]) + "…"
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildProjectBlackboardBlock 组合测试范围 + 事实黑板索引。
|
||||||
|
func BuildProjectBlackboardBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
if projectID == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
proj, err := db.GetProject(projectID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
parts := []string{}
|
||||||
|
if scope := strings.TrimSpace(BuildScopeBlock(proj)); scope != "" {
|
||||||
|
parts = append(parts, scope)
|
||||||
|
}
|
||||||
|
index, err := BuildFactIndexBlock(db, projectID, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(index) != "" {
|
||||||
|
parts = append(parts, index)
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "\n\n"), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildScopeBlock_targetsExcludeNotes(t *testing.T) {
|
||||||
|
proj := &database.Project{
|
||||||
|
ID: "p1",
|
||||||
|
Name: "Acme",
|
||||||
|
ScopeJSON: `{"targets":["https://app.example.com"],"exclude":["*.cdn.example.com"],"notes":"仅 Web 层"}`,
|
||||||
|
}
|
||||||
|
block := BuildScopeBlock(proj)
|
||||||
|
if !strings.Contains(block, "https://app.example.com") {
|
||||||
|
t.Fatalf("missing target: %s", block)
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, "cdn.example.com") {
|
||||||
|
t.Fatalf("missing exclude: %s", block)
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, "仅 Web 层") {
|
||||||
|
t.Fatalf("missing notes: %s", block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildScopeBlock_empty(t *testing.T) {
|
||||||
|
if BuildScopeBlock(&database.Project{Name: "X"}) != "" {
|
||||||
|
t.Fatal("expected empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildScopeBlock_invalidJSON(t *testing.T) {
|
||||||
|
proj := &database.Project{Name: "X", ScopeJSON: `{not json`}
|
||||||
|
block := BuildScopeBlock(proj)
|
||||||
|
if !strings.Contains(block, "非合法 JSON") {
|
||||||
|
t.Fatalf("unexpected: %s", block)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import "cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
// GetProjectStats 聚合项目统计(含待补全事实数)。
|
||||||
|
func GetProjectStats(db *database.DB, projectID string) (*database.ProjectStats, error) {
|
||||||
|
stats, err := db.GetProjectStatsCounts(projectID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rows, err := db.ListProjectFactsForSparseCheck(projectID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, r := range rows {
|
||||||
|
if IsSparseFactBody(r.Category, r.FactKey, r.Body) {
|
||||||
|
stats.SparseFactCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// VisionImageAnalysisSection 单/多代理共用的图片分析提示(analyze_image;上下文仅保留文字摘要)。
|
||||||
|
func VisionImageAnalysisSection() string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("## 图片分析\n\n")
|
||||||
|
b.WriteString("- 遇到图片文件(截图、验证码、登录页、报告配图)时,若存在工具 analyze_image,请传入服务器上的文件路径进行分析。\n")
|
||||||
|
b.WriteString("- 不要对二进制图片使用 read_file 指望理解内容;用户消息中「📎 xxx.png: /path」即为可传给 analyze_image 的路径。\n")
|
||||||
|
b.WriteString("- 验证码类:若已从页面或接口保存为本地图片(如 captcha.png),用 analyze_image,question 写明「只输出验证码字符」;识别失败则刷新验证码后重新保存再识;复杂滑块/行为验证码勿指望单次识图成功。\n")
|
||||||
|
b.WriteString("- 委派子代理时,若子任务含验证码/截图识读,在 task description 中写明图片路径与期望输出格式。\n")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendVisionImageAnalysisIfReady 仅在 vision.enabled 且 model 已配置时追加图片分析提示。
|
||||||
|
func AppendVisionImageAnalysisIfReady(base string, visionReady bool) string {
|
||||||
|
if !visionReady {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
return AppendSystemPromptBlock(base, VisionImageAnalysisSection())
|
||||||
|
}
|
||||||
@@ -149,13 +149,18 @@ func effectiveEffort(sr *config.OpenAIReasoningConfig, client *ClientIntent, all
|
|||||||
func normalizeEffort(s string) string {
|
func normalizeEffort(s string) string {
|
||||||
e := strings.ToLower(strings.TrimSpace(s))
|
e := strings.ToLower(strings.TrimSpace(s))
|
||||||
switch e {
|
switch e {
|
||||||
case "low", "medium", "high", "max":
|
case "low", "medium", "high", "max", "xhigh":
|
||||||
return e
|
return e
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// usesExtraFieldsReasoningEffort 为 Eino 无枚举的最高档 effort,经 ExtraFields 原样下发(max / xhigh 由网关自行识别,不做互转)。
|
||||||
|
func usesExtraFieldsReasoningEffort(e string) bool {
|
||||||
|
return e == "max" || e == "xhigh"
|
||||||
|
}
|
||||||
|
|
||||||
func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile {
|
func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile {
|
||||||
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
|
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
|
||||||
return wireClaude
|
return wireClaude
|
||||||
@@ -210,11 +215,11 @@ func applyOpenAICompat(cfg *einoopenai.ChatModelConfig, mode, effort string) {
|
|||||||
if e == "" {
|
if e == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if e == "max" {
|
if usesExtraFieldsReasoningEffort(e) {
|
||||||
if cfg.ExtraFields == nil {
|
if cfg.ExtraFields == nil {
|
||||||
cfg.ExtraFields = make(map[string]any)
|
cfg.ExtraFields = make(map[string]any)
|
||||||
}
|
}
|
||||||
cfg.ExtraFields["reasoning_effort"] = "max"
|
cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(e)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch e {
|
switch e {
|
||||||
@@ -245,6 +250,6 @@ func applyOutputConfigEffort(cfg *einoopenai.ChatModelConfig, mode, effort strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func effortStringForAPI(e string) string {
|
func effortStringForAPI(e string) string {
|
||||||
// Gateways expect lowercase strings; "max" kept as max.
|
// 原样透传:OpenAI 官方多为 xhigh,部分兼容网关为 max,由配置/对话 effort 选择。
|
||||||
return strings.ToLower(strings.TrimSpace(e))
|
return strings.ToLower(strings.TrimSpace(e))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package reasoning
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEffortStringForAPI_passthrough(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"max": "max",
|
||||||
|
"xhigh": "xhigh",
|
||||||
|
"HIGH": "high",
|
||||||
|
"Medium": "medium",
|
||||||
|
}
|
||||||
|
for in, want := range cases {
|
||||||
|
if got := effortStringForAPI(in); got != want {
|
||||||
|
t.Fatalf("%q -> %q, want %q", in, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeEffort_maxAndXhigh(t *testing.T) {
|
||||||
|
if normalizeEffort("xhigh") != "xhigh" {
|
||||||
|
t.Fatal("xhigh not accepted")
|
||||||
|
}
|
||||||
|
if normalizeEffort("max") != "max" {
|
||||||
|
t.Fatal("max not accepted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyOpenAICompat_xhighExtraField(t *testing.T) {
|
||||||
|
cfg := &einoopenai.ChatModelConfig{}
|
||||||
|
oa := &config.OpenAIConfig{
|
||||||
|
Reasoning: config.OpenAIReasoningConfig{
|
||||||
|
Profile: "openai_compat",
|
||||||
|
Mode: "on",
|
||||||
|
Effort: "xhigh",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ApplyToEinoChatModelConfig(cfg, oa, nil)
|
||||||
|
if cfg.ExtraFields == nil {
|
||||||
|
t.Fatal("expected ExtraFields")
|
||||||
|
}
|
||||||
|
if got, _ := cfg.ExtraFields["reasoning_effort"].(string); got != "xhigh" {
|
||||||
|
t.Fatalf("reasoning_effort=%q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyOpenAICompat_maxPassthrough(t *testing.T) {
|
||||||
|
cfg := &einoopenai.ChatModelConfig{}
|
||||||
|
oa := &config.OpenAIConfig{
|
||||||
|
Reasoning: config.OpenAIReasoningConfig{
|
||||||
|
Profile: "openai_compat",
|
||||||
|
Mode: "on",
|
||||||
|
Effort: "max",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ApplyToEinoChatModelConfig(cfg, oa, nil)
|
||||||
|
got, _ := cfg.ExtraFields["reasoning_effort"].(string)
|
||||||
|
if got != "max" {
|
||||||
|
t.Fatalf("max effort wire=%q, want max", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
package vision
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/openai"
|
||||||
|
|
||||||
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client 调用独立 Vision ChatModel(单次 Generate)。
|
||||||
|
type Client struct {
|
||||||
|
cfg config.VisionConfig
|
||||||
|
mainOA config.OpenAIConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient 构造视觉客户端。
|
||||||
|
func NewClient(visionCfg config.VisionConfig, mainOpenAI config.OpenAIConfig) *Client {
|
||||||
|
return &Client{cfg: visionCfg, mainOA: mainOpenAI}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Analyze 将图片字节送入 VL 模型并返回文本描述。
|
||||||
|
func (c *Client) Analyze(ctx context.Context, img ImagePayload, question string) (string, error) {
|
||||||
|
if len(img.Bytes) == 0 {
|
||||||
|
return "", fmt.Errorf("empty image payload")
|
||||||
|
}
|
||||||
|
mime := strings.TrimSpace(img.MIMEType)
|
||||||
|
if mime == "" {
|
||||||
|
mime = "image/jpeg"
|
||||||
|
}
|
||||||
|
oa := c.cfg.OpenAICfgEffective(c.mainOA)
|
||||||
|
if strings.TrimSpace(oa.APIKey) == "" {
|
||||||
|
return "", fmt.Errorf("vision API key is empty (set vision.api_key or openai.api_key)")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(oa.Model) == "" {
|
||||||
|
return "", fmt.Errorf("vision model is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.Duration(c.cfg.TimeoutSecondsEffective()) * time.Second
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: timeout + 15*time.Second,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 60 * time.Second,
|
||||||
|
KeepAlive: 60 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
ResponseHeaderTimeout: timeout + 10*time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
httpClient = openai.NewEinoHTTPClient(&oa, httpClient)
|
||||||
|
|
||||||
|
modelCfg := &einoopenai.ChatModelConfig{
|
||||||
|
APIKey: oa.APIKey,
|
||||||
|
BaseURL: strings.TrimSuffix(oa.BaseURL, "/"),
|
||||||
|
Model: oa.Model,
|
||||||
|
HTTPClient: httpClient,
|
||||||
|
}
|
||||||
|
chatModel, err := einoopenai.NewChatModel(ctx, modelCfg)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("vision chat model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b64 := base64.StdEncoding.EncodeToString(img.Bytes)
|
||||||
|
detail := schema.ImageURLDetailLow
|
||||||
|
switch c.cfg.DetailEffective() {
|
||||||
|
case "high":
|
||||||
|
detail = schema.ImageURLDetailHigh
|
||||||
|
case "auto":
|
||||||
|
detail = schema.ImageURLDetailAuto
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := buildVisionPrompt(question)
|
||||||
|
userMsg := &schema.Message{
|
||||||
|
Role: schema.User,
|
||||||
|
UserInputMultiContent: []schema.MessageInputPart{
|
||||||
|
{Type: schema.ChatMessagePartTypeText, Text: prompt},
|
||||||
|
{
|
||||||
|
Type: schema.ChatMessagePartTypeImageURL,
|
||||||
|
Image: &schema.MessageInputImage{
|
||||||
|
MessagePartCommon: schema.MessagePartCommon{
|
||||||
|
Base64Data: &b64,
|
||||||
|
MIMEType: mime,
|
||||||
|
},
|
||||||
|
Detail: detail,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := chatModel.Generate(ctx, []*schema.Message{userMsg})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("vision generate: %w", err)
|
||||||
|
}
|
||||||
|
if resp == nil || strings.TrimSpace(resp.Content) == "" {
|
||||||
|
return "", fmt.Errorf("vision model returned empty content")
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(resp.Content), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildVisionPrompt(question string) string {
|
||||||
|
q := strings.TrimSpace(question)
|
||||||
|
if q == "" {
|
||||||
|
q = "请对图片做通用描述,侧重授权安全测试场景(可见文本、表单、按钮、验证码、错误信息、技术栈线索)。"
|
||||||
|
}
|
||||||
|
extra := ""
|
||||||
|
if looksLikeCaptchaQuestion(q) {
|
||||||
|
extra = "\n若为验证码:仅输出你辨认出的字符序列,不要空格、标点、解释;看不清则明确说无法识别。"
|
||||||
|
}
|
||||||
|
return `你是授权安全测试助手。请根据图片回答用户问题,只描述你能从图中确认的内容,不要编造。
|
||||||
|
用户问题:` + q + extra
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeCaptchaQuestion(q string) bool {
|
||||||
|
s := strings.ToLower(q)
|
||||||
|
for _, kw := range []string{"验证码", "captcha", "verification code", "verify code", "vcode", "图形码"} {
|
||||||
|
if strings.Contains(s, kw) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Contains(s, "只输出") && (strings.Contains(s, "字符") || strings.Contains(s, "character"))
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package vision
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestLooksLikeCaptchaQuestion(t *testing.T) {
|
||||||
|
if !looksLikeCaptchaQuestion("识别验证码,只输出字符") {
|
||||||
|
t.Fatal("expected captcha hint")
|
||||||
|
}
|
||||||
|
if looksLikeCaptchaQuestion("描述登录页布局") {
|
||||||
|
t.Fatal("expected non-captcha")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
package vision
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const chatUploadsDirName = "chat_uploads"
|
||||||
|
|
||||||
|
var allowedImageExt = map[string]struct{}{
|
||||||
|
".png": {}, ".jpg": {}, ".jpeg": {}, ".webp": {}, ".gif": {},
|
||||||
|
".bmp": {}, ".tif": {}, ".tiff": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathOptions 图片路径白名单根目录。
|
||||||
|
type PathOptions struct {
|
||||||
|
CWD string
|
||||||
|
ResultStorageDir string // 相对 CWD,如 tmp
|
||||||
|
ExtraRoots []string // vision.allowed_roots 绝对路径
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveImagePath 解析并校验可读图片路径(防穿越、symlink 逃逸)。
|
||||||
|
func ResolveImagePath(path string, opt PathOptions) (string, error) {
|
||||||
|
p := strings.TrimSpace(path)
|
||||||
|
if p == "" {
|
||||||
|
return "", fmt.Errorf("path is empty")
|
||||||
|
}
|
||||||
|
cwd := strings.TrimSpace(opt.CWD)
|
||||||
|
if cwd == "" {
|
||||||
|
var err error
|
||||||
|
cwd, err = os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("getwd: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cwdAbs, err := filepath.Abs(filepath.Clean(cwd))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var candidate string
|
||||||
|
if filepath.IsAbs(p) {
|
||||||
|
candidate = filepath.Clean(p)
|
||||||
|
} else {
|
||||||
|
candidate = filepath.Clean(filepath.Join(cwdAbs, p))
|
||||||
|
}
|
||||||
|
candidate = normalizeAbsPath(candidate)
|
||||||
|
if candidate == "" {
|
||||||
|
return "", fmt.Errorf("invalid path")
|
||||||
|
}
|
||||||
|
|
||||||
|
ext := strings.ToLower(filepath.Ext(candidate))
|
||||||
|
if _, ok := allowedImageExt[ext]; !ok {
|
||||||
|
return "", fmt.Errorf("unsupported image extension %q", ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
roots := buildAllowedRoots(cwdAbs, opt)
|
||||||
|
resolved, err := evalUnderAllowedRoots(candidate, roots)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
st, err := os.Stat(resolved)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("stat: %w", err)
|
||||||
|
}
|
||||||
|
if st.IsDir() {
|
||||||
|
return "", fmt.Errorf("not a regular file")
|
||||||
|
}
|
||||||
|
if st.Size() > 0 && st.Size() > 1<<30 {
|
||||||
|
return "", fmt.Errorf("file too large on disk")
|
||||||
|
}
|
||||||
|
return resolved, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAbsPath(p string) string {
|
||||||
|
abs, err := filepath.Abs(filepath.Clean(p))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if link, err := filepath.EvalSymlinks(abs); err == nil {
|
||||||
|
return link
|
||||||
|
}
|
||||||
|
return abs
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAllowedRoots(cwdAbs string, opt PathOptions) []string {
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
var roots []string
|
||||||
|
add := func(r string) {
|
||||||
|
r = strings.TrimSpace(r)
|
||||||
|
if r == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
abs := normalizeAbsPath(r)
|
||||||
|
if abs == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := seen[abs]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[abs] = struct{}{}
|
||||||
|
roots = append(roots, abs)
|
||||||
|
}
|
||||||
|
add(cwdAbs)
|
||||||
|
add(filepath.Join(cwdAbs, chatUploadsDirName))
|
||||||
|
rs := strings.TrimSpace(opt.ResultStorageDir)
|
||||||
|
if rs == "" {
|
||||||
|
rs = "tmp"
|
||||||
|
}
|
||||||
|
if filepath.IsAbs(rs) {
|
||||||
|
add(rs)
|
||||||
|
} else {
|
||||||
|
add(filepath.Join(cwdAbs, rs))
|
||||||
|
}
|
||||||
|
for _, r := range opt.ExtraRoots {
|
||||||
|
add(r)
|
||||||
|
}
|
||||||
|
return roots
|
||||||
|
}
|
||||||
|
|
||||||
|
func evalUnderAllowedRoots(candidate string, roots []string) (string, error) {
|
||||||
|
check := normalizeAbsPath(candidate)
|
||||||
|
for _, root := range roots {
|
||||||
|
if isUnderRoot(check, root) {
|
||||||
|
return candidate, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("path %q is outside allowed directories", candidate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isUnderRoot(path, root string) bool {
|
||||||
|
path = filepath.Clean(path)
|
||||||
|
root = filepath.Clean(root)
|
||||||
|
if path == root {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
sep := string(filepath.Separator)
|
||||||
|
return strings.HasPrefix(path, root+sep)
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package vision
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveImagePath_underCWD(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
img := filepath.Join(dir, "shot.png")
|
||||||
|
if err := os.WriteFile(img, []byte{0x89, 0x50, 0x4e, 0x47}, 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
got, err := ResolveImagePath(img, PathOptions{CWD: dir, ResultStorageDir: "tmp"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
want := normalizeAbsPath(img)
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveImagePath_rejectsTraversal(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
_, err := ResolveImagePath("../../../etc/passwd", PathOptions{CWD: dir})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for path outside roots")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveImagePath_rejectsNonImageExt(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
f := filepath.Join(dir, "notes.txt")
|
||||||
|
if err := os.WriteFile(f, []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err := ResolveImagePath(f, PathOptions{CWD: dir})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-image extension")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,212 @@
|
|||||||
|
package vision
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/disintegration/imaging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ImagePayload 送入 VL API 的图片字节与 MIME。
|
||||||
|
type ImagePayload struct {
|
||||||
|
Bytes []byte
|
||||||
|
MIMEType string
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreprocessMeta 记录缩放与编码结果,供工具输出与排障。
|
||||||
|
type PreprocessMeta struct {
|
||||||
|
OriginalPath string
|
||||||
|
OriginalBytes int64
|
||||||
|
OriginalWidth int
|
||||||
|
OriginalHeight int
|
||||||
|
OutputWidth int
|
||||||
|
OutputHeight int
|
||||||
|
OutputBytes int
|
||||||
|
OutputMIMEType string
|
||||||
|
JPEGQuality int // 0 表示未 JPEG 重编码(原图直传)
|
||||||
|
PreprocessMode string // passthrough | jpeg
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreprocessOptions 图片预处理参数。
|
||||||
|
type PreprocessOptions struct {
|
||||||
|
MaxImageBytes int64
|
||||||
|
MaxDimension int
|
||||||
|
JPEGQuality int
|
||||||
|
MaxPayloadBytes int64
|
||||||
|
SkipPreprocessBelowBytes int64 // 0 = 始终压缩;>0 时小图+尺寸合规可直传
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreprocessImageFile 读取图片;大图或超尺寸走 imaging 缩放+JPEG,否则可原图直传。
|
||||||
|
func PreprocessImageFile(path string, opt PreprocessOptions) (ImagePayload, PreprocessMeta, error) {
|
||||||
|
var meta PreprocessMeta
|
||||||
|
meta.OriginalPath = path
|
||||||
|
|
||||||
|
st, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return ImagePayload{}, meta, err
|
||||||
|
}
|
||||||
|
meta.OriginalBytes = st.Size()
|
||||||
|
if opt.MaxImageBytes > 0 && st.Size() > opt.MaxImageBytes {
|
||||||
|
return ImagePayload{}, meta, fmt.Errorf("file size %d exceeds max_image_bytes %d", st.Size(), opt.MaxImageBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgW, cfgH, format, err := imageDimensions(path)
|
||||||
|
if err != nil {
|
||||||
|
return ImagePayload{}, meta, err
|
||||||
|
}
|
||||||
|
meta.OriginalWidth = cfgW
|
||||||
|
meta.OriginalHeight = cfgH
|
||||||
|
|
||||||
|
maxDim := opt.MaxDimension
|
||||||
|
if maxDim <= 0 {
|
||||||
|
maxDim = 2048
|
||||||
|
}
|
||||||
|
maxPayload := opt.MaxPayloadBytes
|
||||||
|
if maxPayload <= 0 {
|
||||||
|
maxPayload = 512 * 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload, meta, ok, err := tryPassthrough(path, st.Size(), cfgW, cfgH, format, opt, maxDim, maxPayload); ok {
|
||||||
|
return payload, meta, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return compressWithImaging(path, opt, maxDim, maxPayload, meta)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tryPassthrough(path string, size int64, w, h int, format string, opt PreprocessOptions, maxDim int, maxPayload int64) (ImagePayload, PreprocessMeta, bool, error) {
|
||||||
|
var meta PreprocessMeta
|
||||||
|
meta.OriginalPath = path
|
||||||
|
meta.OriginalBytes = size
|
||||||
|
meta.OriginalWidth = w
|
||||||
|
meta.OriginalHeight = h
|
||||||
|
|
||||||
|
threshold := opt.SkipPreprocessBelowBytes
|
||||||
|
if threshold <= 0 {
|
||||||
|
return ImagePayload{}, meta, false, nil
|
||||||
|
}
|
||||||
|
if size > threshold {
|
||||||
|
return ImagePayload{}, meta, false, nil
|
||||||
|
}
|
||||||
|
longEdge := w
|
||||||
|
if h > longEdge {
|
||||||
|
longEdge = h
|
||||||
|
}
|
||||||
|
if longEdge > maxDim {
|
||||||
|
return ImagePayload{}, meta, false, nil
|
||||||
|
}
|
||||||
|
if size > maxPayload {
|
||||||
|
return ImagePayload{}, meta, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return ImagePayload{}, meta, false, err
|
||||||
|
}
|
||||||
|
mime := mimeFromImageFormat(format)
|
||||||
|
if mime == "" {
|
||||||
|
return ImagePayload{}, meta, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
meta.OutputWidth = w
|
||||||
|
meta.OutputHeight = h
|
||||||
|
meta.OutputBytes = len(raw)
|
||||||
|
meta.OutputMIMEType = mime
|
||||||
|
meta.PreprocessMode = "passthrough"
|
||||||
|
return ImagePayload{Bytes: raw, MIMEType: mime}, meta, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func compressWithImaging(path string, opt PreprocessOptions, maxDim int, maxPayload int64, meta PreprocessMeta) (ImagePayload, PreprocessMeta, error) {
|
||||||
|
src, err := imaging.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return ImagePayload{}, meta, fmt.Errorf("open image: %w", err)
|
||||||
|
}
|
||||||
|
bounds := src.Bounds()
|
||||||
|
meta.OriginalWidth = bounds.Dx()
|
||||||
|
meta.OriginalHeight = bounds.Dy()
|
||||||
|
|
||||||
|
dst := imaging.Fit(src, maxDim, maxDim, imaging.Lanczos)
|
||||||
|
outBounds := dst.Bounds()
|
||||||
|
meta.OutputWidth = outBounds.Dx()
|
||||||
|
meta.OutputHeight = outBounds.Dy()
|
||||||
|
|
||||||
|
quality := opt.JPEGQuality
|
||||||
|
if quality <= 0 || quality > 100 {
|
||||||
|
quality = 82
|
||||||
|
}
|
||||||
|
|
||||||
|
dim := maxDim
|
||||||
|
for attempt := 0; attempt < 6; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
dim = int(float64(dim) * 0.85)
|
||||||
|
if dim < 256 {
|
||||||
|
dim = 256
|
||||||
|
}
|
||||||
|
dst = imaging.Fit(src, dim, dim, imaging.Lanczos)
|
||||||
|
outBounds = dst.Bounds()
|
||||||
|
meta.OutputWidth = outBounds.Dx()
|
||||||
|
meta.OutputHeight = outBounds.Dy()
|
||||||
|
}
|
||||||
|
q := quality
|
||||||
|
for q >= 60 {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := imaging.Encode(&buf, dst, imaging.JPEG, imaging.JPEGQuality(q)); err != nil {
|
||||||
|
return ImagePayload{}, meta, fmt.Errorf("encode jpeg: %w", err)
|
||||||
|
}
|
||||||
|
if int64(buf.Len()) <= maxPayload {
|
||||||
|
meta.JPEGQuality = q
|
||||||
|
meta.OutputBytes = buf.Len()
|
||||||
|
meta.OutputMIMEType = "image/jpeg"
|
||||||
|
meta.PreprocessMode = "jpeg"
|
||||||
|
return ImagePayload{Bytes: buf.Bytes(), MIMEType: "image/jpeg"}, meta, nil
|
||||||
|
}
|
||||||
|
q -= 5
|
||||||
|
}
|
||||||
|
quality = 75
|
||||||
|
}
|
||||||
|
return ImagePayload{}, meta, fmt.Errorf("could not compress image under max_payload_bytes %d", maxPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func imageDimensions(path string) (w, h int, format string, err error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, "", err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
cfg, format, err := image.DecodeConfig(f)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, "", fmt.Errorf("decode image config: %w", err)
|
||||||
|
}
|
||||||
|
return cfg.Width, cfg.Height, format, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mimeFromImageFormat(format string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(format)) {
|
||||||
|
case "jpeg", "jpg":
|
||||||
|
return "image/jpeg"
|
||||||
|
case "png":
|
||||||
|
return "image/png"
|
||||||
|
case "gif":
|
||||||
|
return "image/gif"
|
||||||
|
case "webp":
|
||||||
|
return "image/webp"
|
||||||
|
case "bmp":
|
||||||
|
return "image/bmp"
|
||||||
|
case "tiff":
|
||||||
|
return "image/tiff"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeImageConfig 用于测试:确认文件可被解码。
|
||||||
|
func DecodeImageConfig(path string) (image.Config, string, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return image.Config{}, "", err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
return image.DecodeConfig(f)
|
||||||
|
}
|
||||||
@@ -0,0 +1,109 @@
|
|||||||
|
package vision
|
||||||
|
|
||||||
|
import (
|
||||||
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"image/png"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/disintegration/imaging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPreprocessImageFile_scalesAndLimitsPayload(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "big.png")
|
||||||
|
img := imaging.New(3000, 2000, color.White)
|
||||||
|
if err := imaging.Save(img, path); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, meta, err := PreprocessImageFile(path, PreprocessOptions{
|
||||||
|
MaxImageBytes: 10 * 1024 * 1024,
|
||||||
|
MaxDimension: 1024,
|
||||||
|
JPEGQuality: 85,
|
||||||
|
MaxPayloadBytes: 600 * 1024,
|
||||||
|
SkipPreprocessBelowBytes: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(out.Bytes) == 0 {
|
||||||
|
t.Fatal("empty output")
|
||||||
|
}
|
||||||
|
if meta.PreprocessMode != "jpeg" {
|
||||||
|
t.Fatalf("mode: %s", meta.PreprocessMode)
|
||||||
|
}
|
||||||
|
if meta.OutputWidth > 1024 || meta.OutputHeight > 1024 {
|
||||||
|
t.Fatalf("expected fit within 1024, got %dx%d", meta.OutputWidth, meta.OutputHeight)
|
||||||
|
}
|
||||||
|
if int64(len(out.Bytes)) > 600*1024 {
|
||||||
|
t.Fatalf("payload %d exceeds max", len(out.Bytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreprocessImageFile_passthroughSmallPNG(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "small.png")
|
||||||
|
if err := imaging.Save(imaging.New(400, 300, color.White), path); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, meta, err := PreprocessImageFile(path, PreprocessOptions{
|
||||||
|
MaxImageBytes: 5 * 1024 * 1024,
|
||||||
|
MaxDimension: 2048,
|
||||||
|
MaxPayloadBytes: 512 * 1024,
|
||||||
|
SkipPreprocessBelowBytes: 2 * 1024 * 1024,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if meta.PreprocessMode != "passthrough" {
|
||||||
|
t.Fatalf("expected passthrough, got %s", meta.PreprocessMode)
|
||||||
|
}
|
||||||
|
if out.MIMEType != "image/png" {
|
||||||
|
t.Fatalf("mime: %s", out.MIMEType)
|
||||||
|
}
|
||||||
|
if meta.OutputWidth != 400 || meta.OutputHeight != 300 {
|
||||||
|
t.Fatalf("dims: %dx%d", meta.OutputWidth, meta.OutputHeight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreprocessImageFile_passthroughDisabled(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "small.png")
|
||||||
|
if err := imaging.Save(imaging.New(100, 100, color.White), path); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, meta, err := PreprocessImageFile(path, PreprocessOptions{
|
||||||
|
MaxDimension: 2048,
|
||||||
|
MaxPayloadBytes: 512 * 1024,
|
||||||
|
SkipPreprocessBelowBytes: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if meta.PreprocessMode != "jpeg" {
|
||||||
|
t.Fatalf("expected jpeg compress, got %s", meta.PreprocessMode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreprocessImageFile_rejectsOversizeFile(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "tiny.png")
|
||||||
|
f, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := png.Encode(f, image.NewRGBA(image.Rect(0, 0, 2, 2))); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
|
||||||
|
_, _, err = PreprocessImageFile(path, PreprocessOptions{MaxImageBytes: 1})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when file exceeds max_image_bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
package vision
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterAnalyzeImageTool 在 vision.enabled 且 model 已配置时注册 MCP 工具 analyze_image。
|
||||||
|
func RegisterAnalyzeImageTool(mcpServer *mcp.Server, cfg *config.Config, logger *zap.Logger) {
|
||||||
|
if mcpServer == nil || cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !cfg.Vision.Ready() {
|
||||||
|
if cfg.Vision.Enabled && logger != nil {
|
||||||
|
logger.Warn("vision.enabled 但 vision.model 为空,跳过注册 analyze_image")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Warn("vision: getwd failed, skip analyze_image", zap.Error(err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pathOpt := PathOptions{
|
||||||
|
CWD: cwd,
|
||||||
|
ResultStorageDir: cfg.Agent.ResultStorageDir,
|
||||||
|
ExtraRoots: cfg.Vision.AllowedRoots,
|
||||||
|
}
|
||||||
|
preOpt := PreprocessOptions{
|
||||||
|
MaxImageBytes: cfg.Vision.MaxImageBytesEffective(),
|
||||||
|
MaxDimension: cfg.Vision.MaxDimensionEffective(),
|
||||||
|
JPEGQuality: cfg.Vision.JPEGQualityEffective(),
|
||||||
|
MaxPayloadBytes: cfg.Vision.MaxPayloadBytesEffective(),
|
||||||
|
SkipPreprocessBelowBytes: cfg.Vision.SkipPreprocessBelowBytesEffective(),
|
||||||
|
}
|
||||||
|
client := NewClient(cfg.Vision, cfg.OpenAI)
|
||||||
|
|
||||||
|
tool := mcp.Tool{
|
||||||
|
Name: builtin.ToolAnalyzeImage,
|
||||||
|
Description: "分析服务器上的本地图片并返回文字描述(验证码、UI 元素、报错、架构图要点等)。" +
|
||||||
|
"输入为文件路径(如用户上传的 chat_uploads 路径或工具截图路径)。" +
|
||||||
|
"输出仅为文本,不含图片数据。不要对二进制图片使用 read_file 指望理解内容。",
|
||||||
|
ShortDescription: "分析本地图片并返回文字描述(验证码/UI/报错等)",
|
||||||
|
InputSchema: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"path": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "图片绝对路径或相对于进程工作目录的路径",
|
||||||
|
},
|
||||||
|
"question": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "可选:希望模型重点回答的问题。验证码图建议:只输出验证码字符,不要空格和解释",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"path"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
|
path, _ := args["path"].(string)
|
||||||
|
question, _ := args["question"].(string)
|
||||||
|
|
||||||
|
abs, err := ResolveImagePath(path, pathOpt)
|
||||||
|
if err != nil {
|
||||||
|
return textResult(fmt.Sprintf("路径校验失败: %v", err), true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
img, meta, err := PreprocessImageFile(abs, preOpt)
|
||||||
|
if err != nil {
|
||||||
|
return textResult(fmt.Sprintf("图片预处理失败: %v", err), true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
summary, err := client.Analyze(ctx, img, question)
|
||||||
|
if err != nil {
|
||||||
|
return textResult(fmt.Sprintf("视觉模型调用失败: %v", err), true), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body := formatAnalysisResult(abs, meta, summary)
|
||||||
|
return textResult(body, false), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mcpServer.RegisterTool(tool, handler)
|
||||||
|
if logger != nil {
|
||||||
|
logger.Info("vision: analyze_image 工具已注册", zap.String("model", cfg.Vision.Model))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func textResult(text string, isError bool) *mcp.ToolResult {
|
||||||
|
return &mcp.ToolResult{
|
||||||
|
Content: []mcp.Content{{Type: "text", Text: text}},
|
||||||
|
IsError: isError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatAnalysisResult(path string, meta PreprocessMeta, summary string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("## Image analysis\n")
|
||||||
|
b.WriteString("- **path**: ")
|
||||||
|
b.WriteString(path)
|
||||||
|
b.WriteString("\n")
|
||||||
|
switch meta.PreprocessMode {
|
||||||
|
case "passthrough":
|
||||||
|
b.WriteString(fmt.Sprintf("- **preprocess**: passthrough %dx%d, %s, %dKB (original %dKB)\n\n",
|
||||||
|
meta.OutputWidth, meta.OutputHeight, meta.OutputMIMEType,
|
||||||
|
(meta.OutputBytes+1023)/1024, (meta.OriginalBytes+1023)/1024))
|
||||||
|
default:
|
||||||
|
b.WriteString(fmt.Sprintf("- **preprocess**: %dx%d → %dx%d, jpeg q=%d, %dKB (original %dKB)\n\n",
|
||||||
|
meta.OriginalWidth, meta.OriginalHeight,
|
||||||
|
meta.OutputWidth, meta.OutputHeight,
|
||||||
|
meta.JPEGQuality, (meta.OutputBytes+1023)/1024,
|
||||||
|
(meta.OriginalBytes+1023)/1024))
|
||||||
|
}
|
||||||
|
b.WriteString("### Summary\n")
|
||||||
|
b.WriteString(strings.TrimSpace(summary))
|
||||||
|
b.WriteString("\n")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
- Configure **Host / Port / HTTPS / Password** and choose an agent mode
|
- Configure **Host / Port / HTTPS / Password** and choose an agent mode
|
||||||
- Click **Validate** to login (`POST /api/auth/login`) and verify token (`GET /api/auth/validate`)
|
- Click **Validate** to login (`POST /api/auth/login`) and verify token (`GET /api/auth/validate`)
|
||||||
- Right-click any HTTP message in Burp and send it to CyberStrikeAI for **streaming web pentest**
|
- Right-click any HTTP message in Burp and send it to CyberStrikeAI for **streaming web pentest** (agent modes: **Eino Single**, Deep, Plan-Execute, Supervisor — maps to `/api/eino-agent/stream` or `/api/multi-agent/stream`)
|
||||||
- Keep a **test history sidebar** (searchable) so you can revisit previous runs
|
- Keep a **test history sidebar** (searchable) so you can revisit previous runs
|
||||||
- Output is split into **collapsible Progress** + **Final Response** (Markdown rendering supported)
|
- Output is split into **collapsible Progress** + **Final Response** (Markdown rendering supported)
|
||||||
- View captured **Request / Response** for each run
|
- View captured **Request / Response** for each run
|
||||||
|
|||||||
@@ -10,8 +10,8 @@
|
|||||||
- 右键任意 HTTP 请求包 → **Send to CyberStrikeAI (stream test)**:
|
- 右键任意 HTTP 请求包 → **Send to CyberStrikeAI (stream test)**:
|
||||||
- 将该 HTTP 请求(含 headers/body;若存在响应则附带截断片段)发送到 CyberStrikeAI
|
- 将该 HTTP 请求(含 headers/body;若存在响应则附带截断片段)发送到 CyberStrikeAI
|
||||||
- 以 **SSE 流式**接收返回内容,并在标签页中实时展示
|
- 以 **SSE 流式**接收返回内容,并在标签页中实时展示
|
||||||
- 单 Agent:`POST /api/agent-loop/stream`
|
- 单 Agent:`POST /api/eino-agent/stream`
|
||||||
- 多 Agent:`POST /api/multi-agent/stream`(需要服务端启用 `multi_agent.enabled: true`)
|
- 多 Agent:`POST /api/multi-agent/stream`(需 `multi_agent.enabled: true`,请求体 `orchestration`)
|
||||||
- **测试历史侧边栏(可搜索)**:每次发送都会新增一条记录,方便回看与对比
|
- **测试历史侧边栏(可搜索)**:每次发送都会新增一条记录,方便回看与对比
|
||||||
- **Output 分区**:`Progress`(可折叠)+ `Final Response`(主区域)
|
- **Output 分区**:`Progress`(可折叠)+ `Final Response`(主区域)
|
||||||
- **Markdown 渲染**:最终输出可在 Output 主区域渲染为富文本(可开关)
|
- **Markdown 渲染**:最终输出可在 Output 主区域渲染为富文本(可开关)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user