mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-05 13:58:15 +02:00
Compare commits
121 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 11bab83fc5 | |||
| dc750e3680 | |||
| 0236d1c155 | |||
| be59ddcab6 | |||
| 25464a68e6 | |||
| eabfed09c9 | |||
| cbcbd414cd | |||
| 0933f9365b | |||
| e792891ff3 | |||
| e14e5f15d3 | |||
| 4d5e0c5f21 | |||
| b3238304ce | |||
| 665e2ec73a | |||
| d63d9c25b8 | |||
| d1c63d0ba7 | |||
| 55d6d449cd | |||
| d4bc9646d9 | |||
| b941f5a8d9 | |||
| 97e2c0fd43 | |||
| bd3e48c2d0 | |||
| 8b0b91fddc | |||
| 2b38595b42 | |||
| 5c795439ee | |||
| df531910cf | |||
| 8a089a826c | |||
| 60b32ffc69 | |||
| 21c36fcce8 | |||
| 4d048f6da0 | |||
| 03a2707b83 | |||
| 9941f51b3e | |||
| 1553e896c5 | |||
| ea2184773e | |||
| 764d8110ec | |||
| e037f383f5 | |||
| e40f7cb468 | |||
| 72aca69204 | |||
| 133da1c640 | |||
| af78b47517 | |||
| f5fabc05a4 | |||
| 5cc53b1076 | |||
| f1be2064db | |||
| 0c9c2ec606 | |||
| cf09dd36d8 | |||
| c6e2701b30 | |||
| 42b5901d99 | |||
| 117bed6839 | |||
| bad323cd0e | |||
| 8138f8b576 | |||
| 74627d214b | |||
| f622efe245 | |||
| 3924b5285b | |||
| 21f641bbd7 | |||
| d913695303 | |||
| 6bb3a73f73 | |||
| f0a80a8e58 | |||
| 3f9dbb4214 | |||
| c0f0861b31 | |||
| 704137aa34 | |||
| c56bf36df0 | |||
| 5560f34c6c | |||
| 70e9a73fc0 | |||
| 12bc9d8ab6 | |||
| f8db82a065 | |||
| 8ce30d9072 | |||
| e6506d00e8 | |||
| b2308617b8 | |||
| cd17fdca33 | |||
| 1acaccd09f | |||
| 983fe650c1 | |||
| 52d03dc849 | |||
| 9de72d9ad5 | |||
| d95275ffae | |||
| 6cef93dbb7 | |||
| dd3b1ae219 | |||
| f42209682a | |||
| 1b1aed1699 | |||
| 44ced98863 | |||
| 97834c162e | |||
| 9276f2f144 | |||
| a454cada6a | |||
| 99b53d4fbc | |||
| a43a9deaea | |||
| ce88da84c9 | |||
| 15855c7073 | |||
| 43eb3e546b | |||
| 2d52c9b6ac | |||
| d5401b8b4c | |||
| 5fd4393a2e | |||
| a049f6b5c2 | |||
| acba8e5a39 | |||
| f826b91362 | |||
| 98c2de2a60 | |||
| 1c4d4b305b | |||
| f210ac9a03 | |||
| 6685076dfb | |||
| 7f322653f6 | |||
| 66ac2f1357 | |||
| c446e22d0c | |||
| 0358d3a67d | |||
| 9b82f265fd | |||
| 3d9cae58e4 | |||
| 1f1eadee5e | |||
| 0569255189 | |||
| 8ccf90d067 | |||
| b3be89f47d | |||
| b9bf8f62d4 | |||
| 05ca0c1480 | |||
| 47a4f3fc5b | |||
| a3b378ae9e | |||
| a904d26e78 | |||
| 7ba7476c4f | |||
| ae25a243ac | |||
| 23bd6288ff | |||
| fef21d3a24 | |||
| 933bba4517 | |||
| e1d65437cc | |||
| 9325aed1eb | |||
| dee2b3ab42 | |||
| a69bc93fa1 | |||
| b1a620bfce | |||
| 61b164eec2 |
@@ -113,6 +113,7 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
|
||||
- 🔒 Password-protected web UI, audit logs, and SQLite persistence
|
||||
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
|
||||
- 📁 Conversation grouping with pinning, rename, and batch management
|
||||
- 📂 **Project management**: group conversations and vulnerabilities by project; **shared facts** (project blackboard) persist cross-session context (targets, env, auth notes) with auto-injection for agents and MCP tools (`upsert_project_fact`, `get_project_fact`, …)
|
||||
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
||||
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
||||
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
||||
@@ -285,7 +286,7 @@ Requirements / tips:
|
||||
- **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent.
|
||||
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
|
||||
- **Management** – Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
|
||||
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `default_mode`, `robot_use_multi_agent`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
||||
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
||||
- **Details** – **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
|
||||
|
||||
### Skills System (Agent Skills + Eino)
|
||||
@@ -536,7 +537,7 @@ agents_dir: "agents" # Multi-agent Markdown definitions (orchestrator + sub-age
|
||||
multi_agent:
|
||||
enabled: false
|
||||
default_mode: "single" # single | multi (UI default when multi-agent is enabled)
|
||||
robot_use_multi_agent: false
|
||||
robot_default_agent_mode: react
|
||||
batch_use_multi_agent: false
|
||||
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
|
||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
||||
|
||||
+3
-2
@@ -112,6 +112,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
- 🔒 Web 登录保护、审计日志、SQLite 持久化
|
||||
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
|
||||
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
||||
- 📂 **项目管理**:按项目归类对话与漏洞;**共享事实**(项目黑板)在多会话间沉淀目标/环境/认证等认知,自动注入 Agent 上下文,支持 MCP 工具读写(`upsert_project_fact`、`get_project_fact` 等)
|
||||
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
||||
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
||||
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
||||
@@ -283,7 +284,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
||||
- **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。
|
||||
- **子代理**(**deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
|
||||
- **界面管理**:**Agents → Agent 管理**;API `/api/multi-agent/markdown-agents`。
|
||||
- **配置项**:`multi_agent`:`enabled`、`default_mode`、`robot_use_multi_agent`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
||||
- **配置项**:`multi_agent`:`enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
||||
- **更多细节**:[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
|
||||
|
||||
### Skills 技能系统(Agent Skills + Eino)
|
||||
@@ -534,7 +535,7 @@ agents_dir: "agents" # 多代理 Markdown(主代理 orchestrator.md + 子代
|
||||
multi_agent:
|
||||
enabled: false
|
||||
default_mode: "single" # single | multi(开启多代理时的界面默认模式)
|
||||
robot_use_multi_agent: false
|
||||
robot_default_agent_mode: react
|
||||
batch_use_multi_agent: false
|
||||
orchestrator_instruction: "" # Deep;orchestrator.md 正文为空时使用
|
||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
||||
|
||||
@@ -61,4 +61,8 @@ max_iterations: 0
|
||||
5) Follow-up Verification Plan(后续验证建议)
|
||||
- 对每个优先条目:建议由哪个阶段子代理接手、需要补测的最小证据集
|
||||
|
||||
输出后直接结束。遇到证据不足的条目标注为“需要补证据”。
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
输出后直接结束。遇到证据不足的条目标注为“需要补证据”。
|
||||
|
||||
@@ -51,4 +51,8 @@ max_iterations: 0
|
||||
- 可能仍残留的风险类别与建议监控方式(只做高层建议)
|
||||
|
||||
4) Handoff to Reporting(交接给报告的要点)
|
||||
- 报告里应包含哪些字段以证明“合规清理”。
|
||||
- 报告里应包含哪些字段以证明“合规清理”。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -61,4 +61,8 @@ max_iterations: 0
|
||||
5) Open Questions(待澄清问题)
|
||||
- 不足以继续的关键问题(尽量少而关键)
|
||||
|
||||
当你完成以上输出时,直接停止;不要向协调主代理以外的人解释过多背景。将所有不确定性标注为“需要补证据/需要澄清”。
|
||||
当你完成以上输出时,直接停止;不要向协调主代理以外的人解释过多背景。将所有不确定性标注为“需要补证据/需要澄清”。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -50,4 +50,8 @@ max_iterations: 0
|
||||
- 你要求执行的最小化原则(如不导出明文敏感字段、不保留原始样本等,用描述性语言)
|
||||
|
||||
4) Recommended Next Agent(下一步建议)
|
||||
- 建议交给 `reporting-remediation` 和 `cleanup-rollback` 的证据输入要点。
|
||||
- 建议交给 `reporting-remediation` 和 `cleanup-rollback` 的证据输入要点。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -32,3 +32,7 @@ max_iterations: 0
|
||||
- 优先用工具拿可验证事实,标注信息来源与置信度;避免无依据推测。
|
||||
- 输出结构化(目标、发现项、证据摘要、建议后续动作),便于协调者合并进总报告。
|
||||
- 不执行未授权的入侵或社工骚扰;双用途技术仅用于甲方书面授权场景。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -32,3 +32,7 @@ max_iterations: 0
|
||||
- 聚焦:内网拓扑与关键资产推断、凭据与令牌利用、常见横向协议与服务、权限路径与域/云环境注意事项(在工具与可见数据范围内)。
|
||||
- 每一步说明假设前提与证据;禁止对未授权网段、生产无关系统或真实用户数据进行操作。
|
||||
- 输出结构化:当前据点能力、发现的主机/服务、建议的下一步(可交给其他子代理或主代理编排)、风险与回滚注意点。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -51,4 +51,8 @@ max_iterations: 0
|
||||
- 建议记录哪些证据字段(时间戳、目标、请求摘要、响应摘要、变更清单、回滚确认)
|
||||
|
||||
4) Stop & Rollback Criteria(停止与回滚标准)
|
||||
- 触发阈值/不可控情况(用描述性语言即可)
|
||||
- 触发阈值/不可控情况(用描述性语言即可)
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -102,10 +102,34 @@ description: plan_execute 模式下的规划/重规划侧主代理:拆解目
|
||||
|
||||
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
|
||||
|
||||
## 证据与漏洞
|
||||
## 证据、黑板与漏洞
|
||||
|
||||
- 要求结论有证据支撑(请求/响应、命令输出、可复现步骤);禁止无依据的确定断言。
|
||||
- 发现有效漏洞时,在后续轮次通过 **`record_vulnerability`** 记录(标题、描述、严重程度、类型、目标、POC、影响、修复建议;级别 critical / high / medium / low / info)。
|
||||
|
||||
## 项目黑板(事实)与漏洞记录(分离)
|
||||
|
||||
当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。
|
||||
|
||||
- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。
|
||||
- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。
|
||||
- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。
|
||||
- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。
|
||||
- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。
|
||||
- **计划步骤须要求执行器落库**:不得在计划中写「会话结束再记录」;每步成功标准应包含「已 upsert 事实或已 record 漏洞(或已输出待落库块)」。
|
||||
|
||||
### 事实写入规范(审计复现 / 知识沉淀)
|
||||
|
||||
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||
- **body**:完整可复现上下文,写入 `upsert_project_fact` 的 body 字段;索引不含 body,后续会话须靠 `get_project_fact` 取回。
|
||||
- **category / fact_key 建议**:
|
||||
- 环境认知:`target/`、`auth/`、`infra/`、`business/`(body 用环境模板即可)
|
||||
- 发现与利用:`finding/`、`chain/`、`exploit/`、`poc/`(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||
- **与漏洞记录分工**:`record_vulnerability` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||
- 更新同一发现时保持相同 `fact_key` 覆盖写入,勿散落多个 key 导致上下文丢失。
|
||||
|
||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。
|
||||
|
||||
## 执行器对用户输出(重要)
|
||||
|
||||
|
||||
@@ -117,9 +117,29 @@ description: supervisor 模式下的协调者:通过 transfer 委派专家子
|
||||
3. 期望交付物是否可验收(例如:可复现命令、截图要点、结论段落)?
|
||||
4. 是否已明确写出 URL/IP:Port/域名路径与 in-scope 边界(而非“按上文继续”)?
|
||||
|
||||
## 漏洞
|
||||
## 项目黑板(事实)与漏洞记录(分离)
|
||||
|
||||
有效漏洞应通过 **`record_vulnerability`** 记录(含 POC 与严重性)。
|
||||
当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。
|
||||
|
||||
- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。
|
||||
- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。
|
||||
- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。
|
||||
- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。
|
||||
- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。
|
||||
|
||||
### 事实写入规范(审计复现 / 知识沉淀)
|
||||
|
||||
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||
- **body**:完整可复现上下文,写入 `upsert_project_fact` 的 body 字段;索引不含 body,后续会话须靠 `get_project_fact` 取回。
|
||||
- **category / fact_key 建议**:
|
||||
- 环境认知:`target/`、`auth/`、`infra/`、`business/`(body 用环境模板即可)
|
||||
- 发现与利用:`finding/`、`chain/`、`exploit/`、`poc/`(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||
- **与漏洞记录分工**:`record_vulnerability` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||
- 更新同一发现时保持相同 `fact_key` 覆盖写入,勿散落多个 key 导致上下文丢失。
|
||||
|
||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。
|
||||
|
||||
## 表达
|
||||
|
||||
|
||||
+23
-1
@@ -127,7 +127,29 @@ description: 多代理模式下的 Deep 编排者:在已授权安全场景中
|
||||
## 工具与 MCP
|
||||
|
||||
- **工具调用失败时**:1) 仔细分析错误信息,理解失败的具体原因;2) 如果工具不存在或未启用,尝试使用其他替代工具完成相同目标;3) 如果参数错误,根据错误提示修正参数后重试;4) 如果工具执行失败但输出了有用信息,可以基于这些信息继续分析;5) 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作;6) 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务。工具返回的错误信息会包含在工具响应中,请仔细阅读并做出合理决策。
|
||||
- **漏洞记录**:发现**有效漏洞**时,必须使用 **`record_vulnerability`** 记录(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度使用 critical / high / medium / low / info。记录后可在授权范围内继续测试。
|
||||
## 项目黑板(事实)与漏洞记录(分离)
|
||||
|
||||
当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。
|
||||
|
||||
- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。
|
||||
- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。
|
||||
- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。
|
||||
- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。
|
||||
- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。
|
||||
|
||||
### 事实写入规范(审计复现 / 知识沉淀)
|
||||
|
||||
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||
- **body**:完整可复现上下文,写入 `upsert_project_fact` 的 body 字段;索引不含 body,后续会话须靠 `get_project_fact` 取回。
|
||||
- **category / fact_key 建议**:
|
||||
- 环境认知:`target/`、`auth/`、`infra/`、`business/`(body 用环境模板即可)
|
||||
- 发现与利用:`finding/`、`chain/`、`exploit/`、`poc/`(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||
- **与漏洞记录分工**:`record_vulnerability` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||
- 更新同一发现时保持相同 `fact_key` 覆盖写入,勿散落多个 key 导致上下文丢失。
|
||||
|
||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。
|
||||
- **编排进度(待办)**:当你的任务包含 3 个或以上步骤,或你准备委派多个子目标并行/串行推进时,优先使用 `write_todos` 来向用户展示“当前在做什么/接下来做什么”。维护约束:同一时刻最多一个条目处于 `in_progress`;完成后立刻标记 `completed`;遇到阻塞就保留为 `in_progress` 并继续推进。
|
||||
- **强触发建议(提升多 agent 使用率)**:如果你将要进行任何“证据收集/枚举/扫描/验证/复现/整理报告”这类实质执行动作,且不只是单步查询,请优先在第一个工具调用前就用 `write_todos` 建立计划;随后用 `task` 委派至少一个子代理获取结构化证据,而不是自己把全部步骤做完。
|
||||
- **技能库(Skills)与知识库**:技能包位于服务器 `skills/` 目录(各子目录 `SKILL.md`,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。多代理本会话通过内置 **`skill`** 工具渐进加载;子代理同样挂载 skill + 可选本机文件工具时,可在委派说明中提示按需加载。若当前无 skill 工具,需要完整 Skill 工作流时请使用多代理模式或切换为 Eino 编排会话。
|
||||
|
||||
@@ -31,5 +31,9 @@ max_iterations: 0
|
||||
- 禁止自行猜测目标、替换为历史目标或擅自发起全量探索。
|
||||
|
||||
- 以证据为中心:请求/响应、Payload、命令输出、截图说明等,便于审计与复现。
|
||||
- 先确认边界与禁止项(如拒绝 DoS、数据破坏);发现有效漏洞时按协调者要求使用 `record_vulnerability` 等流程(若你的工具集中包含)。
|
||||
- 先确认边界与禁止项(如拒绝 DoS、数据破坏)。
|
||||
- 输出包含:攻击路径摘要、关键步骤、影响评估、修复与缓解建议;语言简洁,便于主代理汇总。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -51,4 +51,8 @@ max_iterations: 0
|
||||
- 列出需要清理/验证的痕迹类型(配置、会话、日志、服务变更等层级描述即可)
|
||||
|
||||
4) Recommended Next Steps(下一步建议)
|
||||
- 建议由哪个阶段子代理接手,以及需要哪些证据输入。
|
||||
- 建议由哪个阶段子代理接手,以及需要哪些证据输入。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -53,4 +53,8 @@ max_iterations: 0
|
||||
4) Recommended Next Agent(下一步建议)
|
||||
- 明确建议由哪个子代理接手(例如 `lateral-movement` / `persistence-maintenance` / `impact-exfiltration` / `reporting-remediation`)
|
||||
|
||||
输出后直接结束。
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
输出后直接结束。
|
||||
|
||||
@@ -34,3 +34,7 @@ max_iterations: 0
|
||||
|
||||
- 若 **`description` / 用户消息 / 上文交接包** 中已给出资产列表、枚举结论或明确写「跳过全量枚举 / 仅做增量 / 从端口扫描或验证开始」,则**不得**为走完整流程而重新执行等价的广域子域爆破或相同参数集的枚举;仅在交接包声明的**缺口**上补充侦察。
|
||||
- 若子目标实为**漏洞验证、协议利用、权限提升**等而非攻击面扩展,应**极短说明**「当前角色为侦察;建议协调者改派专项代理」并仅提供与侦察相关的最小补充信息,避免擅自把任务扩写成新一轮全盘资产收集。
|
||||
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
@@ -55,4 +55,8 @@ max_iterations: 0
|
||||
5) Appendix(附录)
|
||||
- 术语、假设、证据清单索引(按证据类型列出即可)
|
||||
|
||||
输出后直接结束。
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
输出后直接结束。
|
||||
|
||||
@@ -57,4 +57,8 @@ max_iterations: 0
|
||||
4) Uncertainties & Missing Evidence(不确定性与缺口)
|
||||
- 列出最关键的缺口(尽量少,但要关键)
|
||||
|
||||
输出后直接结束。
|
||||
## 边渗透边记录
|
||||
|
||||
- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。
|
||||
|
||||
输出后直接结束。
|
||||
|
||||
+29
-6
@@ -10,7 +10,7 @@
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.6.18"
|
||||
version: "v1.6.28"
|
||||
# 服务器配置
|
||||
server:
|
||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||
@@ -34,6 +34,12 @@ auth:
|
||||
log:
|
||||
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
|
||||
output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径
|
||||
# 平台操作审计(系统设置 -> 日志审计;不记录对话正文与每次工具调用)
|
||||
audit:
|
||||
enabled: true
|
||||
retention_days: 15 # 0 表示不自动清理
|
||||
max_detail_bytes: 8192
|
||||
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
|
||||
# ============================================
|
||||
# 对话相关配置
|
||||
# ============================================
|
||||
@@ -55,7 +61,7 @@ openai:
|
||||
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
||||
reasoning:
|
||||
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
||||
effort: max # low | medium | high | max;空表示不指定(openai_compat 下 auto 且无强度时不发请求扩展)
|
||||
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
|
||||
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
|
||||
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
|
||||
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
|
||||
@@ -71,21 +77,23 @@ fofa:
|
||||
# Agent 配置
|
||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||
agent:
|
||||
max_iterations: 1200 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||
max_iterations: 12000 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
||||
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||
|
||||
system_prompt_path: ""
|
||||
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
||||
hitl:
|
||||
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
|
||||
tool_whitelist: [read_file, list_dir, glob, grep]
|
||||
# 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存)
|
||||
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
|
||||
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/stream;Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人/批量无请求体时固定按 deep
|
||||
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/stream;Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人按 robot_default_agent_mode
|
||||
multi_agent:
|
||||
enabled: true
|
||||
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高)
|
||||
robot_default_agent_mode: eino_single # 企微/钉钉/飞书机器人默认对话模式:react | eino_single | deep | plan_execute | supervisor
|
||||
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
||||
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
|
||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
|
||||
@@ -108,7 +116,7 @@ multi_agent:
|
||||
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
||||
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
||||
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
||||
tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||
tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过
|
||||
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放
|
||||
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
||||
@@ -125,6 +133,8 @@ multi_agent:
|
||||
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
||||
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
||||
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
|
||||
run_retry_max_attempts: 0 # >0:429/5xx/网络抖动时 ADK 运行循环指数退避续跑次数;0=默认 10
|
||||
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
||||
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
||||
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||
@@ -254,11 +264,13 @@ robots:
|
||||
enabled: false
|
||||
client_id: ""
|
||||
client_secret: ""
|
||||
allow_conversation_id_fallback: false
|
||||
lark: # 飞书
|
||||
enabled: false
|
||||
app_id: ""
|
||||
app_secret: ""
|
||||
verify_token: ""
|
||||
allow_chat_id_fallback: false
|
||||
# ============================================
|
||||
# Skills 相关配置
|
||||
# ============================================
|
||||
@@ -280,3 +292,14 @@ agents_dir: agents
|
||||
# 系统会从该目录加载所有 .yaml 格式的角色配置文件
|
||||
# 每个角色应创建独立的配置文件,例如:roles/CTF.yaml, roles/默认.yaml 等
|
||||
roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录)
|
||||
|
||||
# ============================================
|
||||
# 项目管理与事实黑板
|
||||
# ============================================
|
||||
project:
|
||||
enabled: true
|
||||
# default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID
|
||||
fact_index_max_runes: 3500
|
||||
fact_summary_max_runes: 240
|
||||
default_inject_deprecated: false
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 182 KiB After Width: | Height: | Size: 178 KiB |
@@ -17,6 +17,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/project"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
@@ -365,12 +366,12 @@ type ProgressCallback func(eventType, message string, data interface{})
|
||||
|
||||
// AgentLoop 执行Agent循环
|
||||
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil)
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil, "")
|
||||
}
|
||||
|
||||
// AgentLoopWithConversationID 执行Agent循环(带对话ID)
|
||||
func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) {
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil)
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil, "")
|
||||
}
|
||||
|
||||
// EinoSingleAgentSystemInstruction 供 Eino adk.ChatModelAgent.Instruction 使用,与 AgentLoopWithProgress 首条 system 对齐(含 system_prompt_path)。
|
||||
@@ -396,7 +397,7 @@ func (a *Agent) EinoSingleAgentSystemInstruction() string {
|
||||
}
|
||||
|
||||
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
|
||||
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) {
|
||||
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string, systemPromptExtra string) (*AgentLoopResult, error) {
|
||||
ctx = withAgentConversationID(ctx, conversationID)
|
||||
// 设置当前对话ID(兼容未走 context 的旧路径;并发会话应以 context 为准)
|
||||
a.mu.Lock()
|
||||
@@ -426,6 +427,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
}
|
||||
}
|
||||
}
|
||||
systemPrompt = project.AppendSystemPromptBlock(systemPrompt, systemPromptExtra)
|
||||
|
||||
messages := []ChatMessage{
|
||||
{
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package agent
|
||||
|
||||
import "cyberstrike-ai/internal/mcp/builtin"
|
||||
import (
|
||||
"cyberstrike-ai/internal/project"
|
||||
)
|
||||
|
||||
// DefaultSingleAgentSystemPrompt 单代理(ReAct / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。
|
||||
func DefaultSingleAgentSystemPrompt() string {
|
||||
@@ -105,11 +107,7 @@ func DefaultSingleAgentSystemPrompt() string {
|
||||
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
|
||||
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
|
||||
|
||||
## 漏洞记录
|
||||
|
||||
发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。
|
||||
|
||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。记录后可在授权范围内继续测试。
|
||||
` + project.FactRecordingBlackboardSection(false) + `
|
||||
|
||||
## 技能库(Skills)与知识库
|
||||
|
||||
|
||||
+60
-191
@@ -15,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
@@ -61,6 +62,7 @@ type App struct {
|
||||
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
|
||||
c2WatchdogCancel context.CancelFunc // 看门狗取消函数
|
||||
c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步)
|
||||
auditSvc *audit.Service
|
||||
}
|
||||
|
||||
// New 创建新应用
|
||||
@@ -93,6 +95,11 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||
}
|
||||
|
||||
auditSvc := audit.NewService(db, cfg, log.Logger)
|
||||
audit.RegisterConversationCreateHook(auditSvc)
|
||||
auditSvc.PurgeExpired()
|
||||
audit.StartRetentionLoop(auditSvc, log.Logger)
|
||||
|
||||
// 创建MCP服务器(带数据库持久化)
|
||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||
@@ -104,7 +111,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
executor.RegisterTools(mcpServer)
|
||||
|
||||
// 注册漏洞记录工具
|
||||
registerVulnerabilityTool(mcpServer, db, log.Logger)
|
||||
registerVulnerabilityTools(mcpServer, db, log.Logger)
|
||||
registerProjectFactTools(mcpServer, db, cfg, log.Logger)
|
||||
|
||||
if cfg.Auth.GeneratedPassword != "" {
|
||||
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
|
||||
@@ -222,6 +230,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
|
||||
// 创建知识库API处理器
|
||||
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
|
||||
knowledgeHandler.SetAudit(auditSvc)
|
||||
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
||||
|
||||
// 扫描知识库并建立索引(异步)
|
||||
@@ -318,31 +327,43 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
log.Logger.Warn("创建 agents 目录失败", zap.String("path", agentsDir), zap.Error(err))
|
||||
}
|
||||
markdownAgentsHandler := handler.NewMarkdownAgentsHandler(agentsDir)
|
||||
markdownAgentsHandler.SetAudit(auditSvc)
|
||||
log.Logger.Info("多代理 Markdown 子 Agent 目录", zap.String("agentsDir", agentsDir))
|
||||
|
||||
// 创建处理器
|
||||
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
|
||||
agentHandler.SetAudit(auditSvc)
|
||||
agentHandler.SetAgentsMarkdownDir(agentsDir)
|
||||
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
|
||||
if knowledgeManager != nil {
|
||||
agentHandler.SetKnowledgeManager(knowledgeManager)
|
||||
}
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
||||
monitorHandler.SetAudit(auditSvc)
|
||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
||||
authHandler.SetAudit(auditSvc)
|
||||
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
||||
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
||||
projectHandler := handler.NewProjectHandler(db, log.Logger)
|
||||
vulnerabilityHandler.SetAudit(auditSvc)
|
||||
webshellHandler := handler.NewWebShellHandler(log.Logger, db)
|
||||
webshellHandler.SetAudit(auditSvc)
|
||||
chatUploadsHandler := handler.NewChatUploadsHandler(log.Logger)
|
||||
chatUploadsHandler.SetAudit(auditSvc)
|
||||
registerWebshellTools(mcpServer, db, webshellHandler, log.Logger)
|
||||
registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger)
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||
configHandler.SetAudit(auditSvc)
|
||||
agentHandler.SetHitlToolWhitelistSaver(configHandler)
|
||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||
externalMCPHandler.SetAudit(auditSvc)
|
||||
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
||||
roleHandler.SetAudit(auditSvc)
|
||||
skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger)
|
||||
skillsHandler.SetAudit(auditSvc)
|
||||
fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
|
||||
terminalHandler := handler.NewTerminalHandler(log.Logger)
|
||||
if db != nil {
|
||||
@@ -357,9 +378,12 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port)
|
||||
}
|
||||
c2Handler := handler.NewC2Handler(c2Manager, log.Logger)
|
||||
c2Handler.SetAudit(auditSvc)
|
||||
|
||||
// 创建OpenAPI处理器
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
conversationHandler.SetAudit(auditSvc)
|
||||
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
||||
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
||||
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
|
||||
|
||||
@@ -385,13 +409,15 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
c2Watchdog: c2Watchdog,
|
||||
c2WatchdogCancel: watchdogCancel,
|
||||
c2Handler: c2Handler,
|
||||
auditSvc: auditSvc,
|
||||
}
|
||||
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
|
||||
app.startRobotConnections()
|
||||
|
||||
// 设置漏洞工具注册器(内置工具,必须设置)
|
||||
vulnerabilityRegistrar := func() error {
|
||||
registerVulnerabilityTool(mcpServer, db, log.Logger)
|
||||
registerVulnerabilityTools(mcpServer, db, log.Logger)
|
||||
registerProjectFactTools(mcpServer, db, cfg, log.Logger)
|
||||
return nil
|
||||
}
|
||||
configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar)
|
||||
@@ -479,6 +505,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
attackChainHandler,
|
||||
app, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler,
|
||||
projectHandler,
|
||||
webshellHandler,
|
||||
chatUploadsHandler,
|
||||
roleHandler,
|
||||
@@ -487,6 +514,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
fofaHandler,
|
||||
terminalHandler,
|
||||
app.c2Handler,
|
||||
auditHandler,
|
||||
mcpServer,
|
||||
authManager,
|
||||
openAPIHandler,
|
||||
@@ -723,6 +751,7 @@ func setupRoutes(
|
||||
attackChainHandler *handler.AttackChainHandler,
|
||||
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler *handler.VulnerabilityHandler,
|
||||
projectHandler *handler.ProjectHandler,
|
||||
webshellHandler *handler.WebShellHandler,
|
||||
chatUploadsHandler *handler.ChatUploadsHandler,
|
||||
roleHandler *handler.RoleHandler,
|
||||
@@ -731,6 +760,7 @@ func setupRoutes(
|
||||
fofaHandler *handler.FofaHandler,
|
||||
terminalHandler *handler.TerminalHandler,
|
||||
c2Handler *handler.C2Handler,
|
||||
auditHandler *handler.AuditHandler,
|
||||
mcpServer *mcp.Server,
|
||||
authManager *security.AuthManager,
|
||||
openAPIHandler *handler.OpenAPIHandler,
|
||||
@@ -826,6 +856,7 @@ func setupRoutes(
|
||||
protected.GET("/conversations/:id", conversationHandler.GetConversation)
|
||||
protected.GET("/messages/:id/process-details", conversationHandler.GetMessageProcessDetails)
|
||||
protected.PUT("/conversations/:id", conversationHandler.UpdateConversation)
|
||||
protected.PUT("/conversations/:id/project", conversationHandler.SetConversationProject)
|
||||
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
|
||||
protected.POST("/conversations/:id/delete-turn", conversationHandler.DeleteConversationTurn)
|
||||
protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned)
|
||||
@@ -867,6 +898,13 @@ func setupRoutes(
|
||||
protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream)
|
||||
protected.GET("/terminal/ws", terminalHandler.RunCommandWS)
|
||||
|
||||
// 平台审计日志
|
||||
protected.GET("/audit/meta", auditHandler.Meta)
|
||||
protected.GET("/audit/summary", auditHandler.Summary)
|
||||
protected.GET("/audit/logs", auditHandler.ListLogs)
|
||||
protected.GET("/audit/logs/export", auditHandler.ExportLogs)
|
||||
protected.GET("/audit/logs/:id", auditHandler.GetLog)
|
||||
|
||||
// 外部MCP管理
|
||||
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
|
||||
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
|
||||
@@ -1035,6 +1073,23 @@ func setupRoutes(
|
||||
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
|
||||
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
|
||||
|
||||
// 项目管理与事实黑板
|
||||
protected.GET("/projects", projectHandler.ListProjects)
|
||||
protected.POST("/projects", projectHandler.CreateProject)
|
||||
protected.GET("/projects/:id/stats", projectHandler.GetProjectStats)
|
||||
protected.GET("/projects/:id/conversations", projectHandler.ListProjectConversations)
|
||||
protected.GET("/projects/:id", projectHandler.GetProject)
|
||||
protected.PUT("/projects/:id", projectHandler.UpdateProject)
|
||||
protected.DELETE("/projects/:id", projectHandler.DeleteProject)
|
||||
protected.GET("/projects/:id/facts", projectHandler.ListFacts)
|
||||
protected.GET("/projects/:id/facts/:factId/previous-version", projectHandler.GetFactPreviousVersion)
|
||||
protected.GET("/projects/:id/facts/:factId/versions", projectHandler.ListFactVersions)
|
||||
protected.POST("/projects/:id/facts", projectHandler.CreateFact)
|
||||
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
|
||||
protected.DELETE("/projects/:id/facts/:factId", projectHandler.DeleteFact)
|
||||
protected.POST("/projects/:id/facts/deprecate", projectHandler.DeprecateFact)
|
||||
protected.POST("/projects/:id/facts/restore", projectHandler.RestoreFact)
|
||||
|
||||
// WebShell 管理(代理执行 + 连接配置存 SQLite)
|
||||
protected.GET("/webshell/connections", webshellHandler.ListConnections)
|
||||
protected.POST("/webshell/connections", webshellHandler.CreateConnection)
|
||||
@@ -1155,195 +1210,6 @@ func setupRoutes(
|
||||
})
|
||||
}
|
||||
|
||||
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
|
||||
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: builtin.ToolRecordVulnerability,
|
||||
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
|
||||
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"title": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞标题(必需)",
|
||||
},
|
||||
"description": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞详细描述",
|
||||
},
|
||||
"severity": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
|
||||
"enum": []string{"critical", "high", "medium", "low", "info"},
|
||||
},
|
||||
"vulnerability_type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
|
||||
},
|
||||
"target": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "受影响的目标(URL、IP地址、服务等)",
|
||||
},
|
||||
"proof": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞证明(POC、截图、请求/响应等)",
|
||||
},
|
||||
"impact": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞影响说明",
|
||||
},
|
||||
"recommendation": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "修复建议",
|
||||
},
|
||||
},
|
||||
"required": []string{"title", "severity"},
|
||||
},
|
||||
}
|
||||
|
||||
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
// 从参数中获取conversation_id(由Agent自动添加)
|
||||
conversationID, _ := args["conversation_id"].(string)
|
||||
if conversationID == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: conversation_id 未设置。这是系统错误,请重试。",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
title, ok := args["title"].(string)
|
||||
if !ok || title == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: title 参数必需且不能为空",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
severity, ok := args["severity"].(string)
|
||||
if !ok || severity == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: severity 参数必需且不能为空",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 验证严重程度
|
||||
validSeverities := map[string]bool{
|
||||
"critical": true,
|
||||
"high": true,
|
||||
"medium": true,
|
||||
"low": true,
|
||||
"info": true,
|
||||
}
|
||||
if !validSeverities[severity] {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 获取可选参数
|
||||
description := ""
|
||||
if d, ok := args["description"].(string); ok {
|
||||
description = d
|
||||
}
|
||||
|
||||
vulnType := ""
|
||||
if t, ok := args["vulnerability_type"].(string); ok {
|
||||
vulnType = t
|
||||
}
|
||||
|
||||
target := ""
|
||||
if t, ok := args["target"].(string); ok {
|
||||
target = t
|
||||
}
|
||||
|
||||
proof := ""
|
||||
if p, ok := args["proof"].(string); ok {
|
||||
proof = p
|
||||
}
|
||||
|
||||
impact := ""
|
||||
if i, ok := args["impact"].(string); ok {
|
||||
impact = i
|
||||
}
|
||||
|
||||
recommendation := ""
|
||||
if r, ok := args["recommendation"].(string); ok {
|
||||
recommendation = r
|
||||
}
|
||||
|
||||
// 创建漏洞记录
|
||||
vuln := &database.Vulnerability{
|
||||
ConversationID: conversationID,
|
||||
Title: title,
|
||||
Description: description,
|
||||
Severity: severity,
|
||||
Status: "open",
|
||||
Type: vulnType,
|
||||
Target: target,
|
||||
Proof: proof,
|
||||
Impact: impact,
|
||||
Recommendation: recommendation,
|
||||
}
|
||||
|
||||
created, err := db.CreateVulnerability(vuln)
|
||||
if err != nil {
|
||||
logger.Error("记录漏洞失败", zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("记录漏洞失败: %v", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
logger.Info("漏洞记录成功",
|
||||
zap.String("id", created.ID),
|
||||
zap.String("title", created.Title),
|
||||
zap.String("severity", created.Severity),
|
||||
zap.String("conversation_id", conversationID),
|
||||
)
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n你可以在漏洞管理页面查看和管理此漏洞。", created.ID, created.Title, created.Severity, created.Status),
|
||||
},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, handler)
|
||||
logger.Info("漏洞记录工具注册成功")
|
||||
}
|
||||
|
||||
// registerWebshellTools 注册 WebShell 相关 MCP 工具,供 AI 助手在指定连接上执行命令与文件操作
|
||||
func registerWebshellTools(mcpServer *mcp.Server, db *database.DB, webshellHandler *handler.WebShellHandler, logger *zap.Logger) {
|
||||
if db == nil || webshellHandler == nil {
|
||||
@@ -1928,6 +1794,9 @@ func initializeKnowledge(
|
||||
|
||||
// 创建知识库API处理器
|
||||
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
|
||||
if app != nil && app.auditSvc != nil {
|
||||
knowledgeHandler.SetAudit(app.auditSvc)
|
||||
}
|
||||
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
||||
|
||||
// 设置知识库管理器到AgentHandler以便记录检索日志
|
||||
|
||||
@@ -0,0 +1,336 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/project"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func projectIDFromConversation(db *database.DB, ctx context.Context) (string, error) {
|
||||
convID := agent.ConversationIDFromContext(ctx)
|
||||
if convID == "" {
|
||||
return "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用项目事实工具")
|
||||
}
|
||||
pid, err := db.GetConversationProjectID(convID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(pid) == "" {
|
||||
return "", fmt.Errorf("当前对话未绑定项目,请先在对话中选择项目或创建带项目的对话")
|
||||
}
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
func textResult(msg string, isErr bool) *mcp.ToolResult {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{{Type: "text", Text: msg}},
|
||||
IsError: isErr,
|
||||
}
|
||||
}
|
||||
|
||||
// registerProjectFactTools 注册项目黑板 MCP 工具。
|
||||
func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *config.Config, logger *zap.Logger) {
|
||||
if db == nil || cfg == nil || !cfg.Project.Enabled {
|
||||
if logger != nil {
|
||||
logger.Info("项目黑板工具未注册(未启用)")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
upsertTool := mcp.Tool{
|
||||
Name: builtin.ToolUpsertProjectFact,
|
||||
Description: "写入或更新项目黑板事实,用于跨会话沉淀可复现上下文(非正式漏洞条目;可交付漏洞另用 record_vulnerability)。" +
|
||||
"边渗透边记录:每确认新认知(端口/入口/凭据/可利用点)后立即调用,同 fact_key 覆盖更新,勿等会话结束。" +
|
||||
"禁止仅写结论:summary 须含什么+在哪+如何验证;body 须含攻击链/请求响应/命令等复现细节。" +
|
||||
"发现类建议 fact_key 为 finding|chain|exploit|poc/<slug>,category 对应 finding|chain|exploit|poc,body 按攻击链模板填写。" +
|
||||
"环境类用 target|auth|infra|business/<slug>。同 fact_key 覆盖更新。需当前对话已绑定项目。",
|
||||
ShortDescription: "写入/更新项目事实(含攻击链 body)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"fact_key": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "项目内唯一 key:target/primary_domain、finding/sqli-login、exploit/upload-rce 等",
|
||||
},
|
||||
"category": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "target | auth | infra | business | finding | chain | exploit | poc | note",
|
||||
"enum": []string{"target", "auth", "infra", "business", "finding", "chain", "exploit", "poc", "note"},
|
||||
},
|
||||
"summary": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "索引用一行:结论 + 位置 + 触发/验证要点(勿仅写「存在 XSS」等空话)",
|
||||
},
|
||||
"body": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "完整可复现详情(仅 get_project_fact 返回):须含攻击链步骤、原始 HTTP/命令、响应现象、证据与关联。" +
|
||||
"发现/利用类首次写入必填;环境类建议含来源证据。攻击链类可参考模板章节:结论、目标与入口、攻击链、Exploit/POC、关键证据、关联、备注。" +
|
||||
"更新已有 fact_key 时若省略或留空 body,将保留库中已有 body(可只改 summary)。",
|
||||
},
|
||||
"confidence": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "confirmed | tentative | deprecated",
|
||||
"enum": []string{"confirmed", "tentative", "deprecated"},
|
||||
},
|
||||
"pinned": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "是否优先出现在黑板索引",
|
||||
},
|
||||
"related_vulnerability_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选:关联的漏洞记录 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"fact_key", "summary"},
|
||||
},
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(upsertTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
factKey, _ := args["fact_key"].(string)
|
||||
summary, _ := args["summary"].(string)
|
||||
if strings.TrimSpace(factKey) == "" || strings.TrimSpace(summary) == "" {
|
||||
return textResult("错误: fact_key 与 summary 必填", true), nil
|
||||
}
|
||||
if len([]rune(summary)) > cfg.Project.FactSummaryMaxRunesEffective() {
|
||||
return textResult(fmt.Sprintf("错误: summary 过长(最多 %d 字)", cfg.Project.FactSummaryMaxRunesEffective()), true), nil
|
||||
}
|
||||
f := &database.ProjectFact{
|
||||
ProjectID: projectID,
|
||||
FactKey: factKey,
|
||||
Category: strArg(args, "category"),
|
||||
Summary: summary,
|
||||
Body: strArg(args, "body"),
|
||||
Confidence: strArg(args, "confidence"),
|
||||
Pinned: boolArg(args, "pinned"),
|
||||
RelatedVulnerabilityID: strArg(args, "related_vulnerability_id"),
|
||||
}
|
||||
if convID := agent.ConversationIDFromContext(ctx); convID != "" {
|
||||
f.SourceConversationID = convID
|
||||
}
|
||||
created, err := db.UpsertProjectFact(f)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence)
|
||||
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||
msg += warn
|
||||
}
|
||||
return textResult(msg, false), nil
|
||||
})
|
||||
|
||||
getTool := mcp.Tool{
|
||||
Name: builtin.ToolGetProjectFact,
|
||||
Description: "按 fact_key 获取项目事实完整 body 与元数据。摘要不足时必须调用本工具,禁止臆造细节。",
|
||||
ShortDescription: "按 key 获取事实详情",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"fact_key": map[string]interface{}{"type": "string", "description": "事实 key"},
|
||||
},
|
||||
"required": []string{"fact_key"},
|
||||
},
|
||||
}
|
||||
mcpServer.RegisterTool(getTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
key := strings.TrimSpace(strArg(args, "fact_key"))
|
||||
if key == "" {
|
||||
return textResult("错误: fact_key 必填", true), nil
|
||||
}
|
||||
f, err := db.GetProjectFactByKey(projectID, key)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
msg := fmt.Sprintf("fact_key: %s\ncategory: %s\nconfidence: %s\nsummary: %s\nupdated_at: %s",
|
||||
f.FactKey, f.Category, f.Confidence, f.Summary, f.UpdatedAt.Format("2006-01-02 15:04:05"))
|
||||
if f.RelatedVulnerabilityID != "" {
|
||||
msg += fmt.Sprintf("\nrelated_vulnerability_id: %s", f.RelatedVulnerabilityID)
|
||||
}
|
||||
if f.SourceConversationID != "" {
|
||||
msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID)
|
||||
}
|
||||
msg += "\n\n--- body ---\n" + f.Body
|
||||
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||
msg += warn
|
||||
}
|
||||
return textResult(msg, false), nil
|
||||
})
|
||||
|
||||
listTool := mcp.Tool{
|
||||
Name: builtin.ToolListProjectFacts,
|
||||
Description: "列出当前项目的事实(分页)。",
|
||||
ShortDescription: "列出项目事实",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"category": map[string]interface{}{"type": "string"},
|
||||
"confidence": map[string]interface{}{"type": "string"},
|
||||
"limit": map[string]interface{}{"type": "integer"},
|
||||
"offset": map[string]interface{}{"type": "integer"},
|
||||
},
|
||||
},
|
||||
}
|
||||
mcpServer.RegisterTool(listTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
limit := intArg(args, "limit", 50)
|
||||
offset := intArg(args, "offset", 0)
|
||||
filter := database.ProjectFactListFilter{
|
||||
Category: strArg(args, "category"),
|
||||
Confidence: strArg(args, "confidence"),
|
||||
}
|
||||
list, err := db.ListProjectFacts(projectID, filter, limit, offset)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("共 %d 条(limit=%d offset=%d):\n", len(list), limit, offset))
|
||||
for _, f := range list {
|
||||
b.WriteString(fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, f.Summary, f.Confidence))
|
||||
}
|
||||
return textResult(b.String(), false), nil
|
||||
})
|
||||
|
||||
searchTool := mcp.Tool{
|
||||
Name: builtin.ToolSearchProjectFacts,
|
||||
Description: "按关键词搜索项目事实(summary/body/fact_key)。",
|
||||
ShortDescription: "搜索项目事实",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{"type": "string"},
|
||||
"limit": map[string]interface{}{"type": "integer"},
|
||||
"offset": map[string]interface{}{"type": "integer"},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
}
|
||||
mcpServer.RegisterTool(searchTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
q := strings.TrimSpace(strArg(args, "query"))
|
||||
if q == "" {
|
||||
return textResult("错误: query 必填", true), nil
|
||||
}
|
||||
list, err := db.ListProjectFacts(projectID, database.ProjectFactListFilter{Search: q}, intArg(args, "limit", 30), intArg(args, "offset", 0))
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("搜索 \"%s\" 命中 %d 条:\n", q, len(list)))
|
||||
for _, f := range list {
|
||||
b.WriteString(fmt.Sprintf("- [%s] %s — %s\n", f.FactKey, f.Category, f.Summary))
|
||||
}
|
||||
return textResult(b.String(), false), nil
|
||||
})
|
||||
|
||||
deprecateTool := mcp.Tool{
|
||||
Name: builtin.ToolDeprecateProjectFact,
|
||||
Description: "将事实标记为 deprecated,从黑板索引中排除。",
|
||||
ShortDescription: "废弃项目事实",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"fact_key": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"fact_key"},
|
||||
},
|
||||
}
|
||||
mcpServer.RegisterTool(deprecateTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
key := strings.TrimSpace(strArg(args, "fact_key"))
|
||||
if err := db.DeprecateProjectFact(projectID, key); err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
return textResult("事实已标记为 deprecated: "+key, false), nil
|
||||
})
|
||||
|
||||
restoreTool := mcp.Tool{
|
||||
Name: builtin.ToolRestoreProjectFact,
|
||||
Description: "将已废弃(deprecated)的事实恢复为 tentative 或 confirmed,重新参与黑板索引。",
|
||||
ShortDescription: "恢复已废弃的项目事实",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"fact_key": map[string]interface{}{"type": "string"},
|
||||
"confidence": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "恢复后的置信度:tentative(默认)或 confirmed",
|
||||
"enum": []string{"tentative", "confirmed"},
|
||||
},
|
||||
},
|
||||
"required": []string{"fact_key"},
|
||||
},
|
||||
}
|
||||
mcpServer.RegisterTool(restoreTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
projectID, err := projectIDFromConversation(db, ctx)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
key := strings.TrimSpace(strArg(args, "fact_key"))
|
||||
if key == "" {
|
||||
return textResult("错误: fact_key 必填", true), nil
|
||||
}
|
||||
conf := strArg(args, "confidence")
|
||||
if err := db.RestoreProjectFact(projectID, key, conf); err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
if conf == "" {
|
||||
conf = "tentative"
|
||||
}
|
||||
return textResult(fmt.Sprintf("事实已恢复为 %s: %s", conf, key), false), nil
|
||||
})
|
||||
|
||||
if logger != nil {
|
||||
logger.Info("项目黑板 MCP 工具注册成功")
|
||||
}
|
||||
}
|
||||
|
||||
func strArg(args map[string]interface{}, key string) string {
|
||||
if v, ok := args[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func boolArg(args map[string]interface{}, key string) bool {
|
||||
if v, ok := args[key].(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func intArg(args map[string]interface{}, key string, def int) int {
|
||||
switch v := args[key].(type) {
|
||||
case float64:
|
||||
return int(v)
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
default:
|
||||
return def
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,405 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func conversationIDFromToolCtx(ctx context.Context) string {
|
||||
if id := agent.ConversationIDFromContext(ctx); id != "" {
|
||||
return id
|
||||
}
|
||||
return mcp.MCPConversationIDFromContext(ctx)
|
||||
}
|
||||
|
||||
// canAccessVulnerability 校验当前对话是否有权查看该漏洞(默认项目隔离,未绑项目则仅本会话)。
|
||||
func canAccessVulnerability(vuln *database.Vulnerability, convID, projectID string) bool {
|
||||
if vuln == nil || convID == "" {
|
||||
return false
|
||||
}
|
||||
if projectID != "" {
|
||||
if strings.TrimSpace(vuln.ProjectID) == projectID {
|
||||
return true
|
||||
}
|
||||
// 历史记录:写入时尚未绑定 project_id,但属于同一会话
|
||||
if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID == convID {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return vuln.ConversationID == convID
|
||||
}
|
||||
|
||||
func buildVulnerabilityListFilter(db *database.DB, ctx context.Context, args map[string]interface{}) (database.VulnerabilityListFilter, string, error) {
|
||||
convID := conversationIDFromToolCtx(ctx)
|
||||
if convID == "" {
|
||||
return database.VulnerabilityListFilter{}, "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用漏洞查询工具")
|
||||
}
|
||||
|
||||
projectID := ""
|
||||
if pid, err := db.GetConversationProjectID(convID); err == nil {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
|
||||
scope := strings.TrimSpace(strArg(args, "scope"))
|
||||
if scope == "" {
|
||||
if projectID != "" {
|
||||
scope = "project"
|
||||
} else {
|
||||
scope = "conversation"
|
||||
}
|
||||
}
|
||||
|
||||
filter := database.VulnerabilityListFilter{
|
||||
Severity: strings.TrimSpace(strArg(args, "severity")),
|
||||
Status: strings.TrimSpace(strArg(args, "status")),
|
||||
}
|
||||
if q := strings.TrimSpace(strArg(args, "q")); q != "" {
|
||||
filter.Search = q
|
||||
} else {
|
||||
filter.Search = strings.TrimSpace(strArg(args, "search"))
|
||||
}
|
||||
|
||||
var scopeLabel string
|
||||
switch scope {
|
||||
case "project":
|
||||
if projectID == "" {
|
||||
return filter, "", fmt.Errorf("当前对话未绑定项目,无法按项目列出漏洞;请使用 scope=conversation,或先在对话中绑定项目")
|
||||
}
|
||||
filter.ProjectID = projectID
|
||||
scopeLabel = fmt.Sprintf("项目 %s", projectID)
|
||||
case "conversation":
|
||||
filter.ConversationID = convID
|
||||
scopeLabel = fmt.Sprintf("会话 %s", convID)
|
||||
default:
|
||||
return filter, "", fmt.Errorf("scope 仅支持 project 或 conversation,当前值: %s", scope)
|
||||
}
|
||||
return filter, scopeLabel, nil
|
||||
}
|
||||
|
||||
func formatVulnerabilityListItem(v *database.Vulnerability) string {
|
||||
line := fmt.Sprintf("- id=%s | %s | %s | %s", v.ID, v.Severity, v.Status, v.Title)
|
||||
if v.Type != "" {
|
||||
line += fmt.Sprintf(" | type=%s", v.Type)
|
||||
}
|
||||
if v.Target != "" {
|
||||
line += fmt.Sprintf(" | target=%s", truncateRunes(v.Target, 80))
|
||||
}
|
||||
return line
|
||||
}
|
||||
|
||||
func formatVulnerabilityDetail(v *database.Vulnerability) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("漏洞ID: %s\n", v.ID))
|
||||
b.WriteString(fmt.Sprintf("标题: %s\n", v.Title))
|
||||
b.WriteString(fmt.Sprintf("严重程度: %s\n", v.Severity))
|
||||
b.WriteString(fmt.Sprintf("状态: %s\n", v.Status))
|
||||
if v.Type != "" {
|
||||
b.WriteString(fmt.Sprintf("类型: %s\n", v.Type))
|
||||
}
|
||||
if v.Target != "" {
|
||||
b.WriteString(fmt.Sprintf("目标: %s\n", v.Target))
|
||||
}
|
||||
if v.ProjectID != "" {
|
||||
b.WriteString(fmt.Sprintf("项目ID: %s\n", v.ProjectID))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("会话ID: %s\n", v.ConversationID))
|
||||
if !v.CreatedAt.IsZero() {
|
||||
b.WriteString(fmt.Sprintf("创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05")))
|
||||
}
|
||||
if v.Description != "" {
|
||||
b.WriteString("\n--- 描述 ---\n")
|
||||
b.WriteString(v.Description)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Proof != "" {
|
||||
b.WriteString("\n--- 证明(POC) ---\n")
|
||||
b.WriteString(v.Proof)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Impact != "" {
|
||||
b.WriteString("\n--- 影响 ---\n")
|
||||
b.WriteString(v.Impact)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Recommendation != "" {
|
||||
b.WriteString("\n--- 修复建议 ---\n")
|
||||
b.WriteString(v.Recommendation)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func truncateRunes(s string, max int) string {
|
||||
r := []rune(s)
|
||||
if len(r) <= max {
|
||||
return s
|
||||
}
|
||||
return string(r[:max]) + "…"
|
||||
}
|
||||
|
||||
// registerVulnerabilityTools 注册漏洞记录与查询 MCP 工具。
|
||||
func registerVulnerabilityTools(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
registerRecordVulnerabilityTool(mcpServer, db, logger)
|
||||
registerListVulnerabilitiesTool(mcpServer, db, logger)
|
||||
registerGetVulnerabilityTool(mcpServer, db, logger)
|
||||
if logger != nil {
|
||||
logger.Info("漏洞 MCP 工具注册成功", zap.Strings("tools", []string{
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: builtin.ToolRecordVulnerability,
|
||||
Description: "记录发现的漏洞详情到漏洞管理系统。边渗透边记录:每验证出一条可复现漏洞(含 POC/影响)后立即调用,勿等会话结束。包括标题、描述、严重程度、类型、目标、证明、影响和建议等。记录前可先 list_vulnerabilities 避免重复。",
|
||||
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"title": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞标题(必需)",
|
||||
},
|
||||
"description": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞详细描述",
|
||||
},
|
||||
"severity": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
|
||||
"enum": []string{"critical", "high", "medium", "low", "info"},
|
||||
},
|
||||
"vulnerability_type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
|
||||
},
|
||||
"target": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "受影响的目标(URL、IP地址、服务等)",
|
||||
},
|
||||
"proof": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞证明(POC、截图、请求/响应等)",
|
||||
},
|
||||
"impact": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞影响说明",
|
||||
},
|
||||
"recommendation": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "修复建议",
|
||||
},
|
||||
},
|
||||
"required": []string{"title", "severity"},
|
||||
},
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
conversationID := strings.TrimSpace(strArg(args, "conversation_id"))
|
||||
if conversationID == "" {
|
||||
conversationID = conversationIDFromToolCtx(ctx)
|
||||
}
|
||||
if conversationID == "" {
|
||||
return textResult("错误: conversation_id 未设置。这是系统错误,请重试。", true), nil
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(strArg(args, "title"))
|
||||
if title == "" {
|
||||
return textResult("错误: title 参数必需且不能为空", true), nil
|
||||
}
|
||||
|
||||
severity := strings.TrimSpace(strArg(args, "severity"))
|
||||
if severity == "" {
|
||||
return textResult("错误: severity 参数必需且不能为空", true), nil
|
||||
}
|
||||
|
||||
validSeverities := map[string]bool{
|
||||
"critical": true, "high": true, "medium": true, "low": true, "info": true,
|
||||
}
|
||||
if !validSeverities[severity] {
|
||||
return textResult(fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity), true), nil
|
||||
}
|
||||
|
||||
projectID := ""
|
||||
if pid, perr := db.GetConversationProjectID(conversationID); perr == nil {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
|
||||
vuln := &database.Vulnerability{
|
||||
ConversationID: conversationID,
|
||||
ProjectID: projectID,
|
||||
Title: title,
|
||||
Description: strArg(args, "description"),
|
||||
Severity: severity,
|
||||
Status: "open",
|
||||
Type: strArg(args, "vulnerability_type"),
|
||||
Target: strArg(args, "target"),
|
||||
Proof: strArg(args, "proof"),
|
||||
Impact: strArg(args, "impact"),
|
||||
Recommendation: strArg(args, "recommendation"),
|
||||
}
|
||||
|
||||
created, err := db.CreateVulnerability(vuln)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Error("记录漏洞失败", zap.Error(err))
|
||||
}
|
||||
return textResult(fmt.Sprintf("记录漏洞失败: %v", err), true), nil
|
||||
}
|
||||
|
||||
if logger != nil {
|
||||
logger.Info("漏洞记录成功",
|
||||
zap.String("id", created.ID),
|
||||
zap.String("title", created.Title),
|
||||
zap.String("severity", created.Severity),
|
||||
zap.String("conversation_id", conversationID),
|
||||
)
|
||||
}
|
||||
|
||||
return textResult(fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n可使用 get_vulnerability(id) 查看详情,或 list_vulnerabilities 查看列表。",
|
||||
created.ID, created.Title, created.Severity, created.Status), false), nil
|
||||
})
|
||||
}
|
||||
|
||||
func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: builtin.ToolListVulnerabilities,
|
||||
Description: "列出当前授权范围内的漏洞(摘要)。默认:对话已绑定项目时列出该项目下全部漏洞;未绑项目时仅列出当前会话漏洞。可用 scope=conversation 仅看本会话。记录新漏洞前建议先调用以避免重复。",
|
||||
ShortDescription: "列出漏洞(默认当前项目)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"scope": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "范围:project(默认,需绑定项目)| conversation(仅当前会话)",
|
||||
"enum": []string{"project", "conversation"},
|
||||
},
|
||||
"severity": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "按严重程度筛选:critical、high、medium、low、info",
|
||||
"enum": []string{"critical", "high", "medium", "low", "info"},
|
||||
},
|
||||
"status": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "按状态筛选:open、confirmed、fixed、false_positive",
|
||||
"enum": []string{"open", "confirmed", "fixed", "false_positive"},
|
||||
},
|
||||
"q": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "关键词搜索(标题、描述、类型、目标等)",
|
||||
},
|
||||
"limit": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "返回条数上限,默认 30,最大 100",
|
||||
},
|
||||
"offset": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "分页偏移,默认 0",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
filter, scopeLabel, err := buildVulnerabilityListFilter(db, ctx, args)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
|
||||
limit := intArg(args, "limit", 30)
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 30
|
||||
}
|
||||
offset := intArg(args, "offset", 0)
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
total, err := db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("统计漏洞失败", zap.Error(err))
|
||||
}
|
||||
total = 0
|
||||
}
|
||||
|
||||
list, err := db.ListVulnerabilities(limit, offset, filter)
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("范围: %s\n总计: %d | 本页: %d 条 (limit=%d offset=%d)\n\n", scopeLabel, total, len(list), limit, offset))
|
||||
if len(list) == 0 {
|
||||
b.WriteString("(暂无漏洞记录)\n")
|
||||
} else {
|
||||
for _, v := range list {
|
||||
b.WriteString(formatVulnerabilityListItem(v))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if total > offset+len(list) {
|
||||
b.WriteString(fmt.Sprintf("\n(还有更多,可增大 offset 或使用 q/severity/status 筛选)\n"))
|
||||
}
|
||||
}
|
||||
b.WriteString("\n需要 POC 与完整字段请对具体 id 调用 get_vulnerability。")
|
||||
return textResult(b.String(), false), nil
|
||||
})
|
||||
}
|
||||
|
||||
func registerGetVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: builtin.ToolGetVulnerability,
|
||||
Description: "按漏洞 ID 获取完整详情(含 POC、影响、修复建议)。仅能访问当前项目或当前会话下的漏洞(与 list_vulnerabilities 授权范围一致)。",
|
||||
ShortDescription: "按 ID 获取漏洞详情",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "漏洞 ID(list_vulnerabilities 返回的 id)",
|
||||
},
|
||||
},
|
||||
"required": []string{"id"},
|
||||
},
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
convID := conversationIDFromToolCtx(ctx)
|
||||
if convID == "" {
|
||||
return textResult("错误: 无法确定当前对话,请在对话上下文中使用本工具", true), nil
|
||||
}
|
||||
|
||||
id := strings.TrimSpace(strArg(args, "id"))
|
||||
if id == "" {
|
||||
return textResult("错误: id 必填", true), nil
|
||||
}
|
||||
|
||||
vuln, err := db.GetVulnerability(id)
|
||||
if err != nil {
|
||||
return textResult("错误: 漏洞不存在或查询失败", true), nil
|
||||
}
|
||||
|
||||
projectID := ""
|
||||
if pid, perr := db.GetConversationProjectID(convID); perr == nil {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
|
||||
if !canAccessVulnerability(vuln, convID, projectID) {
|
||||
return textResult("错误: 无权访问该漏洞(仅可查看当前项目或当前会话下的记录)", true), nil
|
||||
}
|
||||
|
||||
return textResult(formatVulnerabilityDetail(vuln), false), nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterConversationCreateHook records platform audit rows for every new conversation.
|
||||
func RegisterConversationCreateHook(s *Service) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
database.SetConversationCreateHook(func(conv *database.Conversation, meta database.ConversationCreateMeta) {
|
||||
detail := map[string]interface{}{
|
||||
"title": conv.Title,
|
||||
"source": meta.Source,
|
||||
}
|
||||
if meta.WebShellConnectionID != "" {
|
||||
detail["webshell_connection_id"] = meta.WebShellConnectionID
|
||||
}
|
||||
s.Record(nil, Entry{
|
||||
Category: "conversation",
|
||||
Action: "create",
|
||||
Result: "success",
|
||||
Message: "创建对话",
|
||||
ResourceType: "conversation",
|
||||
ResourceID: conv.ID,
|
||||
Detail: detail,
|
||||
ClientIP: meta.ClientIP,
|
||||
SessionHint: meta.SessionHint,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ConversationCreateMeta builds audit metadata for conversation creation.
|
||||
func ConversationCreateMeta(source string) database.ConversationCreateMeta {
|
||||
return database.ConversationCreateMeta{Source: strings.TrimSpace(source)}
|
||||
}
|
||||
|
||||
// ConversationCreateMetaFromGin includes client IP and session hint when available.
|
||||
func ConversationCreateMetaFromGin(c *gin.Context, source string) database.ConversationCreateMeta {
|
||||
m := ConversationCreateMeta(source)
|
||||
if c == nil {
|
||||
return m
|
||||
}
|
||||
m.ClientIP = c.ClientIP()
|
||||
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
|
||||
m.SessionHint = sessionHint(token)
|
||||
}
|
||||
return m
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package audit
|
||||
|
||||
// RetentionDays returns configured retention; 0 means keep forever.
|
||||
func (s *Service) RetentionDays() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 0
|
||||
}
|
||||
return s.cfg.Audit.RetentionDaysEffective()
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package audit
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// RecordAction writes a platform audit row with common defaults.
|
||||
func (s *Service) RecordAction(c *gin.Context, category, action, result, message, resourceType, resourceID string, detail map[string]interface{}) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.Record(c, Entry{
|
||||
Category: category,
|
||||
Action: action,
|
||||
Result: result,
|
||||
Message: message,
|
||||
ResourceType: resourceType,
|
||||
ResourceID: resourceID,
|
||||
Detail: detail,
|
||||
})
|
||||
}
|
||||
|
||||
// RecordOK is a shorthand for successful operations.
|
||||
func (s *Service) RecordOK(c *gin.Context, category, action, message, resourceType, resourceID string, detail map[string]interface{}) {
|
||||
s.RecordAction(c, category, action, "success", message, resourceType, resourceID, detail)
|
||||
}
|
||||
|
||||
// RecordFail is a shorthand for failed operations.
|
||||
func (s *Service) RecordFail(c *gin.Context, category, action, message string, detail map[string]interface{}) {
|
||||
s.RecordAction(c, category, action, "failure", message, "", "", detail)
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
var auditActionsResourceRemoved = map[string]bool{
|
||||
"delete": true,
|
||||
"item_delete": true,
|
||||
"connection_delete": true,
|
||||
"listener_delete": true,
|
||||
"session_delete": true,
|
||||
"task_delete": true,
|
||||
"execution_delete": true,
|
||||
"execution_delete_batch": true,
|
||||
"delete_queue": true,
|
||||
"delete_batch_task": true,
|
||||
"markdown_delete": true,
|
||||
}
|
||||
|
||||
// ApplyResourceAvailability sets log.ResourceAvailable when the linked resource can be checked.
|
||||
func ApplyResourceAvailability(db *database.DB, log *database.AuditLog) {
|
||||
if log == nil || strings.TrimSpace(log.ResourceID) == "" {
|
||||
return
|
||||
}
|
||||
if auditActionsResourceRemoved[log.Action] {
|
||||
f := false
|
||||
log.ResourceAvailable = &f
|
||||
return
|
||||
}
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
available, known := resourceStillExists(db, log.ResourceType, log.ResourceID)
|
||||
if known {
|
||||
log.ResourceAvailable = &available
|
||||
}
|
||||
}
|
||||
|
||||
func resourceStillExists(db *database.DB, resourceType, resourceID string) (bool, bool) {
|
||||
resourceID = strings.TrimSpace(resourceID)
|
||||
if resourceID == "" {
|
||||
return false, false
|
||||
}
|
||||
t := strings.TrimSpace(resourceType)
|
||||
if t == "" {
|
||||
if len(resourceID) > 8 && !strings.HasPrefix(resourceID, "c2_") {
|
||||
t = "conversation"
|
||||
} else {
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
switch t {
|
||||
case "conversation":
|
||||
ok, err := db.ConversationExists(resourceID)
|
||||
return ok, err == nil
|
||||
case "vulnerability":
|
||||
_, err := db.GetVulnerability(resourceID)
|
||||
if err != nil {
|
||||
return false, strings.Contains(err.Error(), "不存在")
|
||||
}
|
||||
return true, true
|
||||
case "batch_queue":
|
||||
_, err := db.GetBatchQueue(resourceID)
|
||||
return err == nil, true
|
||||
case "c2_listener":
|
||||
_, err := db.GetC2Listener(resourceID)
|
||||
return err == nil, true
|
||||
case "c2_session":
|
||||
_, err := db.GetC2Session(resourceID)
|
||||
return err == nil, true
|
||||
case "c2_task":
|
||||
_, err := db.GetC2Task(resourceID)
|
||||
return err == nil, true
|
||||
case "webshell_connection":
|
||||
c, err := db.GetWebshellConnection(resourceID)
|
||||
return err == nil && c != nil, true
|
||||
case "tool_execution":
|
||||
_, err := db.GetToolExecution(resourceID)
|
||||
return err == nil, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// auditRetentionPurgeInterval is how often PurgeExpired runs while the process is up (startup also purges once).
|
||||
const auditRetentionPurgeInterval = time.Hour
|
||||
|
||||
// StartRetentionLoop periodically purges expired audit rows.
|
||||
func StartRetentionLoop(s *Service, logger *zap.Logger) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(auditRetentionPurgeInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.PurgeExpired()
|
||||
if logger != nil {
|
||||
logger.Debug("audit retention tick completed")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var sensitiveKeySubstrings = []string{
|
||||
"password", "api_key", "apikey", "secret", "token", "authorization",
|
||||
"credential", "private_key", "access_key",
|
||||
}
|
||||
|
||||
// SanitizeDetail redacts sensitive keys and truncates serialized size.
|
||||
func SanitizeDetail(detail map[string]interface{}, maxBytes int) map[string]interface{} {
|
||||
if detail == nil {
|
||||
return nil
|
||||
}
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 8192
|
||||
}
|
||||
out := sanitizeValue("", detail)
|
||||
if m, ok := out.(map[string]interface{}); ok {
|
||||
b, _ := json.Marshal(m)
|
||||
if len(b) > maxBytes {
|
||||
return map[string]interface{}{
|
||||
"_truncated": true,
|
||||
"_preview": string(b[:maxBytes]),
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
return map[string]interface{}{"value": out}
|
||||
}
|
||||
|
||||
func sanitizeValue(key string, v interface{}) interface{} {
|
||||
kl := strings.ToLower(key)
|
||||
for _, sub := range sensitiveKeySubstrings {
|
||||
if strings.Contains(kl, sub) {
|
||||
return "***"
|
||||
}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case map[string]interface{}:
|
||||
m := make(map[string]interface{}, len(t))
|
||||
for k, val := range t {
|
||||
m[k] = sanitizeValue(k, val)
|
||||
}
|
||||
return m
|
||||
case []interface{}:
|
||||
arr := make([]interface{}, len(t))
|
||||
for i, val := range t {
|
||||
arr[i] = sanitizeValue(key, val)
|
||||
}
|
||||
return arr
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Service persists platform audit logs.
|
||||
type Service struct {
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
logger *zap.Logger
|
||||
failThrottle *failureThrottle
|
||||
}
|
||||
|
||||
// NewService creates an audit service.
|
||||
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
failThrottle: newFailureThrottle(),
|
||||
}
|
||||
}
|
||||
|
||||
// Enabled reports whether audit persistence is on.
|
||||
func (s *Service) Enabled() bool {
|
||||
if s == nil || s.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return s.cfg.Audit.EnabledEffective()
|
||||
}
|
||||
|
||||
// Record writes one audit row from a Gin request context.
|
||||
func (s *Service) Record(c *gin.Context, e Entry) {
|
||||
if s == nil || !s.Enabled() || s.db == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(e.Category) == "" || strings.TrimSpace(e.Action) == "" {
|
||||
return
|
||||
}
|
||||
if e.Result == "failure" && !s.allowFailureAudit(c, e) {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(e.Result) == "" {
|
||||
e.Result = "success"
|
||||
}
|
||||
if strings.TrimSpace(e.Level) == "" {
|
||||
if e.Result == "failure" {
|
||||
e.Level = "warn"
|
||||
} else {
|
||||
e.Level = "info"
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(e.Actor) == "" {
|
||||
e.Actor = "admin"
|
||||
}
|
||||
maxDetail := s.cfg.Audit.MaxDetailBytesEffective()
|
||||
detail := SanitizeDetail(e.Detail, maxDetail)
|
||||
|
||||
sessionHintVal := e.SessionHint
|
||||
if sessionHintVal == "" && c != nil {
|
||||
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
|
||||
sessionHintVal = sessionHint(token)
|
||||
}
|
||||
}
|
||||
clientIPVal := e.ClientIP
|
||||
if clientIPVal == "" {
|
||||
clientIPVal = clientIP(c)
|
||||
}
|
||||
|
||||
row := &database.AuditLog{
|
||||
ID: "audit_" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
CreatedAt: time.Now(),
|
||||
Level: e.Level,
|
||||
Category: e.Category,
|
||||
Action: e.Action,
|
||||
Result: e.Result,
|
||||
Actor: e.Actor,
|
||||
SessionHint: sessionHintVal,
|
||||
ClientIP: clientIPVal,
|
||||
UserAgent: userAgent(c),
|
||||
ResourceType: e.ResourceType,
|
||||
ResourceID: e.ResourceID,
|
||||
Message: e.Message,
|
||||
Detail: detail,
|
||||
}
|
||||
if err := s.db.AppendAuditLog(row); err != nil && s.logger != nil {
|
||||
s.logger.Warn("写入审计日志失败",
|
||||
zap.String("action", e.Action),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSystem writes an audit row without HTTP context (e.g. retention cleanup).
|
||||
func (s *Service) RecordSystem(e Entry) {
|
||||
s.Record(nil, e)
|
||||
}
|
||||
|
||||
// PurgeExpired deletes rows older than retention_days when configured.
|
||||
func (s *Service) PurgeExpired() {
|
||||
if s == nil || s.db == nil || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
days := s.cfg.Audit.RetentionDaysEffective()
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
n, err := s.db.DeleteAuditLogsBefore(cutoff)
|
||||
if err != nil {
|
||||
if s.logger != nil {
|
||||
s.logger.Warn("清理过期审计日志失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && s.logger != nil {
|
||||
s.logger.Info("已清理过期审计日志", zap.Int64("deleted", n))
|
||||
}
|
||||
}
|
||||
|
||||
// HintFromToken returns a short stable hash prefix for a session token.
|
||||
func HintFromToken(token string) string {
|
||||
return sessionHint(token)
|
||||
}
|
||||
|
||||
func sessionHint(token string) string {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:4])
|
||||
}
|
||||
|
||||
func (s *Service) allowFailureAudit(c *gin.Context, e Entry) bool {
|
||||
if !isAuthFailureThrottled(e.Category, e.Action) {
|
||||
return true
|
||||
}
|
||||
cooldown := time.Duration(s.cfg.Audit.AuthFailureCooldownEffective()) * time.Second
|
||||
key := authFailureThrottleKey(e.Category, e.Action, clientIP(c))
|
||||
return s.failThrottle.allow(key, cooldown)
|
||||
}
|
||||
|
||||
func clientIP(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
func userAgent(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
ua := c.GetHeader("User-Agent")
|
||||
if len(ua) > 512 {
|
||||
return ua[:512]
|
||||
}
|
||||
return ua
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// failureThrottle deduplicates high-frequency failure audit rows (e.g. wrong password).
|
||||
type failureThrottle struct {
|
||||
mu sync.Mutex
|
||||
last map[string]time.Time
|
||||
}
|
||||
|
||||
func newFailureThrottle() *failureThrottle {
|
||||
return &failureThrottle{last: make(map[string]time.Time)}
|
||||
}
|
||||
|
||||
// allow reports whether a row with the given key may be written now.
|
||||
func (t *failureThrottle) allow(key string, cooldown time.Duration) bool {
|
||||
if t == nil || cooldown <= 0 || key == "" {
|
||||
return true
|
||||
}
|
||||
now := time.Now()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if prev, ok := t.last[key]; ok && now.Sub(prev) < cooldown {
|
||||
return false
|
||||
}
|
||||
t.last[key] = now
|
||||
if len(t.last) > 4096 {
|
||||
for k, ts := range t.last {
|
||||
if now.Sub(ts) > cooldown*2 {
|
||||
delete(t.last, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// authFailureThrottleKey builds a per-IP key for auth failure deduplication.
|
||||
func authFailureThrottleKey(category, action, clientIP string) string {
|
||||
return category + ":" + action + ":" + clientIP
|
||||
}
|
||||
|
||||
func isAuthFailureThrottled(category, action string) bool {
|
||||
if category != "auth" {
|
||||
return false
|
||||
}
|
||||
switch action {
|
||||
case "login", "change_password":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package audit
|
||||
|
||||
// Entry describes one platform audit record (not chat/tool execution bodies).
|
||||
type Entry struct {
|
||||
Level string
|
||||
Category string
|
||||
Action string
|
||||
Result string // success | failure
|
||||
Actor string
|
||||
SessionHint string
|
||||
ResourceType string
|
||||
ResourceID string
|
||||
Message string
|
||||
Detail map[string]interface{}
|
||||
ClientIP string // optional when c is nil (robot, batch, DB hook)
|
||||
}
|
||||
+109
-10
@@ -26,6 +26,7 @@ type Config struct {
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||
@@ -35,13 +36,39 @@ type Config struct {
|
||||
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
|
||||
AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.md,YAML front matter)
|
||||
MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"`
|
||||
Project ProjectConfig `yaml:"project,omitempty" json:"project,omitempty"`
|
||||
}
|
||||
|
||||
// ProjectConfig 项目黑板(跨对话共享事实)配置。
|
||||
type ProjectConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目
|
||||
FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"`
|
||||
FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"`
|
||||
DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"`
|
||||
}
|
||||
|
||||
// FactIndexMaxRunesEffective 自动注入黑板索引的最大 rune 数。
|
||||
func (c ProjectConfig) FactIndexMaxRunesEffective() int {
|
||||
if c.FactIndexMaxRunes <= 0 {
|
||||
return 3500
|
||||
}
|
||||
return c.FactIndexMaxRunes
|
||||
}
|
||||
|
||||
// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数(索引一行,宜含验证要点)。
|
||||
func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
|
||||
if c.FactSummaryMaxRunes <= 0 {
|
||||
return 200
|
||||
}
|
||||
return c.FactSummaryMaxRunes
|
||||
}
|
||||
|
||||
// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。
|
||||
type MultiAgentConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理
|
||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
RobotDefaultAgentMode string `yaml:"robot_default_agent_mode,omitempty" json:"robot_default_agent_mode,omitempty"` // react | eino_single | deep | plan_execute | supervisor
|
||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
|
||||
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
|
||||
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor)
|
||||
@@ -227,6 +254,10 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
||||
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
||||
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
|
||||
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
|
||||
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
||||
}
|
||||
@@ -362,9 +393,9 @@ type MultiAgentSubConfig struct {
|
||||
|
||||
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
|
||||
type MultiAgentPublic struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
SubAgentCount int `json:"sub_agent_count"`
|
||||
Orchestration string `json:"orchestration,omitempty"`
|
||||
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
|
||||
@@ -372,6 +403,18 @@ type MultiAgentPublic struct {
|
||||
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
|
||||
}
|
||||
|
||||
// NormalizeRobotAgentMode 解析机器人默认对话模式(react | eino_single | deep | plan_execute | supervisor);空值视为 react。
|
||||
func NormalizeRobotAgentMode(ma MultiAgentConfig) string {
|
||||
s := strings.TrimSpace(strings.ToLower(ma.RobotDefaultAgentMode))
|
||||
if s == "" || s == "single" || s == "react" {
|
||||
return "react"
|
||||
}
|
||||
if s == "eino_single" {
|
||||
return "eino_single"
|
||||
}
|
||||
return NormalizeMultiAgentOrchestration(s)
|
||||
}
|
||||
|
||||
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
|
||||
func NormalizeMultiAgentOrchestration(s string) string {
|
||||
v := strings.TrimSpace(strings.ToLower(s))
|
||||
@@ -387,9 +430,9 @@ func NormalizeMultiAgentOrchestration(s string) string {
|
||||
|
||||
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
|
||||
type MultiAgentAPIUpdate struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
||||
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
|
||||
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
@@ -497,7 +540,7 @@ type OpenAIConfig struct {
|
||||
type OpenAIReasoningConfig struct {
|
||||
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||
// Effort: low | medium | high | max;空表示不单独指定强度(各 profile 行为见 internal/reasoning)。
|
||||
// Effort: low | medium | high | max | xhigh;max/xhigh 为不同网关最高档命名,原样下发、不互转。空表示不单独指定强度。
|
||||
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
|
||||
// AllowClientReasoning 为 false 时忽略请求体 reasoning;nil 或未设置等同于 true。
|
||||
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
|
||||
@@ -575,6 +618,51 @@ type AuthConfig struct {
|
||||
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
|
||||
type AuditConfig struct {
|
||||
// Enabled nil or true enables persistence; explicit false disables.
|
||||
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||
RetentionDays int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||
MaxDetailBytes int `yaml:"max_detail_bytes,omitempty" json:"max_detail_bytes,omitempty"`
|
||||
// AuthFailureCooldownSeconds: per-IP cooldown for auth login/change_password failure audit rows; -1 disables; 0 uses default 60.
|
||||
AuthFailureCooldownSeconds int `yaml:"auth_failure_cooldown_seconds,omitempty" json:"auth_failure_cooldown_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// EnabledEffective returns true unless audit.enabled is explicitly false.
|
||||
func (a AuditConfig) EnabledEffective() bool {
|
||||
if a.Enabled == nil {
|
||||
return true
|
||||
}
|
||||
return *a.Enabled
|
||||
}
|
||||
|
||||
// RetentionDaysEffective returns retention; 0 means keep forever.
|
||||
func (a AuditConfig) RetentionDaysEffective() int {
|
||||
if a.RetentionDays < 0 {
|
||||
return 0
|
||||
}
|
||||
return a.RetentionDays
|
||||
}
|
||||
|
||||
// MaxDetailBytesEffective caps serialized detail JSON size.
|
||||
func (a AuditConfig) MaxDetailBytesEffective() int {
|
||||
if a.MaxDetailBytes <= 0 {
|
||||
return 8192
|
||||
}
|
||||
return a.MaxDetailBytes
|
||||
}
|
||||
|
||||
// AuthFailureCooldownEffective returns seconds between duplicate auth-failure audit rows per IP (default 60; -1 disables).
|
||||
func (a AuditConfig) AuthFailureCooldownEffective() int {
|
||||
if a.AuthFailureCooldownSeconds < 0 {
|
||||
return 0
|
||||
}
|
||||
if a.AuthFailureCooldownSeconds == 0 {
|
||||
return 60
|
||||
}
|
||||
return a.AuthFailureCooldownSeconds
|
||||
}
|
||||
|
||||
// ExternalMCPConfig 外部MCP配置
|
||||
type ExternalMCPConfig struct {
|
||||
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
|
||||
@@ -667,6 +755,9 @@ func Load(path string) (*Config, error) {
|
||||
if cfg.Auth.SessionDurationHours <= 0 {
|
||||
cfg.Auth.SessionDurationHours = 12
|
||||
}
|
||||
if cfg.Audit.MaxDetailBytes <= 0 {
|
||||
cfg.Audit.MaxDetailBytes = 8192
|
||||
}
|
||||
if strings.TrimSpace(cfg.Auth.Password) == "" {
|
||||
password, err := generateStrongPassword(24)
|
||||
if err != nil {
|
||||
@@ -1170,6 +1261,14 @@ func Default() *Config {
|
||||
Auth: AuthConfig{
|
||||
SessionDurationHours: 12,
|
||||
},
|
||||
Audit: func() AuditConfig {
|
||||
on := true
|
||||
return AuditConfig{
|
||||
RetentionDays: 90,
|
||||
MaxDetailBytes: 8192,
|
||||
Enabled: &on,
|
||||
}
|
||||
}(),
|
||||
Robots: RobotsConfig{
|
||||
Session: RobotSessionConfig{
|
||||
StrictUserIdentity: &strictRobotIdentity,
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuditLog platform operation audit record.
|
||||
type AuditLog struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Level string `json:"level"`
|
||||
Category string `json:"category"`
|
||||
Action string `json:"action"`
|
||||
Result string `json:"result"`
|
||||
Actor string `json:"actor"`
|
||||
SessionHint string `json:"sessionHint,omitempty"`
|
||||
ClientIP string `json:"clientIp,omitempty"`
|
||||
UserAgent string `json:"userAgent,omitempty"`
|
||||
ResourceType string `json:"resourceType,omitempty"`
|
||||
ResourceID string `json:"resourceId,omitempty"`
|
||||
ResourceAvailable *bool `json:"resourceAvailable,omitempty"` // API-only: whether linked resource still exists
|
||||
Message string `json:"message"`
|
||||
Detail map[string]interface{} `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
// ListAuditLogsFilter query parameters.
|
||||
type ListAuditLogsFilter struct {
|
||||
Level string
|
||||
Category string
|
||||
Action string
|
||||
Result string
|
||||
Query string
|
||||
ResourceType string
|
||||
ResourceID string
|
||||
Since *time.Time
|
||||
Until *time.Time
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
|
||||
conditions := []string{"1=1"}
|
||||
args := []interface{}{}
|
||||
if filter.Level != "" {
|
||||
conditions = append(conditions, "level = ?")
|
||||
args = append(args, filter.Level)
|
||||
}
|
||||
if filter.Category != "" {
|
||||
conditions = append(conditions, "category = ?")
|
||||
args = append(args, filter.Category)
|
||||
}
|
||||
if filter.Action != "" {
|
||||
conditions = append(conditions, "action = ?")
|
||||
args = append(args, filter.Action)
|
||||
}
|
||||
if filter.Result != "" {
|
||||
conditions = append(conditions, "result = ?")
|
||||
args = append(args, filter.Result)
|
||||
}
|
||||
if filter.ResourceType != "" {
|
||||
conditions = append(conditions, "resource_type = ?")
|
||||
args = append(args, filter.ResourceType)
|
||||
}
|
||||
if filter.ResourceID != "" {
|
||||
conditions = append(conditions, "resource_id = ?")
|
||||
args = append(args, filter.ResourceID)
|
||||
}
|
||||
if filter.Since != nil {
|
||||
conditions = append(conditions, "created_at >= ?")
|
||||
args = append(args, *filter.Since)
|
||||
}
|
||||
if filter.Until != nil {
|
||||
conditions = append(conditions, "created_at <= ?")
|
||||
args = append(args, *filter.Until)
|
||||
}
|
||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||
like := "%" + q + "%"
|
||||
conditions = append(conditions, "(message LIKE ? OR resource_id LIKE ? OR action LIKE ? OR category LIKE ?)")
|
||||
args = append(args, like, like, like, like)
|
||||
}
|
||||
return strings.Join(conditions, " AND "), args
|
||||
}
|
||||
|
||||
// AppendAuditLog inserts one audit row.
|
||||
func (db *DB) AppendAuditLog(row *AuditLog) error {
|
||||
if row == nil {
|
||||
return errors.New("audit log is nil")
|
||||
}
|
||||
if strings.TrimSpace(row.ID) == "" {
|
||||
return errors.New("audit id is required")
|
||||
}
|
||||
if row.CreatedAt.IsZero() {
|
||||
row.CreatedAt = time.Now()
|
||||
}
|
||||
if strings.TrimSpace(row.Level) == "" {
|
||||
row.Level = "info"
|
||||
}
|
||||
detailJSON := ""
|
||||
if len(row.Detail) > 0 {
|
||||
if b, err := json.Marshal(row.Detail); err == nil {
|
||||
detailJSON = string(b)
|
||||
}
|
||||
}
|
||||
query := `
|
||||
INSERT INTO audit_logs (
|
||||
id, created_at, level, category, action, result, actor, session_hint,
|
||||
client_ip, user_agent, resource_type, resource_id, message, detail_json
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
_, err := db.Exec(query,
|
||||
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result,
|
||||
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
|
||||
row.ResourceType, row.ResourceID, row.Message, detailJSON,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAuditLogByID returns one row.
|
||||
func (db *DB) GetAuditLogByID(id string) (*AuditLog, error) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return nil, errors.New("id is required")
|
||||
}
|
||||
query := `
|
||||
SELECT id, created_at, level, category, action, result, actor,
|
||||
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
|
||||
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
|
||||
FROM audit_logs WHERE id = ?
|
||||
`
|
||||
var row AuditLog
|
||||
var detailJSON string
|
||||
err := db.QueryRow(query, id).Scan(
|
||||
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
|
||||
&row.SessionHint, &row.ClientIP, &row.UserAgent,
|
||||
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if detailJSON != "" {
|
||||
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// CountAuditLogs counts rows matching filter.
|
||||
func (db *DB) CountAuditLogs(filter ListAuditLogsFilter) (int64, error) {
|
||||
where, args := buildAuditLogsWhere(filter)
|
||||
query := `SELECT COUNT(*) FROM audit_logs WHERE ` + where
|
||||
var n int64
|
||||
err := db.QueryRow(query, args...).Scan(&n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// ListAuditLogs lists audit rows newest first.
|
||||
func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
|
||||
where, args := buildAuditLogsWhere(filter)
|
||||
limit := filter.Limit
|
||||
if limit <= 0 || limit > 500 {
|
||||
limit = 50
|
||||
}
|
||||
offset := filter.Offset
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
query := `
|
||||
SELECT id, created_at, level, category, action, result, actor,
|
||||
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
|
||||
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
|
||||
FROM audit_logs
|
||||
WHERE ` + where + `
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
`
|
||||
args = append(args, limit, offset)
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var list []*AuditLog
|
||||
for rows.Next() {
|
||||
var row AuditLog
|
||||
var detailJSON string
|
||||
if err := rows.Scan(
|
||||
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
|
||||
&row.SessionHint, &row.ClientIP, &row.UserAgent,
|
||||
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
|
||||
); err != nil {
|
||||
continue
|
||||
}
|
||||
if detailJSON != "" {
|
||||
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
|
||||
}
|
||||
list = append(list, &row)
|
||||
}
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteAuditLogsBefore removes rows older than cutoff.
|
||||
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
|
||||
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
@@ -22,6 +22,7 @@ type BatchTaskQueueRow struct {
|
||||
LastScheduleTriggerAt sql.NullTime
|
||||
LastScheduleError sql.NullString
|
||||
LastRunError sql.NullString
|
||||
ProjectID sql.NullString
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
StartedAt sql.NullTime
|
||||
@@ -51,6 +52,7 @@ func (db *DB) CreateBatchQueue(
|
||||
scheduleMode string,
|
||||
cronExpr string,
|
||||
nextRunAt *time.Time,
|
||||
projectID string,
|
||||
tasks []map[string]interface{},
|
||||
) error {
|
||||
tx, err := db.Begin()
|
||||
@@ -65,9 +67,13 @@ func (db *DB) CreateBatchQueue(
|
||||
nextRunAtValue = *nextRunAt
|
||||
}
|
||||
|
||||
var projectIDVal interface{}
|
||||
if strings.TrimSpace(projectID) != "" {
|
||||
projectIDVal = strings.TrimSpace(projectID)
|
||||
}
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, "pending", now, 0,
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
||||
@@ -101,9 +107,9 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
||||
queueID,
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -127,7 +133,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
// GetAllBatchQueues 获取所有批量任务队列
|
||||
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||
@@ -138,7 +144,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
@@ -158,7 +164,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
|
||||
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
||||
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
// 状态筛选
|
||||
@@ -186,7 +192,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
type Conversation struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
Pinned bool `json:"pinned"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
@@ -37,22 +38,41 @@ type Message struct {
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
func (db *DB) CreateConversation(title string) (*Conversation, error) {
|
||||
return db.CreateConversationWithWebshell("", title)
|
||||
func (db *DB) CreateConversation(title string, meta ConversationCreateMeta) (*Conversation, error) {
|
||||
return db.CreateConversationWithWebshell("", title, meta)
|
||||
}
|
||||
|
||||
// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话)
|
||||
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) {
|
||||
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string, meta ConversationCreateMeta) (*Conversation, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
projectID := strings.TrimSpace(meta.ProjectID)
|
||||
if projectID != "" {
|
||||
if _, err := db.GetProject(projectID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if webshellConnectionID != "" {
|
||||
wsID := strings.TrimSpace(webshellConnectionID)
|
||||
switch {
|
||||
case wsID != "" && projectID != "":
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id, project_id) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
id, title, now, now, wsID, projectID,
|
||||
)
|
||||
case wsID != "":
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)",
|
||||
id, title, now, now, webshellConnectionID,
|
||||
id, title, now, now, wsID,
|
||||
)
|
||||
} else {
|
||||
case projectID != "":
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at, project_id) VALUES (?, ?, ?, ?, ?)",
|
||||
id, title, now, now, projectID,
|
||||
)
|
||||
default:
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
|
||||
id, title, now, now,
|
||||
@@ -62,12 +82,18 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string)
|
||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||
}
|
||||
|
||||
return &Conversation{
|
||||
conv := &Conversation{
|
||||
ID: id,
|
||||
Title: title,
|
||||
ProjectID: projectID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
if wsID != "" {
|
||||
meta.WebShellConnectionID = wsID
|
||||
}
|
||||
notifyConversationCreated(conv, meta)
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化)
|
||||
@@ -182,22 +208,43 @@ func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]We
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
// ConversationExists reports whether a conversation row exists (lightweight check for audit links).
|
||||
func (db *DB) ConversationExists(id string) (bool, error) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return false, nil
|
||||
}
|
||||
var one int
|
||||
err := db.QueryRow("SELECT 1 FROM conversations WHERE id = ? LIMIT 1", id).Scan(&one)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetConversation 获取对话
|
||||
func (db *DB) GetConversation(id string) (*Conversation, error) {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
|
||||
var projectID sql.NullString
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
|
||||
"SELECT id, title, pinned, created_at, updated_at, project_id FROM conversations WHERE id = ?",
|
||||
id,
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("对话不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||
}
|
||||
if projectID.Valid {
|
||||
conv.ProjectID = strings.TrimSpace(projectID.String)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
@@ -270,16 +317,20 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
|
||||
var projectID sql.NullString
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
|
||||
"SELECT id, title, pinned, created_at, updated_at, project_id FROM conversations WHERE id = ?",
|
||||
id,
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("对话不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||
}
|
||||
if projectID.Valid {
|
||||
conv.ProjectID = strings.TrimSpace(projectID.String)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
@@ -319,7 +370,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
||||
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
||||
searchPattern := "%" + search + "%"
|
||||
rows, err = db.Query(
|
||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at
|
||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
|
||||
FROM conversations c
|
||||
WHERE c.title LIKE ?
|
||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
||||
@@ -329,7 +380,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
||||
)
|
||||
} else {
|
||||
rows, err = db.Query(
|
||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||||
limit, offset,
|
||||
)
|
||||
}
|
||||
@@ -344,10 +395,14 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
var projectID sql.NullString
|
||||
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil {
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID); err != nil {
|
||||
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||||
}
|
||||
if projectID.Valid {
|
||||
conv.ProjectID = strings.TrimSpace(projectID.String)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package database
|
||||
|
||||
// ConversationCreateMeta describes how a conversation was created (for audit hooks).
|
||||
type ConversationCreateMeta struct {
|
||||
Source string
|
||||
WebShellConnectionID string
|
||||
ProjectID string
|
||||
ClientIP string
|
||||
SessionHint string
|
||||
}
|
||||
|
||||
// ConversationCreateHook is invoked after a conversation row is inserted.
|
||||
type ConversationCreateHook func(conv *Conversation, meta ConversationCreateMeta)
|
||||
|
||||
var conversationCreateHook ConversationCreateHook
|
||||
|
||||
// SetConversationCreateHook registers a global hook (e.g. platform audit).
|
||||
func SetConversationCreateHook(h ConversationCreateHook) {
|
||||
conversationCreateHook = h
|
||||
}
|
||||
|
||||
func notifyConversationCreated(conv *Conversation, meta ConversationCreateMeta) {
|
||||
if conversationCreateHook == nil || conv == nil {
|
||||
return
|
||||
}
|
||||
if meta.Source == "" {
|
||||
meta.Source = "unknown"
|
||||
}
|
||||
conversationCreateHook(conv, meta)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -12,19 +13,106 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
// SQLite 在 WAL 模式下建议使用较保守的连接数,降低长读快照导致 checkpoint 饥饿的概率。
|
||||
sqliteMaxOpenConns = 25
|
||||
sqliteMaxIdleConns = 5
|
||||
// 以页为单位的自动 checkpoint 触发阈值(默认 1000 页,约 4MB @ 4KB/page)。
|
||||
sqliteWALAutoCheckpointPages = 1000
|
||||
// 控制 WAL 目标上限,避免异常场景持续膨胀(256MB)。
|
||||
sqliteJournalSizeLimitBytes = 256 * 1024 * 1024
|
||||
// 定时执行 PASSIVE checkpoint,平滑推进 WAL 回收。
|
||||
sqlitePassiveCheckpointInterval = 300 * time.Second
|
||||
)
|
||||
|
||||
// configureDBPool 设置 SQLite 连接池参数,提升并发稳定性
|
||||
func configureDBPool(db *sql.DB) {
|
||||
// SQLite 同一时间只允许一个写入者,限制连接数避免 "database is locked" 错误
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
// SQLite 同一时间只允许一个写入者;过高连接数会放大锁竞争和 WAL 回收延迟。
|
||||
db.SetMaxOpenConns(sqliteMaxOpenConns)
|
||||
db.SetMaxIdleConns(sqliteMaxIdleConns)
|
||||
db.SetConnMaxLifetime(30 * time.Minute)
|
||||
}
|
||||
|
||||
// configureSQLitePragmas 调整 WAL 回收行为,降低 -wal 文件长期膨胀风险。
|
||||
func configureSQLitePragmas(db *sql.DB) error {
|
||||
if _, err := db.Exec(fmt.Sprintf("PRAGMA wal_autocheckpoint=%d", sqliteWALAutoCheckpointPages)); err != nil {
|
||||
return fmt.Errorf("设置 wal_autocheckpoint 失败: %w", err)
|
||||
}
|
||||
if _, err := db.Exec(fmt.Sprintf("PRAGMA journal_size_limit=%d", sqliteJournalSizeLimitBytes)); err != nil {
|
||||
return fmt.Errorf("设置 journal_size_limit 失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DB 数据库连接
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
logger *zap.Logger
|
||||
conversationArtifactsDir string
|
||||
checkpointLoopName string
|
||||
checkpointStop chan struct{}
|
||||
checkpointDone chan struct{}
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// startPassiveCheckpointLoop 启动后台 PASSIVE checkpoint 循环。
|
||||
func (db *DB) startPassiveCheckpointLoop(name string) {
|
||||
if sqlitePassiveCheckpointInterval <= 0 || db == nil || db.DB == nil {
|
||||
return
|
||||
}
|
||||
db.checkpointLoopName = strings.TrimSpace(name)
|
||||
db.checkpointStop = make(chan struct{})
|
||||
db.checkpointDone = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(db.checkpointDone)
|
||||
ticker := time.NewTicker(sqlitePassiveCheckpointInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动后先尝试一次,尽快回收已有 WAL 堆积。
|
||||
db.runPassiveCheckpoint("startup")
|
||||
for {
|
||||
select {
|
||||
case <-db.checkpointStop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
db.runPassiveCheckpoint("ticker")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// runPassiveCheckpoint 执行一次 PRAGMA wal_checkpoint(PASSIVE)。
|
||||
func (db *DB) runPassiveCheckpoint(trigger string) {
|
||||
if db == nil || db.DB == nil {
|
||||
return
|
||||
}
|
||||
startAt := time.Now()
|
||||
var busy, logFrames, checkpointed int
|
||||
err := db.QueryRow("PRAGMA wal_checkpoint(PASSIVE)").Scan(&busy, &logFrames, &checkpointed)
|
||||
if db.logger == nil {
|
||||
return
|
||||
}
|
||||
fields := []zap.Field{
|
||||
zap.String("db", db.checkpointLoopName),
|
||||
zap.String("trigger", trigger),
|
||||
zap.Int("busy", busy),
|
||||
zap.Int("log_frames", logFrames),
|
||||
zap.Int("checkpointed_frames", checkpointed),
|
||||
zap.Int64("elapsed_ms", time.Since(startAt).Milliseconds()),
|
||||
}
|
||||
if err != nil {
|
||||
db.logger.Warn("SQLite PASSIVE checkpoint 完成(失败)",
|
||||
append(fields, zap.Error(err))...,
|
||||
)
|
||||
return
|
||||
}
|
||||
if busy > 0 {
|
||||
db.logger.Info("SQLite PASSIVE checkpoint 完成(部分推进)", fields...)
|
||||
return
|
||||
}
|
||||
db.logger.Info("SQLite PASSIVE checkpoint 完成(成功)", fields...)
|
||||
}
|
||||
|
||||
// NewDB 创建数据库连接
|
||||
@@ -37,8 +125,13 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
configureDBPool(db)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
if err := configureSQLitePragmas(db); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("配置数据库 PRAGMA 失败: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{
|
||||
DB: db,
|
||||
@@ -54,8 +147,10 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
|
||||
// 初始化表
|
||||
if err := database.initTables(); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("初始化表失败: %w", err)
|
||||
}
|
||||
database.startPassiveCheckpointLoop("conversations")
|
||||
|
||||
return database, nil
|
||||
}
|
||||
@@ -213,6 +308,59 @@ func (db *DB) initTables() error {
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建项目表
|
||||
createProjectsTable := `
|
||||
CREATE TABLE IF NOT EXISTS projects (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
scope_json TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL
|
||||
);`
|
||||
|
||||
// 创建项目事实表(黑板)
|
||||
createProjectFactsTable := `
|
||||
CREATE TABLE IF NOT EXISTS project_facts (
|
||||
id TEXT PRIMARY KEY,
|
||||
project_id TEXT NOT NULL,
|
||||
fact_key TEXT NOT NULL,
|
||||
category TEXT NOT NULL DEFAULT 'note',
|
||||
summary TEXT NOT NULL DEFAULT '',
|
||||
body TEXT,
|
||||
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||
source_conversation_id TEXT,
|
||||
source_message_id TEXT,
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
supersedes_fact_id TEXT,
|
||||
related_vulnerability_id TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||
UNIQUE(project_id, fact_key)
|
||||
);`
|
||||
|
||||
createProjectFactVersionsTable := `
|
||||
CREATE TABLE IF NOT EXISTS project_fact_versions (
|
||||
id TEXT PRIMARY KEY,
|
||||
fact_id TEXT NOT NULL,
|
||||
project_id TEXT NOT NULL,
|
||||
fact_key TEXT NOT NULL,
|
||||
category TEXT NOT NULL DEFAULT 'note',
|
||||
summary TEXT NOT NULL DEFAULT '',
|
||||
body TEXT,
|
||||
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||
source_conversation_id TEXT,
|
||||
source_message_id TEXT,
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
related_vulnerability_id TEXT,
|
||||
archived_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建漏洞表
|
||||
createVulnerabilitiesTable := `
|
||||
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
||||
@@ -387,6 +535,24 @@ func (db *DB) initTables() error {
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
createAuditLogsTable := `
|
||||
CREATE TABLE IF NOT EXISTS audit_logs (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at DATETIME NOT NULL,
|
||||
level TEXT NOT NULL DEFAULT 'info',
|
||||
category TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
result TEXT NOT NULL,
|
||||
actor TEXT NOT NULL DEFAULT 'admin',
|
||||
session_hint TEXT,
|
||||
client_ip TEXT,
|
||||
user_agent TEXT,
|
||||
resource_type TEXT,
|
||||
resource_id TEXT,
|
||||
message TEXT NOT NULL,
|
||||
detail_json TEXT
|
||||
);`
|
||||
|
||||
createC2ProfilesTable := `
|
||||
CREATE TABLE IF NOT EXISTS c2_profiles (
|
||||
id TEXT PRIMARY KEY,
|
||||
@@ -427,6 +593,14 @@ func (db *DB) initTables() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_projects_status ON projects(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_projects_updated_at ON projects(updated_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
|
||||
@@ -445,6 +619,10 @@ func (db *DB) initTables() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_category ON audit_logs(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_action ON audit_logs(action);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_result ON audit_logs(result);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(createConversationsTable); err != nil {
|
||||
@@ -494,6 +672,18 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建robot_user_sessions表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createProjectsTable); err != nil {
|
||||
return fmt.Errorf("创建projects表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createProjectFactsTable); err != nil {
|
||||
return fmt.Errorf("创建project_facts表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createProjectFactVersionsTable); err != nil {
|
||||
return fmt.Errorf("创建project_fact_versions表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
||||
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
|
||||
}
|
||||
@@ -514,6 +704,10 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createAuditLogsTable); err != nil {
|
||||
return fmt.Errorf("创建audit_logs表失败: %w", err)
|
||||
}
|
||||
|
||||
for tableName, ddl := range map[string]string{
|
||||
"c2_listeners": createC2ListenersTable,
|
||||
"c2_sessions": createC2SessionsTable,
|
||||
@@ -557,6 +751,13 @@ func (db *DB) initTables() error {
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if err := db.migrateProjectsTable(); err != nil {
|
||||
db.logger.Warn("迁移projects相关表失败", zap.Error(err))
|
||||
}
|
||||
if err := db.migrateProjectFactVersionsTable(); err != nil {
|
||||
db.logger.Warn("迁移project_fact_versions表失败", zap.Error(err))
|
||||
}
|
||||
|
||||
if err := db.migrateWebshellConnectionsTable(); err != nil {
|
||||
db.logger.Warn("迁移webshell_connections表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
@@ -904,6 +1105,79 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
|
||||
}
|
||||
}
|
||||
|
||||
var projectIDCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='project_id'").Scan(&projectIDCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN project_id TEXT"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加batch_task_queues.project_id字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if projectIDCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN project_id TEXT"); err != nil {
|
||||
db.logger.Warn("添加batch_task_queues.project_id字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateProjectsTable 迁移 projects / conversations / vulnerabilities 的项目关联字段。
|
||||
func (db *DB) migrateProjectsTable() error {
|
||||
for _, col := range []struct {
|
||||
table string
|
||||
name string
|
||||
stmt string
|
||||
}{
|
||||
{"conversations", "project_id", "ALTER TABLE conversations ADD COLUMN project_id TEXT REFERENCES projects(id) ON DELETE SET NULL"},
|
||||
{"vulnerabilities", "project_id", "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"},
|
||||
} {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info(?) WHERE name=?", col.table, col.name).Scan(&count)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec(col.stmt); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加字段失败", zap.String("table", col.table), zap.String("field", col.name), zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if count == 0 {
|
||||
if _, addErr := db.Exec(col.stmt); addErr != nil {
|
||||
db.logger.Warn("添加字段失败", zap.String("table", col.table), zap.String("field", col.name), zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateProjectFactVersionsTable 为已有库创建事实版本表。
|
||||
func (db *DB) migrateProjectFactVersionsTable() error {
|
||||
ddl := `
|
||||
CREATE TABLE IF NOT EXISTS project_fact_versions (
|
||||
id TEXT PRIMARY KEY,
|
||||
fact_id TEXT NOT NULL,
|
||||
project_id TEXT NOT NULL,
|
||||
fact_key TEXT NOT NULL,
|
||||
category TEXT NOT NULL DEFAULT 'note',
|
||||
summary TEXT NOT NULL DEFAULT '',
|
||||
body TEXT,
|
||||
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||
source_conversation_id TEXT,
|
||||
source_message_id TEXT,
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
related_vulnerability_id TEXT,
|
||||
archived_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
||||
);`
|
||||
if _, err := db.Exec(ddl); err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id)`)
|
||||
_, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id)`)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -915,6 +1189,7 @@ func (db *DB) migrateVulnerabilitiesTable() error {
|
||||
}{
|
||||
{name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"},
|
||||
{name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"},
|
||||
{name: "project_id", stmt: "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"},
|
||||
}
|
||||
|
||||
for _, col := range columns {
|
||||
@@ -979,8 +1254,13 @@ func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
configureDBPool(sqlDB)
|
||||
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
_ = sqlDB.Close()
|
||||
return nil, fmt.Errorf("连接知识库数据库失败: %w", err)
|
||||
}
|
||||
if err := configureSQLitePragmas(sqlDB); err != nil {
|
||||
_ = sqlDB.Close()
|
||||
return nil, fmt.Errorf("配置知识库数据库 PRAGMA 失败: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{
|
||||
DB: sqlDB,
|
||||
@@ -989,8 +1269,10 @@ func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
|
||||
// 初始化知识库表
|
||||
if err := database.initKnowledgeTables(); err != nil {
|
||||
_ = sqlDB.Close()
|
||||
return nil, fmt.Errorf("初始化知识库表失败: %w", err)
|
||||
}
|
||||
database.startPassiveCheckpointLoop("knowledge")
|
||||
|
||||
return database, nil
|
||||
}
|
||||
@@ -1104,5 +1386,19 @@ func (db *DB) migrateKnowledgeEmbeddingsColumns() error {
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
if db == nil {
|
||||
return nil
|
||||
}
|
||||
db.closeOnce.Do(func() {
|
||||
if db.checkpointStop != nil {
|
||||
close(db.checkpointStop)
|
||||
if db.checkpointDone != nil {
|
||||
<-db.checkpointDone
|
||||
}
|
||||
}
|
||||
if db.DB != nil {
|
||||
db.closeErr = db.DB.Close()
|
||||
}
|
||||
})
|
||||
return db.closeErr
|
||||
}
|
||||
|
||||
@@ -0,0 +1,513 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var factKeyPattern = regexp.MustCompile(`^[a-z0-9][a-z0-9._/-]*$`)
|
||||
|
||||
// ValidateFactKey 校验事实 key(项目内唯一标识)。
|
||||
func ValidateFactKey(key string) error {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return fmt.Errorf("fact_key 不能为空")
|
||||
}
|
||||
if len(key) > 128 {
|
||||
return fmt.Errorf("fact_key 过长(最多 128 字符)")
|
||||
}
|
||||
if !factKeyPattern.MatchString(key) {
|
||||
return fmt.Errorf("fact_key 格式无效,仅允许小写字母、数字及 . _ / -,且须以小写字母或数字开头")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Project 渗透测试项目(跨对话共享黑板)。
|
||||
type Project struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
ScopeJSON string `json:"scope_json,omitempty"`
|
||||
Status string `json:"status"` // active | archived
|
||||
Pinned bool `json:"pinned"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ProjectFact 项目事实(黑板条目)。
|
||||
type ProjectFact struct {
|
||||
ID string `json:"id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
FactKey string `json:"fact_key"`
|
||||
Category string `json:"category"`
|
||||
Summary string `json:"summary"`
|
||||
Body string `json:"body"`
|
||||
Confidence string `json:"confidence"` // confirmed | tentative | deprecated
|
||||
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
||||
SourceMessageID string `json:"source_message_id,omitempty"`
|
||||
Pinned bool `json:"pinned"`
|
||||
SupersedesFactID string `json:"supersedes_fact_id,omitempty"`
|
||||
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ProjectFactListFilter 事实列表筛选。
|
||||
type ProjectFactListFilter struct {
|
||||
Category string
|
||||
Confidence string
|
||||
Search string
|
||||
RelatedVulnerabilityID string
|
||||
ExcludeDeprecated bool // 为 true 时排除 confidence=deprecated
|
||||
}
|
||||
|
||||
// CreateProject 创建项目。
|
||||
func (db *DB) CreateProject(p *Project) (*Project, error) {
|
||||
if p.ID == "" {
|
||||
p.ID = uuid.New().String()
|
||||
}
|
||||
if strings.TrimSpace(p.Status) == "" {
|
||||
p.Status = "active"
|
||||
}
|
||||
now := time.Now()
|
||||
if p.CreatedAt.IsZero() {
|
||||
p.CreatedAt = now
|
||||
}
|
||||
p.UpdatedAt = now
|
||||
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO projects (id, name, description, scope_json, status, pinned, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
p.ID, p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.CreatedAt, p.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建项目失败: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// GetProject 获取项目。
|
||||
func (db *DB) GetProject(id string) (*Project, error) {
|
||||
var p Project
|
||||
var pinned int
|
||||
var createdAt, updatedAt string
|
||||
err := db.QueryRow(
|
||||
`SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
||||
FROM projects WHERE id = ?`, id,
|
||||
).Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("项目不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("获取项目失败: %w", err)
|
||||
}
|
||||
p.Pinned = pinned != 0
|
||||
p.CreatedAt = parseDBTime(createdAt)
|
||||
p.UpdatedAt = parseDBTime(updatedAt)
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// ListProjects 列出项目。
|
||||
func (db *DB) ListProjects(status string, limit, offset int) ([]*Project, error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
||||
FROM projects WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
if s := strings.TrimSpace(status); s != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, s)
|
||||
}
|
||||
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("列出项目失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*Project
|
||||
for rows.Next() {
|
||||
var p Project
|
||||
var pinned int
|
||||
var createdAt, updatedAt string
|
||||
if err := rows.Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Pinned = pinned != 0
|
||||
p.CreatedAt = parseDBTime(createdAt)
|
||||
p.UpdatedAt = parseDBTime(updatedAt)
|
||||
out = append(out, &p)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateProject 更新项目。
|
||||
func (db *DB) UpdateProject(p *Project) error {
|
||||
p.UpdatedAt = time.Now()
|
||||
_, err := db.Exec(
|
||||
`UPDATE projects SET name = ?, description = ?, scope_json = ?, status = ?, pinned = ?, updated_at = ? WHERE id = ?`,
|
||||
p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.UpdatedAt, p.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新项目失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteProject 删除项目(级联删除事实;对话 project_id 置空由 FK 处理;漏洞 project_id 置空)。
|
||||
func (db *DB) DeleteProject(id string) error {
|
||||
if _, err := db.Exec(`UPDATE vulnerabilities SET project_id = NULL WHERE project_id = ?`, id); err != nil {
|
||||
return fmt.Errorf("解除漏洞项目关联失败: %w", err)
|
||||
}
|
||||
_, err := db.Exec(`DELETE FROM projects WHERE id = ?`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除项目失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConversationProjectID 返回对话绑定的项目 ID。
|
||||
func (db *DB) GetConversationProjectID(conversationID string) (string, error) {
|
||||
var pid sql.NullString
|
||||
err := db.QueryRow(`SELECT project_id FROM conversations WHERE id = ?`, conversationID).Scan(&pid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("对话不存在")
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if pid.Valid {
|
||||
return strings.TrimSpace(pid.String), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// SetConversationProjectID 设置对话所属项目(空字符串表示解除绑定)。
|
||||
func (db *DB) SetConversationProjectID(conversationID, projectID string) error {
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID != "" {
|
||||
if _, err := db.GetProject(projectID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
var val interface{}
|
||||
if projectID == "" {
|
||||
val = nil
|
||||
} else {
|
||||
val = projectID
|
||||
}
|
||||
_, err := db.Exec(`UPDATE conversations SET project_id = ?, updated_at = ? WHERE id = ?`, val, time.Now(), conversationID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("设置对话项目失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListProjectFactsForIndex 列出用于黑板索引注入的事实(不含 deprecated,除非 includeDeprecated)。
|
||||
func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) {
|
||||
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||
FROM project_facts WHERE project_id = ?`
|
||||
args := []interface{}{projectID}
|
||||
if !includeDeprecated {
|
||||
query += " AND confidence != 'deprecated'"
|
||||
}
|
||||
query += " ORDER BY pinned DESC, updated_at DESC"
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanProjectFacts(rows)
|
||||
}
|
||||
|
||||
// ListProjectFacts 分页列出项目事实。
|
||||
func (db *DB) ListProjectFacts(projectID string, filter ProjectFactListFilter, limit, offset int) ([]*ProjectFact, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||
FROM project_facts WHERE project_id = ?`
|
||||
args := []interface{}{projectID}
|
||||
if c := strings.TrimSpace(filter.Category); c != "" {
|
||||
query += " AND category = ?"
|
||||
args = append(args, c)
|
||||
}
|
||||
if c := strings.TrimSpace(filter.Confidence); c != "" {
|
||||
query += " AND confidence = ?"
|
||||
args = append(args, c)
|
||||
}
|
||||
if filter.ExcludeDeprecated {
|
||||
query += " AND confidence != 'deprecated'"
|
||||
}
|
||||
if rid := strings.TrimSpace(filter.RelatedVulnerabilityID); rid != "" {
|
||||
query += " AND related_vulnerability_id = ?"
|
||||
args = append(args, rid)
|
||||
}
|
||||
if s := strings.TrimSpace(filter.Search); s != "" {
|
||||
pat := "%" + s + "%"
|
||||
query += " AND (fact_key LIKE ? OR summary LIKE ? OR body LIKE ?)"
|
||||
args = append(args, pat, pat, pat)
|
||||
}
|
||||
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanProjectFacts(rows)
|
||||
}
|
||||
|
||||
// GetProjectFactByKey 按 key 获取事实。
|
||||
func (db *DB) GetProjectFactByKey(projectID, factKey string) (*ProjectFact, error) {
|
||||
row := db.QueryRow(
|
||||
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||
FROM project_facts WHERE project_id = ? AND fact_key = ?`,
|
||||
projectID, factKey,
|
||||
)
|
||||
return scanProjectFactRow(row)
|
||||
}
|
||||
|
||||
// GetProjectFact 按 ID 获取事实。
|
||||
func (db *DB) GetProjectFact(id string) (*ProjectFact, error) {
|
||||
row := db.QueryRow(
|
||||
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
||||
FROM project_facts WHERE id = ?`, id,
|
||||
)
|
||||
return scanProjectFactRow(row)
|
||||
}
|
||||
|
||||
// mergeFactBodyOnUpdate 更新时若 incoming body 为空则保留已有内容,避免仅改 summary 时丢失攻击链。
|
||||
func mergeFactBodyOnUpdate(incoming, existing string) string {
|
||||
if strings.TrimSpace(incoming) == "" {
|
||||
return existing
|
||||
}
|
||||
return incoming
|
||||
}
|
||||
|
||||
// UpsertProjectFact 创建或更新事实(按 project_id + fact_key)。
|
||||
func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
|
||||
if err := ValidateFactKey(f.FactKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(f.Category) == "" {
|
||||
f.Category = "note"
|
||||
}
|
||||
if strings.TrimSpace(f.Confidence) == "" {
|
||||
f.Confidence = "tentative"
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
existing, err := db.GetProjectFactByKey(f.ProjectID, f.FactKey)
|
||||
if err == nil && existing != nil {
|
||||
f.ID = existing.ID
|
||||
f.CreatedAt = existing.CreatedAt
|
||||
f.UpdatedAt = now
|
||||
f.Body = mergeFactBodyOnUpdate(f.Body, existing.Body)
|
||||
if strings.TrimSpace(f.Category) == "" {
|
||||
f.Category = existing.Category
|
||||
}
|
||||
if strings.TrimSpace(f.Confidence) == "" {
|
||||
f.Confidence = existing.Confidence
|
||||
}
|
||||
if projectFactContentChanged(existing, f) {
|
||||
versionID, verr := db.InsertProjectFactVersion(existing)
|
||||
if verr != nil {
|
||||
return nil, verr
|
||||
}
|
||||
f.SupersedesFactID = versionID
|
||||
} else if f.SupersedesFactID == "" {
|
||||
f.SupersedesFactID = existing.SupersedesFactID
|
||||
}
|
||||
_, err = db.Exec(
|
||||
`UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?,
|
||||
source_conversation_id = COALESCE(?, source_conversation_id),
|
||||
source_message_id = COALESCE(?, source_message_id),
|
||||
pinned = ?, supersedes_fact_id = ?, related_vulnerability_id = ?, updated_at = ?
|
||||
WHERE id = ?`,
|
||||
f.Category, f.Summary, f.Body, f.Confidence,
|
||||
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("更新事实失败: %w", err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
if f.ID == "" {
|
||||
f.ID = uuid.New().String()
|
||||
}
|
||||
f.CreatedAt = now
|
||||
f.UpdatedAt = now
|
||||
_, err = db.Exec(
|
||||
`INSERT INTO project_facts (
|
||||
id, project_id, fact_key, category, summary, body, confidence,
|
||||
source_conversation_id, source_message_id, pinned, supersedes_fact_id, related_vulnerability_id,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
|
||||
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID),
|
||||
f.CreatedAt, f.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建事实失败: %w", err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// DeprecateProjectFact 将事实标记为 deprecated。
|
||||
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
||||
res, err := db.Exec(
|
||||
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
||||
time.Now(), projectID, factKey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return fmt.Errorf("事实不存在")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。
|
||||
func (db *DB) RestoreProjectFact(projectID, factKey, confidence string) error {
|
||||
confidence = strings.TrimSpace(strings.ToLower(confidence))
|
||||
if confidence == "" {
|
||||
confidence = "tentative"
|
||||
}
|
||||
if confidence != "confirmed" && confidence != "tentative" {
|
||||
return fmt.Errorf("confidence 须为 confirmed 或 tentative")
|
||||
}
|
||||
|
||||
existing, err := db.GetProjectFactByKey(projectID, factKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("事实不存在")
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(existing.Confidence)) != "deprecated" {
|
||||
return fmt.Errorf("事实未处于废弃状态")
|
||||
}
|
||||
|
||||
_, err = db.Exec(
|
||||
`UPDATE project_facts SET confidence = ?, updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
||||
confidence, time.Now(), projectID, factKey,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteProjectFact 删除事实。
|
||||
func (db *DB) DeleteProjectFact(id string) error {
|
||||
_, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func scanProjectFacts(rows *sql.Rows) ([]*ProjectFact, error) {
|
||||
var out []*ProjectFact
|
||||
for rows.Next() {
|
||||
f, err := scanProjectFactFromRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, f)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func scanProjectFactRow(row *sql.Row) (*ProjectFact, error) {
|
||||
var f ProjectFact
|
||||
var pinned int
|
||||
var createdAt, updatedAt string
|
||||
err := row.Scan(
|
||||
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
||||
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
||||
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("事实不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
f.Pinned = pinned != 0
|
||||
f.CreatedAt = parseDBTime(createdAt)
|
||||
f.UpdatedAt = parseDBTime(updatedAt)
|
||||
return &f, nil
|
||||
}
|
||||
|
||||
func scanProjectFactFromRows(rows *sql.Rows) (*ProjectFact, error) {
|
||||
var f ProjectFact
|
||||
var pinned int
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
||||
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
||||
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.Pinned = pinned != 0
|
||||
f.CreatedAt = parseDBTime(createdAt)
|
||||
f.UpdatedAt = parseDBTime(updatedAt)
|
||||
return &f, nil
|
||||
}
|
||||
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func nullIfEmpty(s string) interface{} {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func parseDBTime(s string) time.Time {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
// go-sqlite3 读 DATETIME 常返回 RFC3339(含 T),写入时可能是空格分隔格式,需兼容多种形态
|
||||
layouts := []string{
|
||||
time.RFC3339Nano,
|
||||
time.RFC3339,
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05-07:00",
|
||||
"2006-01-02T15:04:05.999999999-07:00",
|
||||
"2006-01-02T15:04:05-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05.999999999",
|
||||
"2006-01-02T15:04:05",
|
||||
}
|
||||
for _, layout := range layouts {
|
||||
if t, e := time.Parse(layout, s); e == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestUpsertProjectFact_preservesBodyOnEmptyUpdate(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
proj, err := db.CreateProject(&Project{Name: "test-facts"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const body = "## 攻击链\n1. step\n```http\nGET / HTTP/1.1\n```\n"
|
||||
_, err = db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "finding/sqli-login",
|
||||
Category: "finding",
|
||||
Summary: "SQLi on /login",
|
||||
Body: body,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
updated, err := db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "finding/sqli-login",
|
||||
Summary: "SQLi on /login (confirmed)",
|
||||
Body: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if updated.Summary != "SQLi on /login (confirmed)" {
|
||||
t.Fatalf("summary=%q", updated.Summary)
|
||||
}
|
||||
if updated.Body != body {
|
||||
t.Fatalf("returned body=%q want preserved attack chain", updated.Body)
|
||||
}
|
||||
|
||||
fromDB, err := db.GetProjectFactByKey(proj.ID, "finding/sqli-login")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if fromDB.Body != body {
|
||||
t.Fatalf("stored body=%q want preserved", fromDB.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertProjectFact_replacesBodyWhenProvided(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
proj, err := db.CreateProject(&Project{Name: "test-facts"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "target/primary",
|
||||
Summary: "v1",
|
||||
Body: "old body",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const newBody = "new body with evidence"
|
||||
updated, err := db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "target/primary",
|
||||
Summary: "v2",
|
||||
Body: newBody,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if updated.Body != newBody {
|
||||
t.Fatalf("body=%q want %q", updated.Body, newBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreProjectFact(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
proj, err := db.CreateProject(&Project{Name: "restore-test"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
key := "target/restore-me"
|
||||
_, err = db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: key,
|
||||
Summary: "s",
|
||||
Confidence: "confirmed",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.DeprecateProjectFact(proj.ID, key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.RestoreProjectFact(proj.ID, key, "confirmed"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f, err := db.GetProjectFactByKey(proj.ID, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if f.Confidence != "confirmed" {
|
||||
t.Fatalf("confidence=%q want confirmed", f.Confidence)
|
||||
}
|
||||
if err := db.RestoreProjectFact(proj.ID, key, ""); err == nil {
|
||||
t.Fatal("expected error when not deprecated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertProjectFact_createsVersionOnContentChange(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
proj, err := db.CreateProject(&Project{Name: "version-test"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
created, err := db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "finding/xss",
|
||||
Category: "finding",
|
||||
Summary: "v1",
|
||||
Body: "body v1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if created.SupersedesFactID != "" {
|
||||
t.Fatalf("expected no supersedes on create, got %q", created.SupersedesFactID)
|
||||
}
|
||||
|
||||
updated, err := db.UpsertProjectFact(&ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "finding/xss",
|
||||
Summary: "v2",
|
||||
Body: "body v2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if updated.SupersedesFactID == "" {
|
||||
t.Fatal("expected supersedes_fact_id after content change")
|
||||
}
|
||||
prev, err := db.GetProjectFactVersion(updated.SupersedesFactID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if prev.Summary != "v1" || prev.Body != "body v1" {
|
||||
t.Fatalf("previous version mismatch: summary=%q body=%q", prev.Summary, prev.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeFactBodyOnUpdate(t *testing.T) {
|
||||
if got := mergeFactBodyOnUpdate("", "keep"); got != "keep" {
|
||||
t.Fatalf("empty incoming: got %q", got)
|
||||
}
|
||||
if got := mergeFactBodyOnUpdate(" ", "keep"); got != "keep" {
|
||||
t.Fatalf("whitespace incoming: got %q", got)
|
||||
}
|
||||
if got := mergeFactBodyOnUpdate("new", "old"); got != "new" {
|
||||
t.Fatalf("non-empty incoming: got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ProjectFactVersion 事实历史快照(同 fact_key 更新前归档)。
|
||||
type ProjectFactVersion struct {
|
||||
ID string `json:"id"`
|
||||
FactID string `json:"fact_id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
FactKey string `json:"fact_key"`
|
||||
Category string `json:"category"`
|
||||
Summary string `json:"summary"`
|
||||
Body string `json:"body"`
|
||||
Confidence string `json:"confidence"`
|
||||
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
||||
SourceMessageID string `json:"source_message_id,omitempty"`
|
||||
Pinned bool `json:"pinned"`
|
||||
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
|
||||
ArchivedAt time.Time `json:"archived_at"`
|
||||
}
|
||||
|
||||
// InsertProjectFactVersion 将当前事实行快照写入版本表。
|
||||
func (db *DB) InsertProjectFactVersion(f *ProjectFact) (string, error) {
|
||||
if f == nil || f.ID == "" {
|
||||
return "", fmt.Errorf("无效的事实记录")
|
||||
}
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO project_fact_versions (
|
||||
id, fact_id, project_id, fact_key, category, summary, body, confidence,
|
||||
source_conversation_id, source_message_id, pinned, related_vulnerability_id, archived_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
id, f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
|
||||
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
||||
nullIfEmpty(f.RelatedVulnerabilityID), now,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("归档事实版本失败: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetProjectFactVersion 按版本 ID 获取快照。
|
||||
func (db *DB) GetProjectFactVersion(versionID string) (*ProjectFactVersion, error) {
|
||||
row := db.QueryRow(
|
||||
`SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(related_vulnerability_id,''), archived_at
|
||||
FROM project_fact_versions WHERE id = ?`, versionID,
|
||||
)
|
||||
return scanProjectFactVersionRow(row)
|
||||
}
|
||||
|
||||
// ListProjectFactVersions 列出某条事实的全部历史版本(新→旧)。
|
||||
func (db *DB) ListProjectFactVersions(factID string, limit int) ([]*ProjectFactVersion, error) {
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
rows, err := db.Query(
|
||||
`SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
||||
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
||||
COALESCE(related_vulnerability_id,''), archived_at
|
||||
FROM project_fact_versions WHERE fact_id = ? ORDER BY archived_at DESC LIMIT ?`,
|
||||
factID, limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*ProjectFactVersion
|
||||
for rows.Next() {
|
||||
v, err := scanProjectFactVersionFromRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, v)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func projectFactContentChanged(existing, incoming *ProjectFact) bool {
|
||||
if existing == nil || incoming == nil {
|
||||
return false
|
||||
}
|
||||
mergedBody := mergeFactBodyOnUpdate(incoming.Body, existing.Body)
|
||||
inCat := stringsTrimDefault(incoming.Category, existing.Category)
|
||||
inConf := stringsTrimDefault(incoming.Confidence, existing.Confidence)
|
||||
return existing.Summary != incoming.Summary ||
|
||||
existing.Body != mergedBody ||
|
||||
existing.Category != inCat ||
|
||||
existing.Confidence != inConf
|
||||
}
|
||||
|
||||
func stringsTrimDefault(s, fallback string) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return fallback
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func scanProjectFactVersionRow(row *sql.Row) (*ProjectFactVersion, error) {
|
||||
var v ProjectFactVersion
|
||||
var pinned int
|
||||
var archivedAt string
|
||||
err := row.Scan(
|
||||
&v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence,
|
||||
&v.SourceConversationID, &v.SourceMessageID, &pinned,
|
||||
&v.RelatedVulnerabilityID, &archivedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("事实版本不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
v.Pinned = pinned != 0
|
||||
v.ArchivedAt = parseDBTime(archivedAt)
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
func scanProjectFactVersionFromRows(rows *sql.Rows) (*ProjectFactVersion, error) {
|
||||
var v ProjectFactVersion
|
||||
var pinned int
|
||||
var archivedAt string
|
||||
err := rows.Scan(
|
||||
&v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence,
|
||||
&v.SourceConversationID, &v.SourceMessageID, &pinned,
|
||||
&v.RelatedVulnerabilityID, &archivedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v.Pinned = pinned != 0
|
||||
v.ArchivedAt = parseDBTime(archivedAt)
|
||||
return &v, nil
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProjectStats 项目聚合统计。
|
||||
type ProjectStats struct {
|
||||
FactCount int `json:"fact_count"`
|
||||
VulnCount int `json:"vuln_count"`
|
||||
ConversationCount int `json:"conversation_count"`
|
||||
SparseFactCount int `json:"sparse_fact_count"`
|
||||
}
|
||||
|
||||
// GetProjectStatsCounts 统计项目下事实、漏洞、对话数量(不含 sparse,由 project 包补全)。
|
||||
func (db *DB) GetProjectStatsCounts(projectID string) (*ProjectStats, error) {
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return nil, fmt.Errorf("project_id 不能为空")
|
||||
}
|
||||
if _, err := db.GetProject(projectID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats := &ProjectStats{}
|
||||
if err := db.QueryRow(
|
||||
`SELECT COUNT(*) FROM project_facts WHERE project_id = ? AND confidence != 'deprecated'`,
|
||||
projectID,
|
||||
).Scan(&stats.FactCount); err != nil {
|
||||
return nil, fmt.Errorf("统计事实失败: %w", err)
|
||||
}
|
||||
if err := db.QueryRow(
|
||||
`SELECT COUNT(*) FROM vulnerabilities WHERE project_id = ?`,
|
||||
projectID,
|
||||
).Scan(&stats.VulnCount); err != nil {
|
||||
return nil, fmt.Errorf("统计漏洞失败: %w", err)
|
||||
}
|
||||
if err := db.QueryRow(
|
||||
`SELECT COUNT(*) FROM conversations WHERE project_id = ?`,
|
||||
projectID,
|
||||
).Scan(&stats.ConversationCount); err != nil {
|
||||
return nil, fmt.Errorf("统计对话失败: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// ListProjectFactsForSparseCheck 返回用于待补全检测的事实字段(非 deprecated)。
|
||||
func (db *DB) ListProjectFactsForSparseCheck(projectID string) ([]struct {
|
||||
Category string
|
||||
FactKey string
|
||||
Body string
|
||||
}, error) {
|
||||
rows, err := db.Query(
|
||||
`SELECT category, fact_key, COALESCE(body,'') FROM project_facts WHERE project_id = ? AND confidence != 'deprecated'`,
|
||||
projectID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []struct {
|
||||
Category string
|
||||
FactKey string
|
||||
Body string
|
||||
}
|
||||
for rows.Next() {
|
||||
var row struct {
|
||||
Category string
|
||||
FactKey string
|
||||
Body string
|
||||
}
|
||||
if err := rows.Scan(&row.Category, &row.FactKey, &row.Body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// ListConversationsByProjectID 列出绑定到项目的对话。
|
||||
func (db *DB) ListConversationsByProjectID(projectID string, limit, offset int) ([]*Conversation, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
rows, err := db.Query(
|
||||
`SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id
|
||||
FROM conversations WHERE project_id = ? ORDER BY updated_at DESC LIMIT ? OFFSET ?`,
|
||||
projectID, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询项目对话失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conversations []*Conversation
|
||||
for rows.Next() {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
var pid sql.NullString
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &pid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pid.Valid {
|
||||
conv.ProjectID = strings.TrimSpace(pid.String)
|
||||
}
|
||||
conv.CreatedAt = parseDBTime(createdAt)
|
||||
conv.UpdatedAt = parseDBTime(updatedAt)
|
||||
conv.Pinned = pinned != 0
|
||||
conversations = append(conversations, &conv)
|
||||
}
|
||||
return conversations, rows.Err()
|
||||
}
|
||||
|
||||
// CountConversationsByProjectID 统计项目绑定对话数。
|
||||
func (db *DB) CountConversationsByProjectID(projectID string) (int, error) {
|
||||
var n int
|
||||
err := db.QueryRow(`SELECT COUNT(*) FROM conversations WHERE project_id = ?`, projectID).Scan(&n)
|
||||
return n, err
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestParseDBTime_projectFactFormats(t *testing.T) {
|
||||
cases := []string{
|
||||
"2026-05-26 11:13:07.442143+08:00",
|
||||
"2026-05-26 11:13:07",
|
||||
"2026-05-26T11:13:07.442143+08:00",
|
||||
}
|
||||
for _, s := range cases {
|
||||
got := parseDBTime(s)
|
||||
if got.IsZero() {
|
||||
t.Fatalf("parseDBTime(%q) returned zero", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListProjectFacts_updatedAtJSON(t *testing.T) {
|
||||
root, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Skip(err)
|
||||
}
|
||||
dbPath := filepath.Join(root, "..", "..", "data", "conversations.db")
|
||||
if _, err := os.Stat(dbPath); err != nil {
|
||||
t.Skip("conversations.db not found")
|
||||
}
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
projects, err := db.ListProjects("", 1, 0)
|
||||
if err != nil || len(projects) == 0 {
|
||||
t.Skip("no projects")
|
||||
}
|
||||
pid := projects[0].ID
|
||||
|
||||
list, err := db.ListProjectFacts(pid, ProjectFactListFilter{}, 5, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(list) == 0 {
|
||||
t.Skip("no facts")
|
||||
}
|
||||
for _, f := range list {
|
||||
if f.UpdatedAt.IsZero() {
|
||||
t.Fatalf("fact %s UpdatedAt is zero after ListProjectFacts", f.FactKey)
|
||||
}
|
||||
b, err := json.Marshal(f)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var m map[string]interface{}
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
raw, ok := m["updated_at"].(string)
|
||||
if !ok || raw == "" || raw[:4] == "0001" {
|
||||
t.Fatalf("bad updated_at in JSON: %v", m["updated_at"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDBTime_zeroOnGarbage(t *testing.T) {
|
||||
if !parseDBTime("").IsZero() {
|
||||
t.Fatal("expected zero for empty")
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure RFC3339 round-trip used by API is after year 2000.
|
||||
func TestParseDBTime_marshalRoundTrip(t *testing.T) {
|
||||
s := "2026-05-26 11:13:07.442143+08:00"
|
||||
tm := parseDBTime(s)
|
||||
b, err := json.Marshal(tm)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var back time.Time
|
||||
if err := json.Unmarshal(b, &back); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if back.IsZero() {
|
||||
t.Fatalf("unmarshal zero from %s", string(b))
|
||||
}
|
||||
}
|
||||
@@ -3,16 +3,94 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// VulnerabilityListFilter 列表/统计/导出共用的筛选条件
|
||||
type VulnerabilityListFilter struct {
|
||||
ID string
|
||||
Search string // 关键词模糊匹配(标题、描述、类型、目标等)
|
||||
ConversationID string
|
||||
ProjectID string
|
||||
Severity string
|
||||
Status string
|
||||
TaskID string
|
||||
ConversationTag string
|
||||
TaskTag string
|
||||
}
|
||||
|
||||
func escapeVulnerabilityLikePattern(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||
return "%" + s + "%"
|
||||
}
|
||||
|
||||
func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (string, []interface{}) {
|
||||
if f.ID != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, f.ID)
|
||||
}
|
||||
if f.ConversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, f.ConversationID)
|
||||
}
|
||||
if f.ProjectID != "" {
|
||||
query += " AND project_id = ?"
|
||||
args = append(args, f.ProjectID)
|
||||
}
|
||||
if f.TaskID != "" {
|
||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, f.TaskID, f.TaskID)
|
||||
}
|
||||
if f.ConversationTag != "" {
|
||||
query += " AND conversation_tag = ?"
|
||||
args = append(args, f.ConversationTag)
|
||||
}
|
||||
if f.TaskTag != "" {
|
||||
query += " AND task_tag = ?"
|
||||
args = append(args, f.TaskTag)
|
||||
}
|
||||
if f.Severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, f.Severity)
|
||||
}
|
||||
if f.Status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, f.Status)
|
||||
}
|
||||
search := strings.TrimSpace(f.Search)
|
||||
if search != "" {
|
||||
pattern := escapeVulnerabilityLikePattern(search)
|
||||
query += ` AND (
|
||||
LOWER(id) LIKE LOWER(?) OR
|
||||
LOWER(title) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(description, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(target, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(task_tag, '')) LIKE LOWER(?)
|
||||
)`
|
||||
for i := 0; i < 11; i++ {
|
||||
args = append(args, pattern)
|
||||
}
|
||||
}
|
||||
return query, args
|
||||
}
|
||||
|
||||
// Vulnerability 漏洞
|
||||
type Vulnerability struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
ConversationTag string `json:"conversation_tag,omitempty"`
|
||||
TaskTag string `json:"task_tag,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
@@ -44,17 +122,23 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
||||
}
|
||||
vuln.UpdatedAt = now
|
||||
|
||||
if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID != "" {
|
||||
if pid, err := db.GetConversationProjectID(vuln.ConversationID); err == nil {
|
||||
vuln.ProjectID = pid
|
||||
}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO vulnerabilities (
|
||||
id, conversation_id, conversation_tag, task_tag, title, description, severity, status,
|
||||
id, conversation_id, project_id, conversation_tag, task_tag, title, description, severity, status,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := db.Exec(
|
||||
query,
|
||||
vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
|
||||
vuln.ID, vuln.ConversationID, nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
|
||||
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
|
||||
vuln.Proof, vuln.Impact, vuln.Recommendation,
|
||||
vuln.CreatedAt, vuln.UpdatedAt,
|
||||
@@ -70,7 +154,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
||||
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
var vuln Vulnerability
|
||||
query := `
|
||||
SELECT id, conversation_id, title, description, severity, status,
|
||||
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status,
|
||||
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
|
||||
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
||||
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
||||
@@ -80,7 +164,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
`
|
||||
|
||||
err := db.QueryRow(query, id).Scan(
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
|
||||
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
||||
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
||||
&vuln.TaskID, &vuln.TaskQueueID,
|
||||
@@ -97,9 +181,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
}
|
||||
|
||||
// ListVulnerabilities 列出漏洞
|
||||
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) {
|
||||
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
|
||||
query := `
|
||||
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
|
||||
SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
||||
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
||||
@@ -108,35 +192,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
||||
WHERE 1=1
|
||||
`
|
||||
args := []interface{}{}
|
||||
|
||||
if id != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, id)
|
||||
}
|
||||
if conversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if taskID != "" {
|
||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, taskID, taskID)
|
||||
}
|
||||
if conversationTag != "" {
|
||||
query += " AND conversation_tag = ?"
|
||||
args = append(args, conversationTag)
|
||||
}
|
||||
if taskTag != "" {
|
||||
query += " AND task_tag = ?"
|
||||
args = append(args, taskTag)
|
||||
}
|
||||
if severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, severity)
|
||||
}
|
||||
if status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
query, args = filter.appendWhere(query, args)
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
@@ -151,7 +207,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
||||
for rows.Next() {
|
||||
var vuln Vulnerability
|
||||
err := rows.Scan(
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
|
||||
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
||||
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
||||
&vuln.TaskID, &vuln.TaskQueueID,
|
||||
@@ -168,38 +224,10 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
||||
}
|
||||
|
||||
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
|
||||
func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) {
|
||||
func (db *DB) CountVulnerabilities(filter VulnerabilityListFilter) (int, error) {
|
||||
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
if id != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, id)
|
||||
}
|
||||
if conversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if taskID != "" {
|
||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, taskID, taskID)
|
||||
}
|
||||
if conversationTag != "" {
|
||||
query += " AND conversation_tag = ?"
|
||||
args = append(args, conversationTag)
|
||||
}
|
||||
if taskTag != "" {
|
||||
query += " AND task_tag = ?"
|
||||
args = append(args, taskTag)
|
||||
}
|
||||
if severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, severity)
|
||||
}
|
||||
if status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
query, args = filter.appendWhere(query, args)
|
||||
|
||||
var count int
|
||||
err := db.QueryRow(query, args...).Scan(&count)
|
||||
@@ -216,7 +244,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
||||
|
||||
query := `
|
||||
UPDATE vulnerabilities
|
||||
SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
|
||||
SET project_id = ?, conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
|
||||
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
|
||||
recommendation = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
@@ -224,7 +252,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
||||
|
||||
_, err := db.Exec(
|
||||
query,
|
||||
vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
|
||||
nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
|
||||
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
|
||||
vuln.Recommendation, vuln.UpdatedAt, id,
|
||||
)
|
||||
@@ -237,27 +265,32 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
||||
|
||||
// DeleteVulnerability 删除漏洞
|
||||
func (db *DB) DeleteVulnerability(id string) error {
|
||||
_, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id)
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("开启事务失败: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// 删除漏洞前先解除项目事实中的关联,避免前端继续显示已删除漏洞的短 ID。
|
||||
if _, err := tx.Exec("UPDATE project_facts SET related_vulnerability_id = NULL WHERE related_vulnerability_id = ?", id); err != nil {
|
||||
return fmt.Errorf("清理事实漏洞关联失败: %w", err)
|
||||
}
|
||||
if _, err := tx.Exec("DELETE FROM vulnerabilities WHERE id = ?", id); err != nil {
|
||||
return fmt.Errorf("删除漏洞失败: %w", err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("提交事务失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
|
||||
func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) {
|
||||
func (db *DB) GetVulnerabilityStats(filter VulnerabilityListFilter) (map[string]interface{}, error) {
|
||||
stats := make(map[string]interface{})
|
||||
|
||||
where := "WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
if conversationID != "" {
|
||||
where += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if taskID != "" {
|
||||
where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, taskID, taskID)
|
||||
}
|
||||
where, args = filter.appendWhere(where, args)
|
||||
|
||||
// 总漏洞数
|
||||
var totalCount int
|
||||
@@ -357,10 +390,15 @@ func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
|
||||
}
|
||||
projectIDs, err := collect(`SELECT DISTINCT project_id FROM vulnerabilities WHERE project_id IS NOT NULL AND project_id <> '' ORDER BY created_at DESC LIMIT 200`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询项目ID建议失败: %w", err)
|
||||
}
|
||||
|
||||
return map[string][]string{
|
||||
"vulnerability_ids": vulnIDs,
|
||||
"conversation_ids": conversationIDs,
|
||||
"project_ids": projectIDs,
|
||||
"task_ids": taskIDs,
|
||||
"queue_ids": queueIDs,
|
||||
"conversation_tags": conversationTags,
|
||||
|
||||
@@ -96,6 +96,17 @@ type runHandler struct {
|
||||
seq atomic.Uint64
|
||||
}
|
||||
|
||||
func safeRunInfo(info *callbacks.RunInfo) callbacks.RunInfo {
|
||||
if info == nil {
|
||||
return callbacks.RunInfo{
|
||||
Name: "unknown",
|
||||
Type: "unknown",
|
||||
Component: components.Component("unknown"),
|
||||
}
|
||||
}
|
||||
return *info
|
||||
}
|
||||
|
||||
func (h *runHandler) genSpanID() string {
|
||||
return fmt.Sprintf("%s-%d", h.runID, h.seq.Add(1))
|
||||
}
|
||||
@@ -134,6 +145,7 @@ func (h *runHandler) popMatching(want string) string {
|
||||
}
|
||||
|
||||
func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||
ri := safeRunInfo(info)
|
||||
var parentID string
|
||||
h.mu.Lock()
|
||||
if len(h.spanStack) > 0 {
|
||||
@@ -151,9 +163,9 @@ func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input
|
||||
ctx, sp = tracer.Start(ctx, spanName,
|
||||
trace.WithSpanKind(trace.SpanKindInternal),
|
||||
trace.WithAttributes(
|
||||
attribute.String("eino.component", string(info.Component)),
|
||||
attribute.String("eino.name", info.Name),
|
||||
attribute.String("eino.type", info.Type),
|
||||
attribute.String("eino.component", string(ri.Component)),
|
||||
attribute.String("eino.name", ri.Name),
|
||||
attribute.String("eino.type", ri.Type),
|
||||
attribute.String("cyberstrike.run_id", h.runID),
|
||||
attribute.String("cyberstrike.conversation_id", strings.TrimSpace(h.params.ConversationID)),
|
||||
attribute.String("cyberstrike.orchestration", strings.TrimSpace(h.params.OrchMode)),
|
||||
@@ -169,9 +181,9 @@ func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("parentSpanId", parentID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("type", info.Type),
|
||||
zap.String("component", string(ri.Component)),
|
||||
zap.String("name", ri.Name),
|
||||
zap.String("type", ri.Type),
|
||||
zap.String("phase", "start"),
|
||||
}
|
||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||
@@ -195,9 +207,9 @@ func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input
|
||||
"parentSpanId": parentID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(info.Component),
|
||||
"name": info.Name,
|
||||
"type": info.Type,
|
||||
"component": string(ri.Component),
|
||||
"name": ri.Name,
|
||||
"type": ri.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"inputSummary": inSum,
|
||||
"source": "eino_callbacks",
|
||||
@@ -208,6 +220,7 @@ func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input
|
||||
}
|
||||
|
||||
func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
||||
ri := safeRunInfo(info)
|
||||
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||
if spanID == "" {
|
||||
spanID = h.popSpan()
|
||||
@@ -226,9 +239,9 @@ func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output
|
||||
fields := []zap.Field{
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("type", info.Type),
|
||||
zap.String("component", string(ri.Component)),
|
||||
zap.String("name", ri.Name),
|
||||
zap.String("type", ri.Type),
|
||||
zap.String("phase", "end"),
|
||||
}
|
||||
if h.cfg.ZapVerbose {
|
||||
@@ -243,9 +256,9 @@ func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output
|
||||
"spanId": spanID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(info.Component),
|
||||
"name": info.Name,
|
||||
"type": info.Type,
|
||||
"component": string(ri.Component),
|
||||
"name": ri.Name,
|
||||
"type": ri.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"outputSummary": outSum,
|
||||
"source": "eino_callbacks",
|
||||
@@ -255,6 +268,7 @@ func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output
|
||||
}
|
||||
|
||||
func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||
ri := safeRunInfo(info)
|
||||
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||
if spanID == "" {
|
||||
spanID = h.popSpan()
|
||||
@@ -276,9 +290,9 @@ func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err e
|
||||
h.params.Logger.Warn("eino_callback_error",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("type", info.Type),
|
||||
zap.String("component", string(ri.Component)),
|
||||
zap.String("name", ri.Name),
|
||||
zap.String("type", ri.Type),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
@@ -288,9 +302,9 @@ func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err e
|
||||
"spanId": spanID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(info.Component),
|
||||
"name": info.Name,
|
||||
"type": info.Type,
|
||||
"component": string(ri.Component),
|
||||
"name": ri.Name,
|
||||
"type": ri.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"error": msg,
|
||||
"source": "eino_callbacks",
|
||||
@@ -300,28 +314,30 @@ func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err e
|
||||
}
|
||||
|
||||
func (h *runHandler) onStartStreamIn(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
|
||||
ri := safeRunInfo(info)
|
||||
if input != nil {
|
||||
input.Close()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Debug("eino_callback_stream_in",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("component", string(ri.Component)),
|
||||
zap.String("name", ri.Name),
|
||||
)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onEndStreamOut(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
|
||||
ri := safeRunInfo(info)
|
||||
if output != nil {
|
||||
output.Close()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Debug("eino_callback_stream_out",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("component", string(ri.Component)),
|
||||
zap.String("name", ri.Name),
|
||||
)
|
||||
}
|
||||
return ctx
|
||||
|
||||
+249
-55
@@ -17,6 +17,7 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
@@ -86,6 +87,23 @@ func normalizeProcessDetailText(s string) string {
|
||||
// discardPlanningIfEchoesToolResult drops buffered planning text when it only repeats the
|
||||
// upcoming tool_result body. Streaming models often echo tool stdout in chunk.Content; flushing
|
||||
// that into "planning" before persisting tool_result duplicates the output after page refresh.
|
||||
// sameResponseStreamMeta 判断是否为同一段主通道流(Eino ADK 可能对同一 MessageStream 重复发 response_start)。
|
||||
func sameResponseStreamMeta(a, b map[string]interface{}) bool {
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
agentA, _ := a["einoAgent"].(string)
|
||||
agentB, _ := b["einoAgent"].(string)
|
||||
agentA = strings.TrimSpace(agentA)
|
||||
agentB = strings.TrimSpace(agentB)
|
||||
if agentA == "" || !strings.EqualFold(agentA, agentB) {
|
||||
return false
|
||||
}
|
||||
orchA, _ := a["orchestration"].(string)
|
||||
orchB, _ := b["orchestration"].(string)
|
||||
return strings.TrimSpace(orchA) == strings.TrimSpace(orchB)
|
||||
}
|
||||
|
||||
func discardPlanningIfEchoesToolResult(respPlan *responsePlanAgg, toolData interface{}) {
|
||||
if respPlan == nil {
|
||||
return
|
||||
@@ -131,6 +149,12 @@ type AgentHandler struct {
|
||||
batchRunning map[string]struct{}
|
||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *AgentHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||
@@ -207,7 +231,7 @@ type ChatAttachment struct {
|
||||
type ChatReasoningRequest struct {
|
||||
// Mode: default(跟随系统)| off | on | auto
|
||||
Mode string `json:"mode,omitempty"`
|
||||
// Effort: low | medium | high | max;空表示不指定(由系统默认与各 profile 决定)。
|
||||
// Effort: low | medium | high | max | xhigh(原样下发;不同网关最高档命名不同)。空表示不指定。
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
@@ -215,6 +239,7 @@ type ChatReasoningRequest struct {
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty"` // 新对话绑定的项目(可选;未指定时可用 config.project.default_project_id)
|
||||
Role string `json:"role,omitempty"` // 角色名称
|
||||
Attachments []ChatAttachment `json:"attachments,omitempty"`
|
||||
WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具
|
||||
@@ -553,7 +578,9 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
conversationID := req.ConversationID
|
||||
if conversationID == "" {
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop")
|
||||
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -628,6 +655,8 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
builtin.ToolWebshellFileRead,
|
||||
builtin.ToolWebshellFileWrite,
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
builtin.ToolListKnowledgeRiskTypes,
|
||||
builtin.ToolSearchKnowledgeBase,
|
||||
}
|
||||
@@ -675,7 +704,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, "", nil)
|
||||
|
||||
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
|
||||
@@ -717,11 +746,45 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) finalizeRobotAgentError(ctx context.Context, assistantMessageID, conversationID string, resultMA *multiagent.RunResult, errMA error) (string, string, error) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errMsg := "执行失败: " + errMA.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
return "", conversationID, errMA
|
||||
}
|
||||
|
||||
func (h *AgentHandler) finalizeRobotAgentSuccess(assistantMessageID, conversationID string, resultMA *multiagent.RunResult) (string, string, error) {
|
||||
if assistantMessageID != "" {
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||
}
|
||||
} else {
|
||||
if _, err := h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
|
||||
}
|
||||
return resultMA.Response, conversationID, nil
|
||||
}
|
||||
|
||||
// ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复
|
||||
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationID, message, role string) (response string, convID string, err error) {
|
||||
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, conversationID, message, role string) (response string, convID string, err error) {
|
||||
if conversationID == "" {
|
||||
title := safeTruncateString(message, 50)
|
||||
conv, createErr := h.db.CreateConversation(title)
|
||||
src := "robot"
|
||||
if strings.TrimSpace(platform) != "" {
|
||||
src = "robot:" + strings.TrimSpace(platform)
|
||||
}
|
||||
meta := audit.ConversationCreateMeta(src)
|
||||
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||
conv, createErr := h.db.CreateConversation(title, meta)
|
||||
if createErr != nil {
|
||||
return "", "", fmt.Errorf("创建对话失败: %w", createErr)
|
||||
}
|
||||
@@ -769,53 +832,92 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
progressCallback := h.createProgressCallback(ctx, nil, conversationID, assistantMessageID, nil)
|
||||
|
||||
useRobotMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.RobotUseMultiAgent
|
||||
if useRobotMulti {
|
||||
resultMA, errMA := multiagent.RunDeepAgent(
|
||||
ctx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
finalMessage,
|
||||
agentHistoryMessages,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
"deep",
|
||||
nil,
|
||||
)
|
||||
if errMA != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errMsg := "执行失败: " + errMA.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
return "", conversationID, errMA
|
||||
// 注册运行中任务并向 taskEventBus 镜像进度事件,供 Web 端 task-events 补流(与 agent-loop/stream 一致)。
|
||||
taskCtx, cancelWithCause := context.WithCancelCause(ctx)
|
||||
defer cancelWithCause(nil)
|
||||
taskStatus := "completed"
|
||||
defer func() {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}()
|
||||
if _, err := h.tasks.StartTask(conversationID, message, cancelWithCause); err != nil {
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
return "", conversationID, fmt.Errorf("当前会话已有任务正在执行中,请稍后再试")
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||
return "", conversationID, fmt.Errorf("无法启动任务: %w", err)
|
||||
}
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, nil)
|
||||
|
||||
robotMode := "react"
|
||||
if h.config != nil {
|
||||
robotMode = config.NormalizeRobotAgentMode(h.config.MultiAgent)
|
||||
}
|
||||
switch robotMode {
|
||||
case "eino_single":
|
||||
curHist := agentHistoryMessages
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
||||
conversationID, curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
if errMA == nil {
|
||||
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
|
||||
transientRunAttempts = 0
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if _, err = h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
|
||||
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
|
||||
case "deep", "plan_execute", "supervisor":
|
||||
if h.config == nil || !h.config.MultiAgent.Enabled {
|
||||
h.logger.Warn("机器人配置为多代理模式但未启用 multi_agent,回退原生 ReAct",
|
||||
zap.String("robot_mode", robotMode))
|
||||
break
|
||||
}
|
||||
return resultMA.Response, conversationID, nil
|
||||
curHist := agentHistoryMessages
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunDeepAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
||||
conversationID, curMsg, curHist, roleTools, progressCallback,
|
||||
h.agentsMarkdownDir, robotMode, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
if errMA == nil {
|
||||
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
|
||||
transientRunAttempts = 0
|
||||
break
|
||||
}
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
|
||||
}
|
||||
|
||||
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
if err != nil {
|
||||
taskStatus = "failed"
|
||||
errMsg := "执行失败: " + err.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
@@ -847,6 +949,23 @@ type StreamEvent struct {
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// publishProgressToTaskEventBus 将进度事件镜像到 taskEventBus(机器人/无 HTTP SSE 客户端时供 Web task-events 订阅)。
|
||||
func (h *AgentHandler) publishProgressToTaskEventBus(conversationID, eventType, message string, data interface{}) {
|
||||
if h == nil || h.taskEventBus == nil || strings.TrimSpace(conversationID) == "" {
|
||||
return
|
||||
}
|
||||
event := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
eventJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sseLine := make([]byte, 0, len(eventJSON)+8)
|
||||
sseLine = append(sseLine, []byte("data: ")...)
|
||||
sseLine = append(sseLine, eventJSON...)
|
||||
sseLine = append(sseLine, '\n', '\n')
|
||||
h.taskEventBus.Publish(conversationID, sseLine)
|
||||
}
|
||||
|
||||
// createProgressCallback 创建进度回调函数,用于保存processDetails
|
||||
// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件
|
||||
func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
|
||||
@@ -894,6 +1013,8 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
|
||||
flushedThinking := make(map[string]bool) // streamId -> flushed
|
||||
seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature
|
||||
seenToolResultSigs := make(map[string]string) // toolCallId -> payload signature
|
||||
|
||||
// response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta;
|
||||
// 聚合为一条 planning 写入 process_details,刷新后与线上一致。
|
||||
@@ -956,9 +1077,34 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
|
||||
return func(eventType, message string, data interface{}) {
|
||||
// 如果提供了sendEventFunc,发送流式事件
|
||||
// 上游在重试/补偿时可能重复回调相同 tool_call/tool_result。
|
||||
// 这里做幂等过滤,保证前端展示和 process_details 都以唯一事件为准。
|
||||
if (eventType == "tool_call" || eventType == "tool_result") && data != nil {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolCallID := strings.TrimSpace(fmt.Sprint(dataMap["toolCallId"]))
|
||||
if toolCallID != "" && toolCallID != "<nil>" {
|
||||
payloadJSON, _ := json.Marshal(dataMap)
|
||||
sig := eventType + "|" + message + "|" + string(payloadJSON)
|
||||
seen := seenToolCallSigs
|
||||
if eventType == "tool_result" {
|
||||
seen = seenToolResultSigs
|
||||
}
|
||||
if prev, exists := seen[toolCallID]; exists && prev == sig {
|
||||
h.logger.Debug("跳过重复工具进度事件",
|
||||
zap.String("eventType", eventType),
|
||||
zap.String("toolCallId", toolCallID))
|
||||
return
|
||||
}
|
||||
seen[toolCallID] = sig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 流式:写 HTTP SSE;非流式(机器人等):镜像到 taskEventBus 供 Web 订阅
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc(eventType, message, data)
|
||||
} else {
|
||||
h.publishProgressToTaskEventBus(conversationID, eventType, message, data)
|
||||
}
|
||||
|
||||
// 保存tool_call事件中的参数
|
||||
@@ -1147,6 +1293,17 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
|
||||
// 多代理主代理「规划中」:response_start / response_delta 仅用于 SSE,聚合落一条 planning
|
||||
if eventType == "response_start" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sameResponseStreamMeta(respPlan.meta, dataMap) {
|
||||
if respPlan.meta == nil {
|
||||
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
||||
}
|
||||
for k, v := range dataMap {
|
||||
respPlan.meta[k] = v
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
flushResponsePlan()
|
||||
respPlan.meta = nil
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
@@ -1288,7 +1445,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// 对于流式请求,也发送SSE格式的错误
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
event := StreamEvent{
|
||||
@@ -1310,7 +1467,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
)
|
||||
|
||||
// 设置SSE响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no") // 禁用nginx缓冲
|
||||
@@ -1420,10 +1577,13 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
var conv *database.Conversation
|
||||
var err error
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop_stream")
|
||||
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||
if req.WebShellConnectionID != "" {
|
||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
|
||||
meta.Source = "webshell_chat"
|
||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title, meta)
|
||||
} else {
|
||||
conv, err = h.db.CreateConversation(title)
|
||||
conv, err = h.db.CreateConversation(title, meta)
|
||||
}
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
@@ -1496,6 +1656,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
builtin.ToolWebshellFileRead,
|
||||
builtin.ToolWebshellFileWrite,
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
builtin.ToolListKnowledgeRiskTypes,
|
||||
builtin.ToolSearchKnowledgeBase,
|
||||
}
|
||||
@@ -1626,7 +1788,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
|
||||
defer close(stopKeepalive)
|
||||
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
cause := context.Cause(baseCtx)
|
||||
@@ -1886,7 +2048,7 @@ func (h *AgentHandler) SubscribeAgentTaskEvents(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
@@ -1938,6 +2100,7 @@ type BatchTaskRequest struct {
|
||||
ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron
|
||||
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
|
||||
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
|
||||
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
|
||||
}
|
||||
|
||||
func normalizeBatchQueueAgentMode(mode string) string {
|
||||
@@ -2018,7 +2181,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
||||
nextRunAt = &next
|
||||
}
|
||||
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, nextRunAt, validTasks)
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks)
|
||||
if createErr != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
|
||||
return
|
||||
@@ -2039,6 +2202,11 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
||||
queue = refreshed
|
||||
}
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "create_queue", "创建批量任务队列", "batch_queue", queue.ID, map[string]interface{}{
|
||||
"task_count": len(validTasks), "started": started,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"queueId": queue.ID,
|
||||
"queue": queue,
|
||||
@@ -2146,6 +2314,9 @@ func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "start_queue", "启动批量任务队列", "batch_queue", queueID, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
|
||||
}
|
||||
|
||||
@@ -2174,6 +2345,9 @@ func (h *AgentHandler) RerunBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "启动失败"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "rerun_queue", "重跑批量任务队列", "batch_queue", queueID, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已重新开始执行", "queueId": queueID})
|
||||
}
|
||||
|
||||
@@ -2185,6 +2359,9 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "pause_queue", "暂停批量任务队列", "batch_queue", queueID, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"})
|
||||
}
|
||||
|
||||
@@ -2280,6 +2457,16 @@ func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "task",
|
||||
Action: "delete_queue",
|
||||
Result: "success",
|
||||
ResourceType: "batch_queue",
|
||||
ResourceID: queueID,
|
||||
Message: "删除批量任务队列",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"})
|
||||
}
|
||||
|
||||
@@ -2365,6 +2552,11 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "delete_batch_task", "删除批量子任务", "batch_task", taskID, map[string]interface{}{
|
||||
"batch_queue_id": queueID,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
||||
}
|
||||
|
||||
@@ -2523,7 +2715,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
|
||||
// 创建新对话
|
||||
title := safeTruncateString(task.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
batchMeta := audit.ConversationCreateMeta("batch_task")
|
||||
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, batchMeta)
|
||||
var conversationID string
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
@@ -2673,15 +2867,15 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
var runErr error
|
||||
switch {
|
||||
case useBatchMulti:
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil)
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
||||
case useEinoSingle:
|
||||
if h.config == nil {
|
||||
runErr = fmt.Errorf("服务器配置未加载")
|
||||
} else {
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil)
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
||||
}
|
||||
default:
|
||||
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
|
||||
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AuditHandler serves platform audit log APIs.
|
||||
type AuditHandler struct {
|
||||
db *database.DB
|
||||
audit *audit.Service
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuditHandler creates an audit log handler.
|
||||
func NewAuditHandler(db *database.DB, auditSvc *audit.Service, logger *zap.Logger) *AuditHandler {
|
||||
return &AuditHandler{db: db, audit: auditSvc, logger: logger}
|
||||
}
|
||||
|
||||
// Meta GET /api/audit/meta
|
||||
func (h *AuditHandler) Meta(c *gin.Context) {
|
||||
enabled := false
|
||||
retentionDays := 0
|
||||
if h.audit != nil {
|
||||
enabled = h.audit.Enabled()
|
||||
retentionDays = h.audit.RetentionDays()
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": enabled,
|
||||
"retention_days": retentionDays,
|
||||
"default_page_size": 20,
|
||||
"max_page_size": 100,
|
||||
"max_export": 5000,
|
||||
})
|
||||
}
|
||||
|
||||
// Summary GET /api/audit/summary
|
||||
func (h *AuditHandler) Summary(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
base := auditFilterFromQuery(c)
|
||||
total, err := h.db.CountAuditLogs(base)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
failFilter := base
|
||||
failFilter.Result = "failure"
|
||||
failures, err := h.db.CountAuditLogs(failFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
since := time.Now().AddDate(0, 0, -7)
|
||||
recentFilter := base
|
||||
recentFilter.Since = &since
|
||||
recent7d, err := h.db.CountAuditLogs(recentFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"total": total,
|
||||
"failures": failures,
|
||||
"recent_7d": recent7d,
|
||||
"has_filters": c.Query("category") != "" || c.Query("action") != "" || c.Query("result") != "" ||
|
||||
c.Query("q") != "" || c.Query("since") != "" || c.Query("until") != "",
|
||||
})
|
||||
}
|
||||
|
||||
// ListLogs GET /api/audit/logs
|
||||
func (h *AuditHandler) ListLogs(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
filter := auditFilterFromQuery(c)
|
||||
page, pageSize := auditPaginationFromQuery(c)
|
||||
filter.Limit = pageSize
|
||||
filter.Offset = (page - 1) * pageSize
|
||||
|
||||
logs, err := h.db.ListAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
total, err := h.db.CountAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"logs": logs,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// GetLog GET /api/audit/logs/:id
|
||||
func (h *AuditHandler) GetLog(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
row, err := h.db.GetAuditLogByID(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"})
|
||||
return
|
||||
}
|
||||
audit.ApplyResourceAvailability(h.db, row)
|
||||
c.JSON(http.StatusOK, gin.H{"log": row})
|
||||
}
|
||||
|
||||
// ExportLogs GET /api/audit/logs/export — JSON or CSV (?format=csv), max 5000 rows.
|
||||
func (h *AuditHandler) ExportLogs(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
filter := auditFilterFromQuery(c)
|
||||
filter.Limit = 5000
|
||||
filter.Offset = 0
|
||||
|
||||
logs, err := h.db.ListAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if c.Query("format") == "csv" {
|
||||
writeAuditLogsCSV(c, logs)
|
||||
return
|
||||
}
|
||||
c.Header("Content-Disposition", `attachment; filename="audit-logs.json"`)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"exported_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"logs": logs,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func writeAuditLogsCSV(c *gin.Context, logs []*database.AuditLog) {
|
||||
c.Header("Content-Type", "text/csv; charset=utf-8")
|
||||
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="audit-logs-%s.csv"`, time.Now().Format("20060102")))
|
||||
|
||||
w := csv.NewWriter(c.Writer)
|
||||
_ = w.Write([]string{
|
||||
"id", "created_at", "level", "category", "action", "result", "actor",
|
||||
"session_hint", "client_ip", "resource_type", "resource_id", "message",
|
||||
})
|
||||
for _, row := range logs {
|
||||
if row == nil {
|
||||
continue
|
||||
}
|
||||
_ = w.Write([]string{
|
||||
row.ID,
|
||||
row.CreatedAt.UTC().Format(time.RFC3339),
|
||||
row.Level,
|
||||
row.Category,
|
||||
row.Action,
|
||||
row.Result,
|
||||
row.Actor,
|
||||
row.SessionHint,
|
||||
row.ClientIP,
|
||||
row.ResourceType,
|
||||
row.ResourceID,
|
||||
row.Message,
|
||||
})
|
||||
}
|
||||
w.Flush()
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
|
||||
filter := database.ListAuditLogsFilter{
|
||||
Level: c.Query("level"),
|
||||
Category: c.Query("category"),
|
||||
Action: c.Query("action"),
|
||||
Result: c.Query("result"),
|
||||
Query: c.Query("q"),
|
||||
ResourceType: c.Query("resource_type"),
|
||||
ResourceID: c.Query("resource_id"),
|
||||
}
|
||||
if since := c.Query("since"); since != "" {
|
||||
if t, err := time.Parse(time.RFC3339, since); err == nil {
|
||||
filter.Since = &t
|
||||
}
|
||||
}
|
||||
if until := c.Query("until"); until != "" {
|
||||
if t, err := time.Parse(time.RFC3339, until); err == nil {
|
||||
filter.Until = &t
|
||||
}
|
||||
}
|
||||
return filter
|
||||
}
|
||||
|
||||
func auditPaginationFromQuery(c *gin.Context) (page, pageSize int) {
|
||||
page = 1
|
||||
pageSize = 20
|
||||
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
|
||||
page = p
|
||||
}
|
||||
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "20")); err == nil && ps > 0 {
|
||||
pageSize = ps
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
}
|
||||
return page, pageSize
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
@@ -18,6 +19,12 @@ type AuthHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *AuthHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler.
|
||||
@@ -49,10 +56,32 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
|
||||
token, expiresAt, err := h.manager.Authenticate(req.Password)
|
||||
if err != nil {
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Level: "warn",
|
||||
Category: "auth",
|
||||
Action: "login",
|
||||
Result: "failure",
|
||||
Message: "登录失败:密码错误",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "login",
|
||||
Result: "success",
|
||||
SessionHint: audit.HintFromToken(token),
|
||||
Message: "登录成功",
|
||||
Detail: map[string]interface{}{
|
||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": token,
|
||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||
@@ -73,6 +102,14 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.manager.RevokeToken(token)
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "logout",
|
||||
Result: "success",
|
||||
Message: "退出登录",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
|
||||
}
|
||||
|
||||
@@ -103,6 +140,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
}
|
||||
|
||||
if !h.manager.CheckPassword(oldPassword) {
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Level: "warn",
|
||||
Category: "auth",
|
||||
Action: "change_password",
|
||||
Result: "failure",
|
||||
Message: "修改密码失败:当前密码不正确",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
|
||||
return
|
||||
}
|
||||
@@ -132,6 +178,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
h.logger.Info("登录密码已更新,所有会话已失效")
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "change_password",
|
||||
Result: "success",
|
||||
Message: "登录密码已修改",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
|
||||
}
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ type BatchTaskQueue struct {
|
||||
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
|
||||
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
||||
LastRunError string `json:"lastRunError,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
Tasks []*BatchTask `json:"tasks"`
|
||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
@@ -103,7 +104,7 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (m *BatchTaskManager) CreateBatchQueue(
|
||||
title, role, agentMode, scheduleMode, cronExpr string,
|
||||
title, role, agentMode, scheduleMode, cronExpr, projectID string,
|
||||
nextRunAt *time.Time,
|
||||
tasks []string,
|
||||
) (*BatchTaskQueue, error) {
|
||||
@@ -126,6 +127,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
||||
ID: queueID,
|
||||
Title: title,
|
||||
Role: role,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
AgentMode: normalizeBatchQueueAgentMode(agentMode),
|
||||
ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode),
|
||||
CronExpr: strings.TrimSpace(cronExpr),
|
||||
@@ -171,6 +173,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
||||
queue.ScheduleMode,
|
||||
queue.CronExpr,
|
||||
queue.NextRunAt,
|
||||
queue.ProjectID,
|
||||
dbTasks,
|
||||
); err != nil {
|
||||
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
@@ -263,6 +266,9 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
||||
if queueRow.LastRunError.Valid {
|
||||
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
||||
}
|
||||
if queueRow.ProjectID.Valid {
|
||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||
}
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
@@ -499,6 +505,9 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
||||
if queueRow.LastRunError.Valid {
|
||||
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
||||
}
|
||||
if queueRow.ProjectID.Valid {
|
||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||
}
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
|
||||
@@ -176,6 +176,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
"type": "boolean",
|
||||
"description": "创建后是否立即开始执行队列,默认 false(pending,需 batch_task_start)",
|
||||
},
|
||||
"project_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
@@ -204,7 +208,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
if !ok {
|
||||
executeNow = false
|
||||
}
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks)
|
||||
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
|
||||
if createErr != nil {
|
||||
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
@@ -25,6 +26,12 @@ import (
|
||||
type C2Handler struct {
|
||||
mgrPtr atomic.Pointer[c2.Manager]
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *C2Handler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时)
|
||||
@@ -104,6 +111,11 @@ func (h *C2Handler) CreateListener(c *gin.Context) {
|
||||
implantToken := listener.ImplantToken
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_create", "创建 C2 监听器", "c2_listener", listener.ID, map[string]interface{}{
|
||||
"name": listener.Name, "bind": listener.BindHost, "port": listener.BindPort,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken})
|
||||
}
|
||||
|
||||
@@ -205,6 +217,9 @@ func (h *C2Handler) DeleteListener(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_delete", "删除 C2 监听器", "c2_listener", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
@@ -222,6 +237,9 @@ func (h *C2Handler) StartListener(c *gin.Context) {
|
||||
}
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_start", "启动 C2 监听器", "c2_listener", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"listener": listener})
|
||||
}
|
||||
|
||||
@@ -236,6 +254,9 @@ func (h *C2Handler) StopListener(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_stop", "停止 C2 监听器", "c2_listener", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"stopped": true})
|
||||
}
|
||||
|
||||
@@ -297,6 +318,9 @@ func (h *C2Handler) DeleteSession(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "session_delete", "删除 C2 会话", "c2_session", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
@@ -407,6 +431,11 @@ func (h *C2Handler) DeleteTasks(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "task_delete", "批量删除 C2 任务", "c2_task", "", map[string]interface{}{
|
||||
"count": n, "ids": req.IDs,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": n})
|
||||
}
|
||||
|
||||
@@ -457,6 +486,11 @@ func (h *C2Handler) CreateTask(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "task_create", "创建 C2 任务", "c2_task", task.ID, map[string]interface{}{
|
||||
"session_id": req.SessionID, "task_type": req.TaskType,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"task": task})
|
||||
}
|
||||
|
||||
@@ -471,6 +505,9 @@ func (h *C2Handler) CancelTask(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "task_cancel", "取消 C2 任务", "c2_task", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"cancelled": true})
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -24,6 +26,12 @@ const (
|
||||
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
|
||||
type ChatUploadsHandler struct {
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ChatUploadsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewChatUploadsHandler 创建处理器
|
||||
@@ -230,6 +238,9 @@ func (h *ChatUploadsHandler) Delete(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "file", "delete", "删除对话附件", "chat_upload", body.Path, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
@@ -503,6 +514,11 @@ func (h *ChatUploadsHandler) Upload(c *gin.Context) {
|
||||
}
|
||||
rel, _ := filepath.Rel(root, fullPath)
|
||||
absSaved, _ := filepath.Abs(fullPath)
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "file", "upload", "上传对话附件", "chat_upload", filepath.ToSlash(rel), map[string]interface{}{
|
||||
"name": unique,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"relativePath": filepath.ToSlash(rel),
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agents"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
@@ -87,6 +88,7 @@ type ConfigHandler struct {
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书
|
||||
audit *audit.Service
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
|
||||
@@ -206,6 +208,13 @@ func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) {
|
||||
h.robotRestarter = restarter
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ConfigHandler) SetAudit(s *audit.Service) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// ApplyWechatRobotBinding 微信 iLink 扫码绑定成功后写入配置并重启机器人连接
|
||||
func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) error {
|
||||
h.mu.Lock()
|
||||
@@ -310,7 +319,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
}
|
||||
multiPub := config.MultiAgentPublic{
|
||||
Enabled: h.config.MultiAgent.Enabled,
|
||||
RobotUseMultiAgent: h.config.MultiAgent.RobotUseMultiAgent,
|
||||
RobotDefaultAgentMode: config.NormalizeRobotAgentMode(h.config.MultiAgent),
|
||||
BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent,
|
||||
SubAgentCount: subAgentCount,
|
||||
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
|
||||
@@ -770,8 +779,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
// 多代理标量(sub_agents 等仍由 config.yaml 维护)
|
||||
if req.MultiAgent != nil {
|
||||
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
|
||||
h.config.MultiAgent.RobotUseMultiAgent = req.MultiAgent.RobotUseMultiAgent
|
||||
h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent
|
||||
if mode := strings.TrimSpace(req.MultiAgent.RobotDefaultAgentMode); mode != "" {
|
||||
h.config.MultiAgent.RobotDefaultAgentMode = mode
|
||||
} else {
|
||||
h.config.MultiAgent.RobotDefaultAgentMode = "react"
|
||||
}
|
||||
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
||||
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
||||
}
|
||||
@@ -780,7 +793,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
h.logger.Info("更新多代理配置",
|
||||
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
||||
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
|
||||
zap.String("robot_default_agent_mode", config.NormalizeRobotAgentMode(h.config.MultiAgent)),
|
||||
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
|
||||
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
|
||||
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
|
||||
@@ -903,6 +916,9 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "config", "update", "更新内存配置", "config", "", nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
|
||||
@@ -1033,6 +1049,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
|
||||
if _, err := knowledgeInitializer(); err != nil {
|
||||
h.logger.Error("动态初始化知识库失败", zap.Error(err))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordFail(c, "config", "apply", "应用配置失败:初始化知识库", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1067,6 +1086,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
|
||||
if _, err := reinitKnowledgeInitializer(); err != nil {
|
||||
h.logger.Error("重新初始化知识库失败", zap.Error(err))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新初始化知识库", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1080,6 +1102,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
if c2Rt != nil {
|
||||
if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil {
|
||||
h.logger.Error("C2 配置应用失败", zap.Error(err))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordFail(c, "config", "apply", "应用配置失败:C2", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1221,6 +1246,20 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
zap.Int("tools_count", len(h.config.Security.Tools)),
|
||||
)
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "config",
|
||||
Action: "apply",
|
||||
Result: "success",
|
||||
Message: "配置已应用",
|
||||
Detail: map[string]interface{}{
|
||||
"tools_count": len(h.config.Security.Tools),
|
||||
"knowledge_enabled": h.config.Knowledge.Enabled,
|
||||
"c2_enabled": h.config.C2.EnabledEffective(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "配置已应用",
|
||||
"tools_count": len(h.config.Security.Tools),
|
||||
@@ -1536,7 +1575,7 @@ func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
|
||||
root := doc.Content[0]
|
||||
maNode := ensureMap(root, "multi_agent")
|
||||
setBoolInMap(maNode, "enabled", cfg.Enabled)
|
||||
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent)
|
||||
setStringInMap(maNode, "robot_default_agent_mode", config.NormalizeRobotAgentMode(cfg))
|
||||
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
|
||||
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
|
||||
mwNode := ensureMap(maNode, "eino_middleware")
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -14,6 +16,12 @@ import (
|
||||
type ConversationHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ConversationHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewConversationHandler 创建新的对话处理器
|
||||
@@ -26,7 +34,13 @@ func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHa
|
||||
|
||||
// CreateConversationRequest 创建对话请求
|
||||
type CreateConversationRequest struct {
|
||||
Title string `json:"title"`
|
||||
Title string `json:"title"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
}
|
||||
|
||||
// SetConversationProjectRequest 设置对话所属项目
|
||||
type SetConversationProjectRequest struct {
|
||||
ProjectID string `json:"projectId"` // 空字符串表示解除绑定
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
@@ -42,7 +56,9 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
||||
title = "新对话"
|
||||
}
|
||||
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "api")
|
||||
meta.ProjectID = strings.TrimSpace(req.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -52,6 +68,25 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, conv)
|
||||
}
|
||||
|
||||
// SetConversationProject 设置或清除对话绑定的项目
|
||||
func (h *ConversationHandler) SetConversationProject(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req SetConversationProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if _, err := h.db.GetConversation(id); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||
return
|
||||
}
|
||||
if err := h.db.SetConversationProjectID(id, req.ProjectID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "projectId": strings.TrimSpace(req.ProjectID)})
|
||||
}
|
||||
|
||||
// ListConversations 列出对话
|
||||
func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "50")
|
||||
@@ -189,6 +224,17 @@ func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "conversation",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "conversation",
|
||||
ResourceID: id,
|
||||
Message: "删除对话",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
@@ -227,6 +273,12 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "conversation", "delete_turn", "删除对话轮次", "conversation", conversationID, map[string]interface{}{
|
||||
"message_id": req.MessageID,
|
||||
"deleted": len(deletedIDs),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"deletedMessageIds": deletedIDs,
|
||||
"message": "ok",
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
)
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
||||
if h.config != nil {
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
||||
}
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
||||
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
||||
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if segmentUserMessage != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
||||
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
||||
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
||||
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
transientAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, fatal error) {
|
||||
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
||||
return false, nil
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*transientAttempts++
|
||||
if *transientAttempts > maxAttempts {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
||||
}
|
||||
attemptNo := *transientAttempts
|
||||
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
}
|
||||
select {
|
||||
case <-baseCtx.Done():
|
||||
return false, context.Cause(baseCtx)
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
})
|
||||
}
|
||||
if sendProgress != nil {
|
||||
sendProgress("正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "transient_retry",
|
||||
})
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
|
||||
// EinoSingleAgentLoopStream Eino ADK 单代理(ChatModelAgent + Runner)流式对话;不依赖 multi_agent.enabled。
|
||||
func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
@@ -90,7 +90,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream")
|
||||
if err != nil {
|
||||
sendEvent("error", err.Error(), nil)
|
||||
sendEvent("done", "", nil)
|
||||
@@ -119,6 +119,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
|
||||
@@ -176,9 +177,41 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
taskOwned = true
|
||||
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
for {
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
segmentMainIterationMax := 0
|
||||
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
if eventType == "iteration" {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||
raw := 0
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
raw = v
|
||||
case int32:
|
||||
raw = int(v)
|
||||
case int64:
|
||||
raw = int(v)
|
||||
case float64:
|
||||
raw = int(v)
|
||||
case float32:
|
||||
raw = int(v)
|
||||
}
|
||||
if raw > 0 {
|
||||
if raw > segmentMainIterationMax {
|
||||
segmentMainIterationMax = raw
|
||||
}
|
||||
m["iteration"] = raw + mainIterationOffset
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rawProgressCallback(eventType, message, data)
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
@@ -197,17 +230,38 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
roleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
timeoutCancel()
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -231,10 +285,14 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"conversationId": conversationID,
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -261,6 +319,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -278,6 +337,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -294,9 +354,12 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
timeoutCancel()
|
||||
|
||||
if assistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
@@ -326,7 +389,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
|
||||
h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID))
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -367,6 +430,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
prep.RoleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
if runErr != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
@@ -20,9 +21,15 @@ type ExternalMCPHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ExternalMCPHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewExternalMCPHandler 创建外部MCP处理器
|
||||
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
|
||||
return &ExternalMCPHandler{
|
||||
@@ -180,6 +187,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "external_mcp",
|
||||
Action: "upsert",
|
||||
Result: "success",
|
||||
ResourceType: "external_mcp",
|
||||
ResourceID: name,
|
||||
Message: "更新外部 MCP 配置",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
|
||||
@@ -209,6 +226,16 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "external_mcp",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "external_mcp",
|
||||
ResourceID: name,
|
||||
Message: "删除外部 MCP 配置",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
|
||||
}
|
||||
|
||||
|
||||
@@ -616,6 +616,11 @@ func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "decision", "HITL 审批决策", "hitl_interrupt", req.InterruptID, map[string]interface{}{
|
||||
"decision": req.Decision,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
|
||||
@@ -20,6 +21,12 @@ type KnowledgeHandler struct {
|
||||
indexer *knowledge.Indexer
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *KnowledgeHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewKnowledgeHandler 创建新的知识库处理器
|
||||
@@ -303,6 +310,9 @@ func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "knowledge", "item_delete", "删除知识项", "knowledge_item", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
@@ -316,6 +326,9 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "knowledge", "index_rebuild", "重建知识库索引", "knowledge", "", nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agents"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -18,7 +19,8 @@ var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.m
|
||||
|
||||
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
|
||||
type MarkdownAgentsHandler struct {
|
||||
dir string
|
||||
dir string
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
|
||||
@@ -26,6 +28,11 @@ func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler {
|
||||
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *MarkdownAgentsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
|
||||
filename = strings.TrimSpace(filename)
|
||||
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
|
||||
@@ -227,6 +234,9 @@ func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_create", "创建 Markdown 子代理", "markdown_agent", filepath.Base(path), nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
|
||||
}
|
||||
|
||||
@@ -294,6 +304,9 @@ func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_update", "更新 Markdown 子代理", "markdown_agent", filename, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已保存"})
|
||||
}
|
||||
|
||||
@@ -313,5 +326,8 @@ func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_delete", "删除 Markdown 子代理", "markdown_agent", filename, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已删除"})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
@@ -23,6 +24,12 @@ type MonitorHandler struct {
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *MonitorHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewMonitorHandler 创建新的监控处理器
|
||||
@@ -365,6 +372,11 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "tool", "execution_delete", "删除工具执行记录", "tool_execution", id, map[string]interface{}{
|
||||
"tool_name": exec.ToolName,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
|
||||
return
|
||||
}
|
||||
@@ -440,6 +452,11 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "tool", "execution_delete_batch", "批量删除工具执行记录", "tool_execution", "", map[string]interface{}{
|
||||
"count": len(request.IDs),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
|
||||
// MultiAgentLoopStream Eino DeepAgent 流式对话(需 config.multi_agent.enabled)。
|
||||
func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
if h.config == nil || !h.config.MultiAgent.Enabled {
|
||||
@@ -107,7 +107,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream")
|
||||
if err != nil {
|
||||
sendEvent("error", err.Error(), nil)
|
||||
sendEvent("done", "", nil)
|
||||
@@ -136,6 +136,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
orch := strings.TrimSpace(req.Orchestration)
|
||||
@@ -186,9 +187,41 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
for {
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
segmentMainIterationMax := 0
|
||||
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
if eventType == "iteration" {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||
raw := 0
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
raw = v
|
||||
case int32:
|
||||
raw = int(v)
|
||||
case int64:
|
||||
raw = int(v)
|
||||
case float64:
|
||||
raw = int(v)
|
||||
case float32:
|
||||
raw = int(v)
|
||||
}
|
||||
if raw > 0 {
|
||||
if raw > segmentMainIterationMax {
|
||||
segmentMainIterationMax = raw
|
||||
}
|
||||
m["iteration"] = raw + mainIterationOffset
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rawProgressCallback(eventType, message, data)
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
@@ -209,17 +242,38 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
h.agentsMarkdownDir,
|
||||
orch,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
timeoutCancel()
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -243,10 +297,14 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"conversationId": conversationID,
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -273,6 +331,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -290,6 +349,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -306,9 +366,12 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
timeoutCancel()
|
||||
|
||||
if assistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
@@ -347,7 +410,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
|
||||
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent")
|
||||
if err != nil {
|
||||
status, msg := multiAgentHTTPErrorStatus(err)
|
||||
c.JSON(status, gin.H{"error": msg})
|
||||
@@ -381,6 +444,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
h.agentsMarkdownDir,
|
||||
strings.TrimSpace(req.Orchestration),
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
if runErr != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -22,7 +24,7 @@ type multiAgentPrepared struct {
|
||||
UserMessageID string
|
||||
}
|
||||
|
||||
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) {
|
||||
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) {
|
||||
if len(req.Attachments) > maxAttachments {
|
||||
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
|
||||
}
|
||||
@@ -33,10 +35,14 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
var conv *database.Conversation
|
||||
var err error
|
||||
meta := audit.ConversationCreateMetaFromGin(c, source)
|
||||
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||
if strings.TrimSpace(req.WebShellConnectionID) != "" {
|
||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
|
||||
meta.Source = source + "_webshell"
|
||||
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
|
||||
conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta)
|
||||
} else {
|
||||
conv, err = h.db.CreateConversation(title)
|
||||
conv, err = h.db.CreateConversation(title, meta)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||
@@ -85,6 +91,14 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
||||
builtin.ToolWebshellFileRead,
|
||||
builtin.ToolWebshellFileWrite,
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
builtin.ToolUpsertProjectFact,
|
||||
builtin.ToolGetProjectFact,
|
||||
builtin.ToolListProjectFacts,
|
||||
builtin.ToolSearchProjectFacts,
|
||||
builtin.ToolDeprecateProjectFact,
|
||||
builtin.ToolRestoreProjectFact,
|
||||
builtin.ToolListKnowledgeRiskTypes,
|
||||
builtin.ToolSearchKnowledgeBase,
|
||||
}
|
||||
|
||||
+139
-1
@@ -73,8 +73,22 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"description": "对话标题",
|
||||
"example": "Web应用安全测试",
|
||||
},
|
||||
"projectId": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "绑定的项目 ID(可选,共享事实黑板)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"SetConversationProjectRequest": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"projectId": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "项目 ID;空字符串表示解除绑定",
|
||||
},
|
||||
},
|
||||
"required": []string{"projectId"},
|
||||
},
|
||||
"Conversation": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
@@ -98,6 +112,10 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"format": "date-time",
|
||||
"description": "更新时间",
|
||||
},
|
||||
"projectId": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "绑定的项目 ID(可选)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"ConversationDetail": map[string]interface{}{
|
||||
@@ -1326,6 +1344,37 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/conversations/{id}/project": map[string]interface{}{
|
||||
"put": map[string]interface{}{
|
||||
"tags": []string{"对话管理"},
|
||||
"summary": "设置对话所属项目",
|
||||
"description": "绑定或解除对话与项目的关联,用于共享事实黑板",
|
||||
"operationId": "setConversationProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{
|
||||
"name": "id", "in": "path", "required": true,
|
||||
"description": "对话ID",
|
||||
"schema": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
"requestBody": map[string]interface{}{
|
||||
"required": true,
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/components/schemas/SetConversationProjectRequest",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{"description": "设置成功"},
|
||||
"400": map[string]interface{}{"description": "项目不存在或参数错误"},
|
||||
"404": map[string]interface{}{"description": "对话不存在"},
|
||||
"401": map[string]interface{}{"description": "未授权"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/conversations/{id}/results": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"对话管理"},
|
||||
@@ -2444,6 +2493,86 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/projects": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"},
|
||||
"summary": "列出项目",
|
||||
"operationId": "listProjects",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "status", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"active", "archived"}}},
|
||||
{"name": "limit", "in": "query", "schema": map[string]interface{}{"type": "integer", "default": 200}},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{"description": "项目列表"},
|
||||
"401": map[string]interface{}{"description": "未授权"},
|
||||
},
|
||||
},
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"项目管理"},
|
||||
"summary": "创建项目",
|
||||
"operationId": "createProject",
|
||||
"requestBody": map[string]interface{}{
|
||||
"required": true,
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"name": map[string]interface{}{"type": "string"},
|
||||
"description": map[string]interface{}{"type": "string"},
|
||||
"scope_json": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{"description": "创建成功"},
|
||||
"401": map[string]interface{}{"description": "未授权"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/projects/{id}": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "获取项目", "operationId": "getProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "项目详情"}},
|
||||
},
|
||||
"put": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "更新项目", "operationId": "updateProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "更新成功"}},
|
||||
},
|
||||
"delete": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "删除项目", "operationId": "deleteProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}},
|
||||
},
|
||||
},
|
||||
"/api/projects/{id}/facts": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "列出或按 key 获取事实", "operationId": "listProjectFacts",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}},
|
||||
},
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
|
||||
},
|
||||
},
|
||||
"/api/vulnerabilities": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"漏洞管理"},
|
||||
@@ -2502,6 +2631,15 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "project_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"description": "项目ID",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "severity",
|
||||
"in": "query",
|
||||
@@ -6254,7 +6392,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取漏洞列表
|
||||
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "")
|
||||
vulnList, err := h.db.ListVulnerabilities(1000, 0, database.VulnerabilityListFilter{ConversationID: conversationID})
|
||||
if err != nil {
|
||||
h.logger.Warn("获取漏洞列表失败", zap.Error(err))
|
||||
vulnList = []*database.Vulnerability{}
|
||||
|
||||
@@ -0,0 +1,400 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/project"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProjectHandler 项目管理处理器。
|
||||
type ProjectHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProjectHandler 创建项目管理处理器。
|
||||
func NewProjectHandler(db *database.DB, logger *zap.Logger) *ProjectHandler {
|
||||
return &ProjectHandler{db: db, logger: logger}
|
||||
}
|
||||
|
||||
type createProjectRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
ScopeJSON string `json:"scope_json"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// updateProjectRequest 部分更新:字段省略表示不修改;传 null 或 "" 可清空字符串字段。
|
||||
type updateProjectRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
ScopeJSON *string `json:"scope_json"`
|
||||
Status *string `json:"status"`
|
||||
Pinned *bool `json:"pinned"`
|
||||
}
|
||||
|
||||
// CreateProject POST /api/projects
|
||||
func (h *ProjectHandler) CreateProject(c *gin.Context) {
|
||||
var req createProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
p := &database.Project{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Description: req.Description,
|
||||
ScopeJSON: req.ScopeJSON,
|
||||
Status: strings.TrimSpace(req.Status),
|
||||
}
|
||||
created, err := h.db.CreateProject(p)
|
||||
if err != nil {
|
||||
h.logger.Error("创建项目失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// ListProjects GET /api/projects
|
||||
func (h *ProjectHandler) ListProjects(c *gin.Context) {
|
||||
status := c.Query("status")
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "200"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
list, err := h.db.ListProjects(status, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.Project{}
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// GetProjectStats GET /api/projects/:id/stats
|
||||
func (h *ProjectHandler) GetProjectStats(c *gin.Context) {
|
||||
stats, err := project.GetProjectStats(h.db, c.Param("id"))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "不存在") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// ListProjectConversations GET /api/projects/:id/conversations
|
||||
func (h *ProjectHandler) ListProjectConversations(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
if _, err := h.db.GetProject(projectID); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
list, err := h.db.ListConversationsByProjectID(projectID, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.Conversation{}
|
||||
}
|
||||
total, _ := h.db.CountConversationsByProjectID(projectID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"conversations": list,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// GetProject GET /api/projects/:id
|
||||
func (h *ProjectHandler) GetProject(c *gin.Context) {
|
||||
p, err := h.db.GetProject(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, p)
|
||||
}
|
||||
|
||||
// UpdateProject PUT /api/projects/:id
|
||||
func (h *ProjectHandler) UpdateProject(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
p, err := h.db.GetProject(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
var req updateProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.Name != nil {
|
||||
if s := strings.TrimSpace(*req.Name); s != "" {
|
||||
p.Name = s
|
||||
}
|
||||
}
|
||||
if req.Description != nil {
|
||||
p.Description = *req.Description
|
||||
}
|
||||
if req.ScopeJSON != nil {
|
||||
p.ScopeJSON = *req.ScopeJSON
|
||||
}
|
||||
if req.Status != nil {
|
||||
if s := strings.TrimSpace(*req.Status); s != "" {
|
||||
p.Status = s
|
||||
}
|
||||
}
|
||||
if req.Pinned != nil {
|
||||
p.Pinned = *req.Pinned
|
||||
}
|
||||
if err := h.db.UpdateProject(p); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, p)
|
||||
}
|
||||
|
||||
// DeleteProject DELETE /api/projects/:id
|
||||
func (h *ProjectHandler) DeleteProject(c *gin.Context) {
|
||||
if err := h.db.DeleteProject(c.Param("id")); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type upsertFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
Category string `json:"category"`
|
||||
Summary string `json:"summary" binding:"required"`
|
||||
Body string `json:"body"`
|
||||
Confidence string `json:"confidence"`
|
||||
Pinned bool `json:"pinned"`
|
||||
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
||||
}
|
||||
|
||||
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
|
||||
type updateFactRequest struct {
|
||||
FactKey *string `json:"fact_key"`
|
||||
Category *string `json:"category"`
|
||||
Summary *string `json:"summary"`
|
||||
Body *string `json:"body"`
|
||||
Confidence *string `json:"confidence"`
|
||||
Pinned *bool `json:"pinned"`
|
||||
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
|
||||
ClearBody bool `json:"clear_body"`
|
||||
}
|
||||
|
||||
// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情)
|
||||
func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
if key := strings.TrimSpace(c.Query("fact_key")); key != "" {
|
||||
f, err := h.db.GetProjectFactByKey(projectID, key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, f)
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
filter := database.ProjectFactListFilter{
|
||||
Category: c.Query("category"),
|
||||
Confidence: c.Query("confidence"),
|
||||
Search: c.Query("search"),
|
||||
RelatedVulnerabilityID: c.Query("related_vulnerability_id"),
|
||||
}
|
||||
if c.Query("exclude_deprecated") == "1" || c.Query("exclude_deprecated") == "true" {
|
||||
filter.ExcludeDeprecated = true
|
||||
}
|
||||
list, err := h.db.ListProjectFacts(projectID, filter, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.ProjectFact{}
|
||||
}
|
||||
if sparseOnly := c.Query("sparse_only"); sparseOnly == "1" || sparseOnly == "true" {
|
||||
filtered := make([]*database.ProjectFact, 0, len(list))
|
||||
for _, f := range list {
|
||||
if project.IsSparseFactBody(f.Category, f.FactKey, f.Body) {
|
||||
filtered = append(filtered, f)
|
||||
}
|
||||
}
|
||||
list = filtered
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// GetFactPreviousVersion GET /api/projects/:id/facts/:factId/previous-version
|
||||
func (h *ProjectHandler) GetFactPreviousVersion(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(existing.SupersedesFactID) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "无上一版本"})
|
||||
return
|
||||
}
|
||||
v, err := h.db.GetProjectFactVersion(existing.SupersedesFactID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, v)
|
||||
}
|
||||
|
||||
// ListFactVersions GET /api/projects/:id/facts/:factId/versions
|
||||
func (h *ProjectHandler) ListFactVersions(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
list, err := h.db.ListProjectFactVersions(existing.ID, limit)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.ProjectFactVersion{}
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// CreateFact POST /api/projects/:id/facts
|
||||
func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
||||
var req upsertFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
f := &database.ProjectFact{
|
||||
ProjectID: c.Param("id"),
|
||||
FactKey: req.FactKey,
|
||||
Category: req.Category,
|
||||
Summary: req.Summary,
|
||||
Body: req.Body,
|
||||
Confidence: req.Confidence,
|
||||
Pinned: req.Pinned,
|
||||
RelatedVulnerabilityID: req.RelatedVulnerabilityID,
|
||||
}
|
||||
created, err := h.db.UpsertProjectFact(f)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// UpdateFact PUT /api/projects/:id/facts/:factId
|
||||
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
var req updateFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.FactKey != nil {
|
||||
if k := strings.TrimSpace(*req.FactKey); k != "" {
|
||||
existing.FactKey = k
|
||||
}
|
||||
}
|
||||
if req.Category != nil && strings.TrimSpace(*req.Category) != "" {
|
||||
existing.Category = *req.Category
|
||||
}
|
||||
if req.Summary != nil && strings.TrimSpace(*req.Summary) != "" {
|
||||
existing.Summary = *req.Summary
|
||||
}
|
||||
if req.ClearBody {
|
||||
existing.Body = ""
|
||||
} else if req.Body != nil {
|
||||
existing.Body = *req.Body
|
||||
}
|
||||
if req.Confidence != nil && strings.TrimSpace(*req.Confidence) != "" {
|
||||
existing.Confidence = *req.Confidence
|
||||
}
|
||||
if req.Pinned != nil {
|
||||
existing.Pinned = *req.Pinned
|
||||
}
|
||||
if req.RelatedVulnerabilityID != nil {
|
||||
existing.RelatedVulnerabilityID = *req.RelatedVulnerabilityID
|
||||
}
|
||||
updated, err := h.db.UpsertProjectFact(existing)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
// DeleteFact DELETE /api/projects/:id/facts/:factId
|
||||
func (h *ProjectHandler) DeleteFact(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteProjectFact(existing.ID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type deprecateFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
}
|
||||
|
||||
// DeprecateFact POST /api/projects/:id/facts/deprecate
|
||||
func (h *ProjectHandler) DeprecateFact(c *gin.Context) {
|
||||
var req deprecateFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeprecateProjectFact(c.Param("id"), req.FactKey); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type restoreFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
Confidence string `json:"confidence"` // 可选:confirmed | tentative,默认 tentative
|
||||
}
|
||||
|
||||
// RestoreFact POST /api/projects/:id/facts/restore
|
||||
func (h *ProjectHandler) RestoreFact(c *gin.Context) {
|
||||
var req restoreFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.db.RestoreProjectFact(c.Param("id"), req.FactKey, req.Confidence); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/project"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
|
||||
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
||||
if h == nil || h.db == nil || h.config == nil {
|
||||
return ""
|
||||
}
|
||||
if !h.config.Project.Enabled {
|
||||
return ""
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
projectID, err := h.db.GetConversationProjectID(conversationID)
|
||||
if err != nil || projectID == "" {
|
||||
return ""
|
||||
}
|
||||
block, err := project.BuildProjectBlackboardBlock(h.db, projectID, h.config.Project)
|
||||
if err != nil {
|
||||
h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(block)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
)
|
||||
|
||||
// effectiveProjectID 请求/队列显式项目优先,否则使用 config.project.default_project_id。
|
||||
func effectiveProjectID(cfg *config.Config, explicit string) string {
|
||||
if pid := strings.TrimSpace(explicit); pid != "" {
|
||||
return pid
|
||||
}
|
||||
if cfg != nil {
|
||||
return strings.TrimSpace(cfg.Project.DefaultProjectID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
+219
-22
@@ -40,8 +40,13 @@ const (
|
||||
robotCmdRoles = "角色"
|
||||
robotCmdRolesList = "角色列表"
|
||||
robotCmdSwitchRole = "切换角色"
|
||||
robotCmdDelete = "删除"
|
||||
robotCmdVersion = "版本"
|
||||
robotCmdDelete = "删除"
|
||||
robotCmdVersion = "版本"
|
||||
robotCmdProjects = "项目"
|
||||
robotCmdProjectsList = "项目列表"
|
||||
robotCmdBindProject = "绑定项目"
|
||||
robotCmdNewProject = "新建项目"
|
||||
robotCmdUnbindProject = "解除项目"
|
||||
)
|
||||
|
||||
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
|
||||
@@ -133,7 +138,9 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (
|
||||
} else {
|
||||
t = safeTruncateString(t, 50)
|
||||
}
|
||||
conv, err := h.db.CreateConversation(t)
|
||||
meta := database.ConversationCreateMeta{Source: "robot:" + platform}
|
||||
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||
conv, err := h.db.CreateConversation(t, meta)
|
||||
if err != nil {
|
||||
h.logger.Warn("创建机器人会话失败", zap.Error(err))
|
||||
return "", false
|
||||
@@ -188,7 +195,9 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) {
|
||||
// clearConversation 清空当前会话(切换到新对话)
|
||||
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
|
||||
title := "新对话 " + time.Now().Format("01-02 15:04")
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
meta := database.ConversationCreateMeta{Source: "robot:" + platform + ":new"}
|
||||
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||
conv, err := h.db.CreateConversation(title, meta)
|
||||
if err != nil {
|
||||
h.logger.Warn("创建新对话失败", zap.Error(err))
|
||||
return ""
|
||||
@@ -230,7 +239,7 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
||||
_ = h.db.UpdateConversationTitle(convID, newTitle)
|
||||
}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), h.robotMessageTimeout())
|
||||
sk := h.sessionKey(platform, userID)
|
||||
h.cancelMu.Lock()
|
||||
h.runningCancels[sk] = cancel
|
||||
@@ -242,12 +251,15 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
||||
h.cancelMu.Unlock()
|
||||
}()
|
||||
role := h.getRole(platform, userID)
|
||||
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role)
|
||||
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, platform, convID, text, role)
|
||||
if err != nil {
|
||||
h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err))
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return "任务已取消。"
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return "任务执行超时,请稍后重试或精简本次请求范围。"
|
||||
}
|
||||
return "处理失败: " + err.Error()
|
||||
}
|
||||
if newConvID != convID {
|
||||
@@ -256,22 +268,182 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
||||
return resp
|
||||
}
|
||||
|
||||
func (h *RobotHandler) robotMessageTimeout() time.Duration {
|
||||
// 机器人整次消息处理超时(与单次工具超时 agent.tool_timeout_minutes 解耦)。
|
||||
return 10 * time.Hour
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdHelp() string {
|
||||
return "**【CyberStrikeAI 机器人命令】**\n\n" +
|
||||
"- `帮助` `help` — 显示本帮助 | Show this help\n" +
|
||||
"- `列表` `list` — 列出所有对话标题与 ID | List conversations\n" +
|
||||
"- `切换 <ID>` `switch <ID>` — 指定对话继续 | Switch to conversation\n" +
|
||||
"- `新对话` `new` — 开启新对话 | Start new conversation\n" +
|
||||
"- `清空` `clear` — 清空当前上下文 | Clear context\n" +
|
||||
"- `当前` `current` — 显示当前对话 ID 与标题 | Show current conversation\n" +
|
||||
"- `停止` `stop` — 中断当前任务 | Stop running task\n" +
|
||||
"- `角色` `roles` — 列出所有可用角色 | List roles\n" +
|
||||
"- `角色 <名>` `role <name>` — 切换当前角色 | Switch role\n" +
|
||||
"- `删除 <ID>` `delete <ID>` — 删除指定对话 | Delete conversation\n" +
|
||||
"- `版本` `version` — 显示当前版本号 | Show version\n\n" +
|
||||
"---\n" +
|
||||
"除以上命令外,直接输入内容将发送给 AI 进行渗透测试/安全分析。\n" +
|
||||
"Otherwise, send any text for AI penetration testing / security analysis."
|
||||
var b strings.Builder
|
||||
b.WriteString("【CyberStrikeAI 机器人命令】\n\n")
|
||||
b.WriteString("【通用 General】\n")
|
||||
b.WriteString("· 帮助 / help — 显示本帮助\n")
|
||||
b.WriteString("· 版本 / version — 显示当前版本号\n")
|
||||
b.WriteString("\n【对话 Conversation】\n")
|
||||
b.WriteString("· 列表 / list — 列出所有对话标题与 ID\n")
|
||||
b.WriteString("· 切换 <ID> / switch <ID> — 指定对话继续\n")
|
||||
b.WriteString("· 新对话 / new — 开启新对话\n")
|
||||
b.WriteString("· 清空 / clear — 清空当前上下文\n")
|
||||
b.WriteString("· 当前 / current — 显示当前对话、角色与项目\n")
|
||||
b.WriteString("· 停止 / stop — 中断当前任务\n")
|
||||
b.WriteString("· 删除 <ID> / delete <ID> — 删除指定对话\n")
|
||||
b.WriteString("\n【角色 Role】\n")
|
||||
b.WriteString("· 角色 / roles — 列出所有可用角色\n")
|
||||
b.WriteString("· 角色 <名> / role <name> — 切换当前角色\n")
|
||||
if h.projectsEnabled() {
|
||||
b.WriteString("\n【项目 Project】\n")
|
||||
b.WriteString("· 项目 / projects — 列出所有项目\n")
|
||||
b.WriteString("· 新建项目 <名称> / new project <name> — 创建并绑定当前对话\n")
|
||||
b.WriteString("· 绑定项目 <ID或名称> / bind project <ID|name> — 绑定到已有项目\n")
|
||||
b.WriteString("· 解除项目 / unbind project — 解除项目绑定\n")
|
||||
}
|
||||
b.WriteString("\n──────────────\n")
|
||||
b.WriteString("除以上命令外,直接输入内容将发送给 AI 进行渗透测试/安全分析。")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (h *RobotHandler) projectsEnabled() bool {
|
||||
return h.config != nil && h.config.Project.Enabled
|
||||
}
|
||||
|
||||
func (h *RobotHandler) resolveProjectByIDOrName(idOrName string) (*database.Project, string) {
|
||||
idOrName = strings.TrimSpace(idOrName)
|
||||
if idOrName == "" {
|
||||
return nil, "请指定项目 ID 或名称,例如:绑定项目 xxx-xxx"
|
||||
}
|
||||
if p, err := h.db.GetProject(idOrName); err == nil {
|
||||
return p, ""
|
||||
}
|
||||
list, err := h.db.ListProjects("", 200, 0)
|
||||
if err != nil {
|
||||
return nil, "查询项目失败: " + err.Error()
|
||||
}
|
||||
var matches []*database.Project
|
||||
for _, p := range list {
|
||||
if p.Name == idOrName {
|
||||
matches = append(matches, p)
|
||||
}
|
||||
}
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
return nil, fmt.Sprintf("项目「%s」不存在。发送「项目」查看列表。", idOrName)
|
||||
case 1:
|
||||
return matches[0], ""
|
||||
default:
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("名称「%s」匹配到多个项目,请使用 ID 绑定:\n", idOrName))
|
||||
for _, p := range matches {
|
||||
b.WriteString(fmt.Sprintf("· %s\n ID: %s\n", p.Name, p.ID))
|
||||
}
|
||||
return nil, strings.TrimSuffix(b.String(), "\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RobotHandler) formatProjectLabel(projectID string) string {
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
return "未绑定"
|
||||
}
|
||||
if p, err := h.db.GetProject(projectID); err == nil {
|
||||
return fmt.Sprintf("「%s」 (%s)", p.Name, p.ID)
|
||||
}
|
||||
return projectID
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdProjects() string {
|
||||
if !h.projectsEnabled() {
|
||||
return "项目功能未启用(config.project.enabled)。"
|
||||
}
|
||||
list, err := h.db.ListProjects("", 50, 0)
|
||||
if err != nil {
|
||||
return "获取项目列表失败: " + err.Error()
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return "暂无项目。发送「新建项目 <名称>」创建并绑定到当前对话。"
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("【项目列表】\n")
|
||||
for i, p := range list {
|
||||
if i >= 20 {
|
||||
b.WriteString("… 仅显示前 20 条\n")
|
||||
break
|
||||
}
|
||||
status := p.Status
|
||||
if status == "" {
|
||||
status = "active"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("· %s [%s]\n ID: %s\n", p.Name, status, p.ID))
|
||||
}
|
||||
return strings.TrimSuffix(b.String(), "\n")
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdBindProject(platform, userID, idOrName string) string {
|
||||
if !h.projectsEnabled() {
|
||||
return "项目功能未启用(config.project.enabled)。"
|
||||
}
|
||||
p, errMsg := h.resolveProjectByIDOrName(idOrName)
|
||||
if p == nil {
|
||||
return errMsg
|
||||
}
|
||||
convID, _ := h.getOrCreateConversation(platform, userID, "")
|
||||
if convID == "" {
|
||||
return "无法获取当前对话,请稍后再试。"
|
||||
}
|
||||
if err := h.db.SetConversationProjectID(convID, p.ID); err != nil {
|
||||
return "绑定失败: " + err.Error()
|
||||
}
|
||||
return fmt.Sprintf("已将当前对话绑定到项目:「%s」\nID: %s", p.Name, p.ID)
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdNewProject(platform, userID, name string) string {
|
||||
if !h.projectsEnabled() {
|
||||
return "项目功能未启用(config.project.enabled)。"
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return "请指定项目名称,例如:新建项目 某目标渗透"
|
||||
}
|
||||
p := &database.Project{Name: name, Status: "active"}
|
||||
created, err := h.db.CreateProject(p)
|
||||
if err != nil {
|
||||
return "创建项目失败: " + err.Error()
|
||||
}
|
||||
convID, _ := h.getOrCreateConversation(platform, userID, name)
|
||||
if convID == "" {
|
||||
return fmt.Sprintf("项目已创建:「%s」\nID: %s\n(绑定当前对话失败,请手动发送「绑定项目 %s」)", created.Name, created.ID, created.ID)
|
||||
}
|
||||
if err := h.db.SetConversationProjectID(convID, created.ID); err != nil {
|
||||
return fmt.Sprintf("项目已创建:「%s」\nID: %s\n绑定失败: %s", created.Name, created.ID, err.Error())
|
||||
}
|
||||
return fmt.Sprintf("已创建项目并绑定当前对话:「%s」\nID: %s", created.Name, created.ID)
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdUnbindProject(platform, userID string) string {
|
||||
if !h.projectsEnabled() {
|
||||
return "项目功能未启用(config.project.enabled)。"
|
||||
}
|
||||
sk := h.sessionKey(platform, userID)
|
||||
h.mu.RLock()
|
||||
convID := h.sessions[sk]
|
||||
h.mu.RUnlock()
|
||||
if convID == "" {
|
||||
if persistedConvID, _ := h.loadSessionBinding(sk); persistedConvID != "" {
|
||||
convID = persistedConvID
|
||||
}
|
||||
}
|
||||
if convID == "" {
|
||||
return "当前没有进行中的对话,无需解除绑定。"
|
||||
}
|
||||
projectID, err := h.db.GetConversationProjectID(convID)
|
||||
if err != nil {
|
||||
return "获取对话项目失败: " + err.Error()
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
return "当前对话未绑定项目。"
|
||||
}
|
||||
if err := h.db.SetConversationProjectID(convID, ""); err != nil {
|
||||
return "解除绑定失败: " + err.Error()
|
||||
}
|
||||
return "已解除当前对话的项目绑定。"
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdList() string {
|
||||
@@ -345,7 +517,12 @@ func (h *RobotHandler) cmdCurrent(platform, userID string) string {
|
||||
return "当前对话 ID: " + convID + "(获取标题失败)"
|
||||
}
|
||||
role := h.getRole(platform, userID)
|
||||
return fmt.Sprintf("当前对话:「%s」\nID: %s\n当前角色: %s", conv.Title, conv.ID, role)
|
||||
reply := fmt.Sprintf("当前对话:「%s」\nID: %s\n当前角色: %s", conv.Title, conv.ID, role)
|
||||
if h.projectsEnabled() {
|
||||
projectID, _ := h.db.GetConversationProjectID(conv.ID)
|
||||
reply += "\n当前项目: " + h.formatProjectLabel(projectID)
|
||||
}
|
||||
return reply
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdRoles() string {
|
||||
@@ -482,6 +659,26 @@ func (h *RobotHandler) handleRobotCommand(platform, userID, text string) (string
|
||||
return h.cmdDelete(platform, userID, convID), true
|
||||
case text == robotCmdVersion || text == "version":
|
||||
return h.cmdVersion(), true
|
||||
case text == robotCmdProjects || text == robotCmdProjectsList || text == "projects":
|
||||
return h.cmdProjects(), true
|
||||
case text == robotCmdUnbindProject || text == "unbind project":
|
||||
return h.cmdUnbindProject(platform, userID), true
|
||||
case strings.HasPrefix(text, robotCmdNewProject+" ") || strings.HasPrefix(text, "new project "):
|
||||
var name string
|
||||
if strings.HasPrefix(text, robotCmdNewProject+" ") {
|
||||
name = strings.TrimSpace(text[len(robotCmdNewProject)+1:])
|
||||
} else {
|
||||
name = strings.TrimSpace(text[len("new project "):])
|
||||
}
|
||||
return h.cmdNewProject(platform, userID, name), true
|
||||
case strings.HasPrefix(text, robotCmdBindProject+" ") || strings.HasPrefix(text, "bind project "):
|
||||
var idOrName string
|
||||
if strings.HasPrefix(text, robotCmdBindProject+" ") {
|
||||
idOrName = strings.TrimSpace(text[len(robotCmdBindProject)+1:])
|
||||
} else {
|
||||
idOrName = strings.TrimSpace(text[len("bind project "):])
|
||||
}
|
||||
return h.cmdBindProject(platform, userID, idOrName), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -21,6 +22,12 @@ type RoleHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *RoleHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewRoleHandler 创建新的角色处理器
|
||||
@@ -174,6 +181,9 @@ func (h *RoleHandler) UpdateRole(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "update", "更新角色", "role", finalKey, map[string]interface{}{"name": req.Name})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已更新",
|
||||
"role": req,
|
||||
@@ -219,6 +229,9 @@ func (h *RoleHandler) CreateRole(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("创建角色", zap.String("roleName", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "create", "创建角色", "role", req.Name, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已创建",
|
||||
"role": req,
|
||||
@@ -287,6 +300,9 @@ func (h *RoleHandler) DeleteRole(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("删除角色", zap.String("roleName", roleName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "delete", "删除角色", "role", roleName, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已删除",
|
||||
})
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/skillpackage"
|
||||
@@ -23,6 +24,12 @@ type SkillsHandler struct {
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除)
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *SkillsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewSkillsHandler 创建新的Skills处理器
|
||||
@@ -365,6 +372,9 @@ func (h *SkillsHandler) CreateSkill(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "create", "创建 Skill", "skill", req.Name, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已创建",
|
||||
"skill": map[string]interface{}{
|
||||
@@ -425,6 +435,9 @@ func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("更新skill成功", zap.String("skill", skillName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "update", "更新 Skill", "skill", skillName, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已更新",
|
||||
})
|
||||
@@ -459,6 +472,11 @@ func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("删除skill成功", zap.String("skill", skillName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "delete", "删除 Skill", "skill", skillName, map[string]interface{}{
|
||||
"affected_roles": affectedRoles,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": responseMsg,
|
||||
"affected_roles": affectedRoles,
|
||||
|
||||
@@ -253,5 +253,5 @@ func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
runCommandStreamImpl(cmd, sendEvent, ctx)
|
||||
_ = runCommandStreamImpl(cmd, sendEvent, ctx)
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ const ptyCols = 256
|
||||
const ptyRows = 40
|
||||
|
||||
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
|
||||
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
defer ptmx.Close()
|
||||
|
||||
@@ -43,4 +43,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
|
||||
exitCode = -1
|
||||
}
|
||||
sendEvent(streamEvent{T: "exit", C: exitCode})
|
||||
return exitCode
|
||||
}
|
||||
|
||||
@@ -11,20 +11,20 @@ import (
|
||||
)
|
||||
|
||||
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
|
||||
normalize := func(s string) string {
|
||||
@@ -62,4 +62,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
|
||||
exitCode = -1
|
||||
}
|
||||
sendEvent(streamEvent{T: "exit", C: exitCode})
|
||||
return exitCode
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -16,6 +17,12 @@ import (
|
||||
type VulnerabilityHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *VulnerabilityHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewVulnerabilityHandler 创建新的漏洞处理器
|
||||
@@ -29,6 +36,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
|
||||
// CreateVulnerabilityRequest 创建漏洞请求
|
||||
type CreateVulnerabilityRequest struct {
|
||||
ConversationID string `json:"conversation_id" binding:"required"`
|
||||
ProjectID string `json:"project_id"`
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
@@ -52,6 +60,7 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
||||
|
||||
vuln := &database.Vulnerability{
|
||||
ConversationID: req.ConversationID,
|
||||
ProjectID: strings.TrimSpace(req.ProjectID),
|
||||
ConversationTag: req.ConversationTag,
|
||||
TaskTag: req.TaskTag,
|
||||
Title: req.Title,
|
||||
@@ -72,6 +81,11 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "vulnerability", "create", "创建漏洞记录", "vulnerability", created.ID, map[string]interface{}{
|
||||
"severity": created.Severity, "title": created.Title,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
@@ -98,18 +112,30 @@ type ListVulnerabilitiesResponse struct {
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter {
|
||||
q := strings.TrimSpace(c.Query("q"))
|
||||
if q == "" {
|
||||
q = strings.TrimSpace(c.Query("search"))
|
||||
}
|
||||
return database.VulnerabilityListFilter{
|
||||
ProjectID: c.Query("project_id"),
|
||||
ID: c.Query("id"),
|
||||
Search: q,
|
||||
ConversationID: c.Query("conversation_id"),
|
||||
Severity: c.Query("severity"),
|
||||
Status: c.Query("status"),
|
||||
TaskID: c.Query("task_id"),
|
||||
ConversationTag: c.Query("conversation_tag"),
|
||||
TaskTag: c.Query("task_tag"),
|
||||
}
|
||||
}
|
||||
|
||||
// ListVulnerabilities 列出漏洞
|
||||
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "20")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
pageStr := c.Query("page")
|
||||
id := c.Query("id")
|
||||
conversationID := c.Query("conversation_id")
|
||||
severity := c.Query("severity")
|
||||
status := c.Query("status")
|
||||
taskID := c.Query("task_id")
|
||||
conversationTag := c.Query("conversation_tag")
|
||||
taskTag := c.Query("task_tag")
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
limit, _ := strconv.Atoi(limitStr)
|
||||
offset, _ := strconv.Atoi(offsetStr)
|
||||
@@ -131,7 +157,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
total, err := h.db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞总数失败", zap.Error(err))
|
||||
// 继续执行,使用0作为总数
|
||||
@@ -139,7 +165,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取漏洞列表
|
||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -170,17 +196,18 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
|
||||
// UpdateVulnerabilityRequest 更新漏洞请求
|
||||
type UpdateVulnerabilityRequest struct {
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
ProjectID *string `json:"project_id"`
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
}
|
||||
|
||||
// UpdateVulnerability 更新漏洞
|
||||
@@ -201,6 +228,9 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.ProjectID != nil {
|
||||
existing.ProjectID = strings.TrimSpace(*req.ProjectID)
|
||||
}
|
||||
if req.ConversationTag != "" {
|
||||
existing.ConversationTag = req.ConversationTag
|
||||
}
|
||||
@@ -249,6 +279,11 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
|
||||
"severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
@@ -262,15 +297,25 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "vulnerability",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "vulnerability",
|
||||
ResourceID: id,
|
||||
Message: "删除漏洞记录",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
// GetVulnerabilityStats 获取漏洞统计
|
||||
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
||||
conversationID := c.Query("conversation_id")
|
||||
taskID := c.Query("task_id")
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
stats, err := h.db.GetVulnerabilityStats(conversationID, taskID)
|
||||
stats, err := h.db.GetVulnerabilityStats(filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞统计失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -304,15 +349,9 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Query("id")
|
||||
conversationID := c.Query("conversation_id")
|
||||
severity := c.Query("severity")
|
||||
status := c.Query("status")
|
||||
taskID := c.Query("task_id")
|
||||
conversationTag := c.Query("conversation_tag")
|
||||
taskTag := c.Query("task_tag")
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
total, err := h.db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -322,7 +361,7 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
items, err := h.db.ListVulnerabilities(total, 0, filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -132,6 +134,16 @@ func quoteCmdPath(p string) string {
|
||||
return "\"" + strings.ReplaceAll(p, "\"", "\"\"") + "\""
|
||||
}
|
||||
|
||||
// normalizeWindowsCmdPath 把前端统一的 "/" 路径转换为 cmd 更稳定识别的 "\"。
|
||||
// 仅用于 Windows 命令构造,不改变语义(例如 "." / ".." 会保持不变)。
|
||||
func normalizeWindowsCmdPath(p string) string {
|
||||
s := strings.TrimSpace(p)
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
return strings.ReplaceAll(s, "/", "\\")
|
||||
}
|
||||
|
||||
// quotePsSingle 把字符串按 PowerShell 单引号字符串规则转义(内部 ' → '')。
|
||||
// 供 PowerShell 脚本参数使用,全脚本只用单引号,外层 cmd 再用双引号包裹即可安全传递。
|
||||
func quotePsSingle(s string) string {
|
||||
@@ -196,6 +208,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
p = "."
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
p = normalizeWindowsCmdPath(p)
|
||||
return "dir /a " + quoteCmdPath(p), nil
|
||||
}
|
||||
return "ls -la " + quoteShellSinglePosix(p), nil
|
||||
@@ -205,6 +218,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return "type " + quoteCmdPath(path), nil
|
||||
}
|
||||
return "cat " + quoteShellSinglePosix(path), nil
|
||||
@@ -214,6 +228,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return "del /q /f " + quoteCmdPath(path), nil
|
||||
}
|
||||
return "rm -f " + quoteShellSinglePosix(path), nil
|
||||
@@ -223,6 +238,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
// cmd 的 md 默认会自动创建中间目录(等价于 Linux 的 mkdir -p)
|
||||
return "md " + quoteCmdPath(path), nil
|
||||
}
|
||||
@@ -235,6 +251,8 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpRenameNeedsBothPaths
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
oldPath = normalizeWindowsCmdPath(oldPath)
|
||||
newPath = normalizeWindowsCmdPath(newPath)
|
||||
return "move /y " + quoteCmdPath(oldPath) + " " + quoteCmdPath(newPath), nil
|
||||
}
|
||||
return "mv -f " + quoteShellSinglePosix(oldPath) + " " + quoteShellSinglePosix(newPath), nil
|
||||
@@ -247,6 +265,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
// 这样既能写入任意二进制/含引号的文本,又避免各家 shell 的转义地狱。
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte(in.Content))
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return buildWindowsPowerShellWrite(path, b64), nil
|
||||
}
|
||||
return "echo '" + b64 + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
||||
@@ -259,6 +278,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpUploadTooLarge
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return buildWindowsPowerShellWrite(path, in.Content), nil
|
||||
}
|
||||
return "echo '" + in.Content + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
||||
@@ -268,6 +288,7 @@ func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error)
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
if in.ChunkIndex == 0 {
|
||||
return buildWindowsPowerShellWrite(path, in.Content), nil
|
||||
}
|
||||
@@ -304,6 +325,12 @@ type WebShellHandler struct {
|
||||
logger *zap.Logger
|
||||
client *http.Client
|
||||
db *database.DB
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *WebShellHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用)
|
||||
@@ -311,8 +338,12 @@ func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler {
|
||||
return &WebShellHandler{
|
||||
logger: logger,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{DisableKeepAlives: false},
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: false,
|
||||
// WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy
|
||||
},
|
||||
},
|
||||
db: db,
|
||||
}
|
||||
@@ -403,6 +434,15 @@ func (h *WebShellHandler) CreateConnection(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
host := req.URL
|
||||
if u, err := url.Parse(req.URL); err == nil {
|
||||
host = u.Host
|
||||
}
|
||||
h.audit.RecordOK(c, "webshell", "connection_create", "创建 WebShell 连接", "webshell_connection", conn.ID, map[string]interface{}{
|
||||
"host": host, "type": shellType,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, conn)
|
||||
}
|
||||
|
||||
@@ -485,6 +525,9 @@ func (h *WebShellHandler) DeleteConnection(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "webshell", "connection_delete", "删除 WebShell 连接", "webshell_connection", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
@@ -714,8 +757,9 @@ func (h *WebShellHandler) Exec(c *gin.Context) {
|
||||
output := decodeWebshellOutput(out, req.Encoding)
|
||||
httpCode := resp.StatusCode
|
||||
|
||||
ok := resp.StatusCode == http.StatusOK
|
||||
c.JSON(http.StatusOK, ExecResponse{
|
||||
OK: resp.StatusCode == http.StatusOK,
|
||||
OK: ok,
|
||||
Output: output,
|
||||
HTTPCode: httpCode,
|
||||
})
|
||||
|
||||
@@ -15,7 +15,7 @@ const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `s
|
||||
|
||||
// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。
|
||||
// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。
|
||||
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base"
|
||||
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_vulnerabilities、get_vulnerability、upsert_project_fact、get_project_fact、list_project_facts、search_project_facts、deprecate_project_fact、restore_project_fact、list_knowledge_risk_types、search_knowledge_base"
|
||||
|
||||
// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。
|
||||
// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、
|
||||
@@ -65,7 +65,7 @@ func BuildWebshellAssistantContext(conn *database.WebShellConnection, skillHint,
|
||||
b.WriteString(conn.ID)
|
||||
b.WriteString("\"):")
|
||||
b.WriteString(webshellAssistantToolList)
|
||||
b.WriteString("。")
|
||||
b.WriteString("。边渗透边记录:每确认新认知即 upsert_project_fact,每验证漏洞即 record_vulnerability,勿等会话结束。")
|
||||
b.WriteString(skillHint)
|
||||
b.WriteString("\n\n用户请求:")
|
||||
b.WriteString(userMsg)
|
||||
|
||||
@@ -4,7 +4,17 @@ package builtin
|
||||
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
|
||||
const (
|
||||
// 漏洞管理工具
|
||||
ToolRecordVulnerability = "record_vulnerability"
|
||||
ToolRecordVulnerability = "record_vulnerability"
|
||||
ToolListVulnerabilities = "list_vulnerabilities"
|
||||
ToolGetVulnerability = "get_vulnerability"
|
||||
|
||||
// 项目黑板(事实)工具
|
||||
ToolUpsertProjectFact = "upsert_project_fact"
|
||||
ToolGetProjectFact = "get_project_fact"
|
||||
ToolListProjectFacts = "list_project_facts"
|
||||
ToolSearchProjectFacts = "search_project_facts"
|
||||
ToolDeprecateProjectFact = "deprecate_project_fact"
|
||||
ToolRestoreProjectFact = "restore_project_fact"
|
||||
|
||||
// 知识库工具
|
||||
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
|
||||
@@ -53,6 +63,14 @@ const (
|
||||
func IsBuiltinTool(toolName string) bool {
|
||||
switch toolName {
|
||||
case ToolRecordVulnerability,
|
||||
ToolListVulnerabilities,
|
||||
ToolGetVulnerability,
|
||||
ToolUpsertProjectFact,
|
||||
ToolGetProjectFact,
|
||||
ToolListProjectFacts,
|
||||
ToolSearchProjectFacts,
|
||||
ToolDeprecateProjectFact,
|
||||
ToolRestoreProjectFact,
|
||||
ToolListKnowledgeRiskTypes,
|
||||
ToolSearchKnowledgeBase,
|
||||
ToolWebshellExec,
|
||||
@@ -96,6 +114,14 @@ func IsBuiltinTool(toolName string) bool {
|
||||
func GetAllBuiltinTools() []string {
|
||||
return []string{
|
||||
ToolRecordVulnerability,
|
||||
ToolListVulnerabilities,
|
||||
ToolGetVulnerability,
|
||||
ToolUpsertProjectFact,
|
||||
ToolGetProjectFact,
|
||||
ToolListProjectFacts,
|
||||
ToolSearchProjectFacts,
|
||||
ToolDeprecateProjectFact,
|
||||
ToolRestoreProjectFact,
|
||||
ToolListKnowledgeRiskTypes,
|
||||
ToolSearchKnowledgeBase,
|
||||
ToolWebshellExec,
|
||||
|
||||
@@ -77,6 +77,9 @@ type einoADKRunLoopArgs struct {
|
||||
StreamsMainAssistant func(agent string) bool
|
||||
EinoRoleTag func(agent string) string
|
||||
CheckpointDir string
|
||||
// RunRetryMaxAttempts / RunRetryMaxBackoffSec:429、5xx、网络抖动时的指数退避续跑(0=默认 10 次 / 30s 上限)。
|
||||
RunRetryMaxAttempts int
|
||||
RunRetryMaxBackoffSec int
|
||||
|
||||
McpIDsMu *sync.Mutex
|
||||
McpIDs *[]string
|
||||
@@ -181,14 +184,19 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
mainAgentToolStep := make(map[string]int)
|
||||
pendingByID := make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent := make(map[string][]string)
|
||||
var pendingMu sync.Mutex
|
||||
markPending := func(tc toolCallPendingInfo) {
|
||||
if tc.ToolCallID == "" {
|
||||
return
|
||||
}
|
||||
pendingMu.Lock()
|
||||
defer pendingMu.Unlock()
|
||||
pendingByID[tc.ToolCallID] = tc
|
||||
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
||||
}
|
||||
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
||||
pendingMu.Lock()
|
||||
defer pendingMu.Unlock()
|
||||
q := pendingQueueByAgent[agentName]
|
||||
for len(q) > 0 {
|
||||
id := q[0]
|
||||
@@ -205,19 +213,42 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
if toolCallID == "" {
|
||||
return
|
||||
}
|
||||
pendingMu.Lock()
|
||||
defer pendingMu.Unlock()
|
||||
delete(pendingByID, toolCallID)
|
||||
}
|
||||
popAnyPending := func() (toolCallPendingInfo, bool) {
|
||||
pendingMu.Lock()
|
||||
defer pendingMu.Unlock()
|
||||
for id, tc := range pendingByID {
|
||||
delete(pendingByID, id)
|
||||
return tc, true
|
||||
}
|
||||
return toolCallPendingInfo{}, false
|
||||
}
|
||||
pendingCount := func() int {
|
||||
pendingMu.Lock()
|
||||
defer pendingMu.Unlock()
|
||||
return len(pendingByID)
|
||||
}
|
||||
flushAllPendingAsFailed := func(err error) {
|
||||
pendingMu.Lock()
|
||||
pendingSnapshot := make([]toolCallPendingInfo, 0, len(pendingByID))
|
||||
for _, tc := range pendingByID {
|
||||
pendingSnapshot = append(pendingSnapshot, tc)
|
||||
}
|
||||
pendingByID = make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent = make(map[string][]string)
|
||||
pendingMu.Unlock()
|
||||
|
||||
if progress == nil {
|
||||
pendingByID = make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent = make(map[string][]string)
|
||||
return
|
||||
}
|
||||
msg := ""
|
||||
if err != nil {
|
||||
msg = err.Error()
|
||||
}
|
||||
for _, tc := range pendingByID {
|
||||
for _, tc := range pendingSnapshot {
|
||||
toolName := tc.ToolName
|
||||
if strings.TrimSpace(toolName) == "" {
|
||||
toolName = "unknown"
|
||||
@@ -235,8 +266,6 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
pendingByID = make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent = make(map[string][]string)
|
||||
}
|
||||
|
||||
// 最近一次成功的 Eino filesystem execute 的标准输出(trim):用于抑制模型紧接着复述同一字符串时的重复「助手输出」时间线。
|
||||
@@ -316,7 +345,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
|
||||
runnerCfg := adk.RunnerConfig{
|
||||
Agent: da,
|
||||
Agent: da,
|
||||
// 启用 ADK 流式事件:plan_execute 也需要输出 reasoning/response 流,
|
||||
// 与 deep/supervisor/eino_single 的前端体验保持一致。
|
||||
EnableStreaming: true,
|
||||
}
|
||||
var cpStore *fileCheckPointStore
|
||||
@@ -437,6 +468,28 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
return runErr
|
||||
}
|
||||
|
||||
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
|
||||
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
|
||||
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||
return false, handleRunErr(runErr)
|
||||
}
|
||||
if logger != nil {
|
||||
logger.Warn("eino transient error, ending run segment for handler resume",
|
||||
zap.Error(runErr),
|
||||
zap.String("orchestration", orchMode))
|
||||
}
|
||||
if progress != nil {
|
||||
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"error": runErr.Error(),
|
||||
"resumeKind": "trace_segment",
|
||||
})
|
||||
}
|
||||
return false, ErrTransientRetryContinue
|
||||
}
|
||||
|
||||
takePartial := func(runErr error) (*RunResult, error) {
|
||||
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
||||
return nil, runErr
|
||||
@@ -494,8 +547,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
return takePartial(ctxErr)
|
||||
}
|
||||
if len(pendingByID) > 0 {
|
||||
orphanCount := len(pendingByID)
|
||||
if orphanCount := pendingCount(); orphanCount > 0 {
|
||||
flushAllPendingAsFailed(errors.New("pending tool call missing result before run completion"))
|
||||
if progress != nil {
|
||||
progress("eino_pending_orphaned", "pending tool calls were force-closed at run end", map[string]interface{}{
|
||||
@@ -519,7 +571,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
continue
|
||||
}
|
||||
if ev.Err != nil {
|
||||
if retErr := handleRunErr(ev.Err); retErr != nil {
|
||||
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
}
|
||||
@@ -821,7 +873,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
})
|
||||
}
|
||||
if retErr := handleRunErr(streamRecvErr); retErr != nil {
|
||||
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
}
|
||||
@@ -932,12 +984,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else if inferred, ok := popNextPendingForAgent(""); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else {
|
||||
for id := range pendingByID {
|
||||
toolCallID = id
|
||||
delete(pendingByID, id)
|
||||
break
|
||||
}
|
||||
} else if inferred, ok := popAnyPending(); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
}
|
||||
}
|
||||
if toolCallID != "" {
|
||||
|
||||
@@ -59,6 +59,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
}
|
||||
plannerCfg := &planexecute.PlannerConfig{
|
||||
ToolCallingChatModel: tcm,
|
||||
NewPlan: newLenientPlan,
|
||||
}
|
||||
if fn := planExecutePlannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers); fn != nil {
|
||||
plannerCfg.GenInputFn = fn
|
||||
@@ -70,6 +71,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
replanner, err := planexecute.NewReplanner(ctx, &planexecute.ReplannerConfig{
|
||||
ChatModel: tcm,
|
||||
GenInputFn: planExecuteReplannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers),
|
||||
NewPlan: newLenientPlan,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("plan_execute replanner: %w", err)
|
||||
@@ -146,14 +148,12 @@ func planExecutePlannerGenInput(
|
||||
}
|
||||
return func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) {
|
||||
userInput = capPlanExecuteUserInputMessages(userInput, appCfg, mwCfg)
|
||||
msgs := make([]adk.Message, 0, 1+len(userInput))
|
||||
if oi != "" {
|
||||
msgs = append(msgs, schema.SystemMessage(oi))
|
||||
}
|
||||
msgs := make([]adk.Message, 0, len(userInput))
|
||||
msgs = append(msgs, userInput...)
|
||||
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
|
||||
msgs = rewritten
|
||||
}
|
||||
msgs = normalizeSingleLeadingSystemMessage(msgs, oi)
|
||||
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_planner", msgs)
|
||||
return msgs, nil
|
||||
}
|
||||
@@ -182,9 +182,7 @@ func planExecuteExecutorGenInput(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oi != "" {
|
||||
userMsgs = append([]adk.Message{schema.SystemMessage(oi)}, userMsgs...)
|
||||
}
|
||||
userMsgs = normalizeSingleLeadingSystemMessage(userMsgs, oi)
|
||||
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_executor_gen_input", userMsgs)
|
||||
return userMsgs, nil
|
||||
}
|
||||
@@ -231,17 +229,54 @@ func planExecuteReplannerGenInput(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oi != "" {
|
||||
msgs = append([]adk.Message{schema.SystemMessage(oi)}, msgs...)
|
||||
}
|
||||
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
|
||||
msgs = rewritten
|
||||
}
|
||||
msgs = normalizeSingleLeadingSystemMessage(msgs, oi)
|
||||
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_replanner", msgs)
|
||||
return msgs, nil
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeSingleLeadingSystemMessage enforces a provider-friendly message shape:
|
||||
// exactly one system message at index 0 (when any system context exists).
|
||||
// For strict OpenAI-compatible backends (e.g. qwen/vllm templates), this avoids
|
||||
// "System message must be at the beginning" caused by multiple/disordered system messages.
|
||||
func normalizeSingleLeadingSystemMessage(msgs []adk.Message, extraSystem string) []adk.Message {
|
||||
extraSystem = strings.TrimSpace(extraSystem)
|
||||
if len(msgs) == 0 {
|
||||
if extraSystem == "" {
|
||||
return msgs
|
||||
}
|
||||
return []adk.Message{schema.SystemMessage(extraSystem)}
|
||||
}
|
||||
|
||||
systemParts := make([]string, 0, 2)
|
||||
if extraSystem != "" {
|
||||
systemParts = append(systemParts, extraSystem)
|
||||
}
|
||||
nonSystem := make([]adk.Message, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.Role == schema.System {
|
||||
if s := strings.TrimSpace(msg.Content); s != "" {
|
||||
systemParts = append(systemParts, s)
|
||||
}
|
||||
continue
|
||||
}
|
||||
nonSystem = append(nonSystem, msg)
|
||||
}
|
||||
if len(systemParts) == 0 {
|
||||
return nonSystem
|
||||
}
|
||||
out := make([]adk.Message, 0, len(nonSystem)+1)
|
||||
out = append(out, schema.SystemMessage(strings.Join(systemParts, "\n\n")))
|
||||
out = append(out, nonSystem...)
|
||||
return out
|
||||
}
|
||||
|
||||
func capPlanExecuteUserInputMessages(input []adk.Message, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
|
||||
if len(input) == 0 {
|
||||
return input
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestNormalizeSingleLeadingSystemMessage_MergesMultipleSystems(t *testing.T) {
|
||||
in := []adk.Message{
|
||||
schema.SystemMessage("sys-1"),
|
||||
schema.UserMessage("u1"),
|
||||
schema.SystemMessage("sys-2"),
|
||||
schema.AssistantMessage("a1", nil),
|
||||
}
|
||||
out := normalizeSingleLeadingSystemMessage(in, "orch")
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("unexpected output length: got %d want 3", len(out))
|
||||
}
|
||||
if out[0].Role != schema.System {
|
||||
t.Fatalf("first message role must be system, got %s", out[0].Role)
|
||||
}
|
||||
if got := out[0].Content; got != "orch\n\nsys-1\n\nsys-2" {
|
||||
t.Fatalf("unexpected merged system content: %q", got)
|
||||
}
|
||||
if out[1].Role != schema.User || out[2].Role != schema.Assistant {
|
||||
t.Fatalf("non-system message order changed unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSingleLeadingSystemMessage_NoSystemKeepsFlow(t *testing.T) {
|
||||
in := []adk.Message{
|
||||
schema.UserMessage("u1"),
|
||||
schema.AssistantMessage("a1", nil),
|
||||
}
|
||||
out := normalizeSingleLeadingSystemMessage(in, "")
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("unexpected output length: got %d want 2", len(out))
|
||||
}
|
||||
if out[0].Role != schema.User || out[1].Role != schema.Assistant {
|
||||
t.Fatalf("message order changed unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/project"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -39,6 +39,7 @@ func RunEinoSingleChatModelAgent(
|
||||
roleTools []string,
|
||||
progress func(eventType, message string, data interface{}),
|
||||
reasoningClient *reasoning.ClientIntent,
|
||||
systemPromptExtra string,
|
||||
) (*RunResult, error) {
|
||||
if appCfg == nil || ag == nil {
|
||||
return nil, fmt.Errorf("eino single: 配置或 Agent 为空")
|
||||
@@ -178,7 +179,8 @@ func RunEinoSingleChatModelAgent(
|
||||
},
|
||||
EmitInternalEvents: true,
|
||||
}
|
||||
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools, singleToolSearchActive)
|
||||
ins := project.AppendSystemPromptBlock(ag.EinoSingleAgentSystemInstruction(), systemPromptExtra)
|
||||
ins = injectToolNamesOnlyInstruction(ctx, ins, mainTools, singleToolSearchActive)
|
||||
if logger != nil {
|
||||
names := collectToolNames(ctx, mainTools)
|
||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||
@@ -213,7 +215,7 @@ func RunEinoSingleChatModelAgent(
|
||||
}
|
||||
|
||||
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
||||
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
|
||||
baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
|
||||
|
||||
streamsMainAssistant := func(agent string) bool {
|
||||
return agent == "" || agent == einoSingleAgentName
|
||||
@@ -233,6 +235,8 @@ func RunEinoSingleChatModelAgent(
|
||||
StreamsMainAssistant: streamsMainAssistant,
|
||||
EinoRoleTag: einoRoleTag,
|
||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
|
||||
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
|
||||
McpIDsMu: &mcpIDsMu,
|
||||
McpIDs: &mcpIDs,
|
||||
FilesystemMonitorAgent: ag,
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultEinoRunRetryMaxAttempts = 10
|
||||
defaultEinoRunRetryMaxBackoff = 30 * time.Second
|
||||
)
|
||||
|
||||
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等)。
|
||||
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列。
|
||||
func isEinoTransientRunError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// io.EOF 常见于流式正常收尾,不应触发分段重试。
|
||||
if errors.Is(err, io.EOF) {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
if isEinoIterationLimitError(err) {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
transientMarkers := []string{
|
||||
"406",
|
||||
"429",
|
||||
"too many requests",
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"ratelimit",
|
||||
"quota exceeded",
|
||||
"overloaded",
|
||||
"capacity",
|
||||
"temporarily unavailable",
|
||||
"service unavailable",
|
||||
"bad gateway",
|
||||
"gateway timeout",
|
||||
"internal server error",
|
||||
"connection reset",
|
||||
"connection refused",
|
||||
"connection closed",
|
||||
"i/o timeout",
|
||||
"no such host",
|
||||
"network is unreachable",
|
||||
"broken pipe",
|
||||
"read tcp",
|
||||
"write tcp",
|
||||
"dial tcp",
|
||||
"tls handshake timeout",
|
||||
"stream error",
|
||||
"unexpected eof",
|
||||
"unexpected end of json",
|
||||
"status code: 406",
|
||||
"status code: 502",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"500",
|
||||
}
|
||||
for _, m := range transientMarkers {
|
||||
if strings.Contains(msg, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||
if args != nil && args.RunRetryMaxAttempts > 0 {
|
||||
return args.RunRetryMaxAttempts
|
||||
}
|
||||
return defaultEinoRunRetryMaxAttempts
|
||||
}
|
||||
|
||||
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致)。
|
||||
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
||||
if mw != nil && mw.RunRetryMaxAttempts > 0 {
|
||||
return mw.RunRetryMaxAttempts
|
||||
}
|
||||
return defaultEinoRunRetryMaxAttempts
|
||||
}
|
||||
|
||||
// TransientRetryBackoff 供 handler 在分段续跑前退避。
|
||||
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
|
||||
max := defaultEinoRunRetryMaxBackoff
|
||||
if maxBackoffSec > 0 {
|
||||
max = time.Duration(maxBackoffSec) * time.Second
|
||||
}
|
||||
return einoTransientRetryBackoff(attempt, max)
|
||||
}
|
||||
|
||||
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
|
||||
if args != nil && args.RunRetryMaxBackoffSec > 0 {
|
||||
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
|
||||
}
|
||||
return defaultEinoRunRetryMaxBackoff
|
||||
}
|
||||
|
||||
// einoRunRestartContextSource 描述无 checkpoint Resume 时 Run 使用的消息来源(日志/SSE)。
|
||||
type einoRunRestartContextSource string
|
||||
|
||||
const (
|
||||
einoRestartContextInitial einoRunRestartContextSource = "initial"
|
||||
einoRestartContextAccumulated einoRunRestartContextSource = "accumulated"
|
||||
einoRestartContextModelTrace einoRunRestartContextSource = "model_trace"
|
||||
)
|
||||
|
||||
// einoMessagesForRunRestart 在退避后重新 Run 时选用最完整的上下文:
|
||||
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
|
||||
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
|
||||
if trace := persistTraceSource(args, nil); len(trace) > 0 {
|
||||
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace
|
||||
}
|
||||
if len(accumulated) > baseCount {
|
||||
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated
|
||||
}
|
||||
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
|
||||
}
|
||||
|
||||
// adkMessagesHasUserContent 从尾部向前查找,是否已有与 want 相同的 user 消息(避免重复 append)。
|
||||
func adkMessagesHasUserContent(msgs []adk.Message, want string) bool {
|
||||
want = strings.TrimSpace(want)
|
||||
if want == "" {
|
||||
return true
|
||||
}
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
m := msgs[i]
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.Role == schema.User {
|
||||
return strings.TrimSpace(m.Content) == want
|
||||
}
|
||||
if m.Role == schema.Assistant || m.Role == schema.Tool {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当轨迹中尚未包含该句)。
|
||||
func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message {
|
||||
if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) {
|
||||
return msgs
|
||||
}
|
||||
return append(msgs, schema.UserMessage(userMessage))
|
||||
}
|
||||
|
||||
// einoTransientRetryBackoff 指数退避:2s, 4s, 8s… capped by maxBackoff。
|
||||
func einoTransientRetryBackoff(attempt int, maxBackoff time.Duration) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
|
||||
if maxBackoff > 0 && backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
return backoff
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestIsEinoTransientRunError(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"io eof", io.EOF, false},
|
||||
{"plain eof text", errors.New("EOF"), false},
|
||||
{"429", errors.New("HTTP 429 Too Many Requests"), true},
|
||||
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
|
||||
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
|
||||
{"unexpected eof", errors.New("unexpected EOF"), true},
|
||||
{"503", errors.New("upstream returned 503"), true},
|
||||
{"iteration limit", errors.New("max iteration reached"), false},
|
||||
{"canceled", context.Canceled, false},
|
||||
{"deadline", context.DeadlineExceeded, false},
|
||||
{"auth", errors.New("invalid api key"), false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isEinoTransientRunError(tc.err); got != tc.want {
|
||||
t.Fatalf("isEinoTransientRunError(%v) = %v, want %v", tc.err, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoTransientRetryBackoff(t *testing.T) {
|
||||
t.Parallel()
|
||||
max := 30 * time.Second
|
||||
if got := einoTransientRetryBackoff(0, max); got != 2*time.Second {
|
||||
t.Fatalf("attempt 0: got %v", got)
|
||||
}
|
||||
if got := einoTransientRetryBackoff(4, max); got != 30*time.Second {
|
||||
t.Fatalf("attempt 4 capped: got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoMessagesForRunRestart(t *testing.T) {
|
||||
t.Parallel()
|
||||
base := []adk.Message{schema.UserMessage("hi")}
|
||||
acc := append([]adk.Message(nil), base...)
|
||||
acc = append(acc, schema.AssistantMessage("step1", nil))
|
||||
|
||||
got, src := einoMessagesForRunRestart(nil, base, acc, len(base))
|
||||
if src != einoRestartContextAccumulated || len(got) != 2 {
|
||||
t.Fatalf("accumulated: src=%s len=%d", src, len(got))
|
||||
}
|
||||
|
||||
holder := newModelFacingTraceHolder()
|
||||
holder.storeFromState(&adk.ChatModelAgentState{
|
||||
Messages: []adk.Message{schema.UserMessage("u"), schema.AssistantMessage("model-view", nil)},
|
||||
})
|
||||
got2, src2 := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, base, acc, len(base))
|
||||
if src2 != einoRestartContextModelTrace || len(got2) != 2 {
|
||||
t.Fatalf("model trace: src=%s len=%d", src2, len(got2))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
|
||||
t.Parallel()
|
||||
if einoRunRetryMaxAttempts(nil) != defaultEinoRunRetryMaxAttempts {
|
||||
t.Fatal("nil args should use default")
|
||||
}
|
||||
if einoRunRetryMaxAttempts(&einoADKRunLoopArgs{RunRetryMaxAttempts: 3}) != 3 {
|
||||
t.Fatal("custom max attempts")
|
||||
}
|
||||
if RunRetryMaxAttemptsFromConfig(nil) != defaultEinoRunRetryMaxAttempts {
|
||||
t.Fatal("config nil should use default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []adk.Message{schema.UserMessage("old task")}
|
||||
out := appendUserMessageIfNeeded(msgs, "你好,你是谁")
|
||||
if len(out) != 2 || out[1].Content != "你好,你是谁" {
|
||||
t.Fatalf("should append user: len=%d", len(out))
|
||||
}
|
||||
dup := appendUserMessageIfNeeded(out, "你好,你是谁")
|
||||
if len(dup) != 2 {
|
||||
t.Fatalf("should not duplicate user message: len=%d", len(dup))
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrTransientRetryContinue(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
|
||||
t.Fatal("sentinel should match")
|
||||
}
|
||||
}
|
||||
@@ -5,3 +5,7 @@ import "errors"
|
||||
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
||||
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
||||
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
||||
|
||||
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
|
||||
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
|
||||
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/agents"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/project"
|
||||
)
|
||||
|
||||
// DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。
|
||||
@@ -106,16 +106,14 @@ func DefaultPlanExecuteOrchestratorInstruction() string {
|
||||
|
||||
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
|
||||
|
||||
## 漏洞记录
|
||||
` + project.FactRecordingBlackboardSection(true) + `
|
||||
|
||||
发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。
|
||||
|
||||
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。记录后可在授权范围内继续测试。
|
||||
- **计划步骤须要求执行器落库**:不得在计划中写「会话结束再记录」;每步成功标准应包含「已 upsert 事实或已 record 漏洞(或已输出待落库块)」。
|
||||
|
||||
## 技能库(Skills)与知识库
|
||||
|
||||
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
|
||||
- plan_execute 执行器通过 MCP 使用知识库与漏洞记录等;Skills 的渐进式加载在「多代理 / Eino DeepAgent」等模式中由内置 skill 工具完成(需 multi_agent.eino_skills)。
|
||||
- plan_execute 执行器通过 MCP 使用知识库、项目事实与漏洞记录等;Skills 的渐进式加载在「多代理 / Eino DeepAgent」等模式中由内置 skill 工具完成(需 multi_agent.eino_skills)。
|
||||
- 若需要完整 Skill 工作流而当前会话无 skill 工具,请在计划或对用户说明中建议切换多代理或 Eino 编排会话。
|
||||
|
||||
## 执行器对用户输出(重要)
|
||||
@@ -206,7 +204,8 @@ func DefaultSupervisorOrchestratorInstruction() string {
|
||||
- **委派优先**:可独立封装、需要专项上下文的子目标(枚举、验证、归纳、报告素材)优先 transfer 给匹配子代理,并在委派说明中写清:子目标、约束、期望交付物结构、证据要求。
|
||||
- **亲自执行**:仅当无合适专家、需全局衔接或子代理结果不足时,由你直接调用工具。
|
||||
- **汇总**:子代理输出是证据来源;你要对齐矛盾、补全上下文,给出统一结论与可复现验证步骤,避免机械拼接。
|
||||
- **漏洞**:有效漏洞应通过 ` + builtin.ToolRecordVulnerability + ` 记录(含 POC 与严重性:critical / high / medium / low / info)。
|
||||
|
||||
` + project.FactRecordingBlackboardSection(true) + `
|
||||
|
||||
## transfer 交接与防重复劳动
|
||||
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/adk/prebuilt/planexecute"
|
||||
)
|
||||
|
||||
// lenientPlan keeps plan_execute running even when model tool arguments contain minor JSON defects.
|
||||
// It first tries strict JSON, then falls back to lightweight step extraction heuristics.
|
||||
type lenientPlan struct {
|
||||
Steps []string `json:"steps"`
|
||||
}
|
||||
|
||||
func newLenientPlan(context.Context) planexecute.Plan {
|
||||
return &lenientPlan{}
|
||||
}
|
||||
|
||||
func (p *lenientPlan) FirstStep() string {
|
||||
if p == nil || len(p.Steps) == 0 {
|
||||
return ""
|
||||
}
|
||||
return p.Steps[0]
|
||||
}
|
||||
|
||||
func (p *lenientPlan) MarshalJSON() ([]byte, error) {
|
||||
type alias lenientPlan
|
||||
return json.Marshal((*alias)(p))
|
||||
}
|
||||
|
||||
func (p *lenientPlan) UnmarshalJSON(b []byte) error {
|
||||
type alias lenientPlan
|
||||
var strict alias
|
||||
if err := json.Unmarshal(b, &strict); err == nil {
|
||||
strict.Steps = normalizePlanSteps(strict.Steps)
|
||||
if len(strict.Steps) > 0 {
|
||||
*p = lenientPlan(strict)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
steps := extractPlanStepsLenient(string(b))
|
||||
if len(steps) == 0 {
|
||||
steps = []string{"继续按当前目标执行下一步,并输出可验证证据。"}
|
||||
}
|
||||
p.Steps = steps
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractPlanStepsLenient(raw string) []string {
|
||||
s := strings.TrimSpace(stripCodeFence(raw))
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if extracted, ok := sliceByStepsArray(s); ok {
|
||||
var arr []string
|
||||
if err := json.Unmarshal([]byte(extracted), &arr); err == nil {
|
||||
arr = normalizePlanSteps(arr)
|
||||
if len(arr) > 0 {
|
||||
return arr
|
||||
}
|
||||
}
|
||||
if arr := splitStepsHeuristically(strings.Trim(extracted, "[]")); len(arr) > 0 {
|
||||
return arr
|
||||
}
|
||||
}
|
||||
|
||||
// Last-resort: treat plaintext body as one actionable step.
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{s}
|
||||
}
|
||||
|
||||
func sliceByStepsArray(s string) (string, bool) {
|
||||
lower := strings.ToLower(s)
|
||||
key := `"steps"`
|
||||
i := strings.Index(lower, key)
|
||||
if i < 0 {
|
||||
return "", false
|
||||
}
|
||||
start := strings.Index(s[i:], "[")
|
||||
if start < 0 {
|
||||
return "", false
|
||||
}
|
||||
start += i
|
||||
depth := 0
|
||||
for j := start; j < len(s); j++ {
|
||||
switch s[j] {
|
||||
case '[':
|
||||
depth++
|
||||
case ']':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return s[start : j+1], true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func splitStepsHeuristically(body string) []string {
|
||||
body = strings.ReplaceAll(body, "\r\n", "\n")
|
||||
body = strings.ReplaceAll(body, "\\n", "\n")
|
||||
var parts []string
|
||||
if strings.Contains(body, "\n") {
|
||||
for _, line := range strings.Split(body, "\n") {
|
||||
parts = append(parts, line)
|
||||
}
|
||||
} else {
|
||||
for _, seg := range strings.Split(body, ",") {
|
||||
parts = append(parts, seg)
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
t := strings.TrimSpace(part)
|
||||
t = strings.Trim(t, "\"'`")
|
||||
t = strings.TrimLeft(t, "-*0123456789.、 \t")
|
||||
t = strings.TrimSpace(strings.ReplaceAll(t, `\"`, `"`))
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, t)
|
||||
}
|
||||
return normalizePlanSteps(out)
|
||||
}
|
||||
|
||||
func normalizePlanSteps(in []string) []string {
|
||||
out := make([]string, 0, len(in))
|
||||
for _, step := range in {
|
||||
t := strings.TrimSpace(step)
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, t)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stripCodeFence(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if !strings.HasPrefix(s, "```") {
|
||||
return s
|
||||
}
|
||||
s = strings.TrimPrefix(s, "```json")
|
||||
s = strings.TrimPrefix(s, "```JSON")
|
||||
s = strings.TrimPrefix(s, "```")
|
||||
s = strings.TrimSuffix(strings.TrimSpace(s), "```")
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/project"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
@@ -64,6 +65,7 @@ func RunDeepAgent(
|
||||
agentsMarkdownDir string,
|
||||
orchestrationOverride string,
|
||||
reasoningClient *reasoning.ClientIntent,
|
||||
systemPromptExtra string,
|
||||
) (*RunResult, error) {
|
||||
if appCfg == nil || ma == nil || ag == nil {
|
||||
return nil, fmt.Errorf("multiagent: 配置或 Agent 为空")
|
||||
@@ -339,6 +341,7 @@ func RunDeepAgent(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
orchInstruction = project.AppendSystemPromptBlock(orchInstruction, systemPromptExtra)
|
||||
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools, mainToolSearchActive)
|
||||
if logger != nil {
|
||||
mainNames := collectToolNames(ctx, mainTools)
|
||||
@@ -387,7 +390,8 @@ func RunDeepAgent(
|
||||
|
||||
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
||||
deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()}
|
||||
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes); mw != nil {
|
||||
taskEnrichExtra := systemPromptExtra
|
||||
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes, taskEnrichExtra); mw != nil {
|
||||
deepHandlers = append(deepHandlers, mw)
|
||||
}
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
@@ -538,7 +542,7 @@ func RunDeepAgent(
|
||||
}
|
||||
|
||||
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
||||
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
|
||||
baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
|
||||
|
||||
streamsMainAssistant := func(agent string) bool {
|
||||
if orchMode == "plan_execute" {
|
||||
@@ -566,6 +570,8 @@ func RunDeepAgent(
|
||||
StreamsMainAssistant: streamsMainAssistant,
|
||||
EinoRoleTag: einoRoleTag,
|
||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
|
||||
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
|
||||
McpIDsMu: &mcpIDsMu,
|
||||
McpIDs: &mcpIDs,
|
||||
FilesystemMonitorAgent: ag,
|
||||
@@ -595,6 +601,13 @@ func chatToolCallsToSchema(tcs []agent.ToolCall) []schema.ToolCall {
|
||||
argsStr = string(b)
|
||||
}
|
||||
}
|
||||
// Some OpenAI-compatible gateways require `function.arguments` to exist
|
||||
// on every assistant tool_call message. When args are empty, omitempty may
|
||||
// drop the field during serialization and cause "missing field arguments"
|
||||
// on the next turn history replay.
|
||||
if strings.TrimSpace(argsStr) == "" {
|
||||
argsStr = "{}"
|
||||
}
|
||||
typ := tc.Type
|
||||
if typ == "" {
|
||||
typ = "function"
|
||||
|
||||
@@ -30,8 +30,15 @@ type taskContextEnrichMiddleware struct {
|
||||
// newTaskContextEnrichMiddleware returns a middleware that enriches task
|
||||
// descriptions with user conversation context. Returns nil if disabled
|
||||
// (maxRunes < 0) or no user messages exist.
|
||||
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int) adk.ChatModelAgentMiddleware {
|
||||
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int, projectBlackboard string) adk.ChatModelAgentMiddleware {
|
||||
supplement := buildUserContextSupplement(userMessage, history, maxRunes)
|
||||
if bb := strings.TrimSpace(projectBlackboard); bb != "" {
|
||||
if supplement != "" {
|
||||
supplement += "\n\n## 项目黑板索引\n" + bb
|
||||
} else {
|
||||
supplement = "\n\n## 项目黑板索引\n" + bb
|
||||
}
|
||||
}
|
||||
if supplement == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ func TestTaskContextEnrichMiddleware_EnrichesTaskDescription(t *testing.T) {
|
||||
"继续测试",
|
||||
[]agent.ChatMessage{{Role: "user", Content: "http://8.163.32.73:8081 pikachu靶场"}},
|
||||
0,
|
||||
"",
|
||||
)
|
||||
if mw == nil {
|
||||
t.Fatal("expected non-nil middleware")
|
||||
@@ -149,7 +150,7 @@ func TestTaskContextEnrichMiddleware_EnrichesTaskDescription(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) {
|
||||
mw := newTaskContextEnrichMiddleware("test", nil, 0)
|
||||
mw := newTaskContextEnrichMiddleware("test", nil, 0, "")
|
||||
if mw == nil {
|
||||
t.Fatal("expected non-nil middleware")
|
||||
}
|
||||
@@ -175,7 +176,7 @@ func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTaskContextEnrichMiddleware_NilWhenDisabled(t *testing.T) {
|
||||
mw := newTaskContextEnrichMiddleware("test", nil, -1)
|
||||
mw := newTaskContextEnrichMiddleware("test", nil, -1, "")
|
||||
if mw != nil {
|
||||
t.Error("middleware should be nil when disabled")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
// AppendSystemPromptBlock 将附加块追加到 system prompt。
|
||||
func AppendSystemPromptBlock(base, block string) string {
|
||||
base = strings.TrimSpace(base)
|
||||
block = strings.TrimSpace(block)
|
||||
if block == "" {
|
||||
return base
|
||||
}
|
||||
if base == "" {
|
||||
return block
|
||||
}
|
||||
return base + "\n\n" + block
|
||||
}
|
||||
|
||||
// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。
|
||||
func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
|
||||
if db == nil || !cfg.Enabled {
|
||||
return "", nil
|
||||
}
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
proj, err := db.GetProject(projectID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
facts, err := db.ListProjectFactsForIndex(projectID, cfg.DefaultInjectDeprecated)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(facts) == 0 {
|
||||
return fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n(暂无事实)\n需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。", proj.Name, proj.ID), nil
|
||||
}
|
||||
|
||||
sort.SliceStable(facts, func(i, j int) bool {
|
||||
if facts[i].Pinned != facts[j].Pinned {
|
||||
return facts[i].Pinned
|
||||
}
|
||||
return facts[i].UpdatedAt.After(facts[j].UpdatedAt)
|
||||
})
|
||||
|
||||
maxRunes := cfg.FactIndexMaxRunesEffective()
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n", proj.Name, proj.ID))
|
||||
used := len([]rune(b.String()))
|
||||
omitted := 0
|
||||
|
||||
for _, f := range facts {
|
||||
line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence)
|
||||
lineRunes := len([]rune(line))
|
||||
if used+lineRunes > maxRunes {
|
||||
omitted++
|
||||
continue
|
||||
}
|
||||
b.WriteString(line)
|
||||
used += lineRunes
|
||||
}
|
||||
|
||||
if omitted > 0 {
|
||||
b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted))
|
||||
}
|
||||
b.WriteString("需要完整内容(攻击链、POC、请求响应等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。\n")
|
||||
b.WriteString("写入事实时:summary 写「什么+在哪+如何验证」;body 写可复现全流程(发现/利用类 fact_key 建议 finding|chain|exploit|poc/ 前缀)。\n")
|
||||
return b.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
)
|
||||
|
||||
// 边渗透边记录:统一节奏文案(agents/*.md 须与 FactRecordingIncrementalRhythmMarkdown 保持一致)。
|
||||
const (
|
||||
factRhythmCore = "勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。"
|
||||
factRhythmCoordinatorSuffix = "委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。"
|
||||
factRhythmSubAgentSuffix = "若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。"
|
||||
)
|
||||
|
||||
// FactRecordingIncrementalRhythmMarkdown 返回边渗透边记录节奏(Markdown,供 agents/*.md 与文档对齐)。
|
||||
func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("- **边渗透边记录(强制节奏)**:")
|
||||
b.WriteString(factRhythmCore)
|
||||
if coordinator {
|
||||
b.WriteString(factRhythmCoordinatorSuffix)
|
||||
}
|
||||
if subAgent {
|
||||
b.WriteString(factRhythmSubAgentSuffix)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func factRecordingIncrementalRhythmBuiltin(coordinator, subAgent bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 ")
|
||||
b.WriteString(builtin.ToolUpsertProjectFact)
|
||||
b.WriteString("(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 ")
|
||||
b.WriteString(builtin.ToolRecordVulnerability)
|
||||
b.WriteString(";与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。")
|
||||
if coordinator {
|
||||
b.WriteString(factRhythmCoordinatorSuffix)
|
||||
}
|
||||
if subAgent {
|
||||
b.WriteString(factRhythmSubAgentSuffix)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// FactRecordingBlackboardSection 项目黑板与漏洞记录的完整系统提示块(单/多 Agent 主代理共用)。
|
||||
// coordinatorDelegate 为 true 时追加「协调者代子代理落库」说明(Deep / plan_execute / supervisor)。
|
||||
func FactRecordingBlackboardSection(coordinatorDelegate bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
||||
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ")
|
||||
b.WriteString(builtin.ToolGetProjectFact)
|
||||
b.WriteString("(fact_key) 获取 body,禁止凭摘要臆造细节。**\n\n")
|
||||
b.WriteString(factRecordingIncrementalRhythmBuiltin(coordinatorDelegate, false))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞条目):使用 ")
|
||||
b.WriteString(builtin.ToolUpsertProjectFact)
|
||||
b.WriteString(",fact_key 建议 `category/slug`(如 target/primary_domain),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
||||
b.WriteString("- **发现与利用上下文**(审计复现):fact_key 建议 finding/、chain/、exploit/、poc/ 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 related_vulnerability_id),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
||||
b.WriteString("- **可交付漏洞**:使用 ")
|
||||
b.WriteString(builtin.ToolRecordVulnerability)
|
||||
b.WriteString(",含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ")
|
||||
b.WriteString(builtin.ToolListVulnerabilities)
|
||||
b.WriteString(" 查重,详情用 ")
|
||||
b.WriteString(builtin.ToolGetVulnerability)
|
||||
b.WriteString("(id)(默认仅当前项目/会话)。\n")
|
||||
b.WriteString("- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ")
|
||||
b.WriteString(builtin.ToolDeprecateProjectFact)
|
||||
b.WriteString(" 或漏洞状态 false_positive。\n")
|
||||
b.WriteString("- 事实多时用 ")
|
||||
b.WriteString(builtin.ToolListProjectFacts)
|
||||
b.WriteString(" / ")
|
||||
b.WriteString(builtin.ToolSearchProjectFacts)
|
||||
b.WriteString(" 检索。\n\n")
|
||||
b.WriteString(FactRecordingGuidanceBlock())
|
||||
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// FactRecordingSubAgentSection 子代理边渗透边记录(无工具时输出待落库条目)。
|
||||
func FactRecordingSubAgentSection() string {
|
||||
return "## 边渗透边记录\n\n" + factRecordingIncrementalRhythmBuiltin(false, true) + "\n"
|
||||
}
|
||||
|
||||
// FactRecordingBlackboardSectionMarkdown 与 FactRecordingBlackboardSection 等价的 Markdown(工具名为字面量,供 agents/*.md)。
|
||||
func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
||||
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**\n\n")
|
||||
b.WriteString(FactRecordingIncrementalRhythmMarkdown(coordinatorDelegate, false))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
||||
b.WriteString("- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
||||
b.WriteString("- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。\n")
|
||||
b.WriteString("- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。\n")
|
||||
b.WriteString("- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。\n\n")
|
||||
b.WriteString(FactRecordingGuidanceBlock())
|
||||
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 事实 category 常量(写入 upsert_project_fact 的 category 字段)。
|
||||
const (
|
||||
FactCategoryTarget = "target"
|
||||
FactCategoryAuth = "auth"
|
||||
FactCategoryInfra = "infra"
|
||||
FactCategoryBusiness = "business"
|
||||
FactCategoryFinding = "finding"
|
||||
FactCategoryChain = "chain"
|
||||
FactCategoryExploit = "exploit"
|
||||
FactCategoryPOC = "poc"
|
||||
FactCategoryNote = "note"
|
||||
)
|
||||
|
||||
// RequiresAttackChainBody 判断该事实是否应携带可复现的攻击链 / exploit 详情(写在 body,非仅 summary)。
|
||||
func RequiresAttackChainBody(category, factKey string) bool {
|
||||
c := strings.ToLower(strings.TrimSpace(category))
|
||||
switch c {
|
||||
case FactCategoryFinding, FactCategoryChain, FactCategoryExploit, FactCategoryPOC, "vuln":
|
||||
return true
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(factKey))
|
||||
for _, prefix := range []string{"finding/", "chain/", "exploit/", "poc/"} {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSparseFactBody 攻击链类事实 body 过短或缺少关键段落时返回 true(软校验,不阻断写入)。
|
||||
func IsSparseFactBody(category, factKey, body string) bool {
|
||||
if !RequiresAttackChainBody(category, factKey) {
|
||||
return false
|
||||
}
|
||||
body = strings.TrimSpace(body)
|
||||
if body == "" {
|
||||
return true
|
||||
}
|
||||
lower := strings.ToLower(body)
|
||||
// 至少应包含可复现线索:步骤/请求/命令/代码块 之一
|
||||
hasSteps := strings.Contains(lower, "攻击链") || strings.Contains(lower, "## 攻击") ||
|
||||
strings.Contains(lower, "## exploit") || strings.Contains(lower, "## poc")
|
||||
hasHTTP := strings.Contains(lower, "```http") || strings.Contains(lower, "```bash") ||
|
||||
strings.Contains(lower, "curl ") || strings.Contains(lower, "get ") || strings.Contains(lower, "post ")
|
||||
hasReq := strings.Contains(lower, "请求") || strings.Contains(lower, "响应") || strings.Contains(lower, "payload")
|
||||
// 无攻击链/POC/请求等结构线索,视为仅结论性描述(不论长短)
|
||||
return !(hasSteps || hasHTTP || hasReq)
|
||||
}
|
||||
|
||||
// FactBodyTemplate 按 category 返回建议的 body Markdown 骨架(供 Agent 填入真实内容)。
|
||||
func FactBodyTemplate(category, factKey string) string {
|
||||
if RequiresAttackChainBody(category, factKey) {
|
||||
return attackChainFactBodyTemplate
|
||||
}
|
||||
return envFactBodyTemplate
|
||||
}
|
||||
|
||||
const attackChainFactBodyTemplate = `## 结论(可验证,一句话)
|
||||
<勿仅写「存在漏洞」;写明类型 + 位置 + 触发条件>
|
||||
|
||||
## 目标与入口
|
||||
- 目标: <URL / IP:Port / 主机名>
|
||||
- 入口: <路径 / 接口 / 参数>
|
||||
- 前置条件: <匿名 / 角色 / Cookie / 其他依赖>
|
||||
|
||||
## 攻击链(逐步可复现)
|
||||
1. <侦察/发现>
|
||||
2. <利用/触发>
|
||||
3. <影响证明(读文件、RCE 回显、越权数据等)>
|
||||
|
||||
## Exploit / POC
|
||||
### 请求
|
||||
` + "```http\n<METHOD> <path> HTTP/1.1\nHost: ...\n...\n\n<body>\n```" + `
|
||||
|
||||
### 响应 / 现象
|
||||
<关键响应片段、状态码、差异点>
|
||||
|
||||
### 命令 / 脚本(如有)
|
||||
` + "```bash\n<command>\n```" + `
|
||||
|
||||
## 关键证据
|
||||
- <工具输出摘要 / 截图路径 / 会话或消息 ID>
|
||||
|
||||
## 关联
|
||||
- related_vulnerability_id: <可选,对应 record_vulnerability 的 id>
|
||||
- 依赖事实: <fact_key,如 auth/session_cookie>
|
||||
|
||||
## 备注与不确定性
|
||||
<待验证假设、环境差异、绕过尝试记录>`
|
||||
|
||||
const envFactBodyTemplate = `## 摘要
|
||||
<该事实的核心认知>
|
||||
|
||||
## 细节
|
||||
<端口/版本/路径/凭据特征/业务规则等>
|
||||
|
||||
## 来源与证据
|
||||
<命令输出、响应片段、发现时间>
|
||||
|
||||
## 关联
|
||||
- 相关 fact_key: <可选>`
|
||||
|
||||
// FactRecordingGuidanceBlock 写入系统提示:要求事实沉淀攻击链上下文而非仅结论。
|
||||
func FactRecordingGuidanceBlock() string {
|
||||
return `### 事实写入规范(审计复现 / 知识沉淀)
|
||||
|
||||
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||
- **body**:完整可复现上下文,写入 ` + "`upsert_project_fact`" + ` 的 body 字段;索引不含 body,后续会话须靠 ` + "`get_project_fact`" + ` 取回。
|
||||
- **category / fact_key 建议**:
|
||||
- 环境认知:` + "`target/`" + `、` + "`auth/`" + `、` + "`infra/`" + `、` + "`business/`" + `(body 用环境模板即可)
|
||||
- 发现与利用:` + "`finding/`" + `、` + "`chain/`" + `、` + "`exploit/`" + `、` + "`poc/`" + `(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||
- **与漏洞记录分工**:` + "`record_vulnerability`" + ` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||
- 更新同一发现时保持相同 ` + "`fact_key`" + ` 覆盖写入,勿散落多个 key 导致上下文丢失。`
|
||||
}
|
||||
|
||||
// SparseBodyWarning 攻击链类事实 body 不足时的工具返回提示(不阻断保存)。
|
||||
func SparseBodyWarning(category, factKey string) string {
|
||||
if !IsSparseFactBody(category, factKey, "") {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"\n\n⚠ 提示:category=%q / fact_key=%q 属于攻击链类事实,但 body 为空或过简。请补充完整攻击链与 POC(参考模板),便于后续审计复现。\n建议 body 骨架:\n%s",
|
||||
category, factKey, FactBodyTemplate(category, factKey),
|
||||
)
|
||||
}
|
||||
|
||||
// SparseBodyWarningIfNeeded 根据实际 body 判断是否追加警告。
|
||||
func SparseBodyWarningIfNeeded(category, factKey, body string) string {
|
||||
if !IsSparseFactBody(category, factKey, body) {
|
||||
return ""
|
||||
}
|
||||
return SparseBodyWarning(category, factKey)
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequiresAttackChainBody(t *testing.T) {
|
||||
cases := []struct {
|
||||
cat, key string
|
||||
want bool
|
||||
}{
|
||||
{"finding", "note/misc", true},
|
||||
{"note", "finding/sqli-login", true},
|
||||
{"target", "target/primary_domain", false},
|
||||
{"auth", "auth/admin_cookie", false},
|
||||
{"chain", "x", true},
|
||||
{"", "exploit/rce-upload", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := RequiresAttackChainBody(tc.cat, tc.key); got != tc.want {
|
||||
t.Errorf("RequiresAttackChainBody(%q,%q)=%v want %v", tc.cat, tc.key, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSparseFactBody(t *testing.T) {
|
||||
long := strings.Repeat("x", 150)
|
||||
if !IsSparseFactBody("finding", "finding/x", "") {
|
||||
t.Error("empty body should be sparse")
|
||||
}
|
||||
if !IsSparseFactBody("finding", "finding/x", long) {
|
||||
t.Error("body without repro clues should be sparse")
|
||||
}
|
||||
body := "## 攻击链\n1. step\n## Exploit\n```http\nGET / HTTP/1.1\n```\n"
|
||||
if IsSparseFactBody("finding", "finding/x", body) {
|
||||
t.Error("structured body should not be sparse")
|
||||
}
|
||||
if IsSparseFactBody("target", "target/x", "") {
|
||||
t.Error("env fact empty body is ok")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
// projectScopePayload 解析 projects.scope_json(约定字段,可扩展)。
|
||||
type projectScopePayload struct {
|
||||
Targets []string `json:"targets"`
|
||||
Exclude []string `json:"exclude"`
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// BuildScopeBlock 将项目 scope_json 格式化为 Agent 可读的授权范围块。
|
||||
func BuildScopeBlock(proj *database.Project) string {
|
||||
if proj == nil {
|
||||
return ""
|
||||
}
|
||||
raw := strings.TrimSpace(proj.ScopeJSON)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var payload projectScopePayload
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
return fmt.Sprintf("## 项目测试范围(project: %s)\n(scope_json 非合法 JSON,请人工核对配置)\n```\n%s\n```\n"+
|
||||
"仅对明确授权目标执行测试;超出范围须停止并说明。\n", proj.Name, truncateRunes(raw, 800))
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("## 项目测试范围(project: %s, id: %s)\n", proj.Name, proj.ID))
|
||||
b.WriteString("以下为授权边界,**必须遵守**:仅测试列出的 targets,避开 exclude,不得擅自扩大范围。\n")
|
||||
|
||||
if len(payload.Targets) > 0 {
|
||||
b.WriteString("\n**允许测试(targets)**:\n")
|
||||
for _, t := range payload.Targets {
|
||||
t = strings.TrimSpace(t)
|
||||
if t != "" {
|
||||
b.WriteString("- " + t + "\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(payload.Exclude) > 0 {
|
||||
b.WriteString("\n**明确排除(exclude)**:\n")
|
||||
for _, t := range payload.Exclude {
|
||||
t = strings.TrimSpace(t)
|
||||
if t != "" {
|
||||
b.WriteString("- " + t + "\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
if n := strings.TrimSpace(payload.Notes); n != "" {
|
||||
b.WriteString("\n**说明(notes)**:\n" + n + "\n")
|
||||
}
|
||||
if len(payload.Targets) == 0 && len(payload.Exclude) == 0 && strings.TrimSpace(payload.Notes) == "" {
|
||||
b.WriteString("\n(scope_json 已配置但未识别 targets/exclude/notes 字段,原始内容供参考)\n```json\n")
|
||||
b.WriteString(truncateRunes(raw, 1200))
|
||||
b.WriteString("\n```\n")
|
||||
}
|
||||
b.WriteString("\n若目标不在 targets 内或命中 exclude,不得主动扫描/利用;需用户明确扩大授权后再继续。\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func truncateRunes(s string, max int) string {
|
||||
r := []rune(s)
|
||||
if len(r) <= max {
|
||||
return s
|
||||
}
|
||||
return string(r[:max]) + "…"
|
||||
}
|
||||
|
||||
// BuildProjectBlackboardBlock 组合测试范围 + 事实黑板索引。
|
||||
func BuildProjectBlackboardBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return "", nil
|
||||
}
|
||||
proj, err := db.GetProject(projectID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
parts := []string{}
|
||||
if scope := strings.TrimSpace(BuildScopeBlock(proj)); scope != "" {
|
||||
parts = append(parts, scope)
|
||||
}
|
||||
index, err := BuildFactIndexBlock(db, projectID, cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(index) != "" {
|
||||
parts = append(parts, index)
|
||||
}
|
||||
return strings.Join(parts, "\n\n"), nil
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
func TestBuildScopeBlock_targetsExcludeNotes(t *testing.T) {
|
||||
proj := &database.Project{
|
||||
ID: "p1",
|
||||
Name: "Acme",
|
||||
ScopeJSON: `{"targets":["https://app.example.com"],"exclude":["*.cdn.example.com"],"notes":"仅 Web 层"}`,
|
||||
}
|
||||
block := BuildScopeBlock(proj)
|
||||
if !strings.Contains(block, "https://app.example.com") {
|
||||
t.Fatalf("missing target: %s", block)
|
||||
}
|
||||
if !strings.Contains(block, "cdn.example.com") {
|
||||
t.Fatalf("missing exclude: %s", block)
|
||||
}
|
||||
if !strings.Contains(block, "仅 Web 层") {
|
||||
t.Fatalf("missing notes: %s", block)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildScopeBlock_empty(t *testing.T) {
|
||||
if BuildScopeBlock(&database.Project{Name: "X"}) != "" {
|
||||
t.Fatal("expected empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildScopeBlock_invalidJSON(t *testing.T) {
|
||||
proj := &database.Project{Name: "X", ScopeJSON: `{not json`}
|
||||
block := BuildScopeBlock(proj)
|
||||
if !strings.Contains(block, "非合法 JSON") {
|
||||
t.Fatalf("unexpected: %s", block)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package project
|
||||
|
||||
import "cyberstrike-ai/internal/database"
|
||||
|
||||
// GetProjectStats 聚合项目统计(含待补全事实数)。
|
||||
func GetProjectStats(db *database.DB, projectID string) (*database.ProjectStats, error) {
|
||||
stats, err := db.GetProjectStatsCounts(projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := db.ListProjectFactsForSparseCheck(projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, r := range rows {
|
||||
if IsSparseFactBody(r.Category, r.FactKey, r.Body) {
|
||||
stats.SparseFactCount++
|
||||
}
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user