mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-04 11:37:57 +02:00
Compare commits
109 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b9e5527131 | |||
| 3d5e2bc4c7 | |||
| d58c4642f7 | |||
| 9df6de088b | |||
| aae71a0c3e | |||
| 059a33029e | |||
| 15daad97d4 | |||
| f02c0d175b | |||
| a8da115d28 | |||
| e4a01089e7 | |||
| bbf8c416fc | |||
| d41decd707 | |||
| 93a600d60e | |||
| c86825d365 | |||
| 4af5e2691e | |||
| 85400cd3f8 | |||
| a66b8fc821 | |||
| 58be62fa24 | |||
| a3739210e4 | |||
| e936c63754 | |||
| 1f46d4a930 | |||
| 3a995183a6 | |||
| 3ed7499a0b | |||
| f26354d483 | |||
| ebd872b373 | |||
| 07439bce6e | |||
| 625ac4358f | |||
| eb6b9d6f45 | |||
| ad97544bbe | |||
| 12a1ebe9cd | |||
| b97e726237 | |||
| 2eb923e5fa | |||
| 745a69f93b | |||
| 011a242acc | |||
| 6a52ef96f4 | |||
| 52f8c377b6 | |||
| 8d04b0c266 | |||
| bcdff06702 | |||
| 3210bc727f | |||
| 5254ca52fb | |||
| 1ff2df68ac | |||
| fe60497863 | |||
| 7acd21bc98 | |||
| dbcf9b8418 | |||
| b3767b2deb | |||
| 7e764df0e8 | |||
| a1ffb20d6e | |||
| 125685f08f | |||
| b804635fa8 | |||
| c9fb5d11d3 | |||
| 926491b746 | |||
| 4e17691717 | |||
| 2e2a6dedd4 | |||
| b1323896c8 | |||
| 595074b7b0 | |||
| 2e063dd857 | |||
| a110d233e1 | |||
| 2f58d0a457 | |||
| 5b7f157802 | |||
| 09890db635 | |||
| c0171ef60a | |||
| 4eb73fb638 | |||
| d1b49cb20d | |||
| 930eb47013 | |||
| 9964e13197 | |||
| 4f7b21cb7e | |||
| 9fae9db906 | |||
| 7ecd8c61e8 | |||
| bdb0326e47 | |||
| 8dccc6aa06 | |||
| fd4bbe8d76 | |||
| d80651e4d8 | |||
| f920ff0a5d | |||
| ce8b57501d | |||
| ecb38a3959 | |||
| e69fdb71ca | |||
| 6aa1631748 | |||
| 52de3b0f41 | |||
| e537e55198 | |||
| dc20b4804e | |||
| 6245d69364 | |||
| ede32951bf | |||
| 866a8ebccf | |||
| 276b3f7ef5 | |||
| 81e461db54 | |||
| 02cd488a3d | |||
| b4b2f55665 | |||
| 7aa0ebea6d | |||
| 63ef4399f8 | |||
| 553d0ed6bf | |||
| d92bbbea07 | |||
| f89ad1b42d | |||
| bbe14c1861 | |||
| 2fc37fefd1 | |||
| ded8ac5a3f | |||
| bf44cf58d3 | |||
| 6d390e80d5 | |||
| cfc49ba16f | |||
| d03f2fcf2b | |||
| 6e67684bba | |||
| 8f9d2f381a | |||
| 89c275269f | |||
| cb4900c61d | |||
| 5c192cd308 | |||
| 8571e41138 | |||
| e1a74b29b1 | |||
| 39f1c72755 | |||
| dd3621e89d | |||
| 0bcb16e021 |
@@ -35,7 +35,18 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
|
||||
|
||||
### System Dashboard Overview
|
||||
|
||||
<img src="./images/dashboard.png" alt="System Dashboard" width="100%">
|
||||
<table>
|
||||
<tr>
|
||||
<td width="50%" align="center">
|
||||
<strong>Light Mode</strong><br/>
|
||||
<img src="./images/dashboard.png" alt="System Dashboard (Light)" width="100%">
|
||||
</td>
|
||||
<td width="50%" align="center">
|
||||
<strong>Dark Mode</strong><br/>
|
||||
<img src="./images/dark.png" alt="System Dashboard (Dark)" width="100%">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
*The dashboard provides a comprehensive overview of system runtime status, security vulnerabilities, tool usage, and knowledge base, helping users quickly understand the platform's core features and current state.*
|
||||
|
||||
@@ -110,12 +121,13 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
|
||||
- 📄 Large-result pagination, compression, and searchable archives
|
||||
- 🔗 Attack-chain graph, risk scoring, and step-by-step replay
|
||||
- 🔒 Password-protected web UI, audit logs, and SQLite persistence
|
||||
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
|
||||
- 📚 Knowledge base (RAG): **Eino MultiQuery** query rewrite + multi-path vector retrieval + **HTTP rerank** (DashScope `gte-rerank` / Cohere-compatible) + post-processing (dedupe, budget); **Eino Compose** indexing pipeline
|
||||
- 📁 Conversation grouping with pinning, rename, and batch management
|
||||
- 📂 **Project management**: shared facts (blackboard) across sessions, `upsert_project_fact` + `links` to chain paths; attack-chain and project fact graph views
|
||||
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
||||
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
||||
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
||||
- 🔀 **Graph orchestration**: visual workflow editor (Start / Agent / Tool / Condition / HITL / Output) with `{{previous.output}}` and `{{outputs.variable_name}}` for inter-node data passing; bind a graph to a role for automatic execution on chat. See [Graph orchestration guide](docs/workflow-graph_en.md)
|
||||
- 🧩 **Agent orchestration (CloudWeGo Eino)**: **single-agent** via **`/api/eino-agent/stream`** (Eino ADK `ChatModelAgent`); **multi-agent** via **`/api/multi-agent/stream`** with **`deep`** (coordinator + `task` sub-agents), **`plan_execute`**, or **`supervisor`** (`orchestration` in the request body). ADK **summarization** compresses long contexts; pre-compaction **transcripts** land at `data/conversation_artifacts/<conversation-id>/summarization/transcript.txt` (full user/assistant/tool turns; static system omitted). Markdown under `agents/`: `orchestrator.md`, `orchestrator-plan-execute.md`, `orchestrator-supervisor.md`, plus sub-agent `*.md` (see [Multi-agent doc](docs/MULTI_AGENT_EINO.md))
|
||||
- 🖼️ **Vision analysis (`analyze_image`)**: separate VL model (e.g. `qwen-vl-max`) via MCP for local screenshots, captchas, and UI; image bytes stay out of agent history (text summaries only). Configure `vision` in `config.yaml`; see [docs/VISION.md](docs/VISION.md)
|
||||
- 🎯 **Skills (refactored for Eino)**: packs under `skills_dir` follow **Agent Skills** layout (`SKILL.md` + optional files); **multi-agent** sessions use the official Eino ADK **`skill`** tool for **progressive disclosure** (load by name), with optional **host filesystem / shell** via `multi_agent.eino_skills`; optional **`eino_middleware`** adds patchtoolcalls, tool_search, **plantask** (`TaskCreate` / `TaskList` boards under `skills_dir/.eino/plantask/`), reduction, file **checkpoints** (`checkpoint_dir`), ChatModel **retries**, session **output key**, and Deep tuning—20+ sample domains (SQLi, XSS, API security, …) ship under `skills/`
|
||||
@@ -244,6 +256,7 @@ Requirements / tips:
|
||||
- **Conversation testing** – Natural-language prompts trigger toolchains with streaming SSE output.
|
||||
- **Single vs multi-agent** – Chat UI switches between **Eino single-agent** (`/api/eino-agent/stream`) and **multi-agent** (`/api/multi-agent/stream` with `orchestration`: `deep` | `plan_execute` | `supervisor`). Multi mode requires `multi_agent.enabled: true`. MCP tools are bridged the same way for both paths.
|
||||
- **Role-based testing** – Select from predefined security testing roles (Penetration Testing, CTF, Web App Scanning, API Security Testing, etc.) to customize AI behavior and tool availability. Each role applies custom system prompts and can restrict available tools for focused testing scenarios.
|
||||
- **Graph orchestration** – Design flows on the **Graph Orchestration** page (drag nodes, connect edges, save); bind `workflow_id` on a role to run the graph on chat (Agent, MCP tools, condition branches). Use `{{outputs.variable_name}}` to pass data across non-adjacent nodes. See [Graph orchestration guide](docs/workflow-graph_en.md).
|
||||
- **Tool monitor** – Inspect running jobs, execution logs, and large-result attachments.
|
||||
- **History & audit** – Every conversation and tool invocation is stored in SQLite with replay.
|
||||
- **Conversation groups** – Organize conversations into groups, pin important groups, rename or delete groups via context menu.
|
||||
@@ -455,16 +468,12 @@ A test SSE MCP server is available at `cmd/test-sse-mcp-server/` for validation
|
||||
|
||||
### Knowledge Base
|
||||
- **Vector search** – AI agent can automatically search the knowledge base for relevant security knowledge during conversations using the `search_knowledge_base` tool.
|
||||
- **Vector retrieval** – cosine similarity over stored embeddings, aligned with Eino `retriever.Retriever` usage.
|
||||
- **Auto-indexing** – scans the `knowledge_base/` directory for Markdown files and automatically indexes them with embeddings.
|
||||
- **Web management** – create, update, delete knowledge items through the web UI, with category-based organization.
|
||||
- **RAG pipeline (always on)** – **MultiQuery** (LLM query rewrite) → vector prefetch & fusion → **HTTP rerank** (DashScope `gte-rerank` or Cohere-compatible `/v1/rerank`) → post-processing (normalized dedupe, char/token budget, final top_k). Rerank failures degrade to fusion order without breaking search.
|
||||
- **Vector retrieval** – cosine similarity over stored embeddings with configurable threshold, aligned with Eino `retriever.Retriever` usage.
|
||||
- **Auto-indexing** – scans the `knowledge_base/` directory for Markdown files and automatically indexes them with embeddings (Markdown header split + recursive chunking via Eino).
|
||||
- **Web management** – create, update, delete knowledge items through the web UI, with category-based organization; settings page exposes MultiQuery / rerank / prefetch options.
|
||||
- **Retrieval logs** – tracks all knowledge retrieval operations for audit and debugging.
|
||||
|
||||
**Quick Start (Using Pre-built Knowledge Base):**
|
||||
1. **Download the knowledge database** – Download the pre-built knowledge database file from [GitHub Releases](https://github.com/Ed1s0nZ/CyberStrikeAI/releases).
|
||||
2. **Extract and place** – Extract the downloaded knowledge database file (`knowledge.db`) and place it in the project's `data/` directory.
|
||||
3. **Restart the service** – Restart the CyberStrikeAI service, and the knowledge base will be ready to use immediately without rebuilding the index.
|
||||
|
||||
**Setting up the knowledge base:**
|
||||
1. **Enable in config** – set `knowledge.enabled: true` in `config.yaml`:
|
||||
```yaml
|
||||
@@ -479,6 +488,17 @@ A test SSE MCP server is available at `cmd/test-sse-mcp-server/` for validation
|
||||
retrieval:
|
||||
top_k: 5
|
||||
similarity_threshold: 0.7
|
||||
multi_query:
|
||||
max_queries: 4 # LLM rewrite variants (always on)
|
||||
rerank: # always on; empty fields inherit openai/embedding credentials
|
||||
provider: "" # auto: dashscope | cohere from base_url
|
||||
model: "" # empty: gte-rerank (DashScope) or rerank-multilingual-v3.0 (Cohere)
|
||||
base_url: ""
|
||||
api_key: ""
|
||||
post_retrieve:
|
||||
prefetch_top_k: 20 # vector candidates per MultiQuery variant; 0 = max(top_k×4, 20)
|
||||
max_context_chars: 0
|
||||
max_context_tokens: 0
|
||||
```
|
||||
2. **Add knowledge files** – place Markdown files in `knowledge_base/` directory, organized by category (e.g., `knowledge_base/SQL Injection/README.md`).
|
||||
3. **Scan and index** – use the web UI to scan the knowledge base directory, which will automatically import files and build vector embeddings.
|
||||
@@ -539,6 +559,17 @@ knowledge:
|
||||
retrieval:
|
||||
top_k: 5 # Number of top results to return
|
||||
similarity_threshold: 0.7 # Minimum cosine similarity (0-1)
|
||||
multi_query:
|
||||
max_queries: 4 # MultiQuery rewrite variants (always on)
|
||||
rerank: # HTTP rerank (always on); empty fields inherit openai/embedding credentials
|
||||
provider: ""
|
||||
model: ""
|
||||
base_url: ""
|
||||
api_key: ""
|
||||
post_retrieve:
|
||||
prefetch_top_k: 20 # per MultiQuery variant; 0 = max(top_k×4, 20)
|
||||
max_context_chars: 0
|
||||
max_context_tokens: 0
|
||||
roles_dir: "roles" # Role configuration directory (relative to config file)
|
||||
skills_dir: "skills" # Skills directory (relative to config file)
|
||||
agents_dir: "agents" # Multi-agent Markdown definitions (orchestrator + sub-agents)
|
||||
@@ -601,6 +632,7 @@ enabled: true
|
||||
## Related documentation
|
||||
|
||||
- [Multi-agent mode (Eino)](docs/MULTI_AGENT_EINO.md): **Deep**, **Plan-Execute**, **Supervisor**, `agents/*.md`, `eino_skills` / `eino_middleware`, APIs, and chat/stream behavior.
|
||||
- [Graph orchestration guide](docs/workflow-graph_en.md): visual workflow design, node configuration, `previous` / `outputs` variable passing, and role binding.
|
||||
- [Robot / Chatbot guide (DingTalk & Lark)](docs/robot_en.md): Full setup, commands, and troubleshooting for using CyberStrikeAI from DingTalk or Lark on your phone. **Follow this doc to avoid common pitfalls.**
|
||||
|
||||
## Project Layout
|
||||
@@ -653,8 +685,6 @@ CyberStrikeAI has joined [404Starlink](https://github.com/knownsec/404StarLink)
|
||||
</a>
|
||||
</div>
|
||||
|
||||
## Stargazers over time
|
||||

|
||||
|
||||
|
||||
---
|
||||
|
||||
+42
-12
@@ -34,7 +34,18 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
|
||||
### 系统仪表盘概览
|
||||
|
||||
<img src="./images/dashboard.png" alt="系统仪表盘" width="100%">
|
||||
<table>
|
||||
<tr>
|
||||
<td width="50%" align="center">
|
||||
<strong>浅色模式</strong><br/>
|
||||
<img src="./images/dashboard.png" alt="系统仪表盘(浅色)" width="100%">
|
||||
</td>
|
||||
<td width="50%" align="center">
|
||||
<strong>深色模式</strong><br/>
|
||||
<img src="./images/dark.png" alt="系统仪表盘(深色)" width="100%">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
*仪表盘提供系统运行状态、安全漏洞、工具使用情况和知识库的全面概览,帮助用户快速了解平台核心功能和当前状态。*
|
||||
|
||||
@@ -109,12 +120,13 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
- 📄 大结果分页、压缩与全文检索
|
||||
- 🔗 攻击链可视化、风险打分与步骤回放
|
||||
- 🔒 Web 登录保护、审计日志、SQLite 持久化
|
||||
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
|
||||
- 📚 知识库(RAG):**Eino MultiQuery** 查询改写 + 多路向量检索 + **HTTP 精排**(DashScope `gte-rerank` / Cohere 兼容)+ 后处理(去重、预算);索引侧为 **Eino Compose** 流水线
|
||||
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
||||
- 📂 **项目管理**:共享事实(黑板)跨会话沉淀认知,`upsert_project_fact` + `links` 串联攻击路径;聊天攻击链与项目事实图可视化
|
||||
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
||||
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
||||
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
||||
- 🔀 **图编排**:可视化流程编排(开始 / Agent / 工具 / 条件 / 审批 / 输出),节点间用 `{{previous.output}}` 或 `{{outputs.变量名}}` 传参;绑定角色后对话自动按图执行。详见 [图编排使用说明](docs/workflow-graph.md)
|
||||
- 🧩 **Agent 编排(CloudWeGo Eino)**:**单代理** `POST /api/eino-agent/stream`(Eino ADK);**多代理** `POST /api/multi-agent/stream`,`orchestration` 选 **`deep`** / **`plan_execute`** / **`supervisor`**。ADK **Summarization** 在上下文过长时压缩历史;压缩前将可恢复 **转录** 写入 `data/conversation_artifacts/<会话ID>/summarization/transcript.txt`(保留完整 user/assistant/tool 轮次,省略静态 system)。`agents/` 下主代理与子代理 Markdown 见 [多代理说明](docs/MULTI_AGENT_EINO.md)
|
||||
- 🖼️ **视觉分析(`analyze_image`)**:独立 Vision 模型(如 `qwen-vl-max`),MCP 工具分析本地截图/验证码/UI;图片仅在单次 VL 调用中出现,对话上下文只保留文字摘要。配置见 `config.yaml` → `vision` 与 [视觉分析说明](docs/VISION.md)
|
||||
- 🎯 **Skills(面向 Eino 重构)**:技能包放在 **`skills_dir`**,遵循 **Agent Skills** 目录规范(`SKILL.md` + 可选文件);**多代理** 下通过 Eino 官方 **`skill`** 工具 **渐进式披露**(按 name 加载)。**`multi_agent.eino_skills`** 控制是否启用、本机文件/Shell 工具、工具名覆盖;**`eino_middleware`** 可选 patch、tool_search、**plantask**(`TaskCreate` / `TaskList` 任务板,落在 `skills_dir/.eino/plantask/`)、reduction、文件型 **checkpoint**(`checkpoint_dir`)、ChatModel **重试**、会话 **输出键** 及 Deep 调参。20+ 领域示例仍可绑定角色
|
||||
@@ -242,6 +254,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
||||
- **对话测试**:自然语言触发多步工具编排,SSE 实时输出。
|
||||
- **单代理 / 多代理**:聊天可选 **Eino 单代理**(`/api/eino-agent/stream`)与 **多代理**(`/api/multi-agent/stream` + `orchestration`)。多代理需 `multi_agent.enabled: true`。MCP 工具桥接一致。
|
||||
- **角色化测试**:从预设的安全测试角色(渗透测试、CTF、Web 应用扫描、API 安全测试等)中选择,自定义 AI 行为和可用工具。每个角色可应用自定义系统提示词,并可限制可用工具列表,实现聚焦的测试场景。
|
||||
- **图编排**:在 **图编排** 页拖拽节点、连线并保存流程;在角色中绑定 `workflow_id` 后,该角色对话将按图执行(Agent、MCP 工具、条件分支等)。跨节点传参优先用 `{{outputs.变量名}}`。详见 [图编排使用说明](docs/workflow-graph.md)。
|
||||
- **工具监控**:查看任务队列、执行日志、大文件附件。
|
||||
- **会话历史**:所有对话与工具调用保存在 SQLite,可随时重放。
|
||||
- **对话分组**:将对话按项目或主题组织到不同分组,支持置顶、重命名、删除等操作,所有数据持久化存储。
|
||||
@@ -453,16 +466,12 @@ CyberStrikeAI 支持通过三种传输模式连接外部 MCP 服务器:
|
||||
|
||||
### 知识库功能
|
||||
- **向量检索**:AI 智能体在对话过程中可自动调用 `search_knowledge_base` 工具搜索知识库中的安全知识。
|
||||
- **向量检索**:基于嵌入余弦相似度与相似度阈值过滤(与 Eino `retriever.Retriever` 语义一致)。
|
||||
- **自动索引**:扫描 `knowledge_base/` 目录下的 Markdown 文件,自动构建向量嵌入索引。
|
||||
- **Web 管理**:通过 Web 界面创建、更新、删除知识项,支持分类管理。
|
||||
- **RAG 管线(始终启用)**:**MultiQuery**(LLM 查询改写)→ 向量预取与融合 → **HTTP 精排**(DashScope `gte-rerank` 或 Cohere 兼容 `/v1/rerank`)→ 后处理(规范化去重、字符/token 预算、最终 top_k)。精排失败时自动降级为融合排序,检索仍可用。
|
||||
- **向量相似度**:基于嵌入余弦相似度与相似度阈值过滤(与 Eino `retriever.Retriever` 语义一致)。
|
||||
- **自动索引**:扫描 `knowledge_base/` 目录下的 Markdown 文件,自动构建向量嵌入索引(Eino Markdown 标题切分 + 递归分块)。
|
||||
- **Web 管理**:通过 Web 界面创建、更新、删除知识项,支持分类管理;设置页可配置 MultiQuery / 精排 / 预取候选数。
|
||||
- **检索日志**:记录所有知识检索操作,便于审计与调试。
|
||||
|
||||
**快速开始(使用预构建知识库):**
|
||||
1. **下载知识数据库**:从 [GitHub Releases](https://github.com/Ed1s0nZ/CyberStrikeAI/releases) 下载预构建的知识数据库文件。
|
||||
2. **解压并放置**:将下载的知识数据库文件(`knowledge.db`)解压后放到项目的 `data/` 目录下。
|
||||
3. **重启服务**:重启 CyberStrikeAI 服务,知识库即可直接使用,无需重新构建索引。
|
||||
|
||||
**知识库配置步骤:**
|
||||
1. **启用功能**:在 `config.yaml` 中设置 `knowledge.enabled: true`:
|
||||
```yaml
|
||||
@@ -477,6 +486,17 @@ CyberStrikeAI 支持通过三种传输模式连接外部 MCP 服务器:
|
||||
retrieval:
|
||||
top_k: 5
|
||||
similarity_threshold: 0.7
|
||||
multi_query:
|
||||
max_queries: 4 # LLM 改写变体上限(始终启用)
|
||||
rerank: # 精排始终启用;留空则继承 openai/embedding 凭据
|
||||
provider: "" # 空=按 base_url 推断 dashscope | cohere
|
||||
model: "" # 空=DashScope→gte-rerank,Cohere→rerank-multilingual-v3.0
|
||||
base_url: ""
|
||||
api_key: ""
|
||||
post_retrieve:
|
||||
prefetch_top_k: 20 # 每条 MultiQuery 变体的向量候选数;0=max(top_k×4, 20)
|
||||
max_context_chars: 0
|
||||
max_context_tokens: 0
|
||||
```
|
||||
2. **添加知识文件**:将 Markdown 文件放入 `knowledge_base/` 目录,按分类组织(如 `knowledge_base/SQL注入/README.md`)。
|
||||
3. **扫描索引**:在 Web 界面中点击"扫描知识库",系统会自动导入文件并构建向量索引。
|
||||
@@ -537,6 +557,17 @@ knowledge:
|
||||
retrieval:
|
||||
top_k: 5 # 检索返回的 Top-K 结果数量
|
||||
similarity_threshold: 0.7 # 余弦相似度阈值(0-1),低于此值的结果将被过滤
|
||||
multi_query:
|
||||
max_queries: 4 # MultiQuery 改写变体上限(始终启用)
|
||||
rerank: # HTTP 精排(始终启用);留空则继承 openai/embedding 凭据
|
||||
provider: ""
|
||||
model: ""
|
||||
base_url: ""
|
||||
api_key: ""
|
||||
post_retrieve:
|
||||
prefetch_top_k: 20 # 每条 MultiQuery 变体;0=max(top_k×4, 20)
|
||||
max_context_chars: 0
|
||||
max_context_tokens: 0
|
||||
roles_dir: "roles" # 角色配置文件目录(相对于配置文件所在目录)
|
||||
skills_dir: "skills" # Skills 目录(相对于配置文件所在目录)
|
||||
agents_dir: "agents" # 多代理 Markdown(主代理 orchestrator.md + 子代理 *.md)
|
||||
@@ -599,6 +630,7 @@ enabled: true
|
||||
## 相关文档
|
||||
|
||||
- [多代理模式(Eino)](docs/MULTI_AGENT_EINO.md):**Deep**、**Plan-Execute**、**Supervisor**、`agents/*.md`、`eino_skills` / `eino_middleware`、接口与流式说明。
|
||||
- [图编排使用说明](docs/workflow-graph.md):可视化流程搭建、节点配置、`previous` / `outputs` 变量传参与角色绑定。
|
||||
- [机器人使用说明(钉钉 / 飞书)](docs/robot.md):在手机端通过钉钉、飞书与 CyberStrikeAI 对话的完整配置步骤、命令与排查说明,**建议按该文档操作以避免走弯路**。
|
||||
|
||||
## 项目结构
|
||||
@@ -650,8 +682,6 @@ CyberStrikeAI 现已加入 [404星链计划](https://github.com/knownsec/404Star
|
||||
</a>
|
||||
</div>
|
||||
|
||||
## Stargazers over time
|
||||

|
||||
|
||||
---
|
||||
|
||||
|
||||
+77
-6
@@ -10,7 +10,7 @@
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.6.47"
|
||||
version: "v1.6.50"
|
||||
# 服务器配置
|
||||
server:
|
||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||
@@ -102,9 +102,71 @@ agent:
|
||||
|
||||
system_prompt_path: ""
|
||||
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
||||
# 非白名单工具在审批方=审计 Agent 时,按会话 HITL 模式选用提示词:
|
||||
# approval → audit_agent_prompt
|
||||
# review_edit → audit_agent_prompt_review_edit(可改参后放行)
|
||||
hitl:
|
||||
# 全局默认审批方:human=人工审批,audit_agent=审计 Agent;未选会话时切换会写入本项,重启后仍生效
|
||||
default_reviewer: human
|
||||
# 已决策审计日志保留天数(与 MCP 监控一致;省略默认 90;0 表示不自动清理)
|
||||
retention_days: 90
|
||||
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
|
||||
tool_whitelist: [read_file, list_dir, glob, grep]
|
||||
tool_whitelist: [read_file, list_dir, glob, grep, tool_search]
|
||||
# audit_agent_prompt: | # 审批模式;留空使用内置默认,可在「人机协同」页编辑
|
||||
# audit_agent_prompt_review_edit: | # 审查编辑模式;留空使用内置默认
|
||||
|
||||
audit_agent_prompt: |-
|
||||
你是 CyberStrikeAI 人机协同审计 Agent。审查 Agent 即将执行的工具调用是否会对系统造成实质性损害。
|
||||
|
||||
你会收到 JSON,包含 hitlMode、toolName、arguments/argumentsObj、userMessage、thinking、reasoningChain、planning 等字段。
|
||||
|
||||
裁决基调(默认放行):
|
||||
- 常规、低风险的渗透测试操作 → approve(如信息收集、端口/服务扫描、目录枚举、只读查询、无害探测命令)
|
||||
- 与用户授权、当前任务目标一致,且未见明确高危迹象 → approve
|
||||
- 仅在「可能对系统造成实质影响」时 → reject
|
||||
|
||||
必须 reject 的高危情形(示例,非穷举):
|
||||
- 删库、清表、批量删除数据、格式化磁盘、不可逆破坏
|
||||
- 修改/重置密码、创建或篡改管理员账号、持久化后门、开机自启
|
||||
- 向生产环境写入恶意载荷、勒索加密、停止关键服务、修改系统核心配置
|
||||
- 明显越权:与任务/授权目标无关的破坏性操作
|
||||
|
||||
不应单独作为 reject 理由的情形:
|
||||
- 常规 nmap/curl/grep/读文件/枚举类命令本身
|
||||
- 参数略显宽泛但无明确破坏意图
|
||||
- 仅因「信息不足」——若无上述高危迹象,应 approve 并可在 comment 中提示注意点
|
||||
|
||||
仅输出一行 JSON,不要 markdown 代码块:
|
||||
{"decision":"approve"|"reject","comment":"简要理由"}
|
||||
audit_agent_prompt_review_edit: |-
|
||||
你是 CyberStrikeAI 人机协同审计 Agent。审查 Agent 即将执行的工具调用是否会对系统造成实质性损害。
|
||||
|
||||
你会收到 JSON,包含 hitlMode、toolName、arguments/argumentsObj、userMessage、thinking、reasoningChain、planning 等字段。
|
||||
|
||||
裁决基调(默认放行):
|
||||
- 常规、低风险的渗透测试操作 → approve(如信息收集、端口/服务扫描、目录枚举、只读查询、无害探测命令)
|
||||
- 与用户授权、当前任务目标一致,且未见明确高危迹象 → approve
|
||||
- 仅在「可能对系统造成实质影响」时 → reject;参数可安全收窄时优先 approve + editedArguments
|
||||
|
||||
必须 reject 的高危情形(示例,非穷举):
|
||||
- 删库、清表、批量删除数据、格式化磁盘、不可逆破坏
|
||||
- 修改/重置密码、创建或篡改管理员账号、持久化后门、开机自启
|
||||
- 向生产环境写入恶意载荷、勒索加密、停止关键服务、修改系统核心配置
|
||||
- 明显越权:与任务/授权目标无关的破坏性操作
|
||||
|
||||
不应单独作为 reject 理由的情形:
|
||||
- 常规 nmap/curl/grep/读文件/枚举类命令本身
|
||||
- 参数略显宽泛但无明确破坏意图(应收窄参数后 approve)
|
||||
- 仅因「信息不足」——若无上述高危迹象,应 approve 并可在 comment 中提示注意点
|
||||
|
||||
仅输出一行 JSON,不要 markdown 代码块:
|
||||
{"decision":"approve"|"reject","comment":"简要理由","editedArguments":{...}}
|
||||
|
||||
editedArguments 规则(仅 approve 且需要改参时填写,否则省略该字段):
|
||||
- 提供完整替换后的工具参数对象,键名与 argumentsObj 一致
|
||||
- 只做最小必要修改以收窄范围、消除风险(如限制 path、去掉危险 flag)
|
||||
- 禁止扩大攻击面:不得扩大目标范围、提升权限或引入破坏性参数
|
||||
- 无法安全改参且存在上述高危情形时应 reject,不要勉强 approve
|
||||
# 多代理与 Eino 单代理(CloudWeGo Eino ADK;单代理入口 /api/eino-agent*,多代理 /api/multi-agent*)
|
||||
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
|
||||
# Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体 orchestration 中指定;机器人按 robot_default_agent_mode
|
||||
@@ -115,7 +177,6 @@ multi_agent:
|
||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。主/子代理 ReAct 轮次见 agent.max_iterations。
|
||||
plan_execute_loop_max_iterations: 0
|
||||
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中注入用户原文;0=不截断(默认),>0=总字符上限,负数=禁用
|
||||
user_verbatim_anchor_max_runes: 0 # 主代理 system 中逐轮保留用户原文(压缩后刷新);0=不截断(默认),>0=总字符上限,负数=禁用
|
||||
without_general_sub_agent: false # false 时保留 Deep 内置 general-purpose 子代理
|
||||
without_write_todos: false
|
||||
orchestrator_instruction: "" # Deep 主代理:agents/orchestrator.md(或 kind: orchestrator 的单个 .md)正文优先;正文为空时用此处;皆空则 Eino 默认
|
||||
@@ -126,7 +187,8 @@ multi_agent:
|
||||
disable: false # true:不注册 skill 渐进式披露中间件,也不挂本机 FS/Shell 工具;false:按下方开关加载
|
||||
filesystem_tools: true # true:注册 read_file/glob/grep/write/edit/execute(授权环境慎用);false:仅 skill,不暴露本机读写与 Shell
|
||||
skill_tool_name: skill # 模型侧可调用的「加载技能」工具名,一般保持 skill;与技能包文档中的调用名一致即可
|
||||
# Eino ADK 中间件与 Deep/Supervisor 调参(结构体见 internal/config/config.go → MultiAgentEinoMiddlewareConfig)
|
||||
# Eino ADK 中间件与 Deep/Supervisor/plan_execute Executor 调参(结构体见 internal/config/config.go → MultiAgentEinoMiddlewareConfig)
|
||||
# plan_execute:下列 patch/reduction/tool_search/plantask 等同样作用于 Executor(经 ExecPreMiddlewares);Planner/Replanner 不挂 MCP 前置中间件。
|
||||
eino_middleware:
|
||||
patch_tool_calls: true # true:修补历史中无 tool_result 的悬空 tool_call(流式中断/重试后更稳);false:关闭;字段省略时默认等同 true
|
||||
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
||||
@@ -150,6 +212,7 @@ multi_agent:
|
||||
checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制
|
||||
run_retry_max_attempts: 0 # 429/5xx/网络抖动时可退避重试次数(run loop + summarization 共用 isEinoTransientRunError);0=默认 10
|
||||
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
||||
empty_response_continue_max_attempts: 0 # Run 成功但未捕获助手正文(含流式中断)时 Handler 退避续跑次数;0=默认 5
|
||||
deep_output_key: final_answer # P0:Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single)
|
||||
deep_model_retry_max_retries: 0 # 已废弃,请用 run_retry_max_attempts;保留字段仅为兼容旧配置
|
||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||
@@ -222,9 +285,17 @@ knowledge:
|
||||
retrieval:
|
||||
top_k: 5 # 检索返回的Top-K结果数量
|
||||
similarity_threshold: 0.4 # 余弦相似度阈值(0-1),低于此值的结果将被过滤
|
||||
# 检索后处理:固定正文规范化去重;上下文预算;可选代码注入 DocumentReranker 做重排
|
||||
# Eino MultiQuery:LLM 改写查询后多路向量检索再融合(始终启用)
|
||||
multi_query:
|
||||
max_queries: 4 # 改写变体上限(含语义覆盖);建议 3~4
|
||||
# 精排(始终启用):dashscope 用 gte-rerank;其他 OpenAI 兼容端点走 /v1/rerank
|
||||
rerank:
|
||||
provider: "" # 空=按 base_url 推断:dashscope | cohere
|
||||
model: "" # 空=dashscope→gte-rerank,cohere→rerank-multilingual-v3.0
|
||||
base_url: "" # 留空则用 embedding / openai 的 base_url
|
||||
api_key: "" # 留空则用 embedding / openai 的 api_key
|
||||
post_retrieve:
|
||||
prefetch_top_k: 0 # 0 与 top_k 相同;可设为 15~30 以便去重后仍填满 top_k
|
||||
prefetch_top_k: 20 # 每条 MultiQuery 变体的向量候选数;0=内置 max(top_k*4,20)
|
||||
max_context_chars: 0 # 0 不限制;否则返回的正文总 Unicode 字符上限(整段 chunk)
|
||||
max_context_tokens: 0 # 0 不限制;tiktoken 总 token 上限
|
||||
sub_index_filter: ""
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
| OpenAPI | 多代理路径说明已更新(流式未启用为 SSE 错误事件)。 |
|
||||
| 机器人 | `ProcessMessageForRobot` 按 `robot_default_agent_mode`(默认 `eino_single`)调用 `RunEinoSingleChatModelAgent` 或 `RunDeepAgent`。 |
|
||||
| 预置编排 | 聊天 / WebShell:`POST /api/multi-agent*` 请求体 `orchestration`:`deep` \| `plan_execute` \| `supervisor`(缺省 `deep`)。`plan_execute` 不构建 YAML/Markdown 子代理;`plan_execute_loop_max_iterations` 仍来自配置。`supervisor` 至少需一个子代理。 |
|
||||
| Eino 中间件 | `multi_agent.eino_middleware`(可选):`patchtoolcalls`(默认开)、`toolsearch`(按阈值拆分 MCP 工具列表)、`plantask`(需 `eino_skills`)、`reduction`(大工具输出截断/落盘)、`checkpoint_dir`(Runner 断点)、`deep_output_key` / `deep_model_retry_max_retries` / `task_tool_description_prefix`(Deep 与 supervisor 主代理共享其中模型重试与 OutputKey)。`plan_execute` 的 Executor 无 Handlers:仅继承 **ToolsConfig** 侧效果(如 `tool_search` 列表拆分),不挂载 patch/plantask/reduction 中间件。 |
|
||||
| Eino 中间件 | `multi_agent.eino_middleware`(可选):`patchtoolcalls`(默认开)、`toolsearch`(按阈值拆分 MCP 工具列表)、`plantask`(需 `eino_skills`)、`reduction`(大工具输出截断/落盘)、`checkpoint_dir`(Runner 断点)、`deep_output_key` / `deep_model_retry_max_retries` / `task_tool_description_prefix`(Deep 与 supervisor 主代理共享其中模型重试与 OutputKey)。**`plan_execute`**:`runner.go` 将 `prependEinoMiddlewares(einoMWMain)` 产物作为 `ExecPreMiddlewares` 挂到 **Executor**(与 Deep/Supervisor 主代理同序:patch → reduction → toolsearch → plantask → filesystem → skill → summarization tail);Planner/Replanner 仅 summarization tail + prompt 预算截断,不跑 MCP 工具链。 |
|
||||
|
||||
## 进行中 / 待办( backlog )
|
||||
|
||||
@@ -37,7 +37,8 @@
|
||||
|
||||
## 关键文件索引
|
||||
|
||||
- `internal/multiagent/runner.go` — DeepAgent 组装与事件循环
|
||||
- `internal/multiagent/runner.go` — DeepAgent / plan_execute / supervisor 组装与事件循环
|
||||
- `internal/multiagent/eino_orchestration.go` — PlanExecute 根节点与 Executor 中间件栈(`buildPlanExecuteExecutorHandlers`)
|
||||
- `internal/handler/multi_agent.go` — SSE 与(同步)HTTP
|
||||
- `internal/handler/multi_agent_prepare.go` — 会话准备(含 WebShell)
|
||||
- `internal/einomcp/` — MCP → Eino Tool
|
||||
@@ -59,4 +60,5 @@
|
||||
| 2026-03-22 | `orchestrator.md` / `kind: orchestrator` 主代理、列表主/子标记、与 `orchestrator_instruction` 优先级。 |
|
||||
| 2026-04-19 | 主聊天「对话模式」:原生 ReAct 与 Deep / Plan-Execute / Supervisor;`POST /api/multi-agent*` 请求体 `orchestration` 与界面一致;`config.yaml` / 设置页不再维护预置编排字段(机器人/批量默认 `deep`)。 |
|
||||
| 2026-04-21 | 移除角色 `skills` 与 `/api/roles/skills/list`;`bind_role` 仅继承 tools;Skills 仅通过 Eino `skill` 工具按需加载。 |
|
||||
| 2026-07-02 | **plan_execute Executor 中间件对齐**:`ExecPreMiddlewares` 与 Deep 主代理同源;`buildPlanExecuteExecutorHandlers` + 回归测试;文档更正。 |
|
||||
| 2026-06-02 | **移除原生 ReAct**:删除 `/api/agent-loop*` 执行入口与 `AgentLoopWithProgress`;统一 Eino ADK(单代理 `/api/eino-agent*`,多代理 `/api/multi-agent*`);任务 cancel/tasks API 保留。 |
|
||||
|
||||
@@ -0,0 +1,403 @@
|
||||
# CyberStrikeAI 图编排使用说明
|
||||
|
||||
[English](workflow-graph_en.md)
|
||||
|
||||
本文档说明 **图编排(Graph Orchestration)** 的完整使用方式:如何在画布上搭建流程、配置各类型节点、在节点之间传递数据,以及如何将流程绑定到角色并自动运行。
|
||||
|
||||
---
|
||||
|
||||
## 一、在哪里使用图编排
|
||||
|
||||
1. 登录 CyberStrikeAI Web 端
|
||||
2. 左侧导航进入 **图编排**
|
||||
3. 在左侧列表选择已有流程,或新建流程
|
||||
4. 在中央画布拖拽、连线、配置节点
|
||||
5. 填写流程 **ID**、**名称**、**描述** 后点击 **保存**
|
||||
|
||||
保存后的流程可在 **角色管理** 中绑定到某个角色。绑定后,用户与该角色对话时会按流程图自动执行(`workflow_policy: auto`)。
|
||||
|
||||
---
|
||||
|
||||
## 二、画布基本操作
|
||||
|
||||
| 操作 | 说明 |
|
||||
|------|------|
|
||||
| 添加节点 | 点击画布上方节点类型按钮(开始、工具、Agent、条件、审批、输出、结束) |
|
||||
| 连线 | 点击 **连线**,依次点击源节点和目标节点;再次点击 **连线** 退出连线模式 |
|
||||
| 选中元素 | 单击节点或连线,右侧显示 **节点属性** |
|
||||
| 删除选中 | 点击 **删除选中** 删除当前节点或连线 |
|
||||
| 自动布局 | 点击 **自动布局** 整理节点位置 |
|
||||
| 删除流程 | 点击 **删除** 删除整个流程定义 |
|
||||
|
||||
**建议:** 每个流程至少包含 **1 个开始节点** 和 **1 个输出节点**;开始节点不应有入边,输出节点不应有出边。
|
||||
|
||||
---
|
||||
|
||||
## 三、执行模型(先理解再配置)
|
||||
|
||||
图编排按 **有向图** 执行,引擎从 **开始** 节点出发,沿连线依次运行下游节点。
|
||||
|
||||
每次运行会维护一份内部状态,模板变量 `{{...}}` 从这里取值:
|
||||
|
||||
| 内部状态 | 模板前缀 | 含义 |
|
||||
|----------|----------|------|
|
||||
| `inputs` | `{{inputs.xxx}}` | 流程启动时的输入(用户消息、会话 ID 等) |
|
||||
| `lastOutput` | `{{previous.xxx}}` | **上一个刚执行完** 的节点的输出 |
|
||||
| `outputs` | `{{outputs.xxx}}` | 全局 **命名变量池**(由节点的「输出变量名」写入) |
|
||||
| `nodeOutputs` | `{{节点ID.xxx}}` | 指定节点 ID 的完整输出对象 |
|
||||
|
||||
### 3.1 `previous` 是什么?
|
||||
|
||||
`{{previous.output}}` 表示 **紧邻的上一个执行节点** 的 `output` 字段。
|
||||
|
||||
- 每执行完一个节点,引擎都会更新 `lastOutput`
|
||||
- **不是**「画布上画线的上游」,而是 **实际执行顺序上的上一步**
|
||||
|
||||
示例:
|
||||
|
||||
```text
|
||||
开始 → Agent A → Agent B
|
||||
```
|
||||
|
||||
Agent B 的 `{{previous.output}}` = Agent A 的输出。
|
||||
|
||||
但若中间有条件节点:
|
||||
|
||||
```text
|
||||
开始 → Agent A → 条件 → Agent B
|
||||
```
|
||||
|
||||
Agent B 的 `{{previous.output}}` = **条件节点** 的输出(`true` / `false`),**不是** Agent A 的结果。
|
||||
|
||||
### 3.2 `outputs` 是什么?
|
||||
|
||||
`outputs` 是引擎在运行过程中维护的 **命名变量注册表**。
|
||||
|
||||
当 Agent、工具、输出 等节点配置了 **输出变量名**(字段 `output_key`)后,节点执行成功会把结果写入:
|
||||
|
||||
```text
|
||||
outputs["你填的变量名"] = 节点输出内容
|
||||
```
|
||||
|
||||
之后 **任意下游节点** 都可以通过 `{{outputs.变量名}}` 引用,不要求两个节点直接相连。
|
||||
|
||||
示例:
|
||||
|
||||
- Agent A 的 **输出变量名** 填 `agent_result1`
|
||||
- Agent B 的 **输入来源** 填 `{{outputs.agent_result1}}`
|
||||
|
||||
即使 A 和 B 之间隔着条件节点,B 仍能拿到 A 的输出。
|
||||
|
||||
### 3.3 什么时候用 `previous`,什么时候用 `outputs`?
|
||||
|
||||
| 场景 | 推荐写法 |
|
||||
|------|----------|
|
||||
| 两个节点 **直连**,只取上一步结果 | `{{previous.output}}` |
|
||||
| 中间有其他节点(条件、工具、审批等) | `{{outputs.变量名}}` |
|
||||
| 需要引用 **更早** 的某个节点结果 | `{{outputs.变量名}}` 或 `{{节点ID.output}}` |
|
||||
| 条件判断要基于某 Agent 的输出 | `{{outputs.变量名}} != ""` |
|
||||
| 读取用户最初输入 | `{{inputs.message}}` |
|
||||
|
||||
**记忆口诀:**
|
||||
|
||||
- `previous` = 上一步(链式、紧邻)
|
||||
- `outputs` = 按名字取(跨节点、可回溯)
|
||||
|
||||
---
|
||||
|
||||
## 四、模板语法
|
||||
|
||||
### 4.1 基本格式
|
||||
|
||||
```text
|
||||
{{变量路径}}
|
||||
```
|
||||
|
||||
支持字母、数字、下划线、点、连字符,例如:
|
||||
|
||||
```text
|
||||
{{previous.output}}
|
||||
{{outputs.agent_result1}}
|
||||
{{inputs.message}}
|
||||
{{inputs.conversationId}}
|
||||
{{previous.matched}}
|
||||
{{node-abc123.output}}
|
||||
```
|
||||
|
||||
### 4.2 可用路径一览
|
||||
|
||||
| 路径 | 说明 |
|
||||
|------|------|
|
||||
| `{{inputs.message}}` | 用户消息(开始节点输入) |
|
||||
| `{{inputs.conversationId}}` | 会话 ID |
|
||||
| `{{inputs.projectId}}` | 项目 ID |
|
||||
| `{{previous.output}}` | 上一节点主输出 |
|
||||
| `{{previous.matched}}` | 上一条件节点的匹配结果(`true` / `false`) |
|
||||
| `{{outputs.变量名}}` | 某节点注册过的命名输出 |
|
||||
| `{{节点ID.output}}` | 指定节点 ID 的 `output` 字段 |
|
||||
|
||||
### 4.3 条件表达式
|
||||
|
||||
条件节点和连线条件支持简单比较:
|
||||
|
||||
```text
|
||||
{{outputs.agent_result1}} != ""
|
||||
{{previous.output}} == "ok"
|
||||
{{outputs.count}} == "100"
|
||||
```
|
||||
|
||||
规则:
|
||||
|
||||
- 使用 `==` 或 `!=` 做字符串比较(两侧会自动去掉首尾空格和引号)
|
||||
- 无比较符时,非空且不为 `false` / `0` / `null` 视为真
|
||||
|
||||
---
|
||||
|
||||
## 五、节点类型与配置
|
||||
|
||||
### 5.1 开始(start)
|
||||
|
||||
流程入口,将用户输入注入 `inputs`。
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| 输入变量 | 逗号分隔的输入键名 | `message, conversationId, projectId` |
|
||||
|
||||
开始节点输出包含:`output`、`message`、`conversationId`、`projectId`。
|
||||
|
||||
### 5.2 Agent(agent)
|
||||
|
||||
调用大模型 Agent 处理任务,支持多种运行模式。
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| Agent 模式 | `eino_single` / `deep` / `plan_execute` / `supervisor` | `eino_single` |
|
||||
| 输入来源 | 上游数据的模板表达式 | `{{previous.output}}` |
|
||||
| 节点指令 | 本节点要完成的任务描述 | 空 |
|
||||
| 输出变量名 | 写入 `outputs` 的键名 | `agent_result` |
|
||||
|
||||
**消息拼装规则:**
|
||||
|
||||
- 仅填 **节点指令**:直接把指令发给 Agent
|
||||
- 仅填 **输入来源**:生成「请基于上游节点输出继续处理:…」
|
||||
- 两者都填:合并为「上游输入 + 节点指令」
|
||||
|
||||
Agent 节点执行后:
|
||||
|
||||
- `previous.output` 更新为本节点响应文本
|
||||
- 若配置了 **输出变量名**,同时写入 `outputs[输出变量名]`
|
||||
|
||||
### 5.3 工具(tool)
|
||||
|
||||
调用已启用的 MCP 工具。
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| MCP 工具 | 工具名称(必填) | — |
|
||||
| 参数模板 | JSON,支持 `{{...}}` 模板 | `{}` |
|
||||
| 超时秒数 | 可选 | 空 |
|
||||
|
||||
示例参数模板:
|
||||
|
||||
```json
|
||||
{"target": "{{inputs.message}}", "port": "443"}
|
||||
```
|
||||
|
||||
若配置了 **输出变量名**,工具返回结果会写入 `outputs`。
|
||||
|
||||
### 5.4 条件(condition)
|
||||
|
||||
根据表达式计算分支,输出 `matched`(`true` / `false`)。
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| 条件表达式 | 支持 `{{...}}` 与 `==` / `!=` | `{{previous.output}} != ""` |
|
||||
|
||||
**分支规则:**
|
||||
|
||||
- 从条件节点连出的 **第一条线** 默认为 **「是」** 分支(`matched == true`)
|
||||
- **第二条线** 默认为 **「否」** 分支(`matched == false`)
|
||||
- 连线标签可写 `是` / `否`(或 `yes` / `no`、`true` / `false`)辅助识别
|
||||
- 第三条及以后的出边需在 **连线条件** 中自定义表达式
|
||||
|
||||
连线条件示例(选中连线后在右侧配置):
|
||||
|
||||
```text
|
||||
{{previous.matched}} == "true"
|
||||
{{previous.matched}} == "false"
|
||||
```
|
||||
|
||||
### 5.5 审批(hitl)
|
||||
|
||||
人工确认检查点(当前为记录模式,自动标记 `approved: true` 并继续)。
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| 审批提示 | 支持模板 | `请审批该步骤是否继续执行` |
|
||||
| 审批方 | `human` / `audit_agent` | `human` |
|
||||
|
||||
### 5.6 输出(output)
|
||||
|
||||
将流程最终结果写入 `outputs`,供结束摘要和对话展示使用。
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| 输出变量名 | 必填,最终结果的键名 | `result` |
|
||||
| 变量来源 | 模板表达式,决定写入的值 | `{{previous.output}}` |
|
||||
|
||||
**注意:** 输出节点是流程的「出口」,不应再有出边。
|
||||
|
||||
### 5.7 结束(end)
|
||||
|
||||
可选节点,用于生成结束摘要模板(角色绑定流程中较少单独使用)。
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| 结束摘要模板 | 支持 `{{outputs.xxx}}` | `{{outputs.result}}` |
|
||||
|
||||
---
|
||||
|
||||
## 六、连线配置
|
||||
|
||||
选中 **连线** 后,右侧可配置 **连线条件**。
|
||||
|
||||
| 场景 | 示例 |
|
||||
|------|------|
|
||||
| 普通节点后的过滤 | `{{previous.output}} == "ok"` |
|
||||
| 条件节点「是」分支 | `{{previous.matched}} == "true"` |
|
||||
| 条件节点「否」分支 | `{{previous.matched}} == "false"` |
|
||||
|
||||
若不填连线条件:
|
||||
|
||||
- 非条件节点:连线始终放行
|
||||
- 条件节点:按出边顺序自动分配是/否分支
|
||||
|
||||
---
|
||||
|
||||
## 七、完整示例:跨条件节点传递 Agent 输出
|
||||
|
||||
### 7.1 流程结构
|
||||
|
||||
```text
|
||||
开始 → Agent(生成初始值)→ 条件 → Agent(加工)→ 输出
|
||||
↘ 否 → 输出
|
||||
```
|
||||
|
||||
### 7.2 节点配置
|
||||
|
||||
**Agent 1(第一个 Agent)**
|
||||
|
||||
| 字段 | 值 |
|
||||
|------|-----|
|
||||
| 节点指令 | 只输出 `123333333` |
|
||||
| 输出变量名 | `agent_result1` |
|
||||
|
||||
**条件**
|
||||
|
||||
| 字段 | 值 |
|
||||
|------|-----|
|
||||
| 条件表达式 | `{{outputs.agent_result1}} != ""` |
|
||||
|
||||
**Agent 2(第二个 Agent)**
|
||||
|
||||
| 字段 | 值 |
|
||||
|------|-----|
|
||||
| 输入来源 | `{{outputs.agent_result1}}` |
|
||||
| 节点指令 | 在输入基础上加 100,然后输出 |
|
||||
| 输出变量名 | `agent_result` |
|
||||
|
||||
**输出**
|
||||
|
||||
| 字段 | 值 |
|
||||
|------|-----|
|
||||
| 输出变量名 | `result` |
|
||||
| 变量来源 | `{{outputs.agent_result}}` |
|
||||
|
||||
### 7.3 常见错误
|
||||
|
||||
| 错误配置 | 原因 |
|
||||
|----------|------|
|
||||
| Agent 2 输入来源写 `{{previous.output}}` | `previous` 指向条件节点,得到的是 `true`/`false`,不是 Agent 1 的文本 |
|
||||
| 未给 Agent 1 填输出变量名 | `outputs.agent_result1` 不存在,下游取到空值 |
|
||||
| 条件表达式写 `{{previous.output}}` | 判断的是开始节点或上一节点的输出,而非 Agent 1 的命名变量 |
|
||||
|
||||
---
|
||||
|
||||
## 八、绑定角色并运行
|
||||
|
||||
### 8.1 在角色管理中绑定
|
||||
|
||||
1. 进入 **角色管理**,编辑或新建角色
|
||||
2. 选择 **工作流 / 图编排** 绑定的流程 ID
|
||||
3. 策略设为 `auto`(默认:有 `workflow_id` 时自动执行)
|
||||
4. 保存角色
|
||||
|
||||
也可在角色 YAML 中直接配置:
|
||||
|
||||
```yaml
|
||||
name: 工作流测试
|
||||
workflow_id: "1233"
|
||||
workflow_version: latest
|
||||
workflow_policy: auto
|
||||
```
|
||||
|
||||
### 8.2 运行效果
|
||||
|
||||
用户选择该角色并发送消息后:
|
||||
|
||||
1. 引擎加载对应 `graph_json` 并按图执行
|
||||
2. 对话页可看到 `workflow_start`、`workflow_node_start`、Agent 推理等进度事件
|
||||
3. 流程结束后返回摘要,列出 `outputs` 中所有命名输出
|
||||
|
||||
若未配置输出节点或条件未命中,`outputs` 可能为空,摘要会提示检查输出节点与分支。
|
||||
|
||||
---
|
||||
|
||||
## 九、保存前校验规则
|
||||
|
||||
保存时系统会自动检查:
|
||||
|
||||
| 规则 | 说明 |
|
||||
|------|------|
|
||||
| 必须有开始节点 | 至少 1 个 `start` |
|
||||
| 必须有输出节点 | 至少 1 个 `output`,且填写输出变量名 |
|
||||
| 连线合法 | 源/目标节点存在,不能自环 |
|
||||
| 开始节点无入边 | 开始节点不能被指向 |
|
||||
| 输出节点无出边 | 输出节点后不应再连线 |
|
||||
| 工具节点 | 必须选择 MCP 工具 |
|
||||
| 条件节点 | 必须填写表达式;建议 1~2 条出边(是/否) |
|
||||
|
||||
---
|
||||
|
||||
## 十、排错指南
|
||||
|
||||
| 现象 | 可能原因 | 处理建议 |
|
||||
|------|----------|----------|
|
||||
| 下游拿到空值 | 上游未配置输出变量名 | 给上游 Agent/工具填 **输出变量名**,下游用 `{{outputs.xxx}}` |
|
||||
| 下游拿到 `true`/`false` | 误用 `{{previous.output}}`,上一步是条件节点 | 改用 `{{outputs.xxx}}` |
|
||||
| 条件总走「否」 | 表达式与真实输出格式不一致 | 检查 Agent 输出是否带引号、换行;用 `!= ""` 先验证 |
|
||||
| 流程无最终输出 | 未命中输出节点所在分支 | 检查条件分支连线;确保至少一条路径到达 **输出** 节点 |
|
||||
| 角色对话未跑流程 | 角色未绑定或未启用 | 确认 `workflow_id`、`workflow_policy: auto`、流程 `enabled: true` |
|
||||
| 工具节点失败 | 参数 JSON 不合法或工具未启用 | 检查参数模板;在 MCP 中启用对应工具 |
|
||||
|
||||
---
|
||||
|
||||
## 十一、最佳实践
|
||||
|
||||
1. **命名规范**:为每个需要被引用的节点设置有意义的输出变量名,如 `scan_result`、`parsed_targets`,避免都叫 `agent_result`。
|
||||
2. **跨节点传参优先用 `outputs`**:只要中间可能插入条件、工具、审批节点,就应用命名变量。
|
||||
3. **`previous` 仅用于直连**:A → B 且无中间节点时,`{{previous.output}}` 最简洁。
|
||||
4. **条件判断引用源数据**:判断 Agent 输出时用 `{{outputs.xxx}}`,不要用 `{{previous.output}}`(除非条件紧跟在目标 Agent 之后)。
|
||||
5. **每条路径都要有出口**:确保「是」「否」分支最终都能到达 **输出** 节点(或你期望的终点)。
|
||||
6. **保存前跑一遍**:用简单指令(如固定字符串输出)验证数据传递,再替换为真实业务逻辑。
|
||||
|
||||
---
|
||||
|
||||
## 十二、相关代码位置(开发者参考)
|
||||
|
||||
| 模块 | 路径 |
|
||||
|------|------|
|
||||
| 执行引擎 | `internal/workflow/runner.go` |
|
||||
| 画布前端 | `web/static/js/workflows.js` |
|
||||
| 流程 API | `internal/handler/workflow.go` |
|
||||
| 角色绑定 | `internal/config/config.go`(`workflow_id` 字段) |
|
||||
@@ -0,0 +1,403 @@
|
||||
# CyberStrikeAI Graph Orchestration Guide
|
||||
|
||||
[中文](workflow-graph.md)
|
||||
|
||||
This document explains how to use **Graph Orchestration**: building workflows on the canvas, configuring node types, passing data between nodes, and binding a graph to a role for automatic execution.
|
||||
|
||||
---
|
||||
|
||||
## 1. Where to find Graph Orchestration
|
||||
|
||||
1. Log in to the CyberStrikeAI web UI.
|
||||
2. Open **Graph Orchestration** in the left sidebar.
|
||||
3. Select an existing workflow from the list, or create a new one.
|
||||
4. Drag nodes, draw edges, and configure properties on the canvas.
|
||||
5. Fill in **ID**, **Name**, and **Description**, then click **Save**.
|
||||
|
||||
Saved workflows can be bound to a role under **Role Management**. When `workflow_policy` is `auto`, chatting with that role runs the bound graph automatically.
|
||||
|
||||
---
|
||||
|
||||
## 2. Canvas basics
|
||||
|
||||
| Action | Description |
|
||||
|--------|-------------|
|
||||
| Add node | Click a node type button above the canvas (Start, Tool, Agent, Condition, HITL, Output, End) |
|
||||
| Connect | Click **Connect**, then click source and target nodes; click **Connect** again to exit connect mode |
|
||||
| Select | Click a node or edge; properties appear in the right panel |
|
||||
| Delete selected | Remove the current node or edge |
|
||||
| Auto layout | Rearrange node positions |
|
||||
| Delete workflow | Remove the entire workflow definition |
|
||||
|
||||
**Requirements:** Every workflow needs at least **one Start node** and **one Output node**. Start nodes must not have incoming edges; Output nodes must not have outgoing edges.
|
||||
|
||||
---
|
||||
|
||||
## 3. Execution model (read this before configuring)
|
||||
|
||||
The engine executes the workflow as a **directed graph**, starting from the **Start** node and following edges to downstream nodes.
|
||||
|
||||
During a run, the engine keeps internal state. Template expressions `{{...}}` read from that state:
|
||||
|
||||
| Internal state | Template prefix | Meaning |
|
||||
|----------------|-----------------|---------|
|
||||
| `inputs` | `{{inputs.xxx}}` | Workflow inputs at start (user message, conversation ID, etc.) |
|
||||
| `lastOutput` | `{{previous.xxx}}` | Output of the **most recently executed** node |
|
||||
| `outputs` | `{{outputs.xxx}}` | Global **named variable pool** (written by nodes with an output key) |
|
||||
| `nodeOutputs` | `{{nodeId.xxx}}` | Full output object of a specific node ID |
|
||||
|
||||
### 3.1 What is `previous`?
|
||||
|
||||
`{{previous.output}}` is the `output` field of the **immediately preceding executed node**.
|
||||
|
||||
- After every node finishes, the engine updates `lastOutput`.
|
||||
- It is **not** “the node drawn upstream on the canvas”; it is **the previous step in actual execution order**.
|
||||
|
||||
Example:
|
||||
|
||||
```text
|
||||
Start → Agent A → Agent B
|
||||
```
|
||||
|
||||
For Agent B, `{{previous.output}}` = Agent A’s output.
|
||||
|
||||
With a condition in between:
|
||||
|
||||
```text
|
||||
Start → Agent A → Condition → Agent B
|
||||
```
|
||||
|
||||
For Agent B, `{{previous.output}}` = the **condition node** output (`true` / `false`), **not** Agent A’s result.
|
||||
|
||||
### 3.2 What is `outputs`?
|
||||
|
||||
`outputs` is a **named variable registry** maintained by the engine during execution.
|
||||
|
||||
When an Agent, Tool, or Output node sets an **Output variable name** (`output_key`), the result is stored as:
|
||||
|
||||
```text
|
||||
outputs["your_variable_name"] = node_output
|
||||
```
|
||||
|
||||
Any downstream node can then reference it via `{{outputs.variable_name}}`, even if other nodes sit in between.
|
||||
|
||||
Example:
|
||||
|
||||
- Agent A **Output variable name**: `agent_result1`
|
||||
- Agent B **Input source**: `{{outputs.agent_result1}}`
|
||||
|
||||
Agent B still receives Agent A’s output even when a condition node lies between them.
|
||||
|
||||
### 3.3 When to use `previous` vs `outputs`
|
||||
|
||||
| Scenario | Recommended |
|
||||
|----------|-------------|
|
||||
| Two nodes are **directly connected**; you only need the last step | `{{previous.output}}` |
|
||||
| Other nodes sit in between (condition, tool, HITL, etc.) | `{{outputs.variable_name}}` |
|
||||
| Reference output from an **earlier** node | `{{outputs.variable_name}}` or `{{nodeId.output}}` |
|
||||
| Condition should test an Agent’s output | `{{outputs.variable_name}} != ""` |
|
||||
| Read the original user input | `{{inputs.message}}` |
|
||||
|
||||
**Rule of thumb:**
|
||||
|
||||
- `previous` = last step (chained, adjacent)
|
||||
- `outputs` = by name (cross-node, look back)
|
||||
|
||||
---
|
||||
|
||||
## 4. Template syntax
|
||||
|
||||
### 4.1 Basic format
|
||||
|
||||
```text
|
||||
{{path.to.value}}
|
||||
```
|
||||
|
||||
Allowed characters in paths: letters, digits, underscore, dot, hyphen. Examples:
|
||||
|
||||
```text
|
||||
{{previous.output}}
|
||||
{{outputs.agent_result1}}
|
||||
{{inputs.message}}
|
||||
{{inputs.conversationId}}
|
||||
{{previous.matched}}
|
||||
{{node-abc123.output}}
|
||||
```
|
||||
|
||||
### 4.2 Available paths
|
||||
|
||||
| Path | Description |
|
||||
|------|-------------|
|
||||
| `{{inputs.message}}` | User message (Start node input) |
|
||||
| `{{inputs.conversationId}}` | Conversation ID |
|
||||
| `{{inputs.projectId}}` | Project ID |
|
||||
| `{{previous.output}}` | Primary output of the previous node |
|
||||
| `{{previous.matched}}` | Match result of the previous condition node (`true` / `false`) |
|
||||
| `{{outputs.variable_name}}` | Named output registered by a node |
|
||||
| `{{nodeId.output}}` | `output` field of the node with that ID |
|
||||
|
||||
### 4.3 Condition expressions
|
||||
|
||||
Condition nodes and edge conditions support simple comparisons:
|
||||
|
||||
```text
|
||||
{{outputs.agent_result1}} != ""
|
||||
{{previous.output}} == "ok"
|
||||
{{outputs.count}} == "100"
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- Use `==` or `!=` for string comparison (leading/trailing spaces and quotes are trimmed)
|
||||
- Without a comparator, non-empty values that are not `false`, `0`, or `null` are treated as true
|
||||
|
||||
---
|
||||
|
||||
## 5. Node types and configuration
|
||||
|
||||
### 5.1 Start
|
||||
|
||||
Workflow entry point; injects user input into `inputs`.
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| Input keys | Comma-separated input key names | `message, conversationId, projectId` |
|
||||
|
||||
Start node output includes: `output`, `message`, `conversationId`, `projectId`.
|
||||
|
||||
### 5.2 Agent
|
||||
|
||||
Runs an LLM Agent task. Supports multiple modes.
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| Agent mode | `eino_single` / `deep` / `plan_execute` / `supervisor` | `eino_single` |
|
||||
| Input source | Template for upstream data | `{{previous.output}}` |
|
||||
| Node instruction | Task description for this node | empty |
|
||||
| Output variable name | Key written into `outputs` | `agent_result` |
|
||||
|
||||
**Message assembly:**
|
||||
|
||||
- Instruction only → send instruction to the Agent
|
||||
- Input source only → “Continue based on upstream output: …”
|
||||
- Both → combined “upstream input + node instruction”
|
||||
|
||||
After execution:
|
||||
|
||||
- `previous.output` becomes this node’s response text
|
||||
- If **Output variable name** is set, the value is also stored in `outputs[variable_name]`
|
||||
|
||||
### 5.3 Tool
|
||||
|
||||
Calls an enabled MCP tool.
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| MCP tool | Tool name (required) | — |
|
||||
| Argument template | JSON with `{{...}}` templates | `{}` |
|
||||
| Timeout (seconds) | Optional | empty |
|
||||
|
||||
Example argument template:
|
||||
|
||||
```json
|
||||
{"target": "{{inputs.message}}", "port": "443"}
|
||||
```
|
||||
|
||||
If an output variable name is configured, the tool result is written to `outputs`.
|
||||
|
||||
### 5.4 Condition
|
||||
|
||||
Evaluates an expression and outputs `matched` (`true` / `false`).
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| Expression | Supports `{{...}}` and `==` / `!=` | `{{previous.output}} != ""` |
|
||||
|
||||
**Branching rules:**
|
||||
|
||||
- The **first outgoing edge** defaults to the **“yes”** branch (`matched == true`)
|
||||
- The **second outgoing edge** defaults to the **“no”** branch (`matched == false`)
|
||||
- Edge labels such as `是` / `否` (or `yes` / `no`, `true` / `false`) help identify branches
|
||||
- A third or later edge needs a custom **edge condition**
|
||||
|
||||
Edge condition examples (select an edge, configure in the right panel):
|
||||
|
||||
```text
|
||||
{{previous.matched}} == "true"
|
||||
{{previous.matched}} == "false"
|
||||
```
|
||||
|
||||
### 5.5 HITL (human-in-the-loop)
|
||||
|
||||
Human approval checkpoint (currently record-only; marks `approved: true` and continues).
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| Prompt | Supports templates | `Please approve before continuing` |
|
||||
| Reviewer | `human` / `audit_agent` | `human` |
|
||||
|
||||
### 5.6 Output
|
||||
|
||||
Writes the final workflow result into `outputs` for summary and chat display.
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| Output variable name | Required key for the final result | `result` |
|
||||
| Variable source | Template deciding what to write | `{{previous.output}}` |
|
||||
|
||||
**Note:** Output nodes are workflow exits and must not have outgoing edges.
|
||||
|
||||
### 5.7 End
|
||||
|
||||
Optional node for an end summary template (less common in role-bound flows).
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| Result template | Supports `{{outputs.xxx}}` | `{{outputs.result}}` |
|
||||
|
||||
---
|
||||
|
||||
## 6. Edge configuration
|
||||
|
||||
Select an **edge** to configure its **condition** in the right panel.
|
||||
|
||||
| Scenario | Example |
|
||||
|----------|---------|
|
||||
| Filter after a normal node | `{{previous.output}} == "ok"` |
|
||||
| “Yes” branch from a condition | `{{previous.matched}} == "true"` |
|
||||
| “No” branch from a condition | `{{previous.matched}} == "false"` |
|
||||
|
||||
If no edge condition is set:
|
||||
|
||||
- Non-condition nodes: edge is always allowed
|
||||
- Condition nodes: yes/no branches are assigned by edge order automatically
|
||||
|
||||
---
|
||||
|
||||
## 7. Full example: passing Agent output across a condition
|
||||
|
||||
### 7.1 Graph structure
|
||||
|
||||
```text
|
||||
Start → Agent (initial value) → Condition → Agent (transform) → Output
|
||||
↘ no → Output
|
||||
```
|
||||
|
||||
### 7.2 Node configuration
|
||||
|
||||
**Agent 1**
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Node instruction | Output only `123333333` |
|
||||
| Output variable name | `agent_result1` |
|
||||
|
||||
**Condition**
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Expression | `{{outputs.agent_result1}} != ""` |
|
||||
|
||||
**Agent 2**
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Input source | `{{outputs.agent_result1}}` |
|
||||
| Node instruction | Add 100 to the input, then output |
|
||||
| Output variable name | `agent_result` |
|
||||
|
||||
**Output**
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| Output variable name | `result` |
|
||||
| Variable source | `{{outputs.agent_result}}` |
|
||||
|
||||
### 7.3 Common mistakes
|
||||
|
||||
| Wrong config | Why it fails |
|
||||
|--------------|--------------|
|
||||
| Agent 2 input source = `{{previous.output}}` | `previous` points to the condition node → `true`/`false`, not Agent 1’s text |
|
||||
| Agent 1 has no output variable name | `outputs.agent_result1` does not exist → empty downstream |
|
||||
| Condition uses `{{previous.output}}` | Tests the wrong upstream value instead of Agent 1’s named output |
|
||||
|
||||
---
|
||||
|
||||
## 8. Bind to a role and run
|
||||
|
||||
### 8.1 Bind in Role Management
|
||||
|
||||
1. Open **Role Management**, edit or create a role.
|
||||
2. Select the workflow / graph ID to bind.
|
||||
3. Set policy to `auto` (default when `workflow_id` is set).
|
||||
4. Save the role.
|
||||
|
||||
You can also configure this in role YAML:
|
||||
|
||||
```yaml
|
||||
name: workflow-test
|
||||
workflow_id: "1233"
|
||||
workflow_version: latest
|
||||
workflow_policy: auto
|
||||
```
|
||||
|
||||
### 8.2 Runtime behavior
|
||||
|
||||
When a user chats with that role:
|
||||
|
||||
1. The engine loads `graph_json` and executes the graph.
|
||||
2. The chat UI shows progress events (`workflow_start`, `workflow_node_start`, Agent reasoning, etc.).
|
||||
3. When finished, a summary lists all named entries in `outputs`.
|
||||
|
||||
If no Output node is reached or no branch matches, `outputs` may be empty and the summary will suggest checking the Output node and branches.
|
||||
|
||||
---
|
||||
|
||||
## 9. Validation before save
|
||||
|
||||
On save, the system checks:
|
||||
|
||||
| Rule | Description |
|
||||
|------|-------------|
|
||||
| Start node required | At least one `start` node |
|
||||
| Output node required | At least one `output` node with an output variable name |
|
||||
| Valid edges | Source and target exist; no self-loops |
|
||||
| Start has no incoming edges | Start must not be targeted |
|
||||
| Output has no outgoing edges | Nothing after Output |
|
||||
| Tool nodes | MCP tool must be selected |
|
||||
| Condition nodes | Expression required; ideally 1–2 outgoing edges (yes/no) |
|
||||
|
||||
---
|
||||
|
||||
## 10. Troubleshooting
|
||||
|
||||
| Symptom | Likely cause | Fix |
|
||||
|---------|--------------|-----|
|
||||
| Downstream gets empty value | Upstream has no output variable name | Set **Output variable name** on upstream; use `{{outputs.xxx}}` downstream |
|
||||
| Downstream gets `true`/`false` | Used `{{previous.output}}` while previous node is a condition | Use `{{outputs.xxx}}` instead |
|
||||
| Condition always takes “no” | Expression does not match actual output format | Check Agent output for quotes/newlines; try `!= ""` first |
|
||||
| No final output | Output node branch not reached | Verify condition wiring; ensure every path reaches an **Output** node |
|
||||
| Role chat does not run workflow | Role not bound or disabled | Check `workflow_id`, `workflow_policy: auto`, workflow `enabled: true` |
|
||||
| Tool node fails | Invalid JSON in arguments or tool disabled | Fix argument template; enable the tool in MCP settings |
|
||||
|
||||
---
|
||||
|
||||
## 11. Best practices
|
||||
|
||||
1. **Meaningful names**: Use descriptive output variable names (`scan_result`, `parsed_targets`) instead of reusing `agent_result` everywhere.
|
||||
2. **Prefer `outputs` for cross-node data**: If a condition, tool, or HITL node might sit in between, use named variables.
|
||||
3. **Use `previous` only for direct links**: `A → B` with nothing in between is the ideal case for `{{previous.output}}`.
|
||||
4. **Conditions should reference source data**: When testing Agent output, use `{{outputs.xxx}}` unless the condition immediately follows that Agent.
|
||||
5. **Every path needs an exit**: Ensure both yes and no branches eventually reach an **Output** node (or your intended end).
|
||||
6. **Validate with a simple run**: Use fixed-string outputs to verify data flow before swapping in real business logic.
|
||||
|
||||
---
|
||||
|
||||
## 12. Code references (for developers)
|
||||
|
||||
| Module | Path |
|
||||
|--------|------|
|
||||
| Execution engine | `internal/workflow/runner.go` |
|
||||
| Canvas UI | `web/static/js/workflows.js` |
|
||||
| Workflow API | `internal/handler/workflow.go` |
|
||||
| Role binding | `internal/config/config.go` (`workflow_id` field) |
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 1.0 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 181 KiB After Width: | Height: | Size: 265 KiB |
+41
-14
@@ -21,6 +21,7 @@ import (
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/einoobserve"
|
||||
"cyberstrike-ai/internal/handler"
|
||||
"cyberstrike-ai/internal/hitl"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
@@ -109,6 +110,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
monitorRetention.PurgeExpired()
|
||||
monitor.StartRetentionLoop(monitorRetention, log.Logger)
|
||||
|
||||
hitlRetention := hitl.NewService(db, cfg, log.Logger)
|
||||
hitlRetention.PurgeExpired()
|
||||
hitl.StartRetentionLoop(hitlRetention, log.Logger)
|
||||
|
||||
// 创建MCP服务器(带数据库持久化)
|
||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||
@@ -202,14 +207,12 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
return nil, fmt.Errorf("初始化知识库嵌入器失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建检索器
|
||||
retrievalConfig := &knowledge.RetrievalConfig{
|
||||
TopK: cfg.Knowledge.Retrieval.TopK,
|
||||
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
|
||||
SubIndexFilter: cfg.Knowledge.Retrieval.SubIndexFilter,
|
||||
PostRetrieve: cfg.Knowledge.Retrieval.PostRetrieve,
|
||||
}
|
||||
// 创建检索器(Eino MultiQuery + 重排流水线)
|
||||
retrievalConfig := knowledge.RetrievalConfigFromYAML(cfg.Knowledge.Retrieval)
|
||||
knowledgeRetriever = knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, log.Logger)
|
||||
if err := knowledge.WireRetrieverPipeline(context.Background(), knowledgeRetriever, &cfg.OpenAI); err != nil {
|
||||
return nil, fmt.Errorf("初始化知识库检索流水线失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建索引器(Eino Compose 链)
|
||||
knowledgeIndexer, err = knowledge.NewIndexer(context.Background(), knowledgeDB, embedder, log.Logger, &cfg.Knowledge)
|
||||
@@ -353,6 +356,9 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
||||
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
||||
projectHandler := handler.NewProjectHandler(db, log.Logger)
|
||||
workflowHandler := handler.NewWorkflowHandler(db, log.Logger)
|
||||
workflowHandler.SetAudit(auditSvc)
|
||||
workflowHandler.SetRuntime(agent, cfg)
|
||||
vulnerabilityHandler.SetAudit(auditSvc)
|
||||
webshellHandler := handler.NewWebShellHandler(log.Logger, db)
|
||||
webshellHandler.SetAudit(auditSvc)
|
||||
@@ -363,6 +369,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||
configHandler.SetAudit(auditSvc)
|
||||
agentHandler.SetHitlToolWhitelistSaver(configHandler)
|
||||
agentHandler.SetHitlAuditStrategySaver(configHandler)
|
||||
agentHandler.SetHitlDefaultReviewerSaver(configHandler)
|
||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||
externalMCPHandler.SetAudit(auditSvc)
|
||||
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
||||
@@ -513,6 +521,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
app, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler,
|
||||
projectHandler,
|
||||
workflowHandler,
|
||||
webshellHandler,
|
||||
chatUploadsHandler,
|
||||
roleHandler,
|
||||
@@ -759,6 +768,7 @@ func setupRoutes(
|
||||
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler *handler.VulnerabilityHandler,
|
||||
projectHandler *handler.ProjectHandler,
|
||||
workflowHandler *handler.WorkflowHandler,
|
||||
webshellHandler *handler.WebShellHandler,
|
||||
chatUploadsHandler *handler.ChatUploadsHandler,
|
||||
roleHandler *handler.RoleHandler,
|
||||
@@ -812,11 +822,20 @@ func setupRoutes(
|
||||
protected.POST("/eino-agent", agentHandler.EinoSingleAgentLoop)
|
||||
protected.POST("/eino-agent/stream", agentHandler.EinoSingleAgentLoopStream)
|
||||
protected.GET("/hitl/pending", agentHandler.ListHITLPending)
|
||||
protected.GET("/hitl/logs", agentHandler.ListHITLLogs)
|
||||
protected.DELETE("/hitl/logs", agentHandler.DeleteHITLLogs)
|
||||
protected.GET("/hitl/logs/:id", agentHandler.GetHITLLog)
|
||||
protected.POST("/hitl/decision", agentHandler.DecideHITLInterrupt)
|
||||
protected.POST("/hitl/dismiss", agentHandler.DismissHITLInterrupt)
|
||||
protected.GET("/hitl/config/:conversationId", agentHandler.GetHITLConversationConfig)
|
||||
protected.PUT("/hitl/config", agentHandler.UpsertHITLConversationConfig)
|
||||
protected.GET("/hitl/tool-whitelist", agentHandler.GetHITLGlobalToolWhitelist)
|
||||
protected.PUT("/hitl/tool-whitelist", agentHandler.SetHITLGlobalToolWhitelist)
|
||||
protected.POST("/hitl/tool-whitelist", agentHandler.MergeHITLGlobalToolWhitelist)
|
||||
protected.GET("/hitl/default-reviewer", agentHandler.GetHITLDefaultReviewer)
|
||||
protected.PUT("/hitl/default-reviewer", agentHandler.UpdateHITLDefaultReviewer)
|
||||
protected.GET("/hitl/audit-strategy", agentHandler.GetHITLAuditStrategy)
|
||||
protected.PUT("/hitl/audit-strategy", agentHandler.UpdateHITLAuditStrategy)
|
||||
// Agent Loop 取消与任务列表
|
||||
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
|
||||
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
|
||||
@@ -1178,6 +1197,16 @@ func setupRoutes(
|
||||
protected.PUT("/roles/:name", roleHandler.UpdateRole)
|
||||
protected.DELETE("/roles/:name", roleHandler.DeleteRole)
|
||||
|
||||
// 图编排 / 工作流定义(图结构固定,业务字段保存在 graph_json 中)
|
||||
protected.GET("/workflows/runs/pending", workflowHandler.ListPendingRuns)
|
||||
protected.GET("/workflows/runs/:runId", workflowHandler.GetRun)
|
||||
protected.POST("/workflows/runs/:runId/resume", workflowHandler.ResumeRun)
|
||||
protected.GET("/workflows", workflowHandler.List)
|
||||
protected.GET("/workflows/:id", workflowHandler.Get)
|
||||
protected.POST("/workflows", workflowHandler.Create)
|
||||
protected.PUT("/workflows/:id", workflowHandler.Update)
|
||||
protected.DELETE("/workflows/:id", workflowHandler.Delete)
|
||||
|
||||
// Skills管理(具体路径需注册在 /skills/:name 之前)
|
||||
protected.GET("/skills", skillsHandler.GetSkills)
|
||||
protected.GET("/skills/stats", skillsHandler.GetSkillStats)
|
||||
@@ -1787,14 +1816,12 @@ func initializeKnowledge(
|
||||
return nil, fmt.Errorf("初始化知识库嵌入器失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建检索器
|
||||
retrievalConfig := &knowledge.RetrievalConfig{
|
||||
TopK: cfg.Knowledge.Retrieval.TopK,
|
||||
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
|
||||
SubIndexFilter: cfg.Knowledge.Retrieval.SubIndexFilter,
|
||||
PostRetrieve: cfg.Knowledge.Retrieval.PostRetrieve,
|
||||
}
|
||||
// 创建检索器(Eino MultiQuery + 重排流水线)
|
||||
retrievalConfig := knowledge.RetrievalConfigFromYAML(cfg.Knowledge.Retrieval)
|
||||
knowledgeRetriever := knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, logger)
|
||||
if err := knowledge.WireRetrieverPipeline(context.Background(), knowledgeRetriever, &cfg.OpenAI); err != nil {
|
||||
return nil, fmt.Errorf("初始化知识库检索流水线失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建索引器(Eino Compose 链)
|
||||
knowledgeIndexer, err := knowledge.NewIndexer(context.Background(), knowledgeDB, embedder, logger, &cfg.Knowledge)
|
||||
|
||||
@@ -120,9 +120,19 @@ func formatVulnerabilityDetail(v *database.Vulnerability) string {
|
||||
b.WriteString(v.Description)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Proof != "" {
|
||||
b.WriteString("\n--- 证明(POC) ---\n")
|
||||
b.WriteString(v.Proof)
|
||||
if v.Preconditions != "" {
|
||||
b.WriteString("\n--- 前置条件 ---\n")
|
||||
b.WriteString(v.Preconditions)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.ReproSteps != "" {
|
||||
b.WriteString("\n--- 复现步骤 ---\n")
|
||||
b.WriteString(v.ReproSteps)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Evidence != "" {
|
||||
b.WriteString("\n--- 证据 / POC ---\n")
|
||||
b.WriteString(v.Evidence)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Impact != "" {
|
||||
@@ -135,9 +145,36 @@ func formatVulnerabilityDetail(v *database.Vulnerability) string {
|
||||
b.WriteString(v.Recommendation)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.RetestNotes != "" {
|
||||
b.WriteString("\n--- 复测方式 ---\n")
|
||||
b.WriteString(v.RetestNotes)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func missingVulnerabilityReproFields(args map[string]interface{}) []string {
|
||||
required := []struct {
|
||||
key string
|
||||
label string
|
||||
}{
|
||||
{"target", "target(受影响的 URL/IP/服务/接口)"},
|
||||
{"vulnerability_type", "vulnerability_type(漏洞类型)"},
|
||||
{"description", "description(漏洞摘要与触发点)"},
|
||||
{"reproduction_steps", "reproduction_steps(可逐步执行的复现步骤)"},
|
||||
{"evidence", "evidence(POC、原始请求/响应、命令输出或截图/日志证据)"},
|
||||
{"impact", "impact(确认后的实际影响)"},
|
||||
{"recommendation", "recommendation(修复建议)"},
|
||||
}
|
||||
missing := make([]string, 0)
|
||||
for _, item := range required {
|
||||
if strings.TrimSpace(strArg(args, item.key)) == "" {
|
||||
missing = append(missing, item.label)
|
||||
}
|
||||
}
|
||||
return missing
|
||||
}
|
||||
|
||||
func truncateRunes(s string, max int) string {
|
||||
r := []rune(s)
|
||||
if len(r) <= max {
|
||||
@@ -163,18 +200,18 @@ func registerVulnerabilityTools(mcpServer *mcp.Server, db *database.DB, logger *
|
||||
func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: builtin.ToolRecordVulnerability,
|
||||
Description: "记录发现的漏洞详情到漏洞管理系统。边渗透边记录:每验证出一条可复现漏洞(含 POC/影响)后立即调用,勿等会话结束。包括标题、描述、严重程度、类型、目标、证明、影响和建议等。记录前可先 list_vulnerabilities 避免重复。",
|
||||
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
|
||||
Description: "记录发现的漏洞详情到漏洞管理系统。必须按“仅看本记录即可复现”的标准填写:目标、触发点、前置条件、复现步骤、证据/POC、实际影响、修复建议和复测方式。边渗透边记录:每验证出一条可复现漏洞后立即调用,勿等会话结束。记录前可先 list_vulnerabilities 避免重复。",
|
||||
ShortDescription: "记录可复现的漏洞详情到漏洞管理系统",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"title": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞标题(必需)",
|
||||
"description": "漏洞标题(必需)。建议格式:<资产/接口> 存在 <漏洞类型>,例如“/api/login 存在 SQL 注入”。",
|
||||
},
|
||||
"description": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞详细描述",
|
||||
"description": "漏洞摘要与触发点(必需):说明哪个功能/参数/入口存在问题、为什么可被利用。不要只写结论。",
|
||||
},
|
||||
"severity": map[string]interface{}{
|
||||
"type": "string",
|
||||
@@ -183,26 +220,38 @@ func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, log
|
||||
},
|
||||
"vulnerability_type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
|
||||
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等(必需)",
|
||||
},
|
||||
"target": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "受影响的目标(URL、IP地址、服务等)",
|
||||
"description": "受影响的目标(必需):尽量精确到 URL、IP:端口、服务名、接口路径和参数名。",
|
||||
},
|
||||
"proof": map[string]interface{}{
|
||||
"preconditions": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞证明(POC、截图、请求/响应等)",
|
||||
"description": "前置条件:登录状态、权限、账号、Header/Cookie、特定数据、网络位置、环境/版本等;无前置条件写“无”。",
|
||||
},
|
||||
"reproduction_steps": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "复现步骤(必需):按 1/2/3 编号,写清入口、参数、payload、执行命令、观察点。应让未参与对话的人照做即可复现。",
|
||||
},
|
||||
"evidence": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "证据 / POC(必需):原始 HTTP 请求/响应、curl/工具命令、截图文字说明、日志、DNSLog/回连记录、数据库结果、文件路径、时间戳等。优先放最小可验证证据。",
|
||||
},
|
||||
"impact": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞影响说明",
|
||||
"description": "漏洞影响说明(必需):结合已验证事实说明可造成什么后果,避免泛泛而谈。",
|
||||
},
|
||||
"recommendation": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "修复建议",
|
||||
"description": "修复建议(必需):给出针对该触发点/参数/组件的具体修复和复测建议。",
|
||||
},
|
||||
"retest_notes": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "复测方式:修复后如何验证漏洞已关闭,包括应返回的状态码、错误信息或访问控制结果。",
|
||||
},
|
||||
},
|
||||
"required": []string{"title", "severity"},
|
||||
"required": []string{"title", "description", "severity", "vulnerability_type", "target", "reproduction_steps", "evidence", "impact", "recommendation"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -231,6 +280,9 @@ func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, log
|
||||
if !validSeverities[severity] {
|
||||
return textResult(fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity), true), nil
|
||||
}
|
||||
if missing := missingVulnerabilityReproFields(args); len(missing) > 0 {
|
||||
return textResult("错误: 漏洞记录缺少复现所需信息,请补充后再记录:\n- "+strings.Join(missing, "\n- ")+"\n\n最佳实践:漏洞管理中的单条记录应独立包含目标、前置条件、复现步骤、证据/POC、影响和修复/复测方式。", true), nil
|
||||
}
|
||||
|
||||
projectID := ""
|
||||
if pid, perr := db.GetConversationProjectID(conversationID); perr == nil {
|
||||
@@ -246,9 +298,12 @@ func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, log
|
||||
Status: "open",
|
||||
Type: strArg(args, "vulnerability_type"),
|
||||
Target: strArg(args, "target"),
|
||||
Proof: strArg(args, "proof"),
|
||||
Preconditions: strArg(args, "preconditions"),
|
||||
ReproSteps: strArg(args, "reproduction_steps"),
|
||||
Evidence: strArg(args, "evidence"),
|
||||
Impact: strArg(args, "impact"),
|
||||
Recommendation: strArg(args, "recommendation"),
|
||||
RetestNotes: strArg(args, "retest_notes"),
|
||||
}
|
||||
|
||||
created, err := db.CreateVulnerability(vuln)
|
||||
@@ -275,8 +330,8 @@ func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, log
|
||||
|
||||
func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: builtin.ToolListVulnerabilities,
|
||||
Description: "列出当前授权范围内的漏洞(摘要)。默认:对话已绑定项目时列出该项目下全部漏洞;未绑项目时仅列出当前会话漏洞。可用 scope=conversation 仅看本会话。记录新漏洞前建议先调用以避免重复。",
|
||||
Name: builtin.ToolListVulnerabilities,
|
||||
Description: "列出当前授权范围内的漏洞(摘要)。默认:对话已绑定项目时列出该项目下全部漏洞;未绑项目时仅列出当前会话漏洞。可用 scope=conversation 仅看本会话。记录新漏洞前建议先调用以避免重复。",
|
||||
ShortDescription: "列出漏洞(默认当前项目)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
|
||||
+313
-96
@@ -30,7 +30,7 @@ type Config struct {
|
||||
Monitor MonitorConfig `yaml:"monitor,omitempty" json:"monitor,omitempty"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||
Robots RobotsConfig `yaml:"robots,omitempty" json:"robots,omitempty"` // 企业微信/钉钉/飞书等机器人配置
|
||||
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
|
||||
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
|
||||
@@ -79,7 +79,7 @@ func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
|
||||
type MultiAgentConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
RobotDefaultAgentMode string `yaml:"robot_default_agent_mode,omitempty" json:"robot_default_agent_mode,omitempty"` // eino_single | deep | plan_execute | supervisor
|
||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
|
||||
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
|
||||
// MaxIteration 已废弃:统一使用 agent.max_iterations(YAML 中保留字段仅为兼容旧配置,运行时不读取)。
|
||||
@@ -87,10 +87,10 @@ type MultiAgentConfig struct {
|
||||
// PlanExecuteLoopMaxIterations plan_execute 模式下 execute↔replan 外层循环上限;0 表示用 Eino 默认 10。
|
||||
PlanExecuteLoopMaxIterations int `yaml:"plan_execute_loop_max_iterations,omitempty" json:"plan_execute_loop_max_iterations,omitempty"`
|
||||
// SubAgentMaxIterations 已废弃:子代理与主代理均使用 agent.max_iterations(Markdown max_iterations>0 可覆盖)。
|
||||
SubAgentMaxIterations int `yaml:"sub_agent_max_iterations,omitempty" json:"sub_agent_max_iterations,omitempty"`
|
||||
WithoutGeneralSubAgent bool `yaml:"without_general_sub_agent" json:"without_general_sub_agent"`
|
||||
WithoutWriteTodos bool `yaml:"without_write_todos" json:"without_write_todos"`
|
||||
OrchestratorInstruction string `yaml:"orchestrator_instruction" json:"orchestrator_instruction"`
|
||||
SubAgentMaxIterations int `yaml:"sub_agent_max_iterations,omitempty" json:"sub_agent_max_iterations,omitempty"`
|
||||
WithoutGeneralSubAgent bool `yaml:"without_general_sub_agent" json:"without_general_sub_agent"`
|
||||
WithoutWriteTodos bool `yaml:"without_write_todos" json:"without_write_todos"`
|
||||
OrchestratorInstruction string `yaml:"orchestrator_instruction" json:"orchestrator_instruction"`
|
||||
// OrchestratorInstructionPlanExecute plan_execute 主代理(规划侧)系统提示;非空且 agents/orchestrator-plan-execute.md 正文为空或未存在时生效。不与 Deep 的 orchestrator_instruction 混用。
|
||||
OrchestratorInstructionPlanExecute string `yaml:"orchestrator_instruction_plan_execute,omitempty" json:"orchestrator_instruction_plan_execute,omitempty"`
|
||||
// OrchestratorInstructionSupervisor supervisor 主代理系统提示(transfer/exit 说明仍由运行追加);非空且 agents/orchestrator-supervisor.md 正文为空或未存在时生效。
|
||||
@@ -99,9 +99,6 @@ type MultiAgentConfig struct {
|
||||
// SubAgentUserContextMaxRunes caps user-context supplement for sub-agent task descriptions.
|
||||
// 0 (default) preserves all user turns verbatim; >0 caps total runes; negative disables injection.
|
||||
SubAgentUserContextMaxRunes int `yaml:"sub_agent_user_context_max_runes,omitempty" json:"sub_agent_user_context_max_runes,omitempty"`
|
||||
// UserVerbatimAnchorMaxRunes injects all user turns verbatim into system prompt (survives summarization refresh).
|
||||
// 0 (default) = no cap; >0 = total rune cap; negative disables anchor injection.
|
||||
UserVerbatimAnchorMaxRunes int `yaml:"user_verbatim_anchor_max_runes,omitempty" json:"user_verbatim_anchor_max_runes,omitempty"`
|
||||
// EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent.
|
||||
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
|
||||
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
|
||||
@@ -110,11 +107,6 @@ type MultiAgentConfig struct {
|
||||
EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"`
|
||||
}
|
||||
|
||||
// UserVerbatimAnchorMaxRunesEffective returns max runes for user verbatim anchor; 0 = unlimited; negative = disabled.
|
||||
func (c MultiAgentConfig) UserVerbatimAnchorMaxRunesEffective() int {
|
||||
return c.UserVerbatimAnchorMaxRunes
|
||||
}
|
||||
|
||||
// SubAgentUserContextMaxRunesEffective returns max runes for sub-agent task supplement; 0 = unlimited; negative = disabled.
|
||||
func (c MultiAgentConfig) SubAgentUserContextMaxRunesEffective() int {
|
||||
return c.SubAgentUserContextMaxRunes
|
||||
@@ -138,11 +130,11 @@ type MultiAgentEinoCallbacksConfig struct {
|
||||
|
||||
// MultiAgentEinoCallbacksOtelConfig OpenTelemetry for Eino callback spans (W3C trace in collector / stdout).
|
||||
type MultiAgentEinoCallbacksOtelConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ServiceName string `yaml:"service_name,omitempty" json:"service_name,omitempty"`
|
||||
Exporter string `yaml:"exporter,omitempty" json:"exporter,omitempty"` // none | stdout | otlphttp
|
||||
OTLPEndpoint string `yaml:"otlp_endpoint,omitempty" json:"otlp_endpoint,omitempty"` // host:port, e.g. localhost:4318 (path /v1/traces)
|
||||
SampleRatio float64 `yaml:"sample_ratio,omitempty" json:"sample_ratio,omitempty"` // 0–1, default 1.0
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ServiceName string `yaml:"service_name,omitempty" json:"service_name,omitempty"`
|
||||
Exporter string `yaml:"exporter,omitempty" json:"exporter,omitempty"` // none | stdout | otlphttp
|
||||
OTLPEndpoint string `yaml:"otlp_endpoint,omitempty" json:"otlp_endpoint,omitempty"` // host:port, e.g. localhost:4318 (path /v1/traces)
|
||||
SampleRatio float64 `yaml:"sample_ratio,omitempty" json:"sample_ratio,omitempty"` // 0–1, default 1.0
|
||||
}
|
||||
|
||||
// EinoCallbacksModeEffective returns off | log_only | sse | full.
|
||||
@@ -253,12 +245,12 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
// PlantaskRelDir relative to skills_dir for per-conversation task boards (default .eino/plantask).
|
||||
PlantaskRelDir string `yaml:"plantask_rel_dir,omitempty" json:"plantask_rel_dir,omitempty"`
|
||||
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
|
||||
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
|
||||
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // 非空:落盘根目录(默认 tmp/reduction);其下按 projects/{id} 或 conversations/{id} 隔离
|
||||
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
|
||||
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
|
||||
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
|
||||
ReductionSubAgents bool `yaml:"reduction_sub_agents,omitempty" json:"reduction_sub_agents,omitempty"` // also attach to sub-agents
|
||||
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
|
||||
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // 非空:落盘根目录(默认 tmp/reduction);其下按 projects/{id} 或 conversations/{id} 隔离
|
||||
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
|
||||
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
|
||||
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
|
||||
ReductionSubAgents bool `yaml:"reduction_sub_agents,omitempty" json:"reduction_sub_agents,omitempty"` // also attach to sub-agents
|
||||
// SummarizationTriggerRatio controls summarization trigger threshold as max_total_tokens * ratio (default 0.8).
|
||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||
@@ -283,6 +275,8 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||
// EmptyResponseContinueMaxAttempts Run 成功但未捕获助手正文时 Handler 层退避续跑次数;0=默认 5。
|
||||
EmptyResponseContinueMaxAttempts int `yaml:"empty_response_continue_max_attempts,omitempty" json:"empty_response_continue_max_attempts,omitempty"`
|
||||
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
|
||||
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
||||
}
|
||||
@@ -404,13 +398,13 @@ type MultiAgentSubConfig struct {
|
||||
|
||||
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
|
||||
type MultiAgentPublic struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
SubAgentCount int `json:"sub_agent_count"`
|
||||
Orchestration string `json:"orchestration,omitempty"`
|
||||
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
|
||||
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
SubAgentCount int `json:"sub_agent_count"`
|
||||
Orchestration string `json:"orchestration,omitempty"`
|
||||
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
|
||||
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
|
||||
}
|
||||
|
||||
@@ -451,10 +445,10 @@ func NormalizeMultiAgentOrchestration(s string) string {
|
||||
|
||||
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
|
||||
type MultiAgentAPIUpdate struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
||||
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
|
||||
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
}
|
||||
@@ -470,14 +464,14 @@ type RobotsConfig struct {
|
||||
|
||||
// RobotWechatConfig 微信 iLink 机器人配置(个人微信 ClawBot / iLink 协议)
|
||||
type RobotWechatConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
BotToken string `yaml:"bot_token,omitempty" json:"bot_token,omitempty"`
|
||||
ILinkBotID string `yaml:"ilink_bot_id,omitempty" json:"ilink_bot_id,omitempty"`
|
||||
ILinkUserID string `yaml:"ilink_user_id,omitempty" json:"ilink_user_id,omitempty"`
|
||||
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://ilinkai.weixin.qq.com
|
||||
BotType string `yaml:"bot_type,omitempty" json:"bot_type,omitempty"` // get_bot_qrcode 参数,默认 3
|
||||
BotAgent string `yaml:"bot_agent,omitempty" json:"bot_agent,omitempty"` // base_info.bot_agent
|
||||
GetUpdatesBuf string `yaml:"get_updates_buf,omitempty" json:"get_updates_buf,omitempty"` // 长轮询游标(运行时)
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
BotToken string `yaml:"bot_token,omitempty" json:"bot_token,omitempty"`
|
||||
ILinkBotID string `yaml:"ilink_bot_id,omitempty" json:"ilink_bot_id,omitempty"`
|
||||
ILinkUserID string `yaml:"ilink_user_id,omitempty" json:"ilink_user_id,omitempty"`
|
||||
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://ilinkai.weixin.qq.com
|
||||
BotType string `yaml:"bot_type,omitempty" json:"bot_type,omitempty"` // get_bot_qrcode 参数,默认 3
|
||||
BotAgent string `yaml:"bot_agent,omitempty" json:"bot_agent,omitempty"` // base_info.bot_agent
|
||||
GetUpdatesBuf string `yaml:"get_updates_buf,omitempty" json:"get_updates_buf,omitempty"` // 长轮询游标(运行时)
|
||||
}
|
||||
|
||||
// RobotSessionConfig 机器人会话隔离策略
|
||||
@@ -503,21 +497,32 @@ type RobotWecomConfig struct {
|
||||
AgentID int64 `yaml:"agent_id" json:"agent_id"` // 应用 AgentId
|
||||
}
|
||||
|
||||
// ValidateWecomConfig 校验企业微信机器人配置;启用时必须配置 token,否则回调无法防伪造。
|
||||
func ValidateWecomConfig(w RobotWecomConfig) error {
|
||||
if !w.Enabled {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(w.Token) == "" {
|
||||
return fmt.Errorf("robots.wecom.enabled 为 true 时必须配置 robots.wecom.token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RobotDingtalkConfig 钉钉机器人配置
|
||||
type RobotDingtalkConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey)
|
||||
ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey)
|
||||
ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret
|
||||
AllowConversationIDFallback bool `yaml:"allow_conversation_id_fallback" json:"allow_conversation_id_fallback"` // sender_id 缺失时是否允许回退到会话 ID
|
||||
}
|
||||
|
||||
// RobotLarkConfig 飞书机器人配置
|
||||
type RobotLarkConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID
|
||||
AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret
|
||||
VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选)
|
||||
AllowChatIDFallback bool `yaml:"allow_chat_id_fallback" json:"allow_chat_id_fallback"` // 用户 ID 缺失时是否允许回退到 chat_id
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID
|
||||
AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret
|
||||
VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选)
|
||||
AllowChatIDFallback bool `yaml:"allow_chat_id_fallback" json:"allow_chat_id_fallback"` // 用户 ID 缺失时是否允许回退到 chat_id
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -616,8 +621,8 @@ type DatabaseConfig struct {
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
||||
// ShellNoOutputTimeoutSeconds execute/exec 无任何 stdout/stderr 时的空闲终止秒数(通用防挂死,不维护命令黑名单);0=默认 300(5 分钟);-1=关闭。
|
||||
ShellNoOutputTimeoutSeconds int `yaml:"shell_no_output_timeout_seconds" json:"shell_no_output_timeout_seconds"`
|
||||
// WorkspaceRootDir 会话工作目录根路径(curl/wget 下载、read_file/glob/grep 本地分析);空=tmp/workspace,其下按 projects/{id} 或 conversations/{id} 隔离。
|
||||
@@ -627,10 +632,112 @@ type AgentConfig struct {
|
||||
}
|
||||
|
||||
// HitlConfig 人机协同全局选项;与会话侧栏/API 中的白名单合并为并集后参与判定。
|
||||
// tool_whitelist 可在侧栏「应用」时合并写入 config.yaml 并立即生效;其他字段若仅改文件仍需重启。
|
||||
// tool_whitelist 可在侧栏「应用」时合并写入 config.yaml 并立即生效。
|
||||
// audit_agent_prompt / audit_agent_prompt_review_edit 可在人机协同页编辑并立即生效;空则使用内置默认。
|
||||
type HitlConfig struct {
|
||||
// ToolWhitelist 全局免审批工具名(与每条会话配置的 sensitiveTools 语义相同:白名单内工具不触发 HITL)。
|
||||
// ToolWhitelist 全局免审批工具名(与白名单内工具不触发 HITL 审批)。
|
||||
ToolWhitelist []string `yaml:"tool_whitelist,omitempty" json:"tool_whitelist,omitempty"`
|
||||
// AuditAgentPrompt 审批模式(approval)下审计 Agent 系统提示词。
|
||||
AuditAgentPrompt string `yaml:"audit_agent_prompt,omitempty" json:"audit_agent_prompt,omitempty"`
|
||||
// AuditAgentPromptReviewEdit 审查编辑模式(review_edit)下审计 Agent 系统提示词。
|
||||
AuditAgentPromptReviewEdit string `yaml:"audit_agent_prompt_review_edit,omitempty" json:"audit_agent_prompt_review_edit,omitempty"`
|
||||
// RetentionDays 已决策审计日志(hitl_interrupts 非 pending)保留天数;省略时默认 90;0 表示不自动清理。
|
||||
RetentionDays *int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||
// DefaultReviewer 全局默认审批方(human | audit_agent);未选会话时切换会写入 config.yaml;新建会话无独立配置时沿用。
|
||||
DefaultReviewer string `yaml:"default_reviewer,omitempty" json:"default_reviewer,omitempty"`
|
||||
}
|
||||
|
||||
// EffectiveDefaultReviewer returns human or audit_agent; omitted or unknown values default to human.
|
||||
func (h HitlConfig) EffectiveDefaultReviewer() string {
|
||||
switch strings.ToLower(strings.TrimSpace(h.DefaultReviewer)) {
|
||||
case "audit_agent", "agent", "ai":
|
||||
return "audit_agent"
|
||||
default:
|
||||
return "human"
|
||||
}
|
||||
}
|
||||
|
||||
// RetentionDaysEffective returns retention; 0 means keep forever; omitted defaults to 90.
|
||||
func (h HitlConfig) RetentionDaysEffective() int {
|
||||
if h.RetentionDays == nil {
|
||||
return 90
|
||||
}
|
||||
if *h.RetentionDays < 0 {
|
||||
return 0
|
||||
}
|
||||
return *h.RetentionDays
|
||||
}
|
||||
|
||||
const hitlAuditAgentPromptBase = `你是 CyberStrikeAI 人机协同审计 Agent。审查 Agent 即将执行的工具调用是否会对系统造成实质性损害。
|
||||
|
||||
你会收到 JSON,包含 hitlMode、toolName、arguments/argumentsObj、userMessage、thinking、reasoningChain、planning 等字段。
|
||||
|
||||
裁决基调(默认放行):
|
||||
- 常规、低风险的渗透测试操作 → approve(如信息收集、端口/服务扫描、目录枚举、只读查询、无害探测命令)
|
||||
- 与用户授权、当前任务目标一致,且未见明确高危迹象 → approve
|
||||
- 仅在「可能对系统造成实质影响」时 → reject
|
||||
|
||||
必须 reject 的高危情形(示例,非穷举):
|
||||
- 删库、清表、批量删除数据、格式化磁盘、不可逆破坏
|
||||
- 修改/重置密码、创建或篡改管理员账号、持久化后门、开机自启
|
||||
- 向生产环境写入恶意载荷、勒索加密、停止关键服务、修改系统核心配置
|
||||
- 明显越权:与任务/授权目标无关的破坏性操作
|
||||
|
||||
不应单独作为 reject 理由的情形:
|
||||
- 常规 nmap/curl/grep/读文件/枚举类命令本身
|
||||
- 参数略显宽泛但无明确破坏意图(审查编辑模式可收窄参数后 approve)
|
||||
- 仅因「信息不足」——若无上述高危迹象,应 approve 并可在 comment 中提示注意点`
|
||||
|
||||
const hitlAuditAgentPromptApprovalOutput = `
|
||||
仅输出一行 JSON,不要 markdown 代码块:
|
||||
{"decision":"approve"|"reject","comment":"简要理由"}`
|
||||
|
||||
const hitlAuditAgentPromptReviewEditOutput = `
|
||||
仅输出一行 JSON,不要 markdown 代码块:
|
||||
{"decision":"approve"|"reject","comment":"简要理由","editedArguments":{...}}
|
||||
|
||||
editedArguments 规则(仅 approve 且需要改参时填写,否则省略该字段):
|
||||
- 提供完整替换后的工具参数对象,键名与 argumentsObj 一致
|
||||
- 只做最小必要修改以收窄范围、消除风险(如限制 path、去掉危险 flag)
|
||||
- 禁止扩大攻击面:不得扩大目标范围、提升权限或引入破坏性参数
|
||||
- 无法安全改参时应 reject,不要勉强 approve`
|
||||
|
||||
// DefaultHitlAuditAgentPrompt 内置审批模式审计 Agent 提示词。
|
||||
func DefaultHitlAuditAgentPrompt() string {
|
||||
return hitlAuditAgentPromptBase + hitlAuditAgentPromptApprovalOutput
|
||||
}
|
||||
|
||||
// DefaultHitlAuditAgentPromptReviewEdit 内置审查编辑模式审计 Agent 提示词。
|
||||
func DefaultHitlAuditAgentPromptReviewEdit() string {
|
||||
return hitlAuditAgentPromptBase + hitlAuditAgentPromptReviewEditOutput
|
||||
}
|
||||
|
||||
// EffectiveAuditAgentPrompt 返回审批模式生效的审计 Agent 提示词。
|
||||
func (c HitlConfig) EffectiveAuditAgentPrompt() string {
|
||||
return c.EffectiveAuditAgentPromptForMode("approval")
|
||||
}
|
||||
|
||||
// EffectiveAuditAgentPromptForMode 按 HITL 模式返回生效的审计 Agent 提示词。
|
||||
func (c HitlConfig) EffectiveAuditAgentPromptForMode(mode string) string {
|
||||
if normalizeHitlModeForPrompt(mode) == "review_edit" {
|
||||
if s := strings.TrimSpace(c.AuditAgentPromptReviewEdit); s != "" {
|
||||
return s
|
||||
}
|
||||
return DefaultHitlAuditAgentPromptReviewEdit()
|
||||
}
|
||||
if s := strings.TrimSpace(c.AuditAgentPrompt); s != "" {
|
||||
return s
|
||||
}
|
||||
return DefaultHitlAuditAgentPrompt()
|
||||
}
|
||||
|
||||
func normalizeHitlModeForPrompt(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case "review_edit":
|
||||
return "review_edit"
|
||||
default:
|
||||
return "approval"
|
||||
}
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
@@ -661,9 +768,9 @@ func (m MonitorConfig) RetentionDaysEffective() int {
|
||||
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
|
||||
type AuditConfig struct {
|
||||
// Enabled nil or true enables persistence; explicit false disables.
|
||||
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||
RetentionDays int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||
MaxDetailBytes int `yaml:"max_detail_bytes,omitempty" json:"max_detail_bytes,omitempty"`
|
||||
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||
RetentionDays int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||
MaxDetailBytes int `yaml:"max_detail_bytes,omitempty" json:"max_detail_bytes,omitempty"`
|
||||
// AuthFailureCooldownSeconds: per-IP cooldown for auth login/change_password failure audit rows; -1 disables; 0 uses default 60.
|
||||
AuthFailureCooldownSeconds int `yaml:"auth_failure_cooldown_seconds,omitempty" json:"auth_failure_cooldown_seconds,omitempty"`
|
||||
}
|
||||
@@ -817,33 +924,13 @@ func Load(path string) (*Config, error) {
|
||||
|
||||
// 如果配置了工具目录,从目录加载工具配置
|
||||
if cfg.Security.ToolsDir != "" {
|
||||
configDir := filepath.Dir(path)
|
||||
toolsDir := cfg.Security.ToolsDir
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(toolsDir) {
|
||||
toolsDir = filepath.Join(configDir, toolsDir)
|
||||
}
|
||||
|
||||
tools, err := LoadToolsFromDir(toolsDir)
|
||||
inlineTools := append([]ToolConfig(nil), cfg.Security.Tools...)
|
||||
toolsDir := ResolveToolsDir(cfg.Security.ToolsDir, path)
|
||||
merged, err := MergeToolsFromDir(toolsDir, inlineTools)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 合并工具配置:目录中的工具优先,主配置中的工具作为补充
|
||||
existingTools := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
existingTools[tool.Name] = true
|
||||
}
|
||||
|
||||
// 添加主配置中不存在于目录中的工具(向后兼容)
|
||||
for _, tool := range cfg.Security.Tools {
|
||||
if !existingTools[tool.Name] {
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
}
|
||||
|
||||
cfg.Security.Tools = tools
|
||||
cfg.Security.Tools = merged
|
||||
}
|
||||
|
||||
// 外部 MCP:迁移 + 环境变量展开
|
||||
@@ -887,6 +974,10 @@ func Load(path string) (*Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := ValidateWecomConfig(cfg.Robots.Wecom); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
@@ -1111,6 +1202,75 @@ func PrintMCPConfigJSON(mcp MCPConfig) {
|
||||
fmt.Println("----------------------------------------------------------------")
|
||||
}
|
||||
|
||||
// ResolveToolsDir 将 tools_dir 解析为绝对路径(相对路径相对于 configPath 所在目录)。
|
||||
func ResolveToolsDir(toolsDir, configPath string) string {
|
||||
toolsDir = strings.TrimSpace(toolsDir)
|
||||
if toolsDir == "" {
|
||||
return ""
|
||||
}
|
||||
if filepath.IsAbs(toolsDir) {
|
||||
return toolsDir
|
||||
}
|
||||
return filepath.Join(filepath.Dir(configPath), toolsDir)
|
||||
}
|
||||
|
||||
// MergeToolsFromDir 从目录加载工具并与 inline 列表合并:目录中的工具优先,主配置中的工具作为补充。
|
||||
func MergeToolsFromDir(toolsDir string, inlineTools []ToolConfig) ([]ToolConfig, error) {
|
||||
dirTools, err := LoadToolsFromDir(toolsDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
existing := make(map[string]bool, len(dirTools))
|
||||
for _, tool := range dirTools {
|
||||
existing[tool.Name] = true
|
||||
}
|
||||
merged := append([]ToolConfig(nil), dirTools...)
|
||||
for _, tool := range inlineTools {
|
||||
if !existing[tool.Name] {
|
||||
merged = append(merged, tool)
|
||||
}
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
// loadInlineSecurityToolsFromYAML 读取 config.yaml 中 security.tools(不含 tools_dir 扫描结果)。
|
||||
func loadInlineSecurityToolsFromYAML(configPath string) ([]ToolConfig, error) {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
var partial struct {
|
||||
Security struct {
|
||||
Tools []ToolConfig `yaml:"tools"`
|
||||
} `yaml:"security"`
|
||||
}
|
||||
if err := yaml.Unmarshal(data, &partial); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
if partial.Security.Tools == nil {
|
||||
return []ToolConfig{}, nil
|
||||
}
|
||||
return partial.Security.Tools, nil
|
||||
}
|
||||
|
||||
// ReloadSecurityToolsFromDir 从 tools_dir 重新加载工具并更新 cfg.Security.Tools(ApplyConfig 热重载用)。
|
||||
func ReloadSecurityToolsFromDir(cfg *Config, configPath string) error {
|
||||
if cfg == nil || strings.TrimSpace(cfg.Security.ToolsDir) == "" {
|
||||
return nil
|
||||
}
|
||||
inlineTools, err := loadInlineSecurityToolsFromYAML(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
toolsDir := ResolveToolsDir(cfg.Security.ToolsDir, configPath)
|
||||
merged, err := MergeToolsFromDir(toolsDir, inlineTools)
|
||||
if err != nil {
|
||||
return fmt.Errorf("从工具目录加载工具配置失败: %w", err)
|
||||
}
|
||||
cfg.Security.Tools = merged
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadToolsFromDir 从目录加载所有工具配置文件
|
||||
func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
|
||||
var tools []ToolConfig
|
||||
@@ -1288,8 +1448,8 @@ func Default() *Config {
|
||||
},
|
||||
Agent: AgentConfig{
|
||||
MaxIterations: 30, // 默认最大迭代次数
|
||||
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
||||
ShellNoOutputTimeoutSeconds: 300, // execute/exec 无新输出空闲终止(秒);-1 关闭
|
||||
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
||||
ShellNoOutputTimeoutSeconds: 300, // execute/exec 无新输出空闲终止(秒);-1 关闭
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
|
||||
@@ -1329,7 +1489,12 @@ func Default() *Config {
|
||||
},
|
||||
Retrieval: RetrievalConfig{
|
||||
TopK: 5,
|
||||
SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检
|
||||
SimilarityThreshold: 0.65,
|
||||
MultiQuery: MultiQueryConfig{MaxQueries: 4},
|
||||
Rerank: RerankConfig{},
|
||||
PostRetrieve: PostRetrieveConfig{
|
||||
PrefetchTopK: 20,
|
||||
},
|
||||
},
|
||||
Indexing: IndexingConfig{
|
||||
ChunkStrategy: "markdown_then_recursive",
|
||||
@@ -1425,7 +1590,7 @@ type EmbeddingConfig struct {
|
||||
|
||||
// PostRetrieveConfig 检索后处理:固定对正文做规范化去重(最佳实践)、上下文预算截断;PrefetchTopK 用于多取候选再收敛到 top_k。
|
||||
type PostRetrieveConfig struct {
|
||||
// PrefetchTopK 向量检索阶段最多保留的候选数(余弦序),应 ≥ top_k,0 表示与 top_k 相同;上限见知识库包内常量。
|
||||
// PrefetchTopK 向量检索阶段每条 MultiQuery 变体最多保留的候选数;0 表示使用内置默认 max(top_k*4, 20)。
|
||||
PrefetchTopK int `yaml:"prefetch_top_k,omitempty" json:"prefetch_top_k,omitempty"`
|
||||
// MaxContextChars 返回文档内容总 Unicode 字符数上限(整段 chunk,不截断半段);0 表示不限制。
|
||||
MaxContextChars int `yaml:"max_context_chars,omitempty" json:"max_context_chars,omitempty"`
|
||||
@@ -1433,13 +1598,62 @@ type PostRetrieveConfig struct {
|
||||
MaxContextTokens int `yaml:"max_context_tokens,omitempty" json:"max_context_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// MultiQueryConfig Eino MultiQuery 查询改写(始终启用,无关闭开关)。
|
||||
type MultiQueryConfig struct {
|
||||
// MaxQueries LLM 生成的检索变体上限(含原问语义覆盖);0 表示默认 4。
|
||||
MaxQueries int `yaml:"max_queries,omitempty" json:"max_queries,omitempty"`
|
||||
}
|
||||
|
||||
func (c MultiQueryConfig) MaxQueriesEffective() int {
|
||||
if c.MaxQueries <= 0 {
|
||||
return 4
|
||||
}
|
||||
if c.MaxQueries > 8 {
|
||||
return 8
|
||||
}
|
||||
return c.MaxQueries
|
||||
}
|
||||
|
||||
// RerankConfig 检索精排(始终启用);支持 dashscope 与 Cohere 兼容 HTTP API。
|
||||
type RerankConfig struct {
|
||||
// Provider: dashscope | cohere;空则按 base_url 自动推断。
|
||||
Provider string `yaml:"provider,omitempty" json:"provider,omitempty"`
|
||||
Model string `yaml:"model,omitempty" json:"model,omitempty"`
|
||||
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"`
|
||||
APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty"`
|
||||
}
|
||||
|
||||
func (c RerankConfig) ProviderEffective(baseURL string) string {
|
||||
p := strings.TrimSpace(strings.ToLower(c.Provider))
|
||||
if p != "" {
|
||||
return p
|
||||
}
|
||||
u := strings.ToLower(baseURL)
|
||||
if strings.Contains(u, "dashscope") {
|
||||
return "dashscope"
|
||||
}
|
||||
return "cohere"
|
||||
}
|
||||
|
||||
func (c RerankConfig) ModelEffective(provider string) string {
|
||||
if m := strings.TrimSpace(c.Model); m != "" {
|
||||
return m
|
||||
}
|
||||
if provider == "dashscope" {
|
||||
return "gte-rerank"
|
||||
}
|
||||
return "rerank-multilingual-v3.0"
|
||||
}
|
||||
|
||||
// RetrievalConfig 检索配置
|
||||
type RetrievalConfig struct {
|
||||
TopK int `yaml:"top_k" json:"top_k"` // 检索Top-K
|
||||
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 余弦相似度阈值
|
||||
// SubIndexFilter 非空时仅保留 sub_indexes 含该标签(逗号分隔之一)的行;sub_indexes 为空的旧行仍返回。
|
||||
SubIndexFilter string `yaml:"sub_index_filter,omitempty" json:"sub_index_filter,omitempty"`
|
||||
// PostRetrieve 检索后处理(去重、预算截断);重排通过代码注入 [knowledge.DocumentReranker]。
|
||||
SubIndexFilter string `yaml:"sub_index_filter,omitempty" json:"sub_index_filter,omitempty"`
|
||||
MultiQuery MultiQueryConfig `yaml:"multi_query" json:"multi_query"`
|
||||
Rerank RerankConfig `yaml:"rerank" json:"rerank"`
|
||||
// PostRetrieve 检索后处理(去重、预算截断);精排在 MultiQuery 融合后执行。
|
||||
PostRetrieve PostRetrieveConfig `yaml:"post_retrieve,omitempty" json:"post_retrieve,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1451,11 +1665,14 @@ type RolesConfig struct {
|
||||
|
||||
// RoleConfig 单个角色配置
|
||||
type RoleConfig struct {
|
||||
Name string `yaml:"name" json:"name"` // 角色名称
|
||||
Description string `yaml:"description" json:"description"` // 角色描述
|
||||
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
|
||||
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
|
||||
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName")
|
||||
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
|
||||
Name string `yaml:"name" json:"name"` // 角色名称
|
||||
Description string `yaml:"description" json:"description"` // 角色描述
|
||||
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
|
||||
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
|
||||
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName")
|
||||
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
|
||||
WorkflowID string `yaml:"workflow_id,omitempty" json:"workflow_id,omitempty"` // 可选:绑定图编排流程 ID
|
||||
WorkflowVersion string `yaml:"workflow_version,omitempty" json:"workflow_version,omitempty"` // latest 或具体版本号;空等同 latest
|
||||
WorkflowPolicy string `yaml:"workflow_policy,omitempty" json:"workflow_policy,omitempty"` // auto | off;空且 workflow_id 非空时按 auto
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
|
||||
}
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateWecomConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg RobotWecomConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "disabled without token",
|
||||
cfg: RobotWecomConfig{Enabled: false, Token: ""},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "enabled with token",
|
||||
cfg: RobotWecomConfig{Enabled: true, Token: "secret"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "enabled without token",
|
||||
cfg: RobotWecomConfig{Enabled: true, Token: ""},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "enabled with whitespace token",
|
||||
cfg: RobotWecomConfig{Enabled: true, Token: " "},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := ValidateWecomConfig(tt.cfg)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("ValidateWecomConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReloadSecurityToolsFromDir(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
toolsDir := filepath.Join(root, "tools")
|
||||
if err := os.MkdirAll(toolsDir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(root, "config.yaml")
|
||||
if err := os.WriteFile(configPath, []byte(`security:
|
||||
tools_dir: tools
|
||||
tools:
|
||||
- name: inline-only
|
||||
command: inline-cmd
|
||||
enabled: true
|
||||
description: inline tool
|
||||
`), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
writeTool := func(name, command string) {
|
||||
t.Helper()
|
||||
content := "name: " + name + "\ncommand: " + command + "\nenabled: true\ndescription: test\n"
|
||||
if err := os.WriteFile(filepath.Join(toolsDir, name+".yaml"), []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
writeTool("alpha", "alpha-cmd")
|
||||
|
||||
cfg := &Config{
|
||||
Security: SecurityConfig{
|
||||
ToolsDir: "tools",
|
||||
Tools: []ToolConfig{
|
||||
{Name: "stale", Command: "stale-cmd", Enabled: true, Description: "should be removed"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := ReloadSecurityToolsFromDir(cfg, configPath); err != nil {
|
||||
t.Fatalf("reload: %v", err)
|
||||
}
|
||||
if len(cfg.Security.Tools) != 2 {
|
||||
t.Fatalf("expected 2 tools, got %d", len(cfg.Security.Tools))
|
||||
}
|
||||
|
||||
names := map[string]string{}
|
||||
for _, tool := range cfg.Security.Tools {
|
||||
names[tool.Name] = tool.Command
|
||||
}
|
||||
if names["alpha"] != "alpha-cmd" {
|
||||
t.Fatalf("alpha tool missing or wrong command: %#v", names)
|
||||
}
|
||||
if names["inline-only"] != "inline-cmd" {
|
||||
t.Fatalf("inline-only tool missing: %#v", names)
|
||||
}
|
||||
if _, ok := names["stale"]; ok {
|
||||
t.Fatal("stale in-memory tool should not survive reload")
|
||||
}
|
||||
|
||||
writeTool("beta", "beta-cmd")
|
||||
if err := ReloadSecurityToolsFromDir(cfg, configPath); err != nil {
|
||||
t.Fatalf("second reload: %v", err)
|
||||
}
|
||||
if len(cfg.Security.Tools) != 3 {
|
||||
t.Fatalf("expected 3 tools after add, got %d", len(cfg.Security.Tools))
|
||||
}
|
||||
foundBeta := false
|
||||
for _, tool := range cfg.Security.Tools {
|
||||
if tool.Name == "beta" {
|
||||
foundBeta = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundBeta {
|
||||
t.Fatal("beta tool not found after second reload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeToolsFromDir_DirOverridesInline(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
toolsDir := filepath.Join(root, "tools")
|
||||
if err := os.MkdirAll(toolsDir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
content := "name: shared\ncommand: dir-cmd\nenabled: true\ndescription: from dir\n"
|
||||
if err := os.WriteFile(filepath.Join(toolsDir, "shared.yaml"), []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
inline := []ToolConfig{
|
||||
{Name: "shared", Command: "inline-cmd", Enabled: true, Description: "from inline"},
|
||||
}
|
||||
merged, err := MergeToolsFromDir(toolsDir, inline)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(merged) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(merged))
|
||||
}
|
||||
if merged[0].Command != "dir-cmd" {
|
||||
t.Fatalf("dir tool should win, got command %q", merged[0].Command)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -577,6 +578,19 @@ func (db *DB) ListUngroupedConversations(limit, offset int, sortBy, projectID st
|
||||
return conversations, rows.Err()
|
||||
}
|
||||
|
||||
// GetConversationTitle 获取对话标题(轻量查询,不加载消息)
|
||||
func (db *DB) GetConversationTitle(id string) (string, error) {
|
||||
var title string
|
||||
err := db.QueryRow("SELECT title FROM conversations WHERE id = ?", id).Scan(&title)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("对话不存在")
|
||||
}
|
||||
return "", fmt.Errorf("查询对话标题失败: %w", err)
|
||||
}
|
||||
return title, nil
|
||||
}
|
||||
|
||||
// UpdateConversationTitle 更新对话标题
|
||||
func (db *DB) UpdateConversationTitle(id, title string) error {
|
||||
// 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间
|
||||
@@ -1057,6 +1071,77 @@ type ProcessDetail struct {
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// GetTurnUserMessage 返回锚点消息所在轮次中的用户原文(最近一条 user 消息,不含完整历史)。
|
||||
func (db *DB) GetTurnUserMessage(conversationID, anchorMessageID string) (string, error) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
anchorMessageID = strings.TrimSpace(anchorMessageID)
|
||||
if conversationID == "" || anchorMessageID == "" {
|
||||
return "", nil
|
||||
}
|
||||
var content string
|
||||
err := db.QueryRow(`
|
||||
SELECT m.content FROM messages m
|
||||
WHERE m.conversation_id = ? AND m.role = 'user'
|
||||
AND m.created_at <= COALESCE((SELECT created_at FROM messages WHERE id = ? AND conversation_id = ?), m.created_at)
|
||||
ORDER BY m.created_at DESC, m.rowid DESC
|
||||
LIMIT 1`, conversationID, anchorMessageID, conversationID).Scan(&content)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("query turn user message: %w", err)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// AssistantCognitionTexts 单条助手消息上的思考/推理/规划文本。
|
||||
type AssistantCognitionTexts struct {
|
||||
Thinking string
|
||||
ReasoningChain string
|
||||
Planning string
|
||||
}
|
||||
|
||||
// GetAssistantCognitionTexts 聚合助手消息在 process_details 中的 thinking / reasoning_chain / planning。
|
||||
func (db *DB) GetAssistantCognitionTexts(assistantMessageID string) (AssistantCognitionTexts, error) {
|
||||
assistantMessageID = strings.TrimSpace(assistantMessageID)
|
||||
if assistantMessageID == "" {
|
||||
return AssistantCognitionTexts{}, nil
|
||||
}
|
||||
rows, err := db.Query(`
|
||||
SELECT event_type, message FROM process_details
|
||||
WHERE message_id = ? AND event_type IN ('thinking', 'reasoning_chain', 'planning')
|
||||
ORDER BY created_at ASC, rowid ASC`, assistantMessageID)
|
||||
if err != nil {
|
||||
return AssistantCognitionTexts{}, fmt.Errorf("query assistant cognition: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var thinkingParts, reasoningParts, planningParts []string
|
||||
for rows.Next() {
|
||||
var eventType, message string
|
||||
if err := rows.Scan(&eventType, &message); err != nil {
|
||||
continue
|
||||
}
|
||||
msg := strings.TrimSpace(message)
|
||||
if msg == "" {
|
||||
continue
|
||||
}
|
||||
switch eventType {
|
||||
case "thinking":
|
||||
thinkingParts = append(thinkingParts, msg)
|
||||
case "reasoning_chain":
|
||||
reasoningParts = append(reasoningParts, msg)
|
||||
case "planning":
|
||||
planningParts = append(planningParts, msg)
|
||||
}
|
||||
}
|
||||
return AssistantCognitionTexts{
|
||||
Thinking: strings.Join(thinkingParts, "\n\n"),
|
||||
ReasoningChain: strings.Join(reasoningParts, "\n\n"),
|
||||
Planning: strings.Join(planningParts, "\n\n"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddProcessDetail 添加过程详情事件
|
||||
func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error {
|
||||
id := uuid.New().String()
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
@@ -388,9 +388,12 @@ func (db *DB) initTables() error {
|
||||
status TEXT NOT NULL DEFAULT 'open',
|
||||
vulnerability_type TEXT,
|
||||
target TEXT,
|
||||
proof TEXT,
|
||||
preconditions TEXT,
|
||||
reproduction_steps TEXT,
|
||||
evidence TEXT,
|
||||
impact TEXT,
|
||||
recommendation TEXT,
|
||||
retest_notes TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
project_id TEXT,
|
||||
@@ -584,6 +587,53 @@ func (db *DB) initTables() error {
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
createWorkflowDefinitionsTable := `
|
||||
CREATE TABLE IF NOT EXISTS workflow_definitions (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
version INTEGER NOT NULL DEFAULT 1,
|
||||
graph_json TEXT NOT NULL,
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL
|
||||
);`
|
||||
|
||||
createWorkflowRunsTable := `
|
||||
CREATE TABLE IF NOT EXISTS workflow_runs (
|
||||
id TEXT PRIMARY KEY,
|
||||
workflow_id TEXT NOT NULL,
|
||||
workflow_version INTEGER NOT NULL DEFAULT 1,
|
||||
conversation_id TEXT,
|
||||
project_id TEXT,
|
||||
role_id TEXT,
|
||||
status TEXT NOT NULL,
|
||||
input_json TEXT,
|
||||
output_json TEXT,
|
||||
error TEXT,
|
||||
pending_hitl_node_id TEXT,
|
||||
pending_hitl_json TEXT,
|
||||
started_at DATETIME NOT NULL,
|
||||
finished_at DATETIME,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL
|
||||
);`
|
||||
|
||||
createWorkflowNodeRunsTable := `
|
||||
CREATE TABLE IF NOT EXISTS workflow_node_runs (
|
||||
id TEXT PRIMARY KEY,
|
||||
run_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
input_json TEXT,
|
||||
output_json TEXT,
|
||||
error TEXT,
|
||||
started_at DATETIME NOT NULL,
|
||||
finished_at DATETIME,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (run_id) REFERENCES workflow_runs(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建索引
|
||||
createIndexes := `
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
|
||||
@@ -642,6 +692,12 @@ func (db *DB) initTables() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_category ON audit_logs(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_action ON audit_logs(action);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_result ON audit_logs(result);
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_definitions_updated_at ON workflow_definitions(updated_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_definitions_enabled ON workflow_definitions(enabled);
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_runs_workflow ON workflow_runs(workflow_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_runs_conversation ON workflow_runs(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_runs_status ON workflow_runs(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_node_runs_run ON workflow_node_runs(run_id);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(createConversationsTable); err != nil {
|
||||
@@ -727,6 +783,16 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建audit_logs表失败: %w", err)
|
||||
}
|
||||
|
||||
for tableName, ddl := range map[string]string{
|
||||
"workflow_definitions": createWorkflowDefinitionsTable,
|
||||
"workflow_runs": createWorkflowRunsTable,
|
||||
"workflow_node_runs": createWorkflowNodeRunsTable,
|
||||
} {
|
||||
if _, err := db.Exec(ddl); err != nil {
|
||||
return fmt.Errorf("创建%s表失败: %w", tableName, err)
|
||||
}
|
||||
}
|
||||
|
||||
for tableName, ddl := range map[string]string{
|
||||
"c2_listeners": createC2ListenersTable,
|
||||
"c2_sessions": createC2SessionsTable,
|
||||
@@ -784,6 +850,9 @@ func (db *DB) initTables() error {
|
||||
db.logger.Warn("迁移webshell_connections表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
if err := db.migrateWorkflowRunsTable(); err != nil {
|
||||
db.logger.Warn("迁移workflow_runs表失败", zap.Error(err))
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createIndexes); err != nil {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
@@ -1224,9 +1293,12 @@ func (db *DB) migrateVulnerabilitiesConversationFK() error {
|
||||
status TEXT NOT NULL DEFAULT 'open',
|
||||
vulnerability_type TEXT,
|
||||
target TEXT,
|
||||
proof TEXT,
|
||||
preconditions TEXT,
|
||||
reproduction_steps TEXT,
|
||||
evidence TEXT,
|
||||
impact TEXT,
|
||||
recommendation TEXT,
|
||||
retest_notes TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
project_id TEXT,
|
||||
@@ -1239,12 +1311,15 @@ func (db *DB) migrateVulnerabilitiesConversationFK() error {
|
||||
const copyRows = `
|
||||
INSERT INTO vulnerabilities_new (
|
||||
id, conversation_id, conversation_tag, task_tag, title, description,
|
||||
severity, status, vulnerability_type, target, proof, impact, recommendation,
|
||||
severity, status, vulnerability_type, target, preconditions, reproduction_steps,
|
||||
evidence, impact, recommendation, retest_notes,
|
||||
created_at, updated_at, project_id
|
||||
)
|
||||
SELECT
|
||||
id, conversation_id, conversation_tag, task_tag, title, description,
|
||||
severity, status, vulnerability_type, target, proof, impact, recommendation,
|
||||
severity, status, vulnerability_type, target,
|
||||
COALESCE(preconditions, ''), COALESCE(reproduction_steps, ''),
|
||||
COALESCE(evidence, ''), impact, recommendation, COALESCE(retest_notes, ''),
|
||||
created_at, updated_at, project_id
|
||||
FROM vulnerabilities;`
|
||||
if _, err := tx.Exec(copyRows); err != nil {
|
||||
@@ -1315,6 +1390,10 @@ func (db *DB) migrateVulnerabilitiesTable() error {
|
||||
{name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"},
|
||||
{name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"},
|
||||
{name: "project_id", stmt: "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"},
|
||||
{name: "preconditions", stmt: "ALTER TABLE vulnerabilities ADD COLUMN preconditions TEXT"},
|
||||
{name: "reproduction_steps", stmt: "ALTER TABLE vulnerabilities ADD COLUMN reproduction_steps TEXT"},
|
||||
{name: "evidence", stmt: "ALTER TABLE vulnerabilities ADD COLUMN evidence TEXT"},
|
||||
{name: "retest_notes", stmt: "ALTER TABLE vulnerabilities ADD COLUMN retest_notes TEXT"},
|
||||
}
|
||||
|
||||
for _, col := range columns {
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DeleteHitlInterruptLogsByIDs deletes decided HITL audit logs by id (pending rows are skipped).
|
||||
func (db *DB) DeleteHitlInterruptLogsByIDs(ids []string) (int64, error) {
|
||||
if db == nil {
|
||||
return 0, fmt.Errorf("database is nil")
|
||||
}
|
||||
clean := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id != "" {
|
||||
clean = append(clean, id)
|
||||
}
|
||||
}
|
||||
if len(clean) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
placeholders := strings.TrimRight(strings.Repeat("?,", len(clean)), ",")
|
||||
q := fmt.Sprintf(`DELETE FROM hitl_interrupts WHERE status != 'pending' AND id IN (%s)`, placeholders)
|
||||
args := make([]interface{}, len(clean))
|
||||
for i, id := range clean {
|
||||
args[i] = id
|
||||
}
|
||||
res, err := db.Exec(q, args...)
|
||||
if err != nil {
|
||||
db.logger.Error("批量删除人机协同审计日志失败", zap.Error(err), zap.Int("count", len(clean)))
|
||||
return 0, fmt.Errorf("批量删除人机协同审计日志失败: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// DeleteHitlInterruptLogsMatching deletes decided logs matching whereSQL (e.g. "WHERE 1=1 AND status != 'pending' ...").
|
||||
func (db *DB) DeleteHitlInterruptLogsMatching(whereSQL string, args []interface{}) (int64, error) {
|
||||
if db == nil {
|
||||
return 0, fmt.Errorf("database is nil")
|
||||
}
|
||||
whereSQL = strings.TrimSpace(whereSQL)
|
||||
if whereSQL == "" {
|
||||
return 0, fmt.Errorf("where clause is required")
|
||||
}
|
||||
q := `DELETE FROM hitl_interrupts ` + whereSQL
|
||||
res, err := db.Exec(q, args...)
|
||||
if err != nil {
|
||||
db.logger.Error("清空人机协同审计日志失败", zap.Error(err))
|
||||
return 0, fmt.Errorf("清空人机协同审计日志失败: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// PurgeHitlInterruptLogsBefore deletes decided logs with decided/created time before cutoff.
|
||||
func (db *DB) PurgeHitlInterruptLogsBefore(cutoff time.Time) (int64, error) {
|
||||
if db == nil {
|
||||
return 0, fmt.Errorf("database is nil")
|
||||
}
|
||||
res, err := db.Exec(
|
||||
`DELETE FROM hitl_interrupts WHERE status != 'pending' AND datetime(COALESCE(decided_at, created_at)) < datetime(?)`,
|
||||
cutoff.UTC().Format(time.RFC3339),
|
||||
)
|
||||
if err != nil {
|
||||
db.logger.Error("清理过期人机协同审计日志失败", zap.Error(err))
|
||||
return 0, fmt.Errorf("清理过期人机协同审计日志失败: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func ensureHitlInterruptsTable(t *testing.T, db *DB) {
|
||||
t.Helper()
|
||||
if _, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS hitl_interrupts (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
message_id TEXT,
|
||||
mode TEXT NOT NULL,
|
||||
tool_name TEXT NOT NULL,
|
||||
tool_call_id TEXT,
|
||||
payload TEXT,
|
||||
status TEXT NOT NULL,
|
||||
decision TEXT,
|
||||
decision_comment TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
decided_at DATETIME
|
||||
);`); err != nil {
|
||||
t.Fatalf("create hitl_interrupts: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteHitlInterruptLogsByIDs_skipsPending(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "hitl.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
ensureHitlInterruptsTable(t, db)
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
if _, err := db.Exec(`INSERT INTO hitl_interrupts
|
||||
(id, conversation_id, mode, tool_name, status, created_at)
|
||||
VALUES ('pending-1', 'c1', 'approval', 'exec', 'pending', ?)`, now); err != nil {
|
||||
t.Fatalf("insert pending: %v", err)
|
||||
}
|
||||
if _, err := db.Exec(`INSERT INTO hitl_interrupts
|
||||
(id, conversation_id, mode, tool_name, status, decision, created_at, decided_at)
|
||||
VALUES ('done-1', 'c1', 'approval', 'exec', 'decided', 'approve', ?, ?)`, now, now); err != nil {
|
||||
t.Fatalf("insert decided: %v", err)
|
||||
}
|
||||
|
||||
deleted, err := db.DeleteHitlInterruptLogsByIDs([]string{"pending-1", "done-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteHitlInterruptLogsByIDs: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("deleted = %d, want 1", deleted)
|
||||
}
|
||||
|
||||
var status string
|
||||
if err := db.QueryRow(`SELECT status FROM hitl_interrupts WHERE id = 'pending-1'`).Scan(&status); err != nil {
|
||||
t.Fatalf("pending row missing: %v", err)
|
||||
}
|
||||
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'done-1'`).Scan(new(string)); err == nil {
|
||||
t.Fatal("decided row should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPurgeHitlInterruptLogsBefore(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "hitl.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
ensureHitlInterruptsTable(t, db)
|
||||
|
||||
old := time.Now().AddDate(0, 0, -100).UTC().Format(time.RFC3339)
|
||||
recent := time.Now().AddDate(0, 0, -1).UTC().Format(time.RFC3339)
|
||||
for _, row := range []struct{ id, decided string }{
|
||||
{"old-1", old},
|
||||
{"new-1", recent},
|
||||
} {
|
||||
if _, err := db.Exec(`INSERT INTO hitl_interrupts
|
||||
(id, conversation_id, mode, tool_name, status, decision, created_at, decided_at)
|
||||
VALUES (?, 'c1', 'approval', 'exec', 'decided', 'approve', ?, ?)`, row.id, row.decided, row.decided); err != nil {
|
||||
t.Fatalf("insert %s: %v", row.id, err)
|
||||
}
|
||||
}
|
||||
|
||||
cutoff := time.Now().AddDate(0, 0, -90)
|
||||
deleted, err := db.PurgeHitlInterruptLogsBefore(cutoff)
|
||||
if err != nil {
|
||||
t.Fatalf("PurgeHitlInterruptLogsBefore: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("deleted = %d, want 1", deleted)
|
||||
}
|
||||
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'old-1'`).Scan(new(string)); err == nil {
|
||||
t.Fatal("old row should be purged")
|
||||
}
|
||||
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'new-1'`).Scan(new(string)); err != nil {
|
||||
t.Fatalf("new row should remain: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -111,19 +111,43 @@ func (db *DB) GetProject(id string) (*Project, error) {
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// CountProjects 统计项目数量。
|
||||
func (db *DB) CountProjects(status, search string) (int, error) {
|
||||
query := `SELECT COUNT(*) FROM projects WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
func projectListSearchPattern(q string) string {
|
||||
q = strings.TrimSpace(q)
|
||||
if q == "" {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteByte('%')
|
||||
for _, r := range q {
|
||||
switch r {
|
||||
case '%', '_', '\\':
|
||||
b.WriteByte('\\')
|
||||
b.WriteRune(r)
|
||||
default:
|
||||
b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
b.WriteByte('%')
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func appendProjectListFilters(query string, args []interface{}, status, search string) (string, []interface{}) {
|
||||
if s := strings.TrimSpace(status); s != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, s)
|
||||
}
|
||||
if q := strings.TrimSpace(search); q != "" {
|
||||
pattern := "%" + q + "%"
|
||||
query += " AND (name LIKE ? OR COALESCE(description,'') LIKE ?)"
|
||||
args = append(args, pattern, pattern)
|
||||
if pattern := projectListSearchPattern(search); pattern != "" {
|
||||
query += ` AND (LOWER(name) LIKE LOWER(?) ESCAPE '\' OR LOWER(COALESCE(description,'')) LIKE LOWER(?) ESCAPE '\' OR LOWER(id) LIKE LOWER(?) ESCAPE '\')`
|
||||
args = append(args, pattern, pattern, pattern)
|
||||
}
|
||||
return query, args
|
||||
}
|
||||
|
||||
// CountProjects 统计项目数量。
|
||||
func (db *DB) CountProjects(status, search string) (int, error) {
|
||||
query := `SELECT COUNT(*) FROM projects WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
query, args = appendProjectListFilters(query, args, status, search)
|
||||
var count int
|
||||
if err := db.QueryRow(query, args...).Scan(&count); err != nil {
|
||||
return 0, fmt.Errorf("统计项目失败: %w", err)
|
||||
@@ -139,15 +163,7 @@ func (db *DB) ListProjects(status, search string, limit, offset int) ([]*Project
|
||||
query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
||||
FROM projects WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
if s := strings.TrimSpace(status); s != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, s)
|
||||
}
|
||||
if q := strings.TrimSpace(search); q != "" {
|
||||
pattern := "%" + q + "%"
|
||||
query += " AND (name LIKE ? OR COALESCE(description,'') LIKE ?)"
|
||||
args = append(args, pattern, pattern)
|
||||
}
|
||||
query, args = appendProjectListFilters(query, args, status, search)
|
||||
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestListProjectsSearchCaseInsensitive(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "projects-search.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
p1, err := db.CreateProject(&Project{Name: "Alpha Security Review", Status: "active"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p2, err := db.CreateProject(&Project{Name: "beta-scan", Status: "active"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := db.CreateProject(&Project{Name: "Other", Status: "archived"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
search string
|
||||
status string
|
||||
want []string
|
||||
}{
|
||||
{name: "case insensitive name", search: "alpha", status: "active", want: []string{p1.ID}},
|
||||
{name: "upper query", search: "BETA", status: "active", want: []string{p2.ID}},
|
||||
{name: "search by id substring", search: p1.ID[:8], status: "", want: []string{p1.ID}},
|
||||
{name: "status filter", search: "alpha", status: "archived", want: nil},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
list, err := db.ListProjects(tc.status, tc.search, 50, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := make([]string, 0, len(list))
|
||||
for _, p := range list {
|
||||
got = append(got, p.ID)
|
||||
}
|
||||
if len(got) != len(tc.want) {
|
||||
t.Fatalf("got %v want %v", got, tc.want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tc.want[i] {
|
||||
t.Fatalf("got %v want %v", got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectListSearchPatternEscapesWildcards(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "projects-like.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
p, err := db.CreateProject(&Project{Name: "100% coverage", Status: "active"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
list, err := db.ListProjects("active", "100%", 50, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(list) != 1 || list[0].ID != p.ID {
|
||||
t.Fatalf("expected exact match for literal %% query, got %#v", list)
|
||||
}
|
||||
}
|
||||
@@ -72,14 +72,17 @@ func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (
|
||||
LOWER(COALESCE(description, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(target, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(preconditions, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(reproduction_steps, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(evidence, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(retest_notes, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(task_tag, '')) LIKE LOWER(?)
|
||||
)`
|
||||
for i := 0; i < 11; i++ {
|
||||
for i := 0; i < 14; i++ {
|
||||
args = append(args, pattern)
|
||||
}
|
||||
}
|
||||
@@ -101,9 +104,12 @@ type Vulnerability struct {
|
||||
Status string `json:"status"` // open, confirmed, fixed, false_positive, ignored
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Preconditions string `json:"preconditions"`
|
||||
ReproSteps string `json:"reproduction_steps"`
|
||||
Evidence string `json:"evidence"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
RetestNotes string `json:"retest_notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -131,16 +137,16 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
||||
query := `
|
||||
INSERT INTO vulnerabilities (
|
||||
id, conversation_id, project_id, conversation_tag, task_tag, title, description, severity, status,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
vulnerability_type, target, preconditions, reproduction_steps, evidence, impact, recommendation, retest_notes,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := db.Exec(
|
||||
query,
|
||||
vuln.ID, nullIfEmpty(vuln.ConversationID), nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
|
||||
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
|
||||
vuln.Proof, vuln.Impact, vuln.Recommendation,
|
||||
vuln.Preconditions, vuln.ReproSteps, vuln.Evidence, vuln.Impact, vuln.Recommendation, vuln.RetestNotes,
|
||||
vuln.CreatedAt, vuln.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -155,7 +161,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
var vuln Vulnerability
|
||||
query := `
|
||||
SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status,
|
||||
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
|
||||
conversation_tag, task_tag, vulnerability_type, target,
|
||||
COALESCE(preconditions,''), COALESCE(reproduction_steps,''), COALESCE(evidence,''),
|
||||
impact, recommendation, COALESCE(retest_notes,''),
|
||||
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
||||
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
||||
created_at, updated_at
|
||||
@@ -166,7 +174,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
err := db.QueryRow(query, id).Scan(
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
|
||||
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
||||
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
||||
&vuln.Preconditions, &vuln.ReproSteps, &vuln.Evidence, &vuln.Impact, &vuln.Recommendation, &vuln.RetestNotes,
|
||||
&vuln.TaskID, &vuln.TaskQueueID,
|
||||
&vuln.CreatedAt, &vuln.UpdatedAt,
|
||||
)
|
||||
@@ -184,7 +192,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
|
||||
query := `
|
||||
SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
vulnerability_type, target,
|
||||
COALESCE(preconditions,''), COALESCE(reproduction_steps,''), COALESCE(evidence,''),
|
||||
impact, recommendation, COALESCE(retest_notes,''),
|
||||
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
||||
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
||||
created_at, updated_at
|
||||
@@ -209,7 +219,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFil
|
||||
err := rows.Scan(
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
|
||||
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
||||
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
||||
&vuln.Preconditions, &vuln.ReproSteps, &vuln.Evidence, &vuln.Impact, &vuln.Recommendation, &vuln.RetestNotes,
|
||||
&vuln.TaskID, &vuln.TaskQueueID,
|
||||
&vuln.CreatedAt, &vuln.UpdatedAt,
|
||||
)
|
||||
@@ -245,16 +255,16 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
||||
query := `
|
||||
UPDATE vulnerabilities
|
||||
SET project_id = ?, conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
|
||||
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
|
||||
recommendation = ?, updated_at = ?
|
||||
vulnerability_type = ?, target = ?, preconditions = ?, reproduction_steps = ?, evidence = ?, impact = ?,
|
||||
recommendation = ?, retest_notes = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
_, err := db.Exec(
|
||||
query,
|
||||
nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
|
||||
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
|
||||
vuln.Recommendation, vuln.UpdatedAt, id,
|
||||
vuln.Type, vuln.Target, vuln.Preconditions, vuln.ReproSteps, vuln.Evidence, vuln.Impact,
|
||||
vuln.Recommendation, vuln.RetestNotes, vuln.UpdatedAt, id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新漏洞失败: %w", err)
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WorkflowDefinition is a persisted user-defined graph/workflow template.
|
||||
// graph_json intentionally remains opaque so users can define their own fields.
|
||||
type WorkflowDefinition struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Version int `json:"version"`
|
||||
GraphJSON string `json:"graph_json"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type WorkflowRun struct {
|
||||
ID string `json:"id"`
|
||||
WorkflowID string `json:"workflow_id"`
|
||||
WorkflowVersion int `json:"workflow_version"`
|
||||
ConversationID string `json:"conversation_id,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
RoleID string `json:"role_id,omitempty"`
|
||||
Status string `json:"status"`
|
||||
InputJSON string `json:"input_json,omitempty"`
|
||||
OutputJSON string `json:"output_json,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
PendingHITLNodeID string `json:"pending_hitl_node_id,omitempty"`
|
||||
PendingHITLJSON string `json:"pending_hitl_json,omitempty"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||
}
|
||||
|
||||
type WorkflowNodeRun struct {
|
||||
ID string `json:"id"`
|
||||
RunID string `json:"run_id"`
|
||||
NodeID string `json:"node_id"`
|
||||
Status string `json:"status"`
|
||||
InputJSON string `json:"input_json,omitempty"`
|
||||
OutputJSON string `json:"output_json,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||
}
|
||||
|
||||
func scanWorkflowDefinition(scanner interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}) (*WorkflowDefinition, error) {
|
||||
var row WorkflowDefinition
|
||||
var desc sql.NullString
|
||||
var enabled int
|
||||
if err := scanner.Scan(&row.ID, &row.Name, &desc, &row.Version, &row.GraphJSON, &enabled, &row.CreatedAt, &row.UpdatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
row.Description = desc.String
|
||||
row.Enabled = enabled != 0
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
const workflowDefinitionColumns = `id, name, description, version, graph_json, enabled, created_at, updated_at`
|
||||
|
||||
func (db *DB) ListWorkflowDefinitions(includeDisabled bool) ([]*WorkflowDefinition, error) {
|
||||
query := "SELECT " + workflowDefinitionColumns + " FROM workflow_definitions"
|
||||
if !includeDisabled {
|
||||
query += " WHERE enabled = 1"
|
||||
}
|
||||
query += " ORDER BY updated_at DESC"
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询工作流列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*WorkflowDefinition
|
||||
for rows.Next() {
|
||||
wf, err := scanWorkflowDefinition(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("扫描工作流失败: %w", err)
|
||||
}
|
||||
out = append(out, wf)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (db *DB) GetWorkflowDefinition(id string) (*WorkflowDefinition, error) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return nil, nil
|
||||
}
|
||||
wf, err := scanWorkflowDefinition(db.QueryRow("SELECT "+workflowDefinitionColumns+" FROM workflow_definitions WHERE id = ?", id))
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询工作流失败: %w", err)
|
||||
}
|
||||
return wf, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpsertWorkflowDefinition(wf *WorkflowDefinition) error {
|
||||
if wf == nil {
|
||||
return fmt.Errorf("工作流为空")
|
||||
}
|
||||
wf.ID = strings.TrimSpace(wf.ID)
|
||||
wf.Name = strings.TrimSpace(wf.Name)
|
||||
if wf.ID == "" || wf.Name == "" {
|
||||
return fmt.Errorf("工作流 id 和 name 不能为空")
|
||||
}
|
||||
if strings.TrimSpace(wf.GraphJSON) == "" {
|
||||
wf.GraphJSON = `{"nodes":[],"edges":[],"config":{}}`
|
||||
}
|
||||
if wf.Version <= 0 {
|
||||
wf.Version = 1
|
||||
}
|
||||
now := time.Now()
|
||||
existing, err := db.GetWorkflowDefinition(wf.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing == nil {
|
||||
_, err = db.Exec(
|
||||
`INSERT INTO workflow_definitions (id, name, description, version, graph_json, enabled, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
wf.ID, wf.Name, wf.Description, wf.Version, wf.GraphJSON, boolToInt(wf.Enabled), now, now,
|
||||
)
|
||||
} else {
|
||||
nextVersion := existing.Version + 1
|
||||
if wf.Version > existing.Version {
|
||||
nextVersion = wf.Version
|
||||
}
|
||||
_, err = db.Exec(
|
||||
`UPDATE workflow_definitions
|
||||
SET name = ?, description = ?, version = ?, graph_json = ?, enabled = ?, updated_at = ?
|
||||
WHERE id = ?`,
|
||||
wf.Name, wf.Description, nextVersion, wf.GraphJSON, boolToInt(wf.Enabled), now, wf.ID,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存工作流失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) DeleteWorkflowDefinition(id string) error {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return fmt.Errorf("工作流 id 不能为空")
|
||||
}
|
||||
if _, err := db.Exec("DELETE FROM workflow_definitions WHERE id = ?", id); err != nil {
|
||||
return fmt.Errorf("删除工作流失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) CreateWorkflowRun(run *WorkflowRun) error {
|
||||
if run == nil {
|
||||
return fmt.Errorf("工作流运行为空")
|
||||
}
|
||||
if strings.TrimSpace(run.ID) == "" || strings.TrimSpace(run.WorkflowID) == "" {
|
||||
return fmt.Errorf("工作流运行 id 和 workflow_id 不能为空")
|
||||
}
|
||||
if run.WorkflowVersion <= 0 {
|
||||
run.WorkflowVersion = 1
|
||||
}
|
||||
if strings.TrimSpace(run.Status) == "" {
|
||||
run.Status = "running"
|
||||
}
|
||||
if run.StartedAt.IsZero() {
|
||||
run.StartedAt = time.Now()
|
||||
}
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO workflow_runs (id, workflow_id, workflow_version, conversation_id, project_id, role_id, status, input_json, started_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
run.ID, run.WorkflowID, run.WorkflowVersion, nullString(run.ConversationID), nullString(run.ProjectID), nullString(run.RoleID), run.Status, run.InputJSON, run.StartedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建工作流运行失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) FinishWorkflowRun(runID, status, outputJSON, errText string) error {
|
||||
runID = strings.TrimSpace(runID)
|
||||
if runID == "" {
|
||||
return fmt.Errorf("工作流运行 id 不能为空")
|
||||
}
|
||||
if strings.TrimSpace(status) == "" {
|
||||
status = "completed"
|
||||
}
|
||||
now := time.Now()
|
||||
_, err := db.Exec(
|
||||
`UPDATE workflow_runs SET status = ?, output_json = ?, error = ?, finished_at = ? WHERE id = ?`,
|
||||
status, outputJSON, errText, now, runID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新工作流运行失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) CreateWorkflowNodeRun(n *WorkflowNodeRun) error {
|
||||
if n == nil {
|
||||
return fmt.Errorf("工作流节点运行为空")
|
||||
}
|
||||
if strings.TrimSpace(n.ID) == "" || strings.TrimSpace(n.RunID) == "" || strings.TrimSpace(n.NodeID) == "" {
|
||||
return fmt.Errorf("节点运行 id、run_id 和 node_id 不能为空")
|
||||
}
|
||||
if strings.TrimSpace(n.Status) == "" {
|
||||
n.Status = "running"
|
||||
}
|
||||
if n.StartedAt.IsZero() {
|
||||
n.StartedAt = time.Now()
|
||||
}
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO workflow_node_runs (id, run_id, node_id, status, input_json, started_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
n.ID, n.RunID, n.NodeID, n.Status, n.InputJSON, n.StartedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建工作流节点运行失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) FinishWorkflowNodeRun(nodeRunID, status, outputJSON, errText string) error {
|
||||
nodeRunID = strings.TrimSpace(nodeRunID)
|
||||
if nodeRunID == "" {
|
||||
return fmt.Errorf("节点运行 id 不能为空")
|
||||
}
|
||||
if strings.TrimSpace(status) == "" {
|
||||
status = "completed"
|
||||
}
|
||||
now := time.Now()
|
||||
_, err := db.Exec(
|
||||
`UPDATE workflow_node_runs SET status = ?, output_json = ?, error = ?, finished_at = ? WHERE id = ?`,
|
||||
status, outputJSON, errText, now, nodeRunID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新工作流节点运行失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanWorkflowRun(scanner interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}) (*WorkflowRun, error) {
|
||||
var row WorkflowRun
|
||||
var convID, projectID, roleID, inputJSON, outputJSON, errText, pendingNode, pendingJSON sql.NullString
|
||||
var finishedAt sql.NullTime
|
||||
if err := scanner.Scan(
|
||||
&row.ID, &row.WorkflowID, &row.WorkflowVersion,
|
||||
&convID, &projectID, &roleID, &row.Status,
|
||||
&inputJSON, &outputJSON, &errText,
|
||||
&pendingNode, &pendingJSON,
|
||||
&row.StartedAt, &finishedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
row.ConversationID = convID.String
|
||||
row.ProjectID = projectID.String
|
||||
row.RoleID = roleID.String
|
||||
row.InputJSON = inputJSON.String
|
||||
row.OutputJSON = outputJSON.String
|
||||
row.Error = errText.String
|
||||
row.PendingHITLNodeID = pendingNode.String
|
||||
row.PendingHITLJSON = pendingJSON.String
|
||||
if finishedAt.Valid {
|
||||
t := finishedAt.Time
|
||||
row.FinishedAt = &t
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
const workflowRunColumns = `id, workflow_id, workflow_version, conversation_id, project_id, role_id, status, input_json, output_json, error, pending_hitl_node_id, pending_hitl_json, started_at, finished_at`
|
||||
|
||||
func (db *DB) GetWorkflowRun(runID string) (*WorkflowRun, error) {
|
||||
runID = strings.TrimSpace(runID)
|
||||
if runID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
row, err := scanWorkflowRun(db.QueryRow("SELECT "+workflowRunColumns+" FROM workflow_runs WHERE id = ?", runID))
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询工作流运行失败: %w", err)
|
||||
}
|
||||
return row, nil
|
||||
}
|
||||
|
||||
func (db *DB) SetWorkflowRunStatus(runID, status string) error {
|
||||
runID = strings.TrimSpace(runID)
|
||||
if runID == "" {
|
||||
return fmt.Errorf("工作流运行 id 不能为空")
|
||||
}
|
||||
_, err := db.Exec(`UPDATE workflow_runs SET status = ? WHERE id = ?`, strings.TrimSpace(status), runID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新工作流运行状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) SetWorkflowRunAwaitingHITL(runID, nodeID, pendingJSON string) error {
|
||||
runID = strings.TrimSpace(runID)
|
||||
if runID == "" {
|
||||
return fmt.Errorf("工作流运行 id 不能为空")
|
||||
}
|
||||
_, err := db.Exec(
|
||||
`UPDATE workflow_runs SET status = 'awaiting_hitl', pending_hitl_node_id = ?, pending_hitl_json = ?, finished_at = NULL WHERE id = ?`,
|
||||
strings.TrimSpace(nodeID), pendingJSON, runID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新工作流 HITL 等待状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordWorkflowRunHITLDecision stores a human decision on a paused workflow run.
|
||||
func (db *DB) RecordWorkflowRunHITLDecision(runID string, approved bool, comment string) error {
|
||||
runID = strings.TrimSpace(runID)
|
||||
if runID == "" {
|
||||
return fmt.Errorf("工作流运行 id 不能为空")
|
||||
}
|
||||
run, err := db.GetWorkflowRun(runID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if run == nil {
|
||||
return fmt.Errorf("工作流运行不存在")
|
||||
}
|
||||
pending := map[string]interface{}{}
|
||||
if strings.TrimSpace(run.PendingHITLJSON) != "" {
|
||||
_ = json.Unmarshal([]byte(run.PendingHITLJSON), &pending)
|
||||
}
|
||||
if approved {
|
||||
pending["decision"] = "approved"
|
||||
} else {
|
||||
pending["decision"] = "rejected"
|
||||
}
|
||||
pending["comment"] = strings.TrimSpace(comment)
|
||||
raw, _ := json.Marshal(pending)
|
||||
_, err = db.Exec(
|
||||
`UPDATE workflow_runs SET pending_hitl_json = ? WHERE id = ? AND status = 'awaiting_hitl'`,
|
||||
string(raw), runID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("记录工作流审批决定失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) ListWorkflowRunsAwaitingHITL(limit int) ([]*WorkflowRun, error) {
|
||||
return db.ListWorkflowRunsAwaitingHITLFiltered("", limit)
|
||||
}
|
||||
|
||||
// ListWorkflowRunsAwaitingHITLFiltered returns awaiting_hitl runs, optionally scoped to a conversation.
|
||||
func (db *DB) ListWorkflowRunsAwaitingHITLFiltered(conversationID string, limit int) ([]*WorkflowRun, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if conversationID != "" {
|
||||
rows, err = db.Query(
|
||||
`SELECT `+workflowRunColumns+` FROM workflow_runs WHERE status = 'awaiting_hitl' AND conversation_id = ? ORDER BY started_at DESC LIMIT ?`,
|
||||
conversationID, limit,
|
||||
)
|
||||
} else {
|
||||
rows, err = db.Query(
|
||||
`SELECT `+workflowRunColumns+` FROM workflow_runs WHERE status = 'awaiting_hitl' ORDER BY started_at DESC LIMIT ?`,
|
||||
limit,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询等待审批的工作流运行失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*WorkflowRun
|
||||
for rows.Next() {
|
||||
row, err := scanWorkflowRun(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (db *DB) migrateWorkflowRunsTable() error {
|
||||
cols := []struct{ name, ddl string }{
|
||||
{"pending_hitl_node_id", "ALTER TABLE workflow_runs ADD COLUMN pending_hitl_node_id TEXT"},
|
||||
{"pending_hitl_json", "ALTER TABLE workflow_runs ADD COLUMN pending_hitl_json TEXT"},
|
||||
}
|
||||
for _, col := range cols {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('workflow_runs') WHERE name=?", col.name).Scan(&count)
|
||||
if err != nil || count > 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := db.Exec(col.ddl); err != nil {
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nullString(v string) interface{} {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
+102
-9
@@ -77,6 +77,13 @@ type responsePlanAgg struct {
|
||||
b strings.Builder
|
||||
}
|
||||
|
||||
// thinkingBuf aggregates thinking_stream_* / reasoning_chain_stream_* before flush to process_details.
|
||||
type thinkingBuf struct {
|
||||
b strings.Builder
|
||||
meta map[string]interface{}
|
||||
persistAs string // "thinking" | "reasoning_chain"
|
||||
}
|
||||
|
||||
func normalizeProcessDetailText(s string) string {
|
||||
s = strings.ReplaceAll(s, "\r\n", "\n")
|
||||
s = strings.ReplaceAll(s, "\r", "\n")
|
||||
@@ -178,7 +185,10 @@ type AgentHandler struct {
|
||||
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
|
||||
batchCronParser cron.Parser
|
||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||
hitlStrategySaver HitlAuditStrategySaver
|
||||
hitlDefaultReviewerSaver HitlDefaultReviewerSaver
|
||||
auditLLM *openai.Client
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
@@ -218,9 +228,10 @@ func (h *AgentHandler) cancelActiveMCPToolForConversation(conversationID string)
|
||||
}
|
||||
}
|
||||
|
||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||
// HitlToolWhitelistSaver 合并/设置 HITL 免审批工具到全局配置并落盘
|
||||
type HitlToolWhitelistSaver interface {
|
||||
MergeHitlToolWhitelistIntoConfig(add []string) error
|
||||
SetHitlToolWhitelist(tools []string) error
|
||||
}
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
@@ -236,6 +247,11 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
||||
bus := NewTaskEventBus()
|
||||
tm := NewAgentTaskManager()
|
||||
tm.SetTaskEventBus(bus)
|
||||
llmHTTP := &http.Client{Timeout: 2 * time.Minute}
|
||||
var llmCfg *config.OpenAIConfig
|
||||
if cfg != nil {
|
||||
llmCfg = &cfg.OpenAI
|
||||
}
|
||||
handler := &AgentHandler{
|
||||
agent: agent,
|
||||
db: db,
|
||||
@@ -246,6 +262,7 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
||||
config: cfg,
|
||||
hitlManager: NewHITLManager(db, logger),
|
||||
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
|
||||
auditLLM: openai.NewClient(llmCfg, llmHTTP, logger),
|
||||
}
|
||||
tm.SetToolCanceler(handler.cancelActiveMCPToolForConversation)
|
||||
if err := handler.hitlManager.EnsureSchema(); err != nil {
|
||||
@@ -272,6 +289,23 @@ func (h *AgentHandler) SetHitlToolWhitelistSaver(s HitlToolWhitelistSaver) {
|
||||
h.hitlWhitelistSaver = s
|
||||
}
|
||||
|
||||
// HitlDefaultReviewerSaver 持久化全局默认审批方到 config.yaml。
|
||||
type HitlDefaultReviewerSaver interface {
|
||||
UpdateHitlDefaultReviewer(reviewer string) error
|
||||
}
|
||||
|
||||
// SetHitlDefaultReviewerSaver 设置 HITL 默认审批方落盘。
|
||||
func (h *AgentHandler) SetHitlDefaultReviewerSaver(s HitlDefaultReviewerSaver) {
|
||||
h.hitlDefaultReviewerSaver = s
|
||||
}
|
||||
|
||||
func (h *AgentHandler) hitlEffectiveDefaultReviewer() string {
|
||||
if h != nil && h.config != nil {
|
||||
return normalizeHitlReviewer(h.config.Hitl.EffectiveDefaultReviewer())
|
||||
}
|
||||
return "human"
|
||||
}
|
||||
|
||||
// HITLNeedsToolApproval 供 C2 危险任务门控:与会话侧人机协同及免审批白名单判定一致。
|
||||
func (h *AgentHandler) HITLNeedsToolApproval(conversationID, toolName string) bool {
|
||||
if h == nil || h.hitlManager == nil {
|
||||
@@ -320,6 +354,7 @@ func chatReasoningToClientIntent(r *ChatReasoningRequest) *reasoning.ClientInten
|
||||
type HITLRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Reviewer string `json:"reviewer,omitempty"` // human | audit_agent
|
||||
SensitiveTools []string `json:"sensitiveTools,omitempty"`
|
||||
TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
|
||||
}
|
||||
@@ -849,11 +884,6 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
|
||||
// thinking_stream_*(ReAct 等助手正文流)与 reasoning_chain_stream_*(Eino ReasoningContent):
|
||||
// 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。
|
||||
type thinkingBuf struct {
|
||||
b strings.Builder
|
||||
meta map[string]interface{}
|
||||
persistAs string // "thinking" | "reasoning_chain"
|
||||
}
|
||||
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
|
||||
flushedThinking := make(map[string]bool) // streamId -> flushed
|
||||
seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature
|
||||
@@ -866,6 +896,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
// response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta;
|
||||
// 聚合为一条 planning 写入 process_details,刷新后与线上一致。
|
||||
var respPlan responsePlanAgg
|
||||
if assistantMessageID != "" {
|
||||
h.tasks.SetHitlAssistantMessageID(conversationID, assistantMessageID)
|
||||
}
|
||||
syncHitlCognition := func() {
|
||||
h.syncHitlCognitionFromProgress(conversationID, assistantMessageID, thinkingStreams, &respPlan)
|
||||
}
|
||||
flushResponsePlan := func() {
|
||||
if assistantMessageID == "" {
|
||||
return
|
||||
@@ -885,6 +921,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "planning", content, data); err != nil {
|
||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "planning"))
|
||||
}
|
||||
syncHitlCognition()
|
||||
respPlan.meta = nil
|
||||
respPlan.b.Reset()
|
||||
}
|
||||
@@ -921,6 +958,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
flushedThinking[sid] = true
|
||||
}
|
||||
syncHitlCognition()
|
||||
}
|
||||
|
||||
return func(eventType, message string, data interface{}) {
|
||||
@@ -981,6 +1019,25 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
}
|
||||
|
||||
if eventType == "tool_result" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolName, _ := dataMap["toolName"].(string)
|
||||
toolCallID, _ := dataMap["toolCallId"].(string)
|
||||
success := true
|
||||
if v, ok := dataMap["success"].(bool); ok {
|
||||
success = v
|
||||
}
|
||||
resultText := ""
|
||||
if r, ok := dataMap["result"].(string); ok {
|
||||
resultText = r
|
||||
}
|
||||
if strings.TrimSpace(resultText) == "" {
|
||||
resultText = message
|
||||
}
|
||||
h.recordHitlToolExecutionResult(conversationID, toolCallID, toolName, success, resultText)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理知识检索日志记录
|
||||
if eventType == "tool_result" && h.knowledgeManager != nil {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
@@ -1188,6 +1245,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
respPlan.meta[k] = v
|
||||
}
|
||||
}
|
||||
syncHitlCognition()
|
||||
return
|
||||
}
|
||||
if eventType == "response" {
|
||||
@@ -1257,6 +1315,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
}
|
||||
}
|
||||
syncHitlCognition()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1489,17 +1548,51 @@ func (h *AgentHandler) SubscribeAgentTaskEvents(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// enrichAgentTasksWithConversationTitles 为任务列表附加当前会话标题(供顶栏/任务页展示,重命名后自动同步)
|
||||
func (h *AgentHandler) enrichAgentTasksWithConversationTitles(tasks []*AgentTask) {
|
||||
if h == nil || h.db == nil {
|
||||
return
|
||||
}
|
||||
for _, task := range tasks {
|
||||
if task == nil || strings.TrimSpace(task.ConversationID) == "" {
|
||||
continue
|
||||
}
|
||||
if title, err := h.db.GetConversationTitle(task.ConversationID); err == nil {
|
||||
task.Title = strings.TrimSpace(title)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// enrichCompletedTasksWithConversationTitles 为已完成任务附加当前会话标题
|
||||
func (h *AgentHandler) enrichCompletedTasksWithConversationTitles(tasks []*CompletedTask) {
|
||||
if h == nil || h.db == nil {
|
||||
return
|
||||
}
|
||||
for _, task := range tasks {
|
||||
if task == nil || strings.TrimSpace(task.ConversationID) == "" {
|
||||
continue
|
||||
}
|
||||
if title, err := h.db.GetConversationTitle(task.ConversationID); err == nil {
|
||||
task.Title = strings.TrimSpace(title)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ListAgentTasks 列出所有运行中的任务
|
||||
func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
|
||||
tasks := h.tasks.GetActiveTasks()
|
||||
h.enrichAgentTasksWithConversationTitles(tasks)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"tasks": h.tasks.GetActiveTasks(),
|
||||
"tasks": tasks,
|
||||
})
|
||||
}
|
||||
|
||||
// ListCompletedTasks 列出最近完成的任务历史
|
||||
func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
|
||||
tasks := h.tasks.GetCompletedTasks()
|
||||
h.enrichCompletedTasksWithConversationTitles(tasks)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"tasks": h.tasks.GetCompletedTasks(),
|
||||
"tasks": tasks,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -798,6 +798,10 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
|
||||
// 更新机器人配置
|
||||
if req.Robots != nil {
|
||||
if err := config.ValidateWecomConfig(req.Robots.Wecom); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
h.config.Robots = *req.Robots
|
||||
h.logger.Info("更新机器人配置",
|
||||
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
|
||||
@@ -1329,6 +1333,17 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("已更新嵌入模型配置记录")
|
||||
}
|
||||
|
||||
// 从 tools 目录重新加载工具配置(新增/修改/删除 yaml 后无需重启)
|
||||
if err := config.ReloadSecurityToolsFromDir(h.config, h.configPath); err != nil {
|
||||
h.logger.Error("重新加载工具配置失败", zap.Error(err))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新加载工具", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新加载工具配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
h.logger.Info("已从 tools 目录重新加载工具配置", zap.Int("tools_count", len(h.config.Security.Tools)))
|
||||
|
||||
// 重新注册工具(根据新的启用状态)
|
||||
h.logger.Info("重新注册工具")
|
||||
|
||||
@@ -1417,12 +1432,7 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
|
||||
// 更新检索器配置(如果知识库启用)
|
||||
if h.config.Knowledge.Enabled && h.retrieverUpdater != nil {
|
||||
retrievalConfig := &knowledge.RetrievalConfig{
|
||||
TopK: h.config.Knowledge.Retrieval.TopK,
|
||||
SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold,
|
||||
SubIndexFilter: h.config.Knowledge.Retrieval.SubIndexFilter,
|
||||
PostRetrieve: h.config.Knowledge.Retrieval.PostRetrieve,
|
||||
}
|
||||
retrievalConfig := knowledge.RetrievalConfigFromYAML(h.config.Knowledge.Retrieval)
|
||||
h.retrieverUpdater.UpdateConfig(retrievalConfig)
|
||||
h.logger.Info("检索器配置已更新",
|
||||
zap.Int("top_k", retrievalConfig.TopK),
|
||||
@@ -1705,6 +1715,13 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
|
||||
setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK)
|
||||
setFloatInMap(retrievalNode, "similarity_threshold", cfg.Retrieval.SimilarityThreshold)
|
||||
setStringInMap(retrievalNode, "sub_index_filter", cfg.Retrieval.SubIndexFilter)
|
||||
mqNode := ensureMap(retrievalNode, "multi_query")
|
||||
setIntInMap(mqNode, "max_queries", cfg.Retrieval.MultiQuery.MaxQueries)
|
||||
rerankNode := ensureMap(retrievalNode, "rerank")
|
||||
setStringInMap(rerankNode, "provider", cfg.Retrieval.Rerank.Provider)
|
||||
setStringInMap(rerankNode, "model", cfg.Retrieval.Rerank.Model)
|
||||
setStringInMap(rerankNode, "base_url", cfg.Retrieval.Rerank.BaseURL)
|
||||
setStringInMap(rerankNode, "api_key", cfg.Retrieval.Rerank.APIKey)
|
||||
postNode := ensureMap(retrievalNode, "post_retrieve")
|
||||
setIntInMap(postNode, "prefetch_top_k", cfg.Retrieval.PostRetrieve.PrefetchTopK)
|
||||
setIntInMap(postNode, "max_context_chars", cfg.Retrieval.PostRetrieve.MaxContextChars)
|
||||
@@ -1751,6 +1768,20 @@ func mergeHitlToolWhitelistSlice(existing, add []string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
// SetHitlToolWhitelist 将全局免审批工具白名单整表写入 config.yaml(替换,非合并)。
|
||||
func (h *ConfigHandler) SetHitlToolWhitelist(tools []string) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.config.Hitl.ToolWhitelist = mergeHitlToolWhitelistSlice(nil, tools)
|
||||
if err := h.saveConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Info("HITL 全局工具白名单已写入配置文件",
|
||||
zap.Int("count", len(h.config.Hitl.ToolWhitelist)),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MergeHitlToolWhitelistIntoConfig 将会话侧栏提交的免审批工具名合并进内存配置并写入 config.yaml(与全局白名单去重规则一致:小写键、保留首次出现的原始大小写)。
|
||||
func (h *ConfigHandler) MergeHitlToolWhitelistIntoConfig(add []string) error {
|
||||
h.mu.Lock()
|
||||
@@ -1771,6 +1802,34 @@ func updateHitlConfig(doc *yaml.Node, cfg config.HitlConfig) {
|
||||
hitlNode := ensureMap(root, "hitl")
|
||||
// flow 样式 [a, b, c] 单行展示,工具多时比块序列省行数
|
||||
setFlowStringSliceInMap(hitlNode, "tool_whitelist", cfg.ToolWhitelist)
|
||||
setStringInMap(hitlNode, "default_reviewer", cfg.EffectiveDefaultReviewer())
|
||||
setStringInMap(hitlNode, "audit_agent_prompt", cfg.AuditAgentPrompt)
|
||||
setStringInMap(hitlNode, "audit_agent_prompt_review_edit", cfg.AuditAgentPromptReviewEdit)
|
||||
}
|
||||
|
||||
// UpdateHitlDefaultReviewer 更新全局默认审批方并写入 config.yaml。
|
||||
func (h *ConfigHandler) UpdateHitlDefaultReviewer(reviewer string) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.config.Hitl.DefaultReviewer = config.HitlConfig{DefaultReviewer: reviewer}.EffectiveDefaultReviewer()
|
||||
if err := h.saveConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Info("HITL 全局默认审批方已写入配置文件", zap.String("default_reviewer", h.config.Hitl.DefaultReviewer))
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateHitlAuditAgentStrategy 更新审批/审查编辑两套审计 Agent 提示词并写入 config.yaml。
|
||||
func (h *ConfigHandler) UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt string) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.config.Hitl.AuditAgentPrompt = strings.TrimSpace(approvalPrompt)
|
||||
h.config.Hitl.AuditAgentPromptReviewEdit = strings.TrimSpace(reviewEditPrompt)
|
||||
if err := h.saveConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Info("HITL 审计 Agent 提示词已写入配置文件")
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// rebindEinoRunningTask 中断并继续 / 空正文续跑:重建 cancel 链与超时 ctx,保持任务 running。
|
||||
func (h *AgentHandler) rebindEinoRunningTask(conversationID string, timeoutCancel context.CancelFunc) (context.Context, context.CancelCauseFunc, context.Context, context.CancelFunc) {
|
||||
if timeoutCancel != nil {
|
||||
timeoutCancel()
|
||||
}
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, newTimeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
return baseCtx, cancelWithCause, taskCtx, newTimeoutCancel
|
||||
}
|
||||
|
||||
// tryContinueOnEinoEmptyResponse Run 成功但 Response 为 emptyHint 时退避续跑;true 表示已准备下一段 Run。
|
||||
func (h *AgentHandler) tryContinueOnEinoEmptyResponse(
|
||||
taskCtx context.Context,
|
||||
mw *config.MultiAgentEinoMiddlewareConfig,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
attempt *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
) bool {
|
||||
if result == nil || !multiagent.IsEinoEmptyResponseResult(result) || !multiagent.HasEinoResumeTrace(result) {
|
||||
return false
|
||||
}
|
||||
maxAttempts := multiagent.EmptyResponseContinueMaxAttemptsFromConfig(mw)
|
||||
if *attempt >= maxAttempts {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("eino empty response continue exhausted",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("maxAttempts", maxAttempts))
|
||||
}
|
||||
return false
|
||||
}
|
||||
*attempt++
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
|
||||
backoff := multiagent.EmptyResponseContinueBackoff(*attempt-1, mw)
|
||||
waitMsg := fmt.Sprintf("会话已结束但未捕获到助手正文,%d 秒后第 %d/%d 次自动续跑…",
|
||||
int(backoff.Seconds()), *attempt, maxAttempts)
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_empty_response_continue", waitMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": *attempt,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
}
|
||||
select {
|
||||
case <-taskCtx.Done():
|
||||
return false
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
|
||||
inject := multiagent.FormatEmptyResponseContinueUserMessage()
|
||||
h.applyEinoTraceResumeSegment(conversationID, result, curHistory, curFinalMessage, inject)
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_empty_response_continue", "已恢复上下文,正在续跑…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": *attempt,
|
||||
"maxAttempts": maxAttempts,
|
||||
"contextSource": "empty_response_continue",
|
||||
})
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -116,6 +116,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"userMessageId": prep.UserMessageID,
|
||||
})
|
||||
}
|
||||
if h.runRoleWorkflowStreamIfBound(&req, prep, sendEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
@@ -178,6 +181,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
var emptyResponseContinueAttempt int
|
||||
|
||||
for {
|
||||
segmentMainIterationMax := 0
|
||||
@@ -239,6 +243,13 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
mw := &h.config.MultiAgent.EinoMiddleware
|
||||
if h.tryContinueOnEinoEmptyResponse(taskCtx, mw, conversationID, result, &emptyResponseContinueAttempt, &curHistory, &curFinalMessage, progressCallback) {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause, taskCtx, timeoutCancel = h.rebindEinoRunningTask(conversationID, timeoutCancel)
|
||||
continue
|
||||
}
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
@@ -377,6 +388,9 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(prep.ConversationID)
|
||||
}
|
||||
if h.runRoleWorkflowJSONIfBound(c, &req, prep) {
|
||||
return
|
||||
}
|
||||
|
||||
var progressBuf strings.Builder
|
||||
progressCallbackRaw := func(eventType, message string, data interface{}) {
|
||||
|
||||
+213
-91
@@ -23,6 +23,7 @@ import (
|
||||
type hitlRuntimeConfig struct {
|
||||
Enabled bool
|
||||
Mode string
|
||||
Reviewer string
|
||||
SensitiveTools map[string]struct{}
|
||||
Timeout time.Duration
|
||||
}
|
||||
@@ -49,6 +50,8 @@ type HITLManager struct {
|
||||
mu sync.RWMutex
|
||||
runtime map[string]hitlRuntimeConfig
|
||||
pending map[string]*pendingInterrupt
|
||||
// approvedExec 审批通过、待回写 tool_result 的队列(按会话 FIFO)
|
||||
approvedExec map[string][]hitlApprovedExecTrack
|
||||
}
|
||||
|
||||
func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager {
|
||||
@@ -90,6 +93,7 @@ CREATE TABLE IF NOT EXISTS hitl_conversation_configs (
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.migrateHitlSchemaColumns()
|
||||
|
||||
// On startup, cancel all orphaned pending interrupts from previous process.
|
||||
// Their in-memory channels are gone, so they can never be resolved.
|
||||
@@ -141,6 +145,7 @@ func (m *HITLManager) ActivateConversation(conversationID string, req *HITLReque
|
||||
m.runtime[conversationID] = hitlRuntimeConfig{
|
||||
Enabled: true,
|
||||
Mode: normalizeHitlMode(req.Mode),
|
||||
Reviewer: normalizeHitlReviewer(req.Reviewer),
|
||||
SensitiveTools: tools,
|
||||
Timeout: timeout,
|
||||
}
|
||||
@@ -153,17 +158,14 @@ func (m *HITLManager) DeactivateConversation(conversationID string) {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空)。
|
||||
// hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空),并合并内置元工具免审批项。
|
||||
func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string {
|
||||
if h == nil || h.config == nil {
|
||||
return nil
|
||||
return multiagent.MergeHitlExemptMetaTools(nil)
|
||||
}
|
||||
raw := h.config.Hitl.ToolWhitelist
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]string, 0, len(raw))
|
||||
out := make([]string, 0, len(raw)+len(multiagent.HitlExemptMetaTools))
|
||||
for _, t := range raw {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n == "" {
|
||||
@@ -175,44 +177,35 @@ func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string {
|
||||
seen[n] = struct{}{}
|
||||
out = append(out, strings.TrimSpace(t))
|
||||
}
|
||||
return out
|
||||
return multiagent.MergeHitlExemptMetaTools(out)
|
||||
}
|
||||
|
||||
// hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单合并(并集),仅用于运行时 Activate;不写入数据库。
|
||||
// hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单及内置元工具免审批项合并(并集),仅用于运行时 Activate;不写入数据库。
|
||||
func (h *AgentHandler) hitlRequestWithMergedConfigWhitelist(req *HITLRequest) *HITLRequest {
|
||||
gw := h.hitlConfigGlobalToolWhitelist()
|
||||
if len(gw) == 0 {
|
||||
return req
|
||||
}
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
union := make([]string, 0, len(gw)+len(req.SensitiveTools))
|
||||
for _, t := range gw {
|
||||
union := make([]string, 0, len(req.SensitiveTools)+16)
|
||||
add := func(t string) {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n == "" {
|
||||
continue
|
||||
return
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
return
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
union = append(union, strings.TrimSpace(t))
|
||||
}
|
||||
for _, t := range h.hitlConfigGlobalToolWhitelist() {
|
||||
add(t)
|
||||
}
|
||||
for _, t := range req.SensitiveTools {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
union = append(union, strings.TrimSpace(t))
|
||||
add(t)
|
||||
}
|
||||
out := *req
|
||||
out.SensitiveTools = union
|
||||
out.SensitiveTools = multiagent.MergeHitlExemptMetaTools(union)
|
||||
return &out
|
||||
}
|
||||
|
||||
@@ -362,22 +355,22 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq
|
||||
timeout = 0
|
||||
}
|
||||
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
|
||||
(conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
(conversation_id, enabled, mode, reviewer, sensitive_tools, timeout_seconds, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(conversation_id) DO UPDATE SET
|
||||
enabled=excluded.enabled, mode=excluded.mode, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`,
|
||||
conversationID, boolToInt(req.Enabled), mode, string(tools), timeout, time.Now())
|
||||
enabled=excluded.enabled, mode=excluded.mode, reviewer=excluded.reviewer, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`,
|
||||
conversationID, boolToInt(req.Enabled), mode, normalizeHitlReviewer(req.Reviewer), string(tools), timeout, time.Now())
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) {
|
||||
var enabledInt int
|
||||
var mode, toolsJSON string
|
||||
var mode, reviewer, toolsJSON string
|
||||
var timeout int
|
||||
err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
|
||||
Scan(&enabledInt, &mode, &toolsJSON, &timeout)
|
||||
err := m.db.QueryRow(`SELECT enabled, mode, COALESCE(reviewer,'human'), sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
|
||||
Scan(&enabledInt, &mode, &reviewer, &toolsJSON, &timeout)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
|
||||
return &HITLRequest{Enabled: false, Mode: "off", Reviewer: "human", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -390,11 +383,24 @@ func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLReques
|
||||
return &HITLRequest{
|
||||
Enabled: enabledInt == 1,
|
||||
Mode: mode,
|
||||
Reviewer: normalizeHitlReviewer(reviewer),
|
||||
SensitiveTools: tools,
|
||||
TimeoutSeconds: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *HITLManager) HasConversationConfig(conversationID string) (bool, error) {
|
||||
if strings.TrimSpace(conversationID) == "" {
|
||||
return false, nil
|
||||
}
|
||||
var one int
|
||||
err := m.db.QueryRow(`SELECT 1 FROM hitl_conversation_configs WHERE conversation_id = ? LIMIT 1`, conversationID).Scan(&one)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return err == nil, err
|
||||
}
|
||||
|
||||
func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, timeout time.Duration) (hitlDecision, error) {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
@@ -413,15 +419,16 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim
|
||||
if p.Mode != "review_edit" && len(d.EditedArguments) > 0 {
|
||||
d.EditedArguments = nil
|
||||
}
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`,
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=?, decided_by='human' WHERE id=?`,
|
||||
d.Decision, d.Comment, time.Now(), p.InterruptID)
|
||||
return d, nil
|
||||
case <-timeoutCh:
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`,
|
||||
time.Now(), p.InterruptID)
|
||||
return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil
|
||||
comment := "HITL timeout auto-reject for safety"
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='reject', decision_comment=?, decided_at=?, decided_by='system' WHERE id=?`,
|
||||
comment, time.Now(), p.InterruptID)
|
||||
return hitlDecision{Decision: "reject", Comment: comment}, nil
|
||||
case <-ctx.Done():
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=? WHERE id=?`,
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=?, decided_by='system' WHERE id=?`,
|
||||
time.Now(), p.InterruptID)
|
||||
return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err()
|
||||
}
|
||||
@@ -432,25 +439,88 @@ func (h *AgentHandler) activateHITLForConversation(conversationID string, req *H
|
||||
return
|
||||
}
|
||||
if req == nil {
|
||||
cfg, err := h.hitlManager.LoadConversationConfig(conversationID)
|
||||
cfg, err := h.loadHITLConversationConfig(conversationID)
|
||||
if err == nil {
|
||||
req = cfg
|
||||
}
|
||||
}
|
||||
if req != nil && strings.TrimSpace(req.Reviewer) == "" {
|
||||
req.Reviewer = h.hitlEffectiveDefaultReviewer()
|
||||
}
|
||||
h.hitlManager.ActivateConversation(conversationID, h.hitlRequestWithMergedConfigWhitelist(req))
|
||||
}
|
||||
|
||||
func (h *AgentHandler) loadHITLConversationConfig(conversationID string) (*HITLRequest, error) {
|
||||
cfg, err := h.hitlManager.LoadConversationConfig(conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
has, err := h.hitlManager.HasConversationConfig(conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !has {
|
||||
cfg.Reviewer = h.hitlEffectiveDefaultReviewer()
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID, toolName, toolCallID string, payload map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) (*hitlDecision, error) {
|
||||
cfg, need := h.hitlManager.shouldInterrupt(conversationID, toolName)
|
||||
if !need {
|
||||
return nil, nil
|
||||
}
|
||||
h.enrichHitlApprovalPayload(conversationID, assistantMessageID, payload)
|
||||
payloadRaw, _ := json.Marshal(payload)
|
||||
p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw))
|
||||
if err != nil {
|
||||
h.logger.Warn("创建 HITL 中断失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.Reviewer == "audit_agent" {
|
||||
ad := h.auditAgentReview(runCtx, cfg.Mode, toolName, payload)
|
||||
now := time.Now()
|
||||
_, _ = h.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=?, decided_by='audit_agent' WHERE id=?`,
|
||||
ad.Decision, ad.Comment, now, p.InterruptID)
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_audit_agent", "审计 Agent 已裁决", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
"mode": cfg.Mode,
|
||||
"decision": ad.Decision,
|
||||
"comment": ad.Comment,
|
||||
"editedArgs": ad.EditedArguments,
|
||||
"decidedBy": "audit_agent",
|
||||
})
|
||||
}
|
||||
if ad.Decision == "reject" {
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_rejected", "审计 Agent 拒绝本次工具调用", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
"comment": ad.Comment,
|
||||
"decidedBy": "audit_agent",
|
||||
})
|
||||
}
|
||||
return &ad, nil
|
||||
}
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_resumed", "审计 Agent 已通过,继续执行", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
"comment": ad.Comment,
|
||||
"editedArgs": ad.EditedArguments,
|
||||
"decidedBy": "audit_agent",
|
||||
})
|
||||
}
|
||||
h.hitlManager.TrackApprovedHitlExecution(p.InterruptID, conversationID, toolName, toolCallID)
|
||||
return &ad, nil
|
||||
}
|
||||
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
@@ -479,8 +549,12 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex
|
||||
return nil, waitErr
|
||||
}
|
||||
if d.Decision == "reject" {
|
||||
rejectMsg := "人工拒绝本次工具调用,模型将基于反馈继续迭代"
|
||||
if strings.Contains(strings.ToLower(strings.TrimSpace(d.Comment)), "timeout") {
|
||||
rejectMsg = "审批超时,安全起见已自动拒绝,模型将基于反馈继续迭代"
|
||||
}
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_rejected", "人工拒绝本次工具调用,模型将基于反馈继续迭代", map[string]interface{}{
|
||||
sendEventFunc("hitl_rejected", rejectMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
@@ -498,6 +572,7 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex
|
||||
"editedArgs": d.EditedArguments,
|
||||
})
|
||||
}
|
||||
h.hitlManager.TrackApprovedHitlExecution(p.InterruptID, conversationID, toolName, toolCallID)
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
@@ -527,11 +602,6 @@ func (h *AgentHandler) handleHITLToolCall(runCtx context.Context, cancelRun cont
|
||||
}
|
||||
|
||||
func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
||||
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
||||
status := strings.TrimSpace(c.Query("status"))
|
||||
if status == "" {
|
||||
status = "pending"
|
||||
}
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
@@ -539,15 +609,12 @@ func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
|
||||
offset := (page - 1) * pageSize
|
||||
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, created_at, decided_at FROM hitl_interrupts WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
if conversationID != "" {
|
||||
q += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if status != "all" {
|
||||
q += " AND status = ?"
|
||||
args = append(args, status)
|
||||
q, args := h.buildHitlListQuery(false)
|
||||
q, args = h.appendHitlListFilters(q, args, c)
|
||||
total, err := h.countHitlQuery(q, args)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
q += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, pageSize, offset)
|
||||
@@ -557,41 +624,12 @@ func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]map[string]interface{}, 0)
|
||||
for rows.Next() {
|
||||
var id, cid, mode, toolName, toolCallID, payload, rowStatus string
|
||||
var messageID sql.NullString
|
||||
var decision, comment sql.NullString
|
||||
var createdAt time.Time
|
||||
var decidedAt sql.NullTime
|
||||
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &createdAt, &decidedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
msgID := ""
|
||||
if messageID.Valid {
|
||||
msgID = messageID.String
|
||||
}
|
||||
items = append(items, map[string]interface{}{
|
||||
"id": id,
|
||||
"conversationId": cid,
|
||||
"messageId": msgID,
|
||||
"mode": mode,
|
||||
"toolName": toolName,
|
||||
"toolCallId": toolCallID,
|
||||
"payload": payload,
|
||||
"status": rowStatus,
|
||||
"decision": decision.String,
|
||||
"comment": comment.String,
|
||||
"createdAt": createdAt,
|
||||
"decidedAt": func() interface{} {
|
||||
if decidedAt.Valid {
|
||||
return decidedAt.Time
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
})
|
||||
items, err := h.scanHitlInterruptRows(rows)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize})
|
||||
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total})
|
||||
}
|
||||
|
||||
type hitlDecisionReq struct {
|
||||
@@ -636,7 +674,7 @@ func (h *AgentHandler) DismissHITLInterrupt(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
res, err := h.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
|
||||
decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP
|
||||
decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP, decided_by='human'
|
||||
WHERE id=? AND status='pending'`, req.InterruptID)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
@@ -702,7 +740,7 @@ func (h *AgentHandler) GetHITLConversationConfig(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
|
||||
return
|
||||
}
|
||||
cfg, err := h.hitlManager.LoadConversationConfig(conversationID)
|
||||
cfg, err := h.loadHITLConversationConfig(conversationID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -721,6 +759,7 @@ func (h *AgentHandler) GetHITLConversationConfig(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"conversationId": conversationID,
|
||||
"hitl": cfg,
|
||||
"defaultReviewer": h.hitlEffectiveDefaultReviewer(),
|
||||
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
})
|
||||
}
|
||||
@@ -732,6 +771,10 @@ func (h *AgentHandler) UpsertHITLConversationConfig(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Mode = normalizeHitlMode(req.Mode)
|
||||
req.Reviewer = normalizeHitlReviewer(req.Reviewer)
|
||||
if strings.TrimSpace(req.Reviewer) == "" {
|
||||
req.Reviewer = h.hitlEffectiveDefaultReviewer()
|
||||
}
|
||||
if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -753,6 +796,85 @@ type mergeHitlGlobalWhitelistReq struct {
|
||||
SensitiveTools []string `json:"sensitiveTools"`
|
||||
}
|
||||
|
||||
type setHitlGlobalWhitelistReq struct {
|
||||
ToolWhitelist []string `json:"toolWhitelist"`
|
||||
}
|
||||
|
||||
// GetHITLGlobalToolWhitelist 返回 config.yaml 中的全局免审批工具白名单。
|
||||
func (h *AgentHandler) GetHITLGlobalToolWhitelist(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"toolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
"defaultReviewer": h.hitlEffectiveDefaultReviewer(),
|
||||
})
|
||||
}
|
||||
|
||||
type setHitlDefaultReviewerReq struct {
|
||||
Reviewer string `json:"reviewer"`
|
||||
}
|
||||
|
||||
// GetHITLDefaultReviewer 返回 config.yaml 中的全局默认审批方。
|
||||
func (h *AgentHandler) GetHITLDefaultReviewer(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"defaultReviewer": h.hitlEffectiveDefaultReviewer(),
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateHITLDefaultReviewer 将全局默认审批方写入 config.yaml(未选会话时切换审批方)。
|
||||
func (h *AgentHandler) UpdateHITLDefaultReviewer(c *gin.Context) {
|
||||
if h.hitlDefaultReviewerSaver == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"})
|
||||
return
|
||||
}
|
||||
var req setHitlDefaultReviewerReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
reviewer := normalizeHitlReviewer(req.Reviewer)
|
||||
if err := h.hitlDefaultReviewerSaver.UpdateHitlDefaultReviewer(reviewer); err != nil {
|
||||
h.logger.Warn("写入 HITL 默认审批方到 config.yaml 失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.config != nil {
|
||||
h.config.Hitl.DefaultReviewer = reviewer
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "default_reviewer_update", "HITL 全局默认审批方更新", "hitl_config", "default_reviewer", nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"defaultReviewer": reviewer,
|
||||
})
|
||||
}
|
||||
|
||||
// SetHITLGlobalToolWhitelist 整表替换 config.yaml 中的全局免审批工具白名单。
|
||||
func (h *AgentHandler) SetHITLGlobalToolWhitelist(c *gin.Context) {
|
||||
if h.hitlWhitelistSaver == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"})
|
||||
return
|
||||
}
|
||||
var req setHitlGlobalWhitelistReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.hitlWhitelistSaver.SetHitlToolWhitelist(req.ToolWhitelist); err != nil {
|
||||
h.logger.Warn("写入 HITL 工具白名单到 config.yaml 失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "tool_whitelist_update", "HITL 全局白名单更新", "hitl_config", "tool_whitelist", nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"toolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
"hitlGlobalWhitelistMerged": false,
|
||||
})
|
||||
}
|
||||
|
||||
// MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。
|
||||
func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) {
|
||||
if h.hitlWhitelistSaver == nil {
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// auditAgentReview 在 reviewer=audit_agent 时由 LLM 代行审批。
|
||||
// 白名单工具在 shouldInterrupt 阶段已跳过,到达此处的一律需要裁决。
|
||||
func (h *AgentHandler) auditAgentReview(ctx context.Context, hitlMode, toolName string, payload map[string]interface{}) hitlDecision {
|
||||
if h == nil {
|
||||
return hitlDecision{Decision: "reject", Comment: "audit agent: handler unavailable"}
|
||||
}
|
||||
mode := normalizeHitlMode(hitlMode)
|
||||
prompt := config.DefaultHitlAuditAgentPrompt()
|
||||
if h.config != nil {
|
||||
prompt = h.config.Hitl.EffectiveAuditAgentPromptForMode(mode)
|
||||
}
|
||||
if h.auditLLM == nil {
|
||||
return hitlDecision{Decision: "reject", Comment: "audit agent: LLM 未配置"}
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
callCtx, cancel := context.WithTimeout(ctx, 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
userContent := buildAuditAgentReviewInput(mode, toolName, payload)
|
||||
requestBody := map[string]interface{}{
|
||||
"model": h.auditLLMModel(),
|
||||
"messages": []map[string]interface{}{
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": userContent},
|
||||
},
|
||||
"temperature": 0.1,
|
||||
"max_completion_tokens": 1024,
|
||||
// 审计裁决需要结构化 JSON;关闭 thinking 避免 Qwen 等把正文放进 reasoning_content 导致解析失败。
|
||||
"thinking": map[string]interface{}{"type": "disabled"},
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := h.auditLLM.ChatCompletion(callCtx, requestBody, &apiResponse); err != nil {
|
||||
h.logger.Warn("审计 Agent LLM 调用失败", zap.Error(err), zap.String("tool", toolName))
|
||||
return hitlDecision{
|
||||
Decision: "reject",
|
||||
Comment: "audit agent: LLM 调用失败,保守拒绝",
|
||||
}
|
||||
}
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return hitlDecision{Decision: "reject", Comment: "audit agent: LLM 无有效响应,保守拒绝"}
|
||||
}
|
||||
msg := apiResponse.Choices[0].Message
|
||||
raw := strings.TrimSpace(msg.Content)
|
||||
if raw == "" {
|
||||
raw = strings.TrimSpace(msg.ReasoningContent)
|
||||
}
|
||||
dec, err := parseAuditAgentLLMContent(raw)
|
||||
if err != nil {
|
||||
snippet := raw
|
||||
if len(snippet) > 240 {
|
||||
snippet = snippet[:240] + "..."
|
||||
}
|
||||
h.logger.Warn("审计 Agent 响应解析失败",
|
||||
zap.Error(err),
|
||||
zap.String("tool", toolName),
|
||||
zap.String("mode", mode),
|
||||
zap.String("snippet", snippet),
|
||||
)
|
||||
return hitlDecision{Decision: "reject", Comment: "audit agent: 响应无法解析,保守拒绝"}
|
||||
}
|
||||
if mode != "review_edit" && len(dec.EditedArguments) > 0 {
|
||||
h.logger.Warn("审计 Agent 在审批模式下返回 editedArguments,已忽略",
|
||||
zap.String("tool", toolName),
|
||||
)
|
||||
dec.EditedArguments = nil
|
||||
}
|
||||
if dec.Comment == "" {
|
||||
dec.Comment = "audit agent: " + dec.Decision
|
||||
} else if !strings.HasPrefix(strings.ToLower(dec.Comment), "audit agent") {
|
||||
dec.Comment = "audit agent: " + dec.Comment
|
||||
}
|
||||
return dec
|
||||
}
|
||||
|
||||
func (h *AgentHandler) auditLLMModel() string {
|
||||
if h.config != nil && strings.TrimSpace(h.config.OpenAI.Model) != "" {
|
||||
return strings.TrimSpace(h.config.OpenAI.Model)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildAuditAgentReviewInput(hitlMode, toolName string, payload map[string]interface{}) string {
|
||||
review := map[string]interface{}{
|
||||
"hitlMode": normalizeHitlMode(hitlMode),
|
||||
"toolName": strings.TrimSpace(toolName),
|
||||
}
|
||||
if payload != nil {
|
||||
for _, k := range []string{"arguments", "argumentsObj", "command", hitlPayloadUserMessage, hitlPayloadThinking, hitlPayloadReasoningChain, hitlPayloadPlanning} {
|
||||
if v, ok := payload[k]; ok && v != nil && fmt.Sprint(v) != "" {
|
||||
review[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
b, err := json.MarshalIndent(review, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Sprintf(`{"hitlMode":%q,"toolName":%q}`, normalizeHitlMode(hitlMode), toolName)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func parseAuditAgentLLMContent(content string) (hitlDecision, error) {
|
||||
s := strings.TrimSpace(content)
|
||||
if s == "" {
|
||||
return hitlDecision{}, errors.New("empty content")
|
||||
}
|
||||
for _, candidate := range auditAgentJSONCandidates(s) {
|
||||
dec, comment, editedArgs, err := parseAuditAgentDecisionObject(candidate)
|
||||
if err == nil {
|
||||
return hitlDecision{
|
||||
Decision: dec,
|
||||
Comment: comment,
|
||||
EditedArguments: editedArgs,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return hitlDecision{}, fmt.Errorf("no valid decision json in response")
|
||||
}
|
||||
|
||||
func auditAgentJSONCandidates(s string) []string {
|
||||
out := make([]string, 0, 4)
|
||||
seen := make(map[string]struct{})
|
||||
add := func(c string) {
|
||||
c = strings.TrimSpace(c)
|
||||
if c == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[c]; ok {
|
||||
return
|
||||
}
|
||||
seen[c] = struct{}{}
|
||||
out = append(out, c)
|
||||
}
|
||||
add(s)
|
||||
add(stripMarkdownCodeFence(s))
|
||||
if obj := extractFirstJSONObject(s); obj != "" {
|
||||
add(obj)
|
||||
}
|
||||
if obj := extractFirstJSONObject(stripMarkdownCodeFence(s)); obj != "" {
|
||||
add(obj)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stripMarkdownCodeFence(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
for _, fence := range []string{"```json", "```JSON", "```"} {
|
||||
if strings.HasPrefix(s, fence) {
|
||||
s = strings.TrimPrefix(s, fence)
|
||||
}
|
||||
}
|
||||
s = strings.TrimSuffix(s, "```")
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func extractFirstJSONObject(s string) string {
|
||||
start := strings.Index(s, "{")
|
||||
if start < 0 {
|
||||
return ""
|
||||
}
|
||||
depth := 0
|
||||
inStr := false
|
||||
esc := false
|
||||
for i := start; i < len(s); i++ {
|
||||
ch := s[i]
|
||||
if inStr {
|
||||
if esc {
|
||||
esc = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
esc = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inStr = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch ch {
|
||||
case '"':
|
||||
inStr = true
|
||||
case '{':
|
||||
depth++
|
||||
case '}':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return s[start : i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseAuditAgentDecisionObject(jsonText string) (decision, comment string, editedArgs map[string]interface{}, err error) {
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonText), &parsed); err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
rawDecision := auditAgentPickString(parsed, "decision", "Decision", "result", "action", "verdict", "决策", "决定")
|
||||
decision = normalizeAuditAgentDecision(rawDecision)
|
||||
if decision == "" {
|
||||
return "", "", nil, fmt.Errorf("missing decision")
|
||||
}
|
||||
comment = auditAgentPickString(parsed, "comment", "Comment", "reason", "message", "rationale", "备注", "理由", "说明")
|
||||
editedArgs = auditAgentPickObject(parsed, "editedArguments", "edited_arguments", "editedArgs")
|
||||
return decision, strings.TrimSpace(comment), editedArgs, nil
|
||||
}
|
||||
|
||||
func auditAgentPickString(m map[string]interface{}, keys ...string) string {
|
||||
for _, k := range keys {
|
||||
if v, ok := m[k]; ok && v != nil {
|
||||
s := strings.TrimSpace(fmt.Sprint(v))
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func auditAgentPickObject(m map[string]interface{}, keys ...string) map[string]interface{} {
|
||||
for _, k := range keys {
|
||||
v, ok := m[k]
|
||||
if !ok || v == nil {
|
||||
continue
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case map[string]interface{}:
|
||||
if len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" || s == "{}" {
|
||||
continue
|
||||
}
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(s), &obj); err == nil && len(obj) > 0 {
|
||||
return obj
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeAuditAgentDecision(v string) string {
|
||||
d := strings.ToLower(strings.TrimSpace(v))
|
||||
switch d {
|
||||
case "approve", "approved", "pass", "passed", "allow", "allowed", "yes", "ok", "accept", "accepted":
|
||||
return "approve"
|
||||
case "reject", "rejected", "deny", "denied", "no", "block", "blocked", "refuse", "refused":
|
||||
return "reject"
|
||||
}
|
||||
switch strings.TrimSpace(v) {
|
||||
case "通过", "批准", "允许", "同意", "放行":
|
||||
return "approve"
|
||||
case "拒绝", "驳回", "禁止", "否决":
|
||||
return "reject"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type hitlAuditStrategyReq struct {
|
||||
AuditAgentPrompt string `json:"auditAgentPrompt"`
|
||||
AuditAgentPromptReviewEdit string `json:"auditAgentPromptReviewEdit"`
|
||||
}
|
||||
|
||||
func (h *AgentHandler) GetHITLAuditStrategy(c *gin.Context) {
|
||||
approvalPrompt := config.DefaultHitlAuditAgentPrompt()
|
||||
reviewEditPrompt := config.DefaultHitlAuditAgentPromptReviewEdit()
|
||||
approvalCustom := false
|
||||
reviewEditCustom := false
|
||||
if h.config != nil {
|
||||
approvalPrompt = h.config.Hitl.EffectiveAuditAgentPromptForMode("approval")
|
||||
reviewEditPrompt = h.config.Hitl.EffectiveAuditAgentPromptForMode("review_edit")
|
||||
approvalCustom = strings.TrimSpace(h.config.Hitl.AuditAgentPrompt) != ""
|
||||
reviewEditCustom = strings.TrimSpace(h.config.Hitl.AuditAgentPromptReviewEdit) != ""
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"auditAgentPrompt": approvalPrompt,
|
||||
"auditAgentPromptCustom": approvalCustom,
|
||||
"auditAgentPromptReviewEdit": reviewEditPrompt,
|
||||
"auditAgentPromptReviewEditCustom": reviewEditCustom,
|
||||
"defaultAuditAgentPrompt": config.DefaultHitlAuditAgentPrompt(),
|
||||
"defaultAuditAgentPromptReviewEdit": config.DefaultHitlAuditAgentPromptReviewEdit(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) UpdateHITLAuditStrategy(c *gin.Context) {
|
||||
if h.hitlStrategySaver == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 策略持久化不可用"})
|
||||
return
|
||||
}
|
||||
var req hitlAuditStrategyReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
approvalPrompt := strings.TrimSpace(req.AuditAgentPrompt)
|
||||
reviewEditPrompt := strings.TrimSpace(req.AuditAgentPromptReviewEdit)
|
||||
if err := h.hitlStrategySaver.UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt); err != nil {
|
||||
h.logger.Warn("保存审计 Agent 提示词失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "audit_strategy_update", "HITL 审计策略更新", "hitl_config", "audit_agent_prompt", nil)
|
||||
}
|
||||
if h.config != nil {
|
||||
h.config.Hitl.AuditAgentPrompt = approvalPrompt
|
||||
h.config.Hitl.AuditAgentPromptReviewEdit = reviewEditPrompt
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"auditAgentPrompt": config.HitlConfig{AuditAgentPrompt: approvalPrompt}.EffectiveAuditAgentPromptForMode("approval"),
|
||||
"auditAgentPromptCustom": approvalPrompt != "",
|
||||
"auditAgentPromptReviewEdit": config.HitlConfig{AuditAgentPromptReviewEdit: reviewEditPrompt}.EffectiveAuditAgentPromptForMode("review_edit"),
|
||||
"auditAgentPromptReviewEditCustom": reviewEditPrompt != "",
|
||||
})
|
||||
}
|
||||
|
||||
// HitlAuditStrategySaver 持久化审计 Agent 提示词到 config.yaml。
|
||||
type HitlAuditStrategySaver interface {
|
||||
UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt string) error
|
||||
}
|
||||
|
||||
// SetHitlAuditStrategySaver 设置审计策略落盘。
|
||||
func (h *AgentHandler) SetHitlAuditStrategySaver(s HitlAuditStrategySaver) {
|
||||
h.hitlStrategySaver = s
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseAuditAgentLLMContentApprove(t *testing.T) {
|
||||
d, err := parseAuditAgentLLMContent(`{"decision":"approve","comment":"与任务一致"}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Decision != "approve" || d.Comment != "与任务一致" {
|
||||
t.Fatalf("unexpected %+v", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuditAgentLLMContentReject(t *testing.T) {
|
||||
d, err := parseAuditAgentLLMContent("```json\n{\"decision\":\"reject\",\"comment\":\"风险过高\"}\n```")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Decision != "reject" {
|
||||
t.Fatalf("expected reject, got %s", d.Decision)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuditAgentLLMContentInvalid(t *testing.T) {
|
||||
_, err := parseAuditAgentLLMContent(`{"decision":"maybe"}`)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid decision")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuditAgentLLMContentProseWrapped(t *testing.T) {
|
||||
d, err := parseAuditAgentLLMContent("好的,裁决如下:\n```json\n{\"decision\":\"approve\",\"comment\":\"只读 ls\"}\n```\n以上。")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Decision != "approve" {
|
||||
t.Fatalf("expected approve, got %s", d.Decision)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuditAgentLLMContentChineseDecision(t *testing.T) {
|
||||
d, err := parseAuditAgentLLMContent(`{"decision":"通过","comment":"风险低"}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Decision != "approve" {
|
||||
t.Fatalf("expected approve, got %s", d.Decision)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuditAgentLLMContentWithEditedArguments(t *testing.T) {
|
||||
d, err := parseAuditAgentLLMContent(`{"decision":"approve","comment":"收窄路径","editedArguments":{"path":"/safe"}}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Decision != "approve" {
|
||||
t.Fatalf("expected approve, got %s", d.Decision)
|
||||
}
|
||||
if d.EditedArguments == nil || d.EditedArguments["path"] != "/safe" {
|
||||
t.Fatalf("unexpected edited args: %+v", d.EditedArguments)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuditAgentReviewInputIncludesMode(t *testing.T) {
|
||||
s := buildAuditAgentReviewInput("review_edit", "execute", map[string]interface{}{
|
||||
"arguments": `{"command":"pwd"}`,
|
||||
})
|
||||
if !strings.Contains(s, "review_edit") || !strings.Contains(s, "execute") {
|
||||
t.Fatalf("unexpected input: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuditAgentReviewInput(t *testing.T) {
|
||||
s := buildAuditAgentReviewInput("approval", "nmap", map[string]interface{}{
|
||||
"arguments": `{"target":"10.0.0.1"}`,
|
||||
"userMessage": "扫描内网",
|
||||
})
|
||||
if s == "" {
|
||||
t.Fatal("expected non-empty input")
|
||||
}
|
||||
if !strings.Contains(s, "nmap") || !strings.Contains(s, "10.0.0.1") || !strings.Contains(s, "扫描内网") {
|
||||
t.Fatalf("unexpected input: %s", s)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type hitlCognitionState struct {
|
||||
AssistantMessageID string
|
||||
UserMessage string
|
||||
Thinking string
|
||||
ReasoningChain string
|
||||
Planning string
|
||||
}
|
||||
|
||||
// GetHitlCognition 返回当前运行任务上缓存的本轮 HITL 上下文(不含会话历史)。
|
||||
func (m *AgentTaskManager) GetHitlCognition(conversationID string) hitlCognitionFields {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if m == nil || conversationID == "" {
|
||||
return hitlCognitionFields{}
|
||||
}
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil || t.hitlCognition == nil {
|
||||
return hitlCognitionFields{}
|
||||
}
|
||||
c := t.hitlCognition
|
||||
return hitlCognitionFields{
|
||||
UserMessage: c.UserMessage,
|
||||
Thinking: c.Thinking,
|
||||
ReasoningChain: c.ReasoningChain,
|
||||
Planning: c.Planning,
|
||||
}
|
||||
}
|
||||
|
||||
// ResetHitlCognition 新任务开始时重置本轮 HITL 上下文。
|
||||
func (m *AgentTaskManager) ResetHitlCognition(conversationID, userMessage string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if m == nil || conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil {
|
||||
return
|
||||
}
|
||||
t.hitlCognition = &hitlCognitionState{UserMessage: strings.TrimSpace(userMessage)}
|
||||
}
|
||||
|
||||
// SetHitlAssistantMessageID 记录当前助手消息 ID,供 HITL 与 DB 回退对齐。
|
||||
func (m *AgentTaskManager) SetHitlAssistantMessageID(conversationID, assistantMessageID string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
assistantMessageID = strings.TrimSpace(assistantMessageID)
|
||||
if m == nil || conversationID == "" || assistantMessageID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil {
|
||||
return
|
||||
}
|
||||
if t.hitlCognition == nil {
|
||||
t.hitlCognition = &hitlCognitionState{}
|
||||
}
|
||||
t.hitlCognition.AssistantMessageID = assistantMessageID
|
||||
}
|
||||
|
||||
// UpdateHitlCognitionSnapshot 从进行中的进度流快照更新 thinking / reasoning / planning。
|
||||
func (m *AgentTaskManager) UpdateHitlCognitionSnapshot(conversationID, assistantMessageID, thinking, reasoningChain, planning string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if m == nil || conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil {
|
||||
return
|
||||
}
|
||||
if t.hitlCognition == nil {
|
||||
t.hitlCognition = &hitlCognitionState{}
|
||||
}
|
||||
if id := strings.TrimSpace(assistantMessageID); id != "" {
|
||||
t.hitlCognition.AssistantMessageID = id
|
||||
}
|
||||
if s := strings.TrimSpace(thinking); s != "" {
|
||||
t.hitlCognition.Thinking = s
|
||||
}
|
||||
if s := strings.TrimSpace(reasoningChain); s != "" {
|
||||
t.hitlCognition.ReasoningChain = s
|
||||
}
|
||||
if s := strings.TrimSpace(planning); s != "" {
|
||||
t.hitlCognition.Planning = s
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
hitlPayloadUserMessage = "userMessage"
|
||||
hitlPayloadThinking = "thinking"
|
||||
hitlPayloadReasoningChain = "reasoningChain"
|
||||
hitlPayloadPlanning = "planning"
|
||||
)
|
||||
|
||||
type hitlCognitionFields struct {
|
||||
UserMessage string
|
||||
Thinking string
|
||||
ReasoningChain string
|
||||
Planning string
|
||||
}
|
||||
|
||||
func (h *AgentHandler) enrichHitlApprovalPayload(conversationID, assistantMessageID string, payload map[string]interface{}) {
|
||||
if h == nil || payload == nil {
|
||||
return
|
||||
}
|
||||
cog := h.collectHitlCognition(conversationID, assistantMessageID)
|
||||
if s := strings.TrimSpace(cog.UserMessage); s != "" {
|
||||
payload[hitlPayloadUserMessage] = s
|
||||
}
|
||||
if s := strings.TrimSpace(cog.Thinking); s != "" {
|
||||
payload[hitlPayloadThinking] = s
|
||||
}
|
||||
if s := strings.TrimSpace(cog.ReasoningChain); s != "" {
|
||||
payload[hitlPayloadReasoningChain] = s
|
||||
}
|
||||
if s := strings.TrimSpace(cog.Planning); s != "" {
|
||||
payload[hitlPayloadPlanning] = s
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) collectHitlCognition(conversationID, assistantMessageID string) hitlCognitionFields {
|
||||
var out hitlCognitionFields
|
||||
if h.tasks != nil {
|
||||
out = h.tasks.GetHitlCognition(conversationID)
|
||||
}
|
||||
if strings.TrimSpace(out.UserMessage) == "" && h.db != nil {
|
||||
if msg, err := h.db.GetTurnUserMessage(conversationID, assistantMessageID); err == nil {
|
||||
out.UserMessage = msg
|
||||
}
|
||||
}
|
||||
if h.db != nil && assistantMessageID != "" {
|
||||
dbCog, err := h.db.GetAssistantCognitionTexts(assistantMessageID)
|
||||
if err == nil {
|
||||
if strings.TrimSpace(out.Thinking) == "" {
|
||||
out.Thinking = dbCog.Thinking
|
||||
}
|
||||
if strings.TrimSpace(out.ReasoningChain) == "" {
|
||||
out.ReasoningChain = dbCog.ReasoningChain
|
||||
}
|
||||
if strings.TrimSpace(out.Planning) == "" {
|
||||
out.Planning = dbCog.Planning
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func snapshotHitlCognitionFromStreams(thinkingStreams map[string]*thinkingBuf, respPlan *responsePlanAgg) (thinking, reasoningChain, planning string) {
|
||||
if len(thinkingStreams) > 0 {
|
||||
var thinkingParts, reasoningParts []string
|
||||
for _, tb := range thinkingStreams {
|
||||
if tb == nil {
|
||||
continue
|
||||
}
|
||||
content := strings.TrimSpace(tb.b.String())
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
if tb.persistAs == "reasoning_chain" {
|
||||
reasoningParts = append(reasoningParts, content)
|
||||
} else {
|
||||
thinkingParts = append(thinkingParts, content)
|
||||
}
|
||||
}
|
||||
thinking = strings.Join(thinkingParts, "\n\n")
|
||||
reasoningChain = strings.Join(reasoningParts, "\n\n")
|
||||
}
|
||||
if respPlan != nil {
|
||||
planning = strings.TrimSpace(respPlan.b.String())
|
||||
}
|
||||
return thinking, reasoningChain, planning
|
||||
}
|
||||
|
||||
func (h *AgentHandler) syncHitlCognitionFromProgress(conversationID, assistantMessageID string, thinkingStreams map[string]*thinkingBuf, respPlan *responsePlanAgg) {
|
||||
if h == nil || h.tasks == nil {
|
||||
return
|
||||
}
|
||||
thinking, reasoning, planning := snapshotHitlCognitionFromStreams(thinkingStreams, respPlan)
|
||||
if thinking == "" && reasoning == "" && planning == "" {
|
||||
return
|
||||
}
|
||||
h.tasks.UpdateHitlCognitionSnapshot(conversationID, assistantMessageID, thinking, reasoning, planning)
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestEnrichHitlApprovalPayload(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("db: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmp)
|
||||
|
||||
conv, err := db.CreateConversation("hitl ctx", database.ConversationCreateMeta{})
|
||||
if err != nil {
|
||||
t.Fatalf("conv: %v", err)
|
||||
}
|
||||
if _, err := db.AddMessage(conv.ID, "user", "scan 10.0.0.1 please", nil); err != nil {
|
||||
t.Fatalf("user msg: %v", err)
|
||||
}
|
||||
asst, err := db.AddMessage(conv.ID, "assistant", "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("asst msg: %v", err)
|
||||
}
|
||||
if err := db.AddProcessDetail(asst.ID, conv.ID, "thinking", "need port scan first", nil); err != nil {
|
||||
t.Fatalf("detail: %v", err)
|
||||
}
|
||||
|
||||
h := &AgentHandler{db: db, tasks: NewAgentTaskManager()}
|
||||
payload := map[string]interface{}{"toolName": "nmap", "arguments": "{}"}
|
||||
h.enrichHitlApprovalPayload(conv.ID, asst.ID, payload)
|
||||
|
||||
if got := payload["userMessage"]; got != "scan 10.0.0.1 please" {
|
||||
t.Fatalf("userMessage=%v", got)
|
||||
}
|
||||
if got := payload["thinking"]; got != "need port scan first" {
|
||||
t.Fatalf("thinking=%v", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const hitlPayloadExecutionResult = "executionResult"
|
||||
|
||||
type hitlExecutionResult struct {
|
||||
Success bool `json:"success"`
|
||||
Result string `json:"result,omitempty"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
ToolCallID string `json:"toolCallId,omitempty"`
|
||||
RecordedAt time.Time `json:"recordedAt"`
|
||||
}
|
||||
|
||||
type hitlApprovedExecTrack struct {
|
||||
InterruptID string
|
||||
ConversationID string
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
}
|
||||
|
||||
// TrackApprovedHitlExecution 审批通过后登记,待 tool_result 回写执行结果。
|
||||
func (m *HITLManager) TrackApprovedHitlExecution(interruptID, conversationID, toolName, toolCallID string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
interruptID = strings.TrimSpace(interruptID)
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if interruptID == "" || conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.approvedExec == nil {
|
||||
m.approvedExec = make(map[string][]hitlApprovedExecTrack)
|
||||
}
|
||||
m.approvedExec[conversationID] = append(m.approvedExec[conversationID], hitlApprovedExecTrack{
|
||||
InterruptID: interruptID,
|
||||
ConversationID: conversationID,
|
||||
ToolName: strings.TrimSpace(toolName),
|
||||
ToolCallID: strings.TrimSpace(toolCallID),
|
||||
})
|
||||
}
|
||||
|
||||
func (m *HITLManager) popApprovedInterruptForTool(conversationID, toolCallID, toolName string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
toolCallID = strings.TrimSpace(toolCallID)
|
||||
toolName = strings.TrimSpace(toolName)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
queue := m.approvedExec[conversationID]
|
||||
if len(queue) == 0 {
|
||||
return ""
|
||||
}
|
||||
idx := -1
|
||||
if toolCallID != "" {
|
||||
for i, t := range queue {
|
||||
if t.ToolCallID == toolCallID {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if idx < 0 && toolName != "" {
|
||||
for i, t := range queue {
|
||||
if strings.EqualFold(t.ToolName, toolName) {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
id := queue[idx].InterruptID
|
||||
queue = append(queue[:idx], queue[idx+1:]...)
|
||||
if len(queue) == 0 {
|
||||
delete(m.approvedExec, conversationID)
|
||||
} else {
|
||||
m.approvedExec[conversationID] = queue
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func mergeHitlPayloadExecutionResult(payloadJSON string, exec hitlExecutionResult) (string, error) {
|
||||
root := make(map[string]interface{})
|
||||
if strings.TrimSpace(payloadJSON) != "" {
|
||||
_ = json.Unmarshal([]byte(payloadJSON), &root)
|
||||
}
|
||||
if root == nil {
|
||||
root = make(map[string]interface{})
|
||||
}
|
||||
root[hitlPayloadExecutionResult] = exec
|
||||
out, err := json.Marshal(root)
|
||||
if err != nil {
|
||||
return payloadJSON, err
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) recordHitlToolExecutionResult(conversationID, toolCallID, toolName string, success bool, result string) {
|
||||
if h == nil || h.hitlManager == nil || h.db == nil {
|
||||
return
|
||||
}
|
||||
interruptID := h.hitlManager.popApprovedInterruptForTool(conversationID, toolCallID, toolName)
|
||||
if interruptID == "" {
|
||||
return
|
||||
}
|
||||
var payloadJSON string
|
||||
err := h.db.QueryRow(`SELECT payload FROM hitl_interrupts WHERE id = ?`, interruptID).Scan(&payloadJSON)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
merged, err := mergeHitlPayloadExecutionResult(payloadJSON, hitlExecutionResult{
|
||||
Success: success,
|
||||
Result: strings.TrimSpace(result),
|
||||
ToolName: strings.TrimSpace(toolName),
|
||||
ToolCallID: strings.TrimSpace(toolCallID),
|
||||
RecordedAt: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = h.db.Exec(`UPDATE hitl_interrupts SET payload = ? WHERE id = ?`, merged, interruptID)
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMergeHitlPayloadExecutionResult(t *testing.T) {
|
||||
merged, err := mergeHitlPayloadExecutionResult(`{"userMessage":"hi","toolName":"nmap"}`, hitlExecutionResult{
|
||||
Success: true,
|
||||
Result: "open ports: 80",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var root map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(merged), &root); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if root["userMessage"] != "hi" {
|
||||
t.Fatalf("userMessage lost: %v", root["userMessage"])
|
||||
}
|
||||
exec, ok := root["executionResult"].(map[string]interface{})
|
||||
if !ok || exec["success"] != true {
|
||||
t.Fatalf("executionResult missing: %v", root["executionResult"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopApprovedInterruptForTool(t *testing.T) {
|
||||
m := NewHITLManager(nil, nil)
|
||||
m.TrackApprovedHitlExecution("hitl_a", "conv1", "nmap", "tc1")
|
||||
m.TrackApprovedHitlExecution("hitl_b", "conv1", "exec", "")
|
||||
if id := m.popApprovedInterruptForTool("conv1", "tc1", "nmap"); id != "hitl_a" {
|
||||
t.Fatalf("tc1 match=%q", id)
|
||||
}
|
||||
if id := m.popApprovedInterruptForTool("conv1", "", "exec"); id != "hitl_b" {
|
||||
t.Fatalf("tool name match=%q", id)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,263 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func normalizeHitlReviewer(v string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "audit_agent", "agent", "ai":
|
||||
return "audit_agent"
|
||||
default:
|
||||
return "human"
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeHitlDecidedBy(v string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "audit_agent", "agent", "ai":
|
||||
return "audit_agent"
|
||||
case "system", "timeout":
|
||||
return "system"
|
||||
case "manual":
|
||||
return "manual"
|
||||
default:
|
||||
return "human"
|
||||
}
|
||||
}
|
||||
|
||||
func (m *HITLManager) migrateHitlSchemaColumns() {
|
||||
_, _ = m.db.Exec(`ALTER TABLE hitl_interrupts ADD COLUMN decided_by TEXT NOT NULL DEFAULT 'human'`)
|
||||
_, _ = m.db.Exec(`ALTER TABLE hitl_conversation_configs ADD COLUMN reviewer TEXT NOT NULL DEFAULT 'human'`)
|
||||
}
|
||||
|
||||
func hitlInterruptRowToMap(
|
||||
id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string,
|
||||
messageID sql.NullString,
|
||||
decision, comment sql.NullString,
|
||||
createdAt time.Time,
|
||||
decidedAt sql.NullTime,
|
||||
) map[string]interface{} {
|
||||
msgID := ""
|
||||
if messageID.Valid {
|
||||
msgID = messageID.String
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"id": id,
|
||||
"conversationId": cid,
|
||||
"messageId": msgID,
|
||||
"mode": mode,
|
||||
"toolName": toolName,
|
||||
"toolCallId": toolCallID,
|
||||
"payload": payload,
|
||||
"status": rowStatus,
|
||||
"decision": decision.String,
|
||||
"comment": comment.String,
|
||||
"decidedBy": decidedBy,
|
||||
"createdAt": createdAt,
|
||||
"decidedAt": func() interface{} {
|
||||
if decidedAt.Valid {
|
||||
return decidedAt.Time
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) buildHitlListQuery(logs bool) (string, []interface{}) {
|
||||
where, args := h.buildHitlLogsWhere(logs)
|
||||
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts` + where
|
||||
return q, args
|
||||
}
|
||||
|
||||
func (h *AgentHandler) buildHitlLogsWhere(logs bool) (string, []interface{}) {
|
||||
q := " WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
if logs {
|
||||
q += " AND status != 'pending'"
|
||||
} else {
|
||||
q += " AND status = 'pending'"
|
||||
}
|
||||
return q, args
|
||||
}
|
||||
|
||||
func (h *AgentHandler) appendHitlListFilters(q string, args []interface{}, c *gin.Context) (string, []interface{}) {
|
||||
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
||||
toolName := strings.TrimSpace(c.Query("toolName"))
|
||||
decision := strings.TrimSpace(c.Query("decision"))
|
||||
decidedBy := strings.TrimSpace(c.Query("decidedBy"))
|
||||
status := strings.TrimSpace(c.Query("status"))
|
||||
search := strings.TrimSpace(c.Query("q"))
|
||||
|
||||
if conversationID != "" {
|
||||
q += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if toolName != "" {
|
||||
q += " AND tool_name LIKE ?"
|
||||
args = append(args, "%"+toolName+"%")
|
||||
}
|
||||
if decision != "" && decision != "all" {
|
||||
q += " AND decision = ?"
|
||||
args = append(args, decision)
|
||||
}
|
||||
if decidedBy != "" && decidedBy != "all" {
|
||||
q += " AND COALESCE(decided_by,'human') = ?"
|
||||
args = append(args, normalizeHitlDecidedBy(decidedBy))
|
||||
}
|
||||
if status != "" && status != "all" {
|
||||
q += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
if search != "" {
|
||||
like := "%" + search + "%"
|
||||
q += " AND (id LIKE ? OR conversation_id LIKE ? OR tool_name LIKE ? OR payload LIKE ? OR COALESCE(decision_comment,'') LIKE ?)"
|
||||
args = append(args, like, like, like, like, like)
|
||||
}
|
||||
return q, args
|
||||
}
|
||||
|
||||
func (h *AgentHandler) scanHitlInterruptRows(rows *sql.Rows) ([]map[string]interface{}, error) {
|
||||
items := make([]map[string]interface{}, 0)
|
||||
for rows.Next() {
|
||||
var id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string
|
||||
var messageID sql.NullString
|
||||
var decision, comment sql.NullString
|
||||
var createdAt time.Time
|
||||
var decidedAt sql.NullTime
|
||||
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, hitlInterruptRowToMap(id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt))
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) countHitlQuery(baseQ string, args []interface{}) (int, error) {
|
||||
countQ := "SELECT COUNT(*) FROM (" + baseQ + ") AS hitl_cnt"
|
||||
var total int
|
||||
if err := h.db.QueryRow(countQ, args...).Scan(&total); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) ListHITLLogs(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
q, args := h.buildHitlListQuery(true)
|
||||
q, args = h.appendHitlListFilters(q, args, c)
|
||||
total, err := h.countHitlQuery(q, args)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
q += " ORDER BY COALESCE(decided_at, created_at) DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, pageSize, offset)
|
||||
rows, err := h.db.Query(q, args...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
items, err := h.scanHitlInterruptRows(rows)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total, "retentionDays": h.hitlRetentionDays()})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) hitlRetentionDays() int {
|
||||
if h.config != nil {
|
||||
return h.config.Hitl.RetentionDaysEffective()
|
||||
}
|
||||
return config.HitlConfig{}.RetentionDaysEffective()
|
||||
}
|
||||
|
||||
// DeleteHITLLogs 批量删除或按筛选清空已决策的人机协同审计日志(不删除 pending)。
|
||||
func (h *AgentHandler) DeleteHITLLogs(c *gin.Context) {
|
||||
var request struct {
|
||||
IDs []string `json:"ids"`
|
||||
All bool `json:"all"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var deleted int64
|
||||
var err error
|
||||
if request.All {
|
||||
where, args := h.buildHitlLogsWhere(true)
|
||||
where, args = h.appendHitlListFilters(where, args, c)
|
||||
deleted, err = h.db.DeleteHitlInterruptLogsMatching(where, args)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "logs_clear", "清空人机协同审计日志", "hitl_interrupt", "", map[string]interface{}{
|
||||
"deleted": deleted,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if len(request.IDs) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "审计日志 ID 列表不能为空"})
|
||||
return
|
||||
}
|
||||
deleted, err = h.db.DeleteHitlInterruptLogsByIDs(request.IDs)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "logs_delete_batch", "批量删除人机协同审计日志", "hitl_interrupt", "", map[string]interface{}{
|
||||
"count": len(request.IDs),
|
||||
"deleted": deleted,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功", "deleted": deleted})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) GetHITLLog(c *gin.Context) {
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
|
||||
return
|
||||
}
|
||||
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts WHERE id = ?`
|
||||
var rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string
|
||||
var messageID sql.NullString
|
||||
var decision, comment sql.NullString
|
||||
var createdAt time.Time
|
||||
var decidedAt sql.NullTime
|
||||
err := h.db.QueryRow(q, id).Scan(&rowID, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, hitlInterruptRowToMap(rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt))
|
||||
}
|
||||
@@ -133,6 +133,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"userMessageId": prep.UserMessageID,
|
||||
})
|
||||
}
|
||||
if h.runRoleWorkflowStreamIfBound(&req, prep, sendEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
@@ -188,6 +191,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
var emptyResponseContinueAttempt int
|
||||
|
||||
for {
|
||||
segmentMainIterationMax := 0
|
||||
@@ -251,6 +255,13 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
mw := &h.config.MultiAgent.EinoMiddleware
|
||||
if h.tryContinueOnEinoEmptyResponse(taskCtx, mw, conversationID, result, &emptyResponseContinueAttempt, &curHistory, &curFinalMessage, progressCallback) {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause, taskCtx, timeoutCancel = h.rebindEinoRunningTask(conversationID, timeoutCancel)
|
||||
continue
|
||||
}
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
@@ -399,6 +410,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(prep.ConversationID)
|
||||
}
|
||||
if h.runRoleWorkflowJSONIfBound(c, &req, prep) {
|
||||
return
|
||||
}
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
|
||||
defer cancelWithCause(nil)
|
||||
|
||||
+23
-23
@@ -506,7 +506,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
},
|
||||
"CreateVulnerabilityRequest": map[string]interface{}{
|
||||
"type": "object",
|
||||
"required": []string{"conversation_id", "title", "severity"},
|
||||
"required": []string{"conversation_id", "title", "description", "severity", "type", "target", "reproduction_steps", "evidence", "impact", "recommendation"},
|
||||
"properties": map[string]interface{}{
|
||||
"conversation_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
@@ -538,10 +538,9 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"type": "string",
|
||||
"description": "受影响的目标",
|
||||
},
|
||||
"proof": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞证明",
|
||||
},
|
||||
"preconditions": map[string]interface{}{"type": "string", "description": "前置条件"},
|
||||
"reproduction_steps": map[string]interface{}{"type": "string", "description": "复现步骤"},
|
||||
"evidence": map[string]interface{}{"type": "string", "description": "证据/POC,包含请求响应、命令输出、截图说明、日志等"},
|
||||
"impact": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "影响",
|
||||
@@ -550,6 +549,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"type": "string",
|
||||
"description": "修复建议",
|
||||
},
|
||||
"retest_notes": map[string]interface{}{"type": "string", "description": "复测方式"},
|
||||
},
|
||||
},
|
||||
"UpdateVulnerabilityRequest": map[string]interface{}{
|
||||
@@ -581,10 +581,9 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"type": "string",
|
||||
"description": "受影响的目标",
|
||||
},
|
||||
"proof": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞证明",
|
||||
},
|
||||
"preconditions": map[string]interface{}{"type": "string", "description": "前置条件"},
|
||||
"reproduction_steps": map[string]interface{}{"type": "string", "description": "复现步骤"},
|
||||
"evidence": map[string]interface{}{"type": "string", "description": "证据/POC,包含请求响应、命令输出、截图说明、日志等"},
|
||||
"impact": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "影响",
|
||||
@@ -593,6 +592,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"type": "string",
|
||||
"description": "修复建议",
|
||||
},
|
||||
"retest_notes": map[string]interface{}{"type": "string", "description": "复测方式"},
|
||||
},
|
||||
},
|
||||
"ListVulnerabilitiesResponse": map[string]interface{}{
|
||||
@@ -805,18 +805,18 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"type": "object",
|
||||
"description": "视觉分析(analyze_image MCP 工具);enabled 且 model 非空时注册工具",
|
||||
"properties": map[string]interface{}{
|
||||
"enabled": map[string]interface{}{"type": "boolean", "description": "是否启用 analyze_image"},
|
||||
"model": map[string]interface{}{"type": "string", "description": "视觉模型名(必填)", "example": "qwen-vl-max"},
|
||||
"api_key": map[string]interface{}{"type": "string", "description": "API Key;留空复用 openai.api_key"},
|
||||
"base_url": map[string]interface{}{"type": "string", "description": "Base URL;留空复用 openai.base_url"},
|
||||
"provider": map[string]interface{}{"type": "string", "description": "提供商;留空复用 openai.provider"},
|
||||
"timeout_seconds": map[string]interface{}{"type": "integer", "description": "VL 调用超时(秒)"},
|
||||
"max_image_bytes": map[string]interface{}{"type": "integer", "description": "原始文件大小上限(字节)"},
|
||||
"max_dimension": map[string]interface{}{"type": "integer", "description": "长边缩放像素"},
|
||||
"jpeg_quality": map[string]interface{}{"type": "integer", "description": "JPEG 质量 60-100"},
|
||||
"max_payload_bytes": map[string]interface{}{"type": "integer", "description": "送 API 体积上限(字节)"},
|
||||
"skip_preprocess_below_bytes": map[string]interface{}{"type": "integer", "description": "低于该字节且尺寸合规时可原图直传;0=始终压缩"},
|
||||
"detail": map[string]interface{}{"type": "string", "enum": []string{"low", "high", "auto"}, "description": "OpenAI 兼容 image detail"},
|
||||
"enabled": map[string]interface{}{"type": "boolean", "description": "是否启用 analyze_image"},
|
||||
"model": map[string]interface{}{"type": "string", "description": "视觉模型名(必填)", "example": "qwen-vl-max"},
|
||||
"api_key": map[string]interface{}{"type": "string", "description": "API Key;留空复用 openai.api_key"},
|
||||
"base_url": map[string]interface{}{"type": "string", "description": "Base URL;留空复用 openai.base_url"},
|
||||
"provider": map[string]interface{}{"type": "string", "description": "提供商;留空复用 openai.provider"},
|
||||
"timeout_seconds": map[string]interface{}{"type": "integer", "description": "VL 调用超时(秒)"},
|
||||
"max_image_bytes": map[string]interface{}{"type": "integer", "description": "原始文件大小上限(字节)"},
|
||||
"max_dimension": map[string]interface{}{"type": "integer", "description": "长边缩放像素"},
|
||||
"jpeg_quality": map[string]interface{}{"type": "integer", "description": "JPEG 质量 60-100"},
|
||||
"max_payload_bytes": map[string]interface{}{"type": "integer", "description": "送 API 体积上限(字节)"},
|
||||
"skip_preprocess_below_bytes": map[string]interface{}{"type": "integer", "description": "低于该字节且尺寸合规时可原图直传;0=始终压缩"},
|
||||
"detail": map[string]interface{}{"type": "string", "enum": []string{"low", "high", "auto"}, "description": "OpenAI 兼容 image detail"},
|
||||
},
|
||||
},
|
||||
"AnalyzeImageToolCall": map[string]interface{}{
|
||||
@@ -1432,7 +1432,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
{
|
||||
"name": "id", "in": "path", "required": true,
|
||||
"description": "对话ID",
|
||||
"schema": map[string]interface{}{"type": "string"},
|
||||
"schema": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
"requestBody": map[string]interface{}{
|
||||
@@ -2570,7 +2570,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"source_fact_key", "target_fact_key", "edge_type"},
|
||||
"properties": map[string]interface{}{
|
||||
"source_fact_key": map[string]interface{}{"type": "string"},
|
||||
|
||||
@@ -7,7 +7,8 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// agentSessionContextBlock 注入会话工作目录、项目黑板与用户原文锚点(用于 system prompt 追加块)。
|
||||
// agentSessionContextBlock 注入会话工作目录与项目黑板(用于 system prompt 追加块)。
|
||||
// 用户输入由 message history 承载;压缩后由 summarization 摘要指令保留关键约束。
|
||||
func (h *AgentHandler) agentSessionContextBlock(conversationID string) string {
|
||||
var parts []string
|
||||
if ws := h.buildWorkspaceBlock(conversationID); ws != "" {
|
||||
@@ -16,9 +17,6 @@ func (h *AgentHandler) agentSessionContextBlock(conversationID string) string {
|
||||
if bb := h.projectBlackboardBlock(conversationID); bb != "" {
|
||||
parts = append(parts, bb)
|
||||
}
|
||||
if uv := h.userVerbatimAnchorBlock(conversationID); uv != "" {
|
||||
parts = append(parts, uv)
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
@@ -70,29 +68,6 @@ func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
||||
return strings.TrimSpace(block)
|
||||
}
|
||||
|
||||
// userVerbatimAnchorBlock 从 messages 表构建用户各轮原文锚点(压缩后仍由 summarization Finalize 刷新)。
|
||||
func (h *AgentHandler) userVerbatimAnchorBlock(conversationID string) string {
|
||||
if h == nil || h.db == nil || h.config == nil {
|
||||
return ""
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
maxRunes := h.config.MultiAgent.UserVerbatimAnchorMaxRunesEffective()
|
||||
if maxRunes < 0 {
|
||||
return ""
|
||||
}
|
||||
msgs, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("构建用户原文锚点失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
return project.BuildUserVerbatimAnchorBlockFromMessages(msgs, maxRunes)
|
||||
}
|
||||
|
||||
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
|
||||
func (h *AgentHandler) conversationProjectID(conversationID string) string {
|
||||
if h == nil || h.db == nil {
|
||||
|
||||
+45
-22
@@ -711,12 +711,27 @@ type wecomReplyXML struct {
|
||||
Content string `xml:"Content"`
|
||||
}
|
||||
|
||||
// wecomRequireToken 企业微信回调必须配置 Token;未配置时拒绝请求,防止未授权触发 Agent。
|
||||
func (h *RobotHandler) wecomRequireToken(c *gin.Context) (string, bool) {
|
||||
token := strings.TrimSpace(h.config.Robots.Wecom.Token)
|
||||
if token == "" {
|
||||
h.logger.Warn("企业微信已启用但未配置 token,已拒绝回调(请在配置中设置 robots.wecom.token)")
|
||||
c.String(http.StatusForbidden, "")
|
||||
return "", false
|
||||
}
|
||||
return token, true
|
||||
}
|
||||
|
||||
// HandleWecomGET 企业微信 URL 校验(GET)
|
||||
func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
|
||||
if !h.config.Robots.Wecom.Enabled {
|
||||
c.String(http.StatusNotFound, "")
|
||||
return
|
||||
}
|
||||
token, ok := h.wecomRequireToken(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串
|
||||
echostr := c.Query("echostr")
|
||||
msgSignature := c.Query("msg_signature")
|
||||
@@ -724,7 +739,7 @@ func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
|
||||
nonce := c.Query("nonce")
|
||||
|
||||
// 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1
|
||||
signature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, echostr)
|
||||
signature := h.signWecomRequest(token, timestamp, nonce, echostr)
|
||||
if signature != msgSignature {
|
||||
h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature))
|
||||
c.String(http.StatusBadRequest, "invalid signature")
|
||||
@@ -865,27 +880,28 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
|
||||
}
|
||||
h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw)))
|
||||
|
||||
// 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段
|
||||
// 若配置了 Token 则必须校验签名,避免未授权请求触发 Agent(防止平台被接管)
|
||||
token := h.config.Robots.Wecom.Token
|
||||
if token != "" {
|
||||
if msgSignature == "" {
|
||||
h.logger.Warn("企业微信 POST 缺少签名,已拒绝(需配置 token 并确保回调携带 msg_signature)")
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
var tmp wecomXML
|
||||
if err := xml.Unmarshal(bodyRaw, &tmp); err != nil {
|
||||
h.logger.Warn("企业微信 POST 签名验证前解析 XML 失败", zap.Error(err))
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
expected := h.signWecomRequest(token, timestamp, nonce, tmp.Encrypt)
|
||||
if expected != msgSignature {
|
||||
h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature))
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
// 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段。
|
||||
// 启用企业微信时必须配置 token 并校验签名,避免未授权请求触发 Agent。
|
||||
token, ok := h.wecomRequireToken(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if msgSignature == "" {
|
||||
h.logger.Warn("企业微信 POST 缺少签名,已拒绝(需确保回调携带 msg_signature)")
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
var tmp wecomXML
|
||||
if err := xml.Unmarshal(bodyRaw, &tmp); err != nil {
|
||||
h.logger.Warn("企业微信 POST 签名验证前解析 XML 失败", zap.Error(err))
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
expected := h.signWecomRequest(token, timestamp, nonce, tmp.Encrypt)
|
||||
if expected != msgSignature {
|
||||
h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature))
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
|
||||
var body wecomXML
|
||||
@@ -899,6 +915,13 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
|
||||
// 保存企业 ID(用于明文模式回复)
|
||||
enterpriseID := body.ToUserName
|
||||
|
||||
// 配置了 EncodingAESKey 时必须走加密消息,拒绝明文 XML 绕过
|
||||
if strings.TrimSpace(h.config.Robots.Wecom.EncodingAESKey) != "" && strings.TrimSpace(body.Encrypt) == "" {
|
||||
h.logger.Warn("企业微信已配置加密模式但收到明文消息,已拒绝")
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
|
||||
// 加密模式:先解密再解析内层 XML
|
||||
if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" {
|
||||
h.logger.Debug("企业微信进入加密模式解密流程")
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func newWecomTestHandler(token string, aesKey string) *RobotHandler {
|
||||
return &RobotHandler{
|
||||
config: &config.Config{
|
||||
Robots: config.RobotsConfig{
|
||||
Wecom: config.RobotWecomConfig{
|
||||
Enabled: true,
|
||||
Token: token,
|
||||
EncodingAESKey: aesKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
logger: zap.NewNop(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWecomPOST_rejectsWhenTokenEmpty(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
h := newWecomTestHandler("", "")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
body := `<?xml version="1.0"?><xml><FromUserName>attacker</FromUserName><MsgType>text</MsgType><Content>hi</Content></xml>`
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/robot/wecom", strings.NewReader(body))
|
||||
|
||||
h.HandleWecomPOST(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
if w.Body.String() == "success" {
|
||||
t.Fatal("expected rejection, got success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWecomPOST_rejectsPlaintextWhenEncryptionConfigured(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
h := newWecomTestHandler("secret-token", "abcdefghijklmnopqrstuvwxyz0123456789ABCD")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
body := `<?xml version="1.0"?><xml><FromUserName>attacker</FromUserName><MsgType>text</MsgType><Content>hi</Content></xml>`
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/robot/wecom?timestamp=1&nonce=2&msg_signature=fake", strings.NewReader(body))
|
||||
|
||||
h.HandleWecomPOST(c)
|
||||
|
||||
if w.Body.String() == "success" {
|
||||
t.Fatal("expected rejection for plaintext in encryption mode, got success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWecomGET_rejectsWhenTokenEmpty(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
h := newWecomTestHandler("", "")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/robot/wecom?msg_signature=x×tamp=1&nonce=2&echostr=abc", nil)
|
||||
|
||||
h.HandleWecomGET(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,7 @@ func shouldPersistEinoAgentTraceAfterRunError(baseCtx context.Context) bool {
|
||||
// AgentTask 描述正在运行的Agent任务
|
||||
type AgentTask struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
Status string `json:"status"`
|
||||
@@ -42,6 +43,9 @@ type AgentTask struct {
|
||||
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
|
||||
activeEinoExecuteAbortNote string
|
||||
|
||||
// hitlCognition 本轮运行中供 HITL/审计 Agent 读取的上下文(用户原话 + 思考,不含会话历史)
|
||||
hitlCognition *hitlCognitionState
|
||||
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
@@ -233,6 +237,7 @@ func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
|
||||
// CompletedTask 已完成的任务(用于历史记录)
|
||||
type CompletedTask struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
CompletedAt time.Time `json:"completedAt"`
|
||||
@@ -352,6 +357,7 @@ func (m *AgentTaskManager) StartTask(conversationID, message string, cancel cont
|
||||
}
|
||||
|
||||
m.tasks[conversationID] = task
|
||||
task.hitlCognition = &hitlCognitionState{UserMessage: strings.TrimSpace(message)}
|
||||
return task, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -45,9 +45,12 @@ type CreateVulnerabilityRequest struct {
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Preconditions string `json:"preconditions"`
|
||||
ReproSteps string `json:"reproduction_steps"`
|
||||
Evidence string `json:"evidence"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
RetestNotes string `json:"retest_notes"`
|
||||
}
|
||||
|
||||
// CreateVulnerability 创建漏洞
|
||||
@@ -69,9 +72,12 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
||||
Status: req.Status,
|
||||
Type: req.Type,
|
||||
Target: req.Target,
|
||||
Proof: req.Proof,
|
||||
Preconditions: req.Preconditions,
|
||||
ReproSteps: req.ReproSteps,
|
||||
Evidence: req.Evidence,
|
||||
Impact: req.Impact,
|
||||
Recommendation: req.Recommendation,
|
||||
RetestNotes: req.RetestNotes,
|
||||
}
|
||||
|
||||
created, err := h.db.CreateVulnerability(vuln)
|
||||
@@ -118,7 +124,7 @@ func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilt
|
||||
q = strings.TrimSpace(c.Query("search"))
|
||||
}
|
||||
return database.VulnerabilityListFilter{
|
||||
ProjectID: c.Query("project_id"),
|
||||
ProjectID: c.Query("project_id"),
|
||||
ID: c.Query("id"),
|
||||
Search: q,
|
||||
ConversationID: c.Query("conversation_id"),
|
||||
@@ -197,17 +203,20 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
// UpdateVulnerabilityRequest 更新漏洞请求
|
||||
type UpdateVulnerabilityRequest struct {
|
||||
ProjectID *string `json:"project_id"`
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
ConversationTag *string `json:"conversation_tag"`
|
||||
TaskTag *string `json:"task_tag"`
|
||||
Title *string `json:"title"`
|
||||
Description *string `json:"description"`
|
||||
Severity *string `json:"severity"`
|
||||
Status *string `json:"status"`
|
||||
Type *string `json:"type"`
|
||||
Target *string `json:"target"`
|
||||
Preconditions *string `json:"preconditions"`
|
||||
ReproSteps *string `json:"reproduction_steps"`
|
||||
Evidence *string `json:"evidence"`
|
||||
Impact *string `json:"impact"`
|
||||
Recommendation *string `json:"recommendation"`
|
||||
RetestNotes *string `json:"retest_notes"`
|
||||
}
|
||||
|
||||
// UpdateVulnerability 更新漏洞
|
||||
@@ -231,38 +240,47 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
||||
if req.ProjectID != nil {
|
||||
existing.ProjectID = strings.TrimSpace(*req.ProjectID)
|
||||
}
|
||||
if req.ConversationTag != "" {
|
||||
existing.ConversationTag = req.ConversationTag
|
||||
if req.ConversationTag != nil {
|
||||
existing.ConversationTag = *req.ConversationTag
|
||||
}
|
||||
if req.TaskTag != "" {
|
||||
existing.TaskTag = req.TaskTag
|
||||
if req.TaskTag != nil {
|
||||
existing.TaskTag = *req.TaskTag
|
||||
}
|
||||
if req.Title != "" {
|
||||
existing.Title = req.Title
|
||||
if req.Title != nil {
|
||||
existing.Title = *req.Title
|
||||
}
|
||||
if req.Description != "" {
|
||||
existing.Description = req.Description
|
||||
if req.Description != nil {
|
||||
existing.Description = *req.Description
|
||||
}
|
||||
if req.Severity != "" {
|
||||
existing.Severity = req.Severity
|
||||
if req.Severity != nil {
|
||||
existing.Severity = *req.Severity
|
||||
}
|
||||
if req.Status != "" {
|
||||
existing.Status = req.Status
|
||||
if req.Status != nil {
|
||||
existing.Status = *req.Status
|
||||
}
|
||||
if req.Type != "" {
|
||||
existing.Type = req.Type
|
||||
if req.Type != nil {
|
||||
existing.Type = *req.Type
|
||||
}
|
||||
if req.Target != "" {
|
||||
existing.Target = req.Target
|
||||
if req.Target != nil {
|
||||
existing.Target = *req.Target
|
||||
}
|
||||
if req.Proof != "" {
|
||||
existing.Proof = req.Proof
|
||||
if req.Preconditions != nil {
|
||||
existing.Preconditions = *req.Preconditions
|
||||
}
|
||||
if req.Impact != "" {
|
||||
existing.Impact = req.Impact
|
||||
if req.ReproSteps != nil {
|
||||
existing.ReproSteps = *req.ReproSteps
|
||||
}
|
||||
if req.Recommendation != "" {
|
||||
existing.Recommendation = req.Recommendation
|
||||
if req.Evidence != nil {
|
||||
existing.Evidence = *req.Evidence
|
||||
}
|
||||
if req.Impact != nil {
|
||||
existing.Impact = *req.Impact
|
||||
}
|
||||
if req.Recommendation != nil {
|
||||
existing.Recommendation = *req.Recommendation
|
||||
}
|
||||
if req.RetestNotes != nil {
|
||||
existing.RetestNotes = *req.RetestNotes
|
||||
}
|
||||
|
||||
if err := h.db.UpdateVulnerability(id, existing); err != nil {
|
||||
@@ -495,9 +513,19 @@ func appendVulnerabilityMarkdown(b *strings.Builder, v *database.Vulnerability,
|
||||
b.WriteString(v.Description)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Proof != "" {
|
||||
b.WriteString("\n#### 证明(POC)\n\n```\n")
|
||||
b.WriteString(v.Proof)
|
||||
if v.Preconditions != "" {
|
||||
b.WriteString("\n#### 前置条件\n\n")
|
||||
b.WriteString(v.Preconditions)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.ReproSteps != "" {
|
||||
b.WriteString("\n#### 复现步骤\n\n")
|
||||
b.WriteString(v.ReproSteps)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Evidence != "" {
|
||||
b.WriteString("\n#### 证据 / POC\n\n```\n")
|
||||
b.WriteString(v.Evidence)
|
||||
b.WriteString("\n```\n")
|
||||
}
|
||||
if v.Impact != "" {
|
||||
@@ -510,6 +538,11 @@ func appendVulnerabilityMarkdown(b *strings.Builder, v *database.Vulnerability,
|
||||
b.WriteString(v.Recommendation)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.RetestNotes != "" {
|
||||
b.WriteString("\n#### 复测方式\n\n")
|
||||
b.WriteString(v.RetestNotes)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
workflowrunner "cyberstrike-ai/internal/workflow"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type WorkflowHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
agent *agent.Agent
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewWorkflowHandler(db *database.DB, logger *zap.Logger) *WorkflowHandler {
|
||||
return &WorkflowHandler{db: db, logger: logger}
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
type workflowSaveRequest struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Version int `json:"version,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
Graph json.RawMessage `json:"graph,omitempty"`
|
||||
GraphJSON json.RawMessage `json:"graph_json,omitempty"`
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) List(c *gin.Context) {
|
||||
includeDisabled := strings.EqualFold(c.Query("includeDisabled"), "true") || c.Query("include_disabled") == "1"
|
||||
items, err := h.db.ListWorkflowDefinitions(includeDisabled)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"workflows": items})
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) Get(c *gin.Context) {
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
wf, err := h.db.GetWorkflowDefinition(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if wf == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "工作流不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"workflow": wf})
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) Create(c *gin.Context) {
|
||||
h.save(c, "")
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) Update(c *gin.Context) {
|
||||
h.save(c, c.Param("id"))
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) save(c *gin.Context, pathID string) {
|
||||
var req workflowSaveRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(req.ID)
|
||||
if strings.TrimSpace(pathID) != "" {
|
||||
id = strings.TrimSpace(pathID)
|
||||
}
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if id == "" || name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作流 id 和 name 不能为空"})
|
||||
return
|
||||
}
|
||||
graph := req.Graph
|
||||
if len(graph) == 0 {
|
||||
graph = req.GraphJSON
|
||||
}
|
||||
if len(graph) == 0 {
|
||||
graph = []byte(`{"nodes":[],"edges":[],"config":{}}`)
|
||||
}
|
||||
if !json.Valid(graph) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "graph 必须是合法 JSON"})
|
||||
return
|
||||
}
|
||||
if err := workflowrunner.ValidateGraphJSON(c.Request.Context(), string(graph)); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作流图无法编译: " + err.Error()})
|
||||
return
|
||||
}
|
||||
var probe interface{}
|
||||
if err := json.Unmarshal(graph, &probe); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "graph JSON 解析失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
enabled := true
|
||||
if req.Enabled != nil {
|
||||
enabled = *req.Enabled
|
||||
}
|
||||
wf := &database.WorkflowDefinition{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Description: strings.TrimSpace(req.Description),
|
||||
Version: req.Version,
|
||||
GraphJSON: string(graph),
|
||||
Enabled: enabled,
|
||||
}
|
||||
if err := h.db.UpsertWorkflowDefinition(wf); err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("保存工作流失败", zap.String("id", id), zap.Error(err))
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
saved, _ := h.db.GetWorkflowDefinition(id)
|
||||
workflowrunner.InvalidateCompiledCache(id)
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "workflow", "save", "保存图编排流程", "workflow", id, map[string]interface{}{"name": name})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "工作流已保存", "workflow": saved})
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) Delete(c *gin.Context) {
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作流 id 不能为空"})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteWorkflowDefinition(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
workflowrunner.InvalidateCompiledCache(id)
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "workflow", "delete", "删除图编排流程", "workflow", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "工作流已删除"})
|
||||
}
|
||||
@@ -0,0 +1,263 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
workflowrunner "cyberstrike-ai/internal/workflow"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (h *AgentHandler) roleForWorkflow(req *ChatRequest) (config.RoleConfig, bool) {
|
||||
if h == nil || h.config == nil || h.config.Roles == nil || req == nil {
|
||||
return config.RoleConfig{}, false
|
||||
}
|
||||
roleName := strings.TrimSpace(req.Role)
|
||||
if roleName == "" {
|
||||
return config.RoleConfig{}, false
|
||||
}
|
||||
role, ok := h.config.Roles[roleName]
|
||||
if !ok || !role.Enabled {
|
||||
return config.RoleConfig{}, false
|
||||
}
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
if !workflowrunner.ShouldAutoRunRoleWorkflow(role) {
|
||||
return config.RoleConfig{}, false
|
||||
}
|
||||
return role, true
|
||||
}
|
||||
|
||||
func (h *AgentHandler) runRoleWorkflowStreamIfBound(
|
||||
req *ChatRequest,
|
||||
prep *multiAgentPrepared,
|
||||
sendEvent func(eventType, message string, data interface{}),
|
||||
) bool {
|
||||
role, ok := h.roleForWorkflow(req)
|
||||
if !ok || prep == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
conversationID := prep.ConversationID
|
||||
assistantMessageID := prep.AssistantMessageID
|
||||
userMessage := ""
|
||||
if req != nil {
|
||||
userMessage = req.Message
|
||||
}
|
||||
|
||||
taskStatus := "completed"
|
||||
taskOwned := false
|
||||
defer func() {
|
||||
if taskOwned {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}
|
||||
}()
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
defer cancelWithCause(nil)
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, userMessage, cancelWithCause); err != nil {
|
||||
var errorMsg string
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
|
||||
sendEvent("error", errorMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"errorType": "task_already_running",
|
||||
})
|
||||
} else {
|
||||
errorMsg = "❌ 无法启动任务: " + err.Error()
|
||||
sendEvent("error", errorMsg, nil)
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID)
|
||||
}
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return true
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
progress := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
result, err := workflowrunner.RunRoleBoundWorkflow(taskCtx, workflowrunner.RunArgs{
|
||||
DB: h.db,
|
||||
Logger: h.logger,
|
||||
Role: role,
|
||||
AppCfg: h.config,
|
||||
Agent: h.agent,
|
||||
ConversationID: conversationID,
|
||||
ProjectID: h.conversationProjectID(conversationID),
|
||||
UserMessage: prep.FinalMessage,
|
||||
History: prep.History,
|
||||
RoleTools: prep.RoleTools,
|
||||
AgentsMarkdownDir: h.agentsMarkdownDir,
|
||||
SystemPromptExtra: h.agentSessionContextBlock(conversationID),
|
||||
AssistantMessageID: assistantMessageID,
|
||||
Progress: progress,
|
||||
})
|
||||
if err != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
taskStatus = "cancelled"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
if assistantMessageID != "" {
|
||||
if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.Error(err))
|
||||
}
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
sendEvent("cancelled", cancelMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) {
|
||||
taskStatus = "timeout"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
timeoutMsg := "任务执行超时,已自动终止。"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
sendEvent("error", timeoutMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return true
|
||||
}
|
||||
errMsg := "执行角色绑定流程失败: " + err.Error()
|
||||
taskStatus = "failed"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
sendEvent("error", errMsg, map[string]interface{}{"conversationId": conversationID})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return true
|
||||
}
|
||||
if prep.AssistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, nil, "")
|
||||
}
|
||||
payload := map[string]interface{}{
|
||||
"conversationId": prep.ConversationID,
|
||||
"messageId": prep.AssistantMessageID,
|
||||
"agentMode": "workflow",
|
||||
"workflowRunId": result.RunID,
|
||||
}
|
||||
if result.AwaitingHITL {
|
||||
payload["workflowStatus"] = "awaiting_hitl"
|
||||
payload["awaitingHitl"] = true
|
||||
}
|
||||
sendEvent("response", result.Response, payload)
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": prep.ConversationID})
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *AgentHandler) runRoleWorkflowJSONIfBound(c *gin.Context, req *ChatRequest, prep *multiAgentPrepared) bool {
|
||||
role, ok := h.roleForWorkflow(req)
|
||||
if !ok || prep == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
conversationID := prep.ConversationID
|
||||
assistantMessageID := prep.AssistantMessageID
|
||||
userMessage := ""
|
||||
if req != nil {
|
||||
userMessage = req.Message
|
||||
}
|
||||
|
||||
taskStatus := "completed"
|
||||
taskOwned := false
|
||||
defer func() {
|
||||
if taskOwned {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}
|
||||
}()
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
|
||||
defer cancelWithCause(nil)
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, userMessage, cancelWithCause); err != nil {
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。",
|
||||
"conversationId": conversationID,
|
||||
"errorType": "task_already_running",
|
||||
})
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "❌ 无法启动任务: " + err.Error()})
|
||||
}
|
||||
return true
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
progress := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, nil)
|
||||
result, err := workflowrunner.RunRoleBoundWorkflow(taskCtx, workflowrunner.RunArgs{
|
||||
DB: h.db,
|
||||
Logger: h.logger,
|
||||
Role: role,
|
||||
AppCfg: h.config,
|
||||
Agent: h.agent,
|
||||
ConversationID: conversationID,
|
||||
ProjectID: h.conversationProjectID(conversationID),
|
||||
UserMessage: prep.FinalMessage,
|
||||
History: prep.History,
|
||||
RoleTools: prep.RoleTools,
|
||||
AgentsMarkdownDir: h.agentsMarkdownDir,
|
||||
SystemPromptExtra: h.agentSessionContextBlock(conversationID),
|
||||
AssistantMessageID: assistantMessageID,
|
||||
Progress: progress,
|
||||
})
|
||||
if err != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
taskStatus = "cancelled"
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
if assistantMessageID != "" {
|
||||
_ = h.appendAssistantMessageNotice(assistantMessageID, cancelMsg)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "cancelled",
|
||||
"message": cancelMsg,
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
return true
|
||||
}
|
||||
errMsg := "执行角色绑定流程失败: " + err.Error()
|
||||
taskStatus = "failed"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg, "conversationId": conversationID})
|
||||
return true
|
||||
}
|
||||
if prep.AssistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, nil, "")
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"response": result.Response,
|
||||
"conversationId": prep.ConversationID,
|
||||
"assistantMessageId": prep.AssistantMessageID,
|
||||
"agentMode": "workflow",
|
||||
"workflowRunId": result.RunID,
|
||||
"workflowStatus": result.Status,
|
||||
"awaitingHitl": result.AwaitingHITL,
|
||||
})
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
workflowrunner "cyberstrike-ai/internal/workflow"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (h *WorkflowHandler) SetRuntime(agent *agent.Agent, cfg *config.Config) {
|
||||
h.agent = agent
|
||||
h.cfg = cfg
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) GetRun(c *gin.Context) {
|
||||
runID := strings.TrimSpace(c.Param("runId"))
|
||||
run, err := h.db.GetWorkflowRun(runID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if run == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "工作流运行不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"run": run})
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) ListPendingRuns(c *gin.Context) {
|
||||
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
||||
runs, err := h.db.ListWorkflowRunsAwaitingHITLFiltered(conversationID, 50)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"runs": runs})
|
||||
}
|
||||
|
||||
type workflowResumeRequest struct {
|
||||
Approved bool `json:"approved"`
|
||||
Comment string `json:"comment,omitempty"`
|
||||
}
|
||||
|
||||
func (h *WorkflowHandler) ResumeRun(c *gin.Context) {
|
||||
if h.agent == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "工作流运行时未初始化"})
|
||||
return
|
||||
}
|
||||
runID := strings.TrimSpace(c.Param("runId"))
|
||||
var req workflowResumeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
run, err := h.db.GetWorkflowRun(runID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if run == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "工作流运行不存在"})
|
||||
return
|
||||
}
|
||||
role := config.RoleConfig{Name: strings.TrimSpace(run.RoleID)}
|
||||
if role.Name != "" && h.cfg.Roles != nil {
|
||||
if r, ok := h.cfg.Roles[role.Name]; ok {
|
||||
role = r
|
||||
if role.Name == "" {
|
||||
role.Name = run.RoleID
|
||||
}
|
||||
}
|
||||
}
|
||||
if run.Status != "awaiting_hitl" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作流运行不在等待审批状态: " + run.Status})
|
||||
return
|
||||
}
|
||||
if err := h.db.RecordWorkflowRunHITLDecision(runID, req.Approved, req.Comment); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
decision := workflowrunner.HITLDecision{
|
||||
Approved: req.Approved,
|
||||
Comment: strings.TrimSpace(req.Comment),
|
||||
}
|
||||
delegated := workflowrunner.NotifyHITLDecision(runID, decision)
|
||||
if !delegated {
|
||||
for i := 0; i < 10; i++ {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if workflowrunner.NotifyHITLDecision(runID, decision) {
|
||||
delegated = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if delegated {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"workflowRunId": runID,
|
||||
"status": "delegated",
|
||||
"streamResuming": true,
|
||||
"approved": req.Approved,
|
||||
})
|
||||
return
|
||||
}
|
||||
result, err := workflowrunner.ResumeWorkflowRun(c.Request.Context(), workflowrunner.RunArgs{
|
||||
DB: h.db,
|
||||
Logger: h.logger,
|
||||
Role: role,
|
||||
AppCfg: h.cfg,
|
||||
Agent: h.agent,
|
||||
ConversationID: run.ConversationID,
|
||||
ProjectID: run.ProjectID,
|
||||
}, runID, req.Approved, req.Comment)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"response": result.Response,
|
||||
"workflowRunId": result.RunID,
|
||||
"status": result.Status,
|
||||
"awaitingHitl": result.AwaitingHITL,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package hitl
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const retentionPurgeInterval = time.Hour
|
||||
|
||||
// Service manages HITL audit log retention (decided hitl_interrupts rows).
|
||||
type Service struct {
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewService creates a HITL audit log retention service.
|
||||
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
|
||||
return &Service{db: db, cfg: cfg, logger: logger}
|
||||
}
|
||||
|
||||
// RetentionDays returns configured retention; 0 means keep forever.
|
||||
func (s *Service) RetentionDays() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return config.HitlConfig{}.RetentionDaysEffective()
|
||||
}
|
||||
return s.cfg.Hitl.RetentionDaysEffective()
|
||||
}
|
||||
|
||||
// PurgeExpired deletes decided HITL log rows older than retention_days when configured.
|
||||
func (s *Service) PurgeExpired() {
|
||||
if s == nil || s.db == nil || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
days := s.cfg.Hitl.RetentionDaysEffective()
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
n, err := s.db.PurgeHitlInterruptLogsBefore(cutoff)
|
||||
if err != nil {
|
||||
if s.logger != nil {
|
||||
s.logger.Warn("清理过期人机协同审计日志失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && s.logger != nil {
|
||||
s.logger.Info("已清理过期人机协同审计日志", zap.Int64("deleted", n), zap.Int("retention_days", days))
|
||||
}
|
||||
}
|
||||
|
||||
// StartRetentionLoop periodically purges expired HITL audit log rows.
|
||||
func StartRetentionLoop(s *Service, logger *zap.Logger) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(retentionPurgeInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.PurgeExpired()
|
||||
if logger != nil {
|
||||
logger.Debug("hitl audit log retention tick completed")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package hitl
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
appconfig "cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestServicePurgeExpired_respectsZeroRetention(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "hitl.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS hitl_interrupts (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
mode TEXT NOT NULL,
|
||||
tool_name TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
decision TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
decided_at DATETIME
|
||||
)`); err != nil {
|
||||
t.Fatalf("create table: %v", err)
|
||||
}
|
||||
|
||||
old := time.Now().AddDate(0, 0, -100).UTC().Format(time.RFC3339)
|
||||
if _, err := db.Exec(`INSERT INTO hitl_interrupts
|
||||
(id, conversation_id, mode, tool_name, status, decision, created_at, decided_at)
|
||||
VALUES ('old-1', 'c1', 'approval', 'exec', 'decided', 'approve', ?, ?)`, old, old); err != nil {
|
||||
t.Fatalf("insert: %v", err)
|
||||
}
|
||||
|
||||
zero := 0
|
||||
svc := NewService(db, &appconfig.Config{
|
||||
Hitl: appconfig.HitlConfig{RetentionDays: &zero},
|
||||
}, zap.NewNop())
|
||||
svc.PurgeExpired()
|
||||
|
||||
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'old-1'`).Scan(new(string)); err != nil {
|
||||
t.Fatalf("record should remain when retention_days=0: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// knowledgePipelineRetriever: MultiQuery → vector candidates → rerank → post-process.
|
||||
type knowledgePipelineRetriever struct {
|
||||
inner retriever.Retriever
|
||||
base *Retriever
|
||||
}
|
||||
|
||||
func newKnowledgePipelineRetriever(inner retriever.Retriever, base *Retriever) *knowledgePipelineRetriever {
|
||||
if inner == nil || base == nil {
|
||||
return nil
|
||||
}
|
||||
return &knowledgePipelineRetriever{inner: inner, base: base}
|
||||
}
|
||||
|
||||
func (p *knowledgePipelineRetriever) GetType() string {
|
||||
return "KnowledgeRAGPipeline"
|
||||
}
|
||||
|
||||
func (p *knowledgePipelineRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) {
|
||||
if p == nil || p.inner == nil || p.base == nil {
|
||||
return nil, fmt.Errorf("knowledge pipeline retriever: nil")
|
||||
}
|
||||
q := strings.TrimSpace(query)
|
||||
if q == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
|
||||
ro := retriever.GetCommonOptions(nil, opts...)
|
||||
finalTopK := p.base.config.TopK
|
||||
if finalTopK <= 0 {
|
||||
finalTopK = 5
|
||||
}
|
||||
if ro.TopK != nil && *ro.TopK > 0 {
|
||||
finalTopK = *ro.TopK
|
||||
}
|
||||
|
||||
ctx = callbacks.EnsureRunInfo(ctx, p.GetType(), components.ComponentOfRetriever)
|
||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{Query: q, TopK: finalTopK, Extra: ro.DSLInfo})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
return
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out})
|
||||
}()
|
||||
|
||||
out, err = p.inner.Retrieve(ctx, q, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
if rr := p.base.documentReranker(); rr != nil && len(out) > 1 {
|
||||
reranked, rerr := rr.Rerank(ctx, q, out)
|
||||
if rerr != nil {
|
||||
if p.base.logger != nil {
|
||||
p.base.logger.Warn("知识检索重排失败,已使用融合序", zap.Error(rerr))
|
||||
}
|
||||
} else if len(reranked) > 0 {
|
||||
out = reranked
|
||||
}
|
||||
}
|
||||
|
||||
tokenModel := ""
|
||||
if p.base.embedder != nil {
|
||||
tokenModel = p.base.embedder.EmbeddingModelName()
|
||||
}
|
||||
var postPO *config.PostRetrieveConfig
|
||||
if p.base.config != nil {
|
||||
postPO = &p.base.config.PostRetrieve
|
||||
}
|
||||
out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
var _ retriever.Retriever = (*knowledgePipelineRetriever)(nil)
|
||||
@@ -8,8 +8,7 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。
|
||||
// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。
|
||||
// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain(MultiQuery → 向量 → 重排 → 后处理)。
|
||||
func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) {
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("retriever is nil")
|
||||
|
||||
@@ -11,19 +11,10 @@ import (
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity.
|
||||
//
|
||||
// Options:
|
||||
// - [retriever.WithTopK]
|
||||
// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 0–1), [DSLSubIndexFilter] (string)
|
||||
//
|
||||
// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric.
|
||||
//
|
||||
// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then
|
||||
// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig].
|
||||
// It returns prefetch-sized vector candidates only; rerank and post-process run in [knowledgePipelineRetriever].
|
||||
type VectorEinoRetriever struct {
|
||||
inner *Retriever
|
||||
}
|
||||
@@ -119,26 +110,6 @@ func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts .
|
||||
return nil, err
|
||||
}
|
||||
out = retrievalResultsToDocuments(results)
|
||||
|
||||
if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 {
|
||||
reranked, rerr := rr.Rerank(ctx, q, out)
|
||||
if rerr != nil {
|
||||
if h.inner.logger != nil {
|
||||
h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr))
|
||||
}
|
||||
} else if len(reranked) > 0 {
|
||||
out = reranked
|
||||
}
|
||||
}
|
||||
|
||||
tokenModel := ""
|
||||
if h.inner.embedder != nil {
|
||||
tokenModel = h.inner.embedder.EmbeddingModelName()
|
||||
}
|
||||
out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HTTPReranker calls a hosted rerank API (DashScope or Cohere-compatible).
|
||||
type HTTPReranker struct {
|
||||
provider string
|
||||
model string
|
||||
baseURL string
|
||||
apiKey string
|
||||
client *http.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewHTTPReranker builds a rerank client from knowledge retrieval config; openAI supplies fallback credentials.
|
||||
func NewHTTPReranker(rc *config.RerankConfig, openAI *config.OpenAIConfig, logger *zap.Logger) (*HTTPReranker, error) {
|
||||
if rc == nil {
|
||||
return nil, fmt.Errorf("rerank config is nil")
|
||||
}
|
||||
baseURL := strings.TrimSpace(rc.BaseURL)
|
||||
apiKey := strings.TrimSpace(rc.APIKey)
|
||||
if openAI != nil {
|
||||
if baseURL == "" {
|
||||
baseURL = strings.TrimSpace(openAI.BaseURL)
|
||||
}
|
||||
if apiKey == "" {
|
||||
apiKey = strings.TrimSpace(openAI.APIKey)
|
||||
}
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("rerank api_key is required")
|
||||
}
|
||||
provider := rc.ProviderEffective(baseURL)
|
||||
model := rc.ModelEffective(provider)
|
||||
return &HTTPReranker{
|
||||
provider: provider,
|
||||
model: model,
|
||||
baseURL: strings.TrimSuffix(baseURL, "/"),
|
||||
apiKey: apiKey,
|
||||
client: &http.Client{Timeout: 60 * time.Second},
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *HTTPReranker) Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) {
|
||||
if r == nil {
|
||||
return docs, nil
|
||||
}
|
||||
q := strings.TrimSpace(query)
|
||||
if q == "" || len(docs) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
if len(docs) == 1 {
|
||||
return docs, nil
|
||||
}
|
||||
texts := make([]string, 0, len(docs))
|
||||
for _, d := range docs {
|
||||
if d == nil {
|
||||
texts = append(texts, "")
|
||||
continue
|
||||
}
|
||||
texts = append(texts, d.Content)
|
||||
}
|
||||
var order []int
|
||||
var err error
|
||||
switch r.provider {
|
||||
case "dashscope":
|
||||
order, err = r.rerankDashScope(ctx, q, texts, len(docs))
|
||||
default:
|
||||
order, err = r.rerankCohere(ctx, q, texts, len(docs))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]*schema.Document, 0, len(order))
|
||||
for _, idx := range order {
|
||||
if idx < 0 || idx >= len(docs) || docs[idx] == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, docs[idx])
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *HTTPReranker) rerankCohere(ctx context.Context, query string, documents []string, topN int) ([]int, error) {
|
||||
url := r.cohereRerankURL()
|
||||
body := map[string]any{
|
||||
"model": r.model,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": topN,
|
||||
}
|
||||
raw, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+r.apiKey)
|
||||
resp, err := r.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rerank request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("rerank http %d: %s", resp.StatusCode, truncateForRerankLog(string(respBody)))
|
||||
}
|
||||
var parsed struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
} `json:"results"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("rerank decode: %w", err)
|
||||
}
|
||||
order := make([]int, 0, len(parsed.Results))
|
||||
for _, row := range parsed.Results {
|
||||
order = append(order, row.Index)
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *HTTPReranker) rerankDashScope(ctx context.Context, query string, documents []string, topN int) ([]int, error) {
|
||||
url := r.dashscopeRerankURL()
|
||||
body := map[string]any{
|
||||
"model": r.model,
|
||||
"input": map[string]any{
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
"parameters": map[string]any{
|
||||
"return_documents": false,
|
||||
"top_n": topN,
|
||||
},
|
||||
}
|
||||
raw, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+r.apiKey)
|
||||
resp, err := r.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dashscope rerank: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("dashscope rerank http %d: %s", resp.StatusCode, truncateForRerankLog(string(respBody)))
|
||||
}
|
||||
var parsed struct {
|
||||
Output struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
} `json:"results"`
|
||||
} `json:"output"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("dashscope rerank decode: %w", err)
|
||||
}
|
||||
order := make([]int, 0, len(parsed.Output.Results))
|
||||
for _, row := range parsed.Output.Results {
|
||||
order = append(order, row.Index)
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *HTTPReranker) cohereRerankURL() string {
|
||||
base := r.baseURL
|
||||
if base == "" {
|
||||
base = "https://api.cohere.com"
|
||||
}
|
||||
if strings.HasSuffix(base, "/v1") {
|
||||
return base + "/rerank"
|
||||
}
|
||||
return base + "/v1/rerank"
|
||||
}
|
||||
|
||||
func (r *HTTPReranker) dashscopeRerankURL() string {
|
||||
base := strings.TrimSpace(r.baseURL)
|
||||
if base == "" {
|
||||
return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
|
||||
}
|
||||
if strings.Contains(base, "/api/v1/services/rerank") {
|
||||
return base
|
||||
}
|
||||
if strings.Contains(base, "dashscope.aliyuncs.com") || strings.Contains(base, "compatible-mode") {
|
||||
return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
|
||||
}
|
||||
return strings.TrimSuffix(base, "/")
|
||||
}
|
||||
|
||||
func truncateForRerankLog(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if len(s) > 512 {
|
||||
return s[:512] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
var _ DocumentReranker = (*HTTPReranker)(nil)
|
||||
@@ -0,0 +1,97 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestHTTPReranker_CohereOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/rerank" {
|
||||
t.Fatalf("path %s", r.URL.Path)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"results": []map[string]any{
|
||||
{"index": 2, "relevance_score": 0.9},
|
||||
{"index": 0, "relevance_score": 0.5},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rr, err := NewHTTPReranker(&config.RerankConfig{
|
||||
Provider: "cohere",
|
||||
Model: "rerank-multilingual-v3.0",
|
||||
BaseURL: srv.URL,
|
||||
APIKey: "test-key",
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
docs := []*schema.Document{
|
||||
{ID: "a", Content: "alpha"},
|
||||
{ID: "b", Content: "beta"},
|
||||
{ID: "c", Content: "gamma"},
|
||||
}
|
||||
out, err := rr.Rerank(context.Background(), "query", docs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out) != 2 || out[0].ID != "c" || out[1].ID != "a" {
|
||||
t.Fatalf("order wrong: %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPReranker_DashScopeOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"output": map[string]any{
|
||||
"results": []map[string]any{
|
||||
{"index": 1, "relevance_score": 0.88},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rr, err := NewHTTPReranker(&config.RerankConfig{
|
||||
Provider: "dashscope",
|
||||
Model: "gte-rerank",
|
||||
BaseURL: srv.URL,
|
||||
APIKey: "test-key",
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
docs := []*schema.Document{{ID: "a", Content: "a"}, {ID: "b", Content: "b"}}
|
||||
out, err := rr.Rerank(context.Background(), "q", docs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out) != 1 || out[0].ID != "b" {
|
||||
t.Fatalf("got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRerankConfigDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
rc := config.RerankConfig{}
|
||||
if rc.ProviderEffective("https://dashscope.aliyuncs.com/x") != "dashscope" {
|
||||
t.Fatal("dashscope detect")
|
||||
}
|
||||
if rc.ModelEffective("dashscope") != "gte-rerank" {
|
||||
t.Fatal("dashscope model")
|
||||
}
|
||||
if rc.ModelEffective("cohere") != "rerank-multilingual-v3.0" {
|
||||
t.Fatal("cohere model")
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。
|
||||
const postRetrieveMaxPrefetchCap = 200
|
||||
|
||||
// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。
|
||||
// DocumentReranker 精排(HTTP dashscope / Cohere 兼容 API),由 [WireRetrieverPipeline] 注入。
|
||||
type DocumentReranker interface {
|
||||
Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error)
|
||||
}
|
||||
@@ -167,13 +167,16 @@ func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int,
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。
|
||||
// EffectivePrefetchTopK 计算每条 MultiQuery 变体在向量阶段的候选条数(供融合 / 重排 / 后处理)。
|
||||
func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
|
||||
if topK < 1 {
|
||||
topK = 5
|
||||
}
|
||||
fetch := topK
|
||||
if po != nil && po.PrefetchTopK > fetch {
|
||||
fetch := topK * 4
|
||||
if fetch < 20 {
|
||||
fetch = 20
|
||||
}
|
||||
if po != nil && po.PrefetchTopK > 0 {
|
||||
fetch = po.PrefetchTopK
|
||||
}
|
||||
if fetch > postRetrieveMaxPrefetchCap {
|
||||
@@ -182,7 +185,7 @@ func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
|
||||
return fetch
|
||||
}
|
||||
|
||||
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。
|
||||
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK(精排已在流水线中完成)。
|
||||
func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) {
|
||||
if finalTopK < 1 {
|
||||
finalTopK = 5
|
||||
|
||||
@@ -28,8 +28,8 @@ func TestDedupeByNormalizedContent(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEffectivePrefetchTopK(t *testing.T) {
|
||||
if g := EffectivePrefetchTopK(5, nil); g != 5 {
|
||||
t.Fatalf("got %d", g)
|
||||
if g := EffectivePrefetchTopK(5, nil); g != 20 {
|
||||
t.Fatalf("default prefetch got %d want 20", g)
|
||||
}
|
||||
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 {
|
||||
t.Fatalf("got %d", g)
|
||||
|
||||
@@ -27,15 +27,19 @@ type Retriever struct {
|
||||
|
||||
rerankMu sync.RWMutex
|
||||
reranker DocumentReranker
|
||||
|
||||
pipeline retriever.Retriever
|
||||
wireOpenAI *config.OpenAIConfig
|
||||
}
|
||||
|
||||
// RetrievalConfig 检索配置
|
||||
type RetrievalConfig struct {
|
||||
TopK int
|
||||
SimilarityThreshold float64
|
||||
// SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。
|
||||
SubIndexFilter string
|
||||
PostRetrieve config.PostRetrieveConfig
|
||||
SubIndexFilter string
|
||||
MultiQuery config.MultiQueryConfig
|
||||
Rerank config.RerankConfig
|
||||
PostRetrieve config.PostRetrieveConfig
|
||||
}
|
||||
|
||||
// NewRetriever 创建新的检索器
|
||||
@@ -48,7 +52,7 @@ func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logge
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig 更新检索配置
|
||||
// UpdateConfig 更新检索配置并重建 Eino MultiQuery + 重排流水线。
|
||||
func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) {
|
||||
if cfg != nil {
|
||||
r.config = cfg
|
||||
@@ -57,12 +61,18 @@ func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) {
|
||||
zap.Int("top_k", cfg.TopK),
|
||||
zap.Float64("similarity_threshold", cfg.SimilarityThreshold),
|
||||
zap.String("sub_index_filter", cfg.SubIndexFilter),
|
||||
zap.Int("multi_query_max", cfg.MultiQuery.MaxQueriesEffective()),
|
||||
zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK),
|
||||
zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars),
|
||||
zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens),
|
||||
)
|
||||
}
|
||||
}
|
||||
if r.wireOpenAI != nil {
|
||||
if err := WireRetrieverPipeline(context.Background(), r, r.wireOpenAI); err != nil && r.logger != nil {
|
||||
r.logger.Warn("检索流水线重建失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。
|
||||
@@ -103,7 +113,7 @@ func cosineSimilarity(a, b []float32) float64 {
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// Search 搜索知识库。统一经 [VectorEinoRetriever](Eino retriever.Retriever 边界)。
|
||||
// Search 搜索知识库(Eino MultiQuery → 向量检索 → 重排 → 后处理)。
|
||||
func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("请求不能为空")
|
||||
@@ -113,7 +123,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
opts := r.einoRetrieverOptions(req)
|
||||
docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...)
|
||||
docs, err := r.activeEinoRetriever().Retrieve(ctx, q, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -143,7 +153,19 @@ func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option
|
||||
|
||||
// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。
|
||||
func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...)
|
||||
return r.activeEinoRetriever().Retrieve(ctx, query, opts...)
|
||||
}
|
||||
|
||||
func (r *Retriever) activeEinoRetriever() retriever.Retriever {
|
||||
if r != nil && r.pipeline != nil {
|
||||
return r.pipeline
|
||||
}
|
||||
return NewVectorEinoRetriever(r)
|
||||
}
|
||||
|
||||
// AsEinoRetriever 将知识库检索流水线暴露为 Eino [retriever.Retriever]。
|
||||
func (r *Retriever) AsEinoRetriever() retriever.Retriever {
|
||||
return r.activeEinoRetriever()
|
||||
}
|
||||
|
||||
func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) {
|
||||
@@ -299,7 +321,14 @@ func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*Re
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。
|
||||
func (r *Retriever) AsEinoRetriever() retriever.Retriever {
|
||||
return NewVectorEinoRetriever(r)
|
||||
// RetrievalConfigFromYAML maps API/YAML retrieval settings into the knowledge package.
|
||||
func RetrievalConfigFromYAML(r config.RetrievalConfig) *RetrievalConfig {
|
||||
return &RetrievalConfig{
|
||||
TopK: r.TopK,
|
||||
SimilarityThreshold: r.SimilarityThreshold,
|
||||
SubIndexFilter: r.SubIndexFilter,
|
||||
MultiQuery: r.MultiQuery,
|
||||
Rerank: r.Rerank,
|
||||
PostRetrieve: r.PostRetrieve,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/flow/retriever/multiquery"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// WireRetrieverPipeline builds Eino MultiQuery + HTTP rerank + post-process pipeline on r.
|
||||
// Call once after NewRetriever; UpdateConfig re-invokes when wireOpenAI is set.
|
||||
func WireRetrieverPipeline(ctx context.Context, r *Retriever, openAI *config.OpenAIConfig) error {
|
||||
if r == nil {
|
||||
return fmt.Errorf("retriever is nil")
|
||||
}
|
||||
if openAI == nil {
|
||||
return fmt.Errorf("openai config is nil")
|
||||
}
|
||||
if r.config == nil {
|
||||
return fmt.Errorf("retrieval config is nil")
|
||||
}
|
||||
r.wireOpenAI = openAI
|
||||
|
||||
httpClient := openai.NewEinoHTTPClient(openAI, &http.Client{Timeout: 120 * time.Second})
|
||||
chatCfg := &einoopenai.ChatModelConfig{
|
||||
APIKey: strings.TrimSpace(openAI.APIKey),
|
||||
BaseURL: strings.TrimSuffix(strings.TrimSpace(openAI.BaseURL), "/"),
|
||||
Model: strings.TrimSpace(openAI.Model),
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
if chatCfg.Model == "" {
|
||||
chatCfg.Model = "gpt-4o"
|
||||
}
|
||||
rewriteLLM, err := einoopenai.NewChatModel(ctx, chatCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("multi_query rewrite model: %w", err)
|
||||
}
|
||||
|
||||
reranker, err := NewHTTPReranker(&r.config.Rerank, openAI, r.logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reranker: %w", err)
|
||||
}
|
||||
r.SetDocumentReranker(reranker)
|
||||
|
||||
vec := NewVectorEinoRetriever(r)
|
||||
mq, err := multiquery.NewRetriever(ctx, &multiquery.Config{
|
||||
RewriteLLM: rewriteLLM,
|
||||
MaxQueriesNum: r.config.MultiQuery.MaxQueriesEffective(),
|
||||
OrigRetriever: vec,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("multi_query: %w", err)
|
||||
}
|
||||
|
||||
r.pipeline = newKnowledgePipelineRetriever(mq, r)
|
||||
if r.logger != nil {
|
||||
provider := r.config.Rerank.ProviderEffective(strings.TrimSpace(openAI.BaseURL))
|
||||
r.logger.Info("知识库检索流水线已启用",
|
||||
zap.String("pipeline", "MultiQuery→Vector→Rerank→PostRetrieve"),
|
||||
zap.Int("multi_query_max", r.config.MultiQuery.MaxQueriesEffective()),
|
||||
zap.String("rerank_provider", provider),
|
||||
zap.String("rerank_model", r.config.Rerank.ModelEffective(provider)),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
)
|
||||
|
||||
const defaultEmptyResponseContinueMaxAttempts = 5
|
||||
|
||||
// IsEinoEmptyResponseResult 判断 Run 是否以「未捕获助手正文」占位结束(非真实用户可见回复)。
|
||||
func IsEinoEmptyResponseResult(result *RunResult) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
return isEinoEmptyResponseText(result.Response)
|
||||
}
|
||||
|
||||
func isEinoEmptyResponseText(s string) bool {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(s, "no assistant text was captured") ||
|
||||
strings.Contains(s, "未捕获到助手文本输出")
|
||||
}
|
||||
|
||||
// HasEinoResumeTrace 轨迹非空,续跑才有上下文可恢复。
|
||||
func HasEinoResumeTrace(result *RunResult) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.TrimSpace(result.LastAgentTraceInput)
|
||||
return s != "" && s != "[]" && s != "null"
|
||||
}
|
||||
|
||||
// EmptyResponseContinueMaxAttemptsFromConfig 无助手正文时 Handler 层退避续跑上限;0=默认 5。
|
||||
func EmptyResponseContinueMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
||||
if mw != nil && mw.EmptyResponseContinueMaxAttempts > 0 {
|
||||
return mw.EmptyResponseContinueMaxAttempts
|
||||
}
|
||||
return defaultEmptyResponseContinueMaxAttempts
|
||||
}
|
||||
|
||||
// EmptyResponseContinueBackoff 与 run_retry 相同指数退避(2s, 4s, 8s… capped)。
|
||||
func EmptyResponseContinueBackoff(attempt int, mw *config.MultiAgentEinoMiddlewareConfig) time.Duration {
|
||||
maxBackoff := defaultEinoRunRetryMaxBackoff
|
||||
if mw != nil && mw.RunRetryMaxBackoffSec > 0 {
|
||||
maxBackoff = time.Duration(mw.RunRetryMaxBackoffSec) * time.Second
|
||||
}
|
||||
return einoTransientRetryBackoff(attempt, maxBackoff)
|
||||
}
|
||||
|
||||
// FormatEmptyResponseContinueUserMessage 系统自动续跑时注入的 user 轮次(不写入 messages 表气泡)。
|
||||
func FormatEmptyResponseContinueUserMessage() string {
|
||||
return strings.TrimSpace(`【系统自动续跑 / Auto resume】
|
||||
上一轮 Eino 会话未产出可见助手正文(可能流式中断或仅完成工具调用)。请基于已有轨迹与工具结果继续推进,并给出阶段性总结;勿重复已完成步骤。`)
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package multiagent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsEinoEmptyResponseResult(t *testing.T) {
|
||||
empty := &RunResult{
|
||||
Response: "(Eino ADK single-agent session completed but no assistant text was captured. Check process details or logs.) " +
|
||||
"(Eino ADK 单代理会话已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
|
||||
}
|
||||
if !IsEinoEmptyResponseResult(empty) {
|
||||
t.Fatal("expected empty placeholder response")
|
||||
}
|
||||
ok := &RunResult{Response: "扫描完成,发现 2 个开放端口。"}
|
||||
if IsEinoEmptyResponseResult(ok) {
|
||||
t.Fatalf("expected real response, got placeholder match")
|
||||
}
|
||||
if IsEinoEmptyResponseResult(nil) {
|
||||
t.Fatal("nil result should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasEinoResumeTrace(t *testing.T) {
|
||||
if HasEinoResumeTrace(nil) {
|
||||
t.Fatal("nil")
|
||||
}
|
||||
if HasEinoResumeTrace(&RunResult{LastAgentTraceInput: "[]"}) {
|
||||
t.Fatal("enable resume on empty trace")
|
||||
}
|
||||
if !HasEinoResumeTrace(&RunResult{LastAgentTraceInput: `[{"role":"user","content":"hi"}]`}) {
|
||||
t.Fatal("expected resume trace")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyResponseContinueMaxAttemptsFromConfig(t *testing.T) {
|
||||
if got := EmptyResponseContinueMaxAttemptsFromConfig(nil); got != defaultEmptyResponseContinueMaxAttempts {
|
||||
t.Fatalf("default: got %d want %d", got, defaultEmptyResponseContinueMaxAttempts)
|
||||
}
|
||||
}
|
||||
@@ -80,34 +80,9 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
return nil, fmt.Errorf("plan_execute replanner: %w", err)
|
||||
}
|
||||
|
||||
// 组装 executor handler 栈,顺序与 Deep/Supervisor 主代理一致(outermost first)。
|
||||
var execHandlers []adk.ChatModelAgentMiddleware
|
||||
// 1. patchtoolcalls, reduction, toolsearch, plantask(来自 prependEinoMiddlewares)
|
||||
if len(a.ExecPreMiddlewares) > 0 {
|
||||
execHandlers = append(execHandlers, a.ExecPreMiddlewares...)
|
||||
}
|
||||
// 2. filesystem 中间件(可选)
|
||||
if a.FilesystemMiddleware != nil {
|
||||
execHandlers = append(execHandlers, a.FilesystemMiddleware)
|
||||
}
|
||||
// 3. skill 中间件(可选)
|
||||
if a.SkillMiddleware != nil {
|
||||
execHandlers = append(execHandlers, a.SkillMiddleware)
|
||||
}
|
||||
// 4. pre-summarization normalize + continuation dedup, then summarization (与 Deep/Supervisor 一致)
|
||||
if a.AppCfg != nil {
|
||||
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger)
|
||||
if sumErr != nil {
|
||||
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
|
||||
}
|
||||
execHandlers = appendEinoChatModelTailMiddlewares(execHandlers, einoChatModelTailConfig{
|
||||
logger: a.Logger,
|
||||
phase: "plan_execute_executor",
|
||||
summarization: sumMw,
|
||||
modelName: a.ModelName,
|
||||
conversationID: a.ConversationID,
|
||||
trace: a.ModelFacingTrace,
|
||||
})
|
||||
execHandlers, err := buildPlanExecuteExecutorHandlers(ctx, a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
||||
Model: a.ExecModel,
|
||||
@@ -130,6 +105,39 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
})
|
||||
}
|
||||
|
||||
// buildPlanExecuteExecutorHandlers 组装 Executor 中间件栈(outermost first),与 Deep/Supervisor 主代理对齐:
|
||||
// ExecPreMiddlewares(patch / reduction / toolsearch / plantask)→ filesystem → skill → summarization tail。
|
||||
func buildPlanExecuteExecutorHandlers(ctx context.Context, a *PlanExecuteRootArgs) ([]adk.ChatModelAgentMiddleware, error) {
|
||||
if a == nil {
|
||||
return nil, fmt.Errorf("plan_execute: args 为空")
|
||||
}
|
||||
var execHandlers []adk.ChatModelAgentMiddleware
|
||||
if len(a.ExecPreMiddlewares) > 0 {
|
||||
execHandlers = append(execHandlers, a.ExecPreMiddlewares...)
|
||||
}
|
||||
if a.FilesystemMiddleware != nil {
|
||||
execHandlers = append(execHandlers, a.FilesystemMiddleware)
|
||||
}
|
||||
if a.SkillMiddleware != nil {
|
||||
execHandlers = append(execHandlers, a.SkillMiddleware)
|
||||
}
|
||||
if a.AppCfg != nil {
|
||||
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger)
|
||||
if sumErr != nil {
|
||||
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
|
||||
}
|
||||
execHandlers = appendEinoChatModelTailMiddlewares(execHandlers, einoChatModelTailConfig{
|
||||
logger: a.Logger,
|
||||
phase: "plan_execute_executor",
|
||||
summarization: sumMw,
|
||||
modelName: a.ModelName,
|
||||
conversationID: a.ConversationID,
|
||||
trace: a.ModelFacingTrace,
|
||||
})
|
||||
}
|
||||
return execHandlers, nil
|
||||
}
|
||||
|
||||
// planExecutePlannerGenInput 将 orchestrator instruction 作为 SystemMessage 注入 planner 输入。
|
||||
// 返回 nil 时 Eino 使用内置默认 planner prompt。
|
||||
func planExecutePlannerGenInput(
|
||||
|
||||
@@ -22,15 +22,60 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。
|
||||
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。
|
||||
// einoSummarizeUserInstruction:压缩历史时保留渗透测试与用户约束关键信息。
|
||||
// 结构对齐 Eino 最佳实践(禁止工具、<analysis>+<summary>、<all_user_messages>),章节为安全测试领域化。
|
||||
const einoSummarizeUserInstruction = `关键:仅以纯文本响应。禁止调用任何工具(read_file、exec、grep、glob、write、edit 等)。
|
||||
上述对话中已包含全部待压缩上下文;不要要求用户粘贴历史,不要输出「请提供待压缩的对话历史」等占位/meta 回复。
|
||||
工具调用将被拒绝并浪费唯一一次摘要机会。
|
||||
|
||||
必须保留:已确认漏洞与攻击路径、工具输出中的核心发现、凭证与认证细节、架构与薄弱点、当前进度、失败尝试与死路、策略决策。
|
||||
保留精确技术细节(URL、路径、参数、Payload、版本号、报错原文可摘要但要点不丢)。
|
||||
将冗长扫描输出概括为结论;重复发现合并表述。
|
||||
已枚举资产须保留**可继承的摘要**:主域、关键子域/主机短表(或数量+代表样例)、高价值目标与已识别服务/端口要点,避免后续子代理因「看不见清单」而重复全量枚举。
|
||||
你的任务:在保持所有关键安全测试信息完整的前提下压缩对话历史,使后续代理能无缝继续同一授权测试任务。
|
||||
|
||||
输出须使后续代理能无缝继续同一授权测试任务。`
|
||||
压缩原则:
|
||||
- 必须保留:已确认漏洞与攻击路径、工具输出核心发现、凭证与认证细节、架构与薄弱点、当前进度、失败尝试与死路、策略决策
|
||||
- 保留精确技术细节(URL、路径、参数、Payload、版本号;报错原文可摘要但要点不丢)
|
||||
- 冗长扫描输出概括为结论;重复发现合并表述
|
||||
- 已枚举资产须保留可继承摘要:主域、关键子域/主机短表(或数量+代表样例)、高价值目标、已识别服务/端口要点
|
||||
|
||||
输出格式(严格遵循,仅一轮回复):
|
||||
1. 先输出 <analysis> 块:按时间顺序梳理对话,检查是否涵盖下方各章节要点;analysis 仅供自检,保持简洁(建议 ≤400 字)
|
||||
2. 再输出 <summary> 块:按以下章节写入可继承的压缩报告(无信息处写「无」,禁止留空模板占位符)
|
||||
|
||||
<summary>
|
||||
## 1. 授权范围与约束
|
||||
- 目标/范围/禁止项(域名、路径、IP、环境)
|
||||
- 凭证/认证信息(账号、Token、Cookie;敏感值原文保留)
|
||||
- 用户指定的方法、工具、优先级与待办
|
||||
- 否定约束(不测什么、不用什么手法)
|
||||
|
||||
## 2. 资产与服务枚举摘要
|
||||
- 主域/核心资产、关键子域或主机短表(或数量+代表样例)
|
||||
- 高价值目标、已识别服务/端口要点
|
||||
- 资产状态(存活/可攻/已排除/待验证)
|
||||
|
||||
## 3. 架构与已知薄弱点
|
||||
- 技术栈/部署拓扑/信任边界
|
||||
- 已识别薄弱点列表
|
||||
|
||||
## 4. 已确认漏洞与攻击路径
|
||||
- 漏洞名/CVE、URL/路径、参数/Payload、PoC 要点、影响等级
|
||||
- 攻击链/利用路径(步骤化)
|
||||
|
||||
## 5. 工具核心发现与扫描结论
|
||||
- 各工具结论(概括核心输出,非冗长日志)
|
||||
- 重复发现合并表述
|
||||
|
||||
## 6. 所有用户消息
|
||||
<all_user_messages>
|
||||
- [逐条列出非 tool 结果的用户消息要点;敏感约束与原文措辞尽量保留]
|
||||
</all_user_messages>
|
||||
|
||||
## 7. 当前进度、策略决策与下一步
|
||||
- 当前位置(已完成/进行中/卡点)
|
||||
- 失败尝试与死路(方法、现象/报错摘要、结论)
|
||||
- 策略决策与下一步具体操作(须与最近用户请求及未完成任务一致)
|
||||
</summary>
|
||||
|
||||
提醒:不要调用任何工具;必须基于上文已有对话直接输出 <analysis> 与 <summary>,勿输出 analysis 以外的正文。`
|
||||
|
||||
// newEinoSummarizationMiddleware 使用 Eino ADK Summarization 中间件(见 https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_summarization/)。
|
||||
// 触发阈值:估算 token 超过 openai.max_total_tokens * summarization_trigger_ratio(默认 0.8)时摘要。
|
||||
@@ -144,13 +189,13 @@ func newEinoSummarizationMiddleware(
|
||||
},
|
||||
},
|
||||
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
|
||||
summary = stripAnalysisFromSummarizationMessage(summary)
|
||||
out, ferr := summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
|
||||
if ferr != nil {
|
||||
return nil, ferr
|
||||
}
|
||||
if appCfg != nil {
|
||||
out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger)
|
||||
out = refreshUserVerbatimAnchorInMessages(out, db, conversationID, appCfg.MultiAgent.UserVerbatimAnchorMaxRunesEffective(), logger)
|
||||
}
|
||||
return out, nil
|
||||
},
|
||||
@@ -414,36 +459,6 @@ func writeSummarizationTranscript(path string, msgs []adk.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshUserVerbatimAnchorInMessages 压缩后从 messages 表刷新 system 中的用户原文锚点。
|
||||
func refreshUserVerbatimAnchorInMessages(msgs []adk.Message, db *database.DB, conversationID string, maxRunes int, logger *zap.Logger) []adk.Message {
|
||||
if maxRunes < 0 || db == nil {
|
||||
return msgs
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return msgs
|
||||
}
|
||||
rows, err := db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("summarization: 刷新用户原文锚点失败",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
block := project.BuildUserVerbatimAnchorBlockFromMessages(rows, maxRunes)
|
||||
if block == "" {
|
||||
return msgs
|
||||
}
|
||||
out := project.RefreshUserVerbatimAnchorInMessages(msgs, block)
|
||||
if logger != nil {
|
||||
logger.Info("summarization: 已刷新用户原文锚点", zap.String("conversationId", conversationID))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||||
tc := agent.NewTikTokenCounter()
|
||||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
var (
|
||||
summarizationAnalysisBlockRegex = regexp.MustCompile(`(?is)<analysis>\s*.*?\s*</analysis>`)
|
||||
summarizationSummaryBlockRegex = regexp.MustCompile(`(?is)<summary>\s*(.*?)\s*</summary>`)
|
||||
)
|
||||
|
||||
// stripAnalysisFromSummarizationMessage removes the <analysis> block from a post-processed
|
||||
// Eino summary user message. Analysis helps one-shot generation quality but should not
|
||||
// occupy continuation context after compaction.
|
||||
func stripAnalysisFromSummarizationMessage(msg adk.Message) adk.Message {
|
||||
if msg == nil {
|
||||
return msg
|
||||
}
|
||||
cloned := *msg
|
||||
if cloned.Content != "" {
|
||||
cloned.Content = stripAnalysisFromSummarizationText(cloned.Content)
|
||||
}
|
||||
if len(cloned.UserInputMultiContent) > 0 {
|
||||
parts := make([]schema.MessageInputPart, len(cloned.UserInputMultiContent))
|
||||
copy(parts, cloned.UserInputMultiContent)
|
||||
// Only the first text part carries model output plus Eino preamble/transcript path.
|
||||
for i := range parts {
|
||||
if parts[i].Type != schema.ChatMessagePartTypeText || parts[i].Text == "" {
|
||||
continue
|
||||
}
|
||||
if i == 0 {
|
||||
parts[i].Text = stripAnalysisFromSummarizationText(parts[i].Text)
|
||||
}
|
||||
break
|
||||
}
|
||||
cloned.UserInputMultiContent = parts
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func stripAnalysisFromSummarizationText(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
stripped := strings.TrimSpace(summarizationAnalysisBlockRegex.ReplaceAllString(text, ""))
|
||||
if stripped == "" {
|
||||
return text
|
||||
}
|
||||
return stripped
|
||||
}
|
||||
|
||||
// extractSummarizationSummaryBody returns the inner text of the last <summary> block when present.
|
||||
// Used by tests and optional strict compaction paths.
|
||||
func extractSummarizationSummaryBody(text string) (string, bool) {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return "", false
|
||||
}
|
||||
all := summarizationSummaryBlockRegex.FindAllStringSubmatch(text, -1)
|
||||
if len(all) == 0 || len(all[len(all)-1]) < 2 {
|
||||
return "", false
|
||||
}
|
||||
body := strings.TrimSpace(all[len(all)-1][1])
|
||||
if body == "" {
|
||||
return "", false
|
||||
}
|
||||
return body, true
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestStripAnalysisFromSummarizationText(t *testing.T) {
|
||||
in := "<analysis>internal notes</analysis>\n\n<summary>\n## 1. 授权\n- example.com\n</summary>"
|
||||
got := stripAnalysisFromSummarizationText(in)
|
||||
if strings.Contains(got, "<analysis>") {
|
||||
t.Fatalf("analysis block should be removed: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "## 1. 授权") {
|
||||
t.Fatalf("summary body should remain: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripAnalysisFromSummarizationMessage_UserInputMultiContent(t *testing.T) {
|
||||
msg := &schema.Message{
|
||||
Role: schema.User,
|
||||
UserInputMultiContent: []schema.MessageInputPart{
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: "此会话延续自此前一段因上下文耗尽而终止的对话。\n\n<analysis>draft</analysis>\n<summary>body</summary>\n\n完整记录位于:/tmp/transcript.txt",
|
||||
},
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: "请从我们中断的地方继续对话,无需向用户提出任何进一步的问题。",
|
||||
},
|
||||
},
|
||||
}
|
||||
out := stripAnalysisFromSummarizationMessage(msg)
|
||||
if len(out.UserInputMultiContent) != 2 {
|
||||
t.Fatalf("expected 2 parts, got %d", len(out.UserInputMultiContent))
|
||||
}
|
||||
if strings.Contains(out.UserInputMultiContent[0].Text, "<analysis>") {
|
||||
t.Fatalf("part 0 should drop analysis: %q", out.UserInputMultiContent[0].Text)
|
||||
}
|
||||
if !strings.Contains(out.UserInputMultiContent[0].Text, "<summary>body</summary>") {
|
||||
t.Fatalf("part 0 should keep summary: %q", out.UserInputMultiContent[0].Text)
|
||||
}
|
||||
if out.UserInputMultiContent[1].Text != "请从我们中断的地方继续对话,无需向用户提出任何进一步的问题。" {
|
||||
t.Fatalf("continue instruction part should be unchanged: %q", out.UserInputMultiContent[1].Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSummarizationSummaryBody(t *testing.T) {
|
||||
body, ok := extractSummarizationSummaryBody("<analysis>x</analysis><summary> kept </summary>")
|
||||
if !ok || body != "kept" {
|
||||
t.Fatalf("extract summary body: ok=%v body=%q", ok, body)
|
||||
}
|
||||
_, ok = extractSummarizationSummaryBody("plain text only")
|
||||
if ok {
|
||||
t.Fatal("expected false for plain text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripAnalysisFromSummarizationText_NoAnalysisUnchanged(t *testing.T) {
|
||||
in := "<summary>only summary</summary>"
|
||||
got := stripAnalysisFromSummarizationText(in)
|
||||
if got != in {
|
||||
t.Fatalf("expected unchanged text, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -1,35 +1,12 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
copenai "cyberstrike-ai/internal/openai"
|
||||
)
|
||||
|
||||
// stripReasoningFromSummarizationPayload removes thinking / reasoning fields from a
|
||||
// chat-completions JSON body. Applied only to summarization Generate calls via
|
||||
// model.ModelOptions on the shared ChatModel — main-agent requests are unchanged.
|
||||
func stripReasoningFromSummarizationPayload(rawBody []byte) ([]byte, error) {
|
||||
var payload map[string]any
|
||||
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
|
||||
return rawBody, nil
|
||||
}
|
||||
changed := false
|
||||
for _, key := range []string{
|
||||
"thinking",
|
||||
"reasoning_effort",
|
||||
"output_config",
|
||||
"reasoning",
|
||||
} {
|
||||
if _, ok := payload[key]; ok {
|
||||
delete(payload, key)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return rawBody, nil
|
||||
}
|
||||
out, err := sonic.Marshal(payload)
|
||||
if err != nil {
|
||||
return rawBody, err
|
||||
}
|
||||
return out, nil
|
||||
return copenai.StripReasoningFromChatCompletionBody(rawBody)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package multiagent
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -75,8 +74,8 @@ func hitlInvokableToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
if err != nil {
|
||||
if IsHumanRejectError(err) {
|
||||
// Human rejection should be a soft tool result so the model can continue iterating.
|
||||
msg := fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by human reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.",
|
||||
input.Name, strings.TrimSpace(err.Error()))
|
||||
// tool_search 须保持 JSON,否则 Eino toolsearch 中间件解析历史时会硬崩 ChatModel。
|
||||
msg := HitlRejectToolResult(input.Name, err.Error())
|
||||
// transfer_to_agent 在 Eino 中标记为 returnDirectly:工具成功后 ReAct 子图会直接 END,
|
||||
// 并依赖真实工具内的 SendToolGenAction 触发移交。HITL 拒绝时不会执行真实工具,
|
||||
// 若仍走 returnDirectly 分支,监督者会在无 Transfer 动作的情况下结束,模型不再迭代。
|
||||
@@ -103,8 +102,7 @@ func hitlStreamableToolCallMiddleware() compose.StreamableToolMiddleware {
|
||||
edited, err := fn(ctx, input.Name, input.Arguments)
|
||||
if err != nil {
|
||||
if IsHumanRejectError(err) {
|
||||
msg := fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by human reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.",
|
||||
input.Name, strings.TrimSpace(err.Error()))
|
||||
msg := HitlRejectToolResult(input.Name, err.Error())
|
||||
hitlClearReturnDirectlyIfTransfer(ctx, input.Name)
|
||||
return &compose.StreamToolOutput{
|
||||
Result: schema.StreamReaderFromArray([]string{msg}),
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const toolSearchToolName = "tool_search"
|
||||
|
||||
// HitlExemptMetaTools 为编排/元工具:不直接执行攻击动作,但会阻塞 agent 控制流。
|
||||
// tool_search 必须免审批,否则其 HITL 拒绝结果与 Eino toolsearch 中间件不兼容(会硬崩 ChatModel)。
|
||||
var HitlExemptMetaTools = []string{
|
||||
toolSearchToolName,
|
||||
"skill",
|
||||
"task",
|
||||
"write_todos",
|
||||
"transfer_to_agent",
|
||||
"exit",
|
||||
"TaskCreate",
|
||||
"TaskGet",
|
||||
"TaskUpdate",
|
||||
"TaskList",
|
||||
}
|
||||
|
||||
// IsToolSearchTool reports whether name is the Eino dynamictool tool_search meta-tool.
|
||||
func IsToolSearchTool(name string) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(name), toolSearchToolName)
|
||||
}
|
||||
|
||||
// MergeHitlExemptMetaTools unions configured whitelist with built-in meta-tool exemptions.
|
||||
func MergeHitlExemptMetaTools(configured []string) []string {
|
||||
merged := make([]string, 0, len(configured)+len(HitlExemptMetaTools))
|
||||
seen := make(map[string]struct{}, len(configured)+len(HitlExemptMetaTools))
|
||||
add := func(name string) {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
if n == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
return
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
merged = append(merged, strings.TrimSpace(name))
|
||||
}
|
||||
for _, t := range configured {
|
||||
add(t)
|
||||
}
|
||||
for _, t := range HitlExemptMetaTools {
|
||||
add(t)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
type toolSearchHitlRejectPayload struct {
|
||||
SelectedTools []string `json:"selectedTools"`
|
||||
HitlRejected bool `json:"_hitlRejected"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// HitlRejectToolResult returns a tool result body safe for downstream consumers.
|
||||
// tool_search must stay JSON-shaped so toolsearch.extractSelectedTools does not terminate the graph.
|
||||
func HitlRejectToolResult(toolName, reason string) string {
|
||||
reason = strings.TrimSpace(reason)
|
||||
if !IsToolSearchTool(toolName) {
|
||||
if reason == "" {
|
||||
reason = "rejected by reviewer"
|
||||
}
|
||||
return fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.",
|
||||
strings.TrimSpace(toolName), reason)
|
||||
}
|
||||
payload := toolSearchHitlRejectPayload{
|
||||
SelectedTools: []string{},
|
||||
HitlRejected: true,
|
||||
Reason: reason,
|
||||
}
|
||||
if payload.Reason == "" {
|
||||
payload.Reason = "tool_search rejected by reviewer; no dynamic tools unlocked"
|
||||
}
|
||||
out, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return `{"selectedTools":[],"_hitlRejected":true,"reason":"tool_search rejected by reviewer"}`
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHitlRejectToolResult_toolSearchIsJSON(t *testing.T) {
|
||||
raw := HitlRejectToolResult("tool_search", "rejected by user: timeout")
|
||||
var payload toolSearchHitlRejectPayload
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if len(payload.SelectedTools) != 0 {
|
||||
t.Fatalf("expected empty selectedTools, got %v", payload.SelectedTools)
|
||||
}
|
||||
if !payload.HitlRejected {
|
||||
t.Fatal("expected _hitlRejected true")
|
||||
}
|
||||
if !strings.Contains(payload.Reason, "timeout") {
|
||||
t.Fatalf("reason=%q", payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHitlRejectToolResult_otherToolKeepsLegacyText(t *testing.T) {
|
||||
raw := HitlRejectToolResult("nmap", "too risky")
|
||||
if strings.HasPrefix(raw, "{") {
|
||||
t.Fatalf("expected legacy text, got %q", raw)
|
||||
}
|
||||
if !strings.HasPrefix(raw, "[HITL Reject]") {
|
||||
t.Fatalf("expected [HITL Reject] prefix, got %q", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeHitlExemptMetaTools_includesToolSearch(t *testing.T) {
|
||||
merged := MergeHitlExemptMetaTools([]string{"read_file"})
|
||||
found := false
|
||||
for _, name := range merged {
|
||||
if IsToolSearchTool(name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("tool_search missing from %v", merged)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
)
|
||||
|
||||
type stubChatModelAgentMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
tag string
|
||||
}
|
||||
|
||||
func stubMW(tag string) adk.ChatModelAgentMiddleware {
|
||||
return &stubChatModelAgentMiddleware{tag: tag}
|
||||
}
|
||||
|
||||
func TestBuildPlanExecuteExecutorHandlers_IncludesExecPreMiddlewares(t *testing.T) {
|
||||
t.Parallel()
|
||||
pre := []adk.ChatModelAgentMiddleware{
|
||||
stubMW("patch"),
|
||||
stubMW("reduction"),
|
||||
}
|
||||
|
||||
got, err := buildPlanExecuteExecutorHandlers(context.Background(), &PlanExecuteRootArgs{
|
||||
ExecPreMiddlewares: pre,
|
||||
FilesystemMiddleware: stubMW("filesystem"),
|
||||
SkillMiddleware: stubMW("skill"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildPlanExecuteExecutorHandlers: %v", err)
|
||||
}
|
||||
if len(got) != 4 {
|
||||
t.Fatalf("expected 4 pre-tail handlers (2 pre + fs + skill), got %d", len(got))
|
||||
}
|
||||
for i, want := range []string{"patch", "reduction", "filesystem", "skill"} {
|
||||
st, ok := got[i].(*stubChatModelAgentMiddleware)
|
||||
if !ok || st.tag != want {
|
||||
t.Fatalf("handler[%d]: got %#v want tag %q", i, got[i], want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stubTools(n int) []tool.BaseTool {
|
||||
out := make([]tool.BaseTool, n)
|
||||
for i := 0; i < n; i++ {
|
||||
out[i] = stubTool{name: fmt.Sprintf("t%d", i)}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestBuildPlanExecuteExecutorHandlers_NilArgs(t *testing.T) {
|
||||
t.Parallel()
|
||||
if _, err := buildPlanExecuteExecutorHandlers(context.Background(), nil); err == nil {
|
||||
t.Fatal("expected error for nil args")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrependEinoMiddlewares_Main_IncludesPatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
mw := configMultiAgentEinoMiddlewareForTest()
|
||||
mw.ReductionEnable = false
|
||||
mw.ToolSearchEnable = false
|
||||
mw.PlantaskEnable = false
|
||||
_, extra, _, err := prependEinoMiddlewares(ctx, mw, einoMWMain, stubTools(25), nil, "", "conv-test", "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("prependEinoMiddlewares: %v", err)
|
||||
}
|
||||
if len(extra) == 0 {
|
||||
t.Fatal("expected patch middleware on einoMWMain when patch_tool_calls enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func configMultiAgentEinoMiddlewareForTest() *config.MultiAgentEinoMiddlewareConfig {
|
||||
patch := true
|
||||
return &config.MultiAgentEinoMiddlewareConfig{
|
||||
PatchToolCalls: &patch,
|
||||
}
|
||||
}
|
||||
@@ -432,6 +432,22 @@ func RunDeepAgent(
|
||||
var da adk.Agent
|
||||
switch orchMode {
|
||||
case "plan_execute":
|
||||
plannerModelCfg := &einoopenai.ChatModelConfig{
|
||||
APIKey: appCfg.OpenAI.APIKey,
|
||||
BaseURL: strings.TrimSuffix(appCfg.OpenAI.BaseURL, "/"),
|
||||
Model: appCfg.OpenAI.Model,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
reasoning.ApplyPlanExecutePlannerModelConfig(plannerModelCfg, &appCfg.OpenAI)
|
||||
peMainModel, perr := einoopenai.NewChatModel(ctx, plannerModelCfg)
|
||||
if perr != nil {
|
||||
return nil, fmt.Errorf("plan_execute 规划模型: %w", perr)
|
||||
}
|
||||
if logger != nil {
|
||||
logger.Info("plan_execute: planner/replanner 使用无 reasoning 的独立 ChatModel(ToolChoiceForced 兼容)",
|
||||
zap.String("model", appCfg.OpenAI.Model),
|
||||
)
|
||||
}
|
||||
execModel, perr := einoopenai.NewChatModel(ctx, baseModelCfg)
|
||||
if perr != nil {
|
||||
return nil, fmt.Errorf("plan_execute 执行器模型: %w", perr)
|
||||
@@ -445,7 +461,7 @@ func RunDeepAgent(
|
||||
}
|
||||
}
|
||||
peRoot, perr := NewPlanExecuteRoot(ctx, &PlanExecuteRootArgs{
|
||||
MainToolCallingModel: mainModel,
|
||||
MainToolCallingModel: peMainModel,
|
||||
ExecModel: execModel,
|
||||
OrchInstruction: orchInstruction,
|
||||
ToolsCfg: mainToolsCfg,
|
||||
@@ -458,6 +474,7 @@ func RunDeepAgent(
|
||||
ProjectID: projectID,
|
||||
Logger: logger,
|
||||
ModelName: appCfg.OpenAI.Model,
|
||||
// 与 Deep/Supervisor 主代理同源:patch / reduction / toolsearch / plantask(见 buildPlanExecuteExecutorHandlers)。
|
||||
ExecPreMiddlewares: mainOrchestratorPre,
|
||||
SkillMiddleware: einoSkillMW,
|
||||
FilesystemMiddleware: peFsMw,
|
||||
|
||||
@@ -806,10 +806,12 @@ func isClaudeProvider(cfg *config.OpenAIConfig) bool {
|
||||
// Eino HTTP Client Bridge
|
||||
// ============================================================
|
||||
|
||||
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个 http.Client,包含两层 transport 包装:
|
||||
// 1. 当 cfg.Provider 为 claude 时,最内层套 claudeRoundTripper,把 OpenAI /chat/completions 透明
|
||||
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个 http.Client,包含多层 transport 包装:
|
||||
// 1. 当 cfg.Provider 为 claude 时,套 claudeRoundTripper,把 OpenAI /chat/completions 透明
|
||||
// 桥接为 Anthropic /v1/messages(并把 Claude SSE 翻译回 OpenAI SSE 格式)。
|
||||
// 2. 最外层无条件套 einoSSESanitizingRoundTripper,吞掉中转站发的 SSE 心跳/注释/控制行
|
||||
// 2. reasoningToolChoiceCompatRoundTripper:tool_choice=required/object 时剥离 thinking 字段,避免
|
||||
// plan_execute replanner 等强制工具调用与推理模式冲突(部分网关返回 400)。
|
||||
// 3. 最外层无条件套 einoSSESanitizingRoundTripper,吞掉中转站发的 SSE 心跳/注释/控制行
|
||||
// (": keepalive" / "event: ping" / "retry: 3000" 等),避免 Eino 用的 meguminnnnnnnnn/go-openai
|
||||
// SDK 在累计超过 300 个非 "data:" 行后抛 "stream has sent too many empty messages"。
|
||||
//
|
||||
@@ -825,6 +827,7 @@ func NewEinoHTTPClient(cfg *config.OpenAIConfig, base *http.Client) *http.Client
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
transport = &reasoningToolChoiceCompatRoundTripper{base: transport}
|
||||
if isClaudeProvider(cfg) {
|
||||
transport = &claudeRoundTripper{
|
||||
base: transport,
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
)
|
||||
|
||||
// reasoningPayloadKeys are OpenAI-compatible root fields that enable "thinking" /
|
||||
// extended-reasoning modes on gateways such as DashScope/Qwen and MiniMax.
|
||||
var reasoningPayloadKeys = []string{
|
||||
"thinking",
|
||||
"reasoning_effort",
|
||||
"output_config",
|
||||
"reasoning",
|
||||
}
|
||||
|
||||
// StripReasoningFromChatCompletionBody removes thinking / reasoning fields from a
|
||||
// chat-completions JSON body.
|
||||
func StripReasoningFromChatCompletionBody(rawBody []byte) ([]byte, error) {
|
||||
var payload map[string]any
|
||||
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
|
||||
return rawBody, nil
|
||||
}
|
||||
if !stripReasoningFields(payload) {
|
||||
return rawBody, nil
|
||||
}
|
||||
out, err := sonic.Marshal(payload)
|
||||
if err != nil {
|
||||
return rawBody, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// StripReasoningIfForcedToolChoice removes thinking / reasoning fields when the
|
||||
// request sets tool_choice to "required" or an object. Several providers reject
|
||||
// that combination (e.g. DashScope: "tool_choice does not support being set to
|
||||
// required or object in thinking mode").
|
||||
func StripReasoningIfForcedToolChoice(rawBody []byte) ([]byte, error) {
|
||||
var payload map[string]any
|
||||
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
|
||||
return rawBody, nil
|
||||
}
|
||||
if !forcedToolChoiceIncompatibleWithThinking(payload) {
|
||||
return rawBody, nil
|
||||
}
|
||||
if !stripReasoningFields(payload) {
|
||||
return rawBody, nil
|
||||
}
|
||||
out, err := sonic.Marshal(payload)
|
||||
if err != nil {
|
||||
return rawBody, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func stripReasoningFields(payload map[string]any) bool {
|
||||
changed := false
|
||||
for _, key := range reasoningPayloadKeys {
|
||||
if _, ok := payload[key]; ok {
|
||||
delete(payload, key)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func forcedToolChoiceIncompatibleWithThinking(payload map[string]any) bool {
|
||||
tc, ok := payload["tool_choice"]
|
||||
if !ok || tc == nil {
|
||||
return false
|
||||
}
|
||||
switch v := tc.(type) {
|
||||
case string:
|
||||
return v == "required"
|
||||
case map[string]any:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStripReasoningFromChatCompletionBody(t *testing.T) {
|
||||
in := []byte(`{"model":"deepseek-chat","messages":[],"thinking":{"type":"enabled"},"reasoning_effort":"high"}`)
|
||||
out, err := StripReasoningFromChatCompletionBody(in)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := string(out)
|
||||
if strings.Contains(s, "thinking") || strings.Contains(s, "reasoning_effort") {
|
||||
t.Fatalf("expected reasoning fields stripped, got %s", s)
|
||||
}
|
||||
if !strings.Contains(s, `"model":"deepseek-chat"`) {
|
||||
t.Fatalf("expected model preserved, got %s", s)
|
||||
}
|
||||
|
||||
plain := []byte(`{"model":"gpt-4o","messages":[]}`)
|
||||
out2, err := StripReasoningFromChatCompletionBody(plain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(out2) != string(plain) {
|
||||
t.Fatalf("expected unchanged payload, got %s", out2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripReasoningIfForcedToolChoice(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
strip bool
|
||||
contain string
|
||||
}{
|
||||
{
|
||||
name: "required strips thinking",
|
||||
in: `{"model":"minimax","messages":[],"thinking":{"type":"enabled"},"tool_choice":"required","tools":[]}`,
|
||||
strip: true,
|
||||
},
|
||||
{
|
||||
name: "object tool_choice strips thinking",
|
||||
in: `{"model":"qwen","messages":[],"thinking":{"type":"enabled"},"tool_choice":{"type":"function","function":{"name":"respond"}}}`,
|
||||
strip: true,
|
||||
},
|
||||
{
|
||||
name: "auto keeps thinking",
|
||||
in: `{"model":"qwen","messages":[],"thinking":{"type":"enabled"},"tool_choice":"auto"}`,
|
||||
strip: false,
|
||||
contain: "thinking",
|
||||
},
|
||||
{
|
||||
name: "no tool_choice keeps thinking",
|
||||
in: `{"model":"qwen","messages":[],"thinking":{"type":"enabled"}}`,
|
||||
strip: false,
|
||||
contain: "thinking",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
out, err := StripReasoningIfForcedToolChoice([]byte(tc.in))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := string(out)
|
||||
hasThinking := strings.Contains(s, "thinking")
|
||||
if tc.strip && hasThinking {
|
||||
t.Fatalf("expected thinking stripped, got %s", s)
|
||||
}
|
||||
if !tc.strip && tc.contain != "" && !strings.Contains(s, tc.contain) {
|
||||
t.Fatalf("expected %q in %s", tc.contain, s)
|
||||
}
|
||||
if !tc.strip && string(out) != tc.in {
|
||||
t.Fatalf("expected unchanged payload, got %s", s)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningToolChoiceCompatRoundTripper(t *testing.T) {
|
||||
var gotBody string
|
||||
rt := &reasoningToolChoiceCompatRoundTripper{
|
||||
base: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
b, _ := io.ReadAll(req.Body)
|
||||
gotBody = string(b)
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader(`{"choices":[{"message":{"content":"ok"}}]}`)),
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1/chat/completions", strings.NewReader(
|
||||
`{"model":"m","thinking":{"type":"enabled"},"tool_choice":"required","messages":[]}`,
|
||||
))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(gotBody, "thinking") {
|
||||
t.Fatalf("expected thinking stripped in transit, got %s", gotBody)
|
||||
}
|
||||
if !strings.Contains(gotBody, `"tool_choice":"required"`) {
|
||||
t.Fatalf("expected tool_choice preserved, got %s", gotBody)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// reasoningToolChoiceCompatRoundTripper strips thinking/reasoning fields from
|
||||
// chat/completions requests that force tool_choice, which some gateways reject
|
||||
// when thinking mode is enabled on the same request.
|
||||
type reasoningToolChoiceCompatRoundTripper struct {
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *reasoningToolChoiceCompatRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if rt == nil || rt.base == nil || req == nil || req.Body == nil {
|
||||
if rt != nil && rt.base != nil {
|
||||
return rt.base.RoundTrip(req)
|
||||
}
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
}
|
||||
if req.Method != http.MethodPost || !strings.HasSuffix(req.URL.Path, "/chat/completions") {
|
||||
return rt.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
_ = req.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
patched, perr := StripReasoningIfForcedToolChoice(body)
|
||||
if perr != nil {
|
||||
patched = body
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(patched))
|
||||
req.ContentLength = int64(len(patched))
|
||||
req.Header.Set("Content-Length", strconv.Itoa(len(patched)))
|
||||
return rt.base.RoundTrip(req)
|
||||
}
|
||||
@@ -1,170 +0,0 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
// UserVerbatimSectionHeading 用户原文锚点可读标题(块内保留,供 Agent 阅读)。
|
||||
UserVerbatimSectionHeading = "## 用户历史输入(原文保留,勿省略或改写)"
|
||||
|
||||
// UserVerbatimSectionStartMarker / EndMarker:HTML 注释边界,供程序化替换;对模型无指令语义。
|
||||
UserVerbatimSectionStartMarker = "<!-- user-verbatim-start -->"
|
||||
UserVerbatimSectionEndMarker = "<!-- user-verbatim-end -->"
|
||||
)
|
||||
|
||||
// ExtractUserContentsFromMessages 按时间顺序提取 user 角色消息的原文(跳过空白)。
|
||||
func ExtractUserContentsFromMessages(msgs []database.Message) []string {
|
||||
out := make([]string, 0, len(msgs))
|
||||
for i := range msgs {
|
||||
if !strings.EqualFold(strings.TrimSpace(msgs[i].Role), "user") {
|
||||
continue
|
||||
}
|
||||
content := strings.TrimSpace(msgs[i].Content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, content)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// BuildUserVerbatimAnchorBlockFromMessages 从 messages 表行构建用户原文锚点块。
|
||||
// maxRunes: 0 = 不截断;>0 = 总 rune 上限(仍保留每一轮,仅对超长单条做尾部截断提示)。
|
||||
func BuildUserVerbatimAnchorBlockFromMessages(msgs []database.Message, maxRunes int) string {
|
||||
return BuildUserVerbatimAnchorBlock(ExtractUserContentsFromMessages(msgs), maxRunes)
|
||||
}
|
||||
|
||||
// BuildUserVerbatimAnchorBlock 将各轮用户原文格式化为 system prompt 锚点块。
|
||||
func BuildUserVerbatimAnchorBlock(userContents []string, maxRunes int) string {
|
||||
if len(userContents) == 0 {
|
||||
return ""
|
||||
}
|
||||
lines := make([]string, 0, len(userContents))
|
||||
for _, content := range userContents {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("[第%d轮] %s", len(lines)+1, content))
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return ""
|
||||
}
|
||||
body := strings.Join(lines, "\n")
|
||||
if maxRunes > 0 {
|
||||
body = capUserVerbatimBody(body, maxRunes)
|
||||
}
|
||||
return wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n" + body)
|
||||
}
|
||||
|
||||
func capUserVerbatimBody(body string, maxRunes int) string {
|
||||
rs := []rune(body)
|
||||
if len(rs) <= maxRunes {
|
||||
return body
|
||||
}
|
||||
suffix := "\n\n...(用户原文锚点已达配置上限,更早轮次可能被截断;完整原文见 messages 表)..."
|
||||
suffixRunes := []rune(suffix)
|
||||
keep := maxRunes - len(suffixRunes)
|
||||
if keep <= 0 {
|
||||
return string(rs[:maxRunes])
|
||||
}
|
||||
return string(rs[:keep]) + suffix
|
||||
}
|
||||
|
||||
func wrapUserVerbatimBlock(content string) string {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
return UserVerbatimSectionStartMarker + "\n" + content + "\n" + UserVerbatimSectionEndMarker + "\n"
|
||||
}
|
||||
|
||||
// ReplaceUserVerbatimAnchorSection 用 freshBlock 替换 content 中已有的用户原文锚点段。
|
||||
func ReplaceUserVerbatimAnchorSection(content, freshBlock string) (string, bool) {
|
||||
content = strings.TrimSpace(content)
|
||||
freshBlock = strings.TrimSpace(freshBlock)
|
||||
if freshBlock == "" {
|
||||
return content, false
|
||||
}
|
||||
start, ok := userVerbatimSectionStart(content)
|
||||
if !ok {
|
||||
return content, false
|
||||
}
|
||||
end, ok := userVerbatimSectionEnd(content, start)
|
||||
if !ok {
|
||||
return content, false
|
||||
}
|
||||
return strings.TrimSpace(content[:start] + freshBlock + content[end:]), true
|
||||
}
|
||||
|
||||
func userVerbatimSectionStart(content string) (int, bool) {
|
||||
idx := strings.Index(content, UserVerbatimSectionStartMarker)
|
||||
if idx < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return idx, true
|
||||
}
|
||||
|
||||
func userVerbatimSectionEnd(content string, start int) (int, bool) {
|
||||
if start < 0 || start >= len(content) {
|
||||
return 0, false
|
||||
}
|
||||
tail := content[start:]
|
||||
idx := strings.LastIndex(tail, UserVerbatimSectionEndMarker)
|
||||
if idx < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return start + idx + len(UserVerbatimSectionEndMarker), true
|
||||
}
|
||||
|
||||
// RefreshUserVerbatimAnchorInMessages 在 summarization 等压缩后,用 freshBlock 刷新 system 中的用户原文锚点。
|
||||
// 若尚无锚点段,则追加到首条 system 消息;若无 system 消息则在开头插入一条。
|
||||
func RefreshUserVerbatimAnchorInMessages(msgs []adk.Message, freshBlock string) []adk.Message {
|
||||
freshBlock = strings.TrimSpace(freshBlock)
|
||||
if freshBlock == "" || len(msgs) == 0 {
|
||||
return msgs
|
||||
}
|
||||
|
||||
out := make([]adk.Message, len(msgs))
|
||||
changed := false
|
||||
for i, msg := range msgs {
|
||||
if msg == nil || msg.Role != schema.System {
|
||||
out[i] = msg
|
||||
continue
|
||||
}
|
||||
newContent, ok := ReplaceUserVerbatimAnchorSection(msg.Content, freshBlock)
|
||||
if !ok {
|
||||
out[i] = msg
|
||||
continue
|
||||
}
|
||||
cloned := *msg
|
||||
cloned.Content = newContent
|
||||
out[i] = &cloned
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
return out
|
||||
}
|
||||
|
||||
for i, msg := range msgs {
|
||||
if msg == nil || msg.Role != schema.System {
|
||||
continue
|
||||
}
|
||||
cloned := *msg
|
||||
cloned.Content = AppendSystemPromptBlock(cloned.Content, freshBlock)
|
||||
out[i] = &cloned
|
||||
return out
|
||||
}
|
||||
|
||||
prefix := make([]adk.Message, 0, len(msgs)+1)
|
||||
prefix = append(prefix, schema.SystemMessage(freshBlock))
|
||||
return append(prefix, msgs...)
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestBuildUserVerbatimAnchorBlock_MultiTurn(t *testing.T) {
|
||||
msgs := []database.Message{
|
||||
{Role: "user", Content: "目标 https://a.com 仅测 /api"},
|
||||
{Role: "assistant", Content: "好的"},
|
||||
{Role: "user", Content: "用 admin:test 登录"},
|
||||
}
|
||||
block := BuildUserVerbatimAnchorBlockFromMessages(msgs, 0)
|
||||
if block == "" {
|
||||
t.Fatal("expected non-empty block")
|
||||
}
|
||||
if !strings.Contains(block, UserVerbatimSectionStartMarker) {
|
||||
t.Error("missing start marker")
|
||||
}
|
||||
if !strings.Contains(block, "[第1轮]") || !strings.Contains(block, "https://a.com") {
|
||||
t.Error("missing first user turn")
|
||||
}
|
||||
if !strings.Contains(block, "[第2轮]") || !strings.Contains(block, "admin:test") {
|
||||
t.Error("missing second user turn")
|
||||
}
|
||||
if strings.Contains(block, "好的") {
|
||||
t.Error("assistant content should not appear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceUserVerbatimAnchorSection(t *testing.T) {
|
||||
old := "prefix\n\n" + wrapUserVerbatimBlock("## old\n\n[第1轮] a") + "\nsuffix"
|
||||
newBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] b\n[第2轮] c")
|
||||
out, ok := ReplaceUserVerbatimAnchorSection(old, newBlock)
|
||||
if !ok {
|
||||
t.Fatal("expected replace ok")
|
||||
}
|
||||
if !strings.Contains(out, "[第2轮] c") {
|
||||
t.Errorf("expected new block, got %q", out)
|
||||
}
|
||||
if !strings.HasPrefix(strings.TrimSpace(out), "prefix") {
|
||||
t.Error("prefix should remain")
|
||||
}
|
||||
if !strings.Contains(out, "suffix") {
|
||||
t.Error("suffix should remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshUserVerbatimAnchorInMessages_ReplaceExisting(t *testing.T) {
|
||||
oldBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] old")
|
||||
msgs := []adk.Message{
|
||||
schema.SystemMessage("instr\n\n" + oldBlock),
|
||||
schema.UserMessage("hi"),
|
||||
}
|
||||
newBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] new")
|
||||
out := RefreshUserVerbatimAnchorInMessages(msgs, newBlock)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("message count: got %d", len(out))
|
||||
}
|
||||
if !strings.Contains(out[0].Content, "[第1轮] new") {
|
||||
t.Errorf("system content: %q", out[0].Content)
|
||||
}
|
||||
if strings.Contains(out[0].Content, "[第1轮] old") {
|
||||
t.Error("old anchor should be replaced")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshUserVerbatimAnchorInMessages_InsertWhenMissing(t *testing.T) {
|
||||
msgs := []adk.Message{
|
||||
schema.SystemMessage("base instruction"),
|
||||
schema.UserMessage("hi"),
|
||||
}
|
||||
block := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] anchor")
|
||||
out := RefreshUserVerbatimAnchorInMessages(msgs, block)
|
||||
if !strings.Contains(out[0].Content, "[第1轮] anchor") {
|
||||
t.Errorf("expected appended anchor, got %q", out[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUserVerbatimAnchorBlock_MaxRunes(t *testing.T) {
|
||||
long := strings.Repeat("字", 200)
|
||||
block := BuildUserVerbatimAnchorBlock([]string{long}, 50)
|
||||
body := block
|
||||
if idx := strings.Index(body, UserVerbatimSectionStartMarker); idx >= 0 {
|
||||
body = strings.TrimPrefix(body[idx+len(UserVerbatimSectionStartMarker):], "\n")
|
||||
}
|
||||
if len([]rune(body)) > 120 {
|
||||
t.Errorf("expected capped body, got %d runes", len([]rune(body)))
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,35 @@ const (
|
||||
wireOutputConfig
|
||||
)
|
||||
|
||||
// ApplyPlanExecutePlannerModelConfig configures the plan_execute planner/replanner
|
||||
// ChatModel. Those Eino agents call WithToolChoice(Forced); several gateways reject
|
||||
// thinking / reasoning fields on the same request (tool_choice required/object).
|
||||
// Executor should keep the normal ApplyToEinoChatModelConfig path.
|
||||
func ApplyPlanExecutePlannerModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig) {
|
||||
if cfg == nil || oa == nil {
|
||||
return
|
||||
}
|
||||
offOA := *oa
|
||||
offReasoning := oa.Reasoning
|
||||
offReasoning.Mode = "off"
|
||||
offOA.Reasoning = offReasoning
|
||||
ApplyToEinoChatModelConfig(cfg, &offOA, nil)
|
||||
clearReasoningFromChatModelConfig(cfg)
|
||||
}
|
||||
|
||||
func clearReasoningFromChatModelConfig(cfg *einoopenai.ChatModelConfig) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.ReasoningEffort = ""
|
||||
if cfg.ExtraFields != nil {
|
||||
for _, key := range []string{"thinking", "reasoning_effort", "output_config", "reasoning"} {
|
||||
delete(cfg.ExtraFields, key)
|
||||
}
|
||||
}
|
||||
applyThinkingDisabled(cfg)
|
||||
}
|
||||
|
||||
// ApplyToEinoChatModelConfig merges reasoning-related options into cfg.
|
||||
// Precondition: cfg already has APIKey, BaseURL, Model, HTTPClient set.
|
||||
func ApplyToEinoChatModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig, client *ClientIntent) {
|
||||
|
||||
@@ -49,6 +49,30 @@ func TestApplyOpenAICompat_xhighExtraField(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPlanExecutePlannerModelConfig_stripsReasoningWhenGlobalOn(t *testing.T) {
|
||||
cfg := &einoopenai.ChatModelConfig{}
|
||||
oa := &config.OpenAIConfig{
|
||||
BaseURL: "https://antchat.example.com/v1",
|
||||
Model: "minimax-m3",
|
||||
Reasoning: config.OpenAIReasoningConfig{
|
||||
Profile: "openai_compat",
|
||||
Mode: "on",
|
||||
Effort: "high",
|
||||
},
|
||||
}
|
||||
ApplyPlanExecutePlannerModelConfig(cfg, oa)
|
||||
if cfg.ReasoningEffort != "" {
|
||||
t.Fatalf("expected ReasoningEffort cleared, got %q", cfg.ReasoningEffort)
|
||||
}
|
||||
th, ok := cfg.ExtraFields["thinking"].(map[string]any)
|
||||
if !ok || th["type"] != "disabled" {
|
||||
t.Fatalf("expected thinking disabled, got %#v", cfg.ExtraFields)
|
||||
}
|
||||
if _, ok := cfg.ExtraFields["reasoning_effort"]; ok {
|
||||
t.Fatalf("expected reasoning_effort stripped, got %#v", cfg.ExtraFields)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyReasoningOff_disablesThinking(t *testing.T) {
|
||||
cfg := &einoopenai.ChatModelConfig{}
|
||||
oa := &config.OpenAIConfig{
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
// compileAgentSubgraph wraps an Agent canvas node as an Eino subgraph (AddGraphNode best practice).
|
||||
func compileAgentSubgraph(_ context.Context, node graphNode) (compose.AnyGraph, error) {
|
||||
n := node
|
||||
innerID := n.ID + "__agent"
|
||||
g := compose.NewGraph[WorkflowNodeOutput, WorkflowNodeOutput]()
|
||||
_ = g.AddLambdaNode(innerID, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) {
|
||||
return runWorkflowNodeLambda(runCtx, n)
|
||||
}))
|
||||
if err := g.AddEdge(compose.START, innerID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := g.AddEdge(innerID, compose.END); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FieldBinding selects a value from workflow state (replaces {{...}} templates).
|
||||
type FieldBinding struct {
|
||||
From string `json:"from"` // inputs | previous | <nodeId>
|
||||
Field string `json:"field"` // e.g. output, message
|
||||
}
|
||||
|
||||
func parseFieldBinding(cfg map[string]any, keys ...string) (FieldBinding, bool) {
|
||||
for _, key := range keys {
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
raw, ok := cfg[key]
|
||||
if !ok || raw == nil {
|
||||
continue
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case map[string]any:
|
||||
return FieldBinding{
|
||||
From: strings.TrimSpace(fmt.Sprint(v["from"])),
|
||||
Field: strings.TrimSpace(fmt.Sprint(v["field"])),
|
||||
}, true
|
||||
case string:
|
||||
s := strings.TrimSpace(v)
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
var b FieldBinding
|
||||
if err := json.Unmarshal([]byte(s), &b); err == nil && (b.From != "" || b.Field != "") {
|
||||
return b, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return FieldBinding{}, false
|
||||
}
|
||||
|
||||
func defaultBinding(from, field string) FieldBinding {
|
||||
return FieldBinding{From: from, Field: field}
|
||||
}
|
||||
|
||||
func resolveBinding(b FieldBinding, state *WorkflowLocalState) any {
|
||||
from := strings.TrimSpace(b.From)
|
||||
field := strings.TrimSpace(b.Field)
|
||||
if field == "" {
|
||||
field = "output"
|
||||
}
|
||||
if from == "" || from == "previous" || from == "prev" {
|
||||
if field == "output" && state.LastOutput != nil {
|
||||
return state.LastOutput["output"]
|
||||
}
|
||||
return valueFromPath("previous."+field, state)
|
||||
}
|
||||
if from == "inputs" || from == "input" {
|
||||
if field == "" {
|
||||
return state.Inputs
|
||||
}
|
||||
return valueFromPath("inputs."+field, state)
|
||||
}
|
||||
if from == "outputs" {
|
||||
return valueFromPath("outputs."+field, state)
|
||||
}
|
||||
return valueFromPath(from+"."+field, state)
|
||||
}
|
||||
|
||||
func resolveBindingString(b FieldBinding, state *WorkflowLocalState) string {
|
||||
return strings.TrimSpace(fmt.Sprint(resolveBinding(b, state)))
|
||||
}
|
||||
|
||||
func resolveNodeInputBinding(cfg map[string]any, state *WorkflowLocalState) string {
|
||||
if b, ok := parseFieldBinding(cfg, "input_binding"); ok {
|
||||
return resolveBindingString(b, state)
|
||||
}
|
||||
// legacy template field removed — default previous.output
|
||||
return resolveBindingString(defaultBinding("previous", "output"), state)
|
||||
}
|
||||
|
||||
func resolveOutputSourceBinding(cfg map[string]any, state *WorkflowLocalState) any {
|
||||
if b, ok := parseFieldBinding(cfg, "source_binding"); ok {
|
||||
return resolveBinding(b, state)
|
||||
}
|
||||
return resolveBinding(defaultBinding("previous", "output"), state)
|
||||
}
|
||||
|
||||
func resolveHITLPromptBinding(cfg map[string]any, state *WorkflowLocalState) string {
|
||||
if b, ok := parseFieldBinding(cfg, "prompt_binding"); ok {
|
||||
return resolveBindingString(b, state)
|
||||
}
|
||||
if s := cfgString(cfg, "prompt"); s != "" {
|
||||
return s
|
||||
}
|
||||
return resolveBindingString(defaultBinding("previous", "output"), state)
|
||||
}
|
||||
|
||||
func toolArgumentBindings(cfg map[string]any) map[string]FieldBinding {
|
||||
raw, ok := cfg["argument_bindings"].(map[string]any)
|
||||
if !ok || len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]FieldBinding, len(raw))
|
||||
for argName, v := range raw {
|
||||
m, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out[argName] = FieldBinding{
|
||||
From: strings.TrimSpace(fmt.Sprint(m["from"])),
|
||||
Field: strings.TrimSpace(fmt.Sprint(m["field"])),
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func resolveToolArguments(cfg map[string]any, state *WorkflowLocalState) (map[string]interface{}, error) {
|
||||
bindings := toolArgumentBindings(cfg)
|
||||
if len(bindings) > 0 {
|
||||
args := make(map[string]interface{}, len(bindings))
|
||||
for k, b := range bindings {
|
||||
args[k] = resolveBinding(b, state)
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
raw := cfgString(cfg, "arguments")
|
||||
if raw == "" {
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if args == nil {
|
||||
args = map[string]interface{}{}
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// fileCheckPointStore persists Eino workflow checkpoints on disk (per run id).
|
||||
type fileCheckPointStore struct {
|
||||
dir string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newFileCheckPointStore(dir string) (*fileCheckPointStore, error) {
|
||||
dir = strings.TrimSpace(dir)
|
||||
if dir == "" {
|
||||
dir = filepath.Join("data", "workflow-checkpoints")
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create workflow checkpoint dir: %w", err)
|
||||
}
|
||||
return &fileCheckPointStore{dir: dir}, nil
|
||||
}
|
||||
|
||||
func (s *fileCheckPointStore) path(id string) (string, error) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return "", fmt.Errorf("checkpoint id is empty")
|
||||
}
|
||||
if strings.Contains(id, "..") || strings.ContainsAny(id, `/\`) {
|
||||
return "", fmt.Errorf("invalid checkpoint id")
|
||||
}
|
||||
return filepath.Join(s.dir, id+".ckpt"), nil
|
||||
}
|
||||
|
||||
func (s *fileCheckPointStore) Get(_ context.Context, checkPointID string) ([]byte, bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
p, err := s.path(checkPointID)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
data, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, false, nil
|
||||
}
|
||||
return nil, false, err
|
||||
}
|
||||
return data, true, nil
|
||||
}
|
||||
|
||||
func (s *fileCheckPointStore) Set(_ context.Context, checkPointID string, checkPoint []byte) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
p, err := s.path(checkPointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmp := p + ".tmp"
|
||||
if err := os.WriteFile(tmp, checkPoint, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmp, p)
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
func hasConditionalOutgoingEdges(idx *graphIndex, nodeID string) bool {
|
||||
for _, edge := range idx.outgoing[nodeID] {
|
||||
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
|
||||
if cond != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func wireConditionBranch(
|
||||
wf *compose.Workflow[WorkflowInput, WorkflowOutput],
|
||||
nodeRefs map[string]*compose.WorkflowNode,
|
||||
idx *graphIndex,
|
||||
condID string,
|
||||
condNode graphNode,
|
||||
) error {
|
||||
edges := idx.outgoing[condID]
|
||||
if len(edges) == 0 {
|
||||
return nil
|
||||
}
|
||||
branchID := branchNodeID(condID)
|
||||
wf.AddPassthroughNode(branchID).AddInput(condID)
|
||||
|
||||
endNodes := map[string]bool{compose.END: true}
|
||||
for _, edge := range edges {
|
||||
endNodes[edge.Target] = true
|
||||
}
|
||||
|
||||
sortedEdges := append([]graphEdge(nil), edges...)
|
||||
sortEdgesByCanvas(sortedEdges, idx.nodes)
|
||||
|
||||
branch := compose.NewGraphBranch(func(runCtx context.Context, _ map[string]any) (string, error) {
|
||||
rt := workflowRuntimeFrom(runCtx)
|
||||
if rt == nil {
|
||||
return compose.END, fmt.Errorf("workflow runtime missing in context")
|
||||
}
|
||||
emitConditionBranchProgress(rt.args, rt.runID, condNode, sortedEdges, idx.nodes, rt.state)
|
||||
for edgeIdx, edge := range sortedEdges {
|
||||
if conditionBranchAllowed(edge, edgeIdx, rt.state) {
|
||||
return edge.Target, nil
|
||||
}
|
||||
}
|
||||
return compose.END, nil
|
||||
}, endNodes)
|
||||
wf.AddBranch(branchID, branch)
|
||||
|
||||
for _, edge := range edges {
|
||||
if target, ok := nodeRefs[edge.Target]; ok {
|
||||
target.AddInput(branchID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func wireEdgeConditionBranch(
|
||||
wf *compose.Workflow[WorkflowInput, WorkflowOutput],
|
||||
nodeRefs map[string]*compose.WorkflowNode,
|
||||
idx *graphIndex,
|
||||
sourceID string,
|
||||
sourceNode graphNode,
|
||||
) error {
|
||||
edges := idx.outgoing[sourceID]
|
||||
if len(edges) == 0 {
|
||||
return nil
|
||||
}
|
||||
branchID := edgeBranchNodeID(sourceID)
|
||||
wf.AddPassthroughNode(branchID).AddInput(sourceID)
|
||||
|
||||
endNodes := map[string]bool{compose.END: true}
|
||||
for _, edge := range edges {
|
||||
endNodes[edge.Target] = true
|
||||
}
|
||||
|
||||
sortedEdges := append([]graphEdge(nil), edges...)
|
||||
sortEdgesByCanvas(sortedEdges, idx.nodes)
|
||||
|
||||
branch := compose.NewGraphBranch(func(runCtx context.Context, _ map[string]any) (string, error) {
|
||||
rt := workflowRuntimeFrom(runCtx)
|
||||
if rt == nil {
|
||||
return compose.END, fmt.Errorf("workflow runtime missing in context")
|
||||
}
|
||||
for edgeIdx, edge := range sortedEdges {
|
||||
if edgeAllowed(edge, sourceNode, edgeIdx, rt.state) {
|
||||
return edge.Target, nil
|
||||
}
|
||||
}
|
||||
return compose.END, nil
|
||||
}, endNodes)
|
||||
wf.AddBranch(branchID, branch)
|
||||
|
||||
for _, edge := range edges {
|
||||
if target, ok := nodeRefs[edge.Target]; ok {
|
||||
target.AddInput(branchID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einoobserve"
|
||||
)
|
||||
|
||||
func attachWorkflowCallbacks(ctx context.Context, cfg *config.Config, args RunArgs, workflowName string) context.Context {
|
||||
if cfg == nil {
|
||||
return ctx
|
||||
}
|
||||
cbCfg := &cfg.MultiAgent.EinoCallbacks
|
||||
return einoobserve.AttachAgentRunCallbacks(ctx, cbCfg, einoobserve.Params{
|
||||
Logger: args.Logger,
|
||||
Progress: args.Progress,
|
||||
ConversationID: args.ConversationID,
|
||||
OrchMode: "workflow",
|
||||
OrchestratorName: workflowName,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
func executeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState) error {
|
||||
_, err := invokeEinoGraph(ctx, args, runID, workflowID, version, g, state, false)
|
||||
return err
|
||||
}
|
||||
|
||||
func invokeEinoGraph(ctx context.Context, args RunArgs, runID string, workflowID string, version int, g *graphDef, state *WorkflowLocalState, resume bool) (bool, error) {
|
||||
wfInput := workflowInputFromMap(state.Inputs)
|
||||
if resume {
|
||||
wfInput = WorkflowInput{}
|
||||
}
|
||||
rt := &workflowRuntime{
|
||||
args: args,
|
||||
runID: runID,
|
||||
idx: indexGraph(g),
|
||||
state: state,
|
||||
}
|
||||
|
||||
art, err := defaultEngine.getOrCompile(ctx, workflowID, version, g)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("编译 Eino Workflow 失败: %w", err)
|
||||
}
|
||||
rt.idx = art.idx
|
||||
|
||||
runCtx := withWorkflowRuntime(ctx, rt)
|
||||
runCtx = attachWorkflowCallbacks(runCtx, args.AppCfg, args, workflowID)
|
||||
|
||||
invokeOpts := []compose.Option{compose.WithCheckPointID(runID)}
|
||||
for {
|
||||
_, err = art.runnable.Invoke(runCtx, wfInput, invokeOpts...)
|
||||
if err == nil {
|
||||
return false, nil
|
||||
}
|
||||
if hitlErr := extractAwaitingHITL(err, art, runID, args, state); hitlErr != nil {
|
||||
return true, hitlErr
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
func extractAwaitingHITL(err error, art *compiledArtifact, runID string, args RunArgs, state *WorkflowLocalState) error {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok || len(art.hitlIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
nodeID := nextHITLNodeID(info, art.hitlIDs)
|
||||
node := art.idx.nodes[nodeID]
|
||||
if nodeID == "" {
|
||||
return nil
|
||||
}
|
||||
prompt := resolveHITLPromptBinding(node.Config, state)
|
||||
label := firstNonEmpty(node.Label, nodeID)
|
||||
if args.DB != nil {
|
||||
pending := map[string]any{
|
||||
"nodeId": nodeID,
|
||||
"label": label,
|
||||
"prompt": prompt,
|
||||
"reviewer": cfgString(node.Config, "reviewer"),
|
||||
}
|
||||
pendingJSON, _ := json.Marshal(pending)
|
||||
_ = args.DB.SetWorkflowRunAwaitingHITL(runID, nodeID, string(pendingJSON))
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_hitl_waiting", fmt.Sprintf("等待人工确认:%s", label), map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeId": nodeID,
|
||||
"label": label,
|
||||
"prompt": prompt,
|
||||
"reviewer": cfgString(node.Config, "reviewer"),
|
||||
"mode": "interactive",
|
||||
"resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID),
|
||||
})
|
||||
}
|
||||
return &AwaitingHITLError{
|
||||
RunID: runID,
|
||||
NodeID: nodeID,
|
||||
NodeLabel: label,
|
||||
Prompt: prompt,
|
||||
Reviewer: cfgString(node.Config, "reviewer"),
|
||||
}
|
||||
}
|
||||
|
||||
func nextHITLNodeID(info *compose.InterruptInfo, hitlIDs []string) string {
|
||||
if info != nil && len(info.BeforeNodes) > 0 {
|
||||
for _, id := range info.BeforeNodes {
|
||||
for _, hitl := range hitlIDs {
|
||||
if id == hitl {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
return info.BeforeNodes[0]
|
||||
}
|
||||
if len(hitlIDs) == 0 {
|
||||
return ""
|
||||
}
|
||||
return hitlIDs[0]
|
||||
}
|
||||
|
||||
// ResumeWorkflowRun continues a run paused at HITL after human decision.
|
||||
func ResumeWorkflowRun(ctx context.Context, args RunArgs, runID string, approved bool, comment string) (*RunResult, error) {
|
||||
run, err := args.DB.GetWorkflowRun(runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if run == nil {
|
||||
return nil, fmt.Errorf("工作流运行不存在")
|
||||
}
|
||||
if run.Status != "awaiting_hitl" {
|
||||
return nil, fmt.Errorf("工作流运行不在等待审批状态: %s", run.Status)
|
||||
}
|
||||
wf, err := args.DB.GetWorkflowDefinition(run.WorkflowID)
|
||||
if err != nil || wf == nil {
|
||||
return nil, fmt.Errorf("工作流定义不存在")
|
||||
}
|
||||
graph, err := parseGraph(wf.GraphJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var input map[string]interface{}
|
||||
_ = json.Unmarshal([]byte(run.InputJSON), &input)
|
||||
state := newWorkflowLocalState(input, runID)
|
||||
if state.Inputs == nil {
|
||||
state.Inputs = map[string]any{}
|
||||
}
|
||||
state.Inputs["_hitl_approved"] = approved
|
||||
state.Inputs["_hitl_comment"] = strings.TrimSpace(comment)
|
||||
state.Inputs["_hitl_node_id"] = run.PendingHITLNodeID
|
||||
|
||||
if !approved {
|
||||
errText := strings.TrimSpace(comment)
|
||||
if errText == "" {
|
||||
errText = "人工审批拒绝"
|
||||
}
|
||||
_ = args.DB.FinishWorkflowRun(runID, "rejected", "", errText)
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_hitl_rejected", fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", run.PendingHITLNodeID), map[string]interface{}{
|
||||
"workflowRunId": runID,
|
||||
"nodeId": run.PendingHITLNodeID,
|
||||
"comment": errText,
|
||||
})
|
||||
}
|
||||
return &RunResult{
|
||||
RunID: runID,
|
||||
Response: fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", run.PendingHITLNodeID),
|
||||
Status: "rejected",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_hitl_resumed", "人工审批已通过,继续执行", map[string]interface{}{
|
||||
"workflowRunId": runID,
|
||||
"nodeId": run.PendingHITLNodeID,
|
||||
"comment": strings.TrimSpace(comment),
|
||||
})
|
||||
}
|
||||
|
||||
_ = args.DB.SetWorkflowRunStatus(runID, "running")
|
||||
resumeArgs := args
|
||||
if strings.TrimSpace(resumeArgs.ConversationID) == "" {
|
||||
resumeArgs.ConversationID = run.ConversationID
|
||||
}
|
||||
|
||||
awaiting, err := invokeEinoGraph(ctx, resumeArgs, runID, wf.ID, run.WorkflowVersion, graph, state, true)
|
||||
if err != nil {
|
||||
if IsAwaitingHITL(err) {
|
||||
return &RunResult{
|
||||
RunID: runID,
|
||||
Status: "awaiting_hitl",
|
||||
Response: fmt.Sprintf("工作流在节点「%s」等待下一次人工确认。", err.(*AwaitingHITLError).NodeID),
|
||||
AwaitingHITL: true,
|
||||
}, nil
|
||||
}
|
||||
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
_ = awaiting
|
||||
|
||||
output := map[string]interface{}{
|
||||
"workflowId": wf.ID,
|
||||
"workflowName": wf.Name,
|
||||
"workflowVersion": wf.Version,
|
||||
"workflowRunId": runID,
|
||||
"status": "completed",
|
||||
"outputs": state.Outputs,
|
||||
"executedNodes": state.Executed,
|
||||
"skippedNodes": state.Skipped,
|
||||
"engine": "eino_workflow",
|
||||
}
|
||||
outputJSON, _ := json.Marshal(output)
|
||||
response := renderWorkflowResponse(args.Role.Name, wf.Name, wf.Version, runID, state)
|
||||
_ = args.DB.FinishWorkflowRun(runID, "completed", string(outputJSON), "")
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_done", fmt.Sprintf("流程「%s」运行完成", wf.Name), map[string]interface{}{
|
||||
"workflowRunId": runID,
|
||||
"workflowId": wf.ID,
|
||||
"outputs": state.Outputs,
|
||||
"response": response,
|
||||
"engine": "eino_workflow",
|
||||
})
|
||||
}
|
||||
return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func testWorkflowDB(t *testing.T) *database.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(dir, "workflow.db"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
return db
|
||||
}
|
||||
|
||||
func linearStartOutputGraph() string {
|
||||
return `{
|
||||
"nodes": [
|
||||
{"id": "start-1", "type": "start", "label": "开始", "position": {"x": 0, "y": 0}, "config": {}},
|
||||
{"id": "out-1", "type": "output", "label": "输出", "position": {"x": 0, "y": 120}, "config": {"output_key": "result", "source_binding": {"from": "inputs", "field": "message"}}}
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "start-1", "target": "out-1"}
|
||||
],
|
||||
"config": {"schema_version": 1}
|
||||
}`
|
||||
}
|
||||
|
||||
func conditionBranchGraph() string {
|
||||
return `{
|
||||
"nodes": [
|
||||
{"id": "start-1", "type": "start", "label": "开始", "position": {"x": 0, "y": 0}, "config": {}},
|
||||
{"id": "cond-1", "type": "condition", "label": "判断", "position": {"x": 0, "y": 80}, "config": {"expression": "{{inputs.message}} == yes"}},
|
||||
{"id": "out-yes", "type": "output", "label": "是", "position": {"x": -80, "y": 160}, "config": {"output_key": "branch", "static_value": "yes"}},
|
||||
{"id": "out-no", "type": "output", "label": "否", "position": {"x": 80, "y": 160}, "config": {"output_key": "branch", "static_value": "no"}}
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "start-1", "target": "cond-1"},
|
||||
{"id": "e2", "source": "cond-1", "target": "out-yes", "label": "是"},
|
||||
{"id": "e3", "source": "cond-1", "target": "out-no", "label": "否"}
|
||||
],
|
||||
"config": {"schema_version": 1}
|
||||
}`
|
||||
}
|
||||
|
||||
func TestValidateGraphJSON_linear(t *testing.T) {
|
||||
if err := ValidateGraphJSON(context.Background(), linearStartOutputGraph()); err != nil {
|
||||
t.Fatalf("validate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileEngine_linear(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetCheckpointDir(t.TempDir())
|
||||
g, err := parseGraph(linearStartOutputGraph())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := defaultEngine.compile(ctx, g); err != nil {
|
||||
t.Fatalf("compile: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestWorkflowRun(t *testing.T, db *database.DB, runID string) {
|
||||
t.Helper()
|
||||
if err := db.CreateWorkflowRun(&database.WorkflowRun{
|
||||
ID: runID,
|
||||
WorkflowID: "test-wf",
|
||||
Status: "running",
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateWorkflowRun: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEinoGraph_linearStartOutput(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetCheckpointDir(t.TempDir())
|
||||
db := testWorkflowDB(t)
|
||||
createTestWorkflowRun(t, db, "run-linear")
|
||||
g, err := parseGraph(linearStartOutputGraph())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
state := newWorkflowLocalState(map[string]interface{}{"message": "ping"}, "run-linear")
|
||||
args := RunArgs{DB: db}
|
||||
if err := executeEinoGraph(ctx, args, "run-linear", "test-wf", 1, g, state); err != nil {
|
||||
t.Fatalf("execute: %v", err)
|
||||
}
|
||||
if got := state.Outputs["result"]; got != "ping" {
|
||||
t.Fatalf("outputs[result] = %v, want ping", got)
|
||||
}
|
||||
if len(state.Executed) != 2 {
|
||||
t.Fatalf("executed nodes = %d, want 2", len(state.Executed))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEinoGraph_conditionBranch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetCheckpointDir(t.TempDir())
|
||||
db := testWorkflowDB(t)
|
||||
createTestWorkflowRun(t, db, "run-yes")
|
||||
createTestWorkflowRun(t, db, "run-no")
|
||||
g, err := parseGraph(conditionBranchGraph())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stateYes := newWorkflowLocalState(map[string]interface{}{"message": "yes"}, "run-yes")
|
||||
if err := executeEinoGraph(ctx, RunArgs{DB: db}, "run-yes", "test-wf-branch", 1, g, stateYes); err != nil {
|
||||
t.Fatalf("execute yes: %v", err)
|
||||
}
|
||||
if got := stateYes.Outputs["branch"]; got != "yes" {
|
||||
t.Fatalf("yes branch output = %v", got)
|
||||
}
|
||||
|
||||
stateNo := newWorkflowLocalState(map[string]interface{}{"message": "no"}, "run-no")
|
||||
if err := executeEinoGraph(ctx, RunArgs{DB: db}, "run-no", "test-wf-branch", 1, g, stateNo); err != nil {
|
||||
t.Fatalf("execute no: %v", err)
|
||||
}
|
||||
if got := stateNo.Outputs["branch"]; got != "no" {
|
||||
t.Fatalf("no branch output = %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunRoleBoundWorkflow_integration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetCheckpointDir(t.TempDir())
|
||||
db := testWorkflowDB(t)
|
||||
graph := linearStartOutputGraph()
|
||||
if err := db.UpsertWorkflowDefinition(&database.WorkflowDefinition{
|
||||
ID: "wf-linear",
|
||||
Name: "线性流程",
|
||||
Version: 1,
|
||||
GraphJSON: graph,
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
role := config.RoleConfig{
|
||||
Name: "tester",
|
||||
Enabled: true,
|
||||
WorkflowID: "wf-linear",
|
||||
WorkflowPolicy: "auto",
|
||||
}
|
||||
result, err := RunRoleBoundWorkflow(ctx, RunArgs{
|
||||
DB: db,
|
||||
Logger: zap.NewNop(),
|
||||
Role: role,
|
||||
UserMessage: "from-role",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RunRoleBoundWorkflow: %v", err)
|
||||
}
|
||||
if result == nil || result.RunID == "" {
|
||||
t.Fatal("expected run result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompiledCache_reuse(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetCheckpointDir(t.TempDir())
|
||||
InvalidateCompiledCache("cache-wf")
|
||||
g, err := parseGraph(linearStartOutputGraph())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a1, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a2, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if a1 != a2 {
|
||||
t.Fatal("expected cached artifact pointer reuse")
|
||||
}
|
||||
InvalidateCompiledCache("cache-wf")
|
||||
a3, err := defaultEngine.getOrCompile(ctx, "cache-wf", 1, g)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if a1 == a3 {
|
||||
t.Fatal("expected new artifact after invalidation")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type workflowRuntimeCtxKey struct{}
|
||||
|
||||
// workflowRuntime carries per-run execution context into Eino Workflow local state.
|
||||
type workflowRuntime struct {
|
||||
args RunArgs
|
||||
runID string
|
||||
idx *graphIndex
|
||||
state *WorkflowLocalState
|
||||
}
|
||||
|
||||
func withWorkflowRuntime(ctx context.Context, rt *workflowRuntime) context.Context {
|
||||
return context.WithValue(ctx, workflowRuntimeCtxKey{}, rt)
|
||||
}
|
||||
|
||||
func workflowRuntimeFrom(ctx context.Context) *workflowRuntime {
|
||||
rt, _ := ctx.Value(workflowRuntimeCtxKey{}).(*workflowRuntime)
|
||||
return rt
|
||||
}
|
||||
|
||||
func newWorkflowRuntime(args RunArgs, runID string, idx *graphIndex, inputs map[string]interface{}) *workflowRuntime {
|
||||
return &workflowRuntime{
|
||||
args: args,
|
||||
runID: runID,
|
||||
idx: idx,
|
||||
state: newWorkflowLocalState(inputs, runID),
|
||||
}
|
||||
}
|
||||
|
||||
// RunArgs is the execution context for a role-bound workflow run.
|
||||
type RunArgs struct {
|
||||
DB *database.DB
|
||||
Logger *zap.Logger
|
||||
Role config.RoleConfig
|
||||
AppCfg *config.Config
|
||||
Agent *agent.Agent
|
||||
ConversationID string
|
||||
ProjectID string
|
||||
UserMessage string
|
||||
History []agent.ChatMessage
|
||||
RoleTools []string
|
||||
AgentsMarkdownDir string
|
||||
SystemPromptExtra string
|
||||
AssistantMessageID string
|
||||
Progress agent.ProgressCallback
|
||||
}
|
||||
|
||||
type RunResult struct {
|
||||
Response string
|
||||
RunID string
|
||||
Status string
|
||||
AwaitingHITL bool
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
type compiledArtifact struct {
|
||||
runnable compose.Runnable[WorkflowInput, WorkflowOutput]
|
||||
idx *graphIndex
|
||||
hitlIDs []string
|
||||
}
|
||||
|
||||
// Engine compiles and caches Eino Workflow artifacts.
|
||||
type Engine struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]*compiledArtifact
|
||||
cpStore compose.CheckPointStore
|
||||
cpStoreMu sync.Once
|
||||
cpStoreErr error
|
||||
checkpointDir string
|
||||
}
|
||||
|
||||
var defaultEngine = &Engine{
|
||||
cache: make(map[string]*compiledArtifact),
|
||||
checkpointDir: "data/workflow-checkpoints",
|
||||
}
|
||||
|
||||
// SetCheckpointDir overrides the workflow checkpoint root (mainly for tests).
|
||||
func SetCheckpointDir(dir string) {
|
||||
defaultEngine.mu.Lock()
|
||||
defer defaultEngine.mu.Unlock()
|
||||
defaultEngine.checkpointDir = strings.TrimSpace(dir)
|
||||
defaultEngine.cpStore = nil
|
||||
defaultEngine.cpStoreErr = nil
|
||||
defaultEngine.cpStoreMu = sync.Once{}
|
||||
}
|
||||
|
||||
func (e *Engine) checkpointStore() (compose.CheckPointStore, error) {
|
||||
e.cpStoreMu.Do(func() {
|
||||
e.cpStore, e.cpStoreErr = newFileCheckPointStore(e.checkpointDir)
|
||||
})
|
||||
return e.cpStore, e.cpStoreErr
|
||||
}
|
||||
|
||||
// InvalidateCompiledCache drops cached compilations for a workflow id.
|
||||
func InvalidateCompiledCache(workflowID string) {
|
||||
workflowID = strings.TrimSpace(workflowID)
|
||||
if workflowID == "" {
|
||||
return
|
||||
}
|
||||
defaultEngine.mu.Lock()
|
||||
defer defaultEngine.mu.Unlock()
|
||||
for key := range defaultEngine.cache {
|
||||
if strings.HasPrefix(key, workflowID+":") {
|
||||
delete(defaultEngine.cache, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateGraphJSON parses and trial-compiles a canvas graph (save-time gate).
|
||||
func ValidateGraphJSON(ctx context.Context, graphJSON string) error {
|
||||
g, err := parseGraph(graphJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idx := indexGraph(g)
|
||||
if len(findStartNodeIDs(idx)) == 0 {
|
||||
return fmt.Errorf("工作流缺少可执行的起点节点")
|
||||
}
|
||||
if !hasTerminalNode(idx) {
|
||||
return fmt.Errorf("工作流至少需要一个无出边的终点或 output/end 节点")
|
||||
}
|
||||
_, err = defaultEngine.compile(ctx, g)
|
||||
return err
|
||||
}
|
||||
|
||||
func hasTerminalNode(idx *graphIndex) bool {
|
||||
for id, node := range idx.nodes {
|
||||
if len(idx.outgoing[id]) == 0 {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(node.Type, "end") || strings.EqualFold(node.Type, "output") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *Engine) getOrCompile(ctx context.Context, workflowID string, version int, g *graphDef) (*compiledArtifact, error) {
|
||||
key := cacheKey(workflowID, version)
|
||||
e.mu.RLock()
|
||||
if art, ok := e.cache[key]; ok {
|
||||
e.mu.RUnlock()
|
||||
return art, nil
|
||||
}
|
||||
e.mu.RUnlock()
|
||||
|
||||
art, err := e.compile(ctx, g)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.mu.Lock()
|
||||
if existing, ok := e.cache[key]; ok {
|
||||
e.mu.Unlock()
|
||||
return existing, nil
|
||||
}
|
||||
e.cache[key] = art
|
||||
e.mu.Unlock()
|
||||
return art, nil
|
||||
}
|
||||
|
||||
func (e *Engine) compile(ctx context.Context, g *graphDef) (*compiledArtifact, error) {
|
||||
cpStore, err := e.checkpointStore()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idx := indexGraph(g)
|
||||
hitlIDs := collectHITLNodeIDs(idx)
|
||||
compileOpts := []compose.GraphCompileOption{
|
||||
compose.WithGraphName("CyberStrikeWorkflow"),
|
||||
compose.WithCheckPointStore(cpStore),
|
||||
}
|
||||
if len(hitlIDs) > 0 {
|
||||
compileOpts = append(compileOpts, compose.WithInterruptBeforeNodes(hitlIDs))
|
||||
}
|
||||
|
||||
wf := compose.NewWorkflow[WorkflowInput, WorkflowOutput](
|
||||
compose.WithGenLocalState(func(runCtx context.Context) *WorkflowLocalState {
|
||||
if rt := workflowRuntimeFrom(runCtx); rt != nil && rt.state != nil {
|
||||
return rt.state
|
||||
}
|
||||
return &WorkflowLocalState{
|
||||
Outputs: make(map[string]any),
|
||||
NodeOutputs: make(map[string]map[string]any),
|
||||
NodeProceed: make(map[string]bool),
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
nodeRefs := make(map[string]*compose.WorkflowNode, len(idx.nodes))
|
||||
for id, node := range idx.nodes {
|
||||
n := node
|
||||
if strings.EqualFold(n.Type, "agent") {
|
||||
sub, err := compileAgentSubgraph(ctx, n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("编译 Agent 子图 %s 失败: %w", id, err)
|
||||
}
|
||||
nodeRefs[id] = wf.AddGraphNode(id, sub)
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(n.Type, "start") {
|
||||
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowInput) (WorkflowNodeOutput, error) {
|
||||
return runWorkflowNodeLambda(runCtx, n)
|
||||
}))
|
||||
continue
|
||||
}
|
||||
nodeRefs[id] = wf.AddLambdaNode(id, compose.InvokableLambda(func(runCtx context.Context, _ WorkflowNodeOutput) (WorkflowNodeOutput, error) {
|
||||
return runWorkflowNodeLambda(runCtx, n)
|
||||
}))
|
||||
}
|
||||
|
||||
for id, node := range idx.nodes {
|
||||
if strings.EqualFold(node.Type, "condition") {
|
||||
if err := wireConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if hasConditionalOutgoingEdges(idx, id) {
|
||||
if err := wireEdgeConditionBranch(wf, nodeRefs, idx, id, node); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, edge := range idx.outgoing[id] {
|
||||
if target, ok := nodeRefs[edge.Target]; ok {
|
||||
target.AddInput(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, startID := range findStartNodeIDs(idx) {
|
||||
if ref, ok := nodeRefs[startID]; ok {
|
||||
ref.AddInput(compose.START)
|
||||
}
|
||||
}
|
||||
|
||||
endNode := wf.End()
|
||||
for id, node := range idx.nodes {
|
||||
if len(idx.outgoing[id]) == 0 || strings.EqualFold(node.Type, "end") {
|
||||
endNode.AddInput(id, compose.ToField(id))
|
||||
}
|
||||
}
|
||||
|
||||
runnable, err := wf.Compile(ctx, compileOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &compiledArtifact{runnable: runnable, idx: idx, hitlIDs: hitlIDs}, nil
|
||||
}
|
||||
|
||||
func collectHITLNodeIDs(idx *graphIndex) []string {
|
||||
var ids []string
|
||||
for id, node := range idx.nodes {
|
||||
if strings.EqualFold(node.Type, "hitl") {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func runWorkflowNodeLambda(runCtx context.Context, n graphNode) (WorkflowNodeOutput, error) {
|
||||
localRT := workflowRuntimeFrom(runCtx)
|
||||
if localRT == nil {
|
||||
return nil, fmt.Errorf("workflow runtime missing in context")
|
||||
}
|
||||
result, proceed, err := executeNode(runCtx, localRT.args, localRT.runID, n, localRT.state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
localRT.state.NodeOutputs[n.ID] = result
|
||||
localRT.state.LastOutput = result
|
||||
if !proceed && !strings.EqualFold(n.Type, "end") {
|
||||
label := firstNonEmpty(n.Label, n.ID)
|
||||
if errText := cfgString(result, "error"); errText != "" {
|
||||
return result, fmt.Errorf("节点「%s」失败: %s", label, errText)
|
||||
}
|
||||
return result, fmt.Errorf("节点「%s」未继续执行", label)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package workflow
|
||||
|
||||
import "errors"
|
||||
|
||||
// AwaitingHITLError indicates the workflow paused before a HITL node for human approval.
|
||||
type AwaitingHITLError struct {
|
||||
RunID string
|
||||
NodeID string
|
||||
NodeLabel string
|
||||
Prompt string
|
||||
Reviewer string
|
||||
}
|
||||
|
||||
func (e *AwaitingHITLError) Error() string {
|
||||
if e == nil {
|
||||
return "workflow awaiting human approval"
|
||||
}
|
||||
return "workflow awaiting human approval at node " + e.NodeID
|
||||
}
|
||||
|
||||
func IsAwaitingHITL(err error) bool {
|
||||
var target *AwaitingHITLError
|
||||
return errors.As(err, &target)
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type graphDef struct {
|
||||
Nodes []graphNode `json:"nodes"`
|
||||
Edges []graphEdge `json:"edges"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type graphNode struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Label string `json:"label"`
|
||||
Position graphPosition `json:"position"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type graphEdge struct {
|
||||
ID string `json:"id"`
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target"`
|
||||
Label string `json:"label"`
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type graphPosition struct {
|
||||
X float64 `json:"x"`
|
||||
Y float64 `json:"y"`
|
||||
}
|
||||
|
||||
type graphIndex struct {
|
||||
nodes map[string]graphNode
|
||||
outgoing map[string][]graphEdge
|
||||
incoming map[string][]graphEdge
|
||||
}
|
||||
|
||||
func parseGraph(raw string) (*graphDef, error) {
|
||||
var g graphDef
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(raw)), &g); err != nil {
|
||||
return nil, fmt.Errorf("解析工作流图失败: %w", err)
|
||||
}
|
||||
if len(g.Nodes) == 0 {
|
||||
return nil, fmt.Errorf("工作流没有节点")
|
||||
}
|
||||
if g.Config == nil {
|
||||
g.Config = make(map[string]any)
|
||||
}
|
||||
return &g, nil
|
||||
}
|
||||
|
||||
func indexGraph(g *graphDef) *graphIndex {
|
||||
idx := &graphIndex{
|
||||
nodes: make(map[string]graphNode, len(g.Nodes)),
|
||||
outgoing: make(map[string][]graphEdge),
|
||||
incoming: make(map[string][]graphEdge),
|
||||
}
|
||||
for _, node := range g.Nodes {
|
||||
node.ID = strings.TrimSpace(node.ID)
|
||||
if node.ID == "" {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(node.Type) == "" {
|
||||
node.Type = "tool"
|
||||
}
|
||||
if node.Config == nil {
|
||||
node.Config = make(map[string]any)
|
||||
}
|
||||
idx.nodes[node.ID] = node
|
||||
}
|
||||
for _, edge := range g.Edges {
|
||||
if _, ok := idx.nodes[edge.Source]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := idx.nodes[edge.Target]; !ok {
|
||||
continue
|
||||
}
|
||||
idx.outgoing[edge.Source] = append(idx.outgoing[edge.Source], edge)
|
||||
idx.incoming[edge.Target] = append(idx.incoming[edge.Target], edge)
|
||||
}
|
||||
for source := range idx.outgoing {
|
||||
sortEdgesByCanvas(idx.outgoing[source], idx.nodes)
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func sortEdgesByCanvas(edges []graphEdge, nodes map[string]graphNode) {
|
||||
sort.SliceStable(edges, func(i, j int) bool {
|
||||
a := nodes[edges[i].Target]
|
||||
b := nodes[edges[j].Target]
|
||||
if a.Position.Y != b.Position.Y {
|
||||
return a.Position.Y < b.Position.Y
|
||||
}
|
||||
if a.Position.X != b.Position.X {
|
||||
return a.Position.X < b.Position.X
|
||||
}
|
||||
return edges[i].Target < edges[j].Target
|
||||
})
|
||||
}
|
||||
|
||||
func sortNodeIDsByCanvas(ids []string, nodes map[string]graphNode) {
|
||||
sort.SliceStable(ids, func(i, j int) bool {
|
||||
a := nodes[ids[i]]
|
||||
b := nodes[ids[j]]
|
||||
if a.Position.Y != b.Position.Y {
|
||||
return a.Position.Y < b.Position.Y
|
||||
}
|
||||
if a.Position.X != b.Position.X {
|
||||
return a.Position.X < b.Position.X
|
||||
}
|
||||
return ids[i] < ids[j]
|
||||
})
|
||||
}
|
||||
|
||||
func findStartNodeIDs(idx *graphIndex) []string {
|
||||
var queue []string
|
||||
for id, node := range idx.nodes {
|
||||
if strings.EqualFold(node.Type, "start") {
|
||||
queue = append(queue, id)
|
||||
}
|
||||
}
|
||||
if len(queue) == 0 {
|
||||
inDegree := make(map[string]int, len(idx.nodes))
|
||||
for id := range idx.nodes {
|
||||
inDegree[id] = 0
|
||||
}
|
||||
for _, edges := range idx.outgoing {
|
||||
for _, edge := range edges {
|
||||
inDegree[edge.Target]++
|
||||
}
|
||||
}
|
||||
for id, deg := range inDegree {
|
||||
if deg == 0 {
|
||||
queue = append(queue, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
sortNodeIDsByCanvas(queue, idx.nodes)
|
||||
return queue
|
||||
}
|
||||
|
||||
func branchNodeID(nodeID string) string {
|
||||
return nodeID + "__eino_branch"
|
||||
}
|
||||
|
||||
func edgeBranchNodeID(nodeID string) string {
|
||||
return nodeID + "__eino_edge_branch"
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
// HITLDecision is a human decision on a workflow approval node.
|
||||
type HITLDecision struct {
|
||||
Approved bool
|
||||
Comment string
|
||||
}
|
||||
|
||||
var hitlWaiters sync.Map // runID -> chan HITLDecision
|
||||
|
||||
func registerHITLWaiter(runID string) chan HITLDecision {
|
||||
ch := make(chan HITLDecision, 1)
|
||||
hitlWaiters.Store(runID, ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
func unregisterHITLWaiter(runID string, ch chan HITLDecision) {
|
||||
hitlWaiters.CompareAndDelete(runID, ch)
|
||||
}
|
||||
|
||||
// NotifyHITLDecision wakes a streaming workflow run waiting at a HITL node.
|
||||
// Returns true when an active waiter was signaled.
|
||||
func NotifyHITLDecision(runID string, decision HITLDecision) bool {
|
||||
v, ok := hitlWaiters.Load(runID)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
ch, ok := v.(chan HITLDecision)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case ch <- decision:
|
||||
return true
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func readHITLDecisionFromDB(db *database.DB, runID string) (HITLDecision, bool, error) {
|
||||
if db == nil {
|
||||
return HITLDecision{}, false, nil
|
||||
}
|
||||
run, err := db.GetWorkflowRun(runID)
|
||||
if err != nil {
|
||||
return HITLDecision{}, false, err
|
||||
}
|
||||
if run == nil || strings.TrimSpace(run.PendingHITLJSON) == "" {
|
||||
return HITLDecision{}, false, nil
|
||||
}
|
||||
var pending map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(run.PendingHITLJSON), &pending); err != nil {
|
||||
return HITLDecision{}, false, nil
|
||||
}
|
||||
raw, ok := pending["decision"]
|
||||
if !ok {
|
||||
return HITLDecision{}, false, nil
|
||||
}
|
||||
decision := strings.ToLower(strings.TrimSpace(fmt.Sprint(raw)))
|
||||
switch decision {
|
||||
case "approved", "approve":
|
||||
comment := ""
|
||||
if v, ok := pending["comment"]; ok {
|
||||
comment = strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
return HITLDecision{Approved: true, Comment: comment}, true, nil
|
||||
case "rejected", "reject":
|
||||
comment := ""
|
||||
if v, ok := pending["comment"]; ok {
|
||||
comment = strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
return HITLDecision{Approved: false, Comment: comment}, true, nil
|
||||
default:
|
||||
return HITLDecision{}, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func waitWorkflowHITLDecision(ctx context.Context, db *database.DB, runID string) (HITLDecision, error) {
|
||||
ch := registerHITLWaiter(runID)
|
||||
defer unregisterHITLWaiter(runID, ch)
|
||||
return waitWorkflowHITLDecisionWithChannel(ctx, db, runID, ch)
|
||||
}
|
||||
|
||||
func waitWorkflowHITLDecisionWithChannel(ctx context.Context, db *database.DB, runID string, ch chan HITLDecision) (HITLDecision, error) {
|
||||
if d, ok, err := readHITLDecisionFromDB(db, runID); err != nil {
|
||||
return HITLDecision{}, err
|
||||
} else if ok {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return HITLDecision{}, ctx.Err()
|
||||
case d := <-ch:
|
||||
return d, nil
|
||||
case <-ticker.C:
|
||||
if d, ok, err := readHITLDecisionFromDB(db, runID); err != nil {
|
||||
return HITLDecision{}, err
|
||||
} else if ok {
|
||||
return d, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func executeNode(ctx context.Context, args RunArgs, runID string, node graphNode, state *WorkflowLocalState) (map[string]any, bool, error) {
|
||||
label := node.Label
|
||||
if strings.TrimSpace(label) == "" {
|
||||
label = node.ID
|
||||
}
|
||||
nodeRunID := uuid.NewString()
|
||||
input := map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": label,
|
||||
"inputs": state.Inputs,
|
||||
"previous": state.LastOutput,
|
||||
}
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
if err := args.DB.CreateWorkflowNodeRun(&database.WorkflowNodeRun{
|
||||
ID: nodeRunID,
|
||||
RunID: runID,
|
||||
NodeID: node.ID,
|
||||
Status: "running",
|
||||
InputJSON: string(inputJSON),
|
||||
StartedAt: time.Now(),
|
||||
}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_node_start", fmt.Sprintf("开始节点:%s", label), map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeRunId": nodeRunID,
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": label,
|
||||
})
|
||||
}
|
||||
|
||||
result, proceed, status, errText := runBuiltinNode(ctx, args, node, state)
|
||||
outputJSON, _ := json.Marshal(result)
|
||||
if err := args.DB.FinishWorkflowNodeRun(nodeRunID, status, string(outputJSON), errText); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if status == "skipped" {
|
||||
state.Skipped = append(state.Skipped, label)
|
||||
} else {
|
||||
state.Executed = append(state.Executed, label)
|
||||
}
|
||||
if args.Progress != nil {
|
||||
progressData := map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeRunId": nodeRunID,
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": label,
|
||||
"status": status,
|
||||
"output": result,
|
||||
}
|
||||
progressMsg := fmt.Sprintf("节点完成:%s(%s)", label, status)
|
||||
if strings.EqualFold(node.Type, "condition") {
|
||||
matched := false
|
||||
if v, ok := result["matched"].(bool); ok {
|
||||
matched = v
|
||||
}
|
||||
expr := cfgString(node.Config, "expression")
|
||||
if matched {
|
||||
progressMsg = fmt.Sprintf("条件判断:%s → 是", label)
|
||||
} else {
|
||||
progressMsg = fmt.Sprintf("条件判断:%s → 否", label)
|
||||
}
|
||||
progressData["expression"] = expr
|
||||
progressData["matched"] = matched
|
||||
}
|
||||
args.Progress("workflow_node_result", progressMsg, progressData)
|
||||
}
|
||||
state.NodeProceed[node.ID] = proceed
|
||||
return result, proceed, nil
|
||||
}
|
||||
|
||||
func emitConditionBranchProgress(args RunArgs, runID string, node graphNode, edges []graphEdge, nodes map[string]graphNode, state *WorkflowLocalState) {
|
||||
if args.Progress == nil || len(edges) == 0 {
|
||||
return
|
||||
}
|
||||
for edgeIdx, edge := range edges {
|
||||
allowed := edgeAllowed(edge, node, edgeIdx, state)
|
||||
target := nodes[edge.Target]
|
||||
targetLabel := strings.TrimSpace(target.Label)
|
||||
if targetLabel == "" {
|
||||
targetLabel = edge.Target
|
||||
}
|
||||
branchLabel := strings.TrimSpace(edge.Label)
|
||||
if branchLabel == "" {
|
||||
switch edgeIdx {
|
||||
case 0:
|
||||
branchLabel = "是"
|
||||
case 1:
|
||||
branchLabel = "否"
|
||||
default:
|
||||
branchLabel = fmt.Sprintf("分支 %d", edgeIdx+1)
|
||||
}
|
||||
}
|
||||
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
|
||||
eventType := "workflow_branch_skipped"
|
||||
msg := fmt.Sprintf("跳过分支「%s」→ %s", branchLabel, targetLabel)
|
||||
if allowed {
|
||||
eventType = "workflow_branch_taken"
|
||||
msg = fmt.Sprintf("执行分支「%s」→ %s", branchLabel, targetLabel)
|
||||
}
|
||||
args.Progress(eventType, msg, map[string]any{
|
||||
"workflowRunId": runID,
|
||||
"nodeId": node.ID,
|
||||
"nodeType": node.Type,
|
||||
"label": node.Label,
|
||||
"branchLabel": branchLabel,
|
||||
"targetId": edge.Target,
|
||||
"targetLabel": targetLabel,
|
||||
"edgeCondition": cond,
|
||||
"matched": conditionMatched(state),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
)
|
||||
|
||||
func runBuiltinNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
|
||||
cfg := node.Config
|
||||
switch strings.ToLower(strings.TrimSpace(node.Type)) {
|
||||
case "start":
|
||||
out := map[string]any{
|
||||
"output": state.Inputs["message"],
|
||||
"message": state.Inputs["message"],
|
||||
"conversationId": state.Inputs["conversationId"],
|
||||
"projectId": state.Inputs["projectId"],
|
||||
}
|
||||
return out, true, "completed", ""
|
||||
case "condition":
|
||||
expr := cfgString(cfg, "expression")
|
||||
ok := evalCondition(expr, state)
|
||||
out := map[string]any{"output": ok, "condition": expr, "matched": ok}
|
||||
return out, true, "completed", ""
|
||||
case "output":
|
||||
key := cfgString(cfg, "output_key")
|
||||
if key == "" {
|
||||
key = "result"
|
||||
}
|
||||
var value any
|
||||
if v := cfgString(cfg, "static_value"); v != "" {
|
||||
value = v
|
||||
} else {
|
||||
value = resolveOutputSourceBinding(cfg, state)
|
||||
}
|
||||
state.Outputs[key] = value
|
||||
return map[string]any{"output": value, "outputs": map[string]any{key: value}}, true, "completed", ""
|
||||
case "end":
|
||||
value := resolveOutputSourceBinding(cfg, state)
|
||||
if b, ok := parseFieldBinding(cfg, "result_binding"); ok {
|
||||
value = resolveBinding(b, state)
|
||||
}
|
||||
return map[string]any{"output": value}, false, "completed", ""
|
||||
case "tool":
|
||||
return runToolNode(ctx, args, node, state)
|
||||
case "agent":
|
||||
return runAgentNode(ctx, args, node, state)
|
||||
case "hitl":
|
||||
return runHITLNode(args, node, state)
|
||||
default:
|
||||
reason := "未知节点类型"
|
||||
return map[string]any{"output": "", "skipped": true, "reason": reason, "node_type": node.Type}, true, "skipped", reason
|
||||
}
|
||||
}
|
||||
|
||||
func runToolNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
|
||||
toolName := cfgString(node.Config, "tool_name")
|
||||
if toolName == "" {
|
||||
errText := "工具节点未选择 MCP 工具"
|
||||
return map[string]any{"output": "", "error": errText}, false, "failed", errText
|
||||
}
|
||||
if args.Agent == nil {
|
||||
errText := "工具节点执行失败:Agent 为空"
|
||||
return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText
|
||||
}
|
||||
toolArgs, err := resolveToolArguments(node.Config, state)
|
||||
if err != nil {
|
||||
errText := fmt.Sprintf("工具参数不是合法 JSON:%v", err)
|
||||
return map[string]any{"output": "", "tool_name": toolName, "error": errText}, false, "failed", errText
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_tool_start", fmt.Sprintf("调用工具:%s", toolName), map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"tool": toolName,
|
||||
"args": toolArgs,
|
||||
})
|
||||
}
|
||||
result, err := args.Agent.ExecuteMCPToolForConversation(ctx, args.ConversationID, toolName, toolArgs)
|
||||
if err != nil {
|
||||
errText := err.Error()
|
||||
return map[string]any{"output": "", "tool_name": toolName, "arguments": toolArgs, "error": errText}, false, "failed", errText
|
||||
}
|
||||
output := ""
|
||||
executionID := ""
|
||||
isError := false
|
||||
if result != nil {
|
||||
output = result.Result
|
||||
executionID = result.ExecutionID
|
||||
isError = result.IsError
|
||||
}
|
||||
out := map[string]any{
|
||||
"output": output,
|
||||
"tool_name": toolName,
|
||||
"arguments": toolArgs,
|
||||
"execution_id": executionID,
|
||||
"is_error": isError,
|
||||
}
|
||||
if key := cfgString(node.Config, "output_key"); key != "" {
|
||||
state.Outputs[key] = output
|
||||
}
|
||||
if isError {
|
||||
errText := strings.TrimSpace(output)
|
||||
if errText == "" {
|
||||
errText = "工具返回错误"
|
||||
}
|
||||
return out, false, "failed", errText
|
||||
}
|
||||
return out, true, "completed", ""
|
||||
}
|
||||
|
||||
func runAgentNode(ctx context.Context, args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
|
||||
if args.AppCfg == nil || args.Agent == nil {
|
||||
errText := "Agent 节点执行失败:应用配置或 Agent 为空"
|
||||
return map[string]any{"output": "", "error": errText}, false, "failed", errText
|
||||
}
|
||||
mode := strings.ToLower(cfgString(node.Config, "agent_mode"))
|
||||
if mode == "" {
|
||||
mode = "eino_single"
|
||||
}
|
||||
inputSource := resolveNodeInputBinding(node.Config, state)
|
||||
message := buildAgentNodeMessage(node, state, inputSource)
|
||||
var result *multiagent.RunResult
|
||||
var err error
|
||||
state.SegmentMaxIteration = 0
|
||||
agentProgress := workflowAgentProgress(args.Progress, state, node)
|
||||
switch mode {
|
||||
case "eino_single", "single", "chat":
|
||||
result, err = multiagent.RunEinoSingleChatModelAgent(
|
||||
ctx,
|
||||
args.AppCfg,
|
||||
&args.AppCfg.MultiAgent,
|
||||
args.Agent,
|
||||
args.DB,
|
||||
args.Logger,
|
||||
args.ConversationID,
|
||||
args.ProjectID,
|
||||
message,
|
||||
args.History,
|
||||
args.RoleTools,
|
||||
agentProgress,
|
||||
nil,
|
||||
args.SystemPromptExtra,
|
||||
)
|
||||
default:
|
||||
result, err = multiagent.RunDeepAgent(
|
||||
ctx,
|
||||
args.AppCfg,
|
||||
&args.AppCfg.MultiAgent,
|
||||
args.Agent,
|
||||
args.DB,
|
||||
args.Logger,
|
||||
args.ConversationID,
|
||||
args.ProjectID,
|
||||
message,
|
||||
args.History,
|
||||
args.RoleTools,
|
||||
agentProgress,
|
||||
args.AgentsMarkdownDir,
|
||||
mode,
|
||||
nil,
|
||||
args.SystemPromptExtra,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
errText := err.Error()
|
||||
state.MainIterationOffset += state.SegmentMaxIteration
|
||||
return map[string]any{"output": "", "mode": mode, "error": errText}, false, "failed", errText
|
||||
}
|
||||
state.MainIterationOffset += state.SegmentMaxIteration
|
||||
response := ""
|
||||
mcpIDs := []string{}
|
||||
if result != nil {
|
||||
response = result.Response
|
||||
mcpIDs = result.MCPExecutionIDs
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_agent_output", response, map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"label": firstNonEmpty(node.Label, node.ID),
|
||||
"mode": mode,
|
||||
"inputSource": inputSource,
|
||||
"inputPreview": truncateWorkflowPreview(inputSource, 500),
|
||||
"mcpExecutionIds": mcpIDs,
|
||||
})
|
||||
}
|
||||
if key := cfgString(node.Config, "output_key"); key != "" {
|
||||
state.Outputs[key] = response
|
||||
}
|
||||
return map[string]any{
|
||||
"output": response,
|
||||
"mode": mode,
|
||||
"mcp_execution_ids": mcpIDs,
|
||||
}, true, "completed", ""
|
||||
}
|
||||
|
||||
func buildAgentNodeMessage(node graphNode, state *WorkflowLocalState, upstreamInput string) string {
|
||||
instruction := strings.TrimSpace(cfgString(node.Config, "instruction"))
|
||||
upstreamInput = strings.TrimSpace(upstreamInput)
|
||||
if instruction == "" {
|
||||
if upstreamInput != "" {
|
||||
return fmt.Sprintf("请基于上游节点输出继续处理:\n%s", upstreamInput)
|
||||
}
|
||||
return fmt.Sprintf("请基于上游节点输出继续处理:\n%v", state.LastOutput["output"])
|
||||
}
|
||||
if upstreamInput == "" {
|
||||
return instruction
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("上游输入:\n%s\n\n节点指令:\n%s", upstreamInput, instruction))
|
||||
}
|
||||
|
||||
func workflowAgentProgress(progress agent.ProgressCallback, state *WorkflowLocalState, node graphNode) agent.ProgressCallback {
|
||||
if progress == nil {
|
||||
return nil
|
||||
}
|
||||
return func(eventType, message string, data interface{}) {
|
||||
switch eventType {
|
||||
case "response_start", "response_delta", "response", "done":
|
||||
return
|
||||
default:
|
||||
enrichWorkflowAgentEventData(data, state, node)
|
||||
if eventType == "iteration" {
|
||||
applyWorkflowMainIterationOffset(data, state)
|
||||
}
|
||||
progress(eventType, message, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func enrichWorkflowAgentEventData(data interface{}, state *WorkflowLocalState, node graphNode) {
|
||||
m, ok := data.(map[string]interface{})
|
||||
if !ok || m == nil {
|
||||
return
|
||||
}
|
||||
if node.ID != "" {
|
||||
m["workflowNodeId"] = node.ID
|
||||
}
|
||||
if state != nil && strings.TrimSpace(state.WorkflowRunID) != "" {
|
||||
m["workflowRunId"] = state.WorkflowRunID
|
||||
}
|
||||
}
|
||||
|
||||
func applyWorkflowMainIterationOffset(data interface{}, state *WorkflowLocalState) {
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
m, ok := data.(map[string]interface{})
|
||||
if !ok || m == nil {
|
||||
return
|
||||
}
|
||||
scope, _ := m["einoScope"].(string)
|
||||
if strings.TrimSpace(scope) != "main" {
|
||||
return
|
||||
}
|
||||
raw := iterationNumberFromProgressData(m)
|
||||
if raw <= 0 {
|
||||
return
|
||||
}
|
||||
if raw > state.SegmentMaxIteration {
|
||||
state.SegmentMaxIteration = raw
|
||||
}
|
||||
m["iteration"] = raw + state.MainIterationOffset
|
||||
}
|
||||
|
||||
func iterationNumberFromProgressData(m map[string]interface{}) int {
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
return v
|
||||
case int32:
|
||||
return int(v)
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case float32:
|
||||
return int(v)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func runHITLNode(args RunArgs, node graphNode, state *WorkflowLocalState) (map[string]any, bool, string, string) {
|
||||
prompt := resolveHITLPromptBinding(node.Config, state)
|
||||
reviewer := cfgString(node.Config, "reviewer")
|
||||
if reviewer == "" {
|
||||
reviewer = "human"
|
||||
}
|
||||
approved := true
|
||||
if state != nil && state.Inputs != nil {
|
||||
if v, ok := state.Inputs["_hitl_approved"]; ok {
|
||||
approved = fmt.Sprint(v) == "true"
|
||||
}
|
||||
}
|
||||
if !approved {
|
||||
reason := "人工审批已拒绝"
|
||||
if state != nil && state.Inputs != nil {
|
||||
if v, ok := state.Inputs["_hitl_comment"]; ok {
|
||||
if s := strings.TrimSpace(fmt.Sprint(v)); s != "" {
|
||||
reason = s
|
||||
}
|
||||
}
|
||||
}
|
||||
return map[string]any{"output": "", "prompt": prompt, "approved": false, "mode": "interactive"}, false, "failed", reason
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_hitl_checkpoint", "人工确认节点已通过", map[string]any{
|
||||
"nodeId": node.ID,
|
||||
"prompt": prompt,
|
||||
"reviewer": reviewer,
|
||||
"mode": "interactive",
|
||||
"approved": true,
|
||||
})
|
||||
}
|
||||
return map[string]any{
|
||||
"output": prompt,
|
||||
"prompt": prompt,
|
||||
"reviewer": reviewer,
|
||||
"approved": true,
|
||||
"mode": "interactive",
|
||||
}, true, "completed", ""
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ShouldAutoRunRoleWorkflow returns true when a role explicitly binds a workflow
|
||||
// and does not turn it off. Empty policy defaults to auto to keep role UX simple.
|
||||
func ShouldAutoRunRoleWorkflow(role config.RoleConfig) bool {
|
||||
if strings.TrimSpace(role.WorkflowID) == "" {
|
||||
return false
|
||||
}
|
||||
policy := strings.ToLower(strings.TrimSpace(role.WorkflowPolicy))
|
||||
return policy == "" || policy == "auto"
|
||||
}
|
||||
|
||||
// RunRoleBoundWorkflow executes the persisted role-bound workflow via cached Eino Workflow.
|
||||
func RunRoleBoundWorkflow(ctx context.Context, args RunArgs) (*RunResult, error) {
|
||||
if args.DB == nil {
|
||||
return nil, fmt.Errorf("workflow db is nil")
|
||||
}
|
||||
workflowID := strings.TrimSpace(args.Role.WorkflowID)
|
||||
if workflowID == "" {
|
||||
return nil, fmt.Errorf("角色未绑定工作流")
|
||||
}
|
||||
wf, err := args.DB.GetWorkflowDefinition(workflowID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wf == nil {
|
||||
return nil, fmt.Errorf("角色绑定的工作流不存在: %s", workflowID)
|
||||
}
|
||||
if !wf.Enabled {
|
||||
return nil, fmt.Errorf("角色绑定的工作流已禁用: %s", workflowID)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
runID := uuid.NewString()
|
||||
input := map[string]interface{}{
|
||||
"message": args.UserMessage,
|
||||
"conversationId": args.ConversationID,
|
||||
"projectId": args.ProjectID,
|
||||
"role": args.Role.Name,
|
||||
"workflowId": wf.ID,
|
||||
"workflowVersion": wf.Version,
|
||||
}
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
run := &database.WorkflowRun{
|
||||
ID: runID,
|
||||
WorkflowID: wf.ID,
|
||||
WorkflowVersion: wf.Version,
|
||||
ConversationID: args.ConversationID,
|
||||
ProjectID: args.ProjectID,
|
||||
RoleID: args.Role.Name,
|
||||
Status: "running",
|
||||
InputJSON: string(inputJSON),
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
if err := args.DB.CreateWorkflowRun(run); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_start", fmt.Sprintf("开始运行流程「%s」", wf.Name), map[string]interface{}{
|
||||
"workflowId": wf.ID,
|
||||
"workflowName": wf.Name,
|
||||
"workflowVersion": wf.Version,
|
||||
"workflowRunId": runID,
|
||||
"conversationId": args.ConversationID,
|
||||
"engine": "eino_workflow",
|
||||
})
|
||||
}
|
||||
|
||||
graph, err := parseGraph(wf.GraphJSON)
|
||||
if err != nil {
|
||||
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
state := newWorkflowLocalState(input, runID)
|
||||
streaming := args.Progress != nil
|
||||
resuming := false
|
||||
for {
|
||||
_, err := invokeEinoGraph(ctx, args, runID, wf.ID, wf.Version, graph, state, resuming)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if !IsAwaitingHITL(err) {
|
||||
_ = args.DB.FinishWorkflowRun(runID, "failed", "", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
hitl := err.(*AwaitingHITLError)
|
||||
partial := map[string]interface{}{
|
||||
"workflowId": wf.ID,
|
||||
"workflowName": wf.Name,
|
||||
"workflowVersion": wf.Version,
|
||||
"workflowRunId": runID,
|
||||
"status": "awaiting_hitl",
|
||||
"outputs": state.Outputs,
|
||||
"executedNodes": state.Executed,
|
||||
"skippedNodes": state.Skipped,
|
||||
"pendingHitl": map[string]interface{}{
|
||||
"nodeId": hitl.NodeID,
|
||||
"label": hitl.NodeLabel,
|
||||
"prompt": hitl.Prompt,
|
||||
},
|
||||
"engine": "eino_workflow",
|
||||
}
|
||||
partialJSON, _ := json.Marshal(partial)
|
||||
_ = args.DB.SetWorkflowRunAwaitingHITL(runID, hitl.NodeID, string(partialJSON))
|
||||
response := fmt.Sprintf("工作流「%s」已在节点「%s」暂停,等待人工审批。\n运行 ID:%s", wf.Name, firstNonEmpty(hitl.NodeLabel, hitl.NodeID), runID)
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_paused", response, map[string]interface{}{
|
||||
"workflowRunId": runID,
|
||||
"status": "awaiting_hitl",
|
||||
"nodeId": hitl.NodeID,
|
||||
"resumeApi": fmt.Sprintf("/api/workflows/runs/%s/resume", runID),
|
||||
})
|
||||
}
|
||||
if !streaming {
|
||||
return &RunResult{
|
||||
Response: response,
|
||||
RunID: runID,
|
||||
Status: "awaiting_hitl",
|
||||
AwaitingHITL: true,
|
||||
}, nil
|
||||
}
|
||||
ch := registerHITLWaiter(runID)
|
||||
decision, waitErr := waitWorkflowHITLDecisionWithChannel(ctx, args.DB, runID, ch)
|
||||
unregisterHITLWaiter(runID, ch)
|
||||
if waitErr != nil {
|
||||
_ = args.DB.FinishWorkflowRun(runID, "cancelled", "", waitErr.Error())
|
||||
return nil, waitErr
|
||||
}
|
||||
if !decision.Approved {
|
||||
errText := strings.TrimSpace(decision.Comment)
|
||||
if errText == "" {
|
||||
errText = "人工审批拒绝"
|
||||
}
|
||||
_ = args.DB.FinishWorkflowRun(runID, "rejected", "", errText)
|
||||
rejectResponse := fmt.Sprintf("工作流已在审批节点「%s」被拒绝。", firstNonEmpty(hitl.NodeLabel, hitl.NodeID))
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_hitl_rejected", rejectResponse, map[string]interface{}{
|
||||
"workflowRunId": runID,
|
||||
"nodeId": hitl.NodeID,
|
||||
"comment": errText,
|
||||
})
|
||||
}
|
||||
return &RunResult{
|
||||
Response: rejectResponse,
|
||||
RunID: runID,
|
||||
Status: "rejected",
|
||||
}, nil
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_hitl_resumed", "人工审批已通过,继续执行", map[string]interface{}{
|
||||
"workflowRunId": runID,
|
||||
"nodeId": hitl.NodeID,
|
||||
"comment": decision.Comment,
|
||||
})
|
||||
}
|
||||
if state.Inputs == nil {
|
||||
state.Inputs = map[string]any{}
|
||||
}
|
||||
state.Inputs["_hitl_approved"] = true
|
||||
state.Inputs["_hitl_comment"] = decision.Comment
|
||||
state.Inputs["_hitl_node_id"] = hitl.NodeID
|
||||
_ = args.DB.SetWorkflowRunStatus(runID, "running")
|
||||
resuming = true
|
||||
}
|
||||
|
||||
output := map[string]interface{}{
|
||||
"workflowId": wf.ID,
|
||||
"workflowName": wf.Name,
|
||||
"workflowVersion": wf.Version,
|
||||
"workflowRunId": runID,
|
||||
"status": "completed",
|
||||
"outputs": state.Outputs,
|
||||
"executedNodes": state.Executed,
|
||||
"skippedNodes": state.Skipped,
|
||||
"engine": "eino_workflow",
|
||||
}
|
||||
outputJSON, _ := json.Marshal(output)
|
||||
|
||||
response := renderWorkflowResponse(args.Role.Name, wf.Name, wf.Version, runID, state)
|
||||
if err := args.DB.FinishWorkflowRun(runID, "completed", string(outputJSON), ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if args.Progress != nil {
|
||||
args.Progress("workflow_done", fmt.Sprintf("流程「%s」运行完成", wf.Name), map[string]interface{}{
|
||||
"workflowRunId": runID,
|
||||
"workflowId": wf.ID,
|
||||
"outputs": state.Outputs,
|
||||
"response": response,
|
||||
"engine": "eino_workflow",
|
||||
})
|
||||
}
|
||||
if args.Logger != nil {
|
||||
args.Logger.Info("role-bound workflow completed",
|
||||
zap.String("workflow_id", wf.ID),
|
||||
zap.String("workflow_run_id", runID),
|
||||
zap.String("conversation_id", args.ConversationID),
|
||||
zap.String("role", args.Role.Name),
|
||||
zap.String("engine", "eino_workflow"),
|
||||
)
|
||||
}
|
||||
return &RunResult{Response: response, RunID: runID, Status: "completed"}, nil
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func init() {
|
||||
schema.RegisterName[*WorkflowLocalState]("_cyberstrike_workflow_local_state")
|
||||
}
|
||||
|
||||
// WorkflowLocalState is the Eino WithGenLocalState payload (checkpoint-serializable).
|
||||
type WorkflowLocalState struct {
|
||||
Inputs map[string]any `json:"inputs,omitempty"`
|
||||
Outputs map[string]any `json:"outputs,omitempty"`
|
||||
NodeOutputs map[string]map[string]any `json:"nodeOutputs,omitempty"`
|
||||
NodeProceed map[string]bool `json:"nodeProceed,omitempty"`
|
||||
LastOutput map[string]any `json:"lastOutput,omitempty"`
|
||||
Executed []string `json:"executed,omitempty"`
|
||||
Skipped []string `json:"skipped,omitempty"`
|
||||
WorkflowRunID string `json:"workflowRunId,omitempty"`
|
||||
MainIterationOffset int `json:"mainIterationOffset,omitempty"`
|
||||
SegmentMaxIteration int `json:"segmentMaxIteration,omitempty"`
|
||||
}
|
||||
|
||||
func newWorkflowLocalState(inputs map[string]interface{}, runID string) *WorkflowLocalState {
|
||||
in := make(map[string]any, len(inputs))
|
||||
for k, v := range inputs {
|
||||
in[k] = v
|
||||
}
|
||||
return &WorkflowLocalState{
|
||||
Inputs: in,
|
||||
Outputs: make(map[string]any),
|
||||
NodeOutputs: make(map[string]map[string]any),
|
||||
NodeProceed: make(map[string]bool),
|
||||
WorkflowRunID: runID,
|
||||
}
|
||||
}
|
||||
|
||||
var templateVarRe = regexp.MustCompile(`\{\{\s*([a-zA-Z0-9_.-]+)\s*\}\}`)
|
||||
|
||||
func resolveTemplate(s string, state *WorkflowLocalState) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return fmt.Sprint(valueFromPath("previous.output", state))
|
||||
}
|
||||
return templateVarRe.ReplaceAllStringFunc(s, func(match string) string {
|
||||
m := templateVarRe.FindStringSubmatch(match)
|
||||
if len(m) != 2 {
|
||||
return match
|
||||
}
|
||||
return fmt.Sprint(valueFromPath(m[1], state))
|
||||
})
|
||||
}
|
||||
|
||||
func valueFromPath(path string, state *WorkflowLocalState) any {
|
||||
parts := strings.Split(path, ".")
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
var cur any
|
||||
switch parts[0] {
|
||||
case "inputs", "input":
|
||||
cur = state.Inputs
|
||||
case "previous", "prev":
|
||||
cur = state.LastOutput
|
||||
case "outputs":
|
||||
cur = state.Outputs
|
||||
default:
|
||||
if v, ok := state.Inputs[parts[0]]; ok {
|
||||
cur = v
|
||||
} else if v, ok := state.NodeOutputs[parts[0]]; ok {
|
||||
cur = v
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
for _, p := range parts[1:] {
|
||||
m, ok := cur.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
cur = m[p]
|
||||
}
|
||||
if cur == nil {
|
||||
return ""
|
||||
}
|
||||
return cur
|
||||
}
|
||||
|
||||
func evalCondition(expr string, state *WorkflowLocalState) bool {
|
||||
expr = strings.TrimSpace(expr)
|
||||
if expr == "" {
|
||||
return true
|
||||
}
|
||||
resolved := strings.TrimSpace(resolveTemplate(expr, state))
|
||||
switch {
|
||||
case strings.Contains(resolved, "!="):
|
||||
parts := strings.SplitN(resolved, "!=", 2)
|
||||
return cleanComparable(parts[0]) != cleanComparable(parts[1])
|
||||
case strings.Contains(resolved, "=="):
|
||||
parts := strings.SplitN(resolved, "==", 2)
|
||||
return cleanComparable(parts[0]) == cleanComparable(parts[1])
|
||||
default:
|
||||
v := strings.ToLower(cleanComparable(resolved))
|
||||
return v != "" && v != "false" && v != "0" && v != "null"
|
||||
}
|
||||
}
|
||||
|
||||
func cleanComparable(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
s = strings.Trim(s, `"'`)
|
||||
return s
|
||||
}
|
||||
|
||||
func edgeAllowed(edge graphEdge, sourceNode graphNode, edgeIndex int, state *WorkflowLocalState) bool {
|
||||
cond := firstNonEmpty(cfgString(edge.Config, "condition"), cfgString(edge.Config, "expression"))
|
||||
if cond != "" {
|
||||
return evalCondition(cond, state)
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(sourceNode.Type), "condition") {
|
||||
return conditionBranchAllowed(edge, edgeIndex, state)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func conditionBranchAllowed(edge graphEdge, edgeIndex int, state *WorkflowLocalState) bool {
|
||||
matched := conditionMatched(state)
|
||||
if branch := conditionBranchHint(edge); branch != "" {
|
||||
return (branch == "true" && matched) || (branch == "false" && !matched)
|
||||
}
|
||||
switch edgeIndex {
|
||||
case 0:
|
||||
return matched
|
||||
case 1:
|
||||
return !matched
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func conditionMatched(state *WorkflowLocalState) bool {
|
||||
v := strings.ToLower(cleanComparable(fmt.Sprint(valueFromPath("previous.matched", state))))
|
||||
return v == "true" || v == "1"
|
||||
}
|
||||
|
||||
func conditionBranchHint(edge graphEdge) string {
|
||||
if edge.Config != nil {
|
||||
switch strings.ToLower(strings.TrimSpace(cfgString(edge.Config, "branch"))) {
|
||||
case "true", "yes", "y", "是":
|
||||
return "true"
|
||||
case "false", "no", "n", "否":
|
||||
return "false"
|
||||
}
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(edge.Label)) {
|
||||
case "true", "yes", "y", "是":
|
||||
return "true"
|
||||
case "false", "no", "n", "否":
|
||||
return "false"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func cfgString(cfg map[string]any, key string) string {
|
||||
if cfg == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if s := strings.TrimSpace(value); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func truncateWorkflowPreview(s string, limit int) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if limit <= 0 || len([]rune(s)) <= limit {
|
||||
return s
|
||||
}
|
||||
runes := []rune(s)
|
||||
return string(runes[:limit]) + "..."
|
||||
}
|
||||
|
||||
func renderWorkflowResponse(roleName, workflowName string, version int, runID string, state *WorkflowLocalState) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("角色「%s」已完成工作流「%s」(版本 %d)。\n\n", roleName, workflowName, version))
|
||||
sb.WriteString(fmt.Sprintf("运行 ID:%s\n", runID))
|
||||
sb.WriteString(fmt.Sprintf("已执行节点:%d", len(state.Executed)))
|
||||
if len(state.Skipped) > 0 {
|
||||
sb.WriteString(fmt.Sprintf(",跳过节点:%d", len(state.Skipped)))
|
||||
}
|
||||
sb.WriteString("\n\n")
|
||||
if len(state.Outputs) > 0 {
|
||||
sb.WriteString("输出:\n")
|
||||
keys := make([]string, 0, len(state.Outputs))
|
||||
for k := range state.Outputs {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
sb.WriteString(fmt.Sprintf("- %s:%v\n", k, state.Outputs[k]))
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("暂无输出。请检查是否配置了输出节点,或条件分支是否命中。\n")
|
||||
}
|
||||
if len(state.Skipped) > 0 {
|
||||
sb.WriteString("\n未执行的节点类型仍会保留运行记录:")
|
||||
sb.WriteString(strings.Join(state.Skipped, "、"))
|
||||
sb.WriteString("。")
|
||||
}
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// WorkflowInput is the typed entry for Eino compose.Workflow[I,O].
|
||||
type WorkflowInput struct {
|
||||
Message string `json:"message"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
ProjectID string `json:"projectId"`
|
||||
Role string `json:"role"`
|
||||
WorkflowID string `json:"workflowId"`
|
||||
WorkflowVersion int `json:"workflowVersion"`
|
||||
}
|
||||
|
||||
// WorkflowOutput aggregates terminal node payloads keyed by canvas node id.
|
||||
type WorkflowOutput map[string]any
|
||||
|
||||
// WorkflowNodeOutput is the per-node lambda payload (alias for Eino edge type alignment).
|
||||
type WorkflowNodeOutput = map[string]interface{}
|
||||
|
||||
func workflowInputFromMap(m map[string]interface{}) WorkflowInput {
|
||||
in := WorkflowInput{}
|
||||
if m == nil {
|
||||
return in
|
||||
}
|
||||
if v, ok := m["message"].(string); ok {
|
||||
in.Message = v
|
||||
} else if m["message"] != nil {
|
||||
in.Message = fmt.Sprint(m["message"])
|
||||
}
|
||||
if v, ok := m["conversationId"].(string); ok {
|
||||
in.ConversationID = v
|
||||
}
|
||||
if v, ok := m["projectId"].(string); ok {
|
||||
in.ProjectID = v
|
||||
}
|
||||
if v, ok := m["role"].(string); ok {
|
||||
in.Role = v
|
||||
}
|
||||
if v, ok := m["workflowId"].(string); ok {
|
||||
in.WorkflowID = v
|
||||
}
|
||||
switch v := m["workflowVersion"].(type) {
|
||||
case int:
|
||||
in.WorkflowVersion = v
|
||||
case int64:
|
||||
in.WorkflowVersion = int(v)
|
||||
case float64:
|
||||
in.WorkflowVersion = int(v)
|
||||
case string:
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
in.WorkflowVersion = n
|
||||
}
|
||||
}
|
||||
return in
|
||||
}
|
||||
|
||||
func (in WorkflowInput) toStateInputs() map[string]any {
|
||||
return map[string]any{
|
||||
"message": in.Message,
|
||||
"conversationId": in.ConversationID,
|
||||
"projectId": in.ProjectID,
|
||||
"role": in.Role,
|
||||
"workflowId": in.WorkflowID,
|
||||
"workflowVersion": in.WorkflowVersion,
|
||||
}
|
||||
}
|
||||
|
||||
func cacheKey(workflowID string, version int) string {
|
||||
return workflowID + ":" + strconv.Itoa(version)
|
||||
}
|
||||
@@ -0,0 +1,318 @@
|
||||
name: "virustotal_search"
|
||||
command: "python3"
|
||||
args:
|
||||
- "-c"
|
||||
- |
|
||||
import sys
|
||||
import json
|
||||
import requests
|
||||
import os
|
||||
import time
|
||||
|
||||
# ==================== VirusTotal 配置 ====================
|
||||
# 请在此处配置您的 VirusTotal API 密钥
|
||||
# 您也可以在环境变量中设置:VT_API_KEY
|
||||
# enable 默认为 false,需开启才能调用该MCP
|
||||
VT_API_KEY = "" # 请填写您的 VirusTotal API 密钥
|
||||
# =======================================================
|
||||
|
||||
# VirusTotal API 基础 URL
|
||||
BASE_URL = "https://www.virustotal.com/api/v3"
|
||||
|
||||
def parse_args():
|
||||
"""解析命令行参数"""
|
||||
# 尝试从第一个参数读取 JSON 配置
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
arg1 = str(sys.argv[1])
|
||||
config = json.loads(arg1)
|
||||
if isinstance(config, dict):
|
||||
return config
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 传统位置参数方式
|
||||
config = {}
|
||||
if len(sys.argv) > 1:
|
||||
config['domain'] = str(sys.argv[1])
|
||||
if len(sys.argv) > 2:
|
||||
try:
|
||||
config['limit'] = int(sys.argv[2])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if len(sys.argv) > 3:
|
||||
config['include_ips'] = sys.argv[3].lower() in ('true', '1', 'yes')
|
||||
return config
|
||||
|
||||
def query_virustotal_subdomains(domain, api_key, limit=100, include_ips=False):
|
||||
"""
|
||||
查询 VirusTotal 的子域名信息
|
||||
|
||||
Args:
|
||||
domain: 要查询的域名
|
||||
api_key: VirusTotal API 密钥
|
||||
limit: 返回结果数量限制
|
||||
include_ips: 是否包含 IP 地址信息
|
||||
|
||||
Returns:
|
||||
dict: 包含查询结果的字典
|
||||
"""
|
||||
# 构建 API 请求 URL
|
||||
url = f"{BASE_URL}/domains/{domain}/subdomains"
|
||||
|
||||
headers = {
|
||||
"x-apikey": api_key,
|
||||
"accept": "application/json"
|
||||
}
|
||||
|
||||
params = {
|
||||
"limit": min(limit, 40) # API 限制最大 40
|
||||
}
|
||||
|
||||
all_results = []
|
||||
next_url = None
|
||||
|
||||
try:
|
||||
# 处理分页
|
||||
while True:
|
||||
if next_url:
|
||||
response = requests.get(next_url, headers=headers, timeout=30)
|
||||
else:
|
||||
response = requests.get(url, headers=headers, params=params, timeout=30)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# 提取子域名数据
|
||||
if 'data' in data and data['data']:
|
||||
for item in data['data']:
|
||||
if 'id' in item:
|
||||
subdomain_info = {
|
||||
'subdomain': item['id'],
|
||||
'type': item.get('type', 'domain'),
|
||||
}
|
||||
|
||||
# 如果 include_ips 为 True,尝试获取解析 IP
|
||||
if include_ips and 'attributes' in item:
|
||||
attributes = item.get('attributes', {})
|
||||
# 这里简化处理,实际可能需要额外的 API 调用
|
||||
subdomain_info['last_dns_records'] = attributes.get('last_dns_records', [])
|
||||
|
||||
all_results.append(subdomain_info)
|
||||
|
||||
# 检查是否有下一页
|
||||
if 'links' in data and 'next' in data['links'] and len(all_results) < limit:
|
||||
next_url = data['links']['next']
|
||||
# 避免请求过快
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果已达到限制,停止获取
|
||||
if len(all_results) >= limit:
|
||||
break
|
||||
|
||||
# 处理返回结果
|
||||
if all_results:
|
||||
return {
|
||||
"status": "success",
|
||||
"domain": domain,
|
||||
"total_found": len(all_results),
|
||||
"results": all_results[:limit],
|
||||
"message": f"成功获取 {len(all_results[:limit])} 个子域名"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "success",
|
||||
"domain": domain,
|
||||
"total_found": 0,
|
||||
"results": [],
|
||||
"message": f"未找到 {domain} 的子域名"
|
||||
}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = str(e)
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"API 请求失败: {error_msg}",
|
||||
"suggestion": "请检查网络连接、API 密钥是否正确,或 VirusTotal API 服务是否可用"
|
||||
}
|
||||
|
||||
# 处理特定 HTTP 状态码
|
||||
if hasattr(e, 'response') and e.response:
|
||||
status_code = e.response.status_code
|
||||
if status_code == 401:
|
||||
error_result["message"] = "API 密钥无效或未授权"
|
||||
error_result["suggestion"] = "请检查 VirusTotal API 密钥是否正确,或在 https://www.virustotal.com/ 获取有效密钥"
|
||||
elif status_code == 429:
|
||||
error_result["message"] = "API 请求频率超限"
|
||||
error_result["suggestion"] = "请稍后再试,VirusTotal API 有严格的速率限制(免费版每分钟4次)"
|
||||
elif status_code == 404:
|
||||
error_result["message"] = f"域名 '{domain}' 不存在或未找到"
|
||||
|
||||
return error_result
|
||||
|
||||
try:
|
||||
config = parse_args()
|
||||
|
||||
if not isinstance(config, dict):
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"参数解析错误: 期望字典类型,但得到 {type(config).__name__}",
|
||||
"type": "TypeError"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
# 获取 API 密钥(从配置或环境变量)
|
||||
api_key = os.getenv('VT_API_KEY', VT_API_KEY).strip()
|
||||
|
||||
if not api_key:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": "缺少 VirusTotal API 密钥",
|
||||
"required_config": ["VT_API_KEY"],
|
||||
"note": "请在 YAML 文件的 VT_API_KEY 配置项中填写您的 VirusTotal API 密钥,或在环境变量 VT_API_KEY 中设置。API 密钥可在 https://www.virustotal.com/ 注册获取"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
# 获取必需参数
|
||||
domain = config.get('domain', '').strip()
|
||||
if not domain:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": "缺少必需参数: domain(要查询的域名)",
|
||||
"required_params": ["domain"],
|
||||
"examples": [
|
||||
"example.com",
|
||||
"google.com",
|
||||
"baidu.com"
|
||||
]
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
# 获取可选参数
|
||||
limit = config.get('limit', 100)
|
||||
try:
|
||||
limit = int(limit)
|
||||
if limit < 1:
|
||||
limit = 100
|
||||
elif limit > 1000:
|
||||
limit = 1000 # 限制最大 1000
|
||||
except (ValueError, TypeError):
|
||||
limit = 100
|
||||
|
||||
include_ips = config.get('include_ips', False)
|
||||
if isinstance(include_ips, str):
|
||||
include_ips = include_ips.lower() in ('true', '1', 'yes')
|
||||
|
||||
# 执行查询
|
||||
result = query_virustotal_subdomains(domain, api_key, limit, include_ips)
|
||||
|
||||
# 输出结果
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"执行出错: {str(e)}",
|
||||
"type": type(e).__name__
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
enabled: false
|
||||
|
||||
short_description: "VirusTotal 子域名查询工具,通过 VirusTotal API 被动收集域名子域名"
|
||||
|
||||
description: |
|
||||
VirusTotal 子域名查询工具,利用 VirusTotal 聚合的历史 DNS 数据来发现目标域名的子域名。
|
||||
|
||||
**主要功能:**
|
||||
- 被动子域名收集:从 VirusTotal 历史 DNS 数据中检索子域名
|
||||
- 分页查询:支持大量子域名的获取
|
||||
- IP 关联:可选包含 DNS 解析记录
|
||||
- 去重处理:自动去重返回结果
|
||||
|
||||
**使用场景:**
|
||||
- 安全测试前期信息收集
|
||||
- 企业网络资产发现
|
||||
- 攻击面分析
|
||||
- 威胁情报收集
|
||||
- 渗透测试信息收集
|
||||
|
||||
**数据来源:**
|
||||
VirusTotal 聚合了来自多个来源的 DNS 数据,包括:
|
||||
- 历史 DNS 解析记录
|
||||
- 被动 DNS 数据库
|
||||
- 证书透明度日志
|
||||
- 安全扫描数据
|
||||
|
||||
**注意事项:**
|
||||
- **API 密钥必需**:需要在 VirusTotal 注册账号并获取 API 密钥
|
||||
- **速率限制**:免费版 API 每分钟限制 4 次请求
|
||||
- **数据时效性**:数据基于历史扫描记录,可能不是实时的
|
||||
- **使用授权**:仅允许对您拥有合法授权的目标进行查询
|
||||
- **配额限制**:免费版每月有查询配额限制
|
||||
|
||||
parameters:
|
||||
- name: "domain"
|
||||
type: "string"
|
||||
description: |
|
||||
要查询的目标域名(必需)。
|
||||
|
||||
**格式要求:**
|
||||
- 仅输入主域名,不要包含协议头(http://)或路径
|
||||
- 支持二级域名查询
|
||||
|
||||
**示例值:**
|
||||
- "example.com"
|
||||
- "google.com"
|
||||
- "baidu.com"
|
||||
- "github.com"
|
||||
|
||||
**注意事项:**
|
||||
- 域名格式必须正确
|
||||
- 查询结果可能包含跨域子域名
|
||||
required: true
|
||||
position: 2
|
||||
format: "positional"
|
||||
|
||||
- name: "limit"
|
||||
type: "int"
|
||||
description: |
|
||||
返回结果数量限制(可选)。
|
||||
|
||||
**说明:**
|
||||
- 默认值:40
|
||||
- 最大值:1000(API 限制)
|
||||
- 建议值:100-500
|
||||
|
||||
**注意事项:**
|
||||
- 设置过大的值可能导致请求超时
|
||||
- API 单次返回限制为 40 条,超过会自动分页
|
||||
required: false
|
||||
position: 3
|
||||
format: "positional"
|
||||
default: 40
|
||||
|
||||
- name: "include_ips"
|
||||
type: "bool"
|
||||
description: |
|
||||
是否包含 IP 地址信息(可选)。
|
||||
|
||||
**说明:**
|
||||
- true:在结果中包含 DNS 解析记录
|
||||
- false:仅返回子域名列表
|
||||
|
||||
**注意事项:**
|
||||
- 包含 IP 信息会增加 API 调用次数
|
||||
- 可能包含历史解析 IP,不一定准确
|
||||
required: false
|
||||
position: 4
|
||||
format: "positional"
|
||||
default: false
|
||||
+88
-1
@@ -33,6 +33,93 @@
|
||||
--c2-mono: 'SF Mono', 'Fira Code', 'JetBrains Mono', 'Cascadia Code', monospace;
|
||||
}
|
||||
|
||||
html[data-theme="dark"] {
|
||||
--c2-accent: #60a5fa;
|
||||
--c2-accent-hover: #93c5fd;
|
||||
--c2-accent-dim: rgba(96, 165, 250, 0.14);
|
||||
--c2-accent-glow: rgba(96, 165, 250, 0.28);
|
||||
--c2-green: #34d399;
|
||||
--c2-green-dim: rgba(52, 211, 153, 0.14);
|
||||
--c2-red: #f87171;
|
||||
--c2-red-dim: rgba(248, 113, 113, 0.14);
|
||||
--c2-amber: #fbbf24;
|
||||
--c2-amber-dim: rgba(251, 191, 36, 0.14);
|
||||
--c2-purple: #a78bfa;
|
||||
--c2-purple-dim: rgba(167, 139, 250, 0.14);
|
||||
--c2-surface: #111827;
|
||||
--c2-surface-alt: #0b1120;
|
||||
--c2-border: #263244;
|
||||
--c2-border-hover: #3b4a63;
|
||||
--c2-text: #e5e7eb;
|
||||
--c2-text-dim: #a7b0c0;
|
||||
--c2-text-muted: #6b7280;
|
||||
--c2-shadow-sm: 0 1px 3px rgba(0,0,0,0.34);
|
||||
--c2-shadow-md: 0 8px 24px rgba(0,0,0,0.38);
|
||||
--c2-shadow-lg: 0 18px 48px rgba(0,0,0,0.45);
|
||||
}
|
||||
|
||||
html[data-theme="dark"] .c2-modal,
|
||||
html[data-theme="dark"] .c2-modal-content,
|
||||
html[data-theme="dark"] .c2-tab-panel--card,
|
||||
html[data-theme="dark"] .c2-session-detail,
|
||||
html[data-theme="dark"] .c2-payload-card,
|
||||
html[data-theme="dark"] .c2-profile-card,
|
||||
html[data-theme="dark"] .c2-event-card,
|
||||
html[data-theme="dark"] .c2-task-row,
|
||||
html[data-theme="dark"] .c2-listener-card {
|
||||
background: var(--c2-surface);
|
||||
color: var(--c2-text);
|
||||
border-color: var(--c2-border);
|
||||
}
|
||||
|
||||
html[data-theme="dark"] .c2-listener-card:hover,
|
||||
html[data-theme="dark"] .c2-payload-card:hover,
|
||||
html[data-theme="dark"] .c2-profile-card:hover {
|
||||
border-color: var(--c2-border-hover);
|
||||
}
|
||||
|
||||
html[data-theme="dark"] .c2-session-chip,
|
||||
html[data-theme="dark"] .c2-listener-pill,
|
||||
html[data-theme="dark"] .c2-task-type-badge,
|
||||
html[data-theme="dark"] .c2-tab-btn {
|
||||
background: var(--c2-surface-alt);
|
||||
border-color: var(--c2-border);
|
||||
}
|
||||
|
||||
html[data-theme="dark"] #page-c2-listeners,
|
||||
html[data-theme="dark"] #page-c2-sessions,
|
||||
html[data-theme="dark"] #page-c2-tasks,
|
||||
html[data-theme="dark"] #page-c2-payloads,
|
||||
html[data-theme="dark"] #page-c2-events,
|
||||
html[data-theme="dark"] #page-c2-profiles,
|
||||
html[data-theme="dark"] #page-c2-listeners .page-content,
|
||||
html[data-theme="dark"] #page-c2-sessions .page-content,
|
||||
html[data-theme="dark"] #page-c2-tasks .page-content,
|
||||
html[data-theme="dark"] #page-c2-payloads .page-content,
|
||||
html[data-theme="dark"] #page-c2-events .page-content,
|
||||
html[data-theme="dark"] #page-c2-profiles .page-content,
|
||||
html[data-theme="dark"] .c2-session-layout,
|
||||
html[data-theme="dark"] .c2-session-main {
|
||||
background: var(--c2-surface-alt) !important;
|
||||
}
|
||||
|
||||
html[data-theme="dark"] .c2-session-sidebar-wrap,
|
||||
html[data-theme="dark"] .c2-sessions-toolbar,
|
||||
html[data-theme="dark"] .c2-session-sidebar {
|
||||
background: #0b1120 !important;
|
||||
border-color: var(--c2-border) !important;
|
||||
}
|
||||
|
||||
html[data-theme="dark"] .c2-session-main-empty {
|
||||
background: transparent !important;
|
||||
color: var(--c2-text-dim) !important;
|
||||
}
|
||||
|
||||
html[data-theme="dark"] .c2-session-main-empty__icon {
|
||||
background: var(--c2-accent-dim) !important;
|
||||
border-color: rgba(96, 165, 250, 0.35) !important;
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Form Controls (scoped to C2 pages)
|
||||
============================================================================ */
|
||||
@@ -533,7 +620,7 @@
|
||||
min-height: 0;
|
||||
overflow: hidden;
|
||||
padding: 12px 16px 16px;
|
||||
background: linear-gradient(180deg, #f8fafc 0%, #ffffff 180px);
|
||||
background: linear-gradient(180deg, var(--c2-surface-alt) 0%, var(--c2-surface) 180px);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
+3888
-24
File diff suppressed because it is too large
Load Diff
+261
-16
@@ -39,6 +39,15 @@
|
||||
"version": "Current version",
|
||||
"toggleSidebar": "Collapse/expand sidebar"
|
||||
},
|
||||
"theme": {
|
||||
"system": "System",
|
||||
"light": "Light",
|
||||
"dark": "Dark",
|
||||
"titleSystem": "Current: system theme. Click to switch to light.",
|
||||
"titleLight": "Current: light theme. Click to switch to dark.",
|
||||
"titleDark": "Current: dark theme. Click to switch to system.",
|
||||
"toggle": "Toggle theme"
|
||||
},
|
||||
"notifications": {
|
||||
"title": "Notifications",
|
||||
"empty": "No new events",
|
||||
@@ -76,6 +85,7 @@
|
||||
"agentsManagement": "Agent management",
|
||||
"roles": "Roles",
|
||||
"rolesManagement": "Roles Management",
|
||||
"workflows": "Graph Orchestration",
|
||||
"settings": "System settings",
|
||||
"hitl": "Human-in-the-loop",
|
||||
"c2": "C2",
|
||||
@@ -504,6 +514,12 @@
|
||||
"filterByProject": "Filter by project",
|
||||
"filterAllProjects": "All projects",
|
||||
"filterUnboundProjects": "Unbound",
|
||||
"filterProjectSearch": "Search projects…",
|
||||
"filterProjectSearchEmpty": "No matching projects",
|
||||
"filterProjectSearchHint": "Type to search projects",
|
||||
"filterProjectSearchMore": "Type to find more projects",
|
||||
"filterProjectSearchLoading": "Searching…",
|
||||
"filterProjectSearchFailed": "Failed to load projects. Try again.",
|
||||
"projectConversationsTitle": "{{name}} · Conversations",
|
||||
"unboundConversationsTitle": "Unbound conversations",
|
||||
"noProjectConversations": "No conversations in this project",
|
||||
@@ -645,7 +661,10 @@
|
||||
"agentModeOrchSupervisor": "Supervisor",
|
||||
"hitlTitle": "Human-in-the-loop",
|
||||
"hitlCardSubtitle": "Approvals & allowlist",
|
||||
"hitlReviewer": "Review",
|
||||
"hitlReviewerLabel": "Reviewer",
|
||||
"hitlReviewerHuman": "Human approval",
|
||||
"hitlReviewerAgent": "Audit Agent",
|
||||
"hitlReviewerHint": "Switch between human and Audit Agent anytime; rules and whitelist stay the same. You can pre-select even when HITL is off.",
|
||||
"hitlConfigTitle": "Collaboration mode config",
|
||||
"hitlModeLabel": "Mode",
|
||||
"hitlModeOff": "Off",
|
||||
@@ -664,7 +683,89 @@
|
||||
},
|
||||
"hitl": {
|
||||
"pageTitle": "HITL approvals",
|
||||
"pageReviewerLabel": "Current reviewer",
|
||||
"pageReviewerHint": "Applies to the selected conversation. Without a conversation, saved to config.yaml as the global default for new chats. Takes effect immediately.",
|
||||
"pageReviewerSaved": "Reviewer saved.",
|
||||
"whitelistLabel": "Tool whitelist (no approval)",
|
||||
"whitelistHint": "One per line or comma-separated. Saved to config.yaml global whitelist and takes effect immediately (synced with chat sidebar).",
|
||||
"whitelistSaved": "Whitelist saved.",
|
||||
"whitelistSaveFailed": "Failed to save whitelist",
|
||||
"strategyLabel": "Audit strategy (prompt)",
|
||||
"strategyHint": "Whitelisted tools skip approval. Other tools are judged by the model using this prompt when Audit Agent is selected.",
|
||||
"strategyTabApproval": "Approval mode",
|
||||
"strategyTabReviewEdit": "Review & edit mode",
|
||||
"strategyHintApproval": "Whitelisted tools skip approval. In approval mode the Audit Agent only approves or rejects.",
|
||||
"strategyHintReviewEdit": "In review & edit mode the Audit Agent may narrow parameters via editedArguments before approve; reject if parameters cannot be safely adjusted.",
|
||||
"strategyReset": "Reset to default",
|
||||
"strategySaved": "Audit strategy saved.",
|
||||
"strategySaveFailed": "Failed to save audit strategy",
|
||||
"tabPending": "Pending",
|
||||
"tabLogs": "Audit logs",
|
||||
"tabStrategy": "Audit strategy",
|
||||
"tabWhitelist": "Tool whitelist",
|
||||
"pendingTitle": "Pending approvals",
|
||||
"searchLabel": "Search",
|
||||
"searchPlaceholder": "Tool, conversation, payload, comment…",
|
||||
"searchApply": "Search",
|
||||
"filterDecision": "Decision",
|
||||
"filterDecidedBy": "Reviewer",
|
||||
"filterAll": "All",
|
||||
"decisionApprove": "Approve",
|
||||
"decisionReject": "Reject",
|
||||
"reviewerHuman": "Human",
|
||||
"reviewerAgent": "Audit Agent",
|
||||
"reviewerSystem": "System",
|
||||
"reviewerManual": "Manual entry",
|
||||
"logCreate": "New log",
|
||||
"logModalTitle": "Audit log",
|
||||
"logModalEdit": "Edit audit log",
|
||||
"fieldConversation": "Conversation ID",
|
||||
"fieldTool": "Tool name",
|
||||
"fieldComment": "Comment",
|
||||
"fieldPayload": "Payload (JSON)",
|
||||
"fieldUserMessage": "User message",
|
||||
"fieldThinking": "Thinking",
|
||||
"fieldReasoning": "Reasoning chain",
|
||||
"fieldPlanning": "Planning",
|
||||
"colId": "ID",
|
||||
"colTool": "Tool",
|
||||
"colConversation": "Conversation",
|
||||
"colDecision": "Decision",
|
||||
"colDecidedBy": "Reviewer",
|
||||
"colContext": "Context",
|
||||
"colTime": "Time",
|
||||
"colActions": "Actions",
|
||||
"viewDetail": "Detail",
|
||||
"logModalView": "Audit log detail",
|
||||
"fieldExecutionResult": "Execution result",
|
||||
"executionSuccess": "success",
|
||||
"executionFailed": "failed",
|
||||
"edit": "Edit",
|
||||
"delete": "Delete",
|
||||
"logsEmpty": "No audit logs",
|
||||
"logsEmptyHint": "Records are created automatically when HITL approvals are approved or rejected.",
|
||||
"pageInfo": "{{total}} total",
|
||||
"prevPage": "Previous",
|
||||
"nextPage": "Next",
|
||||
"conversationRequired": "Conversation ID is required",
|
||||
"toolRequired": "Tool name is required",
|
||||
"saveFailed": "Save failed",
|
||||
"deleteConfirm": "Delete this audit log?",
|
||||
"deleteFailed": "Delete failed",
|
||||
"retentionHint": "Audit logs are kept for {{days}} days, then purged automatically.",
|
||||
"selectedCount": "{{count}} selected",
|
||||
"selectAll": "Select all",
|
||||
"deselectAll": "Deselect all",
|
||||
"batchDelete": "Batch delete",
|
||||
"batchDeleteConfirm": "Delete the selected {{count}} audit log(s)? This cannot be undone.",
|
||||
"batchDeleteSuccess": "Successfully deleted {{count}} audit log(s)",
|
||||
"batchDeleteFailed": "Batch delete failed",
|
||||
"clearAll": "Clear all",
|
||||
"clearAllConfirm": "Clear all {{count}} audit log(s) matching the current filters? This cannot be undone.",
|
||||
"clearAllConfirmNoFilter": "No filters are set. This will clear all {{count}} audit log(s). This cannot be undone. Continue?",
|
||||
"clearAllSuccess": "Cleared {{count}} audit log(s)",
|
||||
"clearAllFailed": "Clear failed",
|
||||
"selectLogsFirst": "Select audit logs to delete first",
|
||||
"loading": "Loading...",
|
||||
"emptyState": "No pending approvals",
|
||||
"dismiss": "Dismiss",
|
||||
@@ -1815,7 +1916,7 @@
|
||||
},
|
||||
"chatFilesPage": {
|
||||
"title": "File Management",
|
||||
"intro": "Files uploaded in chat appear here. Click “Copy path” to copy the server absolute path and paste it into a conversation so the model can reference the file.",
|
||||
"intro": "Files uploaded in chat appear here. Drag files into the list below, or click Upload to pick files (multiple allowed). Click “Copy path” to copy the server absolute path and paste it into a conversation so the model can reference the file.",
|
||||
"upload": "Upload",
|
||||
"conversationFilter": "Conversation ID",
|
||||
"conversationPlaceholder": "Leave empty for all",
|
||||
@@ -1927,7 +2028,7 @@
|
||||
"exportNoResults": "No vulnerabilities match the current filters",
|
||||
"exportStarted": "Started downloading {{count}} file(s)",
|
||||
"exportFailed": "Export failed",
|
||||
"saveRequiredFields": "Please fill in conversation ID, title, and severity",
|
||||
"saveRequiredFields": "Please fill in conversation ID, title, description, severity, type, target, reproduction steps, evidence/POC, impact, and remediation",
|
||||
"saveFailed": "Save failed",
|
||||
"fetchFailed": "Failed to fetch vulnerability",
|
||||
"deleteFailed": "Delete failed",
|
||||
@@ -1946,9 +2047,12 @@
|
||||
"detailTaskQueueId": "Task queue ID",
|
||||
"detailConversationTag": "Conversation tag",
|
||||
"detailTaskTag": "Task tag",
|
||||
"detailProof": "Proof",
|
||||
"detailPreconditions": "Preconditions",
|
||||
"detailReproductionSteps": "Reproduction steps",
|
||||
"detailEvidence": "Evidence / POC",
|
||||
"detailImpact": "Impact",
|
||||
"detailRecommendation": "Remediation",
|
||||
"detailRetestNotes": "Retest method",
|
||||
"downloadOkTitle": "Downloaded",
|
||||
"exportFailedMessage": "Export failed",
|
||||
"downloadFailed": "Download failed"
|
||||
@@ -2087,11 +2191,27 @@
|
||||
"subIndexFilter": "Sub-index filter (optional)",
|
||||
"subIndexFilterPlaceholder": "e.g. prod, must match an indexing sub_indexes tag",
|
||||
"subIndexFilterHint": "Empty = no filter. When set, only rows whose sub_indexes contain this tag (legacy rows with empty sub_indexes still match).",
|
||||
"ragPipelineHeader": "RAG pipeline (MultiQuery + Rerank)",
|
||||
"ragPipelineHint": "MultiQuery and rerank are always on: LLM query rewrite → vector prefetch & fusion → HTTP rerank → dedupe & budget truncate.",
|
||||
"multiQueryMaxQueries": "MultiQuery rewrite variant limit",
|
||||
"multiQueryMaxQueriesPlaceholder": "4",
|
||||
"multiQueryMaxQueriesHint": "Max LLM-generated retrieval variants (including paraphrases of the original query). Recommended 3–4, max 8.",
|
||||
"rerankProvider": "Rerank provider",
|
||||
"rerankProviderAuto": "Auto (infer from Base URL)",
|
||||
"rerankProviderCohere": "Cohere-compatible API",
|
||||
"rerankProviderHint": "DashScope uses gte-rerank; other compatible endpoints use /v1/rerank. Leave empty to infer from Base URL below.",
|
||||
"rerankModel": "Rerank model (optional)",
|
||||
"rerankModelPlaceholder": "Empty: DashScope→gte-rerank, Cohere→rerank-multilingual-v3.0",
|
||||
"rerankBaseUrl": "Rerank Base URL (optional)",
|
||||
"rerankBaseUrlPlaceholder": "Leave empty to reuse embedding / OpenAI base_url",
|
||||
"rerankApiKey": "Rerank API Key (optional)",
|
||||
"rerankApiKeyPlaceholder": "Leave empty to reuse embedding / OpenAI api_key",
|
||||
"rerankApiKeyHint": "On rerank failure, results fall back to fusion order; search still works.",
|
||||
"postRetrieveHeader": "Post-retrieval (dedupe / budget)",
|
||||
"postRetrieveDedupeAuto": "Results are always deduped by normalized text (whitespace-collapsed bodies). No setting required.",
|
||||
"prefetchTopK": "Prefetch candidates (vector stage)",
|
||||
"prefetchTopKPlaceholder": "0",
|
||||
"prefetchTopKHint": "0 = same as Top-K; larger values fetch more vector hits before dedupe/truncate (max 200).",
|
||||
"prefetchTopKPlaceholder": "20",
|
||||
"prefetchTopKHint": "Vector candidates per MultiQuery variant; 0 uses built-in max(top_k×4, 20) (max 200).",
|
||||
"maxContextChars": "Max returned characters (Unicode)",
|
||||
"maxContextCharsPlaceholder": "0",
|
||||
"maxContextCharsHint": "0 = unlimited; keeps whole chunks in rank order until the budget is exceeded.",
|
||||
@@ -2542,6 +2662,7 @@
|
||||
"conversationName": "Conversation name",
|
||||
"project": "Project",
|
||||
"noProject": "No project",
|
||||
"unknownProject": "Unknown project",
|
||||
"filterByProject": "Filter by project",
|
||||
"lastTime": "Last activity",
|
||||
"action": "Action",
|
||||
@@ -2683,9 +2804,9 @@
|
||||
"taskTag": "Task tag",
|
||||
"taskTagPlaceholder": "e.g. batch scan Q2, retest",
|
||||
"title": "Title",
|
||||
"titlePlaceholder": "Vulnerability title",
|
||||
"titlePlaceholder": "/api/login is vulnerable to SQL injection",
|
||||
"description": "Description",
|
||||
"descriptionPlaceholder": "Detailed description",
|
||||
"descriptionPlaceholder": "Describe the summary, trigger point, observed abnormal behavior, and why it is exploitable.",
|
||||
"severity": "Severity",
|
||||
"pleaseSelect": "Please select",
|
||||
"severityCritical": "Critical",
|
||||
@@ -2702,13 +2823,19 @@
|
||||
"type": "Vulnerability type",
|
||||
"typePlaceholder": "e.g. SQL injection, XSS, CSRF",
|
||||
"target": "Target",
|
||||
"targetPlaceholder": "Affected target (URL, IP, etc.)",
|
||||
"proof": "Proof (POC)",
|
||||
"proofPlaceholder": "Proof: request/response, screenshots, etc.",
|
||||
"targetPlaceholder": "Be specific: URL, IP:port, endpoint path, and parameter name.",
|
||||
"preconditions": "Preconditions",
|
||||
"preconditionsPlaceholder": "Login state, permissions, account, headers/cookies, required data, environment/version; write none if not needed.",
|
||||
"reproductionSteps": "Reproduction steps",
|
||||
"reproductionStepsPlaceholder": "Number the steps and include entry point, parameter, payload, command, and observation point.",
|
||||
"evidence": "Evidence / POC",
|
||||
"evidencePlaceholder": "Raw request/response, curl/tool command, screenshot notes, logs, DNSLog/callback records, database results, file paths, timestamps, etc.",
|
||||
"impact": "Impact",
|
||||
"impactPlaceholder": "Impact description",
|
||||
"impactPlaceholder": "Describe the verified real-world impact, such as which data can be read or changed.",
|
||||
"recommendation": "Recommendation",
|
||||
"recommendationPlaceholder": "Remediation"
|
||||
"recommendationPlaceholder": "Write the concrete fix and retest criteria.",
|
||||
"retestNotes": "Retest method",
|
||||
"retestNotesPlaceholder": "How to verify the fix, including expected status code, error message, or access-control result."
|
||||
},
|
||||
"vulnerabilityMd": {
|
||||
"headingBasic": "Basic information",
|
||||
@@ -2725,9 +2852,12 @@
|
||||
"labelCreated": "Created at",
|
||||
"labelUpdated": "Updated at",
|
||||
"headingDescription": "Description",
|
||||
"headingProof": "Proof (POC)",
|
||||
"headingPreconditions": "Preconditions",
|
||||
"headingReproductionSteps": "Reproduction steps",
|
||||
"headingEvidence": "Evidence / POC",
|
||||
"headingImpact": "Impact",
|
||||
"headingRecommendation": "Remediation"
|
||||
"headingRecommendation": "Remediation",
|
||||
"headingRetestNotes": "Retest method"
|
||||
},
|
||||
"roleModal": {
|
||||
"addRole": "Add role",
|
||||
@@ -2786,7 +2916,122 @@
|
||||
"mcpDisabledBadgeTitle": "Off in MCP Management; check only expresses role linkage—turn on in MCP to run",
|
||||
"roleFilterOnBanner": "These tools are checked and linked to this role (independent of MCP-wide enable).",
|
||||
"roleFilterOffBanner": "These tools are unchecked and not linked to this role.",
|
||||
"checkboxLinkTitle": "Check to link this tool to this role"
|
||||
"checkboxLinkTitle": "Check to link this tool to this role",
|
||||
"bindWorkflow": "Bind graph workflow",
|
||||
"bindWorkflowHint": "When a workflow is selected, conversations with this role automatically run the bound graph; workflow fields are configured freely in the graph JSON.",
|
||||
"workflowPolicy": "Workflow trigger policy",
|
||||
"workflowPolicyAuto": "Auto trigger",
|
||||
"workflowPolicyOff": "Off",
|
||||
"noWorkflowBind": "No workflow",
|
||||
"workflowDisabledSuffix": " (disabled)"
|
||||
},
|
||||
"workflows": {
|
||||
"title": "Graph Orchestration",
|
||||
"newGraph": "New graph",
|
||||
"processLibrary": "Process library",
|
||||
"nodeLibrary": "Node library",
|
||||
"emptyList": "No graph workflows yet",
|
||||
"statusEnabled": "Enabled",
|
||||
"statusDisabled": "Disabled",
|
||||
"metaId": "ID",
|
||||
"metaName": "Name",
|
||||
"metaDescription": "Description",
|
||||
"metaEnabled": "Enabled",
|
||||
"namePlaceholder": "Basic Web scan",
|
||||
"descriptionPlaceholder": "Optional",
|
||||
"connect": "Connect",
|
||||
"connecting": "Connecting",
|
||||
"deleteSelected": "Delete selected",
|
||||
"autoLayout": "Auto layout",
|
||||
"canvasEmpty": "Drag nodes from the left onto the canvas, or click node buttons to add quickly",
|
||||
"properties": "Properties",
|
||||
"nodeProperties": "Node properties",
|
||||
"edgeProperties": "Edge properties",
|
||||
"deleteNode": "Delete node",
|
||||
"deleteEdge": "Delete edge",
|
||||
"propertyEmpty": "Select a node or edge to edit properties",
|
||||
"propLabel": "Name",
|
||||
"propType": "Type",
|
||||
"customFields": "Custom fields",
|
||||
"addField": "Add field",
|
||||
"noCustomFields": "No custom fields",
|
||||
"nodes": {
|
||||
"start": "Start",
|
||||
"tool": "Tool",
|
||||
"agent": "Agent",
|
||||
"condition": "Condition",
|
||||
"hitl": "Approval",
|
||||
"output": "Output",
|
||||
"end": "End",
|
||||
"default": "Node"
|
||||
},
|
||||
"edges": {
|
||||
"yes": "Yes",
|
||||
"no": "No"
|
||||
},
|
||||
"config": {
|
||||
"inputKeys": "Input variables",
|
||||
"mcpTool": "MCP tool",
|
||||
"selectTool": "Select a tool",
|
||||
"toolDisabled": " (disabled)",
|
||||
"argumentsTemplate": "Arguments template",
|
||||
"argumentsStatic": "Tool arguments (JSON)",
|
||||
"timeoutSeconds": "Timeout (seconds)",
|
||||
"optional": "Optional",
|
||||
"agentMode": "Agent mode",
|
||||
"inputSource": "Input source",
|
||||
"inputBinding": "Input field binding",
|
||||
"inputBindingHint": "from = data source, field = field name (e.g. output, message)",
|
||||
"nodeInstruction": "Node instruction",
|
||||
"instructionPlaceholder": "Describe what this node should accomplish",
|
||||
"outputKey": "Output variable name",
|
||||
"conditionExpression": "Condition expression",
|
||||
"conditionHint": "The node computes matched (true/false); outgoing edges define branches: first edge is \"Yes\", second is \"No\". You can also write <code>{{previous.matched}} == \"true\"</code> on the edge.",
|
||||
"edgeCondition": "Edge condition",
|
||||
"edgeConditionHintCondition": "{{previous.matched}} == \"true\" (Yes) or == \"false\" (No)",
|
||||
"edgeConditionHintExample": "e.g. {{previous.output}} == \"ok\"",
|
||||
"edgeBranchHint": "The first edge from a condition node defaults to the \"Yes\" branch, the second to \"No\"; you can customize conditions here.",
|
||||
"hitlPrompt": "Approval prompt",
|
||||
"hitlPromptPlaceholder": "Approve to continue",
|
||||
"hitlReviewer": "Reviewer",
|
||||
"hitlInteractiveHint": "The run pauses at this node; approve or reject via API or the monitor panel to continue.",
|
||||
"promptBinding": "Prompt field binding",
|
||||
"promptBindingHint": "When prompt text is empty, read approval text from the bound field",
|
||||
"outputSource": "Variable source",
|
||||
"sourceBinding": "Output field binding",
|
||||
"sourceBindingHint": "Write the bound field value to the output variable; static value below overrides when set",
|
||||
"staticValue": "Static output value",
|
||||
"resultBinding": "End summary binding",
|
||||
"resultBindingHint": "Field shown in the end node summary",
|
||||
"endTemplate": "End summary template"
|
||||
},
|
||||
"defaultHitlPrompt": "Please approve whether this step should continue",
|
||||
"nodeFallback": "Node {{n}}",
|
||||
"loadFailed": "Failed to load workflows",
|
||||
"saveFailed": "Failed to save workflow",
|
||||
"deleteFailed": "Failed to delete workflow",
|
||||
"saved": "Workflow saved",
|
||||
"deleted": "Workflow deleted",
|
||||
"idNameRequired": "Workflow ID and name are required",
|
||||
"selectToDelete": "Select a workflow to delete",
|
||||
"confirmDelete": "Delete workflow {{id}}?",
|
||||
"duplicateEdge": "An edge already exists between these two nodes",
|
||||
"connectModeOn": "Connect mode: click source node then target node",
|
||||
"connectModeOff": "Exited connect mode",
|
||||
"validation": {
|
||||
"needStart": "At least one Start node is required",
|
||||
"needOutput": "At least one Output node is required",
|
||||
"edgeSelfLoop": "Edge {{id}} cannot point to itself",
|
||||
"edgeSourceMissing": "Edge {{id}} source node does not exist",
|
||||
"edgeTargetMissing": "Edge {{id}} target node does not exist",
|
||||
"startIncoming": "Start node {{label}} must not have incoming edges",
|
||||
"outputOutgoing": "Output node {{label}} must not have outgoing edges",
|
||||
"toolNeedsMcp": "Tool node {{label}} requires an MCP tool",
|
||||
"conditionNeedsExpr": "Condition node {{label}} requires a condition expression",
|
||||
"conditionNeedsOutEdge": "Condition node {{label}} needs at least one outgoing edge (Yes/No branch)",
|
||||
"conditionTooManyEdges": "Condition node {{label}} should have at most two outgoing edges (Yes/No); configure edge conditions for a third and beyond",
|
||||
"outputNeedsKey": "Output node {{label}} requires an output variable name"
|
||||
}
|
||||
},
|
||||
"c2": {
|
||||
"clipboardCopied": "Copied to clipboard",
|
||||
|
||||
+262
-17
@@ -39,6 +39,15 @@
|
||||
"version": "当前版本",
|
||||
"toggleSidebar": "折叠/展开侧边栏"
|
||||
},
|
||||
"theme": {
|
||||
"system": "跟随系统",
|
||||
"light": "浅色",
|
||||
"dark": "暗色",
|
||||
"titleSystem": "当前:跟随系统主题。点击切换为浅色。",
|
||||
"titleLight": "当前:浅色主题。点击切换为暗色。",
|
||||
"titleDark": "当前:暗色主题。点击切换为跟随系统。",
|
||||
"toggle": "切换主题"
|
||||
},
|
||||
"notifications": {
|
||||
"title": "事件通知",
|
||||
"empty": "暂无新事件",
|
||||
@@ -76,6 +85,7 @@
|
||||
"agentsManagement": "Agent管理",
|
||||
"roles": "角色",
|
||||
"rolesManagement": "角色管理",
|
||||
"workflows": "图编排",
|
||||
"settings": "系统设置",
|
||||
"hitl": "人机协同",
|
||||
"c2": "C2",
|
||||
@@ -492,6 +502,12 @@
|
||||
"filterByProject": "按项目筛选",
|
||||
"filterAllProjects": "全部项目",
|
||||
"filterUnboundProjects": "未绑定项目",
|
||||
"filterProjectSearch": "搜索项目…",
|
||||
"filterProjectSearchEmpty": "没有匹配的项目",
|
||||
"filterProjectSearchHint": "输入关键字搜索项目",
|
||||
"filterProjectSearchMore": "更多项目请输入关键字搜索",
|
||||
"filterProjectSearchLoading": "搜索中…",
|
||||
"filterProjectSearchFailed": "加载项目失败,请重试",
|
||||
"projectConversationsTitle": "{{name}} · 对话",
|
||||
"unboundConversationsTitle": "未绑定项目",
|
||||
"noProjectConversations": "该项目暂无对话",
|
||||
@@ -633,7 +649,10 @@
|
||||
"agentModeOrchSupervisor": "Supervisor",
|
||||
"hitlTitle": "人机协同",
|
||||
"hitlCardSubtitle": "审批与白名单",
|
||||
"hitlReviewer": "Review",
|
||||
"hitlReviewerLabel": "审批方",
|
||||
"hitlReviewerHuman": "人工审批",
|
||||
"hitlReviewerAgent": "审计 Agent",
|
||||
"hitlReviewerHint": "可在人工与审计 Agent 之间随时切换;规则与白名单不变。人机协同为「关闭」时也可预先选择。",
|
||||
"hitlConfigTitle": "协同模式配置",
|
||||
"hitlModeLabel": "模式",
|
||||
"hitlModeOff": "关闭",
|
||||
@@ -642,7 +661,7 @@
|
||||
"hitlSensitiveTools": "敏感工具(逗号分隔)",
|
||||
"hitlWhitelistTools": "白名单工具(免审批,逗号分隔)",
|
||||
"hitlWhitelistPlaceholder": "例:read_file, grep 或每行一个工具名(与 config 全局白名单合并)",
|
||||
"hitlWhitelistHint": "每行一个或逗号分隔;与 config 中全局白名单合并展示。",
|
||||
"hitlWhitelistHint": "白名单内工具免审批;每行一个或逗号分隔,与 config 全局白名单合并。",
|
||||
"hitlApply": "应用",
|
||||
"hitlApplyOkSync": "人机协同配置已保存并同步到服务器。",
|
||||
"hitlApplyOkWhitelistYaml": "免审批工具已合并进 config.yaml 并生效。协同模式、超时等仍须选中会话后再点「应用」才会写入服务器。",
|
||||
@@ -652,7 +671,89 @@
|
||||
},
|
||||
"hitl": {
|
||||
"pageTitle": "人机协同审批",
|
||||
"pageReviewerLabel": "当前审批方",
|
||||
"pageReviewerHint": "作用于当前选中会话;未选会话时写入 config.yaml 作为全局默认,新建会话时沿用。切换后立即生效。",
|
||||
"pageReviewerSaved": "审批方已保存。",
|
||||
"whitelistLabel": "免审批工具白名单",
|
||||
"whitelistHint": "每行一个或逗号分隔;保存后写入 config.yaml 全局白名单并立即生效(与聊天侧栏同步展示)。",
|
||||
"whitelistSaved": "白名单已保存。",
|
||||
"whitelistSaveFailed": "保存白名单失败",
|
||||
"strategyLabel": "审计策略(提示词)",
|
||||
"strategyHint": "白名单内工具免审批;其余工具在审批方为「审计 Agent」时,由模型按此提示词自主裁决。",
|
||||
"strategyTabApproval": "审批模式",
|
||||
"strategyTabReviewEdit": "审查编辑模式",
|
||||
"strategyHintApproval": "白名单内工具免审批;审批模式下审计 Agent 仅裁决通过/拒绝。",
|
||||
"strategyHintReviewEdit": "审查编辑模式下审计 Agent 可通过 editedArguments 收窄参数后放行;无法安全改参时应拒绝。",
|
||||
"strategyReset": "恢复默认",
|
||||
"strategySaved": "审计策略已保存。",
|
||||
"strategySaveFailed": "保存审计策略失败",
|
||||
"tabPending": "待审计",
|
||||
"tabLogs": "审计日志",
|
||||
"tabStrategy": "审计策略",
|
||||
"tabWhitelist": "工具白名单",
|
||||
"pendingTitle": "待处理审批",
|
||||
"searchLabel": "搜索",
|
||||
"searchPlaceholder": "工具名、会话 ID、载荷、备注…",
|
||||
"searchApply": "搜索",
|
||||
"filterDecision": "决策",
|
||||
"filterDecidedBy": "审批方",
|
||||
"filterAll": "全部",
|
||||
"decisionApprove": "通过",
|
||||
"decisionReject": "拒绝",
|
||||
"reviewerHuman": "人工",
|
||||
"reviewerAgent": "审计 Agent",
|
||||
"reviewerSystem": "系统",
|
||||
"reviewerManual": "手动录入",
|
||||
"logCreate": "新建日志",
|
||||
"logModalTitle": "审计日志",
|
||||
"logModalEdit": "编辑审计日志",
|
||||
"fieldConversation": "会话 ID",
|
||||
"fieldTool": "工具名",
|
||||
"fieldComment": "备注",
|
||||
"fieldPayload": "载荷 (JSON)",
|
||||
"fieldUserMessage": "用户原话",
|
||||
"fieldThinking": "本轮思考",
|
||||
"fieldReasoning": "推理链",
|
||||
"fieldPlanning": "规划",
|
||||
"colId": "ID",
|
||||
"colTool": "工具",
|
||||
"colConversation": "会话",
|
||||
"colDecision": "决策",
|
||||
"colDecidedBy": "审批方",
|
||||
"colContext": "上下文",
|
||||
"colTime": "时间",
|
||||
"colActions": "操作",
|
||||
"viewDetail": "详情",
|
||||
"logModalView": "审计日志详情",
|
||||
"fieldExecutionResult": "执行结果",
|
||||
"executionSuccess": "成功",
|
||||
"executionFailed": "失败",
|
||||
"edit": "编辑",
|
||||
"delete": "删除",
|
||||
"logsEmpty": "暂无审计日志",
|
||||
"logsEmptyHint": "人机协同审批通过或拒绝后会自动记录在此。",
|
||||
"pageInfo": "共 {{total}} 条",
|
||||
"prevPage": "上一页",
|
||||
"nextPage": "下一页",
|
||||
"conversationRequired": "请填写会话 ID",
|
||||
"toolRequired": "请填写工具名",
|
||||
"saveFailed": "保存失败",
|
||||
"deleteConfirm": "确定删除这条审计日志?",
|
||||
"deleteFailed": "删除失败",
|
||||
"retentionHint": "审计日志保留 {{days}} 天,超期自动清理",
|
||||
"selectedCount": "已选择 {{count}} 项",
|
||||
"selectAll": "全选",
|
||||
"deselectAll": "取消全选",
|
||||
"batchDelete": "批量删除",
|
||||
"batchDeleteConfirm": "确定删除选中的 {{count}} 条审计日志?此操作不可恢复。",
|
||||
"batchDeleteSuccess": "成功删除 {{count}} 条审计日志",
|
||||
"batchDeleteFailed": "批量删除失败",
|
||||
"clearAll": "清空",
|
||||
"clearAllConfirm": "确定清空当前筛选条件下的全部 {{count}} 条审计日志?此操作不可恢复。",
|
||||
"clearAllConfirmNoFilter": "未设置筛选条件,将清空全部 {{count}} 条审计日志。此操作不可恢复,是否继续?",
|
||||
"clearAllSuccess": "已清空 {{count}} 条审计日志",
|
||||
"clearAllFailed": "清空失败",
|
||||
"selectLogsFirst": "请先选择要删除的审计日志",
|
||||
"loading": "加载中...",
|
||||
"emptyState": "暂无待审批项",
|
||||
"dismiss": "忽略",
|
||||
@@ -1803,7 +1904,7 @@
|
||||
},
|
||||
"chatFilesPage": {
|
||||
"title": "文件管理",
|
||||
"intro": "管理在对话中上传的文件。需要让 AI 引用某文件时,在列表中点击「复制路径」,到对话里粘贴即可(路径为服务器上的绝对路径,与对话附件保存位置一致)。",
|
||||
"intro": "管理在对话中上传的文件。可将文件拖拽到下方列表区域,或点击「上传文件」选择文件(支持多选)。需要让 AI 引用某文件时,在列表中点击「复制路径」,到对话里粘贴即可(路径为服务器上的绝对路径,与对话附件保存位置一致)。",
|
||||
"upload": "上传文件",
|
||||
"conversationFilter": "会话 ID",
|
||||
"conversationPlaceholder": "留空表示全部",
|
||||
@@ -1915,7 +2016,7 @@
|
||||
"exportNoResults": "当前筛选条件下无可导出漏洞",
|
||||
"exportStarted": "已开始下载 {{count}} 份报告",
|
||||
"exportFailed": "导出失败",
|
||||
"saveRequiredFields": "请填写必填字段:会话ID、标题和严重程度",
|
||||
"saveRequiredFields": "请填写必填字段:会话ID、标题、描述、严重程度、漏洞类型、目标、复现步骤、证据/POC、影响和修复建议",
|
||||
"saveFailed": "保存失败",
|
||||
"fetchFailed": "获取漏洞失败",
|
||||
"deleteFailed": "删除失败",
|
||||
@@ -1934,9 +2035,12 @@
|
||||
"detailTaskQueueId": "任务队列ID",
|
||||
"detailConversationTag": "对话标签",
|
||||
"detailTaskTag": "任务标签",
|
||||
"detailProof": "证明",
|
||||
"detailPreconditions": "前置条件",
|
||||
"detailReproductionSteps": "复现步骤",
|
||||
"detailEvidence": "证据 / POC",
|
||||
"detailImpact": "影响",
|
||||
"detailRecommendation": "修复建议",
|
||||
"detailRetestNotes": "复测方式",
|
||||
"downloadOkTitle": "下载成功",
|
||||
"exportFailedMessage": "导出失败",
|
||||
"downloadFailed": "下载失败"
|
||||
@@ -2075,11 +2179,27 @@
|
||||
"subIndexFilter": "子索引过滤(可选)",
|
||||
"subIndexFilterPlaceholder": "如 prod,与索引 sub_indexes 一致",
|
||||
"subIndexFilterHint": "留空不过滤;填写后仅检索向量行 sub_indexes 中含该标签的结果(未打标旧行仍保留)。",
|
||||
"ragPipelineHeader": "RAG 管线(MultiQuery + Rerank)",
|
||||
"ragPipelineHint": "MultiQuery 与精排始终启用:LLM 改写多路检索 → 向量预取与融合 → HTTP 精排 → 去重与预算截断。",
|
||||
"multiQueryMaxQueries": "MultiQuery 改写变体上限",
|
||||
"multiQueryMaxQueriesPlaceholder": "4",
|
||||
"multiQueryMaxQueriesHint": "LLM 生成的检索变体数量上限(含原问语义覆盖);建议 3~4,最大 8。",
|
||||
"rerankProvider": "精排提供商",
|
||||
"rerankProviderAuto": "自动(按 Base URL 推断)",
|
||||
"rerankProviderCohere": "Cohere 兼容 API",
|
||||
"rerankProviderHint": "DashScope 使用 gte-rerank;其他兼容端点走 /v1/rerank。留空时按下方 Base URL 自动推断。",
|
||||
"rerankModel": "精排模型(可选)",
|
||||
"rerankModelPlaceholder": "留空:DashScope→gte-rerank,Cohere→rerank-multilingual-v3.0",
|
||||
"rerankBaseUrl": "精排 Base URL(可选)",
|
||||
"rerankBaseUrlPlaceholder": "留空则复用嵌入 / OpenAI 的 base_url",
|
||||
"rerankApiKey": "精排 API Key(可选)",
|
||||
"rerankApiKeyPlaceholder": "留空则复用嵌入 / OpenAI 的 api_key",
|
||||
"rerankApiKeyHint": "精排失败时自动降级为融合排序,检索仍可用。",
|
||||
"postRetrieveHeader": "检索后处理(去重 / 预算)",
|
||||
"postRetrieveDedupeAuto": "检索结果会自动按正文规范化去重(合并仅空白不同的重复片段),无需配置。",
|
||||
"prefetchTopK": "预取候选数(向量阶段)",
|
||||
"prefetchTopKPlaceholder": "0",
|
||||
"prefetchTopKHint": "0 表示与 Top-K 相同;大于 Top-K 时先多取候选再经去重/截断回到 Top-K(上限 200)。",
|
||||
"prefetchTopKPlaceholder": "20",
|
||||
"prefetchTopKHint": "每条 MultiQuery 变体的向量候选数;0 表示内置 max(top_k×4, 20)(上限 200)。",
|
||||
"maxContextChars": "返回内容最大字符数(Unicode)",
|
||||
"maxContextCharsPlaceholder": "0",
|
||||
"maxContextCharsHint": "0 表示不限制;按检索顺序整段保留 chunk,超出则丢弃后续。",
|
||||
@@ -2530,6 +2650,7 @@
|
||||
"conversationName": "对话名称",
|
||||
"project": "项目",
|
||||
"noProject": "无项目",
|
||||
"unknownProject": "未知项目",
|
||||
"filterByProject": "按项目筛选",
|
||||
"lastTime": "最近一次对话时间",
|
||||
"action": "操作",
|
||||
@@ -2671,9 +2792,9 @@
|
||||
"taskTag": "任务标签",
|
||||
"taskTagPlaceholder": "如:批量扫描Q2、专项复测",
|
||||
"title": "标题",
|
||||
"titlePlaceholder": "漏洞标题",
|
||||
"titlePlaceholder": "/api/login 存在 SQL 注入",
|
||||
"description": "描述",
|
||||
"descriptionPlaceholder": "漏洞详细描述",
|
||||
"descriptionPlaceholder": "说明漏洞摘要、触发点、异常现象和为什么可被利用。",
|
||||
"severity": "严重程度",
|
||||
"pleaseSelect": "请选择",
|
||||
"severityCritical": "严重",
|
||||
@@ -2690,13 +2811,19 @@
|
||||
"type": "漏洞类型",
|
||||
"typePlaceholder": "如:SQL注入、XSS、CSRF等",
|
||||
"target": "目标",
|
||||
"targetPlaceholder": "受影响的目标(URL、IP地址等)",
|
||||
"proof": "证明(POC)",
|
||||
"proofPlaceholder": "漏洞证明,如请求/响应、截图等",
|
||||
"targetPlaceholder": "精确到 URL/IP:端口/接口路径/参数名",
|
||||
"preconditions": "前置条件",
|
||||
"preconditionsPlaceholder": "登录状态、权限、账号、Header/Cookie、特定数据、环境/版本;无则写无。",
|
||||
"reproductionSteps": "复现步骤",
|
||||
"reproductionStepsPlaceholder": "按 1/2/3 编号,写清入口、参数、payload、执行命令、观察点。",
|
||||
"evidence": "证据 / POC",
|
||||
"evidencePlaceholder": "原始请求/响应、curl/工具命令、截图说明、日志、DNSLog/回连记录、数据库结果、文件路径、时间戳等。",
|
||||
"impact": "影响",
|
||||
"impactPlaceholder": "漏洞影响说明",
|
||||
"impactPlaceholder": "结合已验证事实说明实际影响,例如越权读取哪些数据。",
|
||||
"recommendation": "修复建议",
|
||||
"recommendationPlaceholder": "修复建议"
|
||||
"recommendationPlaceholder": "写具体修复点和复测标准。",
|
||||
"retestNotes": "复测方式",
|
||||
"retestNotesPlaceholder": "修复后如何验证漏洞已关闭,包括应返回的状态码、错误信息或访问控制结果。"
|
||||
},
|
||||
"vulnerabilityMd": {
|
||||
"headingBasic": "基本信息",
|
||||
@@ -2713,9 +2840,12 @@
|
||||
"labelCreated": "创建时间",
|
||||
"labelUpdated": "更新时间",
|
||||
"headingDescription": "描述",
|
||||
"headingProof": "证明(POC)",
|
||||
"headingPreconditions": "前置条件",
|
||||
"headingReproductionSteps": "复现步骤",
|
||||
"headingEvidence": "证据 / POC",
|
||||
"headingImpact": "影响",
|
||||
"headingRecommendation": "修复建议"
|
||||
"headingRecommendation": "修复建议",
|
||||
"headingRetestNotes": "复测方式"
|
||||
},
|
||||
"roleModal": {
|
||||
"addRole": "添加角色",
|
||||
@@ -2774,7 +2904,122 @@
|
||||
"mcpDisabledBadgeTitle": "MCP 管理里该工具为关闭;勾选只表示想关联到本角色,实际调用需先在 MCP 中打开",
|
||||
"roleFilterOnBanner": "以下为「已勾选、关联到本角色」的工具(与 MCP 管理里全局开/关无关)。",
|
||||
"roleFilterOffBanner": "以下为「未勾选、未关联到本角色」的工具。",
|
||||
"checkboxLinkTitle": "勾选表示本角色关联使用该工具"
|
||||
"checkboxLinkTitle": "勾选表示本角色关联使用该工具",
|
||||
"bindWorkflow": "绑定图编排流程",
|
||||
"bindWorkflowHint": "选中流程后,对话页使用该角色会自动触发绑定图;流程字段由图定义 JSON 自由配置。",
|
||||
"workflowPolicy": "流程触发策略",
|
||||
"workflowPolicyAuto": "自动触发",
|
||||
"workflowPolicyOff": "关闭",
|
||||
"noWorkflowBind": "不绑定流程",
|
||||
"workflowDisabledSuffix": "(已禁用)"
|
||||
},
|
||||
"workflows": {
|
||||
"title": "图编排",
|
||||
"newGraph": "新建图",
|
||||
"processLibrary": "流程库",
|
||||
"nodeLibrary": "节点库",
|
||||
"emptyList": "暂无图编排流程",
|
||||
"statusEnabled": "启用",
|
||||
"statusDisabled": "禁用",
|
||||
"metaId": "ID",
|
||||
"metaName": "名称",
|
||||
"metaDescription": "描述",
|
||||
"metaEnabled": "启用",
|
||||
"namePlaceholder": "基础 Web 扫描",
|
||||
"descriptionPlaceholder": "可选",
|
||||
"connect": "连线",
|
||||
"connecting": "连线中",
|
||||
"deleteSelected": "删除选中",
|
||||
"autoLayout": "自动布局",
|
||||
"canvasEmpty": "从左侧拖拽节点到画布,或点击节点按钮快速添加",
|
||||
"properties": "属性",
|
||||
"nodeProperties": "节点属性",
|
||||
"edgeProperties": "连线属性",
|
||||
"deleteNode": "删除节点",
|
||||
"deleteEdge": "删除连线",
|
||||
"propertyEmpty": "选择一个节点或连线后编辑属性",
|
||||
"propLabel": "名称",
|
||||
"propType": "类型",
|
||||
"customFields": "自定义字段",
|
||||
"addField": "添加字段",
|
||||
"noCustomFields": "暂无自定义字段",
|
||||
"nodes": {
|
||||
"start": "开始",
|
||||
"tool": "工具",
|
||||
"agent": "Agent",
|
||||
"condition": "条件",
|
||||
"hitl": "审批",
|
||||
"output": "输出",
|
||||
"end": "结束",
|
||||
"default": "节点"
|
||||
},
|
||||
"edges": {
|
||||
"yes": "是",
|
||||
"no": "否"
|
||||
},
|
||||
"config": {
|
||||
"inputKeys": "输入变量",
|
||||
"mcpTool": "MCP 工具",
|
||||
"selectTool": "请选择工具",
|
||||
"toolDisabled": "(未启用)",
|
||||
"argumentsTemplate": "参数模板",
|
||||
"argumentsStatic": "工具参数(JSON)",
|
||||
"timeoutSeconds": "超时秒数",
|
||||
"optional": "可选",
|
||||
"agentMode": "Agent 模式",
|
||||
"inputSource": "输入来源",
|
||||
"inputBinding": "输入字段绑定",
|
||||
"inputBindingHint": "from 选数据来源,field 为字段名(如 output、message)",
|
||||
"nodeInstruction": "节点指令",
|
||||
"instructionPlaceholder": "描述该节点要完成的任务",
|
||||
"outputKey": "输出变量名",
|
||||
"conditionExpression": "条件表达式",
|
||||
"conditionHint": "节点会计算 matched(true/false),由出边决定分支:第一条线为「是」,第二条为「否」;也可在连线上写 <code>{{previous.matched}} == \"true\"</code>。",
|
||||
"edgeCondition": "连线条件",
|
||||
"edgeConditionHintCondition": "{{previous.matched}} == \"true\"(是)或 == \"false\"(否)",
|
||||
"edgeConditionHintExample": "例如: {{previous.output}} == \"ok\"",
|
||||
"edgeBranchHint": "从条件节点连出的第一条线默认为「是」分支,第二条为「否」分支;也可在此自定义条件。",
|
||||
"hitlPrompt": "审批提示",
|
||||
"hitlPromptPlaceholder": "请审批是否继续",
|
||||
"hitlReviewer": "审批方",
|
||||
"hitlInteractiveHint": "运行时在审批节点暂停,需通过 API 或监控面板人工通过/拒绝后继续。",
|
||||
"promptBinding": "提示字段绑定",
|
||||
"promptBindingHint": "留空提示文案时,从绑定字段读取审批说明",
|
||||
"outputSource": "变量来源",
|
||||
"sourceBinding": "输出字段绑定",
|
||||
"sourceBindingHint": "将绑定字段的值写入输出变量;也可填写下方固定值覆盖",
|
||||
"staticValue": "固定输出值",
|
||||
"resultBinding": "结束摘要绑定",
|
||||
"resultBindingHint": "结束节点展示的摘要字段",
|
||||
"endTemplate": "结束摘要模板"
|
||||
},
|
||||
"defaultHitlPrompt": "请审批该步骤是否继续执行",
|
||||
"nodeFallback": "节点 {{n}}",
|
||||
"loadFailed": "加载工作流失败",
|
||||
"saveFailed": "保存工作流失败",
|
||||
"deleteFailed": "删除工作流失败",
|
||||
"saved": "工作流已保存",
|
||||
"deleted": "工作流已删除",
|
||||
"idNameRequired": "工作流 ID 和名称不能为空",
|
||||
"selectToDelete": "请选择要删除的工作流",
|
||||
"confirmDelete": "确定删除工作流 {{id}}?",
|
||||
"duplicateEdge": "这两个节点之间已经有连线",
|
||||
"connectModeOn": "连线模式:依次点击源节点和目标节点",
|
||||
"connectModeOff": "已退出连线模式",
|
||||
"validation": {
|
||||
"needStart": "至少需要一个开始节点",
|
||||
"needOutput": "至少需要一个输出节点",
|
||||
"edgeSelfLoop": "连线 {{id}} 不能指向自身",
|
||||
"edgeSourceMissing": "连线 {{id}} 的源节点不存在",
|
||||
"edgeTargetMissing": "连线 {{id}} 的目标节点不存在",
|
||||
"startIncoming": "开始节点 {{label}} 不应有入边",
|
||||
"outputOutgoing": "输出节点 {{label}} 不应有出边",
|
||||
"toolNeedsMcp": "工具节点 {{label}} 需要选择 MCP 工具",
|
||||
"conditionNeedsExpr": "条件节点 {{label}} 需要条件表达式",
|
||||
"conditionNeedsOutEdge": "条件节点 {{label}} 至少需要一条出边(是/否分支)",
|
||||
"conditionTooManyEdges": "条件节点 {{label}} 建议最多两条出边(是/否);第三条及以后需配置连线条件",
|
||||
"outputNeedsKey": "输出节点 {{label}} 需要输出变量名"
|
||||
}
|
||||
},
|
||||
"c2": {
|
||||
"clipboardCopied": "已复制到剪贴板",
|
||||
|
||||
@@ -22,7 +22,7 @@ const AUDIT_ACTIONS_BY_CATEGORY = {
|
||||
task: ['create_queue', 'start_queue', 'delete_queue', 'pause_queue', 'rerun_queue', 'delete_batch_task'],
|
||||
tool: ['execution_delete', 'execution_delete_batch'],
|
||||
file: ['upload', 'delete'],
|
||||
hitl: ['decision'],
|
||||
hitl: ['decision', 'audit_strategy_update'],
|
||||
role: ['create', 'update', 'delete'],
|
||||
skill: ['create', 'update', 'delete'],
|
||||
agent: ['markdown_create', 'markdown_update', 'markdown_delete']
|
||||
|
||||
+83
-12
@@ -84,6 +84,7 @@ function initChatFilesPage() {
|
||||
/* ignore */
|
||||
}
|
||||
}
|
||||
setupChatFilesDragDrop();
|
||||
loadChatFilesPage();
|
||||
}
|
||||
|
||||
@@ -1226,21 +1227,31 @@ function chatFilesUploadToFolderClick(ev, btn) {
|
||||
if (inp) inp.click();
|
||||
}
|
||||
|
||||
async function onChatFilesUploadPick(ev) {
|
||||
const input = ev.target;
|
||||
const file = input && input.files && input.files[0];
|
||||
if (!file) return;
|
||||
const form = new FormData();
|
||||
form.append('file', file);
|
||||
function chatFilesResolveUploadTarget() {
|
||||
const pendingDir = chatFilesPendingUploadDir;
|
||||
chatFilesPendingUploadDir = '';
|
||||
if (pendingDir) {
|
||||
form.append('relativeDir', pendingDir);
|
||||
} else {
|
||||
const conv = document.getElementById('chat-files-filter-conv');
|
||||
if (conv && conv.value.trim()) {
|
||||
form.append('conversationId', conv.value.trim());
|
||||
}
|
||||
return { relativeDir: pendingDir };
|
||||
}
|
||||
if (chatFilesGetGroupByMode() === 'folder') {
|
||||
const dir = chatFilesBrowsePath.join('/');
|
||||
return dir ? { relativeDir: dir } : {};
|
||||
}
|
||||
const conv = document.getElementById('chat-files-filter-conv');
|
||||
if (conv && conv.value.trim()) {
|
||||
return { conversationId: conv.value.trim() };
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
async function chatFilesUploadFile(file, target) {
|
||||
if (!file || chatFilesXHRUploadBusy) return false;
|
||||
const form = new FormData();
|
||||
form.append('file', file);
|
||||
if (target && target.relativeDir) {
|
||||
form.append('relativeDir', target.relativeDir);
|
||||
} else if (target && target.conversationId) {
|
||||
form.append('conversationId', target.conversationId);
|
||||
}
|
||||
chatFilesSetUploadBusy(true);
|
||||
chatFilesSetUploadProgressUI(true, 0, file.name);
|
||||
@@ -1265,15 +1276,75 @@ async function onChatFilesUploadPick(ev) {
|
||||
: '上传成功。在列表中点击「复制路径」即可粘贴到对话中引用。';
|
||||
chatFilesShowToast(msg);
|
||||
}
|
||||
return true;
|
||||
} catch (e) {
|
||||
alert((e && e.message) ? e.message : String(e));
|
||||
return false;
|
||||
} finally {
|
||||
chatFilesSetUploadBusy(false);
|
||||
chatFilesSetUploadProgressUI(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function chatFilesUploadFiles(fileList) {
|
||||
if (!fileList || !fileList.length || chatFilesXHRUploadBusy) return;
|
||||
const files = Array.from(fileList).filter(function (f) {
|
||||
return f && (f.name || f.size > 0);
|
||||
});
|
||||
if (!files.length) return;
|
||||
const target = chatFilesResolveUploadTarget();
|
||||
for (let i = 0; i < files.length; i++) {
|
||||
const ok = await chatFilesUploadFile(files[i], target);
|
||||
if (!ok) break;
|
||||
}
|
||||
}
|
||||
|
||||
async function onChatFilesUploadPick(ev) {
|
||||
const input = ev.target;
|
||||
const files = input && input.files;
|
||||
if (!files || !files.length) return;
|
||||
try {
|
||||
await chatFilesUploadFiles(files);
|
||||
} finally {
|
||||
input.value = '';
|
||||
}
|
||||
}
|
||||
|
||||
let chatFilesDragDropBound = false;
|
||||
|
||||
function setupChatFilesDragDrop() {
|
||||
if (chatFilesDragDropBound) return;
|
||||
const wrap = document.getElementById('chat-files-list-wrap');
|
||||
if (!wrap) return;
|
||||
chatFilesDragDropBound = true;
|
||||
|
||||
wrap.addEventListener('dragover', function (e) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (chatFilesXHRUploadBusy) return;
|
||||
this.classList.add('drag-over');
|
||||
});
|
||||
wrap.addEventListener('dragleave', function (e) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (!this.contains(e.relatedTarget)) {
|
||||
this.classList.remove('drag-over');
|
||||
}
|
||||
});
|
||||
wrap.addEventListener('drop', function (e) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
this.classList.remove('drag-over');
|
||||
if (chatFilesXHRUploadBusy) return;
|
||||
const files = e.dataTransfer && e.dataTransfer.files;
|
||||
if (files && files.length) {
|
||||
chatFilesUploadFiles(files).catch(function (err) {
|
||||
if (err) alert((err && err.message) ? err.message : String(err));
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 语言切换后重新渲染列表:表头与「更多」菜单由 JS 拼接,无 data-i18n,需用当前语言的 t() 再生成一遍
|
||||
document.addEventListener('languagechange', function () {
|
||||
if (typeof window.currentPage !== 'function') return;
|
||||
|
||||
+619
-174
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user